summary refs log tree commit diff
path: root/synapse/federation/transport
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/transport')
-rw-r--r--synapse/federation/transport/__init__.py12
-rw-r--r--synapse/federation/transport/client.py67
-rw-r--r--synapse/federation/transport/server.py68
3 files changed, 121 insertions, 26 deletions
diff --git a/synapse/federation/transport/__init__.py b/synapse/federation/transport/__init__.py
index 6800ac46c5..2a671b9aec 100644
--- a/synapse/federation/transport/__init__.py
+++ b/synapse/federation/transport/__init__.py
@@ -24,6 +24,8 @@ communicate over a different (albeit still reliable) protocol.
 from .server import TransportLayerServer
 from .client import TransportLayerClient
 
+from synapse.util.ratelimitutils import FederationRateLimiter
+
 
 class TransportLayer(TransportLayerServer, TransportLayerClient):
     """This is a basic implementation of the transport layer that translates
@@ -55,8 +57,18 @@ class TransportLayer(TransportLayerServer, TransportLayerClient):
                 send requests
         """
         self.keyring = homeserver.get_keyring()
+        self.clock = homeserver.get_clock()
         self.server_name = server_name
         self.server = server
         self.client = client
         self.request_handler = None
         self.received_handler = None
+
+        self.ratelimiter = FederationRateLimiter(
+            self.clock,
+            window_size=homeserver.config.federation_rc_window_size,
+            sleep_limit=homeserver.config.federation_rc_sleep_limit,
+            sleep_msec=homeserver.config.federation_rc_sleep_delay,
+            reject_limit=homeserver.config.federation_rc_reject_limit,
+            concurrent_requests=homeserver.config.federation_rc_concurrent,
+        )
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index e634a3a213..80d03012b7 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -19,7 +19,6 @@ from synapse.api.urls import FEDERATION_PREFIX as PREFIX
 from synapse.util.logutils import log_function
 
 import logging
-import json
 
 
 logger = logging.getLogger(__name__)
@@ -129,7 +128,7 @@ class TransportLayerClient(object):
         # generated by the json_data_callback.
         json_data = transaction.get_dict()
 
