summary refs log tree commit diff
path: root/synapse/federation/federation_server.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/federation_server.py')
-rw-r--r--synapse/federation/federation_server.py212
1 files changed, 125 insertions, 87 deletions
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 2a589524a4..aba19639c7 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -21,10 +21,11 @@ from .units import Transaction, Edu
 
 from synapse.util.async import Linearizer
 from synapse.util.logutils import log_function
+from synapse.util.caches.response_cache import ResponseCache
 from synapse.events import FrozenEvent
 import synapse.metrics
 
-from synapse.api.errors import FederationError, SynapseError
+from synapse.api.errors import AuthError, FederationError, SynapseError
 
 from synapse.crypto.event_signing import compute_event_signature
 
@@ -48,7 +49,14 @@ class FederationServer(FederationBase):
     def __init__(self, hs):
         super(FederationServer, self).__init__(hs)
 
+        self.auth = hs.get_auth()
+
         self._room_pdu_linearizer = Linearizer()
+        self._server_linearizer = Linearizer()
+
+        # We cache responses to state queries, as they take a while and often
+        # come in waves.
+        self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
 
     def set_handler(self, handler):
         """Sets the handler that the replication layer will use to communicate
@@ -89,11 +97,14 @@ class FederationServer(FederationBase):
     @defer.inlineCallbacks
     @log_function
     def on_backfill_request(self, origin, room_id, versions, limit):
-        pdus = yield self.handler.on_backfill_request(
-            origin, room_id, versions, limit
-        )
+        with (yield self._server_linearizer.queue((origin, room_id))):
+            pdus = yield self.handler.on_backfill_request(
+                origin, room_id, versions, limit
+            )
 
-        defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
+            res = self._transaction_from_pdus(pdus).get_dict()
+
+        defer.returnValue((200, res))
 
     @defer.inlineCallbacks
     @log_function
@@ -184,32 +195,71 @@ class FederationServer(FederationBase):
     @defer.inlineCallbacks
     @log_function
     def on_context_state_request(self, origin, room_id, event_id):
-        if event_id:
-            pdus = yield self.handler.get_state_for_pdu(
-                origin, room_id, event_id,
-            )
-            auth_chain = yield self.store.get_auth_chain(
-                [pdu.event_id for pdu in pdus]
-            )
+        if not event_id:
+            raise NotImplementedError("Specify an event")
 
-            for event in auth_chain:
-                # We sign these again because there was a bug where we
-                # incorrectly signed things the first time round
-                if self.hs.is_mine_id(event.event_id):
-                    event.signatures.update(
-                        compute_event_signature(
-                            event,
-                            self.hs.hostname,
-                            self.hs.config.signing_key[0]
-                        )
-                    )
+        in_room = yield self.auth.check_host_in_room(room_id, origin)
+        if not in_room:
+            raise AuthError(403, "Host not in room.")
+
+        result = self._state_resp_cache.get((room_id, event_id))
+        if not result:
+            with (yield self._server_linearizer.queue((origin, room_id))):
+                resp = yield self._state_resp_cache.set(
+                    (room_id, event_id),
+                    self._on_context_state_request_compute(room_id, event_id)
+                )
         else:
+            resp = yield result
+
+        defer.returnValue((200, resp))
+
+    @defer.inlineCallbacks
+    def on_state_ids_request(self, origin, room_id, event_id):
+        if not event_id:
             raise NotImplementedError("Specify an event")
 
+        in_room = yield self.auth.check_host_in_room(room_id, origin)
+        if not in_room:
+            raise AuthError(403, "Host not in room.")
+
+        pdus = yield self.handler.get_state_for_pdu(
+            room_id, event_id,
+        )
+        auth_chain = yield self.store.get_auth_chain(
+            [pdu.event_id for pdu in pdus]
+        )
+
         defer.returnValue((200, {
+            "pdu_ids": [pdu.event_id for pdu in pdus],
+            "auth_chain_ids": [pdu.event_id for pdu in auth_chain],
+        }))
+
+    @defer.inlineCallbacks
+    def _on_context_state_request_compute(self, room_id, event_id):
+        pdus = yield self.handler.get_state_for_pdu(
+            room_id, event_id,
+        )
+        auth_chain = yield self.store.get_auth_chain(
+            [pdu.event_id for pdu in pdus]
+        )
+
+        for event in auth_chain:
+            # We sign these again because there was a bug where we
+            # incorrectly signed things the first time round
+            if self.hs.is_mine_id(event.event_id):
+                event.signatures.update(
+                    compute_event_signature(
+                        event,
+                        self.hs.hostname,
+                        self.hs.config.signing_key[0]
+                    )
+                )
+
+        defer.returnValue({
             "pdus": [pdu.get_pdu_json() for pdu in pdus],
             "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
-        }))
+        })
 
     @defer.inlineCallbacks
     @log_function
@@ -283,14 +333,16 @@ class FederationServer(FederationBase):
 
     @defer.inlineCallbacks
     def on_event_auth(self, origin, room_id, event_id):
-        time_now = self._clock.time_msec()
-        auth_pdus = yield self.handler.on_event_auth(event_id)
-        defer.returnValue((200, {
-            "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
-        }))
+        with (yield self._server_linearizer.queue((origin, room_id))):
+            time_now = self._clock.time_msec()
+            auth_pdus = yield self.handler.on_event_auth(event_id)
+            res = {
+                "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
+            }
+        defer.returnValue((200, res))
 
     @defer.inlineCallbacks
-    def on_query_auth_request(self, origin, content, event_id):
+    def on_query_auth_request(self, origin, content, room_id, event_id):
         """
         Content is a dict with keys::
             auth_chain (list): A list of events that give the auth chain.
