diff --git a/integration_tests/performance_test.go b/integration_tests/performance_test.go index 08be778e..1b9def32 100644 --- a/integration_tests/performance_test.go +++ b/integration_tests/performance_test.go @@ -2,6 +2,7 @@ package integration import ( "net/http/httptest" + "os" "time" . "github.com/onsi/ginkgo/v2" @@ -22,15 +23,15 @@ var _ = Describe("Performance", func() { BeforeEach(func() { backend1 = startSimpleBackend("backend 1") backend2 = startSimpleBackend("backend 2") - addBackend("backend-1", backend1.URL) - addBackend("backend-2", backend2.URL) + os.Setenv("BACKEND_URL_backend-1", backend1.URL) + os.Setenv("BACKEND_URL_backend-2", backend2.URL) addRoute("/one", NewBackendRoute("backend-1")) addRoute("/two", NewBackendRoute("backend-2")) reloadRoutes(apiPort) }) AfterEach(func() { - backend1.Close() - backend2.Close() + os.Unsetenv("BACKEND_URL_backend-1") + os.Unsetenv("BACKEND_URL_backend-2") }) It("Router should not cause errors or much latency", func() { @@ -60,7 +61,8 @@ var _ = Describe("Performance", func() { It("Router should not cause errors or much latency", func() { slowBackend := startTarpitBackend(time.Second) defer slowBackend.Close() - addBackend("backend-slow", slowBackend.URL) + os.Setenv("BACKEND_URL_backend-slow", slowBackend.URL) + defer os.Unsetenv("BACKEND_URL_backend-slow") addRoute("/slow", NewBackendRoute("backend-slow")) reloadRoutes(apiPort) @@ -73,7 +75,9 @@ var _ = Describe("Performance", func() { Describe("with one downed backend hit separately", func() { It("Router should not cause errors or much latency", func() { - addBackend("backend-down", "http://127.0.0.1:3162/") + os.Setenv("BACKEND_URL_backend-down", "http://127.0.0.1:3162/") + defer os.Unsetenv("BACKEND_URL_backend-down") + addRoute("/down", NewBackendRoute("backend-down")) reloadRoutes(apiPort) @@ -98,13 +102,15 @@ var _ = Describe("Performance", func() { BeforeEach(func() { backend1 = startTarpitBackend(time.Second) backend2 = startTarpitBackend(time.Second) - addBackend("backend-1", backend1.URL) - addBackend("backend-2", backend2.URL) + os.Setenv("BACKEND_URL_backend-1", backend1.URL) + os.Setenv("BACKEND_URL_backend-2", backend2.URL) addRoute("/one", NewBackendRoute("backend-1")) addRoute("/two", NewBackendRoute("backend-2")) reloadRoutes(apiPort) }) AfterEach(func() { + os.Unsetenv("BACKEND_URL_backend-1") + os.Unsetenv("BACKEND_URL_backend-2") backend1.Close() backend2.Close() }) diff --git a/integration_tests/proxy_function_test.go b/integration_tests/proxy_function_test.go index cb383137..a51c6804 100644 --- a/integration_tests/proxy_function_test.go +++ b/integration_tests/proxy_function_test.go @@ -6,6 +6,7 @@ import ( "net/http/httptest" "net/textproto" "net/url" + "os" "strings" "time" @@ -19,7 +20,9 @@ var _ = Describe("Functioning as a reverse proxy", func() { Describe("connecting to the backend", func() { It("should return a 502 if the connection to the backend is refused", func() { - addBackend("not-running", "http://127.0.0.1:3164/") + os.Setenv("BACKEND_URL_not-running", "http://127.0.0.1:3164/") + defer os.Unsetenv("BACKEND_URL_not-running") + addRoute("/not-running", NewBackendRoute("not-running")) reloadRoutes(apiPort) @@ -45,7 +48,9 @@ var _ = Describe("Functioning as a reverse proxy", func() { Expect(err).NotTo(HaveOccurred()) defer stopRouter(3167) - addBackend("black-hole", "http://240.0.0.0:1234/") + os.Setenv("BACKEND_URL_black-hole", "http://240.0.0.0:1234/") + defer os.Unsetenv("BACKEND_URL_black-hole") + addRoute("/should-time-out", NewBackendRoute("black-hole")) reloadRoutes(3166) @@ -78,14 +83,16 @@ var _ = Describe("Functioning as a reverse proxy", func() { Expect(err).NotTo(HaveOccurred()) tarpit1 = startTarpitBackend(time.Second) tarpit2 = startTarpitBackend(100*time.Millisecond, 500*time.Millisecond) - addBackend("tarpit1", tarpit1.URL) - addBackend("tarpit2", tarpit2.URL) + os.Setenv("BACKEND_URL_tarpit1", tarpit1.URL) + os.Setenv("BACKEND_URL_tarpit2", tarpit2.URL) addRoute("/tarpit1", NewBackendRoute("tarpit1")) addRoute("/tarpit2", NewBackendRoute("tarpit2")) reloadRoutes(3166) }) AfterEach(func() { + os.Unsetenv("BACKEND_URL_tarpit1") + os.Unsetenv("BACKEND_URL_tarpit2") tarpit1.Close() tarpit2.Close() stopRouter(3167) @@ -119,12 +126,13 @@ var _ = Describe("Functioning as a reverse proxy", func() { Describe("header handling", func() { BeforeEach(func() { recorder = startRecordingBackend() - addBackend("backend", recorder.URL()) + os.Setenv("BACKEND_URL_backend", recorder.URL()) addRoute("/foo", NewBackendRoute("backend", "prefix")) reloadRoutes(apiPort) }) AfterEach(func() { + os.Unsetenv("BACKEND_URL_backend") recorder.Close() }) @@ -243,12 +251,13 @@ var _ = Describe("Functioning as a reverse proxy", func() { Describe("request verb, path, query and body handling", func() { BeforeEach(func() { recorder = startRecordingBackend() - addBackend("backend", recorder.URL()) + os.Setenv("BACKEND_URL_backend", recorder.URL()) addRoute("/foo", NewBackendRoute("backend", "prefix")) reloadRoutes(apiPort) }) AfterEach(func() { + os.Unsetenv("BACKEND_URL_backend") recorder.Close() }) @@ -299,12 +308,13 @@ var _ = Describe("Functioning as a reverse proxy", func() { Describe("handling a backend with a non '/' path", func() { BeforeEach(func() { recorder = startRecordingBackend() - addBackend("backend", recorder.URL()+"/something") + os.Setenv("BACKEND_URL_backend", recorder.URL()+"/something") addRoute("/foo/bar", NewBackendRoute("backend", "prefix")) reloadRoutes(apiPort) }) AfterEach(func() { + os.Unsetenv("BACKEND_URL_backend") recorder.Close() }) @@ -330,12 +340,13 @@ var _ = Describe("Functioning as a reverse proxy", func() { Describe("handling HTTP/1.0 requests", func() { BeforeEach(func() { recorder = startRecordingBackend() - addBackend("backend", recorder.URL()) + os.Setenv("BACKEND_URL_backend", recorder.URL()) addRoute("/foo", NewBackendRoute("backend", "prefix")) reloadRoutes(apiPort) }) AfterEach(func() { + os.Unsetenv("BACKEND_URL_backend") recorder.Close() }) @@ -365,12 +376,13 @@ var _ = Describe("Functioning as a reverse proxy", func() { err := startRouter(3167, 3166, []string{"ROUTER_TLS_SKIP_VERIFY=1"}) Expect(err).NotTo(HaveOccurred()) recorder = startRecordingTLSBackend() - addBackend("backend", recorder.URL()) + os.Setenv("BACKEND_URL_backend", recorder.URL()) addRoute("/foo", NewBackendRoute("backend", "prefix")) reloadRoutes(3166) }) AfterEach(func() { + os.Unsetenv("BACKEND_URL_backend") recorder.Close() stopRouter(3167) }) diff --git a/integration_tests/redirect_test.go b/integration_tests/redirect_test.go index 0179d4ac..1b8a2d33 100644 --- a/integration_tests/redirect_test.go +++ b/integration_tests/redirect_test.go @@ -1,6 +1,7 @@ package integration import ( + "os" "time" . "github.com/onsi/ginkgo/v2" @@ -223,13 +224,14 @@ var _ = Describe("Redirection", func() { BeforeEach(func() { recorder = startRecordingBackend() - addBackend("be", recorder.URL()) + os.Setenv("BACKEND_URL_be", recorder.URL()) addRoute("/guidance/keeping-a-pet-pig-or-micropig", NewBackendRoute("be", "exact")) addRoute("/GUIDANCE/keeping-a-pet-pig-or-micropig", NewBackendRoute("be", "exact")) reloadRoutes(apiPort) }) AfterEach(func() { + os.Unsetenv("BACKEND_URL_be") recorder.Close() }) diff --git a/integration_tests/route_helpers.go b/integration_tests/route_helpers.go index 84ed1e53..4d2795b8 100644 --- a/integration_tests/route_helpers.go +++ b/integration_tests/route_helpers.go @@ -6,7 +6,6 @@ import ( "time" "github.com/globalsign/mgo" - "github.com/globalsign/mgo/bson" // revive:disable:dot-imports . "github.com/onsi/ginkgo/v2" @@ -91,11 +90,6 @@ func initRouteHelper() error { return nil } -func addBackend(id, url string) { - err := routerDB.C("backends").Insert(bson.M{"backend_id": id, "backend_url": url}) - Expect(err).NotTo(HaveOccurred()) -} - func addRoute(path string, route Route) { route.IncomingPath = path diff --git a/integration_tests/route_loading_test.go b/integration_tests/route_loading_test.go index 827c1deb..f621c42a 100644 --- a/integration_tests/route_loading_test.go +++ b/integration_tests/route_loading_test.go @@ -1,8 +1,8 @@ package integration import ( - "fmt" "net/http/httptest" + "os" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -12,16 +12,17 @@ var _ = Describe("loading routes from the db", func() { var ( backend1 *httptest.Server backend2 *httptest.Server - backend3 *httptest.Server ) BeforeEach(func() { backend1 = startSimpleBackend("backend 1") backend2 = startSimpleBackend("backend 2") - addBackend("backend-1", backend1.URL) - addBackend("backend-2", backend2.URL) + os.Setenv("BACKEND_URL_backend-1", backend1.URL) + os.Setenv("BACKEND_URL_backend-2", backend2.URL) }) AfterEach(func() { + os.Unsetenv("BACKEND_URL_backend-1") + os.Unsetenv("BACKEND_URL_backend-2") backend1.Close() backend2.Close() }) @@ -73,34 +74,4 @@ var _ = Describe("loading routes from the db", func() { Expect(readBody(resp)).To(Equal("backend 1")) }) }) - - Context("a backend an env var overriding the backend_url", func() { - BeforeEach(func() { - // This tests the behaviour of backend.ParseURL overriding the backend_url - // provided in the DB with the value of an env var - blackHole := "240.0.0.0/foo" - backend3 = startSimpleBackend("backend 3") - addBackend("backend-3", blackHole) - - stopRouter(routerPort) - err := startRouter(routerPort, apiPort, []string{fmt.Sprintf("BACKEND_URL_backend-3=%s", backend3.URL)}) - Expect(err).NotTo(HaveOccurred()) - - addRoute("/oof", NewBackendRoute("backend-3")) - reloadRoutes(apiPort) - }) - - AfterEach(func() { - stopRouter(routerPort) - err := startRouter(routerPort, apiPort, nil) - Expect(err).NotTo(HaveOccurred()) - backend3.Close() - }) - - It("should send requests to the backend_url provided in the env var", func() { - resp := routerRequest(routerPort, "/oof") - Expect(resp.StatusCode).To(Equal(200)) - Expect(readBody(resp)).To(Equal("backend 3")) - }) - }) }) diff --git a/integration_tests/route_selection_test.go b/integration_tests/route_selection_test.go index 4b60e44b..677ebb1f 100644 --- a/integration_tests/route_selection_test.go +++ b/integration_tests/route_selection_test.go @@ -2,6 +2,7 @@ package integration import ( "net/http/httptest" + "os" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -19,14 +20,16 @@ var _ = Describe("Route selection", func() { BeforeEach(func() { backend1 = startSimpleBackend("backend 1") backend2 = startSimpleBackend("backend 2") - addBackend("backend-1", backend1.URL) - addBackend("backend-2", backend2.URL) + os.Setenv("BACKEND_URL_backend-1", backend1.URL) + os.Setenv("BACKEND_URL_backend-2", backend2.URL) addRoute("/foo", NewBackendRoute("backend-1")) addRoute("/bar", NewBackendRoute("backend-2")) addRoute("/baz", NewBackendRoute("backend-1")) reloadRoutes(apiPort) }) AfterEach(func() { + os.Unsetenv("BACKEND_URL_backend-1") + os.Unsetenv("BACKEND_URL_backend-2") backend1.Close() backend2.Close() }) @@ -68,14 +71,16 @@ var _ = Describe("Route selection", func() { BeforeEach(func() { backend1 = startSimpleBackend("backend 1") backend2 = startSimpleBackend("backend 2") - addBackend("backend-1", backend1.URL) - addBackend("backend-2", backend2.URL) + os.Setenv("BACKEND_URL_backend-1", backend1.URL) + os.Setenv("BACKEND_URL_backend-2", backend2.URL) addRoute("/foo", NewBackendRoute("backend-1", "prefix")) addRoute("/bar", NewBackendRoute("backend-2", "prefix")) addRoute("/baz", NewBackendRoute("backend-1", "prefix")) reloadRoutes(apiPort) }) AfterEach(func() { + os.Unsetenv("BACKEND_URL_backend-1") + os.Unsetenv("BACKEND_URL_backend-2") backend1.Close() backend2.Close() }) @@ -123,12 +128,14 @@ var _ = Describe("Route selection", func() { BeforeEach(func() { outer = startSimpleBackend("outer") inner = startSimpleBackend("inner") - addBackend("outer-backend", outer.URL) - addBackend("inner-backend", inner.URL) + os.Setenv("BACKEND_URL_outer-backend", outer.URL) + os.Setenv("BACKEND_URL_inner-backend", inner.URL) addRoute("/foo", NewBackendRoute("outer-backend", "prefix")) reloadRoutes(apiPort) }) AfterEach(func() { + os.Unsetenv("BACKEND_URL_outer-backend") + os.Unsetenv("BACKEND_URL_inner-backend") outer.Close() inner.Close() }) @@ -191,12 +198,13 @@ var _ = Describe("Route selection", func() { ) BeforeEach(func() { innerer = startSimpleBackend("innerer") - addBackend("innerer-backend", innerer.URL) + os.Setenv("BACKEND_URL_innerer-backend", innerer.URL) addRoute("/foo/bar", NewBackendRoute("inner-backend")) addRoute("/foo/bar/baz", NewBackendRoute("innerer-backend", "prefix")) reloadRoutes(apiPort) }) AfterEach(func() { + os.Unsetenv("BACKEND_URL_innerer-backend") innerer.Close() }) @@ -245,13 +253,15 @@ var _ = Describe("Route selection", func() { BeforeEach(func() { backend1 = startSimpleBackend("backend 1") backend2 = startSimpleBackend("backend 2") - addBackend("backend-1", backend1.URL) - addBackend("backend-2", backend2.URL) + os.Setenv("BACKEND_URL_backend-1", backend1.URL) + os.Setenv("BACKEND_URL_backend-2", backend2.URL) addRoute("/foo", NewBackendRoute("backend-1", "prefix")) addRoute("/foo", NewBackendRoute("backend-2")) reloadRoutes(apiPort) }) AfterEach(func() { + os.Unsetenv("BACKEND_URL_backend-1") + os.Unsetenv("BACKEND_URL_backend-2") backend1.Close() backend2.Close() }) @@ -276,11 +286,13 @@ var _ = Describe("Route selection", func() { BeforeEach(func() { root = startSimpleBackend("root backend") other = startSimpleBackend("other backend") - addBackend("root", root.URL) - addBackend("other", other.URL) + os.Setenv("BACKEND_URL_root", root.URL) + os.Setenv("BACKEND_URL_other", other.URL) addRoute("/foo", NewBackendRoute("other")) }) AfterEach(func() { + os.Unsetenv("BACKEND_URL_root") + os.Unsetenv("BACKEND_URL_other") root.Close() other.Close() }) @@ -323,13 +335,15 @@ var _ = Describe("Route selection", func() { BeforeEach(func() { root = startSimpleBackend("fallthrough") recorder = startRecordingBackend() - addBackend("root", root.URL) - addBackend("other", recorder.URL()) + os.Setenv("BACKEND_URL_root", root.URL) + os.Setenv("BACKEND_URL_other", recorder.URL()) addRoute("/", NewBackendRoute("root", "prefix")) addRoute("/foo/bar", NewBackendRoute("other", "prefix")) reloadRoutes(apiPort) }) AfterEach(func() { + os.Unsetenv("BACKEND_URL_root") + os.Unsetenv("BACKEND_URL_other") root.Close() recorder.Close() }) @@ -359,9 +373,10 @@ var _ = Describe("Route selection", func() { BeforeEach(func() { recorder = startRecordingBackend() - addBackend("backend", recorder.URL()) + os.Setenv("BACKEND_URL_backend", recorder.URL()) }) AfterEach(func() { + os.Unsetenv("BACKEND_URL_backend") recorder.Close() }) diff --git a/lib/backends.go b/lib/backends.go new file mode 100644 index 00000000..85c04c43 --- /dev/null +++ b/lib/backends.go @@ -0,0 +1,49 @@ +package router + +import ( + "fmt" + "net/http" + "net/url" + "os" + "strings" + "time" + + "github.com/alphagov/router/handlers" + "github.com/alphagov/router/logger" +) + +func loadBackendsFromEnv(connTimeout, headerTimeout time.Duration, logger logger.Logger) (backends map[string]http.Handler) { + backends = make(map[string]http.Handler) + + for _, envvar := range os.Environ() { + pair := strings.SplitN(envvar, "=", 2) + + if !strings.HasPrefix(pair[0], "BACKEND_URL_") { + continue + } + + backendID := strings.TrimPrefix(pair[0], "BACKEND_URL_") + backendURL := pair[1] + + if backendURL == "" { + logWarn(fmt.Errorf("router: couldn't find URL for backend %s, skipping", backendID)) + continue + } + + backend, err := url.Parse(backendURL) + if err != nil { + logWarn(fmt.Errorf("router: couldn't parse URL %s for backend %s (error: %w), skipping", backendURL, backendID, err)) + continue + } + + backends[backendID] = handlers.NewBackendHandler( + backendID, + backend, + connTimeout, + headerTimeout, + logger, + ) + } + + return +} diff --git a/lib/backends_test.go b/lib/backends_test.go new file mode 100644 index 00000000..5d1f5748 --- /dev/null +++ b/lib/backends_test.go @@ -0,0 +1,41 @@ +package router + +import ( + "os" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Backends", func() { + Context("When calling loadBackendsFromEnv", func() { + It("should load backends from environment variables", func() { + os.Setenv("BACKEND_URL_testBackend", "http://example.com") + defer os.Unsetenv("BACKEND_URL_testBackend") + + backends := loadBackendsFromEnv(1*time.Second, 20*time.Second, nil) + + Expect(backends).To(HaveKey("testBackend")) + Expect(backends["testBackend"]).ToNot(BeNil()) + }) + + It("should skip backends with empty URLs", func() { + os.Setenv("BACKEND_URL_emptyBackend", "") + defer os.Unsetenv("BACKEND_URL_emptyBackend") + + backends := loadBackendsFromEnv(1*time.Second, 20*time.Second, nil) + + Expect(backends).ToNot(HaveKey("emptyBackend")) + }) + + It("should skip backends with invalid URLs", func() { + os.Setenv("BACKEND_URL_invalidBackend", "://invalid-url") + defer os.Unsetenv("BACKEND_URL_invalidBackend") + + backends := loadBackendsFromEnv(1*time.Second, 20*time.Second, nil) + + Expect(backends).ToNot(HaveKey("invalidBackend")) + }) + }) +}) diff --git a/lib/router.go b/lib/router.go index 5fbac7bd..d7890ce6 100644 --- a/lib/router.go +++ b/lib/router.go @@ -4,7 +4,6 @@ import ( "fmt" "net/http" "net/url" - "os" "strconv" "sync" "time" @@ -39,6 +38,7 @@ const ( // come from, Route and Backend should not contain bson fields. // MongoReplicaSet, MongoReplicaSetMember etc. should move out of this module. type Router struct { + backends map[string]http.Handler mux *triemux.Mux lock sync.RWMutex mongoReadToOptime bson.MongoTimestamp @@ -106,8 +106,11 @@ func NewRouter(o Options) (rt *Router, err error) { return nil, err } + backends := loadBackendsFromEnv(o.BackendConnTimeout, o.BackendHeaderTimeout, l) + reloadChan := make(chan bool, 1) rt = &Router{ + backends: backends, mux: triemux.NewMux(), mongoReadToOptime: mongoReadToOptime, logger: l, @@ -235,8 +238,7 @@ func (rt *Router) reloadRoutes(db *mgo.Database, currentOptime bson.MongoTimesta logInfo("router: reloading routes") newmux := triemux.NewMux() - backends := rt.loadBackends(db.C("backends")) - loadRoutes(db.C("routes"), newmux, backends) + loadRoutes(db.C("routes"), newmux, rt.backends) routeCount := newmux.RouteCount() rt.lock.Lock() @@ -286,39 +288,6 @@ func (rt *Router) shouldReload(currentMongoInstance MongoReplicaSetMember) bool return currentMongoInstance.Optime > rt.mongoReadToOptime } -// loadBackends is a helper function which loads backends from the -// passed mongo collection, constructs a Handler for each one, and returns -// them in map keyed on the backend_id -func (rt *Router) loadBackends(c *mgo.Collection) (backends map[string]http.Handler) { - backend := &Backend{} - backends = make(map[string]http.Handler) - - iter := c.Find(nil).Iter() - - for iter.Next(&backend) { - backendURL, err := backend.ParseURL() - if err != nil { - logWarn(fmt.Errorf("router: couldn't parse URL %s for backend %s "+ - "(error: %w), skipping", backend.BackendURL, backend.BackendID, err)) - continue - } - - backends[backend.BackendID] = handlers.NewBackendHandler( - backend.BackendID, - backendURL, - rt.opts.BackendConnTimeout, - rt.opts.BackendHeaderTimeout, - rt.logger, - ) - } - - if err := iter.Err(); err != nil { - panic(err) - } - - return -} - // loadRoutes is a helper function which loads routes from the passed mongo // collection and registers them with the passed proxy mux. func loadRoutes(c *mgo.Collection, mux *triemux.Mux, backends map[string]http.Handler) { @@ -378,14 +347,6 @@ func loadRoutes(c *mgo.Collection, mux *triemux.Mux, backends map[string]http.Ha } } -func (be *Backend) ParseURL() (*url.URL, error) { - backendURL := os.Getenv(fmt.Sprintf("BACKEND_URL_%s", be.BackendID)) - if backendURL == "" { - backendURL = be.BackendURL - } - return url.Parse(backendURL) -} - func shouldPreserveSegments(route *Route) bool { switch route.RouteType { case RouteTypeExact: