summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/federation/federation_base.py1
-rw-r--r--synapse/federation/federation_client.py31
-rw-r--r--synapse/federation/transport/client.py4
-rw-r--r--synapse/handlers/federation.py152
-rw-r--r--synapse/http/matrixfederationclient.py13
-rw-r--r--synapse/storage/event_federation.py62
6 files changed, 177 insertions, 86 deletions
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 5217d91aab..f0430b2cb1 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -80,6 +80,7 @@ class FederationBase(object):
                             destinations=[pdu.origin],
                             event_id=pdu.event_id,
                             outlier=outlier,
+                            timeout=10000,
                         )
 
                         if new_pdu:
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 3a7bc0c9a7..2f2bf3a134 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -22,6 +22,7 @@ from .units import Edu
 from synapse.api.errors import (
     CodeMessageException, HttpResponseException, SynapseError,
 )
+from synapse.util import unwrapFirstError
 from synapse.util.expiringcache import ExpiringCache
 from synapse.util.logutils import log_function
 from synapse.events import FrozenEvent
@@ -167,13 +168,17 @@ class FederationClient(FederationBase):
         for i, pdu in enumerate(pdus):
             pdus[i] = yield self._check_sigs_and_hash(pdu)
 
-            # FIXME: We should handle signature failures more gracefully.
+        # FIXME: We should handle signature failures more gracefully.
+        pdus[:] = yield defer.gatherResults(
+            [self._check_sigs_and_hash(pdu) for pdu in pdus],
+            consumeErrors=True,
+        ).addErrback(unwrapFirstError)
 
         defer.returnValue(pdus)
 
     @defer.inlineCallbacks
     @log_function
-    def get_pdu(self, destinations, event_id, outlier=False):
+    def get_pdu(self, destinations, event_id, outlier=False, timeout=None):
         """Requests the PDU with given origin and ID from the remote home
         servers.
 
@@ -212,7 +217,7 @@ class FederationClient(FederationBase):
 
                 with limiter:
                     transaction_data = yield self.transport_layer.get_event(
-                        destination, event_id
+                        destination, event_id, timeout=timeout,
                     )
 
                     logger.debug("transaction_data %r", transaction_data)
@@ -370,13 +375,17 @@ class FederationClient(FederationBase):
                     for p in content.get("auth_chain", [])
                 ]
 
-                signed_state = yield self._check_sigs_and_hash_and_fetch(
-                    destination, state, outlier=True
-                )
-
-                signed_auth = yield self._check_sigs_and_hash_and_fetch(
-                    destination, auth_chain, outlier=True
-                )
+                signed_state, signed_auth = yield defer.gatherResults(
+                    [
+                        self._check_sigs_and_hash_and_fetch(
+                            destination, state, outlier=True
+                        ),
+                        self._check_sigs_and_hash_and_fetch(
+                            destination, auth_chain, outlier=True
+                        )
+                    ],
+                    consumeErrors=True
+                ).addErrback(unwrapFirstError)
 
                 auth_chain.sort(key=lambda e: e.depth)
 
@@ -518,7 +527,7 @@ class FederationClient(FederationBase):
             # Are we missing any?
 
             seen_events = set(earliest_events_ids)
-            seen_events.update(e.event_id for e in signed_events)
+            seen_events.update(e.event_id for e in signed_events if e)
 
             missing_events = {}
             for e in itertools.chain(latest_events, signed_events):
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 80d03012b7..c2b53b78b2 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -50,7 +50,7 @@ class TransportLayerClient(object):
         )
 
     @log_function
-    def get_event(self, destination, event_id):
+    def get_event(self, destination, event_id, timeout=None):
         """ Requests the pdu with give id and origin from the given server.
 
         Args:
@@ -65,7 +65,7 @@ class TransportLayerClient(object):
                      destination, event_id)
 
         path = PREFIX + "/event/%s/" % (event_id, )
-        return self.client.get_json(destination, path=path)
+        return self.client.get_json(destination, path=path, timeout=timeout)
 
     @log_function
     def backfill(self, destination, room_id, event_tuples, limit):
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index d35d9f603c..46ce3699d7 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -230,27 +230,65 @@ class FederationHandler(BaseHandler):
         if not extremities:
             extremities = yield self.store.get_oldest_events_in_room(room_id)
 
