summary refs log tree commit diff
path: root/tests/replication/tcp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/replication/tcp')
-rw-r--r--tests/replication/tcp/streams/test_account_data.py4
-rw-r--r--tests/replication/tcp/streams/test_events.py26
-rw-r--r--tests/replication/tcp/streams/test_federation.py2
-rw-r--r--tests/replication/tcp/streams/test_partial_state.py4
-rw-r--r--tests/replication/tcp/streams/test_typing.py37
-rw-r--r--tests/replication/tcp/test_commands.py6
-rw-r--r--tests/replication/tcp/test_handler.py62
-rw-r--r--tests/replication/tcp/test_remote_server_up.py8
8 files changed, 110 insertions, 39 deletions
diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py
index 50fbff5f32..01df1be047 100644
--- a/tests/replication/tcp/streams/test_account_data.py
+++ b/tests/replication/tcp/streams/test_account_data.py
@@ -21,7 +21,7 @@ from tests.replication._base import BaseStreamTestCase
 
 
 class AccountDataStreamTestCase(BaseStreamTestCase):
-    def test_update_function_room_account_data_limit(self):
+    def test_update_function_room_account_data_limit(self) -> None:
         """Test replication with many room account data updates"""
         store = self.hs.get_datastores().main
 
@@ -67,7 +67,7 @@ class AccountDataStreamTestCase(BaseStreamTestCase):
 
         self.assertEqual([], received_rows)
 
-    def test_update_function_global_account_data_limit(self):
+    def test_update_function_global_account_data_limit(self) -> None:
         """Test replication with many global account data updates"""
         store = self.hs.get_datastores().main
 
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 641a94133b..65ef4bb160 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -12,7 +12,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import List, Optional
+from typing import Any, List, Optional, Sequence
+
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
@@ -25,6 +27,8 @@ from synapse.replication.tcp.streams.events import (
 )
 from synapse.rest import admin
 from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests.replication._base import BaseStreamTestCase
 from tests.test_utils.event_injection import inject_event, inject_member_event
@@ -37,7 +41,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
         room.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         super().prepare(reactor, clock, hs)
         self.user_id = self.register_user("u1", "pass")
         self.user_tok = self.login("u1", "pass")
@@ -47,7 +51,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
         self.room_id = self.helper.create_room_as(tok=self.user_tok)
         self.test_handler.received_rdata_rows.clear()
 
-    def test_update_function_event_row_limit(self):
+    def test_update_function_event_row_limit(self) -> None:
         """Test replication with many non-state events
 
         Checks that all events are correctly replicated when there are lots of
@@ -102,7 +106,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
 
         self.assertEqual([], received_rows)
 
-    def test_update_function_huge_state_change(self):
+    def test_update_function_huge_state_change(self) -> None:
         """Test replication with many state events
 
         Ensures that all events are correctly replicated when there are lots of
@@ -135,7 +139,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
         )
 
         # this is the point in the DAG where we make a fork
