summary refs log tree commit diff
path: root/synapse/federation/transport
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/transport')
-rw-r--r--synapse/federation/transport/__init__.py52
-rw-r--r--synapse/federation/transport/client.py4
-rw-r--r--synapse/federation/transport/server.py82
3 files changed, 54 insertions, 84 deletions
diff --git a/synapse/federation/transport/__init__.py b/synapse/federation/transport/__init__.py
index 155a7d5870..d9fcc520a0 100644
--- a/synapse/federation/transport/__init__.py
+++ b/synapse/federation/transport/__init__.py
@@ -20,55 +20,3 @@ By default this is done over HTTPS (and all home servers are required to
 support HTTPS), however individual pairings of servers may decide to
 communicate over a different (albeit still reliable) protocol.
 """
-
-from .server import TransportLayerServer
-from .client import TransportLayerClient
-
-from synapse.util.ratelimitutils import FederationRateLimiter
-
-
-class TransportLayer(TransportLayerServer, TransportLayerClient):
-    """This is a basic implementation of the transport layer that translates
-    transactions and other requests to/from HTTP.
-
-    Attributes:
-        server_name (str): Local home server host
-
-        server (synapse.http.server.HttpServer): the http server to
-                register listeners on
-
-        client (synapse.http.client.HttpClient): the http client used to
-                send requests
-
-        request_handler (TransportRequestHandler): The handler to fire when we
-            receive requests for data.
-
-        received_handler (TransportReceivedHandler): The handler to fire when
-            we receive data.
-    """
-
-    def __init__(self, homeserver, server_name, server, client):
-        """
-        Args:
-            server_name (str): Local home server host
-            server (synapse.protocol.http.HttpServer): the http server to
-                register listeners on
-            client (synapse.protocol.http.HttpClient): the http client used to
-                send requests
-        """
-        self.keyring = homeserver.get_keyring()
-        self.clock = homeserver.get_clock()
-        self.server_name = server_name
-        self.server = server
-        self.client = client
-        self.request_handler = None
-        self.received_handler = None
-
-        self.ratelimiter = FederationRateLimiter(
-            self.clock,
-            window_size=homeserver.config.federation_rc_window_size,
-            sleep_limit=homeserver.config.federation_rc_sleep_limit,
-            sleep_msec=homeserver.config.federation_rc_sleep_delay,
-            reject_limit=homeserver.config.federation_rc_reject_limit,
-            concurrent_requests=homeserver.config.federation_rc_concurrent,
-        )
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 949d01dea8..2b5d40ea7f 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -28,6 +28,10 @@ logger = logging.getLogger(__name__)
 class TransportLayerClient(object):
     """Sends federation HTTP requests to other servers"""
 
+    def __init__(self, hs):
+        self.server_name = hs.hostname
+        self.client = hs.get_http_client()
+
     @log_function
     def get_room_state(self, destination, room_id, event_id):
         """ Requests all state for a given room from the given server at the
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 8dca0a7f6b..65e054f7dd 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -17,7 +17,8 @@ from twisted.internet import defer
 
 from synapse.api.urls import FEDERATION_PREFIX as PREFIX
 from synapse.api.errors import Codes, SynapseError
-from synapse.util.logutils import log_function
+from synapse.http.server import JsonResource
+from synapse.util.ratelimitutils import FederationRateLimiter
 
 import functools
 import logging
@@ -28,9 +29,41 @@ import re
 logger = logging.getLogger(__name__)
 
 
-class TransportLayerServer(object):
+class TransportLayerServer(JsonResource):
     """Handles incoming federation HTTP requests"""
 
+    def __init__(self, hs):
+        self.hs = hs
+        self.clock = hs.get_clock()
+
+        super(TransportLayerServer, self).__init__(hs)
+
+        self.authenticator = Authenticator(hs)
+        self.ratelimiter = FederationRateLimiter(
+            self.clock,
+            window_size=hs.config.federation_rc_window_size,
+            sleep_limit=hs.config.federation_rc_sleep_limit,
+            sleep_msec=hs.config.federation_rc_sleep_delay,
+            reject_limit=hs.config.federation_rc_reject_limit,
+            concurrent_requests=hs.config.federation_rc_concurrent,
+        )
+
+        self.register_servlets()
+
+    def register_servlets(self):
+        register_servlets(
+            self.hs,
+            resource=self,
+            ratelimiter=self.ratelimiter,
+            authenticator=self.authenticator,
+        )
+
+
+class Authenticator(object):
+    def __init__(self, hs):
+        self.keyring = hs.get_keyring()
+        self.server_name = hs.hostname
+
     # A method just so we can pass 'self' as the authenticator to the Servlets
     @defer.inlineCallbacks
     def authenticate_request(self, request):
@@ -98,37 +131,9 @@ class TransportLayerServer(object):
 
         defer.returnValue((origin, content))
 
-    @log_function
-    def register_received_handler(self, handler):
-        """ Register a handler that will be fired when we receive data.
-
-        Args:
-            handler (TransportReceivedHandler)
-        """
-        FederationSendServlet(
-            handler,
-            authenticator=self,
-            ratelimiter=self.ratelimiter,
-            server_name=self.server_name,
-        ).register(self.server)
-
-    @log_function
-    def register_request_handler(self, handler):
-        """ Register a handler that will be fired when we get asked for data.
-
-        Args:
-            handler (TransportRequestHandler)
-        """
-        for servletclass in SERVLET_CLASSES:
-            servletclass(
-                handler,
-                authenticator=self,
-                ratelimiter=self.ratelimiter,
-            ).register(self.server)
-
 
 class BaseFederationServlet(object):
-    def __init__(self, handler, authenticator, ratelimiter):
+    def __init__(self, handler, authenticator, ratelimiter, server_name):
         self.handler = handler
         self.authenticator = authenticator
         self.ratelimiter = ratelimiter
@@ -172,7 +177,9 @@ class FederationSendServlet(BaseFederationServlet):
     PATH = "/send/([^/]*)/"
 
     def __init__(self, handler, server_name, **kwargs):
-        super(FederationSendServlet, self).__init__(handler, **kwargs)
+        super(FederationSendServlet, self).__init__(
+            handler, server_name=server_name, **kwargs
+        )
         self.server_name = server_name
 
     # This is when someone is trying to send us a bunch of data.
@@ -432,6 +439,7 @@ class On3pidBindServlet(BaseFederationServlet):
 
 
 SERVLET_CLASSES = (
+    FederationSendServlet,
     FederationPullServlet,
     FederationEventServlet,
     FederationStateServlet,
@@ -451,3 +459,13 @@ SERVLET_CLASSES = (
     FederationThirdPartyInviteExchangeServlet,
     On3pidBindServlet,
 )
+
+
+def register_servlets(hs, resource, authenticator, ratelimiter):
+    for servletclass in SERVLET_CLASSES:
+        servletclass(
+            handler=hs.get_replication_layer(),
+            authenticator=authenticator,
+            ratelimiter=ratelimiter,
+            server_name=hs.hostname,
+        ).register(resource)