summary refs log tree commit diff
path: root/synapse/federation/transport
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/transport')
-rw-r--r--synapse/federation/transport/server.py300
1 files changed, 130 insertions, 170 deletions
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index ece6dbcf62..eb3e30a189 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -122,14 +122,9 @@ class TransportLayerServer(object):
         Args:
             handler (TransportReceivedHandler)
         """
-        self.received_handler = handler
-
-        # This is when someone is trying to send us a bunch of data.
-        self.server.register_path(
-            "PUT",
-            re.compile("^" + PREFIX + "/send/([^/]*)/$"),
-            self._with_authentication(self._on_send_request)
-        )
+        FederationSendServlet(
+            handler, self._with_authentication, self.server_name
+        ).register(self.server)
 
     @log_function
     def register_request_handler(self, handler):
@@ -138,136 +133,48 @@ class TransportLayerServer(object):
         Args:
             handler (TransportRequestHandler)
         """
-        self.request_handler = handler
-
-        # This is for when someone asks us for everything since version X
-        self.server.register_path(
-            "GET",
-            re.compile("^" + PREFIX + "/pull/$"),
-            self._with_authentication(
-                lambda origin, content, query:
-                handler.on_pull_request(query["origin"][0], query["v"])
-            )
-        )
-
-        # This is when someone asks for a data item for a given server
-        # data_id pair.
-        self.server.register_path(
-            "GET",
-            re.compile("^" + PREFIX + "/event/([^/]*)/$"),
-            self._with_authentication(
-                lambda origin, content, query, event_id:
-                handler.on_pdu_request(origin, event_id)
-            )
-        )
-
-        # This is when someone asks for all data for a given context.
-        self.server.register_path(
-            "GET",
-            re.compile("^" + PREFIX + "/state/([^/]*)/$"),
-            self._with_authentication(
-                lambda origin, content, query, context:
-                handler.on_context_state_request(
-                    origin,
-                    context,
-                    query.get("event_id", [None])[0],
-                )
-            )
-        )
-
-        self.server.register_path(
-            "GET",
-            re.compile("^" + PREFIX + "/backfill/([^/]*)/$"),
-            self._with_authentication(
-                lambda origin, content, query, context:
-                self._on_backfill_request(
-                    origin, context, query["v"], query["limit"]
-                )
-            )
-        )
-
-        # This is when we receive a server-server Query
-        self.server.register_path(
-            "GET",
-            re.compile("^" + PREFIX + "/query/([^/]*)$"),
-            self._with_authentication(
-                lambda origin, content, query, query_type:
-                handler.on_query_request(
-                    query_type,
-                    {k: v[0].decode("utf-8") for k, v in query.items()}
-                )
-            )
-        )
-
-        self.server.register_path(
-            "GET",
-            re.compile("^" + PREFIX + "/make_join/([^/]*)/([^/]*)$"),
-            self._with_authentication(
-                lambda origin, content, query, context, user_id:
-                self._on_make_join_request(
-                    origin, content, query, context, user_id
-                )
-            )
-        )
-
-        self.server.register_path(
-            "GET",
-            re.compile("^" + PREFIX + "/event_auth/([^/]*)/([^/]*)$"),
-            self._with_authentication(
-                lambda origin, content, query, context, event_id:
-                handler.on_event_auth(
-                    origin, context, event_id,
-                )
-            )
-        )
-
-        self.server.register_path(
-            "PUT",
-            re.compile("^" + PREFIX + "/send_join/([^/]*)/([^/]*)$"),
-            self._with_authentication(
-                lambda origin, content, query, context, event_id:
-                self._on_send_join_request(
-                    origin, content, query,
-                )
-            )
-        )
-
-        self.server.register_path(
-            "PUT",
-            re.compile("^" + PREFIX + "/invite/([^/]*)/([^/]*)$"),
-            self._with_authentication(
-                lambda origin, content, query, context, event_id:
-                self._on_invite_request(
-                    origin, content, query,
-                )
-            )
-        )
-
-        self.server.register_path(
-            "POST",
-            re.compile("^" + PREFIX + "/query_auth/([^/]*)/([^/]*)$"),
-            self._with_authentication(
-                lambda origin, content, query, context, event_id:
-                self._on_query_auth_request(
-                    origin, content, event_id,
-                )
-            )
-        )
-
-        self.server.register_path(
-            "POST",
-            re.compile("^" + PREFIX + "/get_missing_events/([^/]*)/?$"),
-            self._with_authentication(
-                lambda origin, content, query, room_id:
-                self._get_missing_events(
-                    origin, content, room_id,
-                )
-            )
-        )
-
+        for servletclass in (
+                FederationPullServlet,
+                FederationEventServlet,
+                FederationStateServlet,
+                FederationBackfillServlet,
+                FederationQueryServlet,
+                FederationMakeJoinServlet,
+                FederationEventServlet,
+                FederationSendJoinServlet,
+                FederationInviteServlet,
+                FederationQueryAuthServlet,
+                FederationGetMissingEventsServlet,
+            ):
+            servletclass(handler, self._with_authentication).register(self.server)
+
+
+class BaseFederationServlet(object):
+    def __init__(self, handler, wrapper):
+        self.handler = handler
+        self.wrapper = wrapper
+
+    def register(self, server):
+        pattern = re.compile("^" + PREFIX + self.PATH)
+
+        for method in ("GET", "PUT", "POST"):
+            code = getattr(self, "on_%s" % (method), None)
+            if code is None:
+                continue
+
+            server.register_path(method, pattern, self.wrapper(code))
+
+
+class FederationSendServlet(BaseFederationServlet):
+    PATH = "/send/([^/]*)/$"
+
+    def __init__(self, handler, wrapper, server_name):
+        super(FederationSendServlet, self).__init__(handler, wrapper)
+        self.server_name = server_name
+
+    # This is when someone is trying to send us a bunch of data.
     @defer.inlineCallbacks