-        pdus = yield self.replication_layer.backfill(
+        events = yield self.replication_layer.backfill(
             dest,
             room_id,
-            limit,
+            limit=limit,
             extremities=extremities,
         )
 
-        events = []
+        event_map = {e.event_id: e for e in events}
 
-        for pdu in pdus:
-            event = pdu
+        event_ids = set(e.event_id for e in events)
 
-            # FIXME (erikj): Not sure this actually works :/
-            context = yield self.state_handler.compute_event_context(event)
+        edges = [
+            ev.event_id
+            for ev in events
+            if set(e_id for e_id, _ in ev.prev_events) - event_ids
+        ]
 
-            events.append((event, context))
+        # For each edge get the current state.
 
-            yield self.store.persist_event(
-                event,
-                context=context,
-                backfilled=True
+        auth_events = {}
+        events_to_state = {}
+        for e_id in edges:
+            state, auth = yield self.replication_layer.get_state_for_room(
+                destination=dest,
+                room_id=room_id,
+                event_id=e_id
+            )
+            auth_events.update({a.event_id: a for a in auth})
+            events_to_state[e_id] = state
+
+        yield defer.gatherResults(
+            [
+                self._handle_new_event(dest, a)
+                for a in auth_events.values()
+            ],
+            consumeErrors=True,
+        ).addErrback(unwrapFirstError)
+
+        yield defer.gatherResults(
+            [
+                self._handle_new_event(
+                    dest, event_map[e_id],
+                    state=events_to_state[e_id],
+                    backfilled=True,
+                )
+                for e_id in events_to_state
+            ],
+            consumeErrors=True
+        ).addErrback(unwrapFirstError)
+
+        events.sort(key=lambda e: e.depth)
+
+        for event in events:
+            if event in events_to_state:
+                continue
+
+            yield self._handle_new_event(
+                dest, event,
+                backfilled=True,
             )
 
         defer.returnValue(events)
@@ -347,7 +385,7 @@ class FederationHandler(BaseHandler):
                     logger.info(e.message)
                     continue
                 except Exception as e:
-                    logger.warn(
+                    logger.exception(
                         "Failed to backfill from %s because %s",
                         dom, e,
                     )
@@ -517,30 +555,14 @@ class FederationHandler(BaseHandler):
                 # FIXME
                 pass
 
-            for e in auth_chain:
-                e.internal_metadata.outlier = True
-
-                if e.event_id == event.event_id:
-                    continue
-
-                try:
-                    auth_ids = [e_id for e_id, _ in e.auth_events]
-                    auth = {
-                        (e.type, e.state_key): e for e in auth_chain
-                        if e.event_id in auth_ids
-                    }
-                    yield self._handle_new_event(
-                        origin, e, auth_events=auth
-                    )
-                except:
-                    logger.exception(
-                        "Failed to handle auth event %s",
-                        e.event_id,
-                    )
+            yield self._handle_auth_events(
+                origin, [e for e in auth_chain if e.event_id != event.event_id]
+            )
 
-            for e in state:
+            @defer.inlineCallbacks
+            def handle_state(e):
                 if e.event_id == event.event_id:
-                    continue
+                    return
 
                 e.internal_metadata.outlier = True
                 try:
@@ -558,6 +580,8 @@ class FederationHandler(BaseHandler):
                         e.event_id,
                     )
 
+            yield defer.DeferredList([handle_state(e) for e in state])
+
             auth_ids = [e_id for e_id, _ in event.auth_events]
             auth_events = {
                 (e.type, e.state_key): e for e in auth_chain
@@ -893,9 +917,12 @@ class FederationHandler(BaseHandler):
         # This is a hack to fix some old rooms where the initial join event
         # didn't reference the create event in its auth events.
         if event.type == EventTypes.Member and not event.auth_events:
-            if len(event.prev_events) == 1:
-                c = yield self.store.get_event(event.prev_events[0][0])
-                if c.type == EventTypes.Create:
+            if len(event.prev_events) == 1 and event.depth < 5:
+                c = yield self.store.get_event(
+                    event.prev_events[0][0],
+                    allow_none=True,
+                )
+                if c and c.type == EventTypes.Create:
                     auth_events[(c.type, c.state_key)] = c
 
         try:
@@ -1314,3 +1341,52 @@ class FederationHandler(BaseHandler):
             },
             "missing": [e.event_id for e in missing_locals],
         })
+
+    @defer.inlineCallbacks
+    def _handle_auth_events(self, origin, auth_events):
+        auth_ids_to_deferred = {}
+
+        def process_auth_ev(ev):
+            auth_ids = [e_id for e_id, _ in ev.auth_events]
+
+            prev_ds = [
+                auth_ids_to_deferred[i]
+                for i in auth_ids
+                if i in auth_ids_to_deferred
+            ]
+
+            d = defer.Deferred()
+
+            auth_ids_to_deferred[ev.event_id] = d
+
+            @defer.inlineCallbacks
+            def f(*_):
+                ev.internal_metadata.outlier = True
+
+                try:
+                    auth = {
+                        (e.type, e.state_key): e for e in auth_events
+                        if e.event_id in auth_ids
+                    }
+
+                    yield self._handle_new_event(
+                        origin, ev, auth_events=auth
+                    )
+                except:
+                    logger.exception(
+                        "Failed to handle auth event %s",
+                        ev.event_id,
+                    )
+
+                d.callback(None)
+
+            if prev_ds:
+                dx = defer.DeferredList(prev_ds)
+                dx.addBoth(f)
+            else:
+                f()
+
+        for e in auth_events:
+            process_auth_ev(e)
+
+        yield defer.DeferredList(auth_ids_to_deferred.values())
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index c99d237c73..312bbcc6b8 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -110,7 +110,8 @@ class MatrixFederationHttpClient(object):
     @defer.inlineCallbacks
     def _create_request(self, destination, method, path_bytes,
                         body_callback, headers_dict={}, param_bytes=b"",
-                        query_bytes=b"", retry_on_dns_fail=True):
+                        query_bytes=b"", retry_on_dns_fail=True,
+                        timeout=None):
         """ Creates and sends a request to the given url
         """
         headers_dict[b"User-Agent"] = [self.version_string]
