summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/handlers/federation_event.py43
-rw-r--r--synapse/storage/databases/main/event_federation.py2
-rw-r--r--tests/handlers/test_federation_event.py129
-rw-r--r--tests/test_utils/event_injection.py4
4 files changed, 143 insertions, 35 deletions
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index d4dfdc9929..118aaca01d 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -14,6 +14,7 @@
 
 import collections
 import itertools
+import json
 import logging
 from http import HTTPStatus
 from typing import (
@@ -659,7 +660,7 @@ class FederationEventHandler:
             # thrashing.
             reverse_chronological_events = events
             # `[::-1]` is just syntax to reverse the list and give us a copy
-            chronological_events = reverse_chronological_events[::-1]
+            # chronological_events = reverse_chronological_events[::-1]
 
             logger.info(
                 "backfill assumed reverse_chronological_events=%s",
@@ -715,8 +716,8 @@ class FederationEventHandler:
                 # Expecting to persist in chronological order here (oldest ->
                 # newest) so that events are persisted before they're referenced
                 # as a `prev_event`.
-                chronological_events,
-                # reverse_chronological_events,
+                # chronological_events,
+                reverse_chronological_events,
                 backfilled=True,
             )
 
@@ -869,17 +870,20 @@ class FederationEventHandler:
 
         logger.info(
             "backfill sorted_events=%s",
-            [
-                "event_id=%s,depth=%d,body=%s(%s),prevs=%s\n"
-                % (
-                    event.event_id,
-                    event.depth,
-                    event.content.get("body", event.type),
-                    getattr(event, "state_key", None),
-                    event.prev_event_ids(),
-                )
-                for event in sorted_events
-            ],
+            json.dumps(
+                [
+                    "event_id=%s,depth=%d,body=%s(%s),prevs=%s\n"
+                    % (
+                        event.event_id,
+                        event.depth,
+                        event.content.get("body", event.type),
+                        getattr(event, "state_key", None),
+                        event.prev_event_ids(),
+                    )
+                    for event in sorted_events
+                ],
+                indent=4,
+            ),
         )
 
         for ev in sorted_events:
@@ -1160,11 +1164,18 @@ class FederationEventHandler:
             destination, room_id, event_id=event_id
         )
 
-        logger.debug(
-            "state_ids returned %i state events, %i auth events",
+        logger.info(
+            "_get_state_ids_after_missing_prev_event(event_id=%s): state_ids returned %i state events, %i auth events",
+            event_id,
             len(state_event_ids),
             len(auth_event_ids),
         )
+        logger.info(
+            "_get_state_ids_after_missing_prev_event(event_id=%s): state_event_ids=%s auth_event_ids=%s",
+            event_id,
+            state_event_ids,
+            auth_event_ids,
+        )
 
         # Start by checking events we already have in the DB
         desired_events = set(state_event_ids)
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 6df15c1c8f..b8c26adb7b 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -972,7 +972,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
                     1000 * BACKFILL_EVENT_BACKOFF_UPPER_BOUND_SECONDS,
                 ),
             )
-            return cast(List[Tuple[str, int]], txn.fetchall())
+            return cast(List[Tuple[str, int, int]], txn.fetchall())
 
         return await self.db_pool.runInteraction(
             "get_insertion_event_backward_extremities_in_room",
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index 34fb883e6b..f16ac9d1b1 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -11,9 +11,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import pprint
 import json
-from typing import Optional
+from typing import Dict, List, Optional, Tuple
 from unittest import mock
 from unittest.mock import Mock, patch
 
@@ -871,19 +870,96 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
             )
 
     def test_process_pulled_events_asdf(self) -> None:
+        main_store = self.hs.get_datastores().main
+        state_storage_controller = self.hs.get_storage_controllers().state
+
         def _debug_event_string(event: EventBase) -> str:
             debug_body = event.content.get("body", event.type)
             maybe_state_key = getattr(event, "state_key", None)
             return f"event_id={event.event_id},depth={event.depth},body={debug_body}({maybe_state_key}),prevs={event.prev_event_ids()}"
 