-    @log_function
-    def _on_send_request(self, origin, content, query, transaction_id):
+    def on_PUT(self, origin, content, query, transaction_id):
         """ Called on PUT /send/<transaction_id>/
 
         Args:
@@ -305,8 +212,7 @@ class TransportLayerServer(object):
             return
 
         try:
-            handler = self.received_handler
-            code, response = yield handler.on_incoming_transaction(
+            code, response = yield self.handler.on_incoming_transaction(
                 transaction_data
             )
         except:
@@ -315,65 +221,119 @@ class TransportLayerServer(object):
 
         defer.returnValue((code, response))
 
-    @log_function
-    def _on_backfill_request(self, origin, context, v_list, limits):
+
+class FederationPullServlet(BaseFederationServlet):
+    PATH = "/pull/$"
+
+    # This is for when someone asks us for everything since version X
+    def on_GET(self, origin, content, query):
+        return self.handler.on_pull_request(query["origin"][0], query["v"])
+
+
+class FederationEventServlet(BaseFederationServlet):
+    PATH = "/event/([^/]*)/$"
+
+    # This is when someone asks for a data item for a given server data_id pair.
+    def on_GET(self, origin, content, query, event_id):
+        return self.handler.on_pdu_request(origin, event_id)
+
+
+class FederationStateServlet(BaseFederationServlet):
+    PATH = "/state/([^/]*)/$"
+
+    # This is when someone asks for all data for a given context.
+    def on_GET(self, origin, content, query, context):
+        return self.handler.on_context_state_request(origin, context,
+            query.get("event_id", [None])[0],
+        )
+
+
+class FederationBackfillServlet(BaseFederationServlet):
+    PATH = "/backfill/([^/]*)/$"
+
+    def on_GET(self, origin, content, query, context):
+        versions = query["v"]
+        limits = query["limit"]
+
         if not limits:
-            return defer.succeed(
-                (400, {"error": "Did not include limit param"})
-            )
+            return defer.succeed((400, {"error": "Did not include limit param"}))
 
         limit = int(limits[-1])
 
-        versions = v_list
+        return self.handler.on_backfill_request(origin, context, versions, limit)
+
+
+class FederationQueryServlet(BaseFederationServlet):
+    PATH = "/query/([^/]*)$"
 
-        return self.request_handler.on_backfill_request(
-            origin, context, versions, limit
+    # This is when we receive a server-server Query
+    def on_GET(self, origin, content, query, query_type):
+        return self.handler.on_query_request(query_type,
+            {k: v[0].decode("utf-8") for k, v in query.items()}
         )
 
+
+class FederationMakeJoinServlet(BaseFederationServlet):
+    PATH = "/make_join/([^/]*)/([^/]*)$"
+
     @defer.inlineCallbacks
-    @log_function
-    def _on_make_join_request(self, origin, content, query, context, user_id):
-        content = yield self.request_handler.on_make_join_request(
-            context, user_id,
-        )
+    def on_GET(self, origin, content, query, context, user_id):
+        content = yield self.handler.on_make_join_request(context, user_id)
         defer.returnValue((200, content))
 
-    @defer.inlineCallbacks
-    @log_function
-    def _on_send_join_request(self, origin, content, query):
-        content = yield self.request_handler.on_send_join_request(
-            origin, content,
-        )
 
-        defer.returnValue((200, content))
+class FederationEventAuthServlet(BaseFederationServlet):
+    PATH = "/event_auth/([^/]*)/([^/]*)$"
+
+    def on_GET(self, origin, content, query, context, event_id):
+        return self.handler.on_event_auth(origin, context, event_id)
+
+
+class FederationSendJoinServlet(BaseFederationServlet):
+    PATH = "/send_join/([^/]*)/([^/]*)$"
 
     @defer.inlineCallbacks
-    @log_function
-    def _on_invite_request(self, origin, content, query):
-        content = yield self.request_handler.on_invite_request(
-            origin, content,
-        )
+    def on_PUT(self, origin, content, query, context, event_id):
+        # TODO(paul): assert that context/event_id parsed from path actually
+        #   match those given in content
+        content = yield self.handler.on_send_join_request(origin, content)
+        defer.returnValue((200, content))
+
+
+class FederationInviteServlet(BaseFederationServlet):
+    PATH = "/invite/([^/]*)/([^/]*)$"
 
+    @defer.inlineCallbacks
+    def on_PUT(self, origin, content, query, context, event_id):
+        # TODO(paul): assert that context/event_id parsed from path actually
+        #   match those given in content
+        content = yield self.handler.on_invite_request(origin, content)
         defer.returnValue((200, content))
 
+
+class FederationQueryAuthServlet(BaseFederationServlet):
+    PATH = "/query_auth/([^/]*)/([^/]*)$"
+
     @defer.inlineCallbacks
-    @log_function
-    def _on_query_auth_request(self, origin, content, event_id):
-        new_content = yield self.request_handler.on_query_auth_request(
+    def on_POST(self, origin, content, query, context, event_id):
+        new_content = yield self.handler.on_query_auth_request(
             origin, content, event_id
         )
 
         defer.returnValue((200, new_content))
 
+
+class FederationGetMissingEventsServlet(BaseFederationServlet):
+    PATH = "/get_missing_events/([^/]*)/?$"
+
     @defer.inlineCallbacks
-    @log_function
-    def _get_missing_events(self, origin, content, room_id):
+    def on_POST(self, origin, content, query, room_id):
         limit = int(content.get("limit", 10))
         min_depth = int(content.get("min_depth", 0))
         earliest_events = content.get("earliest_events", [])
         latest_events = content.get("latest_events", [])
 
-        content = yield self.request_handler.on_get_missing_events(
+        content = yield self.handler.on_get_missing_events(
             origin,
             room_id=room_id,
             earliest_events=earliest_events,