diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index e062561365..69b4ef5378 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -13,9 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
+from typing import List, Optional
from parameterized import parameterized
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse.rest.admin
from synapse.api.constants import (
EventContentFields,
@@ -24,6 +27,9 @@ from synapse.api.constants import (
RelationTypes,
)
from synapse.rest.client import devices, knock, login, read_marker, receipts, room, sync
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
from tests import unittest
from tests.federation.transport.test_knocking import (
@@ -43,7 +49,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
sync.register_servlets,
]
- def test_sync_argless(self):
+ def test_sync_argless(self) -> None:
channel = self.make_request("GET", "/sync")
self.assertEqual(channel.code, 200)
@@ -58,7 +64,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
sync.register_servlets,
]
- def test_sync_filter_labels(self):
+ def test_sync_filter_labels(self) -> None:
"""Test that we can filter by a label."""
sync_filter = json.dumps(
{
@@ -77,7 +83,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
- def test_sync_filter_not_labels(self):
+ def test_sync_filter_not_labels(self) -> None:
"""Test that we can filter by the absence of a label."""
sync_filter = json.dumps(
{
@@ -99,7 +105,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
events[2]["content"]["body"], "with two wrong labels", events[2]
)
- def test_sync_filter_labels_not_labels(self):
+ def test_sync_filter_labels_not_labels(self) -> None:
"""Test that we can filter by both a label and the absence of another label."""
sync_filter = json.dumps(
{
@@ -118,7 +124,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(events), 1, [event["content"] for event in events])
self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
- def _test_sync_filter_labels(self, sync_filter):
+ def _test_sync_filter_labels(self, sync_filter: str) -> List[JsonDict]:
user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test")
@@ -194,7 +200,7 @@ class SyncTypingTests(unittest.HomeserverTestCase):
user_id = True
hijack_auth = False
- def test_sync_backwards_typing(self):
+ def test_sync_backwards_typing(self) -> None:
"""
If the typing serial goes backwards and the typing handler is then reset
(such as when the master restarts and sets the typing serial to 0), we
@@ -298,7 +304,7 @@ class SyncKnockTestCase(
knock.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.url = "/sync?since=%s"
self.next_batch = "s0"
@@ -336,7 +342,7 @@ class SyncKnockTestCase(
)
@override_config({"experimental_features": {"msc2403_enabled": True}})
- def test_knock_room_state(self):
+ def test_knock_room_state(self) -> None:
"""Tests that /sync returns state from a room after knocking on it."""
# Knock on a room
channel = self.make_request(
@@ -383,7 +389,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
sync.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.url = "/sync?since=%s"
self.next_batch = "s0"
@@ -402,7 +408,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
@override_config({"experimental_features": {"msc2285_enabled": True}})
- def test_hidden_read_receipts(self):
+ def test_hidden_read_receipts(self) -> None:
# Send a message as the first user
res = self.helper.send(self.room_id, body="hello", tok=self.tok)
@@ -441,8 +447,8 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
]
)
def test_read_receipt_with_empty_body(
- self, name, user_agent: str, expected_status_code: int
- ):
+ self, name: str, user_agent: str, expected_status_code: int
+ ) -> None:
# Send a message as the first user
res = self.helper.send(self.room_id, body="hello", tok=self.tok)
@@ -455,11 +461,11 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, expected_status_code)
- def _get_read_receipt(self):
+ def _get_read_receipt(self) -> Optional[JsonDict]:
"""Syncs and returns the read receipt."""
# Checks if event is a read receipt
- def is_read_receipt(event):
+ def is_read_receipt(event: JsonDict) -> bool:
return event["type"] == "m.receipt"
# Sync
@@ -477,7 +483,8 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
ephemeral_events = channel.json_body["rooms"]["join"][self.room_id][
"ephemeral"
]["events"]
- return next(filter(is_read_receipt, ephemeral_events), None)
+ receipt_event = filter(is_read_receipt, ephemeral_events)
+ return next(receipt_event, None)
class UnreadMessagesTestCase(unittest.HomeserverTestCase):
@@ -490,7 +497,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
receipts.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.url = "/sync?since=%s"
self.next_batch = "s0"
@@ -533,7 +540,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
tok=self.tok,
)
- def test_unread_counts(self):
+ def test_unread_counts(self) -> None:
"""Tests that /sync returns the right value for the unread count (MSC2654)."""
# Check that our own messages don't increase the unread count.
@@ -640,7 +647,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
)
self._check_unread_count(5)
- def _check_unread_count(self, expected_count: int):
+ def _check_unread_count(self, expected_count: int) -> None:
"""Syncs and compares the unread count with the expected value."""
channel = self.make_request(
@@ -669,7 +676,7 @@ class SyncCacheTestCase(unittest.HomeserverTestCase):
sync.register_servlets,
]
- def test_noop_sync_does_not_tightloop(self):
+ def test_noop_sync_does_not_tightloop(self) -> None:
"""If the sync times out, we shouldn't cache the result
Essentially a regression test for #8518.
@@ -720,7 +727,7 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
devices.register_servlets,
]
- def test_user_with_no_rooms_receives_self_device_list_updates(self):
+ def test_user_with_no_rooms_receives_self_device_list_updates(self) -> None:
"""Tests that a user with no rooms still receives their own device list updates"""
device_id = "TESTDEVICE"
|