-        OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
-        main_store = self.hs.get_datastores().main
+        known_event_dict: Dict[str, Tuple[EventBase, List[EventBase]]] = {}
+
+        def _add_to_known_event_list(
+            event: EventBase, state_events: Optional[List[EventBase]] = None
+        ) -> None:
+            if state_events is None:
+                state_map = self.get_success(
+                    state_storage_controller.get_state_for_event(event.event_id)
+                )
+                state_events = list(state_map.values())
+
+            known_event_dict[event.event_id] = (event, state_events)
+
+        async def get_room_state_ids(
+            destination: str, room_id: str, event_id: str
+        ) -> JsonDict:
+            self.assertEqual(destination, self.OTHER_SERVER_NAME)
+            known_event_info = known_event_dict.get(event_id)
+            if known_event_info is None:
+                self.fail(f"Event ({event_id}) not part of our known events list")
+
+            known_event, known_event_state_list = known_event_info
+            logger.info(
+                "stubbed get_room_state_ids destination=%s event_id=%s auth_event_ids=%s",
+                destination,
+                event_id,
+                known_event.auth_event_ids(),
+            )
+
+            # self.assertEqual(event_id, missing_event.event_id)
+            return {
+                "pdu_ids": [
+                    state_event.event_id for state_event in known_event_state_list
+                ],
+                "auth_chain_ids": known_event.auth_event_ids(),
+            }
+
+        async def get_room_state(
+            room_version: RoomVersion, destination: str, room_id: str, event_id: str
+        ) -> StateRequestResponse:
+            self.assertEqual(destination, self.OTHER_SERVER_NAME)
+            known_event_info = known_event_dict.get(event_id)
+            if known_event_info is None:
+                self.fail(f"Event ({event_id}) not part of our known events list")
+
+            known_event, known_event_state_list = known_event_info
+            logger.info(
+                "stubbed get_room_state destination=%s event_id=%s auth_event_ids=%s",
+                destination,
+                event_id,
+                known_event.auth_event_ids(),
+            )
+
+            auth_event_ids = known_event.auth_event_ids()
+            auth_events = []
+            for auth_event_id in auth_event_ids:
+                known_event_info = known_event_dict.get(event_id)
+                if known_event_info is None:
+                    self.fail(
+                        f"Auth event ({auth_event_id}) is not part of our known events list"
+                    )
+                known_auth_event, _ = known_event_info
+                auth_events.append(known_auth_event)
+
+            return StateRequestResponse(
+                state=known_event_state_list,
+                auth_events=auth_events,
+            )
+
+        self.mock_federation_transport_client.get_room_state_ids.side_effect = (
+            get_room_state_ids
+        )
+        self.mock_federation_transport_client.get_room_state.side_effect = (
+            get_room_state
+        )
 
         # create the room
         room_creator = self.appservice.sender
         room_id = self.helper.create_room_as(
             room_creator=self.appservice.sender, tok=self.appservice.token
         )
+        room_version = self.get_success(main_store.get_room_version(room_id))
 
         user_alice = self.register_user("alice", "pass")
         alice_membership_event = self.get_success(
@@ -899,6 +975,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
                 content={"body": "eventBefore0", "msgtype": "m.text"},
             )
         )
+        _add_to_known_event_list(event_before0)
         event_before1 = self.get_success(
             inject_event(
                 self.hs,
@@ -908,6 +985,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
                 content={"body": "eventBefore1", "msgtype": "m.text"},
             )
         )
+        _add_to_known_event_list(event_before1)
 
         event_after0 = self.get_success(
             inject_event(
@@ -918,6 +996,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
                 content={"body": "eventAfter0", "msgtype": "m.text"},
             )
         )
+        _add_to_known_event_list(event_after0)
         event_after1 = self.get_success(
             inject_event(
                 self.hs,
@@ -927,8 +1006,8 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
                 content={"body": "eventAfter1", "msgtype": "m.text"},
             )
         )
+        _add_to_known_event_list(event_after1)
 
-        state_storage_controller = self.hs.get_storage_controllers().state
         state_map = self.get_success(
             state_storage_controller.get_state_for_event(event_before1.event_id)
         )
@@ -940,13 +1019,17 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
         assert pl_event is not None
         assert as_membership_event is not None
 
