diff --git a/http.go b/http.go index f7f6c86..fd47234 100644 --- a/http.go +++ b/http.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "github.com/jmcvetta/randutil" + "github.com/spf13/viper" "net" "net/http" "net/url" @@ -117,6 +118,20 @@ func redirectHandler(w http.ResponseWriter, r *http.Request) { } func reloadHandler(w http.ResponseWriter, r *http.Request) { + token := r.Header.Get("Authorization") + + if token == "" || !strings.HasPrefix(token, "Bearer") || !strings.Contains(token, " ") { + w.WriteHeader(http.StatusUnauthorized) + return + } + + token = token[strings.Index(token, " ")+1:] + + if token != viper.GetString("reloadToken") { + w.WriteHeader(http.StatusUnauthorized) + return + } + reloadConfig() w.WriteHeader(http.StatusOK) diff --git a/main.go b/main.go index 42f9910..9900b1c 100644 --- a/main.go +++ b/main.go @@ -79,14 +79,20 @@ type ServerConfig struct { var ( configFlag = flag.String("config", "", "configuration file path") + flagDebug = flag.Bool("debug", false, "Enable debug logging") ) func main() { flag.Parse() + if *flagDebug { + log.SetLevel(log.DebugLevel) + } + viper.SetDefault("bind", ":8080") viper.SetDefault("cacheSize", 1024) viper.SetDefault("topChoices", 3) + viper.SetDefault("reloadKey", randSeq(32)) viper.SetConfigName("dlrouter") // name of config file (without extension) viper.SetConfigType("yaml") // REQUIRED if the config file does not have the extension in the name diff --git a/servers.go b/servers.go index b76553f..26aecaa 100644 --- a/servers.go +++ b/servers.go @@ -7,6 +7,7 @@ import ( "math" "net" "net/http" + "net/url" "runtime" "sort" "strings" @@ -33,7 +34,7 @@ type Server struct { } func (server *Server) checkStatus() { - req, err := http.NewRequest(http.MethodGet, "https://"+server.Host+"/"+strings.TrimLeft(server.Path, "/"), nil) + req, err := http.NewRequest(http.MethodGet, "http://"+server.Host+"/"+strings.TrimLeft(server.Path, "/"), nil) req.Header.Set("User-Agent", "ArmbianRouter/1.0 (Go "+runtime.Version()+")") @@ -72,6 +73,35 @@ func (server *Server) checkStatus() { } if res.StatusCode == http.StatusOK || res.StatusCode == http.StatusMovedPermanently || res.StatusCode == http.StatusFound || res.StatusCode == http.StatusNotFound { + if res.StatusCode == http.StatusMovedPermanently || res.StatusCode == http.StatusFound { + location := res.Header.Get("Location") + + responseFields["url"] = location + + log.WithFields(responseFields).Debug("Server responded with redirect") + + newUrl, err := url.Parse(location) + + if err != nil { + if server.Available { + log.WithFields(responseFields).Warning("Server returned invalid url") + server.Available = false + server.LastChange = time.Now() + } + return + } + + if newUrl.Scheme == "https" { + if server.Available { + responseFields["url"] = location + log.WithFields(responseFields).Warning("Server returned https url for http request") + server.Available = false + server.LastChange = time.Now() + } + return + } + } + if !server.Available { server.Available = true server.LastChange = time.Now() @@ -161,9 +191,15 @@ func (s ServerList) Closest(ip net.IP) (*Server, float64, error) { return c[i].Distance < c[j].Distance }) - choices := make([]randutil.Choice, topChoices) + choiceCount := topChoices - for i, item := range c[0:topChoices] { + if len(c) < topChoices { + choiceCount = len(c) + } + + choices := make([]randutil.Choice, choiceCount) + + for i, item := range c[0:choiceCount] { if item.Server == nil { continue } diff --git a/util.go b/util.go new file mode 100644 index 0000000..a79ece5 --- /dev/null +++ b/util.go @@ -0,0 +1,13 @@ +package main + +import "math/rand" + +var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + +func randSeq(n int) string { + b := make([]rune, n) + for i := range b { + b[i] = letters[rand.Intn(len(letters))] + } + return string(b) +}