summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/federation/test_federation_catch_up.py87
-rw-r--r--tests/rest/client/test_third_party_rules.py19
-rw-r--r--tests/test_utils/event_injection.py31
-rw-r--r--tests/test_visibility.py40
4 files changed, 162 insertions, 15 deletions
diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index 6381583c24..391ae51707 100644
--- a/tests/federation/test_federation_catch_up.py
+++ b/tests/federation/test_federation_catch_up.py
@@ -1,4 +1,5 @@
-from typing import Callable, List, Optional, Tuple
+from typing import Callable, Collection, List, Optional, Tuple
+from unittest import mock
 from unittest.mock import Mock
 
 from twisted.test.proto_helpers import MemoryReactor
@@ -500,3 +501,87 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
         self.assertEqual(len(sent_pdus), 1)
         self.assertEqual(sent_pdus[0].event_id, event_2.event_id)
         self.assertFalse(per_dest_queue._catching_up)
+
+    def test_catch_up_is_not_blocked_by_remote_event_in_partial_state_room(
+        self,
+    ) -> None:
+        """Detects (part of?) https://github.com/matrix-org/synapse/issues/15220."""
+        # ARRANGE:
+        # - a local user (u1)
+        # - a room which contains u1 and two remote users, @u2:host2 and @u3:other
+        # - events in that room such that
+        #   - history visibility is restricted
+        #   - u1 sent message events e1 and e2
+        #   - afterwards, u3 sent a remote event e3
+        # - catchup to begin for host2; last successfully sent event was e1
+        per_dest_queue, sent_pdus = self.make_fake_destination_queue()
+
+        self.register_user("u1", "you the one")
+        u1_token = self.login("u1", "you the one")
+        room = self.helper.create_room_as("u1", tok=u1_token)
+        self.helper.send_state(
+            room_id=room,
+            event_type="m.room.history_visibility",
+            body={"history_visibility": "joined"},
+            tok=u1_token,
+        )
+        self.get_success(
+            event_injection.inject_member_event(self.hs, room, "@u2:host2", "join")
+        )
+        self.get_success(
+            event_injection.inject_member_event(self.hs, room, "@u3:other", "join")
+        )
+
+        # create some events
+        event_id_1 = self.helper.send(room, "hello", tok=u1_token)["event_id"]
+        event_id_2 = self.helper.send(room, "world", tok=u1_token)["event_id"]
+        # pretend that u3 changes their displayname
+        event_id_3 = self.get_success(
+            event_injection.inject_member_event(self.hs, room, "@u3:other", "join")
+        ).event_id
+
+        # destination_rooms should already be populated, but let us pretend that we already
+        # sent (successfully) up to and including event id 1
+        event_1 = self.get_success(self.hs.get_datastores().main.get_event(event_id_1))
+        assert event_1.internal_metadata.stream_ordering is not None
+        self.get_success(
+            self.hs.get_datastores().main.set_destination_last_successful_stream_ordering(
+                "host2", event_1.internal_metadata.stream_ordering
+            )
+        )
+
+        # also fetch event 2 so we can compare its stream ordering to the sender's
+        # last_successful_stream_ordering later
+        event_2 = self.get_success(self.hs.get_datastores().main.get_event(event_id_2))
+
+        # Mock event 3 as having partial state
+        self.get_success(
+            event_injection.mark_event_as_partial_state(self.hs, event_id_3, room)
+        )
+
+        # Fail the test if we block on full state for event 3.
+        async def mock_await_full_state(event_ids: Collection[str]) -> None:
+            if event_id_3 in event_ids:
+                raise AssertionError("Tried to await full state for event_id_3")
+
+        # ACT
+        with mock.patch.object(
+            self.hs.get_storage_controllers().state._partial_state_events_tracker,
+            "await_full_state",
+            mock_await_full_state,
+        ):
+            self.get_success(per_dest_queue._catch_up_transmission_loop())
+
+        # ASSERT
+        # We should have:
+        # - not sent event 3: it's not ours, and the room is partial stated
+        # - fallen back to sending event 2: it's the most recent event in the room
+        #   we tried to send to host2
+        # - completed catch-up
+        self.assertEqual(len(sent_pdus), 1)
+        self.assertEqual(sent_pdus[0].event_id, event_id_2)
+        self.assertFalse(per_dest_queue._catching_up)
+        self.assertEqual(
+            per_dest_queue._last_successful_stream_ordering,
+            event_2.internal_metadata.stream_ordering,
+        )
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 3b99513707..753ecc8d16 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -941,18 +941,16 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         just before associating and removing a 3PID to/from an account.
         """
         # Pretend to be a Synapse module and register both callbacks as mocks.
-        third_party_rules = self.hs.get_third_party_event_rules()
         on_add_user_third_party_identifier_callback_mock = Mock(
             return_value=make_awaitable(None)
         )
         on_remove_user_third_party_identifier_callback_mock = Mock(
             return_value=make_awaitable(None)
         )
-        third_party_rules._on_threepid_bind_callbacks.append(
-            on_add_user_third_party_identifier_callback_mock
-        )
-        third_party_rules._on_threepid_bind_callbacks.append(
-            on_remove_user_third_party_identifier_callback_mock
+        third_party_rules = self.hs.get_third_party_event_rules()
+        third_party_rules.register_third_party_rules_callbacks(
+            on_add_user_third_party_identifier=on_add_user_third_party_identifier_callback_mock,
+            on_remove_user_third_party_identifier=on_remove_user_third_party_identifier_callback_mock,
         )
 
         # Register an admin user.
@@ -1008,12 +1006,12 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         when a user is deactivated and their third-party ID associations are deleted.
         """
         # Pretend to be a Synapse module and register both callbacks as mocks.
