summary refs log tree commit diff
path: root/tests/rest/client/test_sync.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client/test_sync.py')
-rw-r--r--tests/rest/client/test_sync.py47
1 files changed, 27 insertions, 20 deletions
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index e062561365..69b4ef5378 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -13,9 +13,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import json
+from typing import List, Optional
 
 from parameterized import parameterized
 
+from twisted.test.proto_helpers import MemoryReactor
+
 import synapse.rest.admin
 from synapse.api.constants import (
     EventContentFields,
@@ -24,6 +27,9 @@ from synapse.api.constants import (
     RelationTypes,
 )
 from synapse.rest.client import devices, knock, login, read_marker, receipts, room, sync
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
 
 from tests import unittest
 from tests.federation.transport.test_knocking import (
@@ -43,7 +49,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
         sync.register_servlets,
     ]
 
-    def test_sync_argless(self):
+    def test_sync_argless(self) -> None:
         channel = self.make_request("GET", "/sync")
 
         self.assertEqual(channel.code, 200)
@@ -58,7 +64,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
         sync.register_servlets,
     ]
 
-    def test_sync_filter_labels(self):
+    def test_sync_filter_labels(self) -> None:
         """Test that we can filter by a label."""
         sync_filter = json.dumps(
             {
@@ -77,7 +83,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
         self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
         self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
 
-    def test_sync_filter_not_labels(self):
+    def test_sync_filter_not_labels(self) -> None:
         """Test that we can filter by the absence of a label."""
         sync_filter = json.dumps(
             {
@@ -99,7 +105,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
             events[2]["content"]["body"], "with two wrong labels", events[2]
         )
 
-    def test_sync_filter_labels_not_labels(self):
+    def test_sync_filter_labels_not_labels(self) -> None:
         """Test that we can filter by both a label and the absence of another label."""
         sync_filter = json.dumps(
             {
@@ -118,7 +124,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
         self.assertEqual(len(events), 1, [event["content"] for event in events])
         self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
 
-    def _test_sync_filter_labels(self, sync_filter):
+    def _test_sync_filter_labels(self, sync_filter: str) -> List[JsonDict]:
         user_id = self.register_user("kermit", "test")
         tok = self.login("kermit", "test")
 
@@ -194,7 +200,7 @@ class SyncTypingTests(unittest.HomeserverTestCase):
     user_id = True
     hijack_auth = False
 
-    def test_sync_backwards_typing(self):
+    def test_sync_backwards_typing(self) -> None:
         """
         If the typing serial goes backwards and the typing handler is then reset
         (such as when the master restarts and sets the typing serial to 0), we
@@ -298,7 +304,7 @@ class SyncKnockTestCase(
         knock.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
         self.url = "/sync?since=%s"
         self.next_batch = "s0"
@@ -336,7 +342,7 @@ class SyncKnockTestCase(
         )
 
     @override_config({"experimental_features": {"msc2403_enabled": True}})
-    def test_knock_room_state(self):
+    def test_knock_room_state(self) -> None:
         """Tests that /sync returns state from a room after knocking on it."""
         # Knock on a room
         channel = self.make_request(
@@ -383,7 +389,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
         sync.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.url = "/sync?since=%s"
         self.next_batch = "s0"
 
@@ -402,7 +408,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
         self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
 
     @override_config({"experimental_features": {"msc2285_enabled": True}})
-    def test_hidden_read_receipts(self):
+    def test_hidden_read_receipts(self) -> None:
         # Send a message as the first user
         res = self.helper.send(self.room_id, body="hello", tok=self.tok)
 
@@ -441,8 +447,8 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
         ]
     )
     def test_read_receipt_with_empty_body(
-        self, name, user_agent: str, expected_status_code: int
-    ):
+        self, name: str, user_agent: str, expected_status_code: int
+    ) -> None:
         # Send a message as the first user
         res = self.helper.send(self.room_id, body="hello", tok=self.tok)
 
@@ -455,11 +461,11 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(channel.code, expected_status_code)
 
-    def _get_read_receipt(self):
+    def _get_read_receipt(self) -> Optional[JsonDict]:
         """Syncs and returns the read receipt."""
 
         # Checks if event is a read receipt
-        def is_read_receipt(event):
+        def is_read_receipt(event: JsonDict) -> bool:
             return event["type"] == "m.receipt"
 
         # Sync
@@ -477,7 +483,8 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
         ephemeral_events = channel.json_body["rooms"]["join"][self.room_id][
             "ephemeral"
         ]["events"]
-        return next(filter(is_read_receipt, ephemeral_events), None)
+        receipt_event = filter(is_read_receipt, ephemeral_events)
+        return next(receipt_event, None)
 
 
 class UnreadMessagesTestCase(unittest.HomeserverTestCase):
@@ -490,7 +497,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
         receipts.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.url = "/sync?since=%s"
         self.next_batch = "s0"
 
@@ -533,7 +540,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
             tok=self.tok,
         )
 
-    def test_unread_counts(self):
+    def test_unread_counts(self) -> None:
         """Tests that /sync returns the right value for the unread count (MSC2654)."""
 
         # Check that our own messages don't increase the unread count.
@@ -640,7 +647,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
         )
         self._check_unread_count(5)
 
-    def _check_unread_count(self, expected_count: int):
+    def _check_unread_count(self, expected_count: int) -> None:
         """Syncs and compares the unread count with the expected value."""
 
         channel = self.make_request(
@@ -669,7 +676,7 @@ class SyncCacheTestCase(unittest.HomeserverTestCase):
         sync.register_servlets,
     ]
 
-    def test_noop_sync_does_not_tightloop(self):
+    def test_noop_sync_does_not_tightloop(self) -> None:
         """If the sync times out, we shouldn't cache the result
 
         Essentially a regression test for #8518.
@@ -720,7 +727,7 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
         devices.register_servlets,
     ]
 
-    def test_user_with_no_rooms_receives_self_device_list_updates(self):
+    def test_user_with_no_rooms_receives_self_device_list_updates(self) -> None:
         """Tests that a user with no rooms still receives their own device list updates"""
         device_id = "TESTDEVICE"