summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--UPGRADE.rst4
-rw-r--r--synapse/federation/replication.py33
-rw-r--r--synapse/handlers/federation.py75
-rw-r--r--synapse/storage/__init__.py4
-rw-r--r--synapse/storage/_base.py6
-rw-r--r--synapse/storage/event_federation.py24
-rw-r--r--synapse/util/distributor.py12
-rw-r--r--tests/federation/test_federation.py2
-rw-r--r--tests/test_distributor.py27
9 files changed, 112 insertions, 75 deletions
diff --git a/UPGRADE.rst b/UPGRADE.rst
index a602a9f3eb..9618ad2d57 100644
--- a/UPGRADE.rst
+++ b/UPGRADE.rst
@@ -1,6 +1,10 @@
 Upgrading to v0.6.0
 ===================
 
+To pull in new dependencies, run::
+
+    python setup.py develop --user
+
 This update includes a change to the database schema. To upgrade you first need
 to upgrade the database by running::
 
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index 9f8aadccca..8abf67b1b5 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -256,31 +256,35 @@ class ReplicationLayer(object):
 
     @defer.inlineCallbacks
     @log_function
-    def get_state_for_context(self, destination, context, event_id=None):
+    def get_state_for_context(self, destination, context, event_id):
         """Requests all of the `current` state PDUs for a given context from
         a remote home server.
 
         Args:
             destination (str): The remote homeserver to query for the state.
             context (str): The context we're interested in.
+            event_id (str): The id of the event we want the state at.
 
         Returns:
             Deferred: Results in a list of PDUs.
         """
 
-        transaction_data = yield self.transport_layer.get_context_state(
+        result = yield self.transport_layer.get_context_state(
             destination,
             context,
             event_id=event_id,
         )
 
-        transaction = Transaction(**transaction_data)
         pdus = [
+            self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
+        ]
+
+        auth_chain = [
             self.event_from_pdu_json(p, outlier=True)
-            for p in transaction.pdus
+            for p in result.get("auth_chain", [])
         ]
 
-        defer.returnValue(pdus)
+        defer.returnValue((pdus, auth_chain))
 
     @defer.inlineCallbacks
     @log_function
@@ -383,10 +387,16 @@ class ReplicationLayer(object):
                 context,
                 event_id,
             )
+            auth_chain = yield self.store.get_auth_chain(
+                [pdu.event_id for pdu in pdus]
+            )
         else:
             raise NotImplementedError("Specify an event")
 
-        defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
+        defer.returnValue((200, {
+            "pdus": [pdu.get_pdu_json() for pdu in pdus],
+            "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
+        }))
 
     @defer.inlineCallbacks
     @log_function
@@ -562,8 +572,8 @@ class ReplicationLayer(object):
 
         already_seen = (
             existing and (
-                not existing.internal_metadata.outlier
-                or pdu.internal_metadata.outlier
+                not existing.internal_metadata.is_outlier()
+                or pdu.internal_metadata.is_outlier()
             )
         )
         if already_seen:
@@ -573,6 +583,8 @@ class ReplicationLayer(object):
 
         state = None
 
+        auth_chain = []
+
         # We need to make sure we have all the auth events.
         # for e_id, _ in pdu.auth_events:
         #     exists = yield self._get_persisted_pdu(
@@ -604,7 +616,7 @@ class ReplicationLayer(object):
         #             )
 
         # Get missing pdus if necessary.
-        if not pdu.internal_metadata.outlier:
+        if not pdu.internal_metadata.is_outlier():
             # We only backfill backwards to the min depth.
             min_depth = yield self.handler.get_min_depth_for_context(
                 pdu.room_id
@@ -645,7 +657,7 @@ class ReplicationLayer(object):
                     "_handle_new_pdu getting state for %s",
                     pdu.room_id
                 )
-                state = yield self.get_state_for_context(
+                state, auth_chain = yield self.get_state_for_context(
                     origin, pdu.room_id, pdu.event_id,
                 )
 
@@ -655,6 +667,7 @@ class ReplicationLayer(object):
                 pdu,
                 backfilled=backfilled,
                 state=state,
+                auth_chain=auth_chain,
             )
         else:
             ret = None
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 4aec3563ac..e23c5c2195 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -91,11 +91,12 @@ class FederationHandler(BaseHandler):
 
         yield run_on_reactor()
 
