summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/3968.bugfix1
-rw-r--r--synapse/handlers/federation.py58
-rw-r--r--tests/test_federation.py106
3 files changed, 47 insertions, 118 deletions
diff --git a/changelog.d/3968.bugfix b/changelog.d/3968.bugfix
new file mode 100644
index 0000000000..18d43cd64e
--- /dev/null
+++ b/changelog.d/3968.bugfix
@@ -0,0 +1 @@
+Fix exceptions when processing incoming events over federation
\ No newline at end of file
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 38bebbf598..d05b63673f 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -106,7 +106,7 @@ class FederationHandler(BaseHandler):
 
         self.hs = hs
 
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastore()  # type: synapse.storage.DataStore
         self.federation_client = hs.get_federation_client()
         self.state_handler = hs.get_state_handler()
         self.server_name = hs.hostname
@@ -323,14 +323,22 @@ class FederationHandler(BaseHandler):
                         affected=pdu.event_id,
                     )
 
-                # Calculate the state of the previous events, and
-                # de-conflict them to find the current state.
-                state_groups = []
+                # Calculate the state after each of the previous events, and
+                # resolve them to find the correct state at the current event.
                 auth_chains = set()
+                event_map = {
+                    event_id: pdu,
+                }
                 try:
                     # Get the state of the events we know about
-                    ours = yield self.store.get_state_groups(room_id, list(seen))
-                    state_groups.append(ours)
+                    ours = yield self.store.get_state_groups_ids(room_id, seen)
+
+                    # state_maps is a list of mappings from (type, state_key) to event_id
+                    # type: list[dict[tuple[str, str], str]]
+                    state_maps = list(ours.values())
+
+                    # we don't need this any more, let's delete it.
+                    del ours
 
                     # Ask the remote server for the states we don't
                     # know about
@@ -350,28 +358,54 @@ class FederationHandler(BaseHandler):
                                 )
                             )
 
+                            # we want the state *after* p; get_state_for_room returns the
+                            # state *before* p.
+                            remote_event = yield self.federation_client.get_pdu(
+                                [origin], p, outlier=True,
+                            )
+
+                            if remote_event is None:
+                                raise Exception(
+                                    "Unable to get missing prev_event %s" % (p, )
+                                )
+
+                            if remote_event.is_state():
+                                remote_state.append(remote_event)
+
                             # XXX hrm I'm not convinced that duplicate events will compare
                             # for equality, so I'm not sure this does what the author
                             # hoped.
                             auth_chains.update(got_auth_chain)
 
-                            state_group = {
+                            remote_state_map = {
                                 (x.type, x.state_key): x.event_id for x in remote_state
                             }
-                            state_groups.append(state_group)
+                            state_maps.append(remote_state_map)
+
+                            for x in remote_state:
+                                event_map[x.event_id] = x
 
                     # Resolve any conflicting state
+                    @defer.inlineCallbacks
                     def fetch(ev_ids):
-                        return self.store.get_events(
-                            ev_ids, get_prev_content=False, check_redacted=False
+                        fetched = yield self.store.get_events(
+                            ev_ids, get_prev_content=False, check_redacted=False,
                         )
+                        # add any events we fetch here to the `event_map` so that we
+                        # can use them to build the state event list below.
+                        event_map.update(fetched)
+                        defer.returnValue(fetched)
 
                     room_version = yield self.store.get_room_version(room_id)
                     state_map = yield resolve_events_with_factory(
-                        room_version, state_groups, {event_id: pdu}, fetch
+                        room_version, state_maps, event_map, fetch,
                     )
 
-                    state = (yield self.store.get_events(state_map.values())).values()
+                    # we need to give _process_received_pdu the actual state events
+                    # rather than event ids, so generate that now.
+                    state = [
+                        event_map[e] for e in six.itervalues(state_map)
+                    ]
                     auth_chain = list(auth_chains)
                 except Exception:
                     logger.warn(
diff --git a/tests/test_federation.py b/tests/test_federation.py
index ff55c7a627..952a0a7b51 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -141,109 +141,3 @@ class MessageAcceptTests(unittest.TestCase):
             self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
         )
         self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
-
-    def test_cant_hide_past_history(self):
-        """
-        If you send a message, you must be able to provide the direct
-        prev_events that said event references.
-        """
-
-        def post_json(destination, path, data, headers=None, timeout=0):
-            if path.startswith("/_matrix/federation/v1/get_missing_events/"):
-                return {
-                    "events": [
-                        {
-                            "room_id": self.room_id,
-                            "sender": "@baduser:test.serv",
-                            "event_id": "three:test.serv",
-                            "depth": 1000,
-                            "origin_server_ts": 1,
-                            "type": "m.room.message",
-                            "origin": "test.serv",
-                            "content": "hewwo?",
-                            "auth_events": [],
-                            "prev_events": [("four:test.serv", {})],
-                        }
-                    ]
-                }
-
-        self.http_client.post_json = post_json
-
-        def get_json(destination, path, args, headers=None):
-            if path.startswith("/_matrix/federation/v1/state_ids/"):
-                d = self.successResultOf(
-                    self.homeserver.datastore.get_state_ids_for_event("one:test.serv")
-                )
-
-                return succeed(
-                    {
-                        "pdu_ids": [
-                            y
-                            for x, y in d.items()
-                            if x == ("m.room.member", "@us:test")
-                        ],
-                        "auth_chain_ids": list(d.values()),
-                    }
-                )
-
-        self.http_client.get_json = get_json
-
-        # Figure out what the most recent event is
-        most_recent = self.successResultOf(
-            maybeDeferred(
-                self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
-            )
-        )[0]
-
-        # Make a good event
-        good_event = FrozenEvent(
-            {
-                "room_id": self.room_id,
-                "sender": "@baduser:test.serv",
-                "event_id": "one:test.serv",
-                "depth": 1000,
-                "origin_server_ts": 1,
-                "type": "m.room.message",
-                "origin": "test.serv",
-                "content": "hewwo?",
-                "auth_events": [],
-                "prev_events": [(most_recent, {})],
-            }
-        )
-
-        with LoggingContext(request="good_event"):
-            d = self.handler.on_receive_pdu(
-                "test.serv", good_event, sent_to_us_directly=True
-            )
-            self.reactor.advance(1)
-            self.assertEqual(self.successResultOf(d), None)
-
-        bad_event = FrozenEvent(
-            {
-                "room_id": self.room_id,
-                "sender": "@baduser:test.serv",
-                "event_id": "two:test.serv",
-                "depth": 1000,
-                "origin_server_ts": 1,
-                "type": "m.room.message",
-                "origin": "test.serv",
-                "content": "hewwo?",
-                "auth_events": [],
-                "prev_events": [("one:test.serv", {}), ("three:test.serv", {})],
-            }
-        )
-
-        with LoggingContext(request="bad_event"):
-            d = self.handler.on_receive_pdu(
-                "test.serv", bad_event, sent_to_us_directly=True
-            )
-            self.reactor.advance(1)
-
-        extrem = maybeDeferred(
-            self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
-        )
-        self.assertEqual(self.successResultOf(extrem)[0], "two:test.serv")
-
-        state = self.homeserver.get_state_handler().get_current_state_ids(self.room_id)
-        self.reactor.advance(1)
-        self.assertIn(("m.room.member", "@us:test"), self.successResultOf(state).keys())