Compare commits

..

No commits in common. "feature/refactor" and "master" have entirely different histories.

21 changed files with 370 additions and 668 deletions

View File

@ -10,7 +10,7 @@ steps:
path: /build
commands:
- go mod download
- go install -mod=mod github.com/onsi/ginkgo/v2/ginkgo
- go install github.com/onsi/ginkgo/v2/ginkgo
- ginkgo --randomize-all --p --cover --coverprofile=cover.out .
- go tool cover -func=cover.out
environment:

1
.gitignore vendored
View File

@ -4,4 +4,3 @@ dlrouter-apt.yaml
*.yaml
!dlrouter.yaml
*.exe
cover.out

View File

@ -5,8 +5,7 @@ This repository contains a redirect service for Armbian downloads, apt, etc.
It uses multiple current technologies and best practices, including:
- Go 1.19
- Ginkgo v2 and Gomega testing framework
- Go 1.17/1.18
- GeoIP + Distance routing
- Server weighting, pooling (top x servers are served instead of a single one)
- Health checks (HTTP, TLS)
@ -14,28 +13,9 @@ It uses multiple current technologies and best practices, including:
Code Quality
------------
The code quality isn't the greatest/top tier. Work is being done towards cleaning it up and standardizing it, writing tests, etc.
The code quality isn't the greatest/top tier. All code lives in the "main" package and should be moved at some point.
All contributions are welcome, see the `check_test.go` file for example tests.
Checks
------
The supported checks are HTTP and TLS.
### HTTP
Verifies server accessibility via HTTP. If the server returns a forced redirect to an `https://` url, it is considered to be https-only.
If the server responds on the `https` url with a forced `http` redirect, it will be marked down due to misconfiguration. Requests should never downgrade.
### TLS
Certificate checking to ensure no servers are used which have invalid/expired certificates. This check is written to use the Mozilla ca certificate list, loaded on start/config load, to verify roots.
OS certificate trusts WERE being used to do this, however some issues with the date validation (which could be user error) caused the move to the ca bundle, which could be considered more usable.
Note: This downloads from github every startup/reload. This should be a reliable process, as long as Mozilla doesn't deprecate their repo. Their HG URL is super slow.
Regardless, it is meant to be simple and easy to understand.
Configuration
-------------
@ -72,19 +52,12 @@ cacheSize: 1024
# server = full url or host+path
# weight = int
# optional: latitude, longitude (float)
# optional: protocols (list/array)
servers:
- server: armbian.12z.eu/apt/
- server: armbian.chi.auroradev.org/apt/
weight: 15
latitude: 41.8879
longitude: -88.1995
# Example of a server with additional protocols (rsync)
# Useful for defining servers which could be used for rsync sources
- server: mirrors.dotsrc.org/armbian-apt/
weight: 15
protocols:
- rsync
````
## API

View File