-        code, response = yield self.client.put_json(
+        response = yield self.client.put_json(
             transaction.destination,
             path=PREFIX + "/send/%s/" % transaction.transaction_id,
             data=json_data,
@@ -137,79 +136,105 @@ class TransportLayerClient(object):
         )
 
         logger.debug(
-            "send_data dest=%s, txid=%s, got response: %d",
-            transaction.destination, transaction.transaction_id, code
+            "send_data dest=%s, txid=%s, got response: 200",
+            transaction.destination, transaction.transaction_id,
         )
 
-        defer.returnValue((code, response))
+        defer.returnValue(response)
 
     @defer.inlineCallbacks
     @log_function
     def make_query(self, destination, query_type, args, retry_on_dns_fail):
         path = PREFIX + "/query/%s" % query_type
 
-        response = yield self.client.get_json(
+        content = yield self.client.get_json(
             destination=destination,
             path=path,
             args=args,
             retry_on_dns_fail=retry_on_dns_fail,
         )
 
-        defer.returnValue(response)
+        defer.returnValue(content)
 
     @defer.inlineCallbacks
     @log_function
     def make_join(self, destination, room_id, user_id, retry_on_dns_fail=True):
         path = PREFIX + "/make_join/%s/%s" % (room_id, user_id)
 
-        response = yield self.client.get_json(
+        content = yield self.client.get_json(
             destination=destination,
             path=path,
             retry_on_dns_fail=retry_on_dns_fail,
         )
 
-        defer.returnValue(response)
+        defer.returnValue(content)
 
     @defer.inlineCallbacks
     @log_function
     def send_join(self, destination, room_id, event_id, content):
         path = PREFIX + "/send_join/%s/%s" % (room_id, event_id)
 
-        code, content = yield self.client.put_json(
+        response = yield self.client.put_json(
             destination=destination,
             path=path,
             data=content,
         )
 
-        if not 200 <= code < 300:
-            raise RuntimeError("Got %d from send_join", code)
-
-        defer.returnValue(json.loads(content))
+        defer.returnValue(response)
 
     @defer.inlineCallbacks
     @log_function
     def send_invite(self, destination, room_id, event_id, content):
         path = PREFIX + "/invite/%s/%s" % (room_id, event_id)
 
-        code, content = yield self.client.put_json(
+        response = yield self.client.put_json(
             destination=destination,
             path=path,
             data=content,
         )
 
-        if not 200 <= code < 300:
-            raise RuntimeError("Got %d from send_invite", code)
-
-        defer.returnValue(json.loads(content))
+        defer.returnValue(response)
 
     @defer.inlineCallbacks
     @log_function
     def get_event_auth(self, destination, room_id, event_id):
         path = PREFIX + "/event_auth/%s/%s" % (room_id, event_id)
 
-        response = yield self.client.get_json(
+        content = yield self.client.get_json(
             destination=destination,
             path=path,
         )
 
-        defer.returnValue(response)
+        defer.returnValue(content)
+
+    @defer.inlineCallbacks
+    @log_function
+    def send_query_auth(self, destination, room_id, event_id, content):
+        path = PREFIX + "/query_auth/%s/%s" % (room_id, event_id)
+
+        content = yield self.client.post_json(
+            destination=destination,
+            path=path,
+            data=content,
+        )
+
+        defer.returnValue(content)
+
+    @defer.inlineCallbacks
+    @log_function
+    def get_missing_events(self, destination, room_id, earliest_events,
+                           latest_events, limit, min_depth):
+        path = PREFIX + "/get_missing_events/%s" % (room_id,)
+
+        content = yield self.client.post_json(
+            destination=destination,
+            path=path,
+            data={
+                "limit": int(limit),
+                "min_depth": int(min_depth),
+                "earliest_events": earliest_events,
+                "latest_events": latest_events,
+            }
+        )
+
+        defer.returnValue(content)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index a380a6910b..ece6dbcf62 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -20,7 +20,7 @@ from synapse.api.errors import Codes, SynapseError
 from synapse.util.logutils import log_function
 
 import logging
-import json
+import simplejson as json
 import re
 
 
@@ -42,7 +42,7 @@ class TransportLayerServer(object):
         content = None
         origin = None
 
-        if request.method == "PUT":
+        if request.method in ["PUT", "POST"]:
             # TODO: Handle other method types? other content types?
             try:
                 content_bytes = request.content.read()
@@ -98,15 +98,23 @@ class TransportLayerServer(object):
         def new_handler(request, *args, **kwargs):
             try:
                 (origin, content) = yield self._authenticate_request(request)
-                response = yield handler(
-                    origin, content, request.args, *args, **kwargs
-                )
+                with self.ratelimiter.ratelimit(origin) as d:
+                    yield d
+                    response = yield handler(
+                        origin, content, request.args, *args, **kwargs
+                    )
             except:
                 logger.exception("_authenticate_request failed")
                 raise
             defer.returnValue(response)
         return new_handler
 
+    def rate_limit_origin(self, handler):
+        def new_handler(origin, *args, **kwargs):
+            response = yield handler(origin, *args, **kwargs)
+            defer.returnValue(response)
+        return new_handler()
+
     @log_function
     def register_received_handler(self, handler):
         """ Register a handler that will be fired when we receive data.
@@ -235,6 +243,28 @@ class TransportLayerServer(object):
             )
         )
 
+        self.server.register_path(
+            "POST",
+            re.compile("^" + PREFIX + "/query_auth/([^/]*)/([^/]*)$"),
+            self._with_authentication(
+                lambda origin, content, query, context, event_id:
+                self._on_query_auth_request(
+                    origin, content, event_id,
+                )
+            )
+        )
+
+        self.server.register_path(
+            "POST",
+            re.compile("^" + PREFIX + "/get_missing_events/([^/]*)/?$"),
+            self._with_authentication(
+                lambda origin, content, query, room_id:
+                self._get_missing_events(
+                    origin, content, room_id,
+                )
+            )
+        )
+
     @defer.inlineCallbacks
     @log_function
     def _on_send_request(self, origin, content, query, transaction_id):
@@ -325,3 +355,31 @@ class TransportLayerServer(object):
         )
 
         defer.returnValue((200, content))
+
+    @defer.inlineCallbacks
+    @log_function
+    def _on_query_auth_request(self, origin, content, event_id):
+        new_content = yield self.request_handler.on_query_auth_request(
+            origin, content, event_id
+        )
+
+        defer.returnValue((200, new_content))
+
+    @defer.inlineCallbacks
+    @log_function
+    def _get_missing_events(self, origin, content, room_id):
+        limit = int(content.get("limit", 10))
+        min_depth = int(content.get("min_depth", 0))
+        earliest_events = content.get("earliest_events", [])
+        latest_events = content.get("latest_events", [])
+
+        content = yield self.request_handler.on_get_missing_events(
+            origin,
+            room_id=room_id,
+            earliest_events=earliest_events,
+            latest_events=latest_events,
+            min_depth=min_depth,
+            limit=limit,
+        )
+
+        defer.returnValue((200, content))