-        yield self.replication_layer.send_pdu(event, destinations)
+        self.replication_layer.send_pdu(event, destinations)
 
     @log_function
     @defer.inlineCallbacks
-    def on_receive_pdu(self, origin, pdu, backfilled, state=None):
+    def on_receive_pdu(self, origin, pdu, backfilled, state=None,
+                       auth_chain=None):
         """ Called by the ReplicationLayer when we have a new pdu. We need to
         do auth checks and put it through the StateHandler.
         """
@@ -150,40 +151,41 @@ class FederationHandler(BaseHandler):
         if not is_in_room and not event.internal_metadata.outlier:
             logger.debug("Got event for room we're not in.")
 
-            replication_layer = self.replication_layer
-            auth_chain = yield replication_layer.get_event_auth(
-                origin,
-                context=event.room_id,
-                event_id=event.event_id,
-            )
+            replication = self.replication_layer
+
+            if not state:
+                state, auth_chain = yield replication.get_state_for_context(
+                    origin, context=event.room_id, event_id=event.event_id,
+                )
+
+            if not auth_chain:
+                auth_chain = yield replication.get_event_auth(
+                    origin,
+                    context=event.room_id,
+                    event_id=event.event_id,
+                )
 
             for e in auth_chain:
                 e.internal_metadata.outlier = True
                 try:
-                    yield self._handle_new_event(e, fetch_missing=False)
+                    yield self._handle_new_event(e, fetch_auth_from=origin)
                 except:
                     logger.exception(
-                        "Failed to parse auth event %s",
+                        "Failed to handle auth event %s",
                         e.event_id,
                     )
 
-            if not state:
-                state = yield replication_layer.get_state_for_context(
-                    origin,
-                    context=event.room_id,
-                    event_id=event.event_id,
-                )
-
             current_state = state
 
         if state:
             for e in state:
+                logging.info("A :) %r", e)
                 e.internal_metadata.outlier = True
                 try:
                     yield self._handle_new_event(e)
                 except:
                     logger.exception(
-                        "Failed to parse state event %s",
+                        "Failed to handle state event %s",
                         e.event_id,
                     )
 
@@ -288,7 +290,7 @@ class FederationHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def on_event_auth(self, event_id):
-        auth = yield self.store.get_auth_chain(event_id)
+        auth = yield self.store.get_auth_chain([event_id])
 
         for event in auth:
             event.signatures.update(
@@ -391,10 +393,10 @@ class FederationHandler(BaseHandler):
             for e in auth_chain:
                 e.internal_metadata.outlier = True
                 try:
-                    yield self._handle_new_event(e, fetch_missing=False)
+                    yield self._handle_new_event(e)
                 except:
                     logger.exception(
-                        "Failed to parse auth event %s",
+                        "Failed to handle auth event %s",
                         e.event_id,
                     )
 
@@ -403,12 +405,11 @@ class FederationHandler(BaseHandler):
                 e.internal_metadata.outlier = True
                 try:
                     yield self._handle_new_event(
-                        e,
-                        fetch_missing=True
+                        e, fetch_auth_from=target_host
                     )
                 except:
                     logger.exception(
-                        "Failed to parse state event %s",
+                        "Failed to handle state event %s",
                         e.event_id,
                     )
 
@@ -526,9 +527,12 @@ class FederationHandler(BaseHandler):
             event.signatures,
         )
 
-        yield self.replication_layer.send_pdu(new_pdu, destinations)
+        self.replication_layer.send_pdu(new_pdu, destinations)
 
