diff --git a/federation/server.go b/federation/server.go index 3e0ed3cf..36487a1a 100644 --- a/federation/server.go +++ b/federation/server.go @@ -528,7 +528,88 @@ func (s *Server) Mux() *mux.Router { return s.mux } -// Listen for federation server requests - call the returned function to gracefully close the server. +// Keep track of the ports that we've previously used so that we never use the same port +// (and therefore the same `server_name`) to two different servers. +// +// A `Server` is identified over federation solely by its `server_name` (which looks +// like `hostname:port` for these Complement engineered homeservers). When the OS recycles a +// freed port, a new Server could otherwise get a `server_name` that is identical to a +// previously torn-down one. +// +// To explain an actual situation where this becomes a problem: A real homeserver under +// test (that is participating in a room with the now-dead engineered homeserver) might +// still try to reach the dead server, but since the `server_name` is the same, it's now +// hitting the new server unexpectedly (cross-test pollution). +// +// This particularly happens when you try to share a `deployment` across many tests and +// then each test creates a engineered homeservers to interact against. +// +// Retiring each port for the lifetime of the process keeps stray requests pointed at a +// dead port (connection refused) instead of a live, unrelated server. +var ( + // Use a mutex so only one thread can advance `lastUsedPort` at a time. We don't want + // multiple threads clobbering `lastUsedPort`. + lastUsedPortMu sync.Mutex + // Start at 1024 (1023 + 1) to avoid the priviged ports used by the system + // + // Since we sequentially try each port, we just need to keep track of the last one we tried + lastUsedPort = 1023 +) + +// listenOnUnusedPort listens on an unused port that no other federation `Server` has +// used before in this process. +func listenOnUnusedPort(t ct.TestLike) net.Listener { + lastUsedPortMu.Lock() + defer lastUsedPortMu.Unlock() + + // We use this sequential port scan strategy over guess and check with an OS-assigned + // port (by using `:0`) as it's more efficient. The OS may recycle and re-use freed + // ports meaning we could regress to O(n^2) behavior trying to search for each new + // port we want to find. + // + // Using `:0` means an unused port is automatically picked for us (could be random, + // could be the next sequential unused port, we don't know). Ideally, we could ask for + // the next unused port after X to avoid a bunch of work. When using using `:0`, the + // pathological case that is O(n^2) is if OS hands back next lowest unused port + // sequentially which would mean we would have to probe and hold each listener until + // we finally got something new. + + // Try the whole port range (untested but it's probably fast to do so) + max_attempts := 65535 + var lastErr error + for i := 0; i < max_attempts; i++ { + port := lastUsedPort + 1 + if port > 65535 { + // If this ever becomes a problem, we can namespace used ports by `deployment` since + // that has to be passed into `NewServer(...)` anyway and the whole point of this is + // that a homeserver from the `deployment` doesn't try to reach out to a previous + // engineered homeserver it knows about. + // + // As another alternative, we could also wrap-around to the beginning of the port + // range again although that is slightly unsound. + ct.Fatalf( + t, "listenOnUnusedPort: could not find an unused port in the entire port range (0 - 65535). "+ + "(see comment here if you run into this). Last error: %s", lastErr, + ) + } + + // Check port availability + ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) + lastUsedPort = port + if err != nil { + lastErr = err + // Port unavailable, skip + continue + } + + return ln + } + // Since we try the entire port range, we don't really expect to get here but we have + // it in case there is a programming error above + ct.Fatalf(t, "listenOnUnusedPort: Programming error") + return nil +} + func (s *Server) Listen() (cancel func()) { if s.listening { return @@ -536,10 +617,7 @@ func (s *Server) Listen() (cancel func()) { var wg sync.WaitGroup wg.Add(1) - ln, err := net.Listen("tcp", ":0") //nolint - if err != nil { - ct.Fatalf(s.t, "ListenFederationServer: net.Listen failed: %s", err) - } + ln := listenOnUnusedPort(s.t) port := ln.Addr().(*net.TCPAddr).Port s.serverName = spec.ServerName(fmt.Sprintf("%s:%d", s.serverName, port)) s.listening = true