diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 8f985f8fe3..6c624977d7 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -102,7 +102,8 @@ class TransportLayerServer(object):
Args:
handler (TransportReceivedHandler)
"""
- FederationSendServlet(handler,
+ FederationSendServlet(
+ handler,
authenticator=self,
ratelimiter=self.ratelimiter,
server_name=self.server_name,
@@ -115,20 +116,9 @@ class TransportLayerServer(object):
Args:
handler (TransportRequestHandler)
"""
- for servletclass in (
- FederationPullServlet,
- FederationEventServlet,
- FederationStateServlet,
- FederationBackfillServlet,
- FederationQueryServlet,
- FederationMakeJoinServlet,
- FederationEventServlet,
- FederationSendJoinServlet,
- FederationInviteServlet,
- FederationQueryAuthServlet,
- FederationGetMissingEventsServlet,
- ):
- servletclass(handler,
+ for servletclass in SERVLET_CLASSES:
+ servletclass(
+ handler,
authenticator=self,
ratelimiter=self.ratelimiter,
).register(self.server)
@@ -138,11 +128,11 @@ class BaseFederationServlet(object):
def __init__(self, handler, authenticator, ratelimiter):
self.handler = handler
self.authenticator = authenticator
- self.ratelimiter = ratelimiter
+ self.ratelimiter = ratelimiter
def _wrap(self, code):
authenticator = self.authenticator
- ratelimiter = self.ratelimiter
+ ratelimiter = self.ratelimiter
@defer.inlineCallbacks
@functools.wraps(code)
@@ -249,7 +239,9 @@ class FederationStateServlet(BaseFederationServlet):
# This is when someone asks for all data for a given context.
def on_GET(self, origin, content, query, context):
- return self.handler.on_context_state_request(origin, context,
+ return self.handler.on_context_state_request(
+ origin,
+ context,
query.get("event_id", [None])[0],
)
@@ -274,7 +266,8 @@ class FederationQueryServlet(BaseFederationServlet):
# This is when we receive a server-server Query
def on_GET(self, origin, content, query, query_type):
- return self.handler.on_query_request(query_type,
+ return self.handler.on_query_request(
+ query_type,
{k: v[0].decode("utf-8") for k, v in query.items()}
)
@@ -350,3 +343,18 @@ class FederationGetMissingEventsServlet(BaseFederationServlet):
)
defer.returnValue((200, content))
+
+
+SERVLET_CLASSES = (
+ FederationPullServlet,
+ FederationEventServlet,
+ FederationStateServlet,
+ FederationBackfillServlet,
+ FederationQueryServlet,
+ FederationMakeJoinServlet,
+ FederationEventServlet,
+ FederationSendJoinServlet,
+ FederationInviteServlet,
+ FederationQueryAuthServlet,
+ FederationGetMissingEventsServlet,
+)
|