summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/14812.bugfix1
-rw-r--r--synapse/rest/client/register.py1
-rw-r--r--synapse/server.py1
-rw-r--r--synapse/util/ratelimitutils.py34
-rw-r--r--tests/util/test_ratelimitutils.py45
5 files changed, 70 insertions, 12 deletions
diff --git a/changelog.d/14812.bugfix b/changelog.d/14812.bugfix
new file mode 100644
index 0000000000..94e0d70cbc
--- /dev/null
+++ b/changelog.d/14812.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug where Synapse would exhaust the stack when processing many federation requests where the remote homeserver has disconencted early.
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 3cb1e7e375..be696c304b 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -310,6 +310,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
         self.hs = hs
         self.registration_handler = hs.get_registration_handler()
         self.ratelimiter = FederationRateLimiter(
+            hs.get_reactor(),
             hs.get_clock(),
             FederationRatelimitSettings(
                 # Time window of 2s
diff --git a/synapse/server.py b/synapse/server.py
index f4ab94c4f3..c8752baa5a 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -768,6 +768,7 @@ class HomeServer(metaclass=abc.ABCMeta):
     @cache_in_self
     def get_federation_ratelimiter(self) -> FederationRateLimiter:
         return FederationRateLimiter(
+            self.get_reactor(),
             self.get_clock(),
             config=self.config.ratelimiting.rc_federation,
             metrics_name="federation_servlets",
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 2aceb1a47f..bd72947bfe 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -34,6 +34,7 @@ from prometheus_client.core import Counter
 from typing_extensions import ContextManager
 
 from twisted.internet import defer
+from twisted.internet.interfaces import IReactorTime
 
 from synapse.api.errors import LimitExceededError
 from synapse.config.ratelimiting import FederationRatelimitSettings
@@ -146,12 +147,14 @@ class FederationRateLimiter:
 
     def __init__(
         self,
+        reactor: IReactorTime,
         clock: Clock,
         config: FederationRatelimitSettings,
         metrics_name: Optional[str] = None,
     ):
         """
         Args:
+            reactor
             clock
             config
             metrics_name: The name of the rate limiter so we can differentiate it
@@ -163,7 +166,7 @@ class FederationRateLimiter:
 
         def new_limiter() -> "_PerHostRatelimiter":
             return _PerHostRatelimiter(
-                clock=clock, config=config, metrics_name=metrics_name
+                reactor=reactor, clock=clock, config=config, metrics_name=metrics_name
             )
 
         self.ratelimiters: DefaultDict[
@@ -194,12 +197,14 @@ class FederationRateLimiter:
 class _PerHostRatelimiter:
     def __init__(
         self,
+        reactor: IReactorTime,
         clock: Clock,
         config: FederationRatelimitSettings,
         metrics_name: Optional[str] = None,
     ):
         """
         Args:
+            reactor
             clock
             config
             metrics_name: The name of the rate limiter so we can differentiate it
@@ -207,6 +212,7 @@ class _PerHostRatelimiter:
                 for this rate limiter.
                 from the rest in the metrics
         """
+        self.reactor = reactor
         self.clock = clock
         self.metrics_name = metrics_name
 
@@ -364,12 +370,22 @@ class _PerHostRatelimiter:
 
     def _on_exit(self, request_id: object) -> None:
         logger.debug("Ratelimit(%s) [%s]: Processed req", self.host, id(request_id))
-        self.current_processing.discard(request_id)
-        try:
-            # start processing the next item on the queue.
-            _, deferred = self.ready_request_queue.popitem(last=False)
 
-            with PreserveLoggingContext():
-                deferred.callback(None)
-        except KeyError:
-            pass
+        # When requests complete synchronously, we will recursively start the next
+        # request in the queue. To avoid stack exhaustion, we defer starting the next
+        # request until the next reactor tick.
+
+        def start_next_request() -> None:
+            # We only remove the completed request from the list when we're about to
+            # start the next one, otherwise we can allow extra requests through.
+            self.current_processing.discard(request_id)
+            try:
+                # start processing the next item on the queue.
+                _, deferred = self.ready_request_queue.popitem(last=False)
+
+                with PreserveLoggingContext():
+                    deferred.callback(None)
+            except KeyError:
+                pass
+
+        self.reactor.callLater(0.0, start_next_request)
diff --git a/tests/util/test_ratelimitutils.py b/tests/util/test_ratelimitutils.py
index 5b327b390e..2f3ea15b96 100644
--- a/tests/util/test_ratelimitutils.py
+++ b/tests/util/test_ratelimitutils.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 from typing import Optional
 
+from twisted.internet import defer
 from twisted.internet.defer import Deferred
 
 from synapse.config.homeserver import HomeServerConfig
@@ -29,7 +30,7 @@ class FederationRateLimiterTestCase(TestCase):
         """A simple test with the default values"""
         reactor, clock = get_clock()
         rc_config = build_rc_config()
-        ratelimiter = FederationRateLimiter(clock, rc_config)
+        ratelimiter = FederationRateLimiter(reactor, clock, rc_config)
 
         with ratelimiter.ratelimit("testhost") as d1:
             # shouldn't block
@@ -39,7 +40,7 @@ class FederationRateLimiterTestCase(TestCase):
         """Test what happens when we hit the concurrent limit"""
         reactor, clock = get_clock()
         rc_config = build_rc_config({"rc_federation": {"concurrent": 2}})
-        ratelimiter = FederationRateLimiter(clock, rc_config)
+        ratelimiter = FederationRateLimiter(reactor, clock, rc_config)
 
         with ratelimiter.ratelimit("testhost") as d1:
             # shouldn't block
@@ -57,6 +58,7 @@ class FederationRateLimiterTestCase(TestCase):
 
             # ... until we complete an earlier request
             cm2.__exit__(None, None, None)
+            reactor.advance(0.0)
             self.successResultOf(d3)
 
     def test_sleep_limit(self) -> None:
@@ -65,7 +67,7 @@ class FederationRateLimiterTestCase(TestCase):
         rc_config = build_rc_config(
             {"rc_federation": {"sleep_limit": 2, "sleep_delay": 500}}
         )
-        ratelimiter = FederationRateLimiter(clock, rc_config)
+        ratelimiter = FederationRateLimiter(reactor, clock, rc_config)
 
         with ratelimiter.ratelimit("testhost") as d1:
             # shouldn't block
@@ -81,6 +83,43 @@ class FederationRateLimiterTestCase(TestCase):
             sleep_time = _await_resolution(reactor, d3)
             self.assertAlmostEqual(sleep_time, 500, places=3)
 
+    def test_lots_of_queued_things(self) -> None:
+        """Tests lots of synchronous things queued up behind a slow thing.
+
+        The stack should *not* explode when the slow thing completes.
+        """
+        reactor, clock = get_clock()
+        rc_config = build_rc_config(
+            {
+                "rc_federation": {
+                    "sleep_limit": 1000000000,  # never sleep
+                    "reject_limit": 1000000000,  # never reject requests
+                    "concurrent": 1,
+                }
+            }
+        )
+        ratelimiter = FederationRateLimiter(reactor, clock, rc_config)
+
+        with ratelimiter.ratelimit("testhost") as d:
+            # shouldn't block
+            self.successResultOf(d)
+
+            async def task() -> None:
+                with ratelimiter.ratelimit("testhost") as d:
+                    await d
+
+            for _ in range(1, 100):
+                defer.ensureDeferred(task())
+
+            last_task = defer.ensureDeferred(task())
+
+            # Upon exiting the context manager, all the synchronous things will resume.
+            # If a stack overflow occurs, the final task will not complete.
+
+        # Wait for all the things to complete.
+        reactor.advance(0.0)
+        self.successResultOf(last_task)
+
 
 def _await_resolution(reactor: ThreadedMemoryReactorClock, d: Deferred) -> float:
     """advance the clock until the deferred completes.