diff --git a/changelog.d/14680.misc b/changelog.d/14680.misc
new file mode 100644
index 0000000000..d44571b731
--- /dev/null
+++ b/changelog.d/14680.misc
@@ -0,0 +1 @@
+Add missing type hints.
diff --git a/mypy.ini b/mypy.ini
index 37acf589c9..1a37414e58 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -95,10 +95,7 @@ disallow_untyped_defs = True
[mypy-tests.federation.transport.test_client]
disallow_untyped_defs = True
-[mypy-tests.handlers.test_sso]
-disallow_untyped_defs = True
-
-[mypy-tests.handlers.test_user_directory]
+[mypy-tests.handlers.*]
disallow_untyped_defs = True
[mypy-tests.metrics.test_background_process_metrics]
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 8b9ef25d29..30f2d46c3c 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -2031,7 +2031,7 @@ class PasswordAuthProvider:
self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = []
# Mapping from login type to login parameters
- self._supported_login_types: Dict[str, Iterable[str]] = {}
+ self._supported_login_types: Dict[str, Tuple[str, ...]] = {}
# Mapping from login type to auth checker callbacks
self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {}
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 57bfbd7734..a7495ab21a 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -31,7 +31,7 @@ from synapse.appservice import (
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.rest.client import login, receipts, register, room, sendtodevice
from synapse.server import HomeServer
-from synapse.types import RoomStreamToken
+from synapse.types import JsonDict, RoomStreamToken
from synapse.util import Clock
from synapse.util.stringutils import random_string
@@ -44,7 +44,7 @@ from tests.utils import MockClock
class AppServiceHandlerTestCase(unittest.TestCase):
"""Tests the ApplicationServicesHandler."""
- def setUp(self):
+ def setUp(self) -> None:
self.mock_store = Mock()
self.mock_as_api = Mock()
self.mock_scheduler = Mock()
@@ -61,7 +61,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.handler = ApplicationServicesHandler(hs)
self.event_source = hs.get_event_sources()
- def test_notify_interested_services(self):
+ def test_notify_interested_services(self) -> None:
interested_service = self._mkservice(is_interested_in_event=True)
services = [
self._mkservice(is_interested_in_event=False),
@@ -90,7 +90,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
interested_service, events=[event]
)
- def test_query_user_exists_unknown_user(self):
+ def test_query_user_exists_unknown_user(self) -> None:
user_id = "@someone:anywhere"
services = [self._mkservice(is_interested_in_event=True)]
services[0].is_interested_in_user.return_value = True
@@ -107,7 +107,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
- def test_query_user_exists_known_user(self):
+ def test_query_user_exists_known_user(self) -> None:
user_id = "@someone:anywhere"
services = [self._mkservice(is_interested_in_event=True)]
services[0].is_interested_in_user.return_value = True
@@ -127,7 +127,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
"query_user called when it shouldn't have been.",
)
- def test_query_room_alias_exists(self):
+ def test_query_room_alias_exists(self) -> None:
room_alias_str = "#foo:bar"
room_alias = Mock()
room_alias.to_string.return_value = room_alias_str
@@ -157,7 +157,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.assertEqual(result.room_id, room_id)
self.assertEqual(result.servers, servers)
- def test_get_3pe_protocols_no_appservices(self):
+ def test_get_3pe_protocols_no_appservices(self) -> None:
self.mock_store.get_app_services.return_value = []
response = self.successResultOf(
defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol"))
@@ -165,7 +165,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.mock_as_api.get_3pe_protocol.assert_not_called()
self.assertEqual(response, {})
- def test_get_3pe_protocols_no_protocols(self):
+ def test_get_3pe_protocols_no_protocols(self) -> None:
service = self._mkservice(False, [])
self.mock_store.get_app_services.return_value = [service]
response = self.successResultOf(
@@ -174,7 +174,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.mock_as_api.get_3pe_protocol.assert_not_called()
self.assertEqual(response, {})
- def test_get_3pe_protocols_protocol_no_response(self):
+ def test_get_3pe_protocols_protocol_no_response(self) -> None:
service = self._mkservice(False, ["my-protocol"])
self.mock_store.get_app_services.return_value = [service]
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None)
@@ -186,7 +186,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
)
self.assertEqual(response, {})
- def test_get_3pe_protocols_select_one_protocol(self):
+ def test_get_3pe_protocols_select_one_protocol(self) -> None:
service = self._mkservice(False, ["my-protocol"])
self.mock_store.get_app_services.return_value = [service]
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
@@ -202,7 +202,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
response, {"my-protocol": {"x-protocol-data": 42, "instances": []}}
)
- def test_get_3pe_protocols_one_protocol(self):
+ def test_get_3pe_protocols_one_protocol(self) -> None:
service = self._mkservice(False, ["my-protocol"])
self.mock_store.get_app_services.return_value = [service]
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
@@ -218,7 +218,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
response, {"my-protocol": {"x-protocol-data": 42, "instances": []}}
)
- def test_get_3pe_protocols_multiple_protocol(self):
+ def test_get_3pe_protocols_multiple_protocol(self) -> None:
service_one = self._mkservice(False, ["my-protocol"])
service_two = self._mkservice(False, ["other-protocol"])
self.mock_store.get_app_services.return_value = [service_one, service_two]
@@ -237,11 +237,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
},
)
- def test_get_3pe_protocols_multiple_info(self):
+ def test_get_3pe_protocols_multiple_info(self) -> None:
service_one = self._mkservice(False, ["my-protocol"])
service_two = self._mkservice(False, ["my-protocol"])
- async def get_3pe_protocol(service, unusedProtocol):
+ async def get_3pe_protocol(
+ service: ApplicationService, protocol: str
+ ) -> Optional[JsonDict]:
if service == service_one:
return {
"x-protocol-data": 42,
@@ -276,7 +278,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
},
)
- def test_notify_interested_services_ephemeral(self):
+ def test_notify_interested_services_ephemeral(self) -> None:
"""
Test sending ephemeral events to the appservice handler are scheduled
to be pushed out to interested appservices, and that the stream ID is
@@ -306,7 +308,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
580,
)
- def test_notify_interested_services_ephemeral_out_of_order(self):
+ def test_notify_interested_services_ephemeral_out_of_order(self) -> None:
"""
Test sending out of order ephemeral events to the appservice handler
are ignored.
@@ -390,7 +392,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
receipts.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.hs = hs
# Mock the ApplicationServiceScheduler's _TransactionController's send method so that
# we can track any outgoing ephemeral events
@@ -417,7 +419,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
"exclusive_as_user", "password", self.exclusive_as_user_device_id
)
- def _notify_interested_services(self):
+ def _notify_interested_services(self) -> None:
# This is normally set in `notify_interested_services` but we need to call the
# internal async version so the reactor gets pushed to completion.
self.hs.get_application_service_handler().current_max += 1
@@ -443,7 +445,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
)
def test_match_interesting_room_members(
self, interesting_user: str, should_notify: bool
- ):
+ ) -> None:
"""
Test to make sure that a interesting user (local or remote) in the room is
notified as expected when someone else in the room sends a message.
@@ -512,7 +514,9 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
else:
self.send_mock.assert_not_called()
- def test_application_services_receive_events_sent_by_interesting_local_user(self):
+ def test_application_services_receive_events_sent_by_interesting_local_user(
+ self,
+ ) -> None:
"""
Test to make sure that a messages sent from a local user can be interesting and
picked up by the appservice.
@@ -568,7 +572,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
self.assertEqual(events[0]["type"], "m.room.message")
self.assertEqual(events[0]["sender"], alice)
- def test_sending_read_receipt_batches_to_application_services(self):
+ def test_sending_read_receipt_batches_to_application_services(self) -> None:
"""Tests that a large batch of read receipts are sent correctly to
interested application services.
"""
@@ -644,7 +648,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
@unittest.override_config(
{"experimental_features": {"msc2409_to_device_messages_enabled": True}}
)
- def test_application_services_receive_local_to_device(self):
+ def test_application_services_receive_local_to_device(self) -> None:
"""
Test that when a user sends a to-device message to another user
that is an application service's user namespace, the
@@ -722,7 +726,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
@unittest.override_config(
{"experimental_features": {"msc2409_to_device_messages_enabled": True}}
)
- def test_application_services_receive_bursts_of_to_device(self):
+ def test_application_services_receive_bursts_of_to_device(self) -> None:
"""
Test that when a user sends >100 to-device messages at once, any
interested AS's will receive them in separate transactions.
@@ -913,7 +917,7 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase)
experimental_feature_enabled: bool,
as_supports_txn_extensions: bool,
as_should_receive_device_list_updates: bool,
- ):
+ ) -> None:
"""
Tests that an application service receives notice of changed device
lists for a user, when a user changes their device lists.
@@ -1070,7 +1074,7 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
and a room for the users to talk in.
"""
- async def preparation():
+ async def preparation() -> None:
await self._add_otks_for_device(self._sender_user, self._sender_device, 42)
await self._add_fallback_key_for_device(
self._sender_user, self._sender_device, used=True
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index 2b21547d0f..2733719d82 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -199,7 +199,7 @@ class CasHandlerTestCase(HomeserverTestCase):
)
-def _mock_request():
+def _mock_request() -> Mock:
"""Returns a mock which will stand in as a SynapseRequest"""
mock = Mock(
spec=[
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 3b72c4c9d0..90aec484c4 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -20,6 +20,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.api.errors
import synapse.rest.admin
from synapse.api.constants import EventTypes
+from synapse.events import EventBase
from synapse.rest.client import directory, login, room
from synapse.server import HomeServer
from synapse.types import JsonDict, RoomAlias, create_requester
@@ -201,7 +202,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
self.test_user_tok = self.login("user", "pass")
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
- def _create_alias(self, user) -> None:
+ def _create_alias(self, user: str) -> None:
# Create a new alias to this room.
self.get_success(
self.store.create_room_alias_association(
@@ -324,7 +325,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
)
return room_alias
- def _set_canonical_alias(self, content) -> None:
+ def _set_canonical_alias(self, content: JsonDict) -> None:
"""Configure the canonical alias state on the room."""
self.helper.send_state(
self.room_id,
@@ -333,13 +334,15 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
tok=self.admin_user_tok,
)
- def _get_canonical_alias(self):
+ def _get_canonical_alias(self) -> EventBase:
"""Get the canonical alias state of the room."""
- return self.get_success(
+ result = self.get_success(
self._storage_controllers.state.get_current_state_event(
self.room_id, EventTypes.CanonicalAlias, ""
)
)
+ assert result is not None
+ return result
def test_remove_alias(self) -> None:
"""Removing an alias that is the canonical alias should remove it there too."""
@@ -349,8 +352,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
)
data = self._get_canonical_alias()
- self.assertEqual(data["content"]["alias"], self.test_alias)
- self.assertEqual(data["content"]["alt_aliases"], [self.test_alias])
+ self.assertEqual(data.content["alias"], self.test_alias)
+ self.assertEqual(data.content["alt_aliases"], [self.test_alias])
# Finally, delete the alias.
self.get_success(
@@ -360,8 +363,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
)
data = self._get_canonical_alias()
- self.assertNotIn("alias", data["content"])
- self.assertNotIn("alt_aliases", data["content"])
+ self.assertNotIn("alias", data.content)
+ self.assertNotIn("alt_aliases", data.content)
def test_remove_other_alias(self) -> None:
"""Removing an alias listed as in alt_aliases should remove it there too."""
@@ -378,9 +381,9 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
)
data = self._get_canonical_alias()
- self.assertEqual(data["content"]["alias"], self.test_alias)
+ self.assertEqual(data.content["alias"], self.test_alias)
self.assertEqual(
- data["content"]["alt_aliases"], [self.test_alias, other_test_alias]
+ data.content["alt_aliases"], [self.test_alias, other_test_alias]
)
# Delete the second alias.
@@ -391,8 +394,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
)
data = self._get_canonical_alias()
- self.assertEqual(data["content"]["alias"], self.test_alias)
- self.assertEqual(data["content"]["alt_aliases"], [self.test_alias])
+ self.assertEqual(data.content["alias"], self.test_alias)
+ self.assertEqual(data.content["alt_aliases"], [self.test_alias])
class TestCreateAliasACL(unittest.HomeserverTestCase):
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 9b7e7a8e9a..6c0b30de9e 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -17,7 +17,11 @@
import copy
from unittest import mock
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.errors import SynapseError
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests import unittest
@@ -39,14 +43,14 @@ room_keys = {
class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(replication_layer=mock.Mock())
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_e2e_room_keys_handler()
self.local_user = "@boris:" + hs.hostname
- def test_get_missing_current_version_info(self):
+ def test_get_missing_current_version_info(self) -> None:
"""Check that we get a 404 if we ask for info about the current version
if there is no version.
"""
@@ -56,7 +60,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code
self.assertEqual(res, 404)
- def test_get_missing_version_info(self):
+ def test_get_missing_version_info(self) -> None:
"""Check that we get a 404 if we ask for info about a specific version
if it doesn't exist.
"""
@@ -67,9 +71,9 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code
self.assertEqual(res, 404)
- def test_create_version(self):
+ def test_create_version(self) -> None:
"""Check that we can create and then retrieve versions."""
- res = self.get_success(
+ version = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -78,7 +82,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)
)
- self.assertEqual(res, "1")
+ self.assertEqual(version, "1")
# check we can retrieve it as the current version
res = self.get_success(self.handler.get_version_info(self.local_user))
@@ -110,7 +114,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
)
# upload a new one...
- res = self.get_success(
+ version = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -119,7 +123,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)
)
- self.assertEqual(res, "2")
+ self.assertEqual(version, "2")
# check we can retrieve it as the current version
res = self.get_success(self.handler.get_version_info(self.local_user))
@@ -134,7 +138,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)
- def test_update_version(self):
+ def test_update_version(self) -> None:
"""Check that we can update versions."""
version = self.get_success(
self.handler.create_version(
@@ -173,7 +177,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)
- def test_update_missing_version(self):
+ def test_update_missing_version(self) -> None:
"""Check that we get a 404 on updating nonexistent versions"""
e = self.get_failure(
self.handler.update_version(
@@ -190,7 +194,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code
self.assertEqual(res, 404)
- def test_update_omitted_version(self):
+ def test_update_omitted_version(self) -> None:
"""Check that the update succeeds if the version is missing from the body"""
version = self.get_success(
self.handler.create_version(
@@ -227,7 +231,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)
- def test_update_bad_version(self):
+ def test_update_bad_version(self) -> None:
"""Check that we get a 400 if the version in the body doesn't match"""
version = self.get_success(
self.handler.create_version(
@@ -255,7 +259,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code
self.assertEqual(res, 400)
- def test_delete_missing_version(self):
+ def test_delete_missing_version(self) -> None:
"""Check that we get a 404 on deleting nonexistent versions"""
e = self.get_failure(
self.handler.delete_version(self.local_user, "1"), SynapseError
@@ -263,15 +267,15 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code
self.assertEqual(res, 404)
- def test_delete_missing_current_version(self):
+ def test_delete_missing_current_version(self) -> None:
"""Check that we get a 404 on deleting nonexistent current version"""
e = self.get_failure(self.handler.delete_version(self.local_user), SynapseError)
res = e.value.code
self.assertEqual(res, 404)
- def test_delete_version(self):
+ def test_delete_version(self) -> None:
"""Check that we can create and then delete versions."""
- res = self.get_success(
+ version = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -280,7 +284,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)
)
- self.assertEqual(res, "1")
+ self.assertEqual(version, "1")
# check we can delete it
self.get_success(self.handler.delete_version(self.local_user, "1"))
@@ -292,7 +296,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code
self.assertEqual(res, 404)
- def test_get_missing_backup(self):
+ def test_get_missing_backup(self) -> None:
"""Check that we get a 404 on querying missing backup"""
e = self.get_failure(
self.handler.get_room_keys(self.local_user, "bogus_version"), SynapseError
@@ -300,7 +304,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code
self.assertEqual(res, 404)
- def test_get_missing_room_keys(self):
+ def test_get_missing_room_keys(self) -> None:
"""Check we get an empty response from an empty backup"""
version = self.get_success(
self.handler.create_version(
@@ -319,7 +323,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
# TODO: test the locking semantics when uploading room_keys,
# although this is probably best done in sytest
- def test_upload_room_keys_no_versions(self):
+ def test_upload_room_keys_no_versions(self) -> None:
"""Check that we get a 404 on uploading keys when no versions are defined"""
e = self.get_failure(
self.handler.upload_room_keys(self.local_user, "no_version", room_keys),
@@ -328,7 +332,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code
self.assertEqual(res, 404)
- def test_upload_room_keys_bogus_version(self):
+ def test_upload_room_keys_bogus_version(self) -> None:
"""Check that we get a 404 on uploading keys when an nonexistent version
is specified
"""
@@ -350,7 +354,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code
self.assertEqual(res, 404)
- def test_upload_room_keys_wrong_version(self):
+ def test_upload_room_keys_wrong_version(self) -> None:
"""Check that we get a 403 on uploading keys for an old version"""
version = self.get_success(
self.handler.create_version(
@@ -380,7 +384,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code
self.assertEqual(res, 403)
- def test_upload_room_keys_insert(self):
+ def test_upload_room_keys_insert(self) -> None:
"""Check that we can insert and retrieve keys for a session"""
version = self.get_success(
self.handler.create_version(
@@ -416,7 +420,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
)
self.assertDictEqual(res, room_keys)
- def test_upload_room_keys_merge(self):
+ def test_upload_room_keys_merge(self) -> None:
"""Check that we can upload a new room_key for an existing session and
have it correctly merged"""
version = self.get_success(
@@ -449,9 +453,11 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
)
- res = self.get_success(self.handler.get_room_keys(self.local_user, version))
+ res_keys = self.get_success(
+ self.handler.get_room_keys(self.local_user, version)
+ )
self.assertEqual(
- res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
+ res_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
"SSBBTSBBIEZJU0gK",
)
@@ -465,9 +471,12 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
)
- res = self.get_success(self.handler.get_room_keys(self.local_user, version))
+ res_keys = self.get_success(
+ self.handler.get_room_keys(self.local_user, version)
+ )
self.assertEqual(
- res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
+ res_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
+ "new",
)
# the etag should NOT be equal now, since the key changed
@@ -483,9 +492,12 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
)
- res = self.get_success(self.handler.get_room_keys(self.local_user, version))
+ res_keys = self.get_success(
+ self.handler.get_room_keys(self.local_user, version)
+ )
self.assertEqual(
- res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
+ res_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
+ "new",
)
# the etag should be the same since the session did not change
@@ -494,7 +506,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
# TODO: check edge cases as well as the common variations here
- def test_delete_room_keys(self):
+ def test_delete_room_keys(self) -> None:
"""Check that we can insert and delete keys for a session"""
version = self.get_success(
self.handler.create_version(
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index d00c69c229..cedbb9fafc 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -439,7 +439,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test")
- def create_invite():
+ def create_invite() -> EventBase:
room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
room_version = self.get_success(self.store.get_room_version(room_id))
return event_from_pdu_json(
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index e448cb1901..70ea4d15d4 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -14,6 +14,8 @@
from typing import Optional
from unittest import mock
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.errors import AuthError, StoreError
from synapse.api.room_versions import RoomVersion
from synapse.event_auth import (
@@ -26,8 +28,10 @@ from synapse.federation.transport.client import StateRequestResponse
from synapse.logging.context import LoggingContext
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.server import HomeServer
from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort
from synapse.types import JsonDict
+from synapse.util import Clock
from tests import unittest
from tests.test_utils import event_injection, make_awaitable
@@ -40,7 +44,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
room.register_servlets,
]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# mock out the federation transport client
self.mock_federation_transport_client = mock.Mock(
spec=["get_room_state_ids", "get_room_state", "get_event", "backfill"]
@@ -165,7 +169,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
)
else:
- async def get_event(destination: str, event_id: str, timeout=None):
+ async def get_event(
+ destination: str, event_id: str, timeout: Optional[int] = None
+ ) -> JsonDict:
self.assertEqual(destination, self.OTHER_SERVER_NAME)
self.assertEqual(event_id, prev_event.event_id)
return {"pdus": [prev_event.get_pdu_json()]}
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index 99384837d0..c4727ab917 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -14,12 +14,16 @@
import logging
from typing import Tuple
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.server import HomeServer
from synapse.types import create_requester
+from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
@@ -35,7 +39,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = self.hs.get_event_creation_handler()
self._persist_event_storage_controller = (
self.hs.get_storage_controllers().persistence
@@ -94,7 +98,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
)
)
- def test_duplicated_txn_id(self):
+ def test_duplicated_txn_id(self) -> None:
"""Test that attempting to handle/persist an event with a transaction ID
that has already been persisted correctly returns the old event and does
*not* produce duplicate messages.
@@ -161,7 +165,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
# rather than the new one.
self.assertEqual(ret_event1.event_id, ret_event4.event_id)
- def test_duplicated_txn_id_one_call(self):
+ def test_duplicated_txn_id_one_call(self) -> None:
"""Test that we correctly handle duplicates that we try and persist at
the same time.
"""
@@ -185,7 +189,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(events), 2)
self.assertEqual(events[0].event_id, events[1].event_id)
- def test_when_empty_prev_events_allowed_create_event_with_empty_prev_events(self):
+ def test_when_empty_prev_events_allowed_create_event_with_empty_prev_events(
+ self,
+ ) -> None:
"""When we set allow_no_prev_events=True, should be able to create a
event without any prev_events (only auth_events).
"""
@@ -214,7 +220,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
def test_when_empty_prev_events_not_allowed_reject_event_with_empty_prev_events(
self,
- ):
+ ) -> None:
"""When we set allow_no_prev_events=False, shouldn't be able to create a
event without any prev_events even if it has auth_events. Expect an
exception to be raised.
@@ -245,7 +251,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
def test_when_empty_prev_events_allowed_reject_event_with_empty_prev_events_and_auth_events(
self,
- ):
+ ) -> None:
"""When we set allow_no_prev_events=True, should be able to create a
event without any prev_events or auth_events. Expect an exception to be
raised.
@@ -277,12 +283,12 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("tester", "foobar")
self.access_token = self.login("tester", "foobar")
self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
- def test_allow_server_acl(self):
+ def test_allow_server_acl(self) -> None:
"""Test that sending an ACL that blocks everyone but ourselves works."""
self.helper.send_state(
@@ -293,7 +299,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
expect_code=200,
)
- def test_deny_server_acl_block_outselves(self):
+ def test_deny_server_acl_block_outselves(self) -> None:
"""Test that sending an ACL that blocks ourselves does not work."""
self.helper.send_state(
self.room_id,
@@ -303,7 +309,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
expect_code=400,
)
- def test_deny_redact_server_acl(self):
+ def test_deny_redact_server_acl(self) -> None:
"""Test that attempting to redact an ACL is blocked."""
body = self.helper.send_state(
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 5955410524..49a1842b5c 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
-from typing import Any, Dict, Tuple
+from typing import Any, Awaitable, ContextManager, Dict, Optional, Tuple
from unittest.mock import ANY, Mock, patch
from urllib.parse import parse_qs, urlparse
@@ -23,7 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.handlers.sso import MappingException
from synapse.http.site import SynapseRequest
from synapse.server import HomeServer
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
from synapse.util import Clock
from synapse.util.macaroons import get_value_from_macaroon
from synapse.util.stringutils import random_string
@@ -34,6 +34,10 @@ from tests.unittest import HomeserverTestCase, override_config
try:
import authlib # noqa: F401
+ from authlib.oidc.core import UserInfo
+ from authlib.oidc.discovery import OpenIDProviderMetadata
+
+ from synapse.handlers.oidc import Token, UserAttributeDict
HAS_OIDC = True
except ImportError:
@@ -70,29 +74,37 @@ EXPLICIT_ENDPOINT_CONFIG = {
class TestMappingProvider:
@staticmethod
- def parse_config(config):
- return
+ def parse_config(config: JsonDict) -> None:
+ return None
- def __init__(self, config):
+ def __init__(self, config: None):
pass
- def get_remote_user_id(self, userinfo):
+ def get_remote_user_id(self, userinfo: "UserInfo") -> str:
return userinfo["sub"]
- async def map_user_attributes(self, userinfo, token):
- return {"localpart": userinfo["username"], "display_name": None}
+ async def map_user_attributes(
+ self, userinfo: "UserInfo", token: "Token"
+ ) -> "UserAttributeDict":
+ # This is testing not providing the full map.
+ return {"localpart": userinfo["username"], "display_name": None} # type: ignore[typeddict-item]
# Do not include get_extra_attributes to test backwards compatibility paths.
class TestMappingProviderExtra(TestMappingProvider):
- async def get_extra_attributes(self, userinfo, token):
+ async def get_extra_attributes(
+ self, userinfo: "UserInfo", token: "Token"
+ ) -> JsonDict:
return {"phone": userinfo["phone"]}
class TestMappingProviderFailures(TestMappingProvider):
- async def map_user_attributes(self, userinfo, token, failures):
- return {
+ # Superclass is testing the legacy interface for map_user_attributes.
+ async def map_user_attributes( # type: ignore[override]
+ self, userinfo: "UserInfo", token: "Token", failures: int
+ ) -> "UserAttributeDict":
+ return { # type: ignore[typeddict-item]
"localpart": userinfo["username"] + (str(failures) if failures else ""),
"display_name": None,
}
@@ -161,13 +173,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.hs_patcher.stop()
return super().tearDown()
- def reset_mocks(self):
+ def reset_mocks(self) -> None:
"""Reset all the Mocks."""
self.fake_server.reset_mocks()
self.render_error.reset_mock()
self.complete_sso_login.reset_mock()
- def metadata_edit(self, values):
+ def metadata_edit(self, values: dict) -> ContextManager[Mock]:
"""Modify the result that will be returned by the well-known query"""
metadata = self.fake_server.get_metadata()
@@ -196,7 +208,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
return _build_callback_request(code, state, session), grant
- def assertRenderedError(self, error, error_description=None):
+ def assertRenderedError(
+ self, error: str, error_description: Optional[str] = None
+ ) -> Tuple[Any, ...]:
self.render_error.assert_called_once()
args = self.render_error.call_args[0]
self.assertEqual(args[1], error)
@@ -273,8 +287,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
"""Provider metadatas are extensively validated."""
h = self.provider
- def force_load_metadata():
- async def force_load():
+ def force_load_metadata() -> Awaitable[None]:
+ async def force_load() -> "OpenIDProviderMetadata":
return await h.load_metadata(force=True)
return get_awaitable_result(force_load())
@@ -1198,7 +1212,7 @@ def _build_callback_request(
state: str,
session: str,
ip_address: str = "10.0.0.1",
-):
+) -> Mock:
"""Builds a fake SynapseRequest to mock the browser callback
Returns a Mock object which looks like the SynapseRequest we get from a browser
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 75934b1707..0916de64f5 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -15,12 +15,13 @@
"""Tests for the password_auth_provider interface"""
from http import HTTPStatus
-from typing import Any, Type, Union
+from typing import Any, Dict, List, Optional, Type, Union
from unittest.mock import Mock
import synapse
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
+from synapse.handlers.account import AccountHandler
from synapse.module_api import ModuleApi
from synapse.rest.client import account, devices, login, logout, register
from synapse.types import JsonDict, UserID
@@ -44,13 +45,13 @@ class LegacyPasswordOnlyAuthProvider:
"""A legacy password_provider which only implements `check_password`."""
@staticmethod
- def parse_config(self):
+ def parse_config(config: JsonDict) -> None:
pass
- def __init__(self, config, account_handler):
+ def __init__(self, config: None, account_handler: AccountHandler):
pass
- def check_password(self, *args):
+ def check_password(self, *args: str) -> Mock:
return mock_password_provider.check_password(*args)
@@ -58,16 +59,16 @@ class LegacyCustomAuthProvider:
"""A legacy password_provider which implements a custom login type."""
@staticmethod
- def parse_config(self):
+ def parse_config(config: JsonDict) -> None:
pass
- def __init__(self, config, account_handler):
+ def __init__(self, config: None, account_handler: AccountHandler):
pass
- def get_supported_login_types(self):
+ def get_supported_login_types(self) -> Dict[str, List[str]]:
return {"test.login_type": ["test_field"]}
- def check_auth(self, *args):
+ def check_auth(self, *args: str) -> Mock:
return mock_password_provider.check_auth(*args)
@@ -75,15 +76,15 @@ class CustomAuthProvider:
"""A module which registers password_auth_provider callbacks for a custom login type."""
@staticmethod
- def parse_config(self):
+ def parse_config(config: JsonDict) -> None:
pass
- def __init__(self, config, api: ModuleApi):
+ def __init__(self, config: None, api: ModuleApi):
api.register_password_auth_provider_callbacks(
auth_checkers={("test.login_type", ("test_field",)): self.check_auth}
)
- def check_auth(self, *args):
+ def check_auth(self, *args: Any) -> Mock:
return mock_password_provider.check_auth(*args)
@@ -92,16 +93,16 @@ class LegacyPasswordCustomAuthProvider:
as a custom type."""
@staticmethod
- def parse_config(self):
+ def parse_config(config: JsonDict) -> None:
pass
- def __init__(self, config, account_handler):
+ def __init__(self, config: None, account_handler: AccountHandler):
pass
- def get_supported_login_types(self):
+ def get_supported_login_types(self) -> Dict[str, List[str]]:
return {"m.login.password": ["password"], "test.login_type": ["test_field"]}
- def check_auth(self, *args):
+ def check_auth(self, *args: str) -> Mock:
return mock_password_provider.check_auth(*args)
@@ -110,10 +111,10 @@ class PasswordCustomAuthProvider:
as well as a password login"""
@staticmethod
- def parse_config(self):
+ def parse_config(config: JsonDict) -> None:
pass
- def __init__(self, config, api: ModuleApi):
+ def __init__(self, config: None, api: ModuleApi):
api.register_password_auth_provider_callbacks(
auth_checkers={
("test.login_type", ("test_field",)): self.check_auth,
@@ -121,10 +122,10 @@ class PasswordCustomAuthProvider:
}
)
- def check_auth(self, *args):
+ def check_auth(self, *args: Any) -> Mock:
return mock_password_provider.check_auth(*args)
- def check_pass(self, *args):
+ def check_pass(self, *args: str) -> Mock:
return mock_password_provider.check_password(*args)
@@ -161,16 +162,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
CALLBACK_USERNAME = "get_username_for_registration"
CALLBACK_DISPLAYNAME = "get_displayname_for_registration"
- def setUp(self):
+ def setUp(self) -> None:
# we use a global mock device, so make sure we are starting with a clean slate
mock_password_provider.reset_mock()
super().setUp()
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
- def test_password_only_auth_progiver_login_legacy(self):
+ def test_password_only_auth_progiver_login_legacy(self) -> None:
self.password_only_auth_provider_login_test_body()
- def password_only_auth_provider_login_test_body(self):
+ def password_only_auth_provider_login_test_body(self) -> None:
# login flows should only have m.login.password
flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
@@ -201,10 +202,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
)
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
- def test_password_only_auth_provider_ui_auth_legacy(self):
+ def test_password_only_auth_provider_ui_auth_legacy(self) -> None:
self.password_only_auth_provider_ui_auth_test_body()
- def password_only_auth_provider_ui_auth_test_body(self):
+ def password_only_auth_provider_ui_auth_test_body(self) -> None:
"""UI Auth should delegate correctly to the password provider"""
# create the user, otherwise access doesn't work
@@ -238,10 +239,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
- def test_local_user_fallback_login_legacy(self):
+ def test_local_user_fallback_login_legacy(self) -> None:
self.local_user_fallback_login_test_body()
- def local_user_fallback_login_test_body(self):
+ def local_user_fallback_login_test_body(self) -> None:
"""rejected login should fall back to local db"""
self.register_user("localuser", "localpass")
@@ -255,10 +256,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual("@localuser:test", channel.json_body["user_id"])
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
- def test_local_user_fallback_ui_auth_legacy(self):
+ def test_local_user_fallback_ui_auth_legacy(self) -> None:
self.local_user_fallback_ui_auth_test_body()
- def local_user_fallback_ui_auth_test_body(self):
+ def local_user_fallback_ui_auth_test_body(self) -> None:
"""rejected login should fall back to local db"""
self.register_user("localuser", "localpass")
@@ -298,10 +299,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"localdb_enabled": False},
}
)
- def test_no_local_user_fallback_login_legacy(self):
+ def test_no_local_user_fallback_login_legacy(self) -> None:
self.no_local_user_fallback_login_test_body()
- def no_local_user_fallback_login_test_body(self):
+ def no_local_user_fallback_login_test_body(self) -> None:
"""localdb_enabled can block login with the local password"""
self.register_user("localuser", "localpass")
@@ -320,10 +321,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"localdb_enabled": False},
}
)
- def test_no_local_user_fallback_ui_auth_legacy(self):
+ def test_no_local_user_fallback_ui_auth_legacy(self) -> None:
self.no_local_user_fallback_ui_auth_test_body()
- def no_local_user_fallback_ui_auth_test_body(self):
+ def no_local_user_fallback_ui_auth_test_body(self) -> None:
"""localdb_enabled can block ui auth with the local password"""
self.register_user("localuser", "localpass")
@@ -361,10 +362,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False},
}
)
- def test_password_auth_disabled_legacy(self):
+ def test_password_auth_disabled_legacy(self) -> None:
self.password_auth_disabled_test_body()
- def password_auth_disabled_test_body(self):
+ def password_auth_disabled_test_body(self) -> None:
"""password auth doesn't work if it's disabled across the board"""
# login flows should be empty
flows = self._get_login_flows()
@@ -376,14 +377,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.check_password.assert_not_called()
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
- def test_custom_auth_provider_login_legacy(self):
+ def test_custom_auth_provider_login_legacy(self) -> None:
self.custom_auth_provider_login_test_body()
@override_config(providers_config(CustomAuthProvider))
- def test_custom_auth_provider_login(self):
+ def test_custom_auth_provider_login(self) -> None:
self.custom_auth_provider_login_test_body()
- def custom_auth_provider_login_test_body(self):
+ def custom_auth_provider_login_test_body(self) -> None:
# login flows should have the custom flow and m.login.password, since we
# haven't disabled local password lookup.
# (password must come first, because reasons)
@@ -424,14 +425,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
)
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
- def test_custom_auth_provider_ui_auth_legacy(self):
+ def test_custom_auth_provider_ui_auth_legacy(self) -> None:
self.custom_auth_provider_ui_auth_test_body()
@override_config(providers_config(CustomAuthProvider))
- def test_custom_auth_provider_ui_auth(self):
+ def test_custom_auth_provider_ui_auth(self) -> None:
self.custom_auth_provider_ui_auth_test_body()
- def custom_auth_provider_ui_auth_test_body(self):
+ def custom_auth_provider_ui_auth_test_body(self) -> None:
# register the user and log in twice, to get two devices
self.register_user("localuser", "localpass")
tok1 = self.login("localuser", "localpass")
@@ -486,14 +487,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
)
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
- def test_custom_auth_provider_callback_legacy(self):
+ def test_custom_auth_provider_callback_legacy(self) -> None:
self.custom_auth_provider_callback_test_body()
@override_config(providers_config(CustomAuthProvider))
- def test_custom_auth_provider_callback(self):
+ def test_custom_auth_provider_callback(self) -> None:
self.custom_auth_provider_callback_test_body()
- def custom_auth_provider_callback_test_body(self):
+ def custom_auth_provider_callback_test_body(self) -> None:
callback = Mock(return_value=make_awaitable(None))
mock_password_provider.check_auth.return_value = make_awaitable(
@@ -521,16 +522,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False},
}
)
- def test_custom_auth_password_disabled_legacy(self):
+ def test_custom_auth_password_disabled_legacy(self) -> None:
self.custom_auth_password_disabled_test_body()
@override_config(
{**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
)
- def test_custom_auth_password_disabled(self):
+ def test_custom_auth_password_disabled(self) -> None:
self.custom_auth_password_disabled_test_body()
- def custom_auth_password_disabled_test_body(self):
+ def custom_auth_password_disabled_test_body(self) -> None:
"""Test login with a custom auth provider where password login is disabled"""
self.register_user("localuser", "localpass")
@@ -548,7 +549,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False, "localdb_enabled": False},
}
)
- def test_custom_auth_password_disabled_localdb_enabled_legacy(self):
+ def test_custom_auth_password_disabled_localdb_enabled_legacy(self) -> None:
self.custom_auth_password_disabled_localdb_enabled_test_body()
@override_config(
@@ -557,10 +558,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False, "localdb_enabled": False},
}
)
- def test_custom_auth_password_disabled_localdb_enabled(self):
+ def test_custom_auth_password_disabled_localdb_enabled(self) -> None:
self.custom_auth_password_disabled_localdb_enabled_test_body()
- def custom_auth_password_disabled_localdb_enabled_test_body(self):
+ def custom_auth_password_disabled_localdb_enabled_test_body(self) -> None:
"""Check the localdb_enabled == enabled == False
Regression test for https://github.com/matrix-org/synapse/issues/8914: check
@@ -583,7 +584,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False},
}
)
- def test_password_custom_auth_password_disabled_login_legacy(self):
+ def test_password_custom_auth_password_disabled_login_legacy(self) -> None:
self.password_custom_auth_password_disabled_login_test_body()
@override_config(
@@ -592,10 +593,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False},
}
)
- def test_password_custom_auth_password_disabled_login(self):
+ def test_password_custom_auth_password_disabled_login(self) -> None:
self.password_custom_auth_password_disabled_login_test_body()
- def password_custom_auth_password_disabled_login_test_body(self):
+ def password_custom_auth_password_disabled_login_test_body(self) -> None:
"""log in with a custom auth provider which implements password, but password
login is disabled"""
self.register_user("localuser", "localpass")
@@ -615,7 +616,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False},
}
)
- def test_password_custom_auth_password_disabled_ui_auth_legacy(self):
+ def test_password_custom_auth_password_disabled_ui_auth_legacy(self) -> None:
self.password_custom_auth_password_disabled_ui_auth_test_body()
@override_config(
@@ -624,10 +625,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False},
}
)
- def test_password_custom_auth_password_disabled_ui_auth(self):
+ def test_password_custom_auth_password_disabled_ui_auth(self) -> None:
self.password_custom_auth_password_disabled_ui_auth_test_body()
- def password_custom_auth_password_disabled_ui_auth_test_body(self):
+ def password_custom_auth_password_disabled_ui_auth_test_body(self) -> None:
"""UI Auth with a custom auth provider which implements password, but password
login is disabled"""
# register the user and log in twice via the test login type to get two devices,
@@ -689,7 +690,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"localdb_enabled": False},
}
)
- def test_custom_auth_no_local_user_fallback_legacy(self):
+ def test_custom_auth_no_local_user_fallback_legacy(self) -> None:
self.custom_auth_no_local_user_fallback_test_body()
@override_config(
@@ -698,10 +699,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"localdb_enabled": False},
}
)
- def test_custom_auth_no_local_user_fallback(self):
+ def test_custom_auth_no_local_user_fallback(self) -> None:
self.custom_auth_no_local_user_fallback_test_body()
- def custom_auth_no_local_user_fallback_test_body(self):
+ def custom_auth_no_local_user_fallback_test_body(self) -> None:
"""Test login with a custom auth provider where the local db is disabled"""
self.register_user("localuser", "localpass")
@@ -713,14 +714,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
- def test_on_logged_out(self):
+ def test_on_logged_out(self) -> None:
"""Tests that the on_logged_out callback is called when the user logs out."""
self.register_user("rin", "password")
tok = self.login("rin", "password")
self.called = False
- async def on_logged_out(user_id, device_id, access_token):
+ async def on_logged_out(
+ user_id: str, device_id: Optional[str], access_token: str
+ ) -> None:
self.called = True
on_logged_out = Mock(side_effect=on_logged_out)
@@ -738,7 +741,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
on_logged_out.assert_called_once()
self.assertTrue(self.called)
- def test_username(self):
+ def test_username(self) -> None:
"""Tests that the get_username_for_registration callback can define the username
of a user when registering.
"""
@@ -763,7 +766,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mxid = channel.json_body["user_id"]
self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
- def test_username_uia(self):
+ def test_username_uia(self) -> None:
"""Tests that the get_username_for_registration callback is only called at the
end of the UIA flow.
"""
@@ -782,7 +785,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# Set some email configuration so the test doesn't fail because of its absence.
@override_config({"email": {"notif_from": "noreply@test"}})
- def test_3pid_allowed(self):
+ def test_3pid_allowed(self) -> None:
"""Tests that an is_3pid_allowed_callbacks forbidding a 3PID makes Synapse refuse
to bind the new 3PID, and that one allowing a 3PID makes Synapse accept to bind
the 3PID. Also checks that the module is passed a boolean indicating whether the
@@ -791,7 +794,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self._test_3pid_allowed("rin", False)
self._test_3pid_allowed("kitay", True)
- def test_displayname(self):
+ def test_displayname(self) -> None:
"""Tests that the get_displayname_for_registration callback can define the
display name of a user when registering.
"""
@@ -820,7 +823,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(display_name, username + "-foo")
- def test_displayname_uia(self):
+ def test_displayname_uia(self) -> None:
"""Tests that the get_displayname_for_registration callback is only called at the
end of the UIA flow.
"""
@@ -841,7 +844,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# Check that the callback has been called.
m.assert_called_once()
- def _test_3pid_allowed(self, username: str, registration: bool):
+ def _test_3pid_allowed(self, username: str, registration: bool) -> None:
"""Tests that the "is_3pid_allowed" module callback is called correctly, using
either /register or /account URLs depending on the arguments.
@@ -907,7 +910,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
client is trying to register.
"""
- async def callback(uia_results, params):
+ async def callback(uia_results: JsonDict, params: JsonDict) -> str:
self.assertIn(LoginType.DUMMY, uia_results)
username = params["username"]
return username + "-foo"
@@ -950,12 +953,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
def _send_password_login(self, user: str, password: str) -> FakeChannel:
return self._send_login(type="m.login.password", user=user, password=password)
- def _send_login(self, type, user, **params) -> FakeChannel:
- params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type})
+ def _send_login(self, type: str, user: str, **extra_params: str) -> FakeChannel:
+ params = {"identifier": {"type": "m.id.user", "user": user}, "type": type}
+ params.update(extra_params)
channel = self.make_request("POST", "/_matrix/client/r0/login", params)
return channel
- def _start_delete_device_session(self, access_token, device_id) -> str:
+ def _start_delete_device_session(self, access_token: str, device_id: str) -> str:
"""Make an initial delete device request, and return the UI Auth session ID"""
channel = self._delete_device(access_token, device_id)
self.assertEqual(channel.code, 401)
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 584e7b8971..19f5322317 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional
+from typing import Optional, cast
from unittest.mock import Mock, call
from parameterized import parameterized
from signedjson.key import generate_signing_key
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
@@ -35,7 +37,9 @@ from synapse.handlers.presence import (
)
from synapse.rest import admin
from synapse.rest.client import room
-from synapse.types import UserID, get_domain_from_id
+from synapse.server import HomeServer
+from synapse.types import JsonDict, UserID, get_domain_from_id
+from synapse.util import Clock
from tests import unittest
from tests.replication._base import BaseMultiWorkerStreamTestCase
@@ -44,10 +48,12 @@ from tests.replication._base import BaseMultiWorkerStreamTestCase
class PresenceUpdateTestCase(unittest.HomeserverTestCase):
servlets = [admin.register_servlets]
- def prepare(self, reactor, clock, homeserver):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
self.store = homeserver.get_datastores().main
- def test_offline_to_online(self):
+ def test_offline_to_online(self) -> None:
wheel_timer = Mock()
user_id = "@foo:bar"
now = 5000000
@@ -85,7 +91,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
any_order=True,
)
- def test_online_to_online(self):
+ def test_online_to_online(self) -> None:
wheel_timer = Mock()
user_id = "@foo:bar"
now = 5000000
@@ -128,7 +134,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
any_order=True,
)
- def test_online_to_online_last_active_noop(self):
+ def test_online_to_online_last_active_noop(self) -> None:
wheel_timer = Mock()
user_id = "@foo:bar"
now = 5000000
@@ -173,7 +179,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
any_order=True,
)
- def test_online_to_online_last_active(self):
+ def test_online_to_online_last_active(self) -> None:
wheel_timer = Mock()
user_id = "@foo:bar"
now = 5000000
@@ -210,7 +216,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
any_order=True,
)
- def test_remote_ping_timer(self):
+ def test_remote_ping_timer(self) -> None:
wheel_timer = Mock()
user_id = "@foo:bar"
now = 5000000
@@ -244,7 +250,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
any_order=True,
)
- def test_online_to_offline(self):
+ def test_online_to_offline(self) -> None:
wheel_timer = Mock()
user_id = "@foo:bar"
now = 5000000
@@ -266,7 +272,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
self.assertEqual(wheel_timer.insert.call_count, 0)
- def test_online_to_idle(self):
+ def test_online_to_idle(self) -> None:
wheel_timer = Mock()
user_id = "@foo:bar"
now = 5000000
@@ -300,7 +306,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
any_order=True,
)
- def test_persisting_presence_updates(self):
+ def test_persisting_presence_updates(self) -> None:
"""Tests that the latest presence state for each user is persisted correctly"""
# Create some test users and presence states for them
presence_states = []
@@ -322,7 +328,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
self.get_success(self.store.update_presence(presence_states))
# Check that each update is present in the database
- db_presence_states = self.get_success(
+ db_presence_states_raw = self.get_success(
self.store.get_all_presence_updates(
instance_name="master",
last_id=0,
@@ -332,7 +338,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
)
# Extract presence update user ID and state information into lists of tuples
- db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states[0]]
+ db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states_raw[0]]
presence_states_compare = [(ps.user_id, ps.state) for ps in presence_states]
# Compare what we put into the storage with what we got out.
@@ -343,7 +349,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
class PresenceTimeoutTestCase(unittest.TestCase):
"""Tests different timers and that the timer does not change `status_msg` of user."""
- def test_idle_timer(self):
+ def test_idle_timer(self) -> None:
user_id = "@foo:bar"
status_msg = "I'm here!"
now = 5000000
@@ -363,7 +369,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertEqual(new_state.state, PresenceState.UNAVAILABLE)
self.assertEqual(new_state.status_msg, status_msg)
- def test_busy_no_idle(self):
+ def test_busy_no_idle(self) -> None:
"""
Tests that a user setting their presence to busy but idling doesn't turn their
presence state into unavailable.
@@ -387,7 +393,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertEqual(new_state.state, PresenceState.BUSY)
self.assertEqual(new_state.status_msg, status_msg)
- def test_sync_timeout(self):
+ def test_sync_timeout(self) -> None:
user_id = "@foo:bar"
status_msg = "I'm here!"
now = 5000000
@@ -407,7 +413,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertEqual(new_state.state, PresenceState.OFFLINE)
self.assertEqual(new_state.status_msg, status_msg)
- def test_sync_online(self):
+ def test_sync_online(self) -> None:
user_id = "@foo:bar"
status_msg = "I'm here!"
now = 5000000
@@ -429,7 +435,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertEqual(new_state.state, PresenceState.ONLINE)
self.assertEqual(new_state.status_msg, status_msg)
- def test_federation_ping(self):
+ def test_federation_ping(self) -> None:
user_id = "@foo:bar"
status_msg = "I'm here!"
now = 5000000
@@ -448,7 +454,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertIsNotNone(new_state)
self.assertEqual(state, new_state)
- def test_no_timeout(self):
+ def test_no_timeout(self) -> None:
user_id = "@foo:bar"
now = 5000000
@@ -464,7 +470,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertIsNone(new_state)
- def test_federation_timeout(self):
+ def test_federation_timeout(self) -> None:
user_id = "@foo:bar"
status_msg = "I'm here!"
now = 5000000
@@ -487,7 +493,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertEqual(new_state.state, PresenceState.OFFLINE)
self.assertEqual(new_state.status_msg, status_msg)
- def test_last_active(self):
+ def test_last_active(self) -> None:
user_id = "@foo:bar"
status_msg = "I'm here!"
now = 5000000
@@ -508,15 +514,15 @@ class PresenceTimeoutTestCase(unittest.TestCase):
class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock()
- def test_external_process_timeout(self):
+ def test_external_process_timeout(self) -> None:
"""Test that if an external process doesn't update the records for a while
we time out their syncing users presence.
"""
- process_id = 1
+ process_id = "1"
user_id = "@test:server"
# Notify handler that a user is now syncing.
@@ -544,7 +550,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
)
self.assertEqual(state.state, PresenceState.OFFLINE)
- def test_user_goes_offline_by_timeout_status_msg_remain(self):
+ def test_user_goes_offline_by_timeout_status_msg_remain(self) -> None:
"""Test that if a user doesn't update the records for a while
users presence goes `OFFLINE` because of timeout and `status_msg` remains.
"""
@@ -576,7 +582,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(state.state, PresenceState.OFFLINE)
self.assertEqual(state.status_msg, status_msg)
- def test_user_goes_offline_manually_with_no_status_msg(self):
+ def test_user_goes_offline_manually_with_no_status_msg(self) -> None:
"""Test that if a user change presence manually to `OFFLINE`
and no status is set, that `status_msg` is `None`.
"""
@@ -601,7 +607,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(state.state, PresenceState.OFFLINE)
self.assertEqual(state.status_msg, None)
- def test_user_goes_offline_manually_with_status_msg(self):
+ def test_user_goes_offline_manually_with_status_msg(self) -> None:
"""Test that if a user change presence manually to `OFFLINE`
and a status is set, that `status_msg` appears.
"""
@@ -618,7 +624,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
user_id, PresenceState.OFFLINE, "And now here."
)
- def test_user_reset_online_with_no_status(self):
+ def test_user_reset_online_with_no_status(self) -> None:
"""Test that if a user set again the presence manually
and no status is set, that `status_msg` is `None`.
"""
@@ -644,7 +650,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(state.state, PresenceState.ONLINE)
self.assertEqual(state.status_msg, None)
- def test_set_presence_with_status_msg_none(self):
+ def test_set_presence_with_status_msg_none(self) -> None:
"""Test that if a user set again the presence manually
and status is `None`, that `status_msg` is `None`.
"""
@@ -659,7 +665,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
# Mark user as online and `status_msg = None`
self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None)
- def test_set_presence_from_syncing_not_set(self):
+ def test_set_presence_from_syncing_not_set(self) -> None:
"""Test that presence is not set by syncing if affect_presence is false"""
user_id = "@test:server"
status_msg = "I'm here!"
@@ -680,7 +686,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
# and status message should still be the same
self.assertEqual(state.status_msg, status_msg)
- def test_set_presence_from_syncing_is_set(self):
+ def test_set_presence_from_syncing_is_set(self) -> None:
"""Test that presence is set by syncing if affect_presence is true"""
user_id = "@test:server"
status_msg = "I'm here!"
@@ -699,7 +705,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
# we should now be online
self.assertEqual(state.state, PresenceState.ONLINE)
- def test_set_presence_from_syncing_keeps_status(self):
+ def test_set_presence_from_syncing_keeps_status(self) -> None:
"""Test that presence set by syncing retains status message"""
user_id = "@test:server"
status_msg = "I'm here!"
@@ -726,7 +732,9 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
},
}
)
- def test_set_presence_from_syncing_keeps_busy(self, test_with_workers: bool):
+ def test_set_presence_from_syncing_keeps_busy(
+ self, test_with_workers: bool
+ ) -> None:
"""Test that presence set by syncing doesn't affect busy status
Args:
@@ -767,7 +775,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
def _set_presencestate_with_status_msg(
self, user_id: str, state: str, status_msg: Optional[str]
- ):
+ ) -> None:
"""Set a PresenceState and status_msg and check the result.
Args:
@@ -790,14 +798,14 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock()
self.instance_name = hs.get_instance_name()
self.queue = self.presence_handler.get_federation_queue()
- def test_send_and_get(self):
+ def test_send_and_get(self) -> None:
state1 = UserPresenceState.default("@user1:test")
state2 = UserPresenceState.default("@user2:test")
state3 = UserPresenceState.default("@user3:test")
@@ -834,7 +842,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
self.assertFalse(limited)
self.assertCountEqual(rows, [])
- def test_send_and_get_split(self):
+ def test_send_and_get_split(self) -> None:
state1 = UserPresenceState.default("@user1:test")
state2 = UserPresenceState.default("@user2:test")
state3 = UserPresenceState.default("@user3:test")
@@ -877,7 +885,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
self.assertCountEqual(rows, expected_rows)
- def test_clear_queue_all(self):
+ def test_clear_queue_all(self) -> None:
state1 = UserPresenceState.default("@user1:test")
state2 = UserPresenceState.default("@user2:test")
state3 = UserPresenceState.default("@user3:test")
@@ -921,7 +929,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
self.assertCountEqual(rows, expected_rows)
- def test_partially_clear_queue(self):
+ def test_partially_clear_queue(self) -> None:
state1 = UserPresenceState.default("@user1:test")
state2 = UserPresenceState.default("@user2:test")
state3 = UserPresenceState.default("@user3:test")
@@ -982,7 +990,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
servlets = [room.register_servlets]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver(
"server",
federation_http_client=None,
@@ -990,14 +998,14 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
)
return hs
- def default_config(self):
+ def default_config(self) -> JsonDict:
config = super().default_config()
# Enable federation sending on the main process.
config["federation_sender_instances"] = None
return config
- def prepare(self, reactor, clock, hs):
- self.federation_sender = hs.get_federation_sender()
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.federation_sender = cast(Mock, hs.get_federation_sender())
self.event_builder_factory = hs.get_event_builder_factory()
self.federation_event_handler = hs.get_federation_event_handler()
self.presence_handler = hs.get_presence_handler()
@@ -1013,7 +1021,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
# random key to use.
self.random_signing_key = generate_signing_key("ver")
- def test_remote_joins(self):
+ def test_remote_joins(self) -> None:
# We advance time to something that isn't 0, as we use 0 as a special
# value.
self.reactor.advance(1000000000000)
@@ -1061,7 +1069,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
destinations={"server3"}, states=[expected_state]
)
- def test_remote_gets_presence_when_local_user_joins(self):
+ def test_remote_gets_presence_when_local_user_joins(self) -> None:
# We advance time to something that isn't 0, as we use 0 as a special
# value.
self.reactor.advance(1000000000000)
@@ -1110,7 +1118,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
destinations={"server2", "server3"}, states=[expected_state]
)
- def _add_new_user(self, room_id, user_id):
+ def _add_new_user(self, room_id: str, user_id: str) -> None:
"""Add new user to the room by creating an event and poking the federation API."""
hostname = get_domain_from_id(user_id)
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 675aa023ac..7c174782da 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -332,7 +332,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
@unittest.override_config(
{"server_name": "test:8888", "allowed_avatar_mimetypes": ["image/png"]}
)
- def test_avatar_constraint_on_local_server_with_port(self):
+ def test_avatar_constraint_on_local_server_with_port(self) -> None:
"""Test that avatar metadata is correctly fetched when the media is on a local
server and the server has an explicit port.
@@ -376,7 +376,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.get_success(self.handler.check_avatar_size_and_mime_type(remote_mxc))
)
- def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]):
+ def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]) -> None:
"""Stores metadata about files in the database.
Args:
diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py
index b55238650c..f60400ff8d 100644
--- a/tests/handlers/test_receipts.py
+++ b/tests/handlers/test_receipts.py
@@ -15,14 +15,18 @@
from copy import deepcopy
from typing import List
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import EduTypes, ReceiptTypes
+from synapse.server import HomeServer
from synapse.types import JsonDict
+from synapse.util import Clock
from tests import unittest
class ReceiptsTestCase(unittest.HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.event_source = hs.get_event_sources().sources.receipt
def test_filters_out_private_receipt(self) -> None:
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 765df75d91..b9332d97dc 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -12,8 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Collection, List, Optional, Tuple
from unittest.mock import Mock
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.auth import Auth
from synapse.api.constants import UserTypes
from synapse.api.errors import (
@@ -22,8 +25,18 @@ from synapse.api.errors import (
ResourceLimitError,
SynapseError,
)
+from synapse.module_api import ModuleApi
+from synapse.server import HomeServer
from synapse.spam_checker_api import RegistrationBehaviour
-from synapse.types import RoomAlias, RoomID, UserID, create_requester
+from synapse.types import (
+ JsonDict,
+ Requester,
+ RoomAlias,
+ RoomID,
+ UserID,
+ create_requester,
+)
+from synapse.util import Clock
from tests.test_utils import make_awaitable
from tests.unittest import override_config
@@ -33,94 +46,98 @@ from .. import unittest
class TestSpamChecker:
- def __init__(self, config, api):
+ def __init__(self, config: None, api: ModuleApi):
api.register_spam_checker_callbacks(
check_registration_for_spam=self.check_registration_for_spam,
)
@staticmethod
- def parse_config(config):
- return config
+ def parse_config(config: JsonDict) -> None:
+ return None
async def check_registration_for_spam(
self,
- email_threepid,
- username,
- request_info,
- auth_provider_id,
- ):
+ email_threepid: Optional[dict],
+ username: Optional[str],
+ request_info: Collection[Tuple[str, str]],
+ auth_provider_id: Optional[str],
+ ) -> RegistrationBehaviour:
pass
class DenyAll(TestSpamChecker):
async def check_registration_for_spam(
self,
- email_threepid,
- username,
- request_info,
- auth_provider_id,
- ):
+ email_threepid: Optional[dict],
+ username: Optional[str],
+ request_info: Collection[Tuple[str, str]],
+ auth_provider_id: Optional[str],
+ ) -> RegistrationBehaviour:
return RegistrationBehaviour.DENY
class BanAll(TestSpamChecker):
async def check_registration_for_spam(
self,
- email_threepid,
- username,
- request_info,
- auth_provider_id,
- ):
+ email_threepid: Optional[dict],
+ username: Optional[str],
+ request_info: Collection[Tuple[str, str]],
+ auth_provider_id: Optional[str],
+ ) -> RegistrationBehaviour:
return RegistrationBehaviour.SHADOW_BAN
class BanBadIdPUser(TestSpamChecker):
async def check_registration_for_spam(
- self, email_threepid, username, request_info, auth_provider_id=None
- ):
+ self,
+ email_threepid: Optional[dict],
+ username: Optional[str],
+ request_info: Collection[Tuple[str, str]],
+ auth_provider_id: Optional[str] = None,
+ ) -> RegistrationBehaviour:
# Reject any user coming from CAS and whose username contains profanity
- if auth_provider_id == "cas" and "flimflob" in username:
+ if auth_provider_id == "cas" and username and "flimflob" in username:
return RegistrationBehaviour.DENY
return RegistrationBehaviour.ALLOW
class TestLegacyRegistrationSpamChecker:
- def __init__(self, config, api):
+ def __init__(self, config: None, api: ModuleApi):
pass
async def check_registration_for_spam(
self,
- email_threepid,
- username,
- request_info,
- ):
+ email_threepid: Optional[dict],
+ username: Optional[str],
+ request_info: Collection[Tuple[str, str]],
+ ) -> RegistrationBehaviour:
pass
class LegacyAllowAll(TestLegacyRegistrationSpamChecker):
async def check_registration_for_spam(
self,
- email_threepid,
- username,
- request_info,
- ):
+ email_threepid: Optional[dict],
+ username: Optional[str],
+ request_info: Collection[Tuple[str, str]],
+ ) -> RegistrationBehaviour:
return RegistrationBehaviour.ALLOW
class LegacyDenyAll(TestLegacyRegistrationSpamChecker):
async def check_registration_for_spam(
self,
- email_threepid,
- username,
- request_info,
- ):
+ email_threepid: Optional[dict],
+ username: Optional[str],
+ request_info: Collection[Tuple[str, str]],
+ ) -> RegistrationBehaviour:
return RegistrationBehaviour.DENY
class RegistrationTestCase(unittest.HomeserverTestCase):
"""Tests the RegistrationHandler."""
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs_config = self.default_config()
# some of the tests rely on us having a user consent version
@@ -145,7 +162,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
return hs
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = self.hs.get_registration_handler()
self.store = self.hs.get_datastores().main
self.lots_of_users = 100
@@ -153,7 +170,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.requester = create_requester("@requester:test")
- def test_user_is_created_and_logged_in_if_doesnt_exist(self):
+ def test_user_is_created_and_logged_in_if_doesnt_exist(self) -> None:
frank = UserID.from_string("@frank:test")
user_id = frank.to_string()
requester = create_requester(user_id)
@@ -164,7 +181,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertIsInstance(result_token, str)
self.assertGreater(len(result_token), 20)
- def test_if_user_exists(self):
+ def test_if_user_exists(self) -> None:
store = self.hs.get_datastores().main
frank = UserID.from_string("@frank:test")
self.get_success(
@@ -180,12 +197,12 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertTrue(result_token is not None)
@override_config({"limit_usage_by_mau": False})
- def test_mau_limits_when_disabled(self):
+ def test_mau_limits_when_disabled(self) -> None:
# Ensure does not throw exception
self.get_success(self.get_or_create_user(self.requester, "a", "display_name"))
@override_config({"limit_usage_by_mau": True})
- def test_get_or_create_user_mau_not_blocked(self):
+ def test_get_or_create_user_mau_not_blocked(self) -> None:
self.store.count_monthly_users = Mock(
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
)
@@ -193,7 +210,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.get_success(self.get_or_create_user(self.requester, "c", "User"))
@override_config({"limit_usage_by_mau": True})
- def test_get_or_create_user_mau_blocked(self):
+ def test_get_or_create_user_mau_blocked(self) -> None:
self.store.get_monthly_active_count = Mock(
return_value=make_awaitable(self.lots_of_users)
)
@@ -211,7 +228,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
)
@override_config({"limit_usage_by_mau": True})
- def test_register_mau_blocked(self):
+ def test_register_mau_blocked(self) -> None:
self.store.get_monthly_active_count = Mock(
return_value=make_awaitable(self.lots_of_users)
)
@@ -229,7 +246,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config(
{"auto_join_rooms": ["#room:test"], "auto_join_rooms_for_guests": False}
)
- def test_auto_join_rooms_for_guests(self):
+ def test_auto_join_rooms_for_guests(self) -> None:
user_id = self.get_success(
self.handler.register_user(localpart="jeff", make_guest=True),
)
@@ -237,7 +254,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(rooms), 0)
@override_config({"auto_join_rooms": ["#room:test"]})
- def test_auto_create_auto_join_rooms(self):
+ def test_auto_create_auto_join_rooms(self) -> None:
room_alias_str = "#room:test"
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
@@ -249,7 +266,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(rooms), 1)
@override_config({"auto_join_rooms": []})
- def test_auto_create_auto_join_rooms_with_no_rooms(self):
+ def test_auto_create_auto_join_rooms_with_no_rooms(self) -> None:
frank = UserID.from_string("@frank:test")
user_id = self.get_success(self.handler.register_user(frank.localpart))
self.assertEqual(user_id, frank.to_string())
@@ -257,7 +274,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(rooms), 0)
@override_config({"auto_join_rooms": ["#room:another"]})
- def test_auto_create_auto_join_where_room_is_another_domain(self):
+ def test_auto_create_auto_join_where_room_is_another_domain(self) -> None:
frank = UserID.from_string("@frank:test")
user_id = self.get_success(self.handler.register_user(frank.localpart))
self.assertEqual(user_id, frank.to_string())
@@ -267,13 +284,13 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config(
{"auto_join_rooms": ["#room:test"], "autocreate_auto_join_rooms": False}
)
- def test_auto_create_auto_join_where_auto_create_is_false(self):
+ def test_auto_create_auto_join_where_auto_create_is_false(self) -> None:
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
@override_config({"auto_join_rooms": ["#room:test"]})
- def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self):
+ def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self) -> None:
room_alias_str = "#room:test"
self.store.is_real_user = Mock(return_value=make_awaitable(False))
user_id = self.get_success(self.handler.register_user(localpart="support"))
@@ -284,7 +301,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.get_failure(directory_handler.get_association(room_alias), SynapseError)
@override_config({"auto_join_rooms": ["#room:test"]})
- def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self):
+ def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None:
room_alias_str = "#room:test"
self.store.count_real_users = Mock(return_value=make_awaitable(1))
@@ -299,7 +316,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(rooms), 1)
@override_config({"auto_join_rooms": ["#room:test"]})
- def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(self):
+ def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(
+ self,
+ ) -> None:
self.store.count_real_users = Mock(return_value=make_awaitable(2))
self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
@@ -312,7 +331,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
"autocreate_auto_join_rooms_federated": False,
}
)
- def test_auto_create_auto_join_rooms_federated(self):
+ def test_auto_create_auto_join_rooms_federated(self) -> None:
"""
Auto-created rooms that are private require an invite to go to the user
(instead of directly joining it).
@@ -339,7 +358,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config(
{"auto_join_rooms": ["#room:test"], "auto_join_mxid_localpart": "support"}
)
- def test_auto_join_mxid_localpart(self):
+ def test_auto_join_mxid_localpart(self) -> None:
"""
Ensure the user still needs up in the room created by a different user.
"""
@@ -376,7 +395,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
"auto_join_mxid_localpart": "support",
}
)
- def test_auto_create_auto_join_room_preset(self):
+ def test_auto_create_auto_join_room_preset(self) -> None:
"""
Auto-created rooms that are private require an invite to go to the user
(instead of directly joining it).
@@ -416,7 +435,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
"auto_join_mxid_localpart": "support",
}
)
- def test_auto_create_auto_join_room_preset_guest(self):
+ def test_auto_create_auto_join_room_preset_guest(self) -> None:
"""
Auto-created rooms that are private require an invite to go to the user
(instead of directly joining it).
@@ -454,7 +473,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
"auto_join_mxid_localpart": "support",
}
)
- def test_auto_create_auto_join_room_preset_invalid_permissions(self):
+ def test_auto_create_auto_join_room_preset_invalid_permissions(self) -> None:
"""
Auto-created rooms that are private require an invite, check that
registration doesn't completely break if the inviter doesn't have proper
@@ -525,7 +544,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
"auto_join_rooms": ["#room:test"],
},
)
- def test_auto_create_auto_join_where_no_consent(self):
+ def test_auto_create_auto_join_where_no_consent(self) -> None:
"""Test to ensure that the first user is not auto-joined to a room if
they have not given general consent.
"""
@@ -550,19 +569,19 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 1)
- def test_register_support_user(self):
+ def test_register_support_user(self) -> None:
user_id = self.get_success(
self.handler.register_user(localpart="user", user_type=UserTypes.SUPPORT)
)
d = self.store.is_support_user(user_id)
self.assertTrue(self.get_success(d))
- def test_register_not_support_user(self):
+ def test_register_not_support_user(self) -> None:
user_id = self.get_success(self.handler.register_user(localpart="user"))
d = self.store.is_support_user(user_id)
self.assertFalse(self.get_success(d))
- def test_invalid_user_id_length(self):
+ def test_invalid_user_id_length(self) -> None:
invalid_user_id = "x" * 256
self.get_failure(
self.handler.register_user(localpart=invalid_user_id), SynapseError
@@ -577,7 +596,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
]
}
)
- def test_spam_checker_deny(self):
+ def test_spam_checker_deny(self) -> None:
"""A spam checker can deny registration, which results in an error."""
self.get_failure(self.handler.register_user(localpart="user"), SynapseError)
@@ -590,7 +609,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
]
}
)
- def test_spam_checker_legacy_allow(self):
+ def test_spam_checker_legacy_allow(self) -> None:
"""Tests that a legacy spam checker implementing the legacy 3-arg version of the
check_registration_for_spam callback is correctly called.
@@ -610,7 +629,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
]
}
)
- def test_spam_checker_legacy_deny(self):
+ def test_spam_checker_legacy_deny(self) -> None:
"""Tests that a legacy spam checker implementing the legacy 3-arg version of the
check_registration_for_spam callback is correctly called.
@@ -630,7 +649,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
]
}
)
- def test_spam_checker_shadow_ban(self):
+ def test_spam_checker_shadow_ban(self) -> None:
"""A spam checker can choose to shadow-ban a user, which allows registration to succeed."""
user_id = self.get_success(self.handler.register_user(localpart="user"))
@@ -660,7 +679,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
]
}
)
- def test_spam_checker_receives_sso_type(self):
+ def test_spam_checker_receives_sso_type(self) -> None:
"""Test rejecting registration based on SSO type"""
f = self.get_failure(
self.handler.register_user(localpart="bobflimflob", auth_provider_id="cas"),
@@ -678,8 +697,12 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
)
async def get_or_create_user(
- self, requester, localpart, displayname, password_hash=None
- ):
+ self,
+ requester: Requester,
+ localpart: str,
+ displayname: Optional[str],
+ password_hash: Optional[str] = None,
+ ) -> Tuple[str, str]:
"""Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one.
@@ -734,13 +757,15 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
class RemoteAutoJoinTestCase(unittest.HomeserverTestCase):
"""Tests auto-join on remote rooms."""
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.room_id = "!roomid:remotetest"
- async def update_membership(*args, **kwargs):
+ async def update_membership(*args: Any, **kwargs: Any) -> None:
pass
- async def lookup_room_alias(*args, **kwargs):
+ async def lookup_room_alias(
+ *args: Any, **kwargs: Any
+ ) -> Tuple[RoomID, List[str]]:
return RoomID.from_string(self.room_id), ["remotetest"]
self.room_member_handler = Mock(spec=["update_membership", "lookup_room_alias"])
@@ -750,12 +775,12 @@ class RemoteAutoJoinTestCase(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver(room_member_handler=self.room_member_handler)
return hs
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = self.hs.get_registration_handler()
self.store = self.hs.get_datastores().main
@override_config({"auto_join_rooms": ["#room:remotetest"]})
- def test_auto_create_auto_join_remote_room(self):
+ def test_auto_create_auto_join_remote_room(self) -> None:
"""Tests that we don't attempt to create remote rooms, and that we don't attempt
to invite ourselves to rooms we're not in."""
diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py
index fcde5dab72..df95490d3b 100644
--- a/tests/handlers/test_room.py
+++ b/tests/handlers/test_room.py
@@ -14,7 +14,7 @@ class EncryptedByDefaultTestCase(unittest.HomeserverTestCase):
]
@override_config({"encryption_enabled_by_default_for_room_type": "all"})
- def test_encrypted_by_default_config_option_all(self):
+ def test_encrypted_by_default_config_option_all(self) -> None:
"""Tests that invite-only and non-invite-only rooms have encryption enabled by
default when the config option encryption_enabled_by_default_for_room_type is "all".
"""
@@ -45,7 +45,7 @@ class EncryptedByDefaultTestCase(unittest.HomeserverTestCase):
self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
@override_config({"encryption_enabled_by_default_for_room_type": "invite"})
- def test_encrypted_by_default_config_option_invite(self):
+ def test_encrypted_by_default_config_option_invite(self) -> None:
"""Tests that only new, invite-only rooms have encryption enabled by default when
the config option encryption_enabled_by_default_for_room_type is "invite".
"""
@@ -76,7 +76,7 @@ class EncryptedByDefaultTestCase(unittest.HomeserverTestCase):
)
@override_config({"encryption_enabled_by_default_for_room_type": "off"})
- def test_encrypted_by_default_config_option_off(self):
+ def test_encrypted_by_default_config_option_off(self) -> None:
"""Tests that neither new invite-only nor non-invite-only rooms have encryption
enabled by default when the config option
encryption_enabled_by_default_for_room_type is "off".
diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py
index aa650756e4..d907fcaf04 100644
--- a/tests/handlers/test_room_summary.py
+++ b/tests/handlers/test_room_summary.py
@@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Iterable, List, Optional, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
from unittest import mock
from twisted.internet.defer import ensureDeferred
+from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import (
EventContentFields,
@@ -34,11 +35,14 @@ from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID, create_requester
+from synapse.util import Clock
from tests import unittest
-def _create_event(room_id: str, order: Optional[Any] = None, origin_server_ts: int = 0):
+def _create_event(
+ room_id: str, order: Optional[Any] = None, origin_server_ts: int = 0
+) -> mock.Mock:
result = mock.Mock(name=room_id)
result.room_id = room_id
result.content = {}
@@ -48,40 +52,40 @@ def _create_event(room_id: str, order: Optional[Any] = None, origin_server_ts: i
return result
-def _order(*events):
+def _order(*events: mock.Mock) -> List[mock.Mock]:
return sorted(events, key=_child_events_comparison_key)
class TestSpaceSummarySort(unittest.TestCase):
- def test_no_order_last(self):
+ def test_no_order_last(self) -> None:
"""An event with no ordering is placed behind those with an ordering."""
ev1 = _create_event("!abc:test")
ev2 = _create_event("!xyz:test", "xyz")
self.assertEqual([ev2, ev1], _order(ev1, ev2))
- def test_order(self):
+ def test_order(self) -> None:
"""The ordering should be used."""
ev1 = _create_event("!abc:test", "xyz")
ev2 = _create_event("!xyz:test", "abc")
self.assertEqual([ev2, ev1], _order(ev1, ev2))
- def test_order_origin_server_ts(self):
+ def test_order_origin_server_ts(self) -> None:
"""Origin server is a tie-breaker for ordering."""
ev1 = _create_event("!abc:test", origin_server_ts=10)
ev2 = _create_event("!xyz:test", origin_server_ts=30)
self.assertEqual([ev1, ev2], _order(ev1, ev2))
- def test_order_room_id(self):
+ def test_order_room_id(self) -> None:
"""Room ID is a final tie-breaker for ordering."""
ev1 = _create_event("!abc:test")
ev2 = _create_event("!xyz:test")
self.assertEqual([ev1, ev2], _order(ev1, ev2))
- def test_invalid_ordering_type(self):
+ def test_invalid_ordering_type(self) -> None:
"""Invalid orderings are considered the same as missing."""
ev1 = _create_event("!abc:test", 1)
ev2 = _create_event("!xyz:test", "xyz")
@@ -97,7 +101,7 @@ class TestSpaceSummarySort(unittest.TestCase):
ev1 = _create_event("!abc:test", True)
self.assertEqual([ev2, ev1], _order(ev1, ev2))
- def test_invalid_ordering_value(self):
+ def test_invalid_ordering_value(self) -> None:
"""Invalid orderings are considered the same as missing."""
ev1 = _create_event("!abc:test", "foo\n")
ev2 = _create_event("!xyz:test", "xyz")
@@ -115,7 +119,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, hs: HomeServer):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.hs = hs
self.handler = self.hs.get_room_summary_handler()
@@ -223,7 +227,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
fed_handler.on_invite_request(fed_hostname, event, RoomVersions.V6)
)
- def test_simple_space(self):
+ def test_simple_space(self) -> None:
"""Test a simple space with a single room."""
# The result should have the space and the room in it, along with a link
# from space -> room.
@@ -234,7 +238,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
)
self._assert_hierarchy(result, expected)
- def test_large_space(self):
+ def test_large_space(self) -> None:
"""Test a space with a large number of rooms."""
rooms = [self.room]
# Make at least 51 rooms that are part of the space.
@@ -260,7 +264,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
result["rooms"] += result2["rooms"]
self._assert_hierarchy(result, expected)
- def test_visibility(self):
+ def test_visibility(self) -> None:
"""A user not in a space cannot inspect it."""
user2 = self.register_user("user2", "pass")
token2 = self.login("user2", "pass")
@@ -380,7 +384,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
self._assert_hierarchy(result2, [(self.space, [self.room])])
def _create_room_with_join_rule(
- self, join_rule: str, room_version: Optional[str] = None, **extra_content
+ self, join_rule: str, room_version: Optional[str] = None, **extra_content: Any
) -> str:
"""Create a room with the given join rule and add it to the space."""
room_id = self.helper.create_room_as(
@@ -403,7 +407,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
self._add_child(self.space, room_id, self.token)
return room_id
- def test_filtering(self):
+ def test_filtering(self) -> None:
"""
Rooms should be properly filtered to only include rooms the user has access to.
"""
@@ -476,7 +480,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
)
self._assert_hierarchy(result, expected)
- def test_complex_space(self):
+ def test_complex_space(self) -> None:
"""
Create a "complex" space to see how it handles things like loops and subspaces.
"""
@@ -516,7 +520,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
)
self._assert_hierarchy(result, expected)
- def test_pagination(self):
+ def test_pagination(self) -> None:
"""Test simple pagination works."""
room_ids = []
for i in range(1, 10):
@@ -553,7 +557,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
self._assert_hierarchy(result, expected)
self.assertNotIn("next_batch", result)
- def test_invalid_pagination_token(self):
+ def test_invalid_pagination_token(self) -> None:
"""An invalid pagination token, or changing other parameters, shoudl be rejected."""
room_ids = []
for i in range(1, 10):
@@ -604,7 +608,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
SynapseError,
)
- def test_max_depth(self):
+ def test_max_depth(self) -> None:
"""Create a deep tree to test the max depth against."""
spaces = [self.space]
rooms = [self.room]
@@ -659,7 +663,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
]
self._assert_hierarchy(result, expected)
- def test_unknown_room_version(self):
+ def test_unknown_room_version(self) -> None:
"""
If a room with an unknown room version is encountered it should not cause
the entire summary to skip.
@@ -685,7 +689,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
)
self._assert_hierarchy(result, expected)
- def test_fed_complex(self):
+ def test_fed_complex(self) -> None:
"""
Return data over federation and ensure that it is handled properly.
"""
@@ -722,7 +726,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
"world_readable": True,
}
- async def summarize_remote_room_hierarchy(_self, room, suggested_only):
+ async def summarize_remote_room_hierarchy(
+ _self: Any, room: Any, suggested_only: bool
+ ) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]:
return requested_room_entry, {subroom: child_room}, set()
# Add a room to the space which is on another server.
@@ -744,7 +750,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
)
self._assert_hierarchy(result, expected)
- def test_fed_filtering(self):
+ def test_fed_filtering(self) -> None:
"""
Rooms returned over federation should be properly filtered to only include
rooms the user has access to.
@@ -853,7 +859,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
],
)
- async def summarize_remote_room_hierarchy(_self, room, suggested_only):
+ async def summarize_remote_room_hierarchy(
+ _self: Any, room: Any, suggested_only: bool
+ ) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]:
return subspace_room_entry, dict(children_rooms), set()
# Add a room to the space which is on another server.
@@ -892,7 +900,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
)
self._assert_hierarchy(result, expected)
- def test_fed_invited(self):
+ def test_fed_invited(self) -> None:
"""
A room which the user was invited to should be included in the response.
@@ -915,7 +923,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
},
)
- async def summarize_remote_room_hierarchy(_self, room, suggested_only):
+ async def summarize_remote_room_hierarchy(
+ _self: Any, room: Any, suggested_only: bool
+ ) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]:
return fed_room_entry, {}, set()
# Add a room to the space which is on another server.
@@ -936,7 +946,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
)
self._assert_hierarchy(result, expected)
- def test_fed_caching(self):
+ def test_fed_caching(self) -> None:
"""
Federation `/hierarchy` responses should be cached.
"""
@@ -1023,7 +1033,7 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, hs: HomeServer):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.hs = hs
self.handler = self.hs.get_room_summary_handler()
@@ -1040,12 +1050,12 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase):
tok=self.token,
)
- def test_own_room(self):
+ def test_own_room(self) -> None:
"""Test a simple room created by the requester."""
result = self.get_success(self.handler.get_room_summary(self.user, self.room))
self.assertEqual(result.get("room_id"), self.room)
- def test_visibility(self):
+ def test_visibility(self) -> None:
"""A user not in a private room cannot get its summary."""
user2 = self.register_user("user2", "pass")
token2 = self.login("user2", "pass")
@@ -1093,7 +1103,7 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase):
result = self.get_success(self.handler.get_room_summary(user2, self.room))
self.assertEqual(result.get("room_id"), self.room)
- def test_fed(self):
+ def test_fed(self) -> None:
"""
Return data over federation and ensure that it is handled properly.
"""
@@ -1105,7 +1115,9 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase):
{"room_id": fed_room, "world_readable": True},
)
- async def summarize_remote_room_hierarchy(_self, room, suggested_only):
+ async def summarize_remote_room_hierarchy(
+ _self: Any, room: Any, suggested_only: bool
+ ) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]:
return requested_room_entry, {}, set()
with mock.patch(
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index a0f84e2940..9b1b8b9f13 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Optional, Set, Tuple
from unittest.mock import Mock
import attr
@@ -20,7 +20,9 @@ import attr
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import RedirectException
+from synapse.module_api import ModuleApi
from synapse.server import HomeServer
+from synapse.types import JsonDict
from synapse.util import Clock
from tests.test_utils import simple_async_mock
@@ -29,6 +31,7 @@ from tests.unittest import HomeserverTestCase, override_config
# Check if we have the dependencies to run the tests.
try:
import saml2.config
+ import saml2.response
from saml2.sigver import SigverError
has_saml2 = True
@@ -56,31 +59,39 @@ class FakeAuthnResponse:
class TestMappingProvider:
- def __init__(self, config, module):
+ def __init__(self, config: None, module: ModuleApi):
pass
@staticmethod
- def parse_config(config):
- return
+ def parse_config(config: JsonDict) -> None:
+ return None
@staticmethod
- def get_saml_attributes(config):
+ def get_saml_attributes(config: None) -> Tuple[Set[str], Set[str]]:
return {"uid"}, {"displayName"}
- def get_remote_user_id(self, saml_response, client_redirect_url):
+ def get_remote_user_id(
+ self, saml_response: "saml2.response.AuthnResponse", client_redirect_url: str
+ ) -> str:
return saml_response.ava["uid"]
def saml_response_to_user_attributes(
- self, saml_response, failures, client_redirect_url
- ):
+ self,
+ saml_response: "saml2.response.AuthnResponse",
+ failures: int,
+ client_redirect_url: str,
+ ) -> dict:
localpart = saml_response.ava["username"] + (str(failures) if failures else "")
return {"mxid_localpart": localpart, "displayname": None}
class TestRedirectMappingProvider(TestMappingProvider):
def saml_response_to_user_attributes(
- self, saml_response, failures, client_redirect_url
- ):
+ self,
+ saml_response: "saml2.response.AuthnResponse",
+ failures: int,
+ client_redirect_url: str,
+ ) -> dict:
raise RedirectException(b"https://custom-saml-redirect/")
@@ -347,7 +358,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
)
-def _mock_request():
+def _mock_request() -> Mock:
"""Returns a mock which will stand in as a SynapseRequest"""
mock = Mock(
spec=[
diff --git a/tests/handlers/test_send_email.py b/tests/handlers/test_send_email.py
index da4bf8b582..8b6e4a40b6 100644
--- a/tests/handlers/test_send_email.py
+++ b/tests/handlers/test_send_email.py
@@ -13,7 +13,7 @@
# limitations under the License.
-from typing import List, Tuple
+from typing import Callable, List, Tuple
from zope.interface import implementer
@@ -28,20 +28,27 @@ from tests.unittest import HomeserverTestCase, override_config
@implementer(interfaces.IMessageDelivery)
class _DummyMessageDelivery:
- def __init__(self):
+ def __init__(self) -> None:
# (recipient, message) tuples
self.messages: List[Tuple[smtp.Address, bytes]] = []
- def receivedHeader(self, helo, origin, recipients):
+ def receivedHeader(
+ self,
+ helo: Tuple[bytes, bytes],
+ origin: smtp.Address,
+ recipients: List[smtp.User],
+ ) -> None:
return None
- def validateFrom(self, helo, origin):
+ def validateFrom(
+ self, helo: Tuple[bytes, bytes], origin: smtp.Address
+ ) -> smtp.Address:
return origin
- def record_message(self, recipient: smtp.Address, message: bytes):
+ def record_message(self, recipient: smtp.Address, message: bytes) -> None:
self.messages.append((recipient, message))
- def validateTo(self, user: smtp.User):
+ def validateTo(self, user: smtp.User) -> Callable[[], interfaces.IMessageSMTP]:
return lambda: _DummyMessage(self, user)
@@ -56,20 +63,20 @@ class _DummyMessage:
self._user = user
self._buffer: List[bytes] = []
- def lineReceived(self, line):
+ def lineReceived(self, line: bytes) -> None:
self._buffer.append(line)
- def eomReceived(self):
+ def eomReceived(self) -> "defer.Deferred[bytes]":
message = b"\n".join(self._buffer) + b"\n"
self._delivery.record_message(self._user.dest, message)
return defer.succeed(b"saved")
- def connectionLost(self):
+ def connectionLost(self) -> None:
pass
class SendEmailHandlerTestCase(HomeserverTestCase):
- def test_send_email(self):
+ def test_send_email(self) -> None:
"""Happy-path test that we can send email to a non-TLS server."""
h = self.hs.get_send_email_handler()
d = ensureDeferred(
@@ -119,7 +126,7 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
},
}
)
- def test_send_email_force_tls(self):
+ def test_send_email_force_tls(self) -> None:
"""Happy-path test that we can send email to an Implicit TLS server."""
h = self.hs.get_send_email_handler()
d = ensureDeferred(
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index 05f9ec3c51..f1a50c5bcb 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -12,9 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Dict, List, Optional
+
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.server import HomeServer
from synapse.storage.databases.main import stats
+from synapse.util import Clock
from tests import unittest
@@ -32,11 +38,11 @@ class StatsRoomTests(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.handler = self.hs.get_stats_handler()
- def _add_background_updates(self):
+ def _add_background_updates(self) -> None:
"""
Add the background updates we need to run.
"""
@@ -63,12 +69,14 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
- async def get_all_room_state(self):
+ async def get_all_room_state(self) -> List[Dict[str, Any]]:
return await self.store.db_pool.simple_select_list(
"room_stats_state", None, retcols=("name", "topic", "canonical_alias")
)
- def _get_current_stats(self, stats_type, stat_id):
+ def _get_current_stats(
+ self, stats_type: str, stat_id: str
+ ) -> Optional[Dict[str, Any]]:
table, id_col = stats.TYPE_TO_TABLE[stats_type]
cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type])
@@ -82,13 +90,13 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
- def _perform_background_initial_update(self):
+ def _perform_background_initial_update(self) -> None:
# Do the initial population of the stats via the background update
self._add_background_updates()
self.wait_for_background_updates()
- def test_initial_room(self):
+ def test_initial_room(self) -> None:
"""
The background updates will build the table from scratch.
"""
@@ -125,7 +133,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.assertEqual(len(r), 1)
self.assertEqual(r[0]["topic"], "foo")
- def test_create_user(self):
+ def test_create_user(self) -> None:
"""
When we create a user, it should have statistics already ready.
"""
@@ -134,12 +142,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
u1stats = self._get_current_stats("user", u1)
- self.assertIsNotNone(u1stats)
+ assert u1stats is not None
# not in any rooms by default
self.assertEqual(u1stats["joined_rooms"], 0)
- def test_create_room(self):
+ def test_create_room(self) -> None:
"""
When we create a room, it should have statistics already ready.
"""
@@ -153,8 +161,8 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r2 = self.helper.create_room_as(u1, tok=u1token, is_public=False)
r2stats = self._get_current_stats("room", r2)
- self.assertIsNotNone(r1stats)
- self.assertIsNotNone(r2stats)
+ assert r1stats is not None
+ assert r2stats is not None
self.assertEqual(
r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM
@@ -171,7 +179,9 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.assertEqual(r2stats["invited_members"], 0)
self.assertEqual(r2stats["banned_members"], 0)
- def test_updating_profile_information_does_not_increase_joined_members_count(self):
+ def test_updating_profile_information_does_not_increase_joined_members_count(
+ self,
+ ) -> None:
"""
Check that the joined_members count does not increase when a user changes their
profile information (which is done by sending another join membership event into
@@ -186,6 +196,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# Get the current room stats
r1stats_ante = self._get_current_stats("room", r1)
+ assert r1stats_ante is not None
# Send a profile update into the room
new_profile = {"displayname": "bob"}
@@ -195,6 +206,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# Get the new room stats
r1stats_post = self._get_current_stats("room", r1)
+ assert r1stats_post is not None
# Ensure that the user count did not changed
self.assertEqual(r1stats_post["joined_members"], r1stats_ante["joined_members"])
@@ -202,7 +214,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post["local_users_in_room"], r1stats_ante["local_users_in_room"]
)
- def test_send_state_event_nonoverwriting(self):
+ def test_send_state_event_nonoverwriting(self) -> None:
"""
When we send a non-overwriting state event, it increments current_state_events
"""
@@ -218,19 +230,21 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
r1stats_ante = self._get_current_stats("room", r1)
+ assert r1stats_ante is not None
self.helper.send_state(
r1, "cat.hissing", {"value": False}, tok=u1token, state_key="moggy"
)
r1stats_post = self._get_current_stats("room", r1)
+ assert r1stats_post is not None
self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
1,
)
- def test_join_first_time(self):
+ def test_join_first_time(self) -> None:
"""
When a user joins a room for the first time, current_state_events and
joined_members should increase by exactly 1.
@@ -246,10 +260,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
u2token = self.login("u2", "pass")
r1stats_ante = self._get_current_stats("room", r1)
+ assert r1stats_ante is not None
self.helper.join(r1, u2, tok=u2token)
r1stats_post = self._get_current_stats("room", r1)
+ assert r1stats_post is not None
self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
@@ -259,7 +275,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post["joined_members"] - r1stats_ante["joined_members"], 1
)
- def test_join_after_leave(self):
+ def test_join_after_leave(self) -> None:
"""
When a user joins a room after being previously left,
joined_members should increase by exactly 1.
@@ -280,10 +296,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.helper.leave(r1, u2, tok=u2token)
r1stats_ante = self._get_current_stats("room", r1)
+ assert r1stats_ante is not None
self.helper.join(r1, u2, tok=u2token)
r1stats_post = self._get_current_stats("room", r1)
+ assert r1stats_post is not None
self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
@@ -296,7 +314,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post["left_members"] - r1stats_ante["left_members"], -1
)
- def test_invited(self):
+ def test_invited(self) -> None:
"""
When a user invites another user, current_state_events and
invited_members should increase by exactly 1.
@@ -311,10 +329,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
u2 = self.register_user("u2", "pass")
r1stats_ante = self._get_current_stats("room", r1)
+ assert r1stats_ante is not None
self.helper.invite(r1, u1, u2, tok=u1token)
r1stats_post = self._get_current_stats("room", r1)
+ assert r1stats_post is not None
self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
@@ -324,7 +344,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post["invited_members"] - r1stats_ante["invited_members"], +1
)
- def test_join_after_invite(self):
+ def test_join_after_invite(self) -> None:
"""
When a user joins a room after being invited and
joined_members should increase by exactly 1.
@@ -344,10 +364,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.helper.invite(r1, u1, u2, tok=u1token)
r1stats_ante = self._get_current_stats("room", r1)
+ assert r1stats_ante is not None
self.helper.join(r1, u2, tok=u2token)
r1stats_post = self._get_current_stats("room", r1)
+ assert r1stats_post is not None
self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
@@ -360,7 +382,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post["invited_members"] - r1stats_ante["invited_members"], -1
)
- def test_left(self):
+ def test_left(self) -> None:
"""
When a user leaves a room after joining and
left_members should increase by exactly 1.
@@ -380,10 +402,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.helper.join(r1, u2, tok=u2token)
r1stats_ante = self._get_current_stats("room", r1)
+ assert r1stats_ante is not None
self.helper.leave(r1, u2, tok=u2token)
r1stats_post = self._get_current_stats("room", r1)
+ assert r1stats_post is not None
self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
@@ -396,7 +420,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post["joined_members"] - r1stats_ante["joined_members"], -1
)
- def test_banned(self):
+ def test_banned(self) -> None:
"""
When a user is banned from a room after joining and
left_members should increase by exactly 1.
@@ -416,10 +440,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.helper.join(r1, u2, tok=u2token)
r1stats_ante = self._get_current_stats("room", r1)
+ assert r1stats_ante is not None
self.helper.change_membership(r1, u1, u2, "ban", tok=u1token)
r1stats_post = self._get_current_stats("room", r1)
+ assert r1stats_post is not None
self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
@@ -432,7 +458,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post["joined_members"] - r1stats_ante["joined_members"], -1
)
- def test_initial_background_update(self):
+ def test_initial_background_update(self) -> None:
"""
Test that statistics can be generated by the initial background update
handler.
@@ -462,6 +488,9 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats = self._get_current_stats("room", r1)
u1stats = self._get_current_stats("user", u1)
+ assert r1stats is not None
+ assert u1stats is not None
+
self.assertEqual(r1stats["joined_members"], 1)
self.assertEqual(
r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM
@@ -469,7 +498,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.assertEqual(u1stats["joined_rooms"], 1)
- def test_incomplete_stats(self):
+ def test_incomplete_stats(self) -> None:
"""
This tests that we track incomplete statistics.
@@ -533,8 +562,11 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.wait_for_background_updates()
r1stats_complete = self._get_current_stats("room", r1)
+ assert r1stats_complete is not None
u1stats_complete = self._get_current_stats("user", u1)
+ assert u1stats_complete is not None
u2stats_complete = self._get_current_stats("user", u2)
+ assert u2stats_complete is not None
# now we make our assertions
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index ab5c101eb7..0d9a3de92a 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -14,6 +14,8 @@
from typing import Optional
from unittest.mock import MagicMock, Mock, patch
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import Codes, ResourceLimitError
from synapse.api.filtering import Filtering
@@ -23,6 +25,7 @@ from synapse.rest import admin
from synapse.rest.client import knock, login, room
from synapse.server import HomeServer
from synapse.types import UserID, create_requester
+from synapse.util import Clock
import tests.unittest
import tests.utils
@@ -39,7 +42,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
room.register_servlets,
]
- def prepare(self, reactor, clock, hs: HomeServer):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.sync_handler = self.hs.get_sync_handler()
self.store = self.hs.get_datastores().main
@@ -47,7 +50,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# modify its config instead of the hs'
self.auth_blocking = self.hs.get_auth_blocking()
- def test_wait_for_sync_for_user_auth_blocking(self):
+ def test_wait_for_sync_for_user_auth_blocking(self) -> None:
user_id1 = "@user1:test"
user_id2 = "@user2:test"
sync_config = generate_sync_config(user_id1)
@@ -82,7 +85,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
)
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
- def test_unknown_room_version(self):
+ def test_unknown_room_version(self) -> None:
"""
A room with an unknown room version should not break sync (and should be excluded).
"""
@@ -186,7 +189,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.assertNotIn(invite_room, [r.room_id for r in result.invited])
self.assertNotIn(knock_room, [r.room_id for r in result.knocked])
- def test_ban_wins_race_with_join(self):
+ def test_ban_wins_race_with_join(self) -> None:
"""Rooms shouldn't appear under "joined" if a join loses a race to a ban.
A complicated edge case. Imagine the following scenario:
|