diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index dc9f1e082b..39b18ae303 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -19,6 +19,7 @@ from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError
from synapse.util.logutils import log_function
+import functools
import logging
import simplejson as json
import re
@@ -30,8 +31,9 @@ logger = logging.getLogger(__name__)
class TransportLayerServer(object):
"""Handles incoming federation HTTP requests"""
+ # A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks
- def _authenticate_request(self, request):
+ def authenticate_request(self, request):
json_request = {
"method": request.method,
"uri": request.uri,
@@ -93,22 +95,6 @@ class TransportLayerServer(object):
defer.returnValue((origin, content))
- def _with_authentication(self, handler):
- @defer.inlineCallbacks
- def new_handler(request, *args, **kwargs):
- try:
- (origin, content) = yield self._authenticate_request(request)
- with self.ratelimiter.ratelimit(origin) as d:
- yield d
- response = yield handler(
- origin, content, request.args, *args, **kwargs
- )
- except:
- logger.exception("_authenticate_request failed")
- raise
- defer.returnValue(response)
- return new_handler
-
@log_function
def register_received_handler(self, handler):
""" Register a handler that will be fired when we receive data.
@@ -116,8 +102,10 @@ class TransportLayerServer(object):
Args:
handler (TransportReceivedHandler)
"""
- FederationSendServlet(
- handler, self._with_authentication, self.server_name
+ FederationSendServlet(handler,
+ authenticator=self,
+ ratelimiter=self.ratelimiter,
+ server_name=self.server_name,
).register(self.server)
@log_function
@@ -140,13 +128,37 @@ class TransportLayerServer(object):
FederationQueryAuthServlet,
FederationGetMissingEventsServlet,
):
- servletclass(handler, self._with_authentication).register(self.server)
+ servletclass(handler,
+ authenticator=self,
+ ratelimiter=self.ratelimiter,
+ ).register(self.server)
class BaseFederationServlet(object):
- def __init__(self, handler, wrapper):
+ def __init__(self, handler, authenticator, ratelimiter):
self.handler = handler
- self.wrapper = wrapper
+ self.authenticator = authenticator
+ self.ratelimiter = ratelimiter
+
+ def _wrap(self, code):
+ authenticator = self.authenticator
+ ratelimiter = self.ratelimiter
+
+ @defer.inlineCallbacks
+ @functools.wraps(code)
+ def new_code(request, *args, **kwargs):
+ try:
+ (origin, content) = yield authenticator.authenticate_request(request)
+ with ratelimiter.ratelimit(origin) as d:
+ yield d
+ response = yield code(
+ origin, content, request.args, *args, **kwargs
+ )
+ except:
+ logger.exception("authenticate_request failed")
+ raise
+ defer.returnValue(response)
+ return new_code
def register(self, server):
pattern = re.compile("^" + PREFIX + self.PATH)
@@ -156,14 +168,14 @@ class BaseFederationServlet(object):
if code is None:
continue
- server.register_path(method, pattern, self.wrapper(code))
+ server.register_path(method, pattern, self._wrap(code))
class FederationSendServlet(BaseFederationServlet):
PATH = "/send/([^/]*)/$"
- def __init__(self, handler, wrapper, server_name):
- super(FederationSendServlet, self).__init__(handler, wrapper)
+ def __init__(self, handler, server_name, **kwargs):
+ super(FederationSendServlet, self).__init__(handler, **kwargs)
self.server_name = server_name
# This is when someone is trying to send us a bunch of data.
|