@ -1,4 +1,4 @@
package redirector
package main
import (
"testing"

113
check.go
View File

@ -1,4 +1,4 @@
package redirector
package main
import (
"crypto/tls"
@ -10,26 +10,18 @@ import (
"net/http"
"net/url"
"runtime"
"strings"
"time"
)
var (
ErrHttpsRedirect = errors.New("unexpected forced https redirect")
ErrHttpRedirect = errors.New("unexpected redirect to insecure url")
ErrCertExpired = errors.New("certificate is expired")
)
func (r *Redirector) checkHttp(scheme string) ServerCheck {
return func(server *Server, logFields log.Fields) (bool, error) {
return r.checkHttpScheme(server, scheme, logFields)
}
}
// checkHttp checks a URL for validity, and checks redirects
func (r *Redirector) checkHttpScheme(server *Server, scheme string, logFields log.Fields) (bool, error) {
func checkHttp(server *Server, logFields log.Fields) (bool, error) {
u := &url.URL{
Scheme: scheme,
Scheme: "http",
Host: server.Host,
Path: server.Path,
}
@ -42,7 +34,7 @@ func (r *Redirector) checkHttpScheme(server *Server, scheme string, logFields lo
return false, err
}
res, err := r.config.checkClient.Do(req)
res, err := checkClient.Do(req)
if err != nil {
return false, err
@ -56,20 +48,13 @@ func (r *Redirector) checkHttpScheme(server *Server, scheme string, logFields lo
logFields["url"] = location
switch u.Scheme {
case "http":
res, err := r.checkRedirect(u.Scheme, location)
// Check that we don't redirect to https from a http url
if u.Scheme == "http" {
res, err := checkRedirect(location)
if !res || err != nil {
// If we don't support http, we remove it from supported protocols
server.Protocols = server.Protocols.Remove("http")
} else {
// Otherwise, we verify https support
r.checkProtocol(server, "https")
return res, err
}
case "https":
// We don't want to allow downgrading, so this is an error.
return r.checkRedirect(u.Scheme, location)
}
}
@ -81,20 +66,8 @@ func (r *Redirector) checkHttpScheme(server *Server, scheme string, logFields lo
return false, nil
}
func (r *Redirector) checkProtocol(server *Server, scheme string) {
res, err := r.checkHttpScheme(server, scheme, log.Fields{})
if !res || err != nil {
return
}
if !server.Protocols.Contains(scheme) {
server.Protocols = server.Protocols.Append(scheme)
}
}
// checkRedirect parses a location header response and checks the scheme
func (r *Redirector) checkRedirect(originatingScheme, locationHeader string) (bool, error) {
func checkRedirect(locationHeader string) (bool, error) {
newUrl, err := url.Parse(locationHeader)
if err != nil {
@ -103,41 +76,20 @@ func (r *Redirector) checkRedirect(originatingScheme, locationHeader string) (bo
if newUrl.Scheme == "https" {
return false, ErrHttpsRedirect
} else if originatingScheme == "https" && newUrl.Scheme == "http" {
return false, ErrHttpRedirect
}
return true, nil
}
// checkTLS checks tls certificates from a host, ensures they're valid, and not expired.
func (r *Redirector) checkTLS(server *Server, logFields log.Fields) (bool, error) {
var host, port string
var err error
if strings.Contains(server.Host, ":") {
host, port, err = net.SplitHostPort(server.Host)
if err != nil {
return false, err
}
} else {
host = server.Host
}
log.WithFields(log.Fields{
"server": server.Host,
"host": host,
"port": port,
}).Info("Checking TLS server")
func checkTLS(server *Server, logFields log.Fields) (bool, error) {
host, port, err := net.SplitHostPort(server.Host)
if port == "" {
port = "443"
}
conn, err := tls.Dial("tcp", host+":"+port, &tls.Config{
RootCAs: r.config.RootCAs,
})
conn, err := tls.Dial("tcp", host+":"+port, checkTLSConfig)
if err != nil {
return false, err
@ -155,38 +107,18 @@ func (r *Redirector) checkTLS(server *Server, logFields log.Fields) (bool, error
state := conn.ConnectionState()
peerPool := x509.NewCertPool()
for _, intermediate := range state.PeerCertificates {
if !intermediate.IsCA {
continue
}
peerPool.AddCert(intermediate)
}
opts := x509.VerifyOptions{
Roots: r.config.RootCAs,
Intermediates: peerPool,
CurrentTime: time.Now(),
CurrentTime: time.Now(),
}
// We want only the leaf certificate, as this will verify up the chain for us.
cert := state.PeerCertificates[0]
if _, err := cert.Verify(opts); err != nil {
logFields["peerCert"] = cert.Subject.String()
if authErr, ok := err.(x509.UnknownAuthorityError); ok {
logFields["authCert"] = authErr.Cert.Subject.String()
logFields["ca"] = authErr.Cert.Issuer
for _, cert := range state.PeerCertificates {
if _, err := cert.Verify(opts); err != nil {
logFields["peerCert"] = cert.Subject.String()
return false, err
}
if now.Before(cert.NotBefore) || now.After(cert.NotAfter) {
return false, err
}
return false, err
}
if now.Before(cert.NotBefore) || now.After(cert.NotAfter) {
logFields["peerCert"] = cert.Subject.String()
return false, err
}
for _, chain := range state.VerifiedChains {
@ -198,10 +130,5 @@ func (r *Redirector) checkTLS(server *Server, logFields log.Fields) (bool, error
}
}
// If https is valid, append it
if !server.Protocols.Contains("https") {
server.Protocols = server.Protocols.Append("https")
}
return true, nil
}

View File

@ -1,4 +1,4 @@
package redirector
package main
import (
"crypto/rand"
@ -58,14 +58,11 @@ var _ = Describe("Check suite", func() {
httpServer *httptest.Server
server *Server
handler http.HandlerFunc
r *Redirector
)
BeforeEach(func() {
httpServer = httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler(w, r)
}))
r = New(&Config{})
r.config.SetRootCAs(x509.NewCertPool())
})
AfterEach(func() {
httpServer.Close()
@ -92,11 +89,22 @@ var _ = Describe("Check suite", func() {
w.WriteHeader(http.StatusOK)
}
res, err := r.checkHttpScheme(server, "http", log.Fields{})
res, err := checkHttp(server, log.Fields{})
Expect(res).To(BeTrue())
Expect(err).To(BeNil())
})
It("Should return an error when redirected to https", func() {
handler = func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Location", strings.Replace(httpServer.URL, "http://", "https://", -1))
w.WriteHeader(http.StatusMovedPermanently)
}
res, err := checkHttp(server, log.Fields{})
Expect(res).To(BeFalse())
Expect(err).To(Equal(ErrHttpsRedirect))
})
})
Context("TLS Checks", func() {
var (
@ -125,60 +133,50 @@ var _ = Describe("Check suite", func() {
Certificates: []tls.Certificate{tlsPair},
}
pool := x509.NewCertPool()
pool.AddCert(x509Cert)
r.config.SetRootCAs(pool)
httpServer.StartTLS()
setupServer()
}
Context("HTTPS Checks", func() {
BeforeEach(func() {
setupCerts(time.Now(), time.Now().Add(24*time.Hour))
})
It("Should return an error when redirected to http from https", func() {
handler = func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Location", strings.Replace(httpServer.URL, "https://", "http://", -1))
w.WriteHeader(http.StatusMovedPermanently)
}
logFields := log.Fields{}
res, err := r.checkHttpScheme(server, "https", logFields)
Expect(logFields["url"]).ToNot(BeEmpty())
Expect(logFields["url"]).ToNot(Equal(httpServer.URL))
Expect(err).To(Equal(ErrHttpRedirect))
Expect(res).To(BeFalse())
})
})
Context("CA Tests", func() {
BeforeEach(func() {
setupCerts(time.Now(), time.Now().Add(24*time.Hour))
})
It("Should fail due to invalid ca", func() {
r.config.SetRootCAs(x509.NewCertPool())
res, err := r.checkTLS(server, log.Fields{})
res, err := checkTLS(server, log.Fields{})
Expect(res).To(BeFalse())
Expect(err).ToNot(BeNil())
})
It("Should successfully validate certificates (valid ca, valid date/times, etc)", func() {
res, err := r.checkTLS(server, log.Fields{})
pool := x509.NewCertPool()
pool.AddCert(x509Cert)
checkTLSConfig = &tls.Config{RootCAs: pool}
res, err := checkTLS(server, log.Fields{})
Expect(res).To(BeFalse())
Expect(err).ToNot(BeNil())
checkTLSConfig = nil
})
})
Context("Expiration tests", func() {
AfterEach(func() {
checkTLSConfig = nil
})
It("Should fail due to not yet valid certificate", func() {
setupCerts(time.Now().Add(5*time.Hour), time.Now().Add(10*time.Hour))
// Trust our certs
pool := x509.NewCertPool()
pool.AddCert(x509Cert)
checkTLSConfig = &tls.Config{RootCAs: pool}
// Check TLS
res, err := r.checkTLS(server, log.Fields{})
res, err := checkTLS(server, log.Fields{})
Expect(res).To(BeFalse())
Expect(err).ToNot(BeNil())
@ -186,8 +184,15 @@ var _ = Describe("Check suite", func() {
It("Should fail due to expired certificate", func() {
setupCerts(time.Now().Add(-10*time.Hour), time.Now().Add(-5*time.Hour))
// Trust our certs
pool := x509.NewCertPool()
pool.AddCert(x509Cert)
checkTLSConfig = &tls.Config{RootCAs: pool}
// Check TLS
res, err := r.checkTLS(server, log.Fields{})
res, err := checkTLS(server, log.Fields{})
Expect(res).To(BeFalse())
Expect(err).ToNot(BeNil())

View File

@ -1,112 +0,0 @@
package main
import (
"flag"
"github.com/armbian/redirector"
"github.com/armbian/redirector/util"
log "github.com/sirupsen/logrus"
"github.com/spf13/viper"
"os"
"os/signal"
"syscall"
)
var (
configFlag = flag.String("config", "", "configuration file path")
flagDebug = flag.Bool("debug", false, "Enable debug logging")
)
func main() {
flag.Parse()
if *flagDebug {
log.SetLevel(log.DebugLevel)
}
viper.SetDefault("bind", ":8080")
viper.SetDefault("cacheSize", 1024)
viper.SetDefault("topChoices", 3)
viper.SetDefault("reloadKey", redirector.RandomSequence(32))
viper.SetConfigName("dlrouter") // name of config file (without extension)
viper.SetConfigType("yaml") // REQUIRED if the config file does not have the extension in the name
viper.AddConfigPath("/etc/dlrouter/") // path to look for the config file in
viper.AddConfigPath("$HOME/.dlrouter") // call multiple times to add many search paths
viper.AddConfigPath(".") // optionally look for config in the working directory
if *configFlag != "" {
viper.SetConfigFile(*configFlag)
}
config := &redirector.Config{}
loadConfig := func(fatal bool) {
log.Info("Reading configuration")
// Bind reload to reading in the viper config, then deserializing
if err := viper.ReadInConfig(); err != nil {
log.WithError(err).Error("Unable to unmarshal configuration")
if fatal {
os.Exit(1)
}
}
log.Info("Unmarshalling configuration")
if err := viper.Unmarshal(config); err != nil {
log.WithError(err).Error("Unable to unmarshal configuration")
if fatal {
os.Exit(1)
}
}
log.Info("Updating root certificates")
certs, err := util.LoadCACerts()
if err != nil {
log.WithError(err).Error("Unable to load certificates")
if fatal {
os.Exit(1)
}
}
config.RootCAs = certs
}
config.ReloadFunc = func() {
loadConfig(false)
}
loadConfig(true)
redir := redirector.New(config)
// Because we have a bind address, we can start it without the return value.
redir.Start()
log.Info("Ready")
c := make(chan os.Signal)
signal.Notify(c, syscall.SIGKILL, syscall.SIGTERM, syscall.SIGHUP)
for {
sig := <-c
if sig != syscall.SIGHUP {
break
}
loadConfig(false)
err := redir.ReloadConfig()
if err != nil {
log.WithError(err).Warning("Did not reload configuration due to error")
}
}
}

194
config.go
View File

@ -1,175 +1,109 @@
package redirector
package main
import (
"crypto/tls"
"crypto/x509"
lru "github.com/hashicorp/golang-lru"
"github.com/oschwald/maxminddb-golang"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
log "github.com/sirupsen/logrus"
"github.com/spf13/viper"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
)
type Config struct {
BindAddress string `mapstructure:"bind"`
GeoDBPath string `mapstructure:"geodb"`
MapFile string `mapstructure:"dl_map"`
CacheSize int `mapstructure:"cacheSize"`
TopChoices int `mapstructure:"topChoices"`
ReloadToken string `mapstructure:"reloadToken"`
ServerList []ServerConfig `mapstructure:"servers"`
ReloadFunc func()
RootCAs *x509.CertPool
checkClient *http.Client
}
// SetRootCAs sets the root ca files, and creates the http client for checks
// This **MUST** be called before r.checkClient is used.
func (c *Config) SetRootCAs(cas *x509.CertPool) {
c.RootCAs = cas
t := &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: cas,
},
}
c.checkClient = &http.Client{
Transport: t,
Timeout: 20 * time.Second,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
}
type ProtocolList []string
func (p ProtocolList) Contains(value string) bool {
for _, val := range p {
if value == val {
return true
}
}
return false
}
func (p ProtocolList) Append(value string) ProtocolList {
return append(p, value)
}
func (p ProtocolList) Remove(value string) ProtocolList {
index := -1
for i, val := range p {
if value == val {
index = i
break
}
}
if index == -1 {
return p
}
p[index] = p[len(p)-1]
return p[:len(p)-1]
}
func (r *Redirector) ReloadConfig() error {
func reloadConfig() error {
log.Info("Loading configuration...")
var err error
err := viper.ReadInConfig() // Find and read the config file
// Load maxmind database
if r.db != nil {
err = r.db.Close()
if err != nil {
return errors.Wrap(err, "Unable to close database")
}
if err != nil { // Handle errors reading the config file
return errors.Wrap(err, "Unable to read configuration")
}
// db can be hot-reloaded if the file changed
r.db, err = maxminddb.Open(r.config.GeoDBPath)
// db will never be reloaded.
if db == nil {
// Load maxmind database
db, err = maxminddb.Open(viper.GetString("geodb"))
if err != nil {
return errors.Wrap(err, "Unable to open database")
if err != nil {
return errors.Wrap(err, "Unable to open database")
}
}
// Refresh server cache if size changed
if r.serverCache == nil {
r.serverCache, err = lru.New(r.config.CacheSize)
if serverCache == nil {
serverCache, err = lru.New(viper.GetInt("cacheSize"))
} else {
r.serverCache.Resize(r.config.CacheSize)
serverCache.Resize(viper.GetInt("cacheSize"))
}
// Purge the cache to ensure we don't have any invalid servers in it
r.serverCache.Purge()
serverCache.Purge()
// Set top choice count
topChoices = viper.GetInt("topChoices")
// Reload map file
if err := r.reloadMap(); err != nil {
if err := reloadMap(); err != nil {
return errors.Wrap(err, "Unable to load map file")
}
// Reload server list
if err := r.reloadServers(); err != nil {
if err := reloadServers(); err != nil {
return errors.Wrap(err, "Unable to load servers")
}
// Create mirror map
mirrors := make(map[string][]*Server)
for _, server := range r.servers {
for _, server := range servers {
mirrors[server.Continent] = append(mirrors[server.Continent], server)
}
mirrors["default"] = append(mirrors["NA"], mirrors["EU"]...)
r.regionMap = mirrors
regionMap = mirrors
hosts := make(map[string]*Server)
for _, server := range r.servers {
for _, server := range servers {
hosts[server.Host] = server
}
r.hostMap = hosts
hostMap = hosts
// Check top choices size
if r.config.TopChoices > len(r.servers) {
r.config.TopChoices = len(r.servers)
if topChoices > len(servers) {
topChoices = len(servers)
}
// Force check
go r.servers.Check(r.checks)
go servers.Check()
return nil
}
func (r *Redirector) reloadServers() error {
log.WithField("count", len(r.config.ServerList)).Info("Loading servers")
func reloadServers() error {
var serverList []ServerConfig
if err := viper.UnmarshalKey("servers", &serverList); err != nil {
return err
}
var wg sync.WaitGroup
existing := make(map[string]int)
for i, server := range r.servers {
for i, server := range servers {
existing[server.Host] = i
}
hosts := make(map[string]bool)
var hostsLock sync.Mutex
for _, server := range r.config.ServerList {
for _, server := range serverList {
wg.Add(1)
var prefix string
@ -188,6 +122,8 @@ func (r *Redirector) reloadServers() error {
return err
}
hosts[u.Host] = true
i := -1
if v, exists := existing[u.Host]; exists {
@ -197,28 +133,19 @@ func (r *Redirector) reloadServers() error {
go func(i int, server ServerConfig, u *url.URL) {
defer wg.Done()
s, err := r.addServer(server, u)
if err != nil {
log.WithError(err).Warning("Unable to add server")
return
}
hostsLock.Lock()
hosts[u.Host] = true
hostsLock.Unlock()
s := addServer(server, u)
if _, ok := existing[u.Host]; ok {
s.Redirects = r.servers[i].Redirects
s.Redirects = servers[i].Redirects
r.servers[i] = s
servers[i] = s
} else {
s.Redirects = promauto.NewCounter(prometheus.CounterOpts{
Name: "armbian_router_redirects_" + metricReplacer.Replace(u.Host),
Help: "The number of redirects for server " + u.Host,
})
r.servers = append(r.servers, s)
servers = append(servers, s)
log.WithFields(log.Fields{
"server": u.Host,
@ -233,16 +160,16 @@ func (r *Redirector) reloadServers() error {
wg.Wait()
// Remove servers that no longer exist in the config
for i := len(r.servers) - 1; i >= 0; i-- {
if _, exists := hosts[r.servers[i].Host]; exists {
for i := len(servers) - 1; i >= 0; i-- {
if _, exists := hosts[servers[i].Host]; exists {
continue
}
log.WithFields(log.Fields{
"server": r.servers[i].Host,
"server": servers[i].Host,
}).Info("Removed server")
r.servers = append(r.servers[:i], r.servers[i+1:]...)
servers = append(servers[:i], servers[i+1:]...)
}
return nil
@ -252,7 +179,7 @@ var metricReplacer = strings.NewReplacer(".", "_", "-", "_")
// addServer takes ServerConfig and constructs a server.
// This will create duplicate servers, but it will overwrite existing ones when changed.
func (r *Redirector) addServer(server ServerConfig, u *url.URL) (*Server, error) {
func addServer(server ServerConfig, u *url.URL) *Server {
s := &Server{
Available: true,
Host: u.Host,
@ -261,15 +188,6 @@ func (r *Redirector) addServer(server ServerConfig, u *url.URL) (*Server, error)
Longitude: server.Longitude,
Continent: server.Continent,
Weight: server.Weight,
Protocols: ProtocolList{"http", "https"},
}
if len(server.Protocols) > 0 {
for _, proto := range server.Protocols {
if !s.Protocols.Contains(proto) {
s.Protocols = s.Protocols.Append(proto)
}
}
}
// Defaults to 10 to allow servers to be set lower for lower priority
@ -284,11 +202,11 @@ func (r *Redirector) addServer(server ServerConfig, u *url.URL) (*Server, error)
"error": err,
"server": s.Host,
}).Warning("Could not resolve address")
return nil, err
return nil
}
var city City
err = r.db.Lookup(ips[0], &city)
err = db.Lookup(ips[0], &city)
if err != nil {
log.WithFields(log.Fields{
@ -296,7 +214,7 @@ func (r *Redirector) addServer(server ServerConfig, u *url.URL) (*Server, error)
"server": s.Host,
"ip": ips[0],
}).Warning("Could not geolocate address")
return nil, err
return nil
}
if s.Continent == "" {
@ -308,11 +226,11 @@ func (r *Redirector) addServer(server ServerConfig, u *url.URL) (*Server, error)
s.Longitude = city.Location.Longitude
}
return s, nil
return s
}
func (r *Redirector) reloadMap() error {
mapFile := r.config.MapFile
func reloadMap() error {
mapFile := viper.GetString("dl_map")
if mapFile == "" {
return nil
@ -326,7 +244,7 @@ func (r *Redirector) reloadMap() error {
return err
}
r.dlMap = newMap
dlMap = newMap
return nil
}

View File

@ -34,10 +34,6 @@ servers:
- server: mirrors.bfsu.edu.cn/armbian/
- server: mirrors.dotsrc.org/armbian-apt/
weight: 15
protocols:
- http
- https
- rsync
- server: mirrors.netix.net/armbian/apt/
- server: mirrors.nju.edu.cn/armbian/
- server: mirrors.sustech.edu.cn/armbian/

5
go.mod
View File

@ -1,11 +1,10 @@
module github.com/armbian/redirector
module meow.tf/armbian-router
go 1.19
go 1.17
require (
github.com/chi-middleware/logrus-logger v0.2.0
github.com/go-chi/chi/v5 v5.0.7
github.com/gwatts/rootcerts v0.0.0-20220501184621-6eac2dff0b8d
github.com/hashicorp/golang-lru v0.5.4
github.com/jmcvetta/randutil v0.0.0-20150817122601-2bb1b664bcff
github.com/onsi/ginkgo/v2 v2.1.4

5
go.sum
View File

@ -128,6 +128,7 @@ github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9
github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk=
github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I=
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
@ -197,6 +198,7 @@ github.com/google/pprof v0.0.0-20201023163331-3e6fc7fc9c4c/go.mod h1:kpwsk12EmLe
github.com/google/pprof v0.0.0-20201203190320-1bf35d6f28c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/pprof v0.0.0-20210122040257-d980be63207e/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/pprof v0.0.0-20210226084205-cbba55b83ad5/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/pprof v0.0.0-20210601050228-01bbb1931b22/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/pprof v0.0.0-20210609004039-a478d1d731e9/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
@ -208,8 +210,6 @@ github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5m
github.com/googleapis/gax-go/v2 v2.1.0/go.mod h1:Q3nei7sK6ybPYH7twZdmQpAd1MKb7pfu6SK+H1/DsU0=
github.com/googleapis/gax-go/v2 v2.1.1/go.mod h1:hddJymUZASv3XPyGkUpKj8pPO47Rmb0eJc8R6ouapiM=
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
github.com/gwatts/rootcerts v0.0.0-20220501184621-6eac2dff0b8d h1:Kp5G1kHMb2fAD9OiqWDXro4qLB8bQ2NusoorYya4Lbo=
github.com/gwatts/rootcerts v0.0.0-20220501184621-6eac2dff0b8d/go.mod h1:5Kt9XkWvkGi2OHOq0QsGxebHmhCcqJ8KCbNg/a6+n+g=
github.com/hashicorp/consul/api v1.12.0/go.mod h1:6pVBMo0ebnYdt2S3H87XhekM/HHrUoTD2XXb/VrZVy0=
github.com/hashicorp/consul/sdk v0.8.0/go.mod h1:GBvyrGALthsZObzUGsfgHZQDXjg4lOjagTIwIR1vPms=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
@ -680,6 +680,7 @@ golang.org/x/tools v0.1.2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.3/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.10 h1:QjFRCZxdOhBJ/UNgnBZLbNV13DlbnK0quyivTnXJM20=
golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

67
http.go
View File

@ -1,9 +1,10 @@
package redirector
package main
import (
"encoding/json"
"fmt"
"github.com/jmcvetta/randutil"
"github.com/spf13/viper"
"net"
"net/http"
"net/url"
@ -13,10 +14,10 @@ import (
)
// statusHandler is a simple handler that will always return 200 OK with a body of "OK"
func (r *Redirector) statusHandler(w http.ResponseWriter, req *http.Request) {
func statusHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
if req.Method != http.MethodHead {
if r.Method != http.MethodHead {
w.Write([]byte("OK"))
}
}
@ -24,8 +25,8 @@ func (r *Redirector) statusHandler(w http.ResponseWriter, req *http.Request) {
// redirectHandler is the default "not found" handler which handles redirects
// if the environment variable OVERRIDE_IP is set, it will use that ip address
// this is useful for local testing when you're on the local network
func (r *Redirector) redirectHandler(w http.ResponseWriter, req *http.Request) {
ipStr, _, err := net.SplitHostPort(req.RemoteAddr)
func redirectHandler(w http.ResponseWriter, r *http.Request) {
ipStr, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
@ -49,11 +50,11 @@ func (r *Redirector) redirectHandler(w http.ResponseWriter, req *http.Request) {
// If the path has a prefix of region/NA, it will use specific regions instead
// of the default geographical distance
if strings.HasPrefix(req.URL.Path, "/region") {
parts := strings.Split(req.URL.Path, "/")
if strings.HasPrefix(r.URL.Path, "/region") {
parts := strings.Split(r.URL.Path, "/")
// region = parts[2]
if mirrors, ok := r.regionMap[parts[2]]; ok {
if mirrors, ok := regionMap[parts[2]]; ok {
choices := make([]randutil.Choice, len(mirrors))
for i, item := range mirrors {
@ -76,20 +77,13 @@ func (r *Redirector) redirectHandler(w http.ResponseWriter, req *http.Request) {
server = choice.Item.(*Server)
req.URL.Path = strings.Join(parts[3:], "/")
r.URL.Path = strings.Join(parts[3:], "/")
}
}
// If we don't have a scheme, we'll use http by default
scheme := req.URL.Scheme
if scheme == "" {
scheme = "http"
}
// If none of the above exceptions are matched, we use the geographical distance based on IP
if server == nil {
server, distance, err = r.servers.Closest(r, scheme, ip)
server, distance, err = servers.Closest(ip)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
@ -97,20 +91,27 @@ func (r *Redirector) redirectHandler(w http.ResponseWriter, req *http.Request) {
}
}
// If we don't have a scheme, we'll use https by default
scheme := r.URL.Scheme
if scheme == "" {
scheme = "https"
}
// redirectPath is a combination of server path (which can be something like /armbian)
// and the URL path.
// Example: /armbian + /some/path = /armbian/some/path
redirectPath := path.Join(server.Path, req.URL.Path)
redirectPath := path.Join(server.Path, r.URL.Path)
// If we have a dlMap, we map the url to a final path instead
if r.dlMap != nil {
if newPath, exists := r.dlMap[strings.TrimLeft(req.URL.Path, "/")]; exists {
if dlMap != nil {
if newPath, exists := dlMap[strings.TrimLeft(r.URL.Path, "/")]; exists {
downloadsMapped.Inc()
redirectPath = path.Join(server.Path, newPath)
}
}
if strings.HasSuffix(req.URL.Path, "/") && !strings.HasSuffix(redirectPath, "/") {
if strings.HasSuffix(r.URL.Path, "/") && !strings.HasSuffix(redirectPath, "/") {
redirectPath += "/"
}
@ -135,13 +136,15 @@ func (r *Redirector) redirectHandler(w http.ResponseWriter, req *http.Request) {
// reloadHandler is an http handler which lets us reload the server configuration
// It is only enabled when the reloadToken is set in the configuration
func (r *Redirector) reloadHandler(w http.ResponseWriter, req *http.Request) {
if r.config.ReloadToken == "" {
func reloadHandler(w http.ResponseWriter, r *http.Request) {
expectedToken := viper.GetString("reloadToken")
if expectedToken == "" {
w.WriteHeader(http.StatusUnauthorized)
return
}
token := req.Header.Get("Authorization")
token := r.Header.Get("Authorization")
if token == "" || !strings.HasPrefix(token, "Bearer") || !strings.Contains(token, " ") {
w.WriteHeader(http.StatusUnauthorized)
@ -150,12 +153,12 @@ func (r *Redirector) reloadHandler(w http.ResponseWriter, req *http.Request) {
token = token[strings.Index(token, " ")+1:]
if token != r.config.ReloadToken {
if token != expectedToken {
w.WriteHeader(http.StatusUnauthorized)
return
}
if err := r.ReloadConfig(); err != nil {
if err := reloadConfig(); err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
return
@ -165,19 +168,19 @@ func (r *Redirector) reloadHandler(w http.ResponseWriter, req *http.Request) {
w.Write([]byte("OK"))
}
func (r *Redirector) dlMapHandler(w http.ResponseWriter, req *http.Request) {
if r.dlMap == nil {
func dlMapHandler(w http.ResponseWriter, r *http.Request) {
if dlMap == nil {
w.WriteHeader(http.StatusNotFound)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(r.dlMap)
json.NewEncoder(w).Encode(dlMap)
}
func (r *Redirector) geoIPHandler(w http.ResponseWriter, req *http.Request) {
ipStr, _, err := net.SplitHostPort(req.RemoteAddr)
func geoIPHandler(w http.ResponseWriter, r *http.Request) {
ipStr, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
@ -187,7 +190,7 @@ func (r *Redirector) geoIPHandler(w http.ResponseWriter, req *http.Request) {
ip := net.ParseIP(ipStr)
var city City
err = r.db.Lookup(ip, &city)
err = db.Lookup(ip, &city)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)

154
main.go Normal file
View File

@ -0,0 +1,154 @@
package main
import (
"flag"
"github.com/chi-middleware/logrus-logger"
"github.com/go-chi/chi/v5"
lru "github.com/hashicorp/golang-lru"
"github.com/oschwald/maxminddb-golang"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/prometheus/client_golang/prometheus/promhttp"
log "github.com/sirupsen/logrus"
"github.com/spf13/viper"
"net/http"
"os"
"os/signal"
"syscall"
)
var (
db *maxminddb.Reader
servers ServerList
regionMap map[string][]*Server
hostMap map[string]*Server
dlMap map[string]string
topChoices int
redirectsServed = promauto.NewCounter(prometheus.CounterOpts{
Name: "armbian_router_redirects",
Help: "The total number of processed redirects",
})
downloadsMapped = promauto.NewCounter(prometheus.CounterOpts{
Name: "armbian_router_download_maps",
Help: "The total number of mapped download paths",
})
serverCache *lru.Cache
)
type LocationLookup struct {
Location struct {
Latitude float64 `maxminddb:"latitude"`
Longitude float64 `maxminddb:"longitude"`
} `maxminddb:"location"`
}
// City represents a MaxmindDB city
type City struct {
Continent struct {
Code string `maxminddb:"code" json:"code"`
GeoNameID uint `maxminddb:"geoname_id" json:"geoname_id"`
Names map[string]string `maxminddb:"names" json:"names"`
} `maxminddb:"continent" json:"continent"`
Country struct {
GeoNameID uint `maxminddb:"geoname_id" json:"geoname_id"`
IsoCode string `maxminddb:"iso_code" json:"iso_code"`
Names map[string]string `maxminddb:"names" json:"names"`
} `maxminddb:"country" json:"country"`
Location struct {
AccuracyRadius uint16 `maxminddb:"accuracy_radius" json:'accuracy_radius'`
Latitude float64 `maxminddb:"latitude" json:"latitude"`
Longitude float64 `maxminddb:"longitude" json:"longitude"`
} `maxminddb:"location"`
RegisteredCountry struct {
GeoNameID uint `maxminddb:"geoname_id" json:"geoname_id"`
IsoCode string `maxminddb:"iso_code" json:"iso_code"`
Names map[string]string `maxminddb:"names" json:"names"`
} `maxminddb:"registered_country" json:"registered_country"`
}
type ServerConfig struct {
Server string `mapstructure:"server" yaml:"server"`
Latitude float64 `mapstructure:"latitude" yaml:"latitude"`
Longitude float64 `mapstructure:"longitude" yaml:"longitude"`
Continent string `mapstructure:"continent"`
Weight int `mapstructure:"weight" yaml:"weight"`
}
var (
configFlag = flag.String("config", "", "configuration file path")
flagDebug = flag.Bool("debug", false, "Enable debug logging")
)
func main() {
flag.Parse()
if *flagDebug {
log.SetLevel(log.DebugLevel)
}
viper.SetDefault("bind", ":8080")
viper.SetDefault("cacheSize", 1024)
viper.SetDefault("topChoices", 3)
viper.SetDefault("reloadKey", randSeq(32))
viper.SetConfigName("dlrouter") // name of config file (without extension)
viper.SetConfigType("yaml") // REQUIRED if the config file does not have the extension in the name
viper.AddConfigPath("/etc/dlrouter/") // path to look for the config file in
viper.AddConfigPath("$HOME/.dlrouter") // call multiple times to add many search paths
viper.AddConfigPath(".") // optionally look for config in the working directory
if *configFlag != "" {
viper.SetConfigFile(*configFlag)
}
if err := reloadConfig(); err != nil {
log.WithError(err).Fatalln("Unable to load configuration")
}
// Start check loop
go servers.checkLoop()
log.Info("Starting")
r := chi.NewRouter()
r.Use(RealIPMiddleware)
r.Use(logger.Logger("router", log.StandardLogger()))
r.Head("/status", statusHandler)
r.Get("/status", statusHandler)
r.Get("/mirrors", legacyMirrorsHandler)
r.Get("/mirrors/{server}.svg", mirrorStatusHandler)
r.Get("/mirrors.json", mirrorsHandler)
r.Post("/reload", reloadHandler)
r.Get("/dl_map", dlMapHandler)
r.Get("/geoip", geoIPHandler)
r.Get("/metrics", promhttp.Handler().ServeHTTP)
r.NotFound(redirectHandler)
go http.ListenAndServe(viper.GetString("bind"), r)
log.Info("Ready")
c := make(chan os.Signal)
signal.Notify(c, syscall.SIGKILL, syscall.SIGTERM, syscall.SIGHUP)
for {
sig := <-c
if sig != syscall.SIGHUP {
break
}
err := reloadConfig()
if err != nil {
log.WithError(err).Warning("Did not reload configuration due to error")
}
}
}

2
map.go
View File

@ -1,4 +1,4 @@
package redirector
package main
import (
"encoding/csv"

View File

@ -1,4 +1,4 @@
package redirector
package main
import (
. "github.com/onsi/ginkgo/v2"

View File

@ -1,4 +1,4 @@
package middleware
package main
import (
"net"

View File

@ -1,4 +1,4 @@
package redirector
package main
import (
_ "embed"
@ -11,16 +11,16 @@ import (
// legacyMirrorsHandler will list the mirrors by region in the legacy format
// it is preferred to use mirrors.json, but this handler is here for build support
func (r *Redirector) legacyMirrorsHandler(w http.ResponseWriter, req *http.Request) {
func legacyMirrorsHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
mirrorOutput := make(map[string][]string)
for region, mirrors := range r.regionMap {
for region, mirrors := range regionMap {
list := make([]string, len(mirrors))
for i, mirror := range mirrors {
list[i] = req.URL.Scheme + "://" + mirror.Host + "/" + strings.TrimLeft(mirror.Path, "/")
list[i] = r.URL.Scheme + "://" + mirror.Host + "/" + strings.TrimLeft(mirror.Path, "/")
}
mirrorOutput[region] = list
@ -30,9 +30,9 @@ func (r *Redirector) legacyMirrorsHandler(w http.ResponseWriter, req *http.Reque
}
// mirrorsHandler is a simple handler that will return the list of servers
func (r *Redirector) mirrorsHandler(w http.ResponseWriter, req *http.Request) {
func mirrorsHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(r.servers)
json.NewEncoder(w).Encode(servers)
}
var (
@ -48,8 +48,8 @@ var (
// mirrorStatusHandler is a fancy svg-returning handler.
// it is used to display mirror statuses on a config repo of sorts
func (r *Redirector) mirrorStatusHandler(w http.ResponseWriter, req *http.Request) {
serverHost := chi.URLParam(req, "server")
func mirrorStatusHandler(w http.ResponseWriter, r *http.Request) {
serverHost := chi.URLParam(r, "server")
w.Header().Set("Content-Type", "image/svg+xml;charset=utf-8")
w.Header().Set("Cache-Control", "max-age=120")
@ -61,7 +61,7 @@ func (r *Redirector) mirrorStatusHandler(w http.ResponseWriter, req *http.Reques
serverHost = strings.Replace(serverHost, "_", ".", -1)
server, ok := r.hostMap[serverHost]
server, ok := hostMap[serverHost]
if !ok {
w.Header().Set("Content-Length", strconv.Itoa(len(statusUnknown)))
@ -77,7 +77,7 @@ func (r *Redirector) mirrorStatusHandler(w http.ResponseWriter, req *http.Reques
w.Header().Set("ETag", "\""+key+"\"")
if match := req.Header.Get("If-None-Match"); match != "" {
if match := r.Header.Get("If-None-Match"); match != "" {
if strings.Trim(match, "\"") == key {
w.WriteHeader(http.StatusNotModified)
return

View File

@ -1,131 +0,0 @@
package redirector
import (
"github.com/armbian/redirector/middleware"
"github.com/chi-middleware/logrus-logger"
"github.com/go-chi/chi/v5"
lru "github.com/hashicorp/golang-lru"
"github.com/oschwald/maxminddb-golang"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/prometheus/client_golang/prometheus/promhttp"
log "github.com/sirupsen/logrus"
"net/http"
)
var (
redirectsServed = promauto.NewCounter(prometheus.CounterOpts{
Name: "armbian_router_redirects",
Help: "The total number of processed redirects",
})
downloadsMapped = promauto.NewCounter(prometheus.CounterOpts{
Name: "armbian_router_download_maps",
Help: "The total number of mapped download paths",
})
)
type Redirector struct {
config *Config
db *maxminddb.Reader
servers ServerList
regionMap map[string][]*Server
hostMap map[string]*Server
dlMap map[string]string
topChoices int
serverCache *lru.Cache
checks []ServerCheck
checkClient *http.Client
}
type LocationLookup struct {
Location struct {
Latitude float64 `maxminddb:"latitude"`
Longitude float64 `maxminddb:"longitude"`
} `maxminddb:"location"`
}
// City represents a MaxmindDB city
type City struct {
Continent struct {
Code string `maxminddb:"code" json:"code"`
GeoNameID uint `maxminddb:"geoname_id" json:"geoname_id"`
Names map[string]string `maxminddb:"names" json:"names"`
} `maxminddb:"continent" json:"continent"`
Country struct {
GeoNameID uint `maxminddb:"geoname_id" json:"geoname_id"`
IsoCode string `maxminddb:"iso_code" json:"iso_code"`
Names map[string]string `maxminddb:"names" json:"names"`
} `maxminddb:"country" json:"country"`
Location struct {
AccuracyRadius uint16 `maxminddb:"accuracy_radius" json:'accuracy_radius'`
Latitude float64 `maxminddb:"latitude" json:"latitude"`
Longitude float64 `maxminddb:"longitude" json:"longitude"`
} `maxminddb:"location"`
RegisteredCountry struct {
GeoNameID uint `maxminddb:"geoname_id" json:"geoname_id"`
IsoCode string `maxminddb:"iso_code" json:"iso_code"`
Names map[string]string `maxminddb:"names" json:"names"`
} `maxminddb:"registered_country" json:"registered_country"`
}
type ServerConfig struct {
Server string `mapstructure:"server" yaml:"server"`
Latitude float64 `mapstructure:"latitude" yaml:"latitude"`
Longitude float64 `mapstructure:"longitude" yaml:"longitude"`
Continent string `mapstructure:"continent"`
Weight int `mapstructure:"weight" yaml:"weight"`
Protocols []string `mapstructure:"protocols" yaml:"protocols"`
}
// New creates a new instance of Redirector
func New(config *Config) *Redirector {
r := &Redirector{
config: config,
}
r.checks = []ServerCheck{
r.checkHttp("http"),
r.checkTLS,
}
return r
}
func (r *Redirector) Start() http.Handler {
if err := r.ReloadConfig(); err != nil {
log.WithError(err).Fatalln("Unable to load configuration")
}
log.Info("Starting check loop")
// Start check loop
go r.servers.checkLoop(r.checks)
log.Info("Setting up routes")
router := chi.NewRouter()
router.Use(middleware.RealIPMiddleware)
router.Use(logger.Logger("router", log.StandardLogger()))
router.Head("/status", r.statusHandler)
router.Get("/status", r.statusHandler)
router.Get("/mirrors", r.legacyMirrorsHandler)
router.Get("/mirrors/{server}.svg", r.mirrorStatusHandler)
router.Get("/mirrors.json", r.mirrorsHandler)
router.Post("/reload", r.reloadHandler)
router.Get("/dl_map", r.dlMapHandler)
router.Get("/geoip", r.geoIPHandler)
router.Get("/metrics", promhttp.Handler().ServeHTTP)
router.NotFound(r.redirectHandler)
if r.config.BindAddress != "" {
log.WithField("bind", r.config.BindAddress).Info("Binding to address")
go http.ListenAndServe(r.config.BindAddress, router)
}
return router
}

View File

@ -1,16 +1,34 @@
package redirector
package main
import (
"crypto/tls"
"github.com/jmcvetta/randutil"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"math"
"net"
"net/http"
"sort"
"sync"
"time"
)
var (
checkClient = &http.Client{
Timeout: 20 * time.Second,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
checkTLSConfig *tls.Config = nil
checks = []serverCheck{
checkHttp,
checkTLS,
}
)
// Server represents a download server
type Server struct {
Available bool `json:"available"`
@ -20,15 +38,14 @@ type Server struct {
Longitude float64 `json:"longitude"`
Weight int `json:"weight"`
Continent string `json:"continent"`
Protocols ProtocolList `json:"protocols"`
Redirects prometheus.Counter `json:"-"`
LastChange time.Time `json:"lastChange"`
}
type ServerCheck func(server *Server, logFields log.Fields) (bool, error)
type serverCheck func(server *Server, logFields log.Fields) (bool, error)
// checkStatus runs all status checks against a server
func (server *Server) checkStatus(checks []ServerCheck) {
func (server *Server) checkStatus() {
logFields := log.Fields{
"host": server.Host,
}
@ -70,19 +87,19 @@ func (server *Server) checkStatus(checks []ServerCheck) {
type ServerList []*Server
func (s ServerList) checkLoop(checks []ServerCheck) {
func (s ServerList) checkLoop() {
t := time.NewTicker(60 * time.Second)
for {
<-t.C
s.Check(checks)
s.Check()
}
}
// Check will request the index from all servers
// If a server does not respond in 10 seconds, it is considered offline.
// This will wait until all checks are complete.
func (s ServerList) Check(checks []ServerCheck) {
func (s ServerList) Check() {
var wg sync.WaitGroup
for _, server := range s {
@ -91,7 +108,7 @@ func (s ServerList) Check(checks []ServerCheck) {
go func(server *Server) {
defer wg.Done()
server.checkStatus(checks)
server.checkStatus()
}(server)
}
@ -110,12 +127,12 @@ type DistanceList []ComputedDistance
// Closest will use GeoIP on the IP provided and find the closest servers.
// When we have a list of x servers closest, we can choose a random or weighted one.
// Return values are the closest server, the distance, and if an error occurred.
func (s ServerList) Closest(r *Redirector, scheme string, ip net.IP) (*Server, float64, error) {
choiceInterface, exists := r.serverCache.Get(scheme + "_" + ip.String())
func (s ServerList) Closest(ip net.IP) (*Server, float64, error) {
choiceInterface, exists := serverCache.Get(ip.String())
if !exists {
var city LocationLookup
err := r.db.Lookup(ip, &city)
err := db.Lookup(ip, &city)
if err != nil {
return nil, -1, err
@ -124,7 +141,7 @@ func (s ServerList) Closest(r *Redirector, scheme string, ip net.IP) (*Server, f
c := make(DistanceList, len(s))
for i, server := range s {
if !server.Available || !server.Protocols.Contains(scheme) {
if !server.Available {
continue
}
@ -141,9 +158,9 @@ func (s ServerList) Closest(r *Redirector, scheme string, ip net.IP) (*Server, f
return c[i].Distance < c[j].Distance
})
choiceCount := r.config.TopChoices
choiceCount := topChoices
if len(c) < r.config.TopChoices {
if len(c) < topChoices {
choiceCount = len(c)
}
@ -162,7 +179,7 @@ func (s ServerList) Closest(r *Redirector, scheme string, ip net.IP) (*Server, f
choiceInterface = choices
r.serverCache.Add(scheme+"_"+ip.String(), choiceInterface)
serverCache.Add(ip.String(), choiceInterface)
}
choice, err := randutil.WeightedChoice(choiceInterface.([]randutil.Choice))
@ -175,9 +192,9 @@ func (s ServerList) Closest(r *Redirector, scheme string, ip net.IP) (*Server, f
if !dist.Server.Available {
// Choose a new server and refresh cache
r.serverCache.Remove(scheme + "_" + ip.String())
serverCache.Remove(ip.String())
return s.Closest(r, scheme, ip)
return s.Closest(ip)
}
return dist.Server, dist.Distance, nil
@ -189,10 +206,9 @@ func hsin(theta float64) float64 {
}
// Distance function returns the distance (in meters) between two points of
//
// a given longitude and latitude relatively accurately (using a spherical
// approximation of the Earth) through the Haversine Distance Formula for
// great arc distance on a sphere with accuracy for small distances
// a given longitude and latitude relatively accurately (using a spherical
// approximation of the Earth) through the Haversine Distance Formula for
// great arc distance on a sphere with accuracy for small distances
//
// point coordinates are supplied in degrees and converted into rad. in the func
//

View File

@ -1,10 +1,10 @@
package redirector
package main
import "math/rand"
var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
func RandomSequence(n int) string {
func randSeq(n int) string {
b := make([]rune, n)
for i := range b {
b[i] = letters[rand.Intn(len(letters))]

View File

@ -1,46 +0,0 @@
package util
import (
"crypto/x509"
"github.com/gwatts/rootcerts/certparse"
log "github.com/sirupsen/logrus"
"net/http"
)
const (
defaultDownloadURL = "https://github.com/mozilla/gecko-dev/blob/master/security/nss/lib/ckfw/builtins/certdata.txt?raw=true"
)
func LoadCACerts() (*x509.CertPool, error) {
res, err := http.Get(defaultDownloadURL)
if err != nil {
return nil, err
}
defer res.Body.Close()
certs, err := certparse.ReadTrustedCerts(res.Body)
if err != nil {
return nil, err
}
pool := x509.NewCertPool()
var count int
for _, cert := range certs {
if cert.Trust&certparse.ServerTrustedDelegator == 0 {
continue
}
count++
pool.AddCert(cert.Cert)
}
log.WithField("certs", count).Info("Loaded root cas")
return pool, nil
}