diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index ece6dbcf62..eb3e30a189 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -122,14 +122,9 @@ class TransportLayerServer(object):
Args:
handler (TransportReceivedHandler)
"""
- self.received_handler = handler
-
- # This is when someone is trying to send us a bunch of data.
- self.server.register_path(
- "PUT",
- re.compile("^" + PREFIX + "/send/([^/]*)/$"),
- self._with_authentication(self._on_send_request)
- )
+ FederationSendServlet(
+ handler, self._with_authentication, self.server_name
+ ).register(self.server)
@log_function
def register_request_handler(self, handler):
@@ -138,136 +133,48 @@ class TransportLayerServer(object):
Args:
handler (TransportRequestHandler)
"""
- self.request_handler = handler
-
- # This is for when someone asks us for everything since version X
- self.server.register_path(
- "GET",
- re.compile("^" + PREFIX + "/pull/$"),
- self._with_authentication(
- lambda origin, content, query:
- handler.on_pull_request(query["origin"][0], query["v"])
- )
- )
-
- # This is when someone asks for a data item for a given server
- # data_id pair.
- self.server.register_path(
- "GET",
- re.compile("^" + PREFIX + "/event/([^/]*)/$"),
- self._with_authentication(
- lambda origin, content, query, event_id:
- handler.on_pdu_request(origin, event_id)
- )
- )
-
- # This is when someone asks for all data for a given context.
- self.server.register_path(
- "GET",
- re.compile("^" + PREFIX + "/state/([^/]*)/$"),
- self._with_authentication(
- lambda origin, content, query, context:
- handler.on_context_state_request(
- origin,
- context,
- query.get("event_id", [None])[0],
- )
- )
- )
-
- self.server.register_path(
- "GET",
- re.compile("^" + PREFIX + "/backfill/([^/]*)/$"),
- self._with_authentication(
- lambda origin, content, query, context:
- self._on_backfill_request(
- origin, context, query["v"], query["limit"]
- )
- )
- )
-
- # This is when we receive a server-server Query
- self.server.register_path(
- "GET",
- re.compile("^" + PREFIX + "/query/([^/]*)$"),
- self._with_authentication(
- lambda origin, content, query, query_type:
- handler.on_query_request(
- query_type,
- {k: v[0].decode("utf-8") for k, v in query.items()}
- )
- )
- )
-
- self.server.register_path(
- "GET",
- re.compile("^" + PREFIX + "/make_join/([^/]*)/([^/]*)$"),
- self._with_authentication(
- lambda origin, content, query, context, user_id:
- self._on_make_join_request(
- origin, content, query, context, user_id
- )
- )
- )
-
- self.server.register_path(
- "GET",
- re.compile("^" + PREFIX + "/event_auth/([^/]*)/([^/]*)$"),
- self._with_authentication(
- lambda origin, content, query, context, event_id:
- handler.on_event_auth(
- origin, context, event_id,
- )
- )
- )
-
- self.server.register_path(
- "PUT",
- re.compile("^" + PREFIX + "/send_join/([^/]*)/([^/]*)$"),
- self._with_authentication(
- lambda origin, content, query, context, event_id:
- self._on_send_join_request(
- origin, content, query,
- )
- )
- )
-
- self.server.register_path(
- "PUT",
- re.compile("^" + PREFIX + "/invite/([^/]*)/([^/]*)$"),
- self._with_authentication(
- lambda origin, content, query, context, event_id:
- self._on_invite_request(
- origin, content, query,
- )
- )
- )
-
- self.server.register_path(
- "POST",
- re.compile("^" + PREFIX + "/query_auth/([^/]*)/([^/]*)$"),
- self._with_authentication(
- lambda origin, content, query, context, event_id:
- self._on_query_auth_request(
- origin, content, event_id,
- )
- )
- )
-
- self.server.register_path(
- "POST",
- re.compile("^" + PREFIX + "/get_missing_events/([^/]*)/?$"),
- self._with_authentication(
- lambda origin, content, query, room_id:
- self._get_missing_events(
- origin, content, room_id,
- )
- )
- )
-
+ for servletclass in (
+ FederationPullServlet,
+ FederationEventServlet,
+ FederationStateServlet,
+ FederationBackfillServlet,
+ FederationQueryServlet,
+ FederationMakeJoinServlet,
+ FederationEventServlet,
+ FederationSendJoinServlet,
+ FederationInviteServlet,
+ FederationQueryAuthServlet,
+ FederationGetMissingEventsServlet,
+ ):
+ servletclass(handler, self._with_authentication).register(self.server)
+
+
+class BaseFederationServlet(object):
+ def __init__(self, handler, wrapper):
+ self.handler = handler
+ self.wrapper = wrapper
+
+ def register(self, server):
+ pattern = re.compile("^" + PREFIX + self.PATH)
+
+ for method in ("GET", "PUT", "POST"):
+ code = getattr(self, "on_%s" % (method), None)
+ if code is None:
+ continue
+
+ server.register_path(method, pattern, self.wrapper(code))
+
+
+class FederationSendServlet(BaseFederationServlet):
+ PATH = "/send/([^/]*)/$"
+
+ def __init__(self, handler, wrapper, server_name):
+ super(FederationSendServlet, self).__init__(handler, wrapper)
+ self.server_name = server_name
+
+ # This is when someone is trying to send us a bunch of data.
@defer.inlineCallbacks
- @log_function
- def _on_send_request(self, origin, content, query, transaction_id):
+ def on_PUT(self, origin, content, query, transaction_id):
""" Called on PUT /send/<transaction_id>/
Args:
@@ -305,8 +212,7 @@ class TransportLayerServer(object):
return
try:
- handler = self.received_handler
- code, response = yield handler.on_incoming_transaction(
+ code, response = yield self.handler.on_incoming_transaction(
transaction_data
)
except:
@@ -315,65 +221,119 @@ class TransportLayerServer(object):
defer.returnValue((code, response))
- @log_function
- def _on_backfill_request(self, origin, context, v_list, limits):
+
+class FederationPullServlet(BaseFederationServlet):
+ PATH = "/pull/$"
+
+ # This is for when someone asks us for everything since version X
+ def on_GET(self, origin, content, query):
+ return self.handler.on_pull_request(query["origin"][0], query["v"])
+
+
+class FederationEventServlet(BaseFederationServlet):
+ PATH = "/event/([^/]*)/$"
+
+ # This is when someone asks for a data item for a given server data_id pair.
+ def on_GET(self, origin, content, query, event_id):
+ return self.handler.on_pdu_request(origin, event_id)
+
+
+class FederationStateServlet(BaseFederationServlet):
+ PATH = "/state/([^/]*)/$"
+
+ # This is when someone asks for all data for a given context.
+ def on_GET(self, origin, content, query, context):
+ return self.handler.on_context_state_request(origin, context,
+ query.get("event_id", [None])[0],
+ )
+
+
+class FederationBackfillServlet(BaseFederationServlet):
+ PATH = "/backfill/([^/]*)/$"
+
+ def on_GET(self, origin, content, query, context):
+ versions = query["v"]
+ limits = query["limit"]
+
if not limits:
- return defer.succeed(
- (400, {"error": "Did not include limit param"})
- )
+ return defer.succeed((400, {"error": "Did not include limit param"}))
limit = int(limits[-1])
- versions = v_list
+ return self.handler.on_backfill_request(origin, context, versions, limit)
+
+
+class FederationQueryServlet(BaseFederationServlet):
+ PATH = "/query/([^/]*)$"
- return self.request_handler.on_backfill_request(
- origin, context, versions, limit
+ # This is when we receive a server-server Query
+ def on_GET(self, origin, content, query, query_type):
+ return self.handler.on_query_request(query_type,
+ {k: v[0].decode("utf-8") for k, v in query.items()}
)
+
+class FederationMakeJoinServlet(BaseFederationServlet):
+ PATH = "/make_join/([^/]*)/([^/]*)$"
+
@defer.inlineCallbacks
- @log_function
- def _on_make_join_request(self, origin, content, query, context, user_id):
- content = yield self.request_handler.on_make_join_request(
- context, user_id,
- )
+ def on_GET(self, origin, content, query, context, user_id):
+ content = yield self.handler.on_make_join_request(context, user_id)
defer.returnValue((200, content))
- @defer.inlineCallbacks
- @log_function
- def _on_send_join_request(self, origin, content, query):
- content = yield self.request_handler.on_send_join_request(
- origin, content,
- )
- defer.returnValue((200, content))
+class FederationEventAuthServlet(BaseFederationServlet):
+ PATH = "/event_auth/([^/]*)/([^/]*)$"
+
+ def on_GET(self, origin, content, query, context, event_id):
+ return self.handler.on_event_auth(origin, context, event_id)
+
+
+class FederationSendJoinServlet(BaseFederationServlet):
+ PATH = "/send_join/([^/]*)/([^/]*)$"
@defer.inlineCallbacks
- @log_function
- def _on_invite_request(self, origin, content, query):
- content = yield self.request_handler.on_invite_request(
- origin, content,
- )
+ def on_PUT(self, origin, content, query, context, event_id):
+ # TODO(paul): assert that context/event_id parsed from path actually
+ # match those given in content
+ content = yield self.handler.on_send_join_request(origin, content)
+ defer.returnValue((200, content))
+
+
+class FederationInviteServlet(BaseFederationServlet):
+ PATH = "/invite/([^/]*)/([^/]*)$"
+ @defer.inlineCallbacks
+ def on_PUT(self, origin, content, query, context, event_id):
+ # TODO(paul): assert that context/event_id parsed from path actually
+ # match those given in content
+ content = yield self.handler.on_invite_request(origin, content)
defer.returnValue((200, content))
+
+class FederationQueryAuthServlet(BaseFederationServlet):
+ PATH = "/query_auth/([^/]*)/([^/]*)$"
+
@defer.inlineCallbacks
- @log_function
- def _on_query_auth_request(self, origin, content, event_id):
- new_content = yield self.request_handler.on_query_auth_request(
+ def on_POST(self, origin, content, query, context, event_id):
+ new_content = yield self.handler.on_query_auth_request(
origin, content, event_id
)
defer.returnValue((200, new_content))
+
+class FederationGetMissingEventsServlet(BaseFederationServlet):
+ PATH = "/get_missing_events/([^/]*)/?$"
+
@defer.inlineCallbacks
- @log_function
- def _get_missing_events(self, origin, content, room_id):
+ def on_POST(self, origin, content, query, room_id):
limit = int(content.get("limit", 10))
min_depth = int(content.get("min_depth", 0))
earliest_events = content.get("earliest_events", [])
latest_events = content.get("latest_events", [])
- content = yield self.request_handler.on_get_missing_events(
+ content = yield self.handler.on_get_missing_events(
origin,
room_id=room_id,
earliest_events=earliest_events,
|