summary refs log tree commit diff
path: root/tests/replication
diff options
context:
space:
mode:
Diffstat (limited to 'tests/replication')
-rw-r--r--tests/replication/tcp/streams/test_events.py10
-rw-r--r--tests/replication/tcp/streams/test_partial_state.py2
-rw-r--r--tests/replication/tcp/streams/test_typing.py4
-rw-r--r--tests/replication/tcp/test_handler.py1
-rw-r--r--tests/replication/test_federation_sender_shard.py2
-rw-r--r--tests/replication/test_pusher_shard.py1
6 files changed, 13 insertions, 7 deletions
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 043dbe76af..65ef4bb160 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Any, List, Optional
+from typing import Any, List, Optional, Sequence
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -139,7 +139,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
         )
 
         # this is the point in the DAG where we make a fork
-        fork_point: List[str] = self.get_success(
+        fork_point: Sequence[str] = self.get_success(
             self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
         )
 
@@ -168,7 +168,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
         pl_event = self.get_success(
             inject_event(
                 self.hs,
-                prev_event_ids=prev_events,
+                prev_event_ids=list(prev_events),
                 type=EventTypes.PowerLevels,
                 state_key="",
                 sender=self.user_id,
@@ -294,7 +294,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
         )
 
         # this is the point in the DAG where we make a fork
-        fork_point: List[str] = self.get_success(
+        fork_point: Sequence[str] = self.get_success(
             self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
         )
 
@@ -323,7 +323,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
             e = self.get_success(
                 inject_event(
                     self.hs,
-                    prev_event_ids=prev_events,
+                    prev_event_ids=list(prev_events),
                     type=EventTypes.PowerLevels,
                     state_key="",
                     sender=self.user_id,
diff --git a/tests/replication/tcp/streams/test_partial_state.py b/tests/replication/tcp/streams/test_partial_state.py
index 38b5020ce0..452ac85069 100644
--- a/tests/replication/tcp/streams/test_partial_state.py
+++ b/tests/replication/tcp/streams/test_partial_state.py
@@ -37,7 +37,7 @@ class PartialStateStreamsTestCase(BaseMultiWorkerStreamTestCase):
         room_id = self.helper.create_room_as("@bob:test")
         # Mark the room as partial-stated.
         self.get_success(
-            self.store.store_partial_state_room(room_id, ["serv1", "serv2"], 0, "serv1")
+            self.store.store_partial_state_room(room_id, {"serv1", "serv2"}, 0, "serv1")
         )
 
         worker = self.make_worker_hs("synapse.app.generic_worker")
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index 68de5d1cc2..5a38ac831f 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 from unittest.mock import Mock
 
-from synapse.handlers.typing import RoomMember
+from synapse.handlers.typing import RoomMember, TypingWriterHandler
 from synapse.replication.tcp.streams import TypingStream
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
@@ -33,6 +33,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
 
     def test_typing(self) -> None:
         typing = self.hs.get_typing_handler()
+        assert isinstance(typing, TypingWriterHandler)
 
         self.reconnect()
 
@@ -88,6 +89,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
         sends the proper position and RDATA).
         """
         typing = self.hs.get_typing_handler()
+        assert isinstance(typing, TypingWriterHandler)
 
         self.reconnect()
 
diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py
index 6e4055cc21..bf927beb6a 100644
--- a/tests/replication/tcp/test_handler.py
+++ b/tests/replication/tcp/test_handler.py
@@ -127,6 +127,7 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
 
         # ... updating the cache ID gen on the master still shouldn't cause the
         # deferred to wake up.
+        assert store._cache_id_gen is not None
         ctx = store._cache_id_gen.get_next()
         self.get_success(ctx.__aenter__())
         self.get_success(ctx.__aexit__(None, None, None))
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 89380e25b5..08703206a9 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -16,6 +16,7 @@ from unittest.mock import Mock
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.events.builder import EventBuilderFactory
+from synapse.handlers.typing import TypingWriterHandler
 from synapse.rest.admin import register_servlets_for_client_rest_resource
 from synapse.rest.client import login, room
 from synapse.types import UserID, create_requester
@@ -174,6 +175,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         token = self.login("user3", "pass")
 
         typing_handler = self.hs.get_typing_handler()
+        assert isinstance(typing_handler, TypingWriterHandler)
 
         sent_on_1 = False
         sent_on_2 = False
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 9345cfbeb2..0798b021c3 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -50,6 +50,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
         user_dict = self.get_success(
             self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
+        assert user_dict is not None
         token_id = user_dict.token_id
 
         self.get_success(