summary refs log tree commit diff
path: root/synapse/federation/transport
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/transport')
-rw-r--r--synapse/federation/transport/server.py62
1 files changed, 37 insertions, 25 deletions
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index dc9f1e082b..39b18ae303 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -19,6 +19,7 @@ from synapse.api.urls import FEDERATION_PREFIX as PREFIX
 from synapse.api.errors import Codes, SynapseError
 from synapse.util.logutils import log_function
 
+import functools
 import logging
 import simplejson as json
 import re
@@ -30,8 +31,9 @@ logger = logging.getLogger(__name__)
 class TransportLayerServer(object):
     """Handles incoming federation HTTP requests"""
 
+    # A method just so we can pass 'self' as the authenticator to the Servlets
     @defer.inlineCallbacks
-    def _authenticate_request(self, request):
+    def authenticate_request(self, request):
         json_request = {
             "method": request.method,
             "uri": request.uri,
@@ -93,22 +95,6 @@ class TransportLayerServer(object):
 
         defer.returnValue((origin, content))
 
-    def _with_authentication(self, handler):
-        @defer.inlineCallbacks
-        def new_handler(request, *args, **kwargs):
-            try:
-                (origin, content) = yield self._authenticate_request(request)
-                with self.ratelimiter.ratelimit(origin) as d:
-                    yield d
-                    response = yield handler(
-                        origin, content, request.args, *args, **kwargs
-                    )
-            except:
-                logger.exception("_authenticate_request failed")
-                raise
-            defer.returnValue(response)
-        return new_handler
-
     @log_function
     def register_received_handler(self, handler):
         """ Register a handler that will be fired when we receive data.
@@ -116,8 +102,10 @@ class TransportLayerServer(object):
         Args:
             handler (TransportReceivedHandler)
         """
-        FederationSendServlet(
-            handler, self._with_authentication, self.server_name
+        FederationSendServlet(handler,
+            authenticator=self,
+            ratelimiter=self.ratelimiter,
+            server_name=self.server_name,
         ).register(self.server)
 
     @log_function
@@ -140,13 +128,37 @@ class TransportLayerServer(object):
                 FederationQueryAuthServlet,
                 FederationGetMissingEventsServlet,
             ):
-            servletclass(handler, self._with_authentication).register(self.server)
+            servletclass(handler,
+                authenticator=self,
+                ratelimiter=self.ratelimiter,
+            ).register(self.server)
 
 
 class BaseFederationServlet(object):
-    def __init__(self, handler, wrapper):
+    def __init__(self, handler, authenticator, ratelimiter):
         self.handler = handler
-        self.wrapper = wrapper
+        self.authenticator = authenticator
+        self.ratelimiter   = ratelimiter
+
+    def _wrap(self, code):
+        authenticator = self.authenticator
+        ratelimiter   = self.ratelimiter
+
+        @defer.inlineCallbacks
+        @functools.wraps(code)
+        def new_code(request, *args, **kwargs):
+            try:
+                (origin, content) = yield authenticator.authenticate_request(request)
+                with ratelimiter.ratelimit(origin) as d:
+                    yield d
+                    response = yield code(
+                        origin, content, request.args, *args, **kwargs
+                    )
+            except:
+                logger.exception("authenticate_request failed")
+                raise
+            defer.returnValue(response)
+        return new_code
 
     def register(self, server):
         pattern = re.compile("^" + PREFIX + self.PATH)
@@ -156,14 +168,14 @@ class BaseFederationServlet(object):
             if code is None:
                 continue
 
-            server.register_path(method, pattern, self.wrapper(code))
+            server.register_path(method, pattern, self._wrap(code))
 
 
 class FederationSendServlet(BaseFederationServlet):
     PATH = "/send/([^/]*)/$"
 
-    def __init__(self, handler, wrapper, server_name):
-        super(FederationSendServlet, self).__init__(handler, wrapper)
+    def __init__(self, handler, server_name, **kwargs):
+        super(FederationSendServlet, self).__init__(handler, **kwargs)
         self.server_name = server_name
 
     # This is when someone is trying to send us a bunch of data.