summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/federation_client.py113
-rw-r--r--synapse/federation/federation_server.py43
-rw-r--r--synapse/federation/transport/client.py22
-rw-r--r--synapse/federation/transport/server.py16
4 files changed, 169 insertions, 25 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index b06387051c..65778fd4ee 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -314,6 +314,40 @@ class FederationClient(FederationBase):
             Deferred: Results in a list of PDUs.
         """
 
+        try:
+            # First we try and ask for just the IDs, as thats far quicker if
+            # we have most of the state and auth_chain already.
+            # However, this may 404 if the other side has an old synapse.
+            result = yield self.transport_layer.get_room_state_ids(
+                destination, room_id, event_id=event_id,
+            )
+
+            state_event_ids = result["pdu_ids"]
+            auth_event_ids = result.get("auth_chain_ids", [])
+
+            fetched_events, failed_to_fetch = yield self.get_events(
+                [destination], room_id, set(state_event_ids + auth_event_ids)
+            )
+
+            if failed_to_fetch:
+                logger.warn("Failed to get %r", failed_to_fetch)
+
+            event_map = {
+                ev.event_id: ev for ev in fetched_events
+            }
+
+            pdus = [event_map[e_id] for e_id in state_event_ids]
+            auth_chain = [event_map[e_id] for e_id in auth_event_ids]
+
+            auth_chain.sort(key=lambda e: e.depth)
+
+            defer.returnValue((pdus, auth_chain))
+        except HttpResponseException as e:
+            if e.code == 400 or e.code == 404:
+                logger.info("Failed to use get_room_state_ids API, falling back")
+            else:
+                raise e
+
         result = yield self.transport_layer.get_room_state(
             destination, room_id, event_id=event_id,
         )
@@ -327,12 +361,26 @@ class FederationClient(FederationBase):
             for p in result.get("auth_chain", [])
         ]
 
+        seen_events = yield self.store.get_events([
+            ev.event_id for ev in itertools.chain(pdus, auth_chain)
+        ])
+
         signed_pdus = yield self._check_sigs_and_hash_and_fetch(
-            destination, pdus, outlier=True
+            destination,
+            [p for p in pdus if p.event_id not in seen_events],
+            outlier=True
+        )
+        signed_pdus.extend(
+            seen_events[p.event_id] for p in pdus if p.event_id in seen_events
         )
 
         signed_auth = yield self._check_sigs_and_hash_and_fetch(
-            destination, auth_chain, outlier=True
+            destination,
+            [p for p in auth_chain if p.event_id not in seen_events],
+            outlier=True
+        )
+        signed_auth.extend(
+            seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events
         )
 
         signed_auth.sort(key=lambda e: e.depth)
@@ -340,6 +388,67 @@ class FederationClient(FederationBase):
         defer.returnValue((signed_pdus, signed_auth))
 
     @defer.inlineCallbacks
+    def get_events(self, destinations, room_id, event_ids, return_local=True):
+        """Fetch events from some remote destinations, checking if we already
+        have them.
+
+        Args:
+            destinations (list)
+            room_id (str)
+            event_ids (list)
+            return_local (bool): Whether to include events we already have in
+                the DB in the returned list of events
+
+        Returns:
+            Deferred: A deferred resolving to a 2-tuple where the first is a list of
+            events and the second is a list of event ids that we failed to fetch.
+        """
+        if return_local:
+            seen_events = yield self.store.get_events(event_ids)
+            signed_events = seen_events.values()
+        else:
+            seen_events = yield self.store.have_events(event_ids)
+            signed_events = []
+
+        failed_to_fetch = set()
+
+        missing_events = set(event_ids)
+        for k in seen_events:
+            missing_events.discard(k)
+
+        if not missing_events:
+            defer.returnValue((signed_events, failed_to_fetch))
+
+        def random_server_list():
+            srvs = list(destinations)
+            random.shuffle(srvs)
+            return srvs
+
+        batch_size = 20
+        missing_events = list(missing_events)
+        for i in xrange(0, len(missing_events), batch_size):
+            batch = set(missing_events[i:i + batch_size])
+
+            deferreds = [
+                self.get_pdu(
+                    destinations=random_server_list(),
+                    event_id=e_id,
+                )
+                for e_id in batch
+            ]
+
+            res = yield defer.DeferredList(deferreds, consumeErrors=True)
+            for success, result in res:
+                if success:
+                    signed_events.append(result)
+                    batch.discard(result.event_id)
+
+            # We removed all events we successfully fetched from `batch`
+            failed_to_fetch.update(batch)
+
+        defer.returnValue((signed_events, failed_to_fetch))
+
+    @defer.inlineCallbacks
     @log_function
     def get_event_auth(self, destination, room_id, event_id):
         res = yield self.transport_layer.get_event_auth(
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 612d274bdb..aba19639c7 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -215,6 +215,27 @@ class FederationServer(FederationBase):
         defer.returnValue((200, resp))
 
     @defer.inlineCallbacks
+    def on_state_ids_request(self, origin, room_id, event_id):
+        if not event_id:
+            raise NotImplementedError("Specify an event")
+
+        in_room = yield self.auth.check_host_in_room(room_id, origin)
+        if not in_room:
+            raise AuthError(403, "Host not in room.")
+
+        pdus = yield self.handler.get_state_for_pdu(
+            room_id, event_id,
+        )
+        auth_chain = yield self.store.get_auth_chain(
+            [pdu.event_id for pdu in pdus]
+        )
+
+        defer.returnValue((200, {
+            "pdu_ids": [pdu.event_id for pdu in pdus],
+            "auth_chain_ids": [pdu.event_id for pdu in auth_chain],
+        }))
+
+    @defer.inlineCallbacks
     def _on_context_state_request_compute(self, room_id, event_id):
         pdus = yield self.handler.get_state_for_pdu(
             room_id, event_id,
@@ -372,27 +393,9 @@ class FederationServer(FederationBase):
             (200, send_content)
         )
 
-    @defer.inlineCallbacks
     @log_function
     def on_query_client_keys(self, origin, content):
-        query = []
-        for user_id, device_ids in content.get("device_keys", {}).items():
-            if not device_ids:
-                query.append((user_id, None))
-            else:
-                for device_id in device_ids:
-                    query.append((user_id, device_id))
-
-        results = yield self.store.get_e2e_device_keys(query)
-
-        json_result = {}
-        for user_id, device_keys in results.items():
-            for device_id, json_bytes in device_keys.items():
-                json_result.setdefault(user_id, {})[device_id] = json.loads(
-                    json_bytes
-                )
-
-        defer.returnValue({"device_keys": json_result})
+        return self.on_query_request("client_keys", content)
 
     @defer.inlineCallbacks
     @log_function
@@ -602,7 +605,7 @@ class FederationServer(FederationBase):
                     origin, pdu.room_id, pdu.event_id,
                 )
             except:
-                logger.warn("Failed to get state for event: %s", pdu.event_id)
+                logger.exception("Failed to get state for event: %s", pdu.event_id)
 
         yield self.handler.on_receive_pdu(
             origin,
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index ebb698e278..3d088e43cb 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -55,6 +55,28 @@ class TransportLayerClient(object):
         )
 
     @log_function
+    def get_room_state_ids(self, destination, room_id, event_id):
+        """ Requests all state for a given room from the given server at the
+        given event. Returns the state's event_id's
+
+        Args:
+            destination (str): The host name of the remote home server we want
+                to get the state from.
+            context (str): The name of the context we want the state of
+            event_id (str): The event we want the context at.
+
+        Returns:
+            Deferred: Results in a dict received from the remote homeserver.
+        """
+        logger.debug("get_room_state_ids dest=%s, room=%s",
+                     destination, room_id)
+
+        path = PREFIX + "/state_ids/%s/" % room_id
+        return self.client.get_json(
+            destination, path=path, args={"event_id": event_id},
+        )
+
+    @log_function
     def get_event(self, destination, event_id, timeout=None):
         """ Requests the pdu with give id and origin from the given server.
 
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 26fa88ae84..0bc6e0801d 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -271,6 +271,17 @@ class FederationStateServlet(BaseFederationServlet):
         )
 
 
+class FederationStateIdsServlet(BaseFederationServlet):
+    PATH = "/state_ids/(?P<room_id>[^/]*)/"
+
+    def on_GET(self, origin, content, query, room_id):
+        return self.handler.on_state_ids_request(
+            origin,
+            room_id,
+            query.get("event_id", [None])[0],
+        )
+
+
 class FederationBackfillServlet(BaseFederationServlet):
     PATH = "/backfill/(?P<context>[^/]*)/"
 
@@ -367,10 +378,8 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
 class FederationClientKeysQueryServlet(BaseFederationServlet):
     PATH = "/user/keys/query"
 
-    @defer.inlineCallbacks
     def on_POST(self, origin, content, query):
-        response = yield self.handler.on_query_client_keys(origin, content)
-        defer.returnValue((200, response))
+        return self.handler.on_query_client_keys(origin, content)
 
 
 class FederationClientKeysClaimServlet(BaseFederationServlet):
@@ -538,6 +547,7 @@ SERVLET_CLASSES = (
     FederationPullServlet,
     FederationEventServlet,
     FederationStateServlet,
+    FederationStateIdsServlet,
     FederationBackfillServlet,
     FederationQueryServlet,
     FederationMakeJoinServlet,