diff options
Diffstat (limited to '')
-rw-r--r-- | synapse/federation/transport/server.py | 62 |
1 files changed, 37 insertions, 25 deletions
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index dc9f1e082b..39b18ae303 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -19,6 +19,7 @@ from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.errors import Codes, SynapseError from synapse.util.logutils import log_function +import functools import logging import simplejson as json import re @@ -30,8 +31,9 @@ logger = logging.getLogger(__name__) class TransportLayerServer(object): """Handles incoming federation HTTP requests""" + # A method just so we can pass 'self' as the authenticator to the Servlets @defer.inlineCallbacks - def _authenticate_request(self, request): + def authenticate_request(self, request): json_request = { "method": request.method, "uri": request.uri, @@ -93,22 +95,6 @@ class TransportLayerServer(object): defer.returnValue((origin, content)) - def _with_authentication(self, handler): - @defer.inlineCallbacks - def new_handler(request, *args, **kwargs): - try: - (origin, content) = yield self._authenticate_request(request) - with self.ratelimiter.ratelimit(origin) as d: - yield d - response = yield handler( - origin, content, request.args, *args, **kwargs - ) - except: - logger.exception("_authenticate_request failed") - raise - defer.returnValue(response) - return new_handler - @log_function def register_received_handler(self, handler): """ Register a handler that will be fired when we receive data. @@ -116,8 +102,10 @@ class TransportLayerServer(object): Args: handler (TransportReceivedHandler) """ - FederationSendServlet( - handler, self._with_authentication, self.server_name + FederationSendServlet(handler, + authenticator=self, + ratelimiter=self.ratelimiter, + server_name=self.server_name, ).register(self.server) @log_function @@ -140,13 +128,37 @@ class TransportLayerServer(object): FederationQueryAuthServlet, FederationGetMissingEventsServlet, ): - servletclass(handler, self._with_authentication).register(self.server) + servletclass(handler, + authenticator=self, + ratelimiter=self.ratelimiter, + ).register(self.server) class BaseFederationServlet(object): - def __init__(self, handler, wrapper): + def __init__(self, handler, authenticator, ratelimiter): self.handler = handler - self.wrapper = wrapper + self.authenticator = authenticator + self.ratelimiter = ratelimiter + + def _wrap(self, code): + authenticator = self.authenticator + ratelimiter = self.ratelimiter + + @defer.inlineCallbacks + @functools.wraps(code) + def new_code(request, *args, **kwargs): + try: + (origin, content) = yield authenticator.authenticate_request(request) + with ratelimiter.ratelimit(origin) as d: + yield d + response = yield code( + origin, content, request.args, *args, **kwargs + ) + except: + logger.exception("authenticate_request failed") + raise + defer.returnValue(response) + return new_code def register(self, server): pattern = re.compile("^" + PREFIX + self.PATH) @@ -156,14 +168,14 @@ class BaseFederationServlet(object): if code is None: continue - server.register_path(method, pattern, self.wrapper(code)) + server.register_path(method, pattern, self._wrap(code)) class FederationSendServlet(BaseFederationServlet): PATH = "/send/([^/]*)/$" - def __init__(self, handler, wrapper, server_name): - super(FederationSendServlet, self).__init__(handler, wrapper) + def __init__(self, handler, server_name, **kwargs): + super(FederationSendServlet, self).__init__(handler, **kwargs) self.server_name = server_name # This is when someone is trying to send us a bunch of data. |