diff options
-rw-r--r-- | UPGRADE.rst | 4 | ||||
-rw-r--r-- | synapse/federation/replication.py | 33 | ||||
-rw-r--r-- | synapse/handlers/federation.py | 75 | ||||
-rw-r--r-- | synapse/storage/__init__.py | 4 | ||||
-rw-r--r-- | synapse/storage/_base.py | 6 | ||||
-rw-r--r-- | synapse/storage/event_federation.py | 24 | ||||
-rw-r--r-- | synapse/util/distributor.py | 12 | ||||
-rw-r--r-- | tests/federation/test_federation.py | 2 | ||||
-rw-r--r-- | tests/test_distributor.py | 27 |
9 files changed, 112 insertions, 75 deletions
diff --git a/UPGRADE.rst b/UPGRADE.rst index a602a9f3eb..9618ad2d57 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -1,6 +1,10 @@ Upgrading to v0.6.0 =================== +To pull in new dependencies, run:: + + python setup.py develop --user + This update includes a change to the database schema. To upgrade you first need to upgrade the database by running:: diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py index 9f8aadccca..8abf67b1b5 100644 --- a/synapse/federation/replication.py +++ b/synapse/federation/replication.py @@ -256,31 +256,35 @@ class ReplicationLayer(object): @defer.inlineCallbacks @log_function - def get_state_for_context(self, destination, context, event_id=None): + def get_state_for_context(self, destination, context, event_id): """Requests all of the `current` state PDUs for a given context from a remote home server. Args: destination (str): The remote homeserver to query for the state. context (str): The context we're interested in. + event_id (str): The id of the event we want the state at. Returns: Deferred: Results in a list of PDUs. """ - transaction_data = yield self.transport_layer.get_context_state( + result = yield self.transport_layer.get_context_state( destination, context, event_id=event_id, ) - transaction = Transaction(**transaction_data) pdus = [ + self.event_from_pdu_json(p, outlier=True) for p in result["pdus"] + ] + + auth_chain = [ self.event_from_pdu_json(p, outlier=True) - for p in transaction.pdus + for p in result.get("auth_chain", []) ] - defer.returnValue(pdus) + defer.returnValue((pdus, auth_chain)) @defer.inlineCallbacks @log_function @@ -383,10 +387,16 @@ class ReplicationLayer(object): context, event_id, ) + auth_chain = yield self.store.get_auth_chain( + [pdu.event_id for pdu in pdus] + ) else: raise NotImplementedError("Specify an event") - defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) + defer.returnValue((200, { + "pdus": [pdu.get_pdu_json() for pdu in pdus], + "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], + })) @defer.inlineCallbacks @log_function @@ -562,8 +572,8 @@ class ReplicationLayer(object): already_seen = ( existing and ( - not existing.internal_metadata.outlier - or pdu.internal_metadata.outlier + not existing.internal_metadata.is_outlier() + or pdu.internal_metadata.is_outlier() ) ) if already_seen: @@ -573,6 +583,8 @@ class ReplicationLayer(object): state = None + auth_chain = [] + # We need to make sure we have all the auth events. # for e_id, _ in pdu.auth_events: # exists = yield self._get_persisted_pdu( @@ -604,7 +616,7 @@ class ReplicationLayer(object): # ) # Get missing pdus if necessary. - if not pdu.internal_metadata.outlier: + if not pdu.internal_metadata.is_outlier(): # We only backfill backwards to the min depth. min_depth = yield self.handler.get_min_depth_for_context( pdu.room_id @@ -645,7 +657,7 @@ class ReplicationLayer(object): "_handle_new_pdu getting state for %s", pdu.room_id ) - state = yield self.get_state_for_context( + state, auth_chain = yield self.get_state_for_context( origin, pdu.room_id, pdu.event_id, ) @@ -655,6 +667,7 @@ class ReplicationLayer(object): pdu, backfilled=backfilled, state=state, + auth_chain=auth_chain, ) else: ret = None diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 4aec3563ac..e23c5c2195 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -91,11 +91,12 @@ class FederationHandler(BaseHandler): yield run_on_reactor() - yield self.replication_layer.send_pdu(event, destinations) + self.replication_layer.send_pdu(event, destinations) @log_function @defer.inlineCallbacks - def on_receive_pdu(self, origin, pdu, backfilled, state=None): + def on_receive_pdu(self, origin, pdu, backfilled, state=None, + auth_chain=None): """ Called by the ReplicationLayer when we have a new pdu. We need to do auth checks and put it through the StateHandler. """ @@ -150,40 +151,41 @@ class FederationHandler(BaseHandler): if not is_in_room and not event.internal_metadata.outlier: logger.debug("Got event for room we're not in.") - replication_layer = self.replication_layer - auth_chain = yield replication_layer.get_event_auth( - origin, - context=event.room_id, - event_id=event.event_id, - ) + replication = self.replication_layer + + if not state: + state, auth_chain = yield replication.get_state_for_context( + origin, context=event.room_id, event_id=event.event_id, + ) + + if not auth_chain: + auth_chain = yield replication.get_event_auth( + origin, + context=event.room_id, + event_id=event.event_id, + ) for e in auth_chain: e.internal_metadata.outlier = True try: - yield self._handle_new_event(e, fetch_missing=False) + yield self._handle_new_event(e, fetch_auth_from=origin) except: logger.exception( - "Failed to parse auth event %s", + "Failed to handle auth event %s", e.event_id, ) - if not state: - state = yield replication_layer.get_state_for_context( - origin, - context=event.room_id, - event_id=event.event_id, - ) - current_state = state if state: for e in state: + logging.info("A :) %r", e) e.internal_metadata.outlier = True try: yield self._handle_new_event(e) except: logger.exception( - "Failed to parse state event %s", + "Failed to handle state event %s", e.event_id, ) @@ -288,7 +290,7 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def on_event_auth(self, event_id): - auth = yield self.store.get_auth_chain(event_id) + auth = yield self.store.get_auth_chain([event_id]) for event in auth: event.signatures.update( @@ -391,10 +393,10 @@ class FederationHandler(BaseHandler): for e in auth_chain: e.internal_metadata.outlier = True try: - yield self._handle_new_event(e, fetch_missing=False) + yield self._handle_new_event(e) except: logger.exception( - "Failed to parse auth event %s", + "Failed to handle auth event %s", e.event_id, ) @@ -403,12 +405,11 @@ class FederationHandler(BaseHandler): e.internal_metadata.outlier = True try: yield self._handle_new_event( - e, - fetch_missing=True + e, fetch_auth_from=target_host ) except: logger.exception( - "Failed to parse state event %s", + "Failed to handle state event %s", e.event_id, ) @@ -526,9 +527,12 @@ class FederationHandler(BaseHandler): event.signatures, ) - yield self.replication_layer.send_pdu(new_pdu, destinations) + self.replication_layer.send_pdu(new_pdu, destinations) - auth_chain = yield self.store.get_auth_chain(event.event_id) + state_ids = [e.event_id for e in context.current_state.values()] + auth_chain = yield self.store.get_auth_chain(set( + [event.event_id] + state_ids + )) defer.returnValue({ "state": context.current_state.values(), @@ -678,7 +682,7 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def _handle_new_event(self, event, state=None, backfilled=False, - current_state=None, fetch_missing=True): + current_state=None, fetch_auth_from=None): logger.debug( "_handle_new_event: Before annotate: %s, sigs: %s", @@ -699,11 +703,20 @@ class FederationHandler(BaseHandler): known_ids = set( [s.event_id for s in context.auth_events.values()] ) + for e_id, _ in event.auth_events: if e_id not in known_ids: - e = yield self.store.get_event( - e_id, allow_none=True, - ) + e = yield self.store.get_event(e_id, allow_none=True) + + if not e and fetch_auth_from is not None: + # Grab the auth_chain over federation if we are missing + # auth events. + auth_chain = yield self.replication_layer.get_event_auth( + fetch_auth_from, event.event_id, event.room_id + ) + for auth_event in auth_chain: + yield self._handle_new_event(auth_event) + e = yield self.store.get_event(e_id, allow_none=True) if not e: # TODO: Do some conflict res to make sure that we're @@ -713,7 +726,7 @@ class FederationHandler(BaseHandler): event.event_id, e_id, known_ids, ) # FIXME: How does raising AuthError work with federation? - raise AuthError(403, "Auth events are stale") + raise AuthError(403, "Cannot find auth event") context.auth_events[(e.type, e.state_key)] = e diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 60c2d67425..e6bb665932 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -143,9 +143,7 @@ class DataStore(RoomMemberStore, RoomStore, elif event.type == EventTypes.Redaction: self._store_redaction(txn, event) - outlier = False - if hasattr(event.internal_metadata, "outlier"): - outlier = event.internal_metadata.outlier + outlier = event.internal_metadata.is_outlier() event_dict = { k: v diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 6dc857c4aa..e0d97f440b 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -488,11 +488,13 @@ class SQLBaseStore(object): ev.unsigned["redacted_because"] = because if get_prev_content and "replaces_state" in ev.unsigned: - ev.unsigned["prev_content"] = self._get_event_txn( + prev = self._get_event_txn( txn, ev.unsigned["replaces_state"], get_prev_content=False, - ).get_dict()["content"] + ) + if prev: + ev.unsigned["prev_content"] = prev.get_dict()["content"] return ev diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index ced066f407..fb2eb21713 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -32,39 +32,33 @@ class EventFederationStore(SQLBaseStore): and backfilling from another server respectively. """ - def get_auth_chain(self, event_id): + def get_auth_chain(self, event_ids): return self.runInteraction( "get_auth_chain", self._get_auth_chain_txn, - event_id + event_ids ) - def _get_auth_chain_txn(self, txn, event_id): - results = self._get_auth_chain_ids_txn(txn, event_id) + def _get_auth_chain_txn(self, txn, event_ids): + results = self._get_auth_chain_ids_txn(txn, event_ids) - sql = "SELECT * FROM events WHERE event_id = ?" - rows = [] - for ev_id in results: - c = txn.execute(sql, (ev_id,)) - rows.extend(self.cursor_to_dict(c)) + return self._get_events_txn(txn, results) - return self._parse_events_txn(txn, rows) - - def get_auth_chain_ids(self, event_id): + def get_auth_chain_ids(self, event_ids): return self.runInteraction( "get_auth_chain_ids", self._get_auth_chain_ids_txn, - event_id + event_ids ) - def _get_auth_chain_ids_txn(self, txn, event_id): + def _get_auth_chain_ids_txn(self, txn, event_ids): results = set() base_sql = ( "SELECT auth_id FROM event_auth WHERE %s" ) - front = set([event_id]) + front = set(event_ids) while front: sql = base_sql % ( " OR ".join(["event_id=?"] * len(front)), diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index 701ccdb781..6925ac96b6 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -115,10 +115,10 @@ class Signal(object): failure.value, failure.getTracebackObject())) if not self.suppress_failures: - raise failure + failure.raiseException() deferreds.append(d.addErrback(eb)) - - result = yield defer.DeferredList( - deferreds, fireOnOneErrback=not self.suppress_failures - ) - defer.returnValue(result) + results = [] + for deferred in deferreds: + result = yield deferred + results.append(result) + defer.returnValue(results) diff --git a/tests/federation/test_federation.py b/tests/federation/test_federation.py index 79ac1ce10d..3e484cd303 100644 --- a/tests/federation/test_federation.py +++ b/tests/federation/test_federation.py @@ -52,6 +52,7 @@ class FederationTestCase(unittest.TestCase): "get_received_txn_response", "set_received_txn_response", "get_destination_retry_timings", + "get_auth_chain", ]) self.mock_persistence.get_received_txn_response.return_value = ( defer.succeed(None) @@ -59,6 +60,7 @@ class FederationTestCase(unittest.TestCase): self.mock_persistence.get_destination_retry_timings.return_value = ( defer.succeed(DestinationsTable.EntryType("", 0, 0)) ) + self.mock_persistence.get_auth_chain.return_value = [] self.mock_config = Mock() self.mock_config.signing_key = [MockKey()] self.clock = MockClock() diff --git a/tests/test_distributor.py b/tests/test_distributor.py index 39c5b8dff2..6a0095d850 100644 --- a/tests/test_distributor.py +++ b/tests/test_distributor.py @@ -13,12 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tests import unittest +from . import unittest from twisted.internet import defer from mock import Mock, patch from synapse.util.distributor import Distributor +from synapse.util.async import run_on_reactor class DistributorTestCase(unittest.TestCase): @@ -26,6 +27,7 @@ class DistributorTestCase(unittest.TestCase): def setUp(self): self.dist = Distributor() + @defer.inlineCallbacks def test_signal_dispatch(self): self.dist.declare("alert") @@ -33,10 +35,11 @@ class DistributorTestCase(unittest.TestCase): self.dist.observe("alert", observer) d = self.dist.fire("alert", 1, 2, 3) - + yield d self.assertTrue(d.called) observer.assert_called_with(1, 2, 3) + @defer.inlineCallbacks def test_signal_dispatch_deferred(self): self.dist.declare("whine") @@ -50,8 +53,10 @@ class DistributorTestCase(unittest.TestCase): self.assertFalse(d_outer.called) d_inner.callback(None) + yield d_outer self.assertTrue(d_outer.called) + @defer.inlineCallbacks def test_signal_catch(self): self.dist.declare("alarm") @@ -65,6 +70,7 @@ class DistributorTestCase(unittest.TestCase): spec=["warning"] ) as mock_logger: d = self.dist.fire("alarm", "Go") + yield d self.assertTrue(d.called) observers[0].assert_called_once("Go") @@ -81,23 +87,28 @@ class DistributorTestCase(unittest.TestCase): self.dist.declare("whail") - observer = Mock() - observer.return_value = defer.fail( - Exception("Oopsie") - ) + class MyException(Exception): + pass + + @defer.inlineCallbacks + def observer(): + yield run_on_reactor() + raise MyException("Oopsie") self.dist.observe("whail", observer) d = self.dist.fire("whail") - yield self.assertFailure(d, Exception) + yield self.assertFailure(d, MyException) + self.dist.suppress_failures = True + @defer.inlineCallbacks def test_signal_prereg(self): observer = Mock() self.dist.observe("flare", observer) self.dist.declare("flare") - self.dist.fire("flare", 4, 5) + yield self.dist.fire("flare", 4, 5) observer.assert_called_with(4, 5) |