summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/federation_server.py32
-rw-r--r--synapse/federation/transport/__init__.py12
-rw-r--r--synapse/federation/transport/server.py14
3 files changed, 35 insertions, 23 deletions
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 22b9663831..bc9bac809a 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -112,17 +112,19 @@ class FederationServer(FederationBase):
         logger.debug("[%s] Transaction is new", transaction.transaction_id)
 
         with PreserveLoggingContext():
-            dl = []
+            results = []
+
             for pdu in pdu_list:
                 d = self._handle_new_pdu(transaction.origin, pdu)
 
-                def handle_failure(failure):
-                    failure.trap(FederationError)
-                    self.send_failure(failure.value, transaction.origin)
-
-                d.addErrback(handle_failure)
-
-                dl.append(d)
+                try:
+                    yield d
+                    results.append({})
+                except FederationError as e:
+                    self.send_failure(e, transaction.origin)
+                    results.append({"error": str(e)})
+                except Exception as e:
+                    results.append({"error": str(e)})
 
             if hasattr(transaction, "edus"):
                 for edu in [Edu(**x) for x in transaction.edus]:
@@ -135,21 +137,11 @@ class FederationServer(FederationBase):
             for failure in getattr(transaction, "pdu_failures", []):
                 logger.info("Got failure %r", failure)
 
-            results = yield defer.DeferredList(dl, consumeErrors=True)
-
-        ret = []
-        for r in results:
-            if r[0]:
-                ret.append({})
-            else:
-                logger.exception(r[1])
-                ret.append({"error": str(r[1].value)})
-
-        logger.debug("Returning: %s", str(ret))
+        logger.debug("Returning: %s", str(results))
 
         response = {
             "pdus": dict(zip(
-                (p.event_id for p in pdu_list), ret
+                (p.event_id for p in pdu_list), results
             )),
         }
 
diff --git a/synapse/federation/transport/__init__.py b/synapse/federation/transport/__init__.py
index 6800ac46c5..2a671b9aec 100644
--- a/synapse/federation/transport/__init__.py
+++ b/synapse/federation/transport/__init__.py
@@ -24,6 +24,8 @@ communicate over a different (albeit still reliable) protocol.
 from .server import TransportLayerServer
 from .client import TransportLayerClient
 
+from synapse.util.ratelimitutils import FederationRateLimiter
+
 
 class TransportLayer(TransportLayerServer, TransportLayerClient):
     """This is a basic implementation of the transport layer that translates
@@ -55,8 +57,18 @@ class TransportLayer(TransportLayerServer, TransportLayerClient):
                 send requests
         """
         self.keyring = homeserver.get_keyring()
+        self.clock = homeserver.get_clock()
         self.server_name = server_name
         self.server = server
         self.client = client
         self.request_handler = None
         self.received_handler = None
+
+        self.ratelimiter = FederationRateLimiter(
+            self.clock,
+            window_size=homeserver.config.federation_rc_window_size,
+            sleep_limit=homeserver.config.federation_rc_sleep_limit,
+            sleep_msec=homeserver.config.federation_rc_sleep_delay,
+            reject_limit=homeserver.config.federation_rc_reject_limit,
+            concurrent_requests=homeserver.config.federation_rc_concurrent,
+        )
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 2ffb37aa18..fce9c0195e 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -98,15 +98,23 @@ class TransportLayerServer(object):
         def new_handler(request, *args, **kwargs):
             try:
                 (origin, content) = yield self._authenticate_request(request)
-                response = yield handler(
-                    origin, content, request.args, *args, **kwargs
-                )
+                with self.ratelimiter.ratelimit(origin) as d:
+                    yield d
+                    response = yield handler(
+                        origin, content, request.args, *args, **kwargs
+                    )
             except:
                 logger.exception("_authenticate_request failed")
                 raise
             defer.returnValue(response)
         return new_handler
 
+    def rate_limit_origin(self, handler):
+        def new_handler(origin, *args, **kwargs):
+            response = yield handler(origin, *args, **kwargs)
+            defer.returnValue(response)
+        return new_handler()
+
     @log_function
     def register_received_handler(self, handler):
         """ Register a handler that will be fired when we receive data.