@@ -309,58 +361,41 @@ class FederationServer(FederationBase):
         Returns:
             Deferred: Results in `dict` with the same format as `content`
         """
-        auth_chain = [
-            self.event_from_pdu_json(e)
-            for e in content["auth_chain"]
-        ]
-
-        signed_auth = yield self._check_sigs_and_hash_and_fetch(
-            origin, auth_chain, outlier=True
-        )
+        with (yield self._server_linearizer.queue((origin, room_id))):
+            auth_chain = [
+                self.event_from_pdu_json(e)
+                for e in content["auth_chain"]
+            ]
+
+            signed_auth = yield self._check_sigs_and_hash_and_fetch(
+                origin, auth_chain, outlier=True
+            )
 
-        ret = yield self.handler.on_query_auth(
-            origin,
-            event_id,
-            signed_auth,
-            content.get("rejects", []),
-            content.get("missing", []),
-        )
+            ret = yield self.handler.on_query_auth(
+                origin,
+                event_id,
+                signed_auth,
+                content.get("rejects", []),
+                content.get("missing", []),
+            )
 
-        time_now = self._clock.time_msec()
-        send_content = {
-            "auth_chain": [
-                e.get_pdu_json(time_now)
-                for e in ret["auth_chain"]
-            ],
-            "rejects": ret.get("rejects", []),
-            "missing": ret.get("missing", []),
-        }
+            time_now = self._clock.time_msec()
+            send_content = {
+                "auth_chain": [
+                    e.get_pdu_json(time_now)
+                    for e in ret["auth_chain"]
+                ],
+                "rejects": ret.get("rejects", []),
+                "missing": ret.get("missing", []),
+            }
 
         defer.returnValue(
             (200, send_content)
         )
 
-    @defer.inlineCallbacks
     @log_function
     def on_query_client_keys(self, origin, content):
-        query = []
-        for user_id, device_ids in content.get("device_keys", {}).items():
-            if not device_ids:
-                query.append((user_id, None))
-            else:
-                for device_id in device_ids:
-                    query.append((user_id, device_id))
-
-        results = yield self.store.get_e2e_device_keys(query)
-
-        json_result = {}
-        for user_id, device_keys in results.items():
-            for device_id, json_bytes in device_keys.items():
-                json_result.setdefault(user_id, {})[device_id] = json.loads(
-                    json_bytes
-                )
-
-        defer.returnValue({"device_keys": json_result})
+        return self.on_query_request("client_keys", content)
 
     @defer.inlineCallbacks
     @log_function
@@ -386,21 +421,24 @@ class FederationServer(FederationBase):
     @log_function
     def on_get_missing_events(self, origin, room_id, earliest_events,
                               latest_events, limit, min_depth):
-        logger.info(
-            "on_get_missing_events: earliest_events: %r, latest_events: %r,"
-            " limit: %d, min_depth: %d",
-            earliest_events, latest_events, limit, min_depth
-        )
-        missing_events = yield self.handler.on_get_missing_events(
-            origin, room_id, earliest_events, latest_events, limit, min_depth
-        )
+        with (yield self._server_linearizer.queue((origin, room_id))):
+            logger.info(
+                "on_get_missing_events: earliest_events: %r, latest_events: %r,"
+                " limit: %d, min_depth: %d",
+                earliest_events, latest_events, limit, min_depth
+            )
+            missing_events = yield self.handler.on_get_missing_events(
+                origin, room_id, earliest_events, latest_events, limit, min_depth
+            )
 
-        if len(missing_events) < 5:
-            logger.info("Returning %d events: %r", len(missing_events), missing_events)
-        else:
-            logger.info("Returning %d events", len(missing_events))
+            if len(missing_events) < 5:
+                logger.info(
+                    "Returning %d events: %r", len(missing_events), missing_events
+                )
+            else:
+                logger.info("Returning %d events", len(missing_events))
 
-        time_now = self._clock.time_msec()
+            time_now = self._clock.time_msec()
 
         defer.returnValue({
             "events": [ev.get_pdu_json(time_now) for ev in missing_events],
@@ -567,7 +605,7 @@ class FederationServer(FederationBase):
                     origin, pdu.room_id, pdu.event_id,
                 )
             except:
-                logger.warn("Failed to get state for event: %s", pdu.event_id)
+                logger.exception("Failed to get state for event: %s", pdu.event_id)
 
         yield self.handler.on_receive_pdu(
             origin,