-        auth_chain = yield self.store.get_auth_chain(event.event_id)
+        state_ids = [e.event_id for e in context.current_state.values()]
+        auth_chain = yield self.store.get_auth_chain(set(
+            [event.event_id] + state_ids
+        ))
 
         defer.returnValue({
             "state": context.current_state.values(),
@@ -678,7 +682,7 @@ class FederationHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def _handle_new_event(self, event, state=None, backfilled=False,
-                          current_state=None, fetch_missing=True):
+                          current_state=None, fetch_auth_from=None):
 
         logger.debug(
             "_handle_new_event: Before annotate: %s, sigs: %s",
@@ -699,11 +703,20 @@ class FederationHandler(BaseHandler):
         known_ids = set(
             [s.event_id for s in context.auth_events.values()]
         )
+
         for e_id, _ in event.auth_events:
             if e_id not in known_ids:
-                e = yield self.store.get_event(
-                    e_id, allow_none=True,
-                )
+                e = yield self.store.get_event(e_id, allow_none=True)
+
+                if not e and fetch_auth_from is not None:
+                    # Grab the auth_chain over federation if we are missing
+                    # auth events.
+                    auth_chain = yield self.replication_layer.get_event_auth(
+                        fetch_auth_from, event.event_id, event.room_id
+                    )
+                    for auth_event in auth_chain:
+                        yield self._handle_new_event(auth_event)
+                    e = yield self.store.get_event(e_id, allow_none=True)
 
                 if not e:
                     # TODO: Do some conflict res to make sure that we're
@@ -713,7 +726,7 @@ class FederationHandler(BaseHandler):
                         event.event_id, e_id, known_ids,
                     )
                     # FIXME: How does raising AuthError work with federation?
-                    raise AuthError(403, "Auth events are stale")
+                    raise AuthError(403, "Cannot find auth event")
 
                 context.auth_events[(e.type, e.state_key)] = e
 
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 60c2d67425..e6bb665932 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -143,9 +143,7 @@ class DataStore(RoomMemberStore, RoomStore,
         elif event.type == EventTypes.Redaction:
             self._store_redaction(txn, event)
 
-        outlier = False
-        if hasattr(event.internal_metadata, "outlier"):
-            outlier = event.internal_metadata.outlier
+        outlier = event.internal_metadata.is_outlier()
 
         event_dict = {
             k: v
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 6dc857c4aa..e0d97f440b 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -488,11 +488,13 @@ class SQLBaseStore(object):
                 ev.unsigned["redacted_because"] = because
 
         if get_prev_content and "replaces_state" in ev.unsigned:
-            ev.unsigned["prev_content"] = self._get_event_txn(
+            prev = self._get_event_txn(
                 txn,
                 ev.unsigned["replaces_state"],
                 get_prev_content=False,
-            ).get_dict()["content"]
+            )
+            if prev:
+                ev.unsigned["prev_content"] = prev.get_dict()["content"]
 
         return ev
 
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index ced066f407..fb2eb21713 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -32,39 +32,33 @@ class EventFederationStore(SQLBaseStore):
     and backfilling from another server respectively.
     """
 
-    def get_auth_chain(self, event_id):
+    def get_auth_chain(self, event_ids):
         return self.runInteraction(
             "get_auth_chain",
             self._get_auth_chain_txn,
-            event_id
+            event_ids
         )
 
-    def _get_auth_chain_txn(self, txn, event_id):
-        results = self._get_auth_chain_ids_txn(txn, event_id)
+    def _get_auth_chain_txn(self, txn, event_ids):
+        results = self._get_auth_chain_ids_txn(txn, event_ids)
 
-        sql = "SELECT * FROM events WHERE event_id = ?"
-        rows = []
-        for ev_id in results:
-            c = txn.execute(sql, (ev_id,))
-            rows.extend(self.cursor_to_dict(c))
+        return self._get_events_txn(txn, results)
 
-        return self._parse_events_txn(txn, rows)
-
-    def get_auth_chain_ids(self, event_id):
+    def get_auth_chain_ids(self, event_ids):
         return self.runInteraction(
             "get_auth_chain_ids",
             self._get_auth_chain_ids_txn,
-            event_id
+            event_ids
         )
 
-    def _get_auth_chain_ids_txn(self, txn, event_id):
+    def _get_auth_chain_ids_txn(self, txn, event_ids):
         results = set()
 
         base_sql = (
             "SELECT auth_id FROM event_auth WHERE %s"
         )
 
-        front = set([event_id])
+        front = set(event_ids)
         while front:
             sql = base_sql % (
                 " OR ".join(["event_id=?"] * len(front)),
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 701ccdb781..6925ac96b6 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -115,10 +115,10 @@ class Signal(object):
                             failure.value,
                             failure.getTracebackObject()))
                     if not self.suppress_failures:
-                        raise failure
+                        failure.raiseException()
                 deferreds.append(d.addErrback(eb))
-
-            result = yield defer.DeferredList(
-                deferreds, fireOnOneErrback=not self.suppress_failures
-            )
-        defer.returnValue(result)
+            results = []
+            for deferred in deferreds:
+                result = yield deferred
+                results.append(result)
+        defer.returnValue(results)
diff --git a/tests/federation/test_federation.py b/tests/federation/test_federation.py
index 79ac1ce10d..3e484cd303 100644
--- a/tests/federation/test_federation.py
+++ b/tests/federation/test_federation.py
@@ -52,6 +52,7 @@ class FederationTestCase(unittest.TestCase):
             "get_received_txn_response",
             "set_received_txn_response",
             "get_destination_retry_timings",
+            "get_auth_chain",
         ])
         self.mock_persistence.get_received_txn_response.return_value = (
             defer.succeed(None)
@@ -59,6 +60,7 @@ class FederationTestCase(unittest.TestCase):
         self.mock_persistence.get_destination_retry_timings.return_value = (
             defer.succeed(DestinationsTable.EntryType("", 0, 0))
         )
+        self.mock_persistence.get_auth_chain.return_value = []
         self.mock_config = Mock()
         self.mock_config.signing_key = [MockKey()]
         self.clock = MockClock()
diff --git a/tests/test_distributor.py b/tests/test_distributor.py
index 39c5b8dff2..6a0095d850 100644
--- a/tests/test_distributor.py
+++ b/tests/test_distributor.py
@@ -13,12 +13,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from tests import unittest
+from . import unittest
 from twisted.internet import defer
 
 from mock import Mock, patch
 
 from synapse.util.distributor import Distributor
+from synapse.util.async import run_on_reactor
 
 
 class DistributorTestCase(unittest.TestCase):
@@ -26,6 +27,7 @@ class DistributorTestCase(unittest.TestCase):
     def setUp(self):
         self.dist = Distributor()
 
+    @defer.inlineCallbacks
     def test_signal_dispatch(self):
         self.dist.declare("alert")
 
@@ -33,10 +35,11 @@ class DistributorTestCase(unittest.TestCase):
         self.dist.observe("alert", observer)
 
         d = self.dist.fire("alert", 1, 2, 3)
-
+        yield d
         self.assertTrue(d.called)
         observer.assert_called_with(1, 2, 3)
 
+    @defer.inlineCallbacks
     def test_signal_dispatch_deferred(self):
         self.dist.declare("whine")
 
@@ -50,8 +53,10 @@ class DistributorTestCase(unittest.TestCase):
         self.assertFalse(d_outer.called)
 
         d_inner.callback(None)
+        yield d_outer
         self.assertTrue(d_outer.called)
 
+    @defer.inlineCallbacks
     def test_signal_catch(self):
         self.dist.declare("alarm")
 
@@ -65,6 +70,7 @@ class DistributorTestCase(unittest.TestCase):
                 spec=["warning"]
         ) as mock_logger:
             d = self.dist.fire("alarm", "Go")
+            yield d
             self.assertTrue(d.called)
 
             observers[0].assert_called_once("Go")
@@ -81,23 +87,28 @@ class DistributorTestCase(unittest.TestCase):
 
         self.dist.declare("whail")
 
-        observer = Mock()
-        observer.return_value = defer.fail(
-            Exception("Oopsie")
-        )
+        class MyException(Exception):
+            pass
+
+        @defer.inlineCallbacks
+        def observer():
+            yield run_on_reactor()
+            raise MyException("Oopsie")
 
         self.dist.observe("whail", observer)
 
         d = self.dist.fire("whail")
 
-        yield self.assertFailure(d, Exception)
+        yield self.assertFailure(d, MyException)
+        self.dist.suppress_failures = True
 
+    @defer.inlineCallbacks
     def test_signal_prereg(self):
         observer = Mock()
         self.dist.observe("flare", observer)
 
         self.dist.declare("flare")
-        self.dist.fire("flare", 4, 5)
+        yield self.dist.fire("flare", 4, 5)
 
         observer.assert_called_with(4, 5)