summary refs log tree commit diff
path: root/synapse/federation/federation_server.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/federation_server.py')
-rw-r--r--synapse/federation/federation_server.py280
1 files changed, 177 insertions, 103 deletions
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index f1d231b9d8..aba19639c7 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -19,11 +19,13 @@ from twisted.internet import defer
 from .federation_base import FederationBase
 from .units import Transaction, Edu
 
+from synapse.util.async import Linearizer
 from synapse.util.logutils import log_function
+from synapse.util.caches.response_cache import ResponseCache
 from synapse.events import FrozenEvent
 import synapse.metrics
 
-from synapse.api.errors import FederationError, SynapseError
+from synapse.api.errors import AuthError, FederationError, SynapseError
 
 from synapse.crypto.event_signing import compute_event_signature
 
@@ -44,6 +46,18 @@ received_queries_counter = metrics.register_counter("received_queries", labels=[
 
 
 class FederationServer(FederationBase):
+    def __init__(self, hs):
+        super(FederationServer, self).__init__(hs)
+
+        self.auth = hs.get_auth()
+
+        self._room_pdu_linearizer = Linearizer()
+        self._server_linearizer = Linearizer()
+
+        # We cache responses to state queries, as they take a while and often
+        # come in waves.
+        self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
+
     def set_handler(self, handler):
         """Sets the handler that the replication layer will use to communicate
         receipt of new PDUs from other home servers. The required methods are
@@ -83,11 +97,14 @@ class FederationServer(FederationBase):
     @defer.inlineCallbacks
     @log_function
     def on_backfill_request(self, origin, room_id, versions, limit):
-        pdus = yield self.handler.on_backfill_request(
-            origin, room_id, versions, limit
-        )
+        with (yield self._server_linearizer.queue((origin, room_id))):
+            pdus = yield self.handler.on_backfill_request(
+                origin, room_id, versions, limit
+            )
+
+            res = self._transaction_from_pdus(pdus).get_dict()
 
-        defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
+        defer.returnValue((200, res))
 
     @defer.inlineCallbacks
     @log_function
@@ -178,15 +195,59 @@ class FederationServer(FederationBase):
     @defer.inlineCallbacks
     @log_function
     def on_context_state_request(self, origin, room_id, event_id):
-        if event_id:
-            pdus = yield self.handler.get_state_for_pdu(
-                origin, room_id, event_id,
-            )
-            auth_chain = yield self.store.get_auth_chain(
-                [pdu.event_id for pdu in pdus]
-            )
+        if not event_id:
+            raise NotImplementedError("Specify an event")
+
+        in_room = yield self.auth.check_host_in_room(room_id, origin)
+        if not in_room:
+            raise AuthError(403, "Host not in room.")
+
+        result = self._state_resp_cache.get((room_id, event_id))
+        if not result:
+            with (yield self._server_linearizer.queue((origin, room_id))):
+                resp = yield self._state_resp_cache.set(
+                    (room_id, event_id),
+                    self._on_context_state_request_compute(room_id, event_id)
+                )
+        else:
+            resp = yield result
+
+        defer.returnValue((200, resp))
+
+    @defer.inlineCallbacks
+    def on_state_ids_request(self, origin, room_id, event_id):
+        if not event_id:
+            raise NotImplementedError("Specify an event")
+
+        in_room = yield self.auth.check_host_in_room(room_id, origin)
+        if not in_room:
+            raise AuthError(403, "Host not in room.")
 
-            for event in auth_chain:
+        pdus = yield self.handler.get_state_for_pdu(
+            room_id, event_id,
+        )
+        auth_chain = yield self.store.get_auth_chain(
+            [pdu.event_id for pdu in pdus]
+        )
+
+        defer.returnValue((200, {
+            "pdu_ids": [pdu.event_id for pdu in pdus],
+            "auth_chain_ids": [pdu.event_id for pdu in auth_chain],
+        }))
+
+    @defer.inlineCallbacks
+    def _on_context_state_request_compute(self, room_id, event_id):
+        pdus = yield self.handler.get_state_for_pdu(
+            room_id, event_id,
+        )
+        auth_chain = yield self.store.get_auth_chain(
+            [pdu.event_id for pdu in pdus]
+        )
+
+        for event in auth_chain:
+            # We sign these again because there was a bug where we
+            # incorrectly signed things the first time round
+            if self.hs.is_mine_id(event.event_id):
                 event.signatures.update(
                     compute_event_signature(
                         event,
@@ -194,13 +255,11 @@ class FederationServer(FederationBase):
                         self.hs.config.signing_key[0]
                     )
                 )
-        else:
-            raise NotImplementedError("Specify an event")
 
-        defer.returnValue((200, {
+        defer.returnValue({
             "pdus": [pdu.get_pdu_json() for pdu in pdus],
             "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
-        }))
+        })
 
     @defer.inlineCallbacks
     @log_function
@@ -274,14 +333,16 @@ class FederationServer(FederationBase):
 
     @defer.inlineCallbacks
     def on_event_auth(self, origin, room_id, event_id):
-        time_now = self._clock.time_msec()
-        auth_pdus = yield self.handler.on_event_auth(event_id)
-        defer.returnValue((200, {
-            "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
-        }))
+        with (yield self._server_linearizer.queue((origin, room_id))):
+            time_now = self._clock.time_msec()
+            auth_pdus = yield self.handler.on_event_auth(event_id)
+            res = {
+                "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
+            }
+        defer.returnValue((200, res))
 
     @defer.inlineCallbacks
-    def on_query_auth_request(self, origin, content, event_id):
+    def on_query_auth_request(self, origin, content, room_id, event_id):
         """
         Content is a dict with keys::
             auth_chain (list): A list of events that give the auth chain.
@@ -300,58 +361,41 @@ class FederationServer(FederationBase):
         Returns:
             Deferred: Results in `dict` with the same format as `content`
         """
-        auth_chain = [
-            self.event_from_pdu_json(e)
-            for e in content["auth_chain"]
-        ]
-
-        signed_auth = yield self._check_sigs_and_hash_and_fetch(
-            origin, auth_chain, outlier=True
-        )
+        with (yield self._server_linearizer.queue((origin, room_id))):
+            auth_chain = [
+                self.event_from_pdu_json(e)
+                for e in content["auth_chain"]
+            ]
+
+            signed_auth = yield self._check_sigs_and_hash_and_fetch(
+                origin, auth_chain, outlier=True
+            )
 
-        ret = yield self.handler.on_query_auth(
-            origin,
-            event_id,
-            signed_auth,
-            content.get("rejects", []),
-            content.get("missing", []),
-        )
+            ret = yield self.handler.on_query_auth(
+                origin,
+                event_id,
+                signed_auth,
+                content.get("rejects", []),
+                content.get("missing", []),
+            )
 
-        time_now = self._clock.time_msec()
-        send_content = {
-            "auth_chain": [
-                e.get_pdu_json(time_now)
-                for e in ret["auth_chain"]
-            ],
-            "rejects": ret.get("rejects", []),
-            "missing": ret.get("missing", []),
-        }
+            time_now = self._clock.time_msec()
+            send_content = {
+                "auth_chain": [
+                    e.get_pdu_json(time_now)
+                    for e in ret["auth_chain"]
+                ],
+                "rejects": ret.get("rejects", []),
+                "missing": ret.get("missing", []),
+            }
 
         defer.returnValue(
             (200, send_content)
         )
 
-    @defer.inlineCallbacks
     @log_function
     def on_query_client_keys(self, origin, content):
-        query = []
-        for user_id, device_ids in content.get("device_keys", {}).items():
-            if not device_ids:
-                query.append((user_id, None))
-            else:
-                for device_id in device_ids:
-                    query.append((user_id, device_id))
-
-        results = yield self.store.get_e2e_device_keys(query)
-
-        json_result = {}
-        for user_id, device_keys in results.items():
-            for device_id, json_bytes in device_keys.items():
-                json_result.setdefault(user_id, {})[device_id] = json.loads(
-                    json_bytes
-                )
-
-        defer.returnValue({"device_keys": json_result})
+        return self.on_query_request("client_keys", content)
 
     @defer.inlineCallbacks
     @log_function
@@ -377,11 +421,24 @@ class FederationServer(FederationBase):
     @log_function
     def on_get_missing_events(self, origin, room_id, earliest_events,
                               latest_events, limit, min_depth):
-        missing_events = yield self.handler.on_get_missing_events(
-            origin, room_id, earliest_events, latest_events, limit, min_depth
-        )
+        with (yield self._server_linearizer.queue((origin, room_id))):
+            logger.info(
+                "on_get_missing_events: earliest_events: %r, latest_events: %r,"
+                " limit: %d, min_depth: %d",
+                earliest_events, latest_events, limit, min_depth
+            )
+            missing_events = yield self.handler.on_get_missing_events(
+                origin, room_id, earliest_events, latest_events, limit, min_depth
+            )
 
-        time_now = self._clock.time_msec()
+            if len(missing_events) < 5:
+                logger.info(
+                    "Returning %d events: %r", len(missing_events), missing_events
+                )
+            else:
+                logger.info("Returning %d events", len(missing_events))
+
+            time_now = self._clock.time_msec()
 
         defer.returnValue({
             "events": [ev.get_pdu_json(time_now) for ev in missing_events],
@@ -481,42 +538,59 @@ class FederationServer(FederationBase):
                 pdu.internal_metadata.outlier = True
             elif min_depth and pdu.depth > min_depth:
                 if get_missing and prevs - seen:
-                    latest = yield self.store.get_latest_event_ids_in_room(
-                        pdu.room_id
-                    )
-
-                    # We add the prev events that we have seen to the latest
-                    # list to ensure the remote server doesn't give them to us
-                    latest = set(latest)
-                    latest |= seen
-
-                    missing_events = yield self.get_missing_events(
-                        origin,
-                        pdu.room_id,
-                        earliest_events_ids=list(latest),
-                        latest_events=[pdu],
-                        limit=10,
-                        min_depth=min_depth,
-                    )
-
-                    # We want to sort these by depth so we process them and
-                    # tell clients about them in order.
-                    missing_events.sort(key=lambda x: x.depth)
-
-                    for e in missing_events:
-                        yield self._handle_new_pdu(
-                            origin,
-                            e,
-                            get_missing=False
-                        )
-
-                    have_seen = yield self.store.have_events(
-                        [ev for ev, _ in pdu.prev_events]
-                    )
+                    # If we're missing stuff, ensure we only fetch stuff one
+                    # at a time.
+                    with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
+                        # We recalculate seen, since it may have changed.
+                        have_seen = yield self.store.have_events(prevs)
+                        seen = set(have_seen.keys())
+
+                        if prevs - seen:
+                            latest = yield self.store.get_latest_event_ids_in_room(
+                                pdu.room_id
+                            )
+
+                            # We add the prev events that we have seen to the latest
+                            # list to ensure the remote server doesn't give them to us
+                            latest = set(latest)
+                            latest |= seen
+
+                            logger.info(
+                                "Missing %d events for room %r: %r...",
+                                len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
+                            )
+
+                            missing_events = yield self.get_missing_events(
+                                origin,
+                                pdu.room_id,
+                                earliest_events_ids=list(latest),
+                                latest_events=[pdu],
+                                limit=10,
+                                min_depth=min_depth,
+                            )
+
+                            # We want to sort these by depth so we process them and
+                            # tell clients about them in order.
+                            missing_events.sort(key=lambda x: x.depth)
+
+                            for e in missing_events:
+                                yield self._handle_new_pdu(
+                                    origin,
+                                    e,
+                                    get_missing=False
+                                )
+
+                            have_seen = yield self.store.have_events(
+                                [ev for ev, _ in pdu.prev_events]
+                            )
 
             prevs = {e_id for e_id, _ in pdu.prev_events}
             seen = set(have_seen.keys())
             if prevs - seen:
+                logger.info(
+                    "Still missing %d events for room %r: %r...",
+                    len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
+                )
                 fetch_state = True
 
         if fetch_state:
@@ -531,7 +605,7 @@ class FederationServer(FederationBase):
                     origin, pdu.room_id, pdu.event_id,
                 )
             except:
-                logger.warn("Failed to get state for event: %s", pdu.event_id)
+                logger.exception("Failed to get state for event: %s", pdu.event_id)
 
         yield self.handler.on_receive_pdu(
             origin,