diff options
Diffstat (limited to 'synapse/federation/transport')
-rw-r--r-- | synapse/federation/transport/server.py | 46 |
1 files changed, 27 insertions, 19 deletions
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 8f985f8fe3..6c624977d7 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -102,7 +102,8 @@ class TransportLayerServer(object): Args: handler (TransportReceivedHandler) """ - FederationSendServlet(handler, + FederationSendServlet( + handler, authenticator=self, ratelimiter=self.ratelimiter, server_name=self.server_name, @@ -115,20 +116,9 @@ class TransportLayerServer(object): Args: handler (TransportRequestHandler) """ - for servletclass in ( - FederationPullServlet, - FederationEventServlet, - FederationStateServlet, - FederationBackfillServlet, - FederationQueryServlet, - FederationMakeJoinServlet, - FederationEventServlet, - FederationSendJoinServlet, - FederationInviteServlet, - FederationQueryAuthServlet, - FederationGetMissingEventsServlet, - ): - servletclass(handler, + for servletclass in SERVLET_CLASSES: + servletclass( + handler, authenticator=self, ratelimiter=self.ratelimiter, ).register(self.server) @@ -138,11 +128,11 @@ class BaseFederationServlet(object): def __init__(self, handler, authenticator, ratelimiter): self.handler = handler self.authenticator = authenticator - self.ratelimiter = ratelimiter + self.ratelimiter = ratelimiter def _wrap(self, code): authenticator = self.authenticator - ratelimiter = self.ratelimiter + ratelimiter = self.ratelimiter @defer.inlineCallbacks @functools.wraps(code) @@ -249,7 +239,9 @@ class FederationStateServlet(BaseFederationServlet): # This is when someone asks for all data for a given context. def on_GET(self, origin, content, query, context): - return self.handler.on_context_state_request(origin, context, + return self.handler.on_context_state_request( + origin, + context, query.get("event_id", [None])[0], ) @@ -274,7 +266,8 @@ class FederationQueryServlet(BaseFederationServlet): # This is when we receive a server-server Query def on_GET(self, origin, content, query, query_type): - return self.handler.on_query_request(query_type, + return self.handler.on_query_request( + query_type, {k: v[0].decode("utf-8") for k, v in query.items()} ) @@ -350,3 +343,18 @@ class FederationGetMissingEventsServlet(BaseFederationServlet): ) defer.returnValue((200, content)) + + +SERVLET_CLASSES = ( + FederationPullServlet, + FederationEventServlet, + FederationStateServlet, + FederationBackfillServlet, + FederationQueryServlet, + FederationMakeJoinServlet, + FederationEventServlet, + FederationSendJoinServlet, + FederationInviteServlet, + FederationQueryAuthServlet, + FederationGetMissingEventsServlet, +) |