-        third_party_rules = self.hs.get_third_party_event_rules()
         on_remove_user_third_party_identifier_callback_mock = Mock(
             return_value=make_awaitable(None)
         )
-        third_party_rules._on_threepid_bind_callbacks.append(
-            on_remove_user_third_party_identifier_callback_mock
+        third_party_rules = self.hs.get_third_party_event_rules()
+        third_party_rules.register_third_party_rules_callbacks(
+            on_remove_user_third_party_identifier=on_remove_user_third_party_identifier_callback_mock,
         )
 
         # Register an admin user.
@@ -1039,6 +1037,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         )
         self.assertEqual(channel.code, 200, channel.json_body)
 
+        # Check that the mock was not called on the act of adding a third-party ID.
+        on_remove_user_third_party_identifier_callback_mock.assert_not_called()
+
         # Now deactivate the user.
         channel = self.make_request(
             "PUT",
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index a6330ed840..9679904c33 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -102,3 +102,34 @@ async def create_event(
     context = await unpersisted_context.persist(event)
 
     return event, context
+
+
+async def mark_event_as_partial_state(
+    hs: synapse.server.HomeServer,
+    event_id: str,
+    room_id: str,
+) -> None:
+    """
+    (Falsely) mark an event as having partial state.
+
+    Naughty, but occasionally useful when checking that partial state doesn't
+    block something from happening.
+
+    If the event already has partial state, this insert will fail (event_id is unique
+    in this table).
+    """
+    store = hs.get_datastores().main
+    await store.db_pool.simple_upsert(
+        table="partial_state_rooms",
+        keyvalues={"room_id": room_id},
+        values={},
+        insertion_values={"room_id": room_id},
+    )
+
+    await store.db_pool.simple_insert(
+        table="partial_state_events",
+        values={
+            "room_id": room_id,
+            "event_id": event_id,
+        },
+    )
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 2801a950a8..9ed330f554 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -63,7 +63,13 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
 
         filtered = self.get_success(
             filter_events_for_server(
-                self._storage_controllers, "test_server", "hs", events_to_filter
+                self._storage_controllers,
+                "test_server",
+                "hs",
+                events_to_filter,
+                redact=True,
+                filter_out_erased_senders=True,
+                filter_out_remote_partial_state_events=True,
             )
         )
 
@@ -85,7 +91,13 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         self.assertEqual(
             self.get_success(
                 filter_events_for_server(
-                    self._storage_controllers, "remote_hs", "hs", [outlier]
+                    self._storage_controllers,
+                    "remote_hs",
+                    "hs",
+                    [outlier],
+                    redact=True,
+                    filter_out_erased_senders=True,
+                    filter_out_remote_partial_state_events=True,
                 )
             ),
             [outlier],
@@ -96,7 +108,13 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
 
         filtered = self.get_success(
             filter_events_for_server(
-                self._storage_controllers, "remote_hs", "local_hs", [outlier, evt]
+                self._storage_controllers,
+                "remote_hs",
+                "local_hs",
+                [outlier, evt],
+                redact=True,
+                filter_out_erased_senders=True,
+                filter_out_remote_partial_state_events=True,
             )
         )
         self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}")
@@ -108,7 +126,13 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         # be redacted)
         filtered = self.get_success(
             filter_events_for_server(
-                self._storage_controllers, "other_server", "local_hs", [outlier, evt]
+                self._storage_controllers,
+                "other_server",
+                "local_hs",
+                [outlier, evt],
+                redact=True,
+                filter_out_erased_senders=True,
+                filter_out_remote_partial_state_events=True,
             )
         )
         self.assertEqual(filtered[0], outlier)
@@ -143,7 +167,13 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         # ... and the filtering happens.
         filtered = self.get_success(
             filter_events_for_server(
-                self._storage_controllers, "test_server", "local_hs", events_to_filter
+                self._storage_controllers,
+                "test_server",
+                "local_hs",
+                events_to_filter,
+                redact=True,
+                filter_out_erased_senders=True,
+                filter_out_remote_partial_state_events=True,
             )
         )