summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/federation_client.py19
-rw-r--r--synapse/federation/federation_server.py16
-rw-r--r--synapse/federation/transaction_queue.py31
-rw-r--r--synapse/federation/transport/server.py343
4 files changed, 218 insertions, 191 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index f131941f45..6811a0e3d1 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -25,6 +25,7 @@ from synapse.api.errors import (
 from synapse.util.expiringcache import ExpiringCache
 from synapse.util.logutils import log_function
 from synapse.events import FrozenEvent
+import synapse.metrics
 
 from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
 
@@ -36,9 +37,17 @@ import random
 logger = logging.getLogger(__name__)
 
 
+# synapse.federation.federation_client is a silly name
+metrics = synapse.metrics.get_metrics_for("synapse.federation.client")
+
+sent_pdus_destination_dist = metrics.register_distribution("sent_pdu_destinations")
+
+sent_edus_counter = metrics.register_counter("sent_edus")
+
+sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
+
+
 class FederationClient(FederationBase):
-    def __init__(self):
-        self._get_pdu_cache = None
 
     def start_get_pdu_cache(self):
         self._get_pdu_cache = ExpiringCache(
@@ -68,6 +77,8 @@ class FederationClient(FederationBase):
         order = self._order
         self._order += 1
 
+        sent_pdus_destination_dist.inc_by(len(destinations))
+
         logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
 
         # TODO, add errback, etc.
@@ -87,6 +98,8 @@ class FederationClient(FederationBase):
             content=content,
         )
 
+        sent_edus_counter.inc()
+
         # TODO, add errback, etc.
         self._transaction_queue.enqueue_edu(edu)
         return defer.succeed(None)
@@ -113,6 +126,8 @@ class FederationClient(FederationBase):
             a Deferred which will eventually yield a JSON object from the
             response
         """
+        sent_queries_counter.inc(query_type)
+
         return self.transport_layer.make_query(
             destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
         )
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 9c7dcdba96..25c0014f97 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -22,6 +22,7 @@ from .units import Transaction, Edu
 from synapse.util.logutils import log_function
 from synapse.util.logcontext import PreserveLoggingContext
 from synapse.events import FrozenEvent
+import synapse.metrics
 
 from synapse.api.errors import FederationError, SynapseError
 
@@ -32,6 +33,15 @@ import logging
 
 logger = logging.getLogger(__name__)
 
+# synapse.federation.federation_server is a silly name
+metrics = synapse.metrics.get_metrics_for("synapse.federation.server")
+
+received_pdus_counter = metrics.register_counter("received_pdus")
+
+received_edus_counter = metrics.register_counter("received_edus")
+
+received_queries_counter = metrics.register_counter("received_queries", labels=["type"])
+
 
 class FederationServer(FederationBase):
     def set_handler(self, handler):
@@ -84,6 +94,8 @@ class FederationServer(FederationBase):
     def on_incoming_transaction(self, transaction_data):
         transaction = Transaction(**transaction_data)
 
+        received_pdus_counter.inc_by(len(transaction.pdus))
+
         for p in transaction.pdus:
             if "unsigned" in p:
                 unsigned = p["unsigned"]
@@ -153,6 +165,8 @@ class FederationServer(FederationBase):
         defer.returnValue((200, response))
 
     def received_edu(self, origin, edu_type, content):
+        received_edus_counter.inc()
+
         if edu_type in self.edu_handlers:
             self.edu_handlers[edu_type](origin, content)
         else:
@@ -204,6 +218,8 @@ class FederationServer(FederationBase):
 
     @defer.inlineCallbacks
     def on_query_request(self, query_type, args):
+        received_queries_counter.inc(query_type)
+
         if query_type in self.query_handlers:
             response = yield self.query_handlers[query_type](args)
             defer.returnValue((200, response))
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 741a4e7a1a..4dccd93d0e 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -25,12 +25,15 @@ from synapse.util.logcontext import PreserveLoggingContext
 from synapse.util.retryutils import (
     get_retry_limiter, NotRetryingDestination,
 )
+import synapse.metrics
 
 import logging
 
 
 logger = logging.getLogger(__name__)
 
+metrics = synapse.metrics.get_metrics_for(__name__)
+
 
 class TransactionQueue(object):
     """This class makes sure we only have one transaction in flight at
@@ -54,11 +57,25 @@ class TransactionQueue(object):
         # done
         self.pending_transactions = {}
 
+        metrics.register_callback(
+            "pending_destinations",
+            lambda: len(self.pending_transactions),
+        )
+
         # Is a mapping from destination -> list of
         # tuple(pending pdus, deferred, order)
-        self.pending_pdus_by_dest = {}
+        self.pending_pdus_by_dest = pdus = {}
         # destination -> list of tuple(edu, deferred)
-        self.pending_edus_by_dest = {}
+        self.pending_edus_by_dest = edus = {}
+
+        metrics.register_callback(
+            "pending_pdus",
+            lambda: sum(map(len, pdus.values())),
+        )
+        metrics.register_callback(
+            "pending_edus",
+            lambda: sum(map(len, edus.values())),
+        )
 
         # destination -> list of tuple(failure, deferred)
         self.pending_failures_by_dest = {}
@@ -115,8 +132,8 @@ class TransactionQueue(object):
                 if not deferred.called:
                     deferred.errback(failure)
 
-            def log_failure(failure):
-                logger.warn("Failed to send pdu", failure.value)
+            def log_failure(f):
+                logger.warn("Failed to send pdu to %s: %s", destination, f.value)
 
             deferred.addErrback(log_failure)
 
@@ -143,8 +160,8 @@ class TransactionQueue(object):
             if not deferred.called:
                 deferred.errback(failure)
 
-        def log_failure(failure):
-            logger.warn("Failed to send pdu", failure.value)
+        def log_failure(f):
+            logger.warn("Failed to send edu to %s: %s", destination, f.value)
 
         deferred.addErrback(log_failure)
 
@@ -174,7 +191,7 @@ class TransactionQueue(object):
                 deferred.errback(f)
 
         def log_failure(f):
-            logger.warn("Failed to send pdu", f.value)
+            logger.warn("Failed to send failure to %s: %s", destination, f.value)
 
         deferred.addErrback(log_failure)
 
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index ece6dbcf62..7838a81362 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -19,6 +19,7 @@ from synapse.api.urls import FEDERATION_PREFIX as PREFIX
 from synapse.api.errors import Codes, SynapseError
 from synapse.util.logutils import log_function
 
+import functools
 import logging
 import simplejson as json
 import re
@@ -30,8 +31,9 @@ logger = logging.getLogger(__name__)
 class TransportLayerServer(object):
     """Handles incoming federation HTTP requests"""
 
+    # A method just so we can pass 'self' as the authenticator to the Servlets
     @defer.inlineCallbacks
-    def _authenticate_request(self, request):
+    def authenticate_request(self, request):
         json_request = {
             "method": request.method,
             "uri": request.uri,
@@ -93,28 +95,6 @@ class TransportLayerServer(object):
 
         defer.returnValue((origin, content))
 
-    def _with_authentication(self, handler):
-        @defer.inlineCallbacks
-        def new_handler(request, *args, **kwargs):
-            try:
-                (origin, content) = yield self._authenticate_request(request)
-                with self.ratelimiter.ratelimit(origin) as d:
-                    yield d
-                    response = yield handler(
-                        origin, content, request.args, *args, **kwargs
-                    )
-            except:
-                logger.exception("_authenticate_request failed")
-                raise
-            defer.returnValue(response)
-        return new_handler
-
-    def rate_limit_origin(self, handler):
-        def new_handler(origin, *args, **kwargs):
-            response = yield handler(origin, *args, **kwargs)
-            defer.returnValue(response)
-        return new_handler()
-
     @log_function
     def register_received_handler(self, handler):
         """ Register a handler that will be fired when we receive data.
@@ -122,14 +102,12 @@ class TransportLayerServer(object):
         Args:
             handler (TransportReceivedHandler)
         """
-        self.received_handler = handler
-
-        # This is when someone is trying to send us a bunch of data.
-        self.server.register_path(
-            "PUT",
-            re.compile("^" + PREFIX + "/send/([^/]*)/$"),
-            self._with_authentication(self._on_send_request)
-        )
+        FederationSendServlet(
+            handler,
+            authenticator=self,
+            ratelimiter=self.ratelimiter,
+            server_name=self.server_name,
+        ).register(self.server)
 
     @log_function
     def register_request_handler(self, handler):
@@ -138,136 +116,65 @@ class TransportLayerServer(object):
         Args:
             handler (TransportRequestHandler)
         """
-        self.request_handler = handler
-
-        # This is for when someone asks us for everything since version X
-        self.server.register_path(
-            "GET",
-            re.compile("^" + PREFIX + "/pull/$"),
-            self._with_authentication(
-                lambda origin, content, query:
-                handler.on_pull_request(query["origin"][0], query["v"])
-            )
-        )
+        for servletclass in SERVLET_CLASSES:
+            servletclass(
+                handler,
+                authenticator=self,
+                ratelimiter=self.ratelimiter,
+            ).register(self.server)
 
-        # This is when someone asks for a data item for a given server
-        # data_id pair.
-        self.server.register_path(
-            "GET",
-            re.compile("^" + PREFIX + "/event/([^/]*)/$"),
-            self._with_authentication(
-                lambda origin, content, query, event_id:
-                handler.on_pdu_request(origin, event_id)
-            )
-        )
 
-        # This is when someone asks for all data for a given context.
-        self.server.register_path(
-            "GET",
-            re.compile("^" + PREFIX + "/state/([^/]*)/$"),
-            self._with_authentication(
-                lambda origin, content, query, context:
-                handler.on_context_state_request(
-                    origin,
-                    context,
-                    query.get("event_id", [None])[0],
-                )
-            )
-        )
+class BaseFederationServlet(object):
+    def __init__(self, handler, authenticator, ratelimiter):
+        self.handler = handler
+        self.authenticator = authenticator
+        self.ratelimiter = ratelimiter
 
-        self.server.register_path(
-            "GET",
-            re.compile("^" + PREFIX + "/backfill/([^/]*)/$"),
-            self._with_authentication(
-                lambda origin, content, query, context:
-                self._on_backfill_request(
-                    origin, context, query["v"], query["limit"]
-                )
-            )
-        )
+    def _wrap(self, code):
+        authenticator = self.authenticator
+        ratelimiter = self.ratelimiter
 
-        # This is when we receive a server-server Query
-        self.server.register_path(
-            "GET",
-            re.compile("^" + PREFIX + "/query/([^/]*)$"),
-            self._with_authentication(
-                lambda origin, content, query, query_type:
-                handler.on_query_request(
-                    query_type,
-                    {k: v[0].decode("utf-8") for k, v in query.items()}
-                )
-            )
-        )
+        @defer.inlineCallbacks
+        @functools.wraps(code)
+        def new_code(request, *args, **kwargs):
+            try:
+                (origin, content) = yield authenticator.authenticate_request(request)
+                with ratelimiter.ratelimit(origin) as d:
+                    yield d
+                    response = yield code(
+                        origin, content, request.args, *args, **kwargs
+                    )
+            except:
+                logger.exception("authenticate_request failed")
+                raise
+            defer.returnValue(response)
 
-        self.server.register_path(
-            "GET",
-            re.compile("^" + PREFIX + "/make_join/([^/]*)/([^/]*)$"),
-            self._with_authentication(
-                lambda origin, content, query, context, user_id:
-                self._on_make_join_request(
-                    origin, content, query, context, user_id
-                )
-            )
-        )
+        # Extra logic that functools.wraps() doesn't finish
+        new_code.__self__ = code.__self__
 
-        self.server.register_path(
-            "GET",
-            re.compile("^" + PREFIX + "/event_auth/([^/]*)/([^/]*)$"),
-            self._with_authentication(
-                lambda origin, content, query, context, event_id:
-                handler.on_event_auth(
-                    origin, context, event_id,
-                )
-            )
-        )
+        return new_code
 
-        self.server.register_path(
-            "PUT",
-            re.compile("^" + PREFIX + "/send_join/([^/]*)/([^/]*)$"),
-            self._with_authentication(
-                lambda origin, content, query, context, event_id:
-                self._on_send_join_request(
-                    origin, content, query,
-                )
-            )
-        )
+    def register(self, server):
+        pattern = re.compile("^" + PREFIX + self.PATH + "$")
 
-        self.server.register_path(
-            "PUT",
-            re.compile("^" + PREFIX + "/invite/([^/]*)/([^/]*)$"),
-            self._with_authentication(
-                lambda origin, content, query, context, event_id:
-                self._on_invite_request(
-                    origin, content, query,
-                )
-            )
-        )
+        for method in ("GET", "PUT", "POST"):
+            code = getattr(self, "on_%s" % (method), None)
+            if code is None:
+                continue
 
-        self.server.register_path(
-            "POST",
-            re.compile("^" + PREFIX + "/query_auth/([^/]*)/([^/]*)$"),
-            self._with_authentication(
-                lambda origin, content, query, context, event_id:
-                self._on_query_auth_request(
-                    origin, content, event_id,
-                )
-            )
-        )
+            server.register_path(method, pattern, self._wrap(code))
 
-        self.server.register_path(
-            "POST",
-            re.compile("^" + PREFIX + "/get_missing_events/([^/]*)/?$"),
-            self._with_authentication(
-                lambda origin, content, query, room_id:
-                self._get_missing_events(
-                    origin, content, room_id,
-                )
-            )
-        )
 
+class FederationSendServlet(BaseFederationServlet):
+    PATH = "/send/([^/]*)/"
+
+    def __init__(self, handler, server_name, **kwargs):
+        super(FederationSendServlet, self).__init__(handler, **kwargs)
+        self.server_name = server_name
+
+    # This is when someone is trying to send us a bunch of data.
     @defer.inlineCallbacks
-    @log_function
-    def _on_send_request(self, origin, content, query, transaction_id):
+    def on_PUT(self, origin, content, query, transaction_id):
         """ Called on PUT /send/<transaction_id>/
 
         Args:
@@ -305,8 +212,7 @@ class TransportLayerServer(object):
             return
 
         try:
-            handler = self.received_handler
-            code, response = yield handler.on_incoming_transaction(
+            code, response = yield self.handler.on_incoming_transaction(
                 transaction_data
             )
         except:
@@ -315,65 +221,123 @@ class TransportLayerServer(object):
 
         defer.returnValue((code, response))
 
-    @log_function
-    def _on_backfill_request(self, origin, context, v_list, limits):
+
+class FederationPullServlet(BaseFederationServlet):
+    PATH = "/pull/"
+
+    # This is for when someone asks us for everything since version X
+    def on_GET(self, origin, content, query):
+        return self.handler.on_pull_request(query["origin"][0], query["v"])
+
+
+class FederationEventServlet(BaseFederationServlet):
+    PATH = "/event/([^/]*)/"
+
+    # This is when someone asks for a data item for a given server data_id pair.
+    def on_GET(self, origin, content, query, event_id):
+        return self.handler.on_pdu_request(origin, event_id)
+
+
+class FederationStateServlet(BaseFederationServlet):
+    PATH = "/state/([^/]*)/"
+
+    # This is when someone asks for all data for a given context.
+    def on_GET(self, origin, content, query, context):
+        return self.handler.on_context_state_request(
+            origin,
+            context,
+            query.get("event_id", [None])[0],
+        )
+
+
+class FederationBackfillServlet(BaseFederationServlet):
+    PATH = "/backfill/([^/]*)/"
+
+    def on_GET(self, origin, content, query, context):
+        versions = query["v"]
+        limits = query["limit"]
+
         if not limits:
-            return defer.succeed(
-                (400, {"error": "Did not include limit param"})
-            )
+            return defer.succeed((400, {"error": "Did not include limit param"}))
 
         limit = int(limits[-1])
 
-        versions = v_list
+        return self.handler.on_backfill_request(origin, context, versions, limit)
 
-        return self.request_handler.on_backfill_request(
-            origin, context, versions, limit
+
+class FederationQueryServlet(BaseFederationServlet):
+    PATH = "/query/([^/]*)"
+
+    # This is when we receive a server-server Query
+    def on_GET(self, origin, content, query, query_type):
+        return self.handler.on_query_request(
+            query_type,
+            {k: v[0].decode("utf-8") for k, v in query.items()}
         )
 
+
+class FederationMakeJoinServlet(BaseFederationServlet):
+    PATH = "/make_join/([^/]*)/([^/]*)"
+
     @defer.inlineCallbacks
-    @log_function
-    def _on_make_join_request(self, origin, content, query, context, user_id):
-        content = yield self.request_handler.on_make_join_request(
-            context, user_id,
-        )
+    def on_GET(self, origin, content, query, context, user_id):
+        content = yield self.handler.on_make_join_request(context, user_id)
         defer.returnValue((200, content))
 
-    @defer.inlineCallbacks
-    @log_function
-    def _on_send_join_request(self, origin, content, query):
-        content = yield self.request_handler.on_send_join_request(
-            origin, content,
-        )
 
-        defer.returnValue((200, content))
+class FederationEventAuthServlet(BaseFederationServlet):
+    PATH = "/event_auth/([^/]*)/([^/]*)"
+
+    def on_GET(self, origin, content, query, context, event_id):
+        return self.handler.on_event_auth(origin, context, event_id)
+
+
+class FederationSendJoinServlet(BaseFederationServlet):
+    PATH = "/send_join/([^/]*)/([^/]*)"
 
     @defer.inlineCallbacks
-    @log_function
-    def _on_invite_request(self, origin, content, query):
-        content = yield self.request_handler.on_invite_request(
-            origin, content,
-        )
+    def on_PUT(self, origin, content, query, context, event_id):
+        # TODO(paul): assert that context/event_id parsed from path actually
+        #   match those given in content
+        content = yield self.handler.on_send_join_request(origin, content)
+        defer.returnValue((200, content))
+
+
+class FederationInviteServlet(BaseFederationServlet):
+    PATH = "/invite/([^/]*)/([^/]*)"
 
+    @defer.inlineCallbacks
+    def on_PUT(self, origin, content, query, context, event_id):
+        # TODO(paul): assert that context/event_id parsed from path actually
+        #   match those given in content
+        content = yield self.handler.on_invite_request(origin, content)
         defer.returnValue((200, content))
 
+
+class FederationQueryAuthServlet(BaseFederationServlet):
+    PATH = "/query_auth/([^/]*)/([^/]*)"
+
     @defer.inlineCallbacks
-    @log_function
-    def _on_query_auth_request(self, origin, content, event_id):
-        new_content = yield self.request_handler.on_query_auth_request(
+    def on_POST(self, origin, content, query, context, event_id):
+        new_content = yield self.handler.on_query_auth_request(
             origin, content, event_id
         )
 
         defer.returnValue((200, new_content))
 
+
+class FederationGetMissingEventsServlet(BaseFederationServlet):
+    # TODO(paul): Why does this path alone end with "/?" optional?
+    PATH = "/get_missing_events/([^/]*)/?"
+
     @defer.inlineCallbacks
-    @log_function
-    def _get_missing_events(self, origin, content, room_id):
+    def on_POST(self, origin, content, query, room_id):
         limit = int(content.get("limit", 10))
         min_depth = int(content.get("min_depth", 0))
         earliest_events = content.get("earliest_events", [])
         latest_events = content.get("latest_events", [])
 
-        content = yield self.request_handler.on_get_missing_events(
+        content = yield self.handler.on_get_missing_events(
             origin,
             room_id=room_id,
             earliest_events=earliest_events,
@@ -383,3 +347,18 @@ class TransportLayerServer(object):
         )
 
         defer.returnValue((200, content))
+
+
+SERVLET_CLASSES = (
+    FederationPullServlet,
+    FederationEventServlet,
+    FederationStateServlet,
+    FederationBackfillServlet,
+    FederationQueryServlet,
+    FederationMakeJoinServlet,
+    FederationEventServlet,
+    FederationSendJoinServlet,
+    FederationInviteServlet,
+    FederationQueryAuthServlet,
+    FederationGetMissingEventsServlet,
+)