diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index b518dace8a..7b4baddbf8 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -46,7 +46,6 @@ from synapse.logging.opentracing import (
)
from synapse.server import HomeServer
from synapse.types import ThirdPartyInstanceID, get_domain_from_id
-from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string
logger = logging.getLogger(__name__)
@@ -70,12 +69,10 @@ class TransportLayerServer(JsonResource):
self.clock = hs.get_clock()
self.servlet_groups = servlet_groups
- super(TransportLayerServer, self).__init__(hs, canonical_json=False)
+ super().__init__(hs, canonical_json=False)
self.authenticator = Authenticator(hs)
- self.ratelimiter = FederationRateLimiter(
- self.clock, config=hs.config.rc_federation
- )
+ self.ratelimiter = hs.get_federation_ratelimiter()
self.register_servlets()
@@ -101,7 +98,7 @@ class NoAuthenticationError(AuthenticationError):
pass
-class Authenticator(object):
+class Authenticator:
def __init__(self, hs: HomeServer):
self._clock = hs.get_clock()
self.keyring = hs.get_keyring()
@@ -229,7 +226,7 @@ def _parse_auth_header(header_bytes):
)
-class BaseFederationServlet(object):
+class BaseFederationServlet:
"""Abstract base class for federation servlet classes.
The servlet object should have a PATH attribute which takes the form of a regexp to
@@ -273,6 +270,8 @@ class BaseFederationServlet(object):
PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version
+ RATELIMIT = True # Whether to rate limit requests or not
+
def __init__(self, handler, authenticator, ratelimiter, server_name):
self.handler = handler
self.authenticator = authenticator
@@ -336,7 +335,7 @@ class BaseFederationServlet(object):
)
with scope:
- if origin:
+ if origin and self.RATELIMIT:
with ratelimiter.ratelimit(origin) as d:
await d
if request._disconnected:
@@ -373,10 +372,12 @@ class BaseFederationServlet(object):
class FederationSendServlet(BaseFederationServlet):
PATH = "/send/(?P<transaction_id>[^/]*)/?"
+ # We ratelimit manually in the handler as we queue up the requests and we
+ # don't want to fill up the ratelimiter with blocked requests.
+ RATELIMIT = False
+
def __init__(self, handler, server_name, **kwargs):
- super(FederationSendServlet, self).__init__(
- handler, server_name=server_name, **kwargs
- )
+ super().__init__(handler, server_name=server_name, **kwargs)
self.server_name = server_name
# This is when someone is trying to send us a bunch of data.
@@ -771,9 +772,7 @@ class PublicRoomList(BaseFederationServlet):
PATH = "/publicRooms"
def __init__(self, handler, authenticator, ratelimiter, server_name, allow_access):
- super(PublicRoomList, self).__init__(
- handler, authenticator, ratelimiter, server_name
- )
+ super().__init__(handler, authenticator, ratelimiter, server_name)
self.allow_access = allow_access
async def on_GET(self, origin, content, query):
|