+        for state_event in state_map.values():
+            _add_to_known_event_list(state_event)
+
         historical_auth_event_ids = [
             room_create_event.event_id,
             pl_event.event_id,
             as_membership_event.event_id,
         ]
+        historical_state_events = list(state_map.values())
         historical_state_event_ids = [
-            state_event.event_id for state_event in list(state_map.values())
+            state_event.event_id for state_event in historical_state_events
         ]
 
         inherited_depth = event_after0.depth
@@ -969,6 +1052,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
                 depth=inherited_depth,
             )
         )
+        _add_to_known_event_list(insertion_event, historical_state_events)
         historical_message_event, _ = self.get_success(
             create_event(
                 self.hs,
@@ -981,6 +1065,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
                 depth=inherited_depth,
             )
         )
+        _add_to_known_event_list(historical_message_event, historical_state_events)
         batch_event, _ = self.get_success(
             create_event(
                 self.hs,
@@ -996,7 +1081,8 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
                 depth=inherited_depth,
             )
         )
-        base_insertion_event, _ = self.get_success(
+        _add_to_known_event_list(batch_event, historical_state_events)
+        base_insertion_event, base_insertion_event_context = self.get_success(
             create_event(
                 self.hs,
                 room_id=room_id,
@@ -1012,9 +1098,10 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
                 depth=inherited_depth,
             )
         )
+        _add_to_known_event_list(base_insertion_event, historical_state_events)
 
         # Chronological
-        # pulled_events = [
+        # pulled_events: List[EventBase] = [
         #     # Beginning of room (oldest messages)
         #     # *list(state_map.values()),
         #     room_create_event,
@@ -1035,7 +1122,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
         # ]
 
         # The random pattern that may make it be expected
-        pulled_events = [
+        pulled_events: List[EventBase] = [
             # Beginning of room (oldest messages)
             # *list(state_map.values()),
             room_create_event,
@@ -1066,10 +1153,20 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
             ),
         )
 
+        for event, _ in known_event_dict.values():
+            if event.internal_metadata.outlier:
+                self.fail("Our pristine events should not be marked as an outlier")
+
         self.get_success(
             self.hs.get_federation_event_handler()._process_pulled_events(
                 self.OTHER_SERVER_NAME,
-                pulled_events,
+                [
+                    # Make copies of events since Synapse modifies the
+                    # internal_metadata in place and we want to keep our
+                    # pristine copies
+                    make_event_from_dict(pulled_event.get_pdu_json(), room_version)
+                    for pulled_event in pulled_events
+                ],
                 backfilled=True,
             )
         )
@@ -1110,15 +1207,15 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
             # Latest in the room (newest messages)
         ]
 
-        event_id_diff = set([event.event_id for event in expected_event_order]) - set(
-            [event.event_id for event in actual_events_in_room_chronological]
-        )
+        event_id_diff = {event.event_id for event in expected_event_order} - {
+            event.event_id for event in actual_events_in_room_chronological
+        }
         event_diff_ordered = [
             event for event in expected_event_order if event.event_id in event_id_diff
         ]
-        event_id_extra = set(
-            [event.event_id for event in actual_events_in_room_chronological]
-        ) - set([event.event_id for event in expected_event_order])
+        event_id_extra = {
+            event.event_id for event in actual_events_in_room_chronological
+        } - {event.event_id for event in expected_event_order}
         event_extra_ordered = [
             event
             for event in actual_events_in_room_chronological
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index b2cb8d0be7..497ee188ca 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -12,7 +12,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from optparse import Option
 from typing import List, Optional, Tuple
 
 import synapse.server
@@ -104,7 +103,8 @@ async def create_event(
     )
     event, context = await hs.get_event_creation_handler().create_new_client_event(
         builder,
-        allow_no_prev_events=allow_no_prev_events,
+        # Why does this need another default to pass: `Argument "allow_no_prev_events" to "create_new_client_event" of "EventCreationHandler" has incompatible type "Optional[bool]"; expected "bool"`
+        allow_no_prev_events=allow_no_prev_events or False,
         prev_event_ids=prev_event_ids,
         auth_event_ids=auth_event_ids,
         state_event_ids=state_event_ids,