@@ -158,7 +159,7 @@ class MatrixFederationHttpClient(object):
 
                 response = yield self.clock.time_bound_deferred(
                     request_deferred,
-                    time_out=60,
+                    time_out=timeout/1000. if timeout else 60,
                 )
 
                 logger.debug("Got response to %s", method)
@@ -181,7 +182,7 @@ class MatrixFederationHttpClient(object):
                     _flatten_response_never_received(e),
                 )
 
-                if retries_left:
+                if retries_left and not timeout:
                     yield sleep(2 ** (5 - retries_left))
                     retries_left -= 1
                 else:
@@ -334,7 +335,8 @@ class MatrixFederationHttpClient(object):
         defer.returnValue(json.loads(body))
 
     @defer.inlineCallbacks
-    def get_json(self, destination, path, args={}, retry_on_dns_fail=True):
+    def get_json(self, destination, path, args={}, retry_on_dns_fail=True,
+                 timeout=None):
         """ GETs some json from the given host homeserver and path
 
         Args:
@@ -370,7 +372,8 @@ class MatrixFederationHttpClient(object):
             path.encode("ascii"),
             query_bytes=query_bytes,
             body_callback=body_callback,
-            retry_on_dns_fail=retry_on_dns_fail
+            retry_on_dns_fail=retry_on_dns_fail,
+            timeout=timeout,
         )
 
         if 200 <= response.code < 300:
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 5d4b7843f3..80eff8e6f2 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -19,6 +19,7 @@ from ._base import SQLBaseStore, cached
 from syutil.base64util import encode_base64
 
 import logging
+from Queue import PriorityQueue
 
 
 logger = logging.getLogger(__name__)
@@ -330,12 +331,13 @@ class EventFederationStore(SQLBaseStore):
                 " WHERE event_id = ? AND room_id = ?"
                 " )"
                 " AND NOT EXISTS ("
-                " SELECT 1 FROM events WHERE event_id = ? AND room_id = ?"
+                " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
+                " AND outlier = ?"
                 " )"
             )
 
             txn.executemany(query, [
-                (e_id, room_id, e_id, room_id, e_id, room_id, )
+                (e_id, room_id, e_id, room_id, e_id, room_id, False)
                 for e_id, _ in prev_events
             ])
 
@@ -370,43 +372,43 @@ class EventFederationStore(SQLBaseStore):
             room_id, repr(event_list), limit
         )
 
-        event_results = event_list
+        event_results = set(event_list)
 
-        front = event_list
+        # We want to make sure that we do a breadth-first, "depth" ordered
+        # search.
 
         query = (
-            "SELECT prev_event_id FROM event_edges "
-            "WHERE room_id = ? AND event_id = ? "
-            "LIMIT ?"
+            "SELECT depth, prev_event_id FROM event_edges"
+            " INNER JOIN events"
+            " ON prev_event_id = events.event_id"
+            " AND event_edges.room_id = events.room_id"
+            " WHERE event_edges.room_id = ? AND event_edges.event_id = ?"
+            " LIMIT ?"
         )
 
-        # We iterate through all event_ids in `front` to select their previous
-        # events. These are dumped in `new_front`.
-        # We continue until we reach the limit *or* new_front is empty (i.e.,
-        # we've run out of things to select
-        while front and len(event_results) < limit:
+        queue = PriorityQueue()
 
-            new_front = []
-            for event_id in front:
-                logger.debug(
-                    "_backfill_interaction: id=%s",
-                    event_id
-                )
+        for event_id in event_list:
+            txn.execute(
+                query,
+                (room_id, event_id, limit - len(event_results))
+            )
 
-                txn.execute(
-                    query,
-                    (room_id, event_id, limit - len(event_results))
-                )
+            for row in txn.fetchall():
+                queue.put(row)
 
-                for row in txn.fetchall():
-                    logger.debug(
-                        "_backfill_interaction: got id=%s",
-                        *row
-                    )
-                    new_front.append(row[0])
+        while not queue.empty() and len(event_results) < limit:
+            _, event_id = queue.get_nowait()
 
-            front = new_front
-            event_results += new_front
+            event_results.add(event_id)
+
+            txn.execute(
+                query,
+                (room_id, event_id, limit - len(event_results))
+            )
+
+            for row in txn.fetchall():
+                queue.put(row)
 
         return event_results