-        fork_point: List[str] = self.get_success(
+        fork_point: Sequence[str] = self.get_success(
             self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
         )
 
@@ -164,7 +168,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
         pl_event = self.get_success(
             inject_event(
                 self.hs,
-                prev_event_ids=prev_events,
+                prev_event_ids=list(prev_events),
                 type=EventTypes.PowerLevels,
                 state_key="",
                 sender=self.user_id,
@@ -256,7 +260,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
             # "None" indicates the state has been deleted
             self.assertIsNone(sr.event_id)
 
-    def test_update_function_state_row_limit(self):
+    def test_update_function_state_row_limit(self) -> None:
         """Test replication with many state events over several stream ids."""
 
         # we want to generate lots of state changes, but for this test, we want to
@@ -290,7 +294,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
         )
 
         # this is the point in the DAG where we make a fork
-        fork_point: List[str] = self.get_success(
+        fork_point: Sequence[str] = self.get_success(
             self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
         )
 
@@ -319,7 +323,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
             e = self.get_success(
                 inject_event(
                     self.hs,
-                    prev_event_ids=prev_events,
+                    prev_event_ids=list(prev_events),
                     type=EventTypes.PowerLevels,
                     state_key="",
                     sender=self.user_id,
@@ -376,7 +380,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
 
         self.assertEqual([], received_rows)
 
-    def test_backwards_stream_id(self):
+    def test_backwards_stream_id(self) -> None:
         """
         Test that RDATA that comes after the current position should be discarded.
         """
@@ -437,7 +441,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
     event_count = 0
 
     def _inject_test_event(
-        self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs
+        self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs: Any
     ) -> EventBase:
         if sender is None:
             sender = self.user_id
diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py
index bcb82c9c80..cdbdfaf057 100644
--- a/tests/replication/tcp/streams/test_federation.py
+++ b/tests/replication/tcp/streams/test_federation.py
@@ -26,7 +26,7 @@ class FederationStreamTestCase(BaseStreamTestCase):
         config["federation_sender_instances"] = ["federation_sender1"]
         return config
 
-    def test_catchup(self):
+    def test_catchup(self) -> None:
         """Basic test of catchup on reconnect
 
         Makes sure that updates sent while we are offline are received later.
diff --git a/tests/replication/tcp/streams/test_partial_state.py b/tests/replication/tcp/streams/test_partial_state.py
index 2c10eab4db..452ac85069 100644
--- a/tests/replication/tcp/streams/test_partial_state.py
+++ b/tests/replication/tcp/streams/test_partial_state.py
@@ -23,7 +23,7 @@ class PartialStateStreamsTestCase(BaseMultiWorkerStreamTestCase):
     hijack_auth = True
     user_id = "@bob:test"
 
-    def setUp(self):
+    def setUp(self) -> None:
         super().setUp()
         self.store = self.hs.get_datastores().main
 
@@ -37,7 +37,7 @@ class PartialStateStreamsTestCase(BaseMultiWorkerStreamTestCase):
         room_id = self.helper.create_room_as("@bob:test")
         # Mark the room as partial-stated.
         self.get_success(
-            self.store.store_partial_state_room(room_id, ["serv1", "serv2"], 0, "serv1")
+            self.store.store_partial_state_room(room_id, {"serv1", "serv2"}, 0, "serv1")
         )
 
         worker = self.make_worker_hs("synapse.app.generic_worker")
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index 9a229dd23f..5a38ac831f 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 from unittest.mock import Mock
 
-from synapse.handlers.typing import RoomMember
+from synapse.handlers.typing import RoomMember, TypingWriterHandler
 from synapse.replication.tcp.streams import TypingStream
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
@@ -27,11 +27,13 @@ ROOM_ID_2 = "!foo:blue"
 
 
 class TypingStreamTestCase(BaseStreamTestCase):
-    def _build_replication_data_handler(self):
-        return Mock(wraps=super()._build_replication_data_handler())
+    def _build_replication_data_handler(self) -> Mock:
+        self.mock_handler = Mock(wraps=super()._build_replication_data_handler())
+        return self.mock_handler
 
-    def test_typing(self):
+    def test_typing(self) -> None:
         typing = self.hs.get_typing_handler()
+        assert isinstance(typing, TypingWriterHandler)
 
         self.reconnect()
 
@@ -43,8 +45,8 @@ class TypingStreamTestCase(BaseStreamTestCase):
         request = self.handle_http_replication_attempt()
         self.assert_request_is_get_repl_stream_updates(request, "typing")
 
-        self.test_handler.on_rdata.assert_called_once()
-        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        self.mock_handler.on_rdata.assert_called_once()
+        stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
         row: TypingStream.TypingStreamRow = rdata_rows[0]
@@ -54,11 +56,11 @@ class TypingStreamTestCase(BaseStreamTestCase):
         # Now let's disconnect and insert some data.
         self.disconnect()
 
-        self.test_handler.on_rdata.reset_mock()
+        self.mock_handler.on_rdata.reset_mock()
 
         typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=False)
 
-        self.test_handler.on_rdata.assert_not_called()
+        self.mock_handler.on_rdata.assert_not_called()
 
         self.reconnect()
         self.pump(0.1)
@@ -71,15 +73,15 @@ class TypingStreamTestCase(BaseStreamTestCase):
         assert request.args is not None
         self.assertEqual(int(request.args[b"from_token"][0]), token)
 
-        self.test_handler.on_rdata.assert_called_once()
-        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        self.mock_handler.on_rdata.assert_called_once()
+        stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
         row = rdata_rows[0]
         self.assertEqual(ROOM_ID, row.room_id)
         self.assertEqual([], row.user_ids)
 
-    def test_reset(self):
+    def test_reset(self) -> None:
         """
         Test what happens when a typing stream resets.
 
@@ -87,6 +89,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
         sends the proper position and RDATA).
         """
         typing = self.hs.get_typing_handler()
+        assert isinstance(typing, TypingWriterHandler)
 
         self.reconnect()
 
@@ -98,8 +101,8 @@ class TypingStreamTestCase(BaseStreamTestCase):
         request = self.handle_http_replication_attempt()
         self.assert_request_is_get_repl_stream_updates(request, "typing")
 
-        self.test_handler.on_rdata.assert_called_once()
-        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        self.mock_handler.on_rdata.assert_called_once()
+        stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
         row: TypingStream.TypingStreamRow = rdata_rows[0]
@@ -134,15 +137,15 @@ class TypingStreamTestCase(BaseStreamTestCase):
         self.assert_request_is_get_repl_stream_updates(request, "typing")
 
         # Reset the test code.
-        self.test_handler.on_rdata.reset_mock()
-        self.test_handler.on_rdata.assert_not_called()
+        self.mock_handler.on_rdata.reset_mock()
+        self.mock_handler.on_rdata.assert_not_called()
 
         # Push additional data.
         typing._push_update(member=RoomMember(ROOM_ID_2, USER_ID_2), typing=False)
         self.reactor.advance(0)
 
-        self.test_handler.on_rdata.assert_called_once()
-        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        self.mock_handler.on_rdata.assert_called_once()
+        stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
         row = rdata_rows[0]
diff --git a/tests/replication/tcp/test_commands.py b/tests/replication/tcp/test_commands.py
index cca7ebb719..5d6b72b16d 100644
--- a/tests/replication/tcp/test_commands.py
+++ b/tests/replication/tcp/test_commands.py
@@ -21,12 +21,12 @@ from tests.unittest import TestCase
 
 
 class ParseCommandTestCase(TestCase):
-    def test_parse_one_word_command(self):
+    def test_parse_one_word_command(self) -> None:
         line = "REPLICATE"
         cmd = parse_command_from_line(line)
         self.assertIsInstance(cmd, ReplicateCommand)
 
-    def test_parse_rdata(self):
+    def test_parse_rdata(self) -> None:
         line = 'RDATA events master 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]'
         cmd = parse_command_from_line(line)
         assert isinstance(cmd, RdataCommand)
@@ -34,7 +34,7 @@ class ParseCommandTestCase(TestCase):
         self.assertEqual(cmd.instance_name, "master")
         self.assertEqual(cmd.token, 6287863)
 
-    def test_parse_rdata_batch(self):
+    def test_parse_rdata_batch(self) -> None:
         line = 'RDATA presence master batch ["@foo:example.com", "online"]'
         cmd = parse_command_from_line(line)
         assert isinstance(cmd, RdataCommand)
diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py
index 6e4055cc21..bab77b2df7 100644
--- a/tests/replication/tcp/test_handler.py
+++ b/tests/replication/tcp/test_handler.py
@@ -127,6 +127,7 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
 
         # ... updating the cache ID gen on the master still shouldn't cause the
         # deferred to wake up.
+        assert store._cache_id_gen is not None
         ctx = store._cache_id_gen.get_next()
         self.get_success(ctx.__aenter__())
         self.get_success(ctx.__aexit__(None, None, None))
@@ -140,3 +141,64 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
         self.get_success(ctx_worker1.__aexit__(None, None, None))
 
         self.assertTrue(d.called)
+
+    def test_wait_for_stream_position_rdata(self) -> None:
+        """Check that wait for stream position correctly waits for an update
+        from the correct instance, when RDATA is sent.
+        """
+        store = self.hs.get_datastores().main
+        cmd_handler = self.hs.get_replication_command_handler()
+        data_handler = self.hs.get_replication_data_handler()
+
+        worker1 = self.make_worker_hs(
+            "synapse.app.generic_worker",
+            extra_config={
+                "worker_name": "worker1",
+                "run_background_tasks_on": "worker1",
+                "redis": {"enabled": True},
+            },
+        )
+
+        cache_id_gen = worker1.get_datastores().main._cache_id_gen
+        assert cache_id_gen is not None
+
+        self.replicate()
+
+        # First, make sure the master knows that `worker1` exists.
+        initial_token = cache_id_gen.get_current_token()
+        cmd_handler.send_command(
+            PositionCommand("caches", "worker1", initial_token, initial_token)
+        )
+        self.replicate()
+
+        # `wait_for_stream_position` should only return once master receives a
+        # notification that `next_token2` has persisted.
+        ctx_worker1 = cache_id_gen.get_next_mult(2)
+        next_token1, next_token2 = self.get_success(ctx_worker1.__aenter__())
+
+        d = defer.ensureDeferred(
+            data_handler.wait_for_stream_position("worker1", "caches", next_token2)
+        )
+        self.assertFalse(d.called)
+
+        # Insert an entry into the cache stream with token `next_token1`, but
+        # not `next_token2`.
+        self.get_success(
+            store.db_pool.simple_insert(
+                table="cache_invalidation_stream_by_instance",
+                values={
+                    "stream_id": next_token1,
+                    "instance_name": "worker1",
+                    "cache_func": "foo",
+                    "keys": [],
+                    "invalidation_ts": 0,
+                },
+            )
+        )
+
+        # Finish the context manager, triggering the data to be sent to master.
+        self.get_success(ctx_worker1.__aexit__(None, None, None))
+
+        # Master should get told about `next_token2`, so the deferred should
+        # resolve.
+        self.assertTrue(d.called)
diff --git a/tests/replication/tcp/test_remote_server_up.py b/tests/replication/tcp/test_remote_server_up.py
index 545f11acd1..b75fc05fd5 100644
--- a/tests/replication/tcp/test_remote_server_up.py
+++ b/tests/replication/tcp/test_remote_server_up.py
@@ -16,15 +16,17 @@ from typing import Tuple
 
 from twisted.internet.address import IPv4Address
 from twisted.internet.interfaces import IProtocol
-from twisted.test.proto_helpers import StringTransport
+from twisted.test.proto_helpers import MemoryReactor, StringTransport
 
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests.unittest import HomeserverTestCase
 
 
 class RemoteServerUpTestCase(HomeserverTestCase):
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.factory = ReplicationStreamProtocolFactory(hs)
 
     def _make_client(self) -> Tuple[IProtocol, StringTransport]:
@@ -40,7 +42,7 @@ class RemoteServerUpTestCase(HomeserverTestCase):
 
         return proto, transport
 
-    def test_relay(self):
+    def test_relay(self) -> None:
         """Test that Synapse will relay REMOTE_SERVER_UP commands to all
         other connections, but not the one that sent it.
         """