summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/app/test_openid_listener.py8
-rw-r--r--tests/appservice/test_scheduler.py6
-rw-r--r--tests/crypto/test_keyring.py22
-rw-r--r--tests/events/test_presence_router.py22
-rw-r--r--tests/federation/test_complexity.py35
-rw-r--r--tests/federation/test_federation_catch_up.py32
-rw-r--r--tests/federation/test_federation_client.py4
-rw-r--r--tests/federation/test_federation_sender.py55
-rw-r--r--tests/handlers/test_admin.py27
-rw-r--r--tests/handlers/test_appservice.py2
-rw-r--r--tests/handlers/test_cas.py8
-rw-r--r--tests/handlers/test_e2e_keys.py55
-rw-r--r--tests/handlers/test_federation.py58
-rw-r--r--tests/handlers/test_federation_event.py6
-rw-r--r--tests/handlers/test_message.py11
-rw-r--r--tests/handlers/test_oidc.py4
-rw-r--r--tests/handlers/test_password_providers.py2
-rw-r--r--tests/handlers/test_register.py14
-rw-r--r--tests/handlers/test_saml.py14
-rw-r--r--tests/handlers/test_typing.py12
-rw-r--r--tests/handlers/test_user_directory.py42
-rw-r--r--tests/http/federation/test_matrix_federation_agent.py16
-rw-r--r--tests/http/test_proxyagent.py45
-rw-r--r--tests/logging/test_remote_handler.py17
-rw-r--r--tests/module_api/test_api.py136
-rw-r--r--tests/push/test_bulk_push_rule_evaluator.py18
-rw-r--r--tests/push/test_email.py51
-rw-r--r--tests/push/test_http.py45
-rw-r--r--tests/push/test_push_rule_evaluator.py227
-rw-r--r--tests/replication/tcp/streams/test_events.py10
-rw-r--r--tests/replication/tcp/streams/test_partial_state.py2
-rw-r--r--tests/replication/tcp/streams/test_typing.py4
-rw-r--r--tests/replication/tcp/test_handler.py1
-rw-r--r--tests/replication/test_federation_sender_shard.py2
-rw-r--r--tests/replication/test_pusher_shard.py1
-rw-r--r--tests/rest/admin/test_media.py9
-rw-r--r--tests/rest/admin/test_server_notice.py4
-rw-r--r--tests/rest/admin/test_user.py9
-rw-r--r--tests/rest/admin/test_username_available.py15
-rw-r--r--tests/rest/client/test_account.py2
-rw-r--r--tests/rest/client/test_auth.py17
-rw-r--r--tests/rest/client/test_filter.py4
-rw-r--r--tests/rest/client/test_presence.py10
-rw-r--r--tests/rest/client/test_register.py7
-rw-r--r--tests/rest/client/test_report_event.py12
-rw-r--r--tests/rest/client/test_retention.py6
-rw-r--r--tests/rest/client/test_rooms.py12
-rw-r--r--tests/rest/client/test_shadow_banned.py6
-rw-r--r--tests/rest/client/test_third_party_rules.py2
-rw-r--r--tests/rest/client/test_upgrade_room.py2
-rw-r--r--tests/rest/client/utils.py58
-rw-r--r--tests/rest/media/v1/test_media_storage.py49
-rw-r--r--tests/scripts/test_new_matrix_user.py25
-rw-r--r--tests/server.py253
-rw-r--r--tests/server_notices/test_consent.py14
-rw-r--r--tests/server_notices/test_resource_limits_server_notices.py85
-rw-r--r--tests/storage/databases/main/test_events_worker.py1
-rw-r--r--tests/storage/test_event_chain.py10
-rw-r--r--tests/storage/test_event_federation.py9
-rw-r--r--tests/storage/test_events.py8
-rw-r--r--tests/storage/test_keys.py61
-rw-r--r--tests/storage/test_purge.py2
-rw-r--r--tests/storage/test_receipts.py6
-rw-r--r--tests/storage/test_redaction.py24
-rw-r--r--tests/storage/test_room_search.py3
-rw-r--r--tests/storage/test_state.py4
-rw-r--r--tests/storage/test_stream.py4
-rw-r--r--tests/storage/test_unsafe_locale.py2
-rw-r--r--tests/storage/test_user_directory.py63
-rw-r--r--tests/test_distributor.py12
-rw-r--r--tests/test_event_auth.py32
-rw-r--r--tests/test_federation.py109
-rw-r--r--tests/test_mau.py35
-rw-r--r--tests/test_phone_home.py2
-rw-r--r--tests/test_rust.py2
-rw-r--r--tests/test_test_utils.py16
-rw-r--r--tests/test_types.py30
-rw-r--r--tests/test_utils/__init__.py26
-rw-r--r--tests/test_utils/event_injection.py15
-rw-r--r--tests/test_utils/html_parsers.py6
-rw-r--r--tests/test_utils/logging_setup.py4
-rw-r--r--tests/test_utils/oidc.py10
-rw-r--r--tests/test_visibility.py27
-rw-r--r--tests/unittest.py17
-rw-r--r--tests/util/test_retryutils.py2
-rw-r--r--tests/utils.py31
86 files changed, 1416 insertions, 772 deletions
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py

index 5d89ba94ad..2ee343d8a4 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py
@@ -67,7 +67,9 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): } # Listen with the config - self.hs._listen_http(parse_listener_def(0, config)) + hs = self.hs + assert isinstance(hs, GenericWorkerServer) + hs._listen_http(parse_listener_def(0, config)) # Grab the resource from the site that was told to listen site = self.reactor.tcpServers[0][1] @@ -115,7 +117,9 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase): } # Listen with the config - self.hs._listener_http(self.hs.config, parse_listener_def(0, config)) + hs = self.hs + assert isinstance(hs, SynapseHomeServer) + hs._listener_http(self.hs.config, parse_listener_def(0, config)) # Grab the resource from the site that was told to listen site = self.reactor.tcpServers[0][1] diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index febcc1499d..e2a3bad065 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py
@@ -11,12 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, cast +from typing import List, Optional, Sequence, Tuple, cast from unittest.mock import Mock from typing_extensions import TypeAlias from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor from synapse.appservice import ( ApplicationService, @@ -40,9 +41,6 @@ from tests.test_utils import simple_async_mock from ..utils import MockClock -if TYPE_CHECKING: - from twisted.internet.testing import MemoryReactor - class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): def setUp(self) -> None: diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 0e8af2da54..1b9696748f 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py
@@ -192,7 +192,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): key1 = signedjson.key.generate_signing_key("1") r = self.hs.get_datastores().main.store_server_verify_keys( "server9", - time.time() * 1000, + int(time.time() * 1000), [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))], ) self.get_success(r) @@ -287,7 +287,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): key1 = signedjson.key.generate_signing_key("1") r = self.hs.get_datastores().main.store_server_verify_keys( "server9", - time.time() * 1000, + int(time.time() * 1000), # None is not a valid value in FetchKeyResult, but we're abusing this # API to insert null values into the database. The nulls get converted # to 0 when fetched in KeyStore.get_server_verify_keys. @@ -466,9 +466,9 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): key_json = self.get_success( self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) ) - res = key_json[lookup_triplet] - self.assertEqual(len(res), 1) - res = res[0] + res_keys = key_json[lookup_triplet] + self.assertEqual(len(res_keys), 1) + res = res_keys[0] self.assertEqual(res["key_id"], testverifykey_id) self.assertEqual(res["from_server"], SERVER_NAME) self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000) @@ -584,9 +584,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): key_json = self.get_success( self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) ) - res = key_json[lookup_triplet] - self.assertEqual(len(res), 1) - res = res[0] + res_keys = key_json[lookup_triplet] + self.assertEqual(len(res_keys), 1) + res = res_keys[0] self.assertEqual(res["key_id"], testverifykey_id) self.assertEqual(res["from_server"], self.mock_perspective_server.server_name) self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000) @@ -705,9 +705,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): key_json = self.get_success( self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) ) - res = key_json[lookup_triplet] - self.assertEqual(len(res), 1) - res = res[0] + res_keys = key_json[lookup_triplet] + self.assertEqual(len(res_keys), 1) + res = res_keys[0] self.assertEqual(res["key_id"], testverifykey_id) self.assertEqual(res["from_server"], self.mock_perspective_server.server_name) self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000) diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index a9893def74..6fb1f1bd6e 100644 --- a/tests/events/test_presence_router.py +++ b/tests/events/test_presence_router.py
@@ -31,7 +31,11 @@ from synapse.util import Clock from tests.handlers.test_sync import generate_sync_config from tests.test_utils import simple_async_mock -from tests.unittest import FederatingHomeserverTestCase, override_config +from tests.unittest import ( + FederatingHomeserverTestCase, + HomeserverTestCase, + override_config, +) @attr.s @@ -152,11 +156,11 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: # Mock out the calls over federation. - fed_transport_client = Mock(spec=["send_transaction"]) - fed_transport_client.send_transaction = simple_async_mock({}) + self.fed_transport_client = Mock(spec=["send_transaction"]) + self.fed_transport_client.send_transaction = simple_async_mock({}) hs = self.setup_test_homeserver( - federation_transport_client=fed_transport_client, + federation_transport_client=self.fed_transport_client, ) load_legacy_presence_router(hs) @@ -418,7 +422,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): # # Thus we reset the mock, and try sending all online local user # presence again - self.hs.get_federation_transport_client().send_transaction.reset_mock() + self.fed_transport_client.send_transaction.reset_mock() # Broadcast local user online presence self.get_success( @@ -443,9 +447,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): } found_users = set() - calls = ( - self.hs.get_federation_transport_client().send_transaction.call_args_list - ) + calls = self.fed_transport_client.send_transaction.call_args_list for call in calls: call_args = call[0] federation_transaction: Transaction = call_args[0] @@ -470,7 +472,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): def send_presence_update( - testcase: FederatingHomeserverTestCase, + testcase: HomeserverTestCase, user_id: str, access_token: str, presence_state: str, @@ -491,7 +493,7 @@ def send_presence_update( def sync_presence( - testcase: FederatingHomeserverTestCase, + testcase: HomeserverTestCase, user_id: str, since_token: Optional[StreamToken] = None, ) -> Tuple[List[UserPresenceState], StreamToken]: diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index d667dd27bf..35dd9a20df 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py
@@ -17,7 +17,7 @@ from unittest.mock import Mock from synapse.api.errors import Codes, SynapseError from synapse.rest import admin from synapse.rest.client import login, room -from synapse.types import JsonDict, UserID +from synapse.types import JsonDict, UserID, create_requester from tests import unittest from tests.test_utils import make_awaitable @@ -56,7 +56,11 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): # Artificially raise the complexity store = self.hs.get_datastores().main - store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23) + + async def get_current_state_event_counts(room_id: str) -> int: + return int(500 * 1.23) + + store.get_current_state_event_counts = get_current_state_event_counts # type: ignore[assignment] # Get the room complexity again -- make sure it's our artificial value channel = self.make_signed_federation_request( @@ -75,12 +79,12 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): # Mock out some things, because we don't want to test the whole join fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) - handler.federation_handler.do_invite_join = Mock( + handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] return_value=make_awaitable(("", 1)) ) d = handler._remote_join( - None, + create_requester(u1), ["other.example.com"], "roomid", UserID.from_string(u1), @@ -106,12 +110,12 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): # Mock out some things, because we don't want to test the whole join fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) - handler.federation_handler.do_invite_join = Mock( + handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] return_value=make_awaitable(("", 1)) ) d = handler._remote_join( - None, + create_requester(u1), ["other.example.com"], "roomid", UserID.from_string(u1), @@ -144,17 +148,18 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): # Mock out some things, because we don't want to test the whole join fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) - handler.federation_handler.do_invite_join = Mock( + handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] return_value=make_awaitable(("", 1)) ) # Artificially raise the complexity - self.hs.get_datastores().main.get_current_state_event_counts = ( - lambda x: make_awaitable(600) - ) + async def get_current_state_event_counts(room_id: str) -> int: + return 600 + + self.hs.get_datastores().main.get_current_state_event_counts = get_current_state_event_counts # type: ignore[assignment] d = handler._remote_join( - None, + create_requester(u1), ["other.example.com"], room_1, UserID.from_string(u1), @@ -200,12 +205,12 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): # Mock out some things, because we don't want to test the whole join fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) - handler.federation_handler.do_invite_join = Mock( + handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] return_value=make_awaitable(("", 1)) ) d = handler._remote_join( - None, + create_requester(u1), ["other.example.com"], "roomid", UserID.from_string(u1), @@ -230,12 +235,12 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): # Mock out some things, because we don't want to test the whole join fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) - handler.federation_handler.do_invite_join = Mock( + handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] return_value=make_awaitable(("", 1)) ) d = handler._remote_join( - None, + create_requester(u1), ["other.example.com"], "roomid", UserID.from_string(u1), diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index a986b15f0a..6381583c24 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py
@@ -5,7 +5,11 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.federation.sender import PerDestinationQueue, TransactionManager +from synapse.federation.sender import ( + FederationSender, + PerDestinationQueue, + TransactionManager, +) from synapse.federation.units import Edu, Transaction from synapse.rest import admin from synapse.rest.client import login, room @@ -33,8 +37,9 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.federation_transport_client = Mock(spec=["send_transaction"]) return self.setup_test_homeserver( - federation_transport_client=Mock(spec=["send_transaction"]), + federation_transport_client=self.federation_transport_client, ) def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -52,10 +57,14 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): self.pdus: List[JsonDict] = [] self.failed_pdus: List[JsonDict] = [] self.is_online = True - self.hs.get_federation_transport_client().send_transaction.side_effect = ( + self.federation_transport_client.send_transaction.side_effect = ( self.record_transaction ) + federation_sender = hs.get_federation_sender() + assert isinstance(federation_sender, FederationSender) + self.federation_sender = federation_sender + def default_config(self) -> JsonDict: config = super().default_config() config["federation_sender_instances"] = None @@ -229,11 +238,11 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): # let's delete the federation transmission queue # (this pretends we are starting up fresh.) self.assertFalse( - self.hs.get_federation_sender() - ._per_destination_queues["host2"] - .transmission_loop_running + self.federation_sender._per_destination_queues[ + "host2" + ].transmission_loop_running ) - del self.hs.get_federation_sender()._per_destination_queues["host2"] + del self.federation_sender._per_destination_queues["host2"] # let's also clear any backoffs self.get_success( @@ -322,6 +331,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): # also fetch event 5 so we know its last_successful_stream_ordering later event_5 = self.get_success(self.hs.get_datastores().main.get_event(event_id_5)) + assert event_2.internal_metadata.stream_ordering is not None self.get_success( self.hs.get_datastores().main.set_destination_last_successful_stream_ordering( "host2", event_2.internal_metadata.stream_ordering @@ -425,15 +435,16 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): def wake_destination_track(destination: str) -> None: woken.append(destination) - self.hs.get_federation_sender().wake_destination = wake_destination_track + self.federation_sender.wake_destination = wake_destination_track # type: ignore[assignment] # cancel the pre-existing timer for _wake_destinations_needing_catchup # this is because we are calling it manually rather than waiting for it # to be called automatically - self.hs.get_federation_sender()._catchup_after_startup_timer.cancel() + assert self.federation_sender._catchup_after_startup_timer is not None + self.federation_sender._catchup_after_startup_timer.cancel() self.get_success( - self.hs.get_federation_sender()._wake_destinations_needing_catchup(), by=5.0 + self.federation_sender._wake_destinations_needing_catchup(), by=5.0 ) # ASSERT (_wake_destinations_needing_catchup): @@ -475,6 +486,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): ) ) + assert event_1.internal_metadata.stream_ordering is not None self.get_success( self.hs.get_datastores().main.set_destination_last_successful_stream_ordering( "host2", event_1.internal_metadata.stream_ordering diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py
index 86e1236501..91694e4fca 100644 --- a/tests/federation/test_federation_client.py +++ b/tests/federation/test_federation_client.py
@@ -178,7 +178,7 @@ class FederationClientTest(FederatingHomeserverTestCase): RoomVersions.V9, ) ) - self.assertIsNotNone(pulled_pdu_info2) + assert pulled_pdu_info2 is not None remote_pdu2 = pulled_pdu_info2.pdu # Sanity check that we are working against the same event @@ -226,7 +226,7 @@ class FederationClientTest(FederatingHomeserverTestCase): RoomVersions.V9, ) ) - self.assertIsNotNone(pulled_pdu_info) + assert pulled_pdu_info is not None remote_pdu = pulled_pdu_info.pdu # check the right call got made to the agent diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index ddeffe1ad5..9e104fd96a 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py
@@ -22,6 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EduTypes, RoomEncryptionAlgorithms from synapse.federation.units import Transaction +from synapse.handlers.device import DeviceHandler from synapse.rest import admin from synapse.rest.client import login from synapse.server import HomeServer @@ -41,8 +42,9 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): """ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.federation_transport_client = Mock(spec=["send_transaction"]) hs = self.setup_test_homeserver( - federation_transport_client=Mock(spec=["send_transaction"]), + federation_transport_client=self.federation_transport_client, ) hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[assignment] @@ -61,9 +63,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): return config def test_send_receipts(self) -> None: - mock_send_transaction = ( - self.hs.get_federation_transport_client().send_transaction - ) + mock_send_transaction = self.federation_transport_client.send_transaction mock_send_transaction.return_value = make_awaitable({}) sender = self.hs.get_federation_sender() @@ -103,9 +103,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): ) def test_send_receipts_thread(self) -> None: - mock_send_transaction = ( - self.hs.get_federation_transport_client().send_transaction - ) + mock_send_transaction = self.federation_transport_client.send_transaction mock_send_transaction.return_value = make_awaitable({}) # Create receipts for: @@ -181,9 +179,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): def test_send_receipts_with_backoff(self) -> None: """Send two receipts in quick succession; the second should be flushed, but only after 20ms""" - mock_send_transaction = ( - self.hs.get_federation_transport_client().send_transaction - ) + mock_send_transaction = self.federation_transport_client.send_transaction mock_send_transaction.return_value = make_awaitable({}) sender = self.hs.get_federation_sender() @@ -277,10 +273,11 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.federation_transport_client = Mock( + spec=["send_transaction", "query_user_devices"] + ) return self.setup_test_homeserver( - federation_transport_client=Mock( - spec=["send_transaction", "query_user_devices"] - ), + federation_transport_client=self.federation_transport_client, ) def default_config(self) -> JsonDict: @@ -310,9 +307,13 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): hs.get_datastores().main.get_current_hosts_in_room = get_current_hosts_in_room # type: ignore[assignment] + device_handler = hs.get_device_handler() + assert isinstance(device_handler, DeviceHandler) + self.device_handler = device_handler + # whenever send_transaction is called, record the edu data self.edus: List[JsonDict] = [] - self.hs.get_federation_transport_client().send_transaction.side_effect = ( + self.federation_transport_client.send_transaction.side_effect = ( self.record_transaction ) @@ -353,7 +354,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # Send the server a device list EDU for the other user, this will cause # it to try and resync the device lists. - self.hs.get_federation_transport_client().query_user_devices.return_value = ( + self.federation_transport_client.query_user_devices.return_value = ( make_awaitable( { "stream_id": "1", @@ -364,7 +365,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): ) self.get_success( - self.hs.get_device_handler().device_list_updater.incoming_device_list_update( + self.device_handler.device_list_updater.incoming_device_list_update( "host2", { "user_id": "@user2:host2", @@ -507,9 +508,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D3", stream_id) # delete them again - self.get_success( - self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) - ) + self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"])) # We queue up device list updates to be sent over federation, so we # advance to clear the queue. @@ -533,7 +532,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): """If the destination server is unreachable, all the updates should get sent on recovery """ - mock_send_txn = self.hs.get_federation_transport_client().send_transaction + mock_send_txn = self.federation_transport_client.send_transaction mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) # create devices @@ -543,9 +542,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): self.login("user", "pass", device_id="D3") # delete them again - self.get_success( - self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) - ) + self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"])) # We queue up device list updates to be sent over federation, so we # advance to clear the queue. @@ -580,7 +577,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): This case tests the behaviour when the server has never been reachable. """ - mock_send_txn = self.hs.get_federation_transport_client().send_transaction + mock_send_txn = self.federation_transport_client.send_transaction mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) # create devices @@ -590,9 +587,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): self.login("user", "pass", device_id="D3") # delete them again - self.get_success( - self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) - ) + self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"])) # We queue up device list updates to be sent over federation, so we # advance to clear the queue. @@ -640,7 +635,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): self.check_device_update_edu(self.edus.pop(0), u1, "D1", None) # now the server goes offline - mock_send_txn = self.hs.get_federation_transport_client().send_transaction + mock_send_txn = self.federation_transport_client.send_transaction mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) self.login("user", "pass", device_id="D2") @@ -651,9 +646,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): self.reactor.advance(1) # delete them again - self.get_success( - self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) - ) + self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"])) self.assertGreaterEqual(mock_send_txn.call_count, 3) diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
index 6f300b8e11..1b97aaeed1 100644 --- a/tests/handlers/test_admin.py +++ b/tests/handlers/test_admin.py
@@ -296,3 +296,30 @@ class ExfiltrateData(unittest.HomeserverTestCase): self.assertEqual(args[0][0]["user_agent"], "user_agent") self.assertGreater(args[0][0]["last_seen"], 0) self.assertNotIn("access_token", args[0][0]) + + def test_account_data(self) -> None: + """Tests that user account data get exported.""" + # add account data + self.get_success( + self._store.add_account_data_for_user(self.user2, "m.global", {"a": 1}) + ) + self.get_success( + self._store.add_account_data_to_room( + self.user2, "test_room", "m.per_room", {"b": 2} + ) + ) + + writer = Mock() + + self.get_success(self.admin_handler.export_user_data(self.user2, writer)) + + # two calls, one call for user data and one call for room data + writer.write_account_data.assert_called() + + args = writer.write_account_data.call_args_list[0][0] + self.assertEqual(args[0], "global") + self.assertEqual(args[1]["m.global"]["a"], 1) + + args = writer.write_account_data.call_args_list[1][0] + self.assertEqual(args[0], "test_room") + self.assertEqual(args[1]["m.per_room"]["b"], 2) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index a7495ab21a..9014e60577 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py
@@ -899,7 +899,7 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase) # Mock out application services, and allow defining our own in tests self._services: List[ApplicationService] = [] - self.hs.get_datastores().main.get_app_services = Mock( + self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[assignment] return_value=self._services ) diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index 2733719d82..63aad0d10c 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py
@@ -61,7 +61,7 @@ class CasHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() + auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] cas_response = CasResponse("test_user", {}) request = _mock_request() @@ -89,7 +89,7 @@ class CasHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() + auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] # Map a user via SSO. cas_response = CasResponse("test_user", {}) @@ -129,7 +129,7 @@ class CasHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() + auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] cas_response = CasResponse("föö", {}) request = _mock_request() @@ -160,7 +160,7 @@ class CasHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() + auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] # The response doesn't have the proper userGroup or department. cas_response = CasResponse("test_user", {}) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 95698bc275..6b4cba65d0 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py
@@ -23,6 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import RoomEncryptionAlgorithms from synapse.api.errors import Codes, SynapseError +from synapse.handlers.device import DeviceHandler from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock @@ -187,37 +188,37 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): ) # we should now have an unused alg1 key - res = self.get_success( + fallback_res = self.get_success( self.store.get_e2e_unused_fallback_key_types(local_user, device_id) ) - self.assertEqual(res, ["alg1"]) + self.assertEqual(fallback_res, ["alg1"]) # claiming an OTK when no OTKs are available should return the fallback # key - res = self.get_success( + claim_res = self.get_success( self.handler.claim_one_time_keys( {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None ) ) self.assertEqual( - res, + claim_res, {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}}, ) # we shouldn't have any unused fallback keys again - res = self.get_success( + unused_res = self.get_success( self.store.get_e2e_unused_fallback_key_types(local_user, device_id) ) - self.assertEqual(res, []) + self.assertEqual(unused_res, []) # claiming an OTK again should return the same fallback key - res = self.get_success( + claim_res = self.get_success( self.handler.claim_one_time_keys( {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None ) ) self.assertEqual( - res, + claim_res, {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}}, ) @@ -231,10 +232,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): ) ) - res = self.get_success( + unused_res = self.get_success( self.store.get_e2e_unused_fallback_key_types(local_user, device_id) ) - self.assertEqual(res, []) + self.assertEqual(unused_res, []) # uploading a new fallback key should result in an unused fallback key self.get_success( @@ -245,10 +246,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): ) ) - res = self.get_success( + unused_res = self.get_success( self.store.get_e2e_unused_fallback_key_types(local_user, device_id) ) - self.assertEqual(res, ["alg1"]) + self.assertEqual(unused_res, ["alg1"]) # if the user uploads a one-time key, the next claim should fetch the # one-time key, and then go back to the fallback @@ -258,23 +259,23 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): ) ) - res = self.get_success( + claim_res = self.get_success( self.handler.claim_one_time_keys( {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None ) ) self.assertEqual( - res, + claim_res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}}, ) - res = self.get_success( + claim_res = self.get_success( self.handler.claim_one_time_keys( {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None ) ) self.assertEqual( - res, + claim_res, {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}}, ) @@ -287,13 +288,13 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): ) ) - res = self.get_success( + claim_res = self.get_success( self.handler.claim_one_time_keys( {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None ) ) self.assertEqual( - res, + claim_res, {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}}, ) @@ -366,7 +367,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1)) # upload two device keys, which will be signed later by the self-signing key - device_key_1 = { + device_key_1: JsonDict = { "user_id": local_user, "device_id": "abc", "algorithms": [ @@ -379,7 +380,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): }, "signatures": {local_user: {"ed25519:abc": "base64+signature"}}, } - device_key_2 = { + device_key_2: JsonDict = { "user_id": local_user, "device_id": "def", "algorithms": [ @@ -451,8 +452,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): } self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1)) + device_handler = self.hs.get_device_handler() + assert isinstance(device_handler, DeviceHandler) e = self.get_failure( - self.hs.get_device_handler().check_device_registered( + device_handler.check_device_registered( user_id=local_user, device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", initial_device_display_name="new display name", @@ -475,7 +478,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): device_id = "xyz" # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA device_pubkey = "NnHhnqiMFQkq969szYkooLaBAXW244ZOxgukCvm2ZeY" - device_key = { + device_key: JsonDict = { "user_id": local_user, "device_id": device_id, "algorithms": [ @@ -497,7 +500,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0 master_pubkey = "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk" - master_key = { + master_key: JsonDict = { "user_id": local_user, "usage": ["master"], "keys": {"ed25519:" + master_pubkey: master_pubkey}, @@ -540,7 +543,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): # the first user other_user = "@otherboris:" + self.hs.hostname other_master_pubkey = "fHZ3NPiKxoLQm5OoZbKa99SYxprOjNs4TwJUKP+twCM" - other_master_key = { + other_master_key: JsonDict = { # private key: oyw2ZUx0O4GifbfFYM0nQvj9CL0b8B7cyN4FprtK8OI "user_id": other_user, "usage": ["master"], @@ -702,7 +705,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" - self.hs.get_federation_client().query_client_keys = mock.Mock( + self.hs.get_federation_client().query_client_keys = mock.Mock( # type: ignore[assignment] return_value=make_awaitable( { "device_keys": {remote_user_id: {}}, @@ -782,7 +785,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" - self.hs.get_federation_client().query_user_devices = mock.Mock( + self.hs.get_federation_client().query_user_devices = mock.Mock( # type: ignore[assignment] return_value=make_awaitable( { "user_id": remote_user_id, diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 57675fa407..bf0862ed54 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py
@@ -371,14 +371,14 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): # We mock out the FederationClient.backfill method, to pretend that a remote # server has returned our fake event. federation_client_backfill_mock = Mock(return_value=make_awaitable([event])) - self.hs.get_federation_client().backfill = federation_client_backfill_mock + self.hs.get_federation_client().backfill = federation_client_backfill_mock # type: ignore[assignment] # We also mock the persist method with a side effect of itself. This allows us # to track when it has been called while preserving its function. persist_events_and_notify_mock = Mock( side_effect=self.hs.get_federation_event_handler().persist_events_and_notify ) - self.hs.get_federation_event_handler().persist_events_and_notify = ( + self.hs.get_federation_event_handler().persist_events_and_notify = ( # type: ignore[assignment] persist_events_and_notify_mock ) @@ -575,26 +575,6 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): fed_client = fed_handler.federation_client room_id = "!room:example.com" - membership_event = make_event_from_dict( - { - "room_id": room_id, - "type": "m.room.member", - "sender": "@alice:test", - "state_key": "@alice:test", - "content": {"membership": "join"}, - }, - RoomVersions.V10, - ) - - mock_make_membership_event = Mock( - return_value=make_awaitable( - ( - "example.com", - membership_event, - RoomVersions.V10, - ) - ) - ) EVENT_CREATE = make_event_from_dict( { @@ -640,6 +620,26 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): }, room_version=RoomVersions.V10, ) + membership_event = make_event_from_dict( + { + "room_id": room_id, + "type": "m.room.member", + "sender": "@alice:test", + "state_key": "@alice:test", + "content": {"membership": "join"}, + "prev_events": [EVENT_INVITATION_MEMBERSHIP.event_id], + }, + RoomVersions.V10, + ) + mock_make_membership_event = Mock( + return_value=make_awaitable( + ( + "example.com", + membership_event, + RoomVersions.V10, + ) + ) + ) mock_send_join = Mock( return_value=make_awaitable( SendJoinResult( @@ -712,12 +712,12 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room): # Start the partial state sync. - fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id") + fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id") self.assertEqual(mock_sync_partial_state_room.call_count, 1) # Try to start another partial state sync. # Nothing should happen. - fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id") + fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id") self.assertEqual(mock_sync_partial_state_room.call_count, 1) # End the partial state sync @@ -729,7 +729,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): # The next attempt to start the partial state sync should work. is_partial_state = True - fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id") + fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id") self.assertEqual(mock_sync_partial_state_room.call_count, 2) def test_partial_state_room_sync_restart(self) -> None: @@ -764,7 +764,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room): # Start the partial state sync. - fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id") + fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id") self.assertEqual(mock_sync_partial_state_room.call_count, 1) # Fail the partial state sync. @@ -773,11 +773,11 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): self.assertEqual(mock_sync_partial_state_room.call_count, 1) # Start the partial state sync again. - fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id") + fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id") self.assertEqual(mock_sync_partial_state_room.call_count, 2) # Deduplicate another partial state sync. - fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id") + fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id") self.assertEqual(mock_sync_partial_state_room.call_count, 2) # Fail the partial state sync. @@ -786,6 +786,6 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): self.assertEqual(mock_sync_partial_state_room.call_count, 3) mock_sync_partial_state_room.assert_called_with( initial_destination="hs3", - other_destinations=["hs2"], + other_destinations={"hs2"}, room_id="room_id", ) diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index 70ea4d15d4..c067e5bfe3 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py
@@ -29,6 +29,7 @@ from synapse.logging.context import LoggingContext from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer +from synapse.state import StateResolutionStore from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort from synapse.types import JsonDict from synapse.util import Clock @@ -161,6 +162,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): if prev_exists_as_outlier: prev_event.internal_metadata.outlier = True persistence = self.hs.get_storage_controllers().persistence + assert persistence is not None self.get_success( persistence.persist_event( prev_event, @@ -861,7 +863,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): bert_member_event.event_id: bert_member_event, rejected_kick_event.event_id: rejected_kick_event, }, - state_res_store=main_store, + state_res_store=StateResolutionStore(main_store), ) ), [bert_member_event.event_id, rejected_kick_event.event_id], @@ -906,7 +908,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): rejected_power_levels_event.event_id, ], event_map={}, - state_res_store=main_store, + state_res_store=StateResolutionStore(main_store), full_conflicted_set=set(), ) ), diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index c4727ab917..69d384442f 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py
@@ -41,20 +41,21 @@ class EventCreationTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = self.hs.get_event_creation_handler() - self._persist_event_storage_controller = ( - self.hs.get_storage_controllers().persistence - ) + persistence = self.hs.get_storage_controllers().persistence + assert persistence is not None + self._persist_event_storage_controller = persistence self.user_id = self.register_user("tester", "foobar") self.access_token = self.login("tester", "foobar") self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token) - self.info = self.get_success( + info = self.get_success( self.hs.get_datastores().main.get_user_by_access_token( self.access_token, ) ) - self.token_id = self.info.token_id + assert info is not None + self.token_id = info.token_id self.requester = create_requester(self.user_id, access_token_id=self.token_id) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index adddbd002f..951caaa6b3 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py
@@ -150,7 +150,7 @@ class OidcHandlerTestCase(HomeserverTestCase): hs = self.setup_test_homeserver() self.hs_patcher = self.fake_server.patch_homeserver(hs=hs) - self.hs_patcher.start() + self.hs_patcher.start() # type: ignore[attr-defined] self.handler = hs.get_oidc_handler() self.provider = self.handler._providers["oidc"] @@ -170,7 +170,7 @@ class OidcHandlerTestCase(HomeserverTestCase): return hs def tearDown(self) -> None: - self.hs_patcher.stop() + self.hs_patcher.stop() # type: ignore[attr-defined] return super().tearDown() def reset_mocks(self) -> None: diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 0916de64f5..aa91bc0a3d 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py
@@ -852,7 +852,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): username: The username to use for the test. registration: Whether to test with registration URLs. """ - self.hs.get_identity_handler().send_threepid_validation = Mock( + self.hs.get_identity_handler().send_threepid_validation = Mock( # type: ignore[assignment] return_value=make_awaitable(0), ) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index b9332d97dc..1db99b3c00 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py
@@ -62,7 +62,7 @@ class TestSpamChecker: request_info: Collection[Tuple[str, str]], auth_provider_id: Optional[str], ) -> RegistrationBehaviour: - pass + return RegistrationBehaviour.ALLOW class DenyAll(TestSpamChecker): @@ -111,7 +111,7 @@ class TestLegacyRegistrationSpamChecker: username: Optional[str], request_info: Collection[Tuple[str, str]], ) -> RegistrationBehaviour: - pass + return RegistrationBehaviour.ALLOW class LegacyAllowAll(TestLegacyRegistrationSpamChecker): @@ -203,7 +203,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): @override_config({"limit_usage_by_mau": True}) def test_get_or_create_user_mau_not_blocked(self) -> None: - self.store.count_monthly_users = Mock( + self.store.count_monthly_users = Mock( # type: ignore[assignment] return_value=make_awaitable(self.hs.config.server.max_mau_value - 1) ) # Ensure does not throw exception @@ -304,7 +304,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None: room_alias_str = "#room:test" - self.store.count_real_users = Mock(return_value=make_awaitable(1)) + self.store.count_real_users = Mock(return_value=make_awaitable(1)) # type: ignore[assignment] self.store.is_real_user = Mock(return_value=make_awaitable(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) @@ -319,7 +319,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user( self, ) -> None: - self.store.count_real_users = Mock(return_value=make_awaitable(2)) + self.store.count_real_users = Mock(return_value=make_awaitable(2)) # type: ignore[assignment] self.store.is_real_user = Mock(return_value=make_awaitable(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) @@ -346,6 +346,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): # Ensure the room is properly not federated. room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) + assert room is not None self.assertFalse(room["federatable"]) self.assertFalse(room["public"]) self.assertEqual(room["join_rules"], "public") @@ -375,6 +376,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): # Ensure the room is properly a public room. room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) + assert room is not None self.assertEqual(room["join_rules"], "public") # Both users should be in the room. @@ -413,6 +415,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): # Ensure the room is properly a private room. room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) + assert room is not None self.assertFalse(room["public"]) self.assertEqual(room["join_rules"], "invite") self.assertEqual(room["guest_access"], "can_join") @@ -456,6 +459,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): # Ensure the room is properly a private room. room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) + assert room is not None self.assertFalse(room["public"]) self.assertEqual(room["join_rules"], "invite") self.assertEqual(room["guest_access"], "can_join") diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 9b1b8b9f13..b5c772a7ae 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py
@@ -134,7 +134,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() + auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] # send a mocked-up SAML response to the callback saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) @@ -164,7 +164,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() + auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] # Map a user via SSO. saml_response = FakeAuthnResponse( @@ -206,11 +206,11 @@ class SamlHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() + auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] # mock out the error renderer too sso_handler = self.hs.get_sso_handler() - sso_handler.render_error = Mock(return_value=None) + sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment] saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"}) request = _mock_request() @@ -227,9 +227,9 @@ class SamlHandlerTestCase(HomeserverTestCase): # stub out the auth handler and error renderer auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() + auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] sso_handler = self.hs.get_sso_handler() - sso_handler.render_error = Mock(return_value=None) + sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment] # register a user to occupy the first-choice MXID store = self.hs.get_datastores().main @@ -312,7 +312,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() + auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] # The response doesn't have the proper userGroup or department. saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 1fe9563c98..94518a7196 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py
@@ -74,8 +74,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): mock_keyring.verify_json_for_server.return_value = make_awaitable(True) # we mock out the federation client too - mock_federation_client = Mock(spec=["put_json"]) - mock_federation_client.put_json.return_value = make_awaitable((200, "OK")) + self.mock_federation_client = Mock(spec=["put_json"]) + self.mock_federation_client.put_json.return_value = make_awaitable((200, "OK")) # the tests assume that we are starting at unix time 1000 reactor.pump((1000,)) @@ -83,7 +83,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.mock_hs_notifier = Mock() hs = self.setup_test_homeserver( notifier=self.mock_hs_notifier, - federation_http_client=mock_federation_client, + federation_http_client=self.mock_federation_client, keyring=mock_keyring, replication_streams={}, ) @@ -233,8 +233,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) ) - put_json = self.hs.get_federation_http_client().put_json - put_json.assert_called_once_with( + self.mock_federation_client.put_json.assert_called_once_with( "farm", path="/_matrix/federation/v1/send/1000000", data=_expect_edu_transaction( @@ -349,8 +348,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) - put_json = self.hs.get_federation_http_client().put_json - put_json.assert_called_once_with( + self.mock_federation_client.put_json.assert_called_once_with( "farm", path="/_matrix/federation/v1/send/1000000", data=_expect_edu_transaction( diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 75fc5a17a4..a02c1c6227 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py
@@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +from typing import Any, Tuple from unittest.mock import Mock, patch from urllib.parse import quote @@ -24,7 +24,7 @@ from synapse.appservice import ApplicationService from synapse.rest.client import login, register, room, user_directory from synapse.server import HomeServer from synapse.storage.roommember import ProfileInfo -from synapse.types import create_requester +from synapse.types import UserProfile, create_requester from synapse.util import Clock from tests import unittest @@ -34,6 +34,12 @@ from tests.test_utils.event_injection import inject_member_event from tests.unittest import override_config +# A spam checker which doesn't implement anything, so create a bare object. +class UselessSpamChecker: + def __init__(self, config: Any): + pass + + class UserDirectoryTestCase(unittest.HomeserverTestCase): """Tests the UserDirectoryHandler. @@ -186,6 +192,13 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.helper.join(room, self.appservice.sender, tok=self.appservice.token) self._check_only_one_user_in_directory(user, room) + def test_search_term_with_colon_in_it_does_not_raise(self) -> None: + """ + Regression test: Test that search terms with colons in them are acceptable. + """ + u1 = self.register_user("user1", "pass") + self.get_success(self.handler.search_users(u1, "haha:paamayim-nekudotayim", 10)) + def test_user_not_in_users_table(self) -> None: """Unclear how it happens, but on matrix.org we've seen join events for users who aren't in the users table. Test that we don't fall over @@ -773,7 +786,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): s = self.get_success(self.handler.search_users(u1, "user2", 10)) self.assertEqual(len(s["results"]), 1) - async def allow_all(user_profile: ProfileInfo) -> bool: + async def allow_all(user_profile: UserProfile) -> bool: # Allow all users. return False @@ -787,7 +800,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.assertEqual(len(s["results"]), 1) # Configure a spam checker that filters all users. - async def block_all(user_profile: ProfileInfo) -> bool: + async def block_all(user_profile: UserProfile) -> bool: # All users are spammy. return True @@ -797,6 +810,13 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): s = self.get_success(self.handler.search_users(u1, "user2", 10)) self.assertEqual(len(s["results"]), 0) + @override_config( + { + "spam_checker": { + "module": "tests.handlers.test_user_directory.UselessSpamChecker" + } + } + ) def test_legacy_spam_checker(self) -> None: """ A spam checker without the expected method should be ignored. @@ -825,11 +845,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)}) self.assertEqual(public_users, set()) - # Configure a spam checker. - spam_checker = self.hs.get_spam_checker() - # The spam checker doesn't need any methods, so create a bare object. - spam_checker.spam_checker = object() - # We get one search result when searching for user2 by user1. s = self.get_success(self.handler.search_users(u1, "user2", 10)) self.assertEqual(len(s["results"]), 1) @@ -949,13 +964,14 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): }, ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) - self.get_success( - self.hs.get_storage_controllers().persistence.persist_event(event, context) - ) + context = self.get_success(unpersisted_context.persist(event)) + persistence = self.hs.get_storage_controllers().persistence + assert persistence is not None + self.get_success(persistence.persist_event(event, context)) def test_local_user_leaving_room_remains_in_user_directory(self) -> None: """We've chosen to simplify the user directory's implementation by diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index acfdcd3bca..eb7f53fee5 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -30,7 +30,7 @@ from twisted.internet.interfaces import ( IOpenSSLClientConnectionCreator, IProtocolFactory, ) -from twisted.internet.protocol import Factory +from twisted.internet.protocol import Factory, Protocol from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol from twisted.web._newclient import ResponseNeverReceived from twisted.web.client import Agent @@ -63,7 +63,7 @@ from tests.http import ( get_test_ca_cert_file, ) from tests.server import FakeTransport, ThreadedMemoryReactorClock -from tests.utils import default_config +from tests.utils import checked_cast, default_config logger = logging.getLogger(__name__) @@ -146,8 +146,10 @@ class MatrixFederationAgentTests(unittest.TestCase): # # Normally this would be done by the TCP socket code in Twisted, but we are # stubbing that out here. - client_protocol = client_factory.buildProtocol(dummy_address) - assert isinstance(client_protocol, _WrappingProtocol) + # NB: we use a checked_cast here to workaround https://github.com/Shoobx/mypy-zope/issues/91) + client_protocol = checked_cast( + _WrappingProtocol, client_factory.buildProtocol(dummy_address) + ) client_protocol.makeConnection( FakeTransport(server_protocol, self.reactor, client_protocol) ) @@ -446,7 +448,6 @@ class MatrixFederationAgentTests(unittest.TestCase): server_ssl_protocol = _wrap_server_factory_for_tls( _get_test_protocol_factory() ).buildProtocol(dummy_address) - assert isinstance(server_ssl_protocol, TLSMemoryBIOProtocol) # Tell the HTTP server to send outgoing traffic back via the proxy's transport. proxy_server_transport = proxy_server.transport @@ -465,7 +466,8 @@ class MatrixFederationAgentTests(unittest.TestCase): else: assert isinstance(proxy_server_transport, FakeTransport) client_protocol = proxy_server_transport.other - c2s_transport = client_protocol.transport + assert isinstance(client_protocol, Protocol) + c2s_transport = checked_cast(FakeTransport, client_protocol.transport) c2s_transport.other = server_ssl_protocol self.reactor.advance(0) @@ -1529,7 +1531,7 @@ def _check_logcontext(context: LoggingContextOrSentinel) -> None: def _wrap_server_factory_for_tls( factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None -) -> IProtocolFactory: +) -> TLSMemoryBIOFactory: """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory The resultant factory will create a TLS server which presents a certificate signed by our test CA, valid for the domains in `sanlist` diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index a817940730..cc175052ac 100644 --- a/tests/http/test_proxyagent.py +++ b/tests/http/test_proxyagent.py
@@ -28,7 +28,7 @@ from twisted.internet.endpoints import ( _WrappingProtocol, ) from twisted.internet.interfaces import IProtocol, IProtocolFactory -from twisted.internet.protocol import Factory +from twisted.internet.protocol import Factory, Protocol from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol from twisted.web.http import HTTPChannel @@ -43,6 +43,7 @@ from tests.http import ( ) from tests.server import FakeTransport, ThreadedMemoryReactorClock from tests.unittest import TestCase +from tests.utils import checked_cast logger = logging.getLogger(__name__) @@ -620,7 +621,6 @@ class MatrixFederationAgentTests(TestCase): server_ssl_protocol = _wrap_server_factory_for_tls( _get_test_protocol_factory() ).buildProtocol(dummy_address) - assert isinstance(server_ssl_protocol, TLSMemoryBIOProtocol) # Tell the HTTP server to send outgoing traffic back via the proxy's transport. proxy_server_transport = proxy_server.transport @@ -644,7 +644,8 @@ class MatrixFederationAgentTests(TestCase): else: assert isinstance(proxy_server_transport, FakeTransport) client_protocol = proxy_server_transport.other - c2s_transport = client_protocol.transport + assert isinstance(client_protocol, Protocol) + c2s_transport = checked_cast(FakeTransport, client_protocol.transport) c2s_transport.other = server_ssl_protocol self.reactor.advance(0) @@ -757,12 +758,14 @@ class MatrixFederationAgentTests(TestCase): assert isinstance(proxy_server, HTTPChannel) # fish the transports back out so that we can do the old switcheroo - s2c_transport = proxy_server.transport - assert isinstance(s2c_transport, FakeTransport) - client_protocol = s2c_transport.other - assert isinstance(client_protocol, _WrappingProtocol) - c2s_transport = client_protocol.transport - assert isinstance(c2s_transport, FakeTransport) + # To help mypy out with the various Protocols and wrappers and mocks, we do + # some explicit casting. Without the casts, we hit the bug I reported at + # https://github.com/Shoobx/mypy-zope/issues/91 . + # We also double-checked these casts at runtime (test-time) because I found it + # quite confusing to deduce these types in the first place! + s2c_transport = checked_cast(FakeTransport, proxy_server.transport) + client_protocol = checked_cast(_WrappingProtocol, s2c_transport.other) + c2s_transport = checked_cast(FakeTransport, client_protocol.transport) # the FakeTransport is async, so we need to pump the reactor self.reactor.advance(0) @@ -822,9 +825,9 @@ class MatrixFederationAgentTests(TestCase): @patch.dict(os.environ, {"http_proxy": "proxy.com:8888"}) def test_proxy_with_no_scheme(self) -> None: http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True) - assert isinstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint) - self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com") - self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888) + proxy_ep = checked_cast(HostnameEndpoint, http_proxy_agent.http_proxy_endpoint) + self.assertEqual(proxy_ep._hostStr, "proxy.com") + self.assertEqual(proxy_ep._port, 8888) @patch.dict(os.environ, {"http_proxy": "socks://proxy.com:8888"}) def test_proxy_with_unsupported_scheme(self) -> None: @@ -834,25 +837,21 @@ class MatrixFederationAgentTests(TestCase): @patch.dict(os.environ, {"http_proxy": "http://proxy.com:8888"}) def test_proxy_with_http_scheme(self) -> None: http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True) - assert isinstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint) - self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com") - self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888) + proxy_ep = checked_cast(HostnameEndpoint, http_proxy_agent.http_proxy_endpoint) + self.assertEqual(proxy_ep._hostStr, "proxy.com") + self.assertEqual(proxy_ep._port, 8888) @patch.dict(os.environ, {"http_proxy": "https://proxy.com:8888"}) def test_proxy_with_https_scheme(self) -> None: https_proxy_agent = ProxyAgent(self.reactor, use_proxy=True) - assert isinstance(https_proxy_agent.http_proxy_endpoint, _WrapperEndpoint) - self.assertEqual( - https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._hostStr, "proxy.com" - ) - self.assertEqual( - https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._port, 8888 - ) + proxy_ep = checked_cast(_WrapperEndpoint, https_proxy_agent.http_proxy_endpoint) + self.assertEqual(proxy_ep._wrappedEndpoint._hostStr, "proxy.com") + self.assertEqual(proxy_ep._wrappedEndpoint._port, 8888) def _wrap_server_factory_for_tls( factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None -) -> IProtocolFactory: +) -> TLSMemoryBIOFactory: """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory The resultant factory will create a TLS server which presents a certificate diff --git a/tests/logging/test_remote_handler.py b/tests/logging/test_remote_handler.py
index c08954d887..5191e31a8a 100644 --- a/tests/logging/test_remote_handler.py +++ b/tests/logging/test_remote_handler.py
@@ -21,6 +21,7 @@ from synapse.logging import RemoteHandler from tests.logging import LoggerCleanupMixin from tests.server import FakeTransport, get_clock from tests.unittest import TestCase +from tests.utils import checked_cast def connect_logging_client( @@ -56,8 +57,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): client, server = connect_logging_client(self.reactor, 0) # Trigger data being sent - assert isinstance(client.transport, FakeTransport) - client.transport.flush() + client_transport = checked_cast(FakeTransport, client.transport) + client_transport.flush() # One log message, with a single trailing newline logs = server.data.decode("utf8").splitlines() @@ -89,8 +90,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): # Allow the reconnection client, server = connect_logging_client(self.reactor, 0) - assert isinstance(client.transport, FakeTransport) - client.transport.flush() + client_transport = checked_cast(FakeTransport, client.transport) + client_transport.flush() # Only the 7 infos made it through, the debugs were elided logs = server.data.splitlines() @@ -123,8 +124,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): # Allow the reconnection client, server = connect_logging_client(self.reactor, 0) - assert isinstance(client.transport, FakeTransport) - client.transport.flush() + client_transport = checked_cast(FakeTransport, client.transport) + client_transport.flush() # The 10 warnings made it through, the debugs and infos were elided logs = server.data.splitlines() @@ -148,8 +149,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): # Allow the reconnection client, server = connect_logging_client(self.reactor, 0) - assert isinstance(client.transport, FakeTransport) - client.transport.flush() + client_transport = checked_cast(FakeTransport, client.transport) + client_transport.flush() # The first five and last five warnings made it through, the debugs and # infos were elided diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 8f88c0117d..3a1929691e 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py
@@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict from unittest.mock import Mock from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EduTypes, EventTypes from synapse.api.errors import NotFoundError @@ -21,9 +23,12 @@ from synapse.events import EventBase from synapse.federation.units import Transaction from synapse.handlers.presence import UserPresenceState from synapse.handlers.push_rules import InvalidRuleException +from synapse.module_api import ModuleApi from synapse.rest import admin from synapse.rest.client import login, notifications, presence, profile, room -from synapse.types import create_requester +from synapse.server import HomeServer +from synapse.types import JsonDict, create_requester +from synapse.util import Clock from tests.events.test_presence_router import send_presence_update, sync_presence from tests.replication._base import BaseMultiWorkerStreamTestCase @@ -32,7 +37,19 @@ from tests.test_utils.event_injection import inject_member_event from tests.unittest import HomeserverTestCase, override_config -class ModuleApiTestCase(HomeserverTestCase): +class BaseModuleApiTestCase(HomeserverTestCase): + """Common properties of the two test case classes.""" + + module_api: ModuleApi + + # These are all written by _test_sending_local_online_presence_to_local_user. + presence_receiver_id: str + presence_receiver_tok: str + presence_sender_id: str + presence_sender_tok: str + + +class ModuleApiTestCase(BaseModuleApiTestCase): servlets = [ admin.register_servlets, login.register_servlets, @@ -42,23 +59,23 @@ class ModuleApiTestCase(HomeserverTestCase): notifications.register_servlets, ] - def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastores().main - self.module_api = homeserver.get_module_api() - self.event_creation_handler = homeserver.get_event_creation_handler() - self.sync_handler = homeserver.get_sync_handler() - self.auth_handler = homeserver.get_auth_handler() + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.module_api = hs.get_module_api() + self.event_creation_handler = hs.get_event_creation_handler() + self.sync_handler = hs.get_sync_handler() + self.auth_handler = hs.get_auth_handler() - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: # Mock out the calls over federation. - fed_transport_client = Mock(spec=["send_transaction"]) - fed_transport_client.send_transaction = simple_async_mock({}) + self.fed_transport_client = Mock(spec=["send_transaction"]) + self.fed_transport_client.send_transaction = simple_async_mock({}) return self.setup_test_homeserver( - federation_transport_client=fed_transport_client, + federation_transport_client=self.fed_transport_client, ) - def test_can_register_user(self): + def test_can_register_user(self) -> None: """Tests that an external module can register a user""" # Register a new user user_id, access_token = self.get_success( @@ -88,16 +105,17 @@ class ModuleApiTestCase(HomeserverTestCase): displayname = self.get_success(self.store.get_profile_displayname("bob")) self.assertEqual(displayname, "Bobberino") - def test_can_register_admin_user(self): + def test_can_register_admin_user(self) -> None: user_id = self.register_user( "bob_module_admin", "1234", displayname="Bobberino Admin", admin=True ) found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id)) + assert found_user is not None self.assertEqual(found_user.user_id.to_string(), user_id) self.assertIdentical(found_user.is_admin, True) - def test_can_set_admin(self): + def test_can_set_admin(self) -> None: user_id = self.register_user( "alice_wants_admin", "1234", @@ -107,16 +125,17 @@ class ModuleApiTestCase(HomeserverTestCase): self.get_success(self.module_api.set_user_admin(user_id, True)) found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id)) + assert found_user is not None self.assertEqual(found_user.user_id.to_string(), user_id) self.assertIdentical(found_user.is_admin, True) - def test_can_set_displayname(self): + def test_can_set_displayname(self) -> None: localpart = "alice_wants_a_new_displayname" user_id = self.register_user( localpart, "1234", displayname="Alice", admin=False ) found_userinfo = self.get_success(self.module_api.get_userinfo_by_id(user_id)) - + assert found_userinfo is not None self.get_success( self.module_api.set_displayname( found_userinfo.user_id, "Bob", deactivation=False @@ -128,17 +147,18 @@ class ModuleApiTestCase(HomeserverTestCase): self.assertEqual(found_profile.display_name, "Bob") - def test_get_userinfo_by_id(self): + def test_get_userinfo_by_id(self) -> None: user_id = self.register_user("alice", "1234") found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id)) + assert found_user is not None self.assertEqual(found_user.user_id.to_string(), user_id) self.assertIdentical(found_user.is_admin, False) - def test_get_userinfo_by_id__no_user_found(self): + def test_get_userinfo_by_id__no_user_found(self) -> None: found_user = self.get_success(self.module_api.get_userinfo_by_id("@alice:test")) self.assertIsNone(found_user) - def test_get_user_ip_and_agents(self): + def test_get_user_ip_and_agents(self) -> None: user_id = self.register_user("test_get_user_ip_and_agents_user", "1234") # Initially, we should have no ip/agent for our user. @@ -185,7 +205,7 @@ class ModuleApiTestCase(HomeserverTestCase): # we should only find the second ip, agent. info = self.get_success( self.module_api.get_user_ip_and_agents( - user_id, (last_seen_1 + last_seen_2) / 2 + user_id, (last_seen_1 + last_seen_2) // 2 ) ) self.assertEqual(len(info), 1) @@ -200,7 +220,7 @@ class ModuleApiTestCase(HomeserverTestCase): ) self.assertEqual(info, []) - def test_get_user_ip_and_agents__no_user_found(self): + def test_get_user_ip_and_agents__no_user_found(self) -> None: info = self.get_success( self.module_api.get_user_ip_and_agents( "@test_get_user_ip_and_agents_user_nonexistent:example.com" @@ -208,10 +228,10 @@ class ModuleApiTestCase(HomeserverTestCase): ) self.assertEqual(info, []) - def test_sending_events_into_room(self): + def test_sending_events_into_room(self) -> None: """Tests that a module can send events into a room""" # Mock out create_and_send_nonmember_event to check whether events are being sent - self.event_creation_handler.create_and_send_nonmember_event = Mock( + self.event_creation_handler.create_and_send_nonmember_event = Mock( # type: ignore[assignment] spec=[], side_effect=self.event_creation_handler.create_and_send_nonmember_event, ) @@ -222,7 +242,7 @@ class ModuleApiTestCase(HomeserverTestCase): room_id = self.helper.create_room_as(user_id, tok=tok) # Create and send a non-state event - content = {"body": "I am a puppet", "msgtype": "m.text"} + content: JsonDict = {"body": "I am a puppet", "msgtype": "m.text"} event_dict = { "room_id": room_id, "type": "m.room.message", @@ -265,7 +285,7 @@ class ModuleApiTestCase(HomeserverTestCase): "sender": user_id, "state_key": "", } - event: EventBase = self.get_success( + event = self.get_success( self.module_api.create_and_send_event_into_room(event_dict) ) self.assertEqual(event.sender, user_id) @@ -303,7 +323,7 @@ class ModuleApiTestCase(HomeserverTestCase): self.module_api.create_and_send_event_into_room(event_dict), Exception ) - def test_public_rooms(self): + def test_public_rooms(self) -> None: """Tests that a room can be added and removed from the public rooms list, as well as have its public rooms directory state queried. """ @@ -350,13 +370,13 @@ class ModuleApiTestCase(HomeserverTestCase): ) self.assertFalse(is_in_public_rooms) - def test_send_local_online_presence_to(self): + def test_send_local_online_presence_to(self) -> None: # Test sending local online presence to users from the main process _test_sending_local_online_presence_to_local_user(self, test_with_workers=False) # Enable federation sending on the main process. @override_config({"federation_sender_instances": None}) - def test_send_local_online_presence_to_federation(self): + def test_send_local_online_presence_to_federation(self) -> None: """Tests that send_local_presence_to_users sends local online presence to remote users.""" # Create a user who will send presence updates self.presence_sender_id = self.register_user("presence_sender1", "monkey") @@ -397,7 +417,7 @@ class ModuleApiTestCase(HomeserverTestCase): # # Thus we reset the mock, and try sending online local user # presence again - self.hs.get_federation_transport_client().send_transaction.reset_mock() + self.fed_transport_client.send_transaction.reset_mock() # Broadcast local user online presence self.get_success( @@ -409,9 +429,7 @@ class ModuleApiTestCase(HomeserverTestCase): # Check that a presence update was sent as part of a federation transaction found_update = False - calls = ( - self.hs.get_federation_transport_client().send_transaction.call_args_list - ) + calls = self.fed_transport_client.send_transaction.call_args_list for call in calls: call_args = call[0] federation_transaction: Transaction = call_args[0] @@ -431,7 +449,7 @@ class ModuleApiTestCase(HomeserverTestCase): self.assertTrue(found_update) - def test_update_membership(self): + def test_update_membership(self) -> None: """Tests that the module API can update the membership of a user in a room.""" peter = self.register_user("peter", "hackme") lesley = self.register_user("lesley", "hackme") @@ -554,14 +572,14 @@ class ModuleApiTestCase(HomeserverTestCase): self.assertEqual(res["displayname"], "simone") self.assertIsNone(res["avatar_url"]) - def test_update_room_membership_remote_join(self): + def test_update_room_membership_remote_join(self) -> None: """Test that the module API can join a remote room.""" # Necessary to fake a remote join. fake_stream_id = 1 mocked_remote_join = simple_async_mock( return_value=("fake-event-id", fake_stream_id) ) - self.hs.get_room_member_handler()._remote_join = mocked_remote_join + self.hs.get_room_member_handler()._remote_join = mocked_remote_join # type: ignore[assignment] fake_remote_host = f"{self.module_api.server_name}-remote" # Given that the join is to be faked, we expect the relevant join event not to @@ -582,7 +600,7 @@ class ModuleApiTestCase(HomeserverTestCase): # Check that a remote join was attempted. self.assertEqual(mocked_remote_join.call_count, 1) - def test_get_room_state(self): + def test_get_room_state(self) -> None: """Tests that a module can retrieve the state of a room through the module API.""" user_id = self.register_user("peter", "hackme") tok = self.login("peter", "hackme") @@ -677,7 +695,7 @@ class ModuleApiTestCase(HomeserverTestCase): self.module_api.check_push_rule_actions(["foo"]) with self.assertRaises(InvalidRuleException): - self.module_api.check_push_rule_actions({"foo": "bar"}) + self.module_api.check_push_rule_actions([{"foo": "bar"}]) self.module_api.check_push_rule_actions(["notify"]) @@ -756,7 +774,7 @@ class ModuleApiTestCase(HomeserverTestCase): self.assertIsNone(room_alias) -class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase): +class ModuleApiWorkerTestCase(BaseModuleApiTestCase, BaseMultiWorkerStreamTestCase): """For testing ModuleApi functionality in a multi-worker setup""" servlets = [ @@ -766,7 +784,7 @@ class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase): presence.register_servlets, ] - def default_config(self): + def default_config(self) -> Dict[str, Any]: conf = super().default_config() conf["stream_writers"] = {"presence": ["presence_writer"]} conf["instance_map"] = { @@ -774,18 +792,18 @@ class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase): } return conf - def prepare(self, reactor, clock, homeserver): - self.module_api = homeserver.get_module_api() - self.sync_handler = homeserver.get_sync_handler() + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.module_api = hs.get_module_api() + self.sync_handler = hs.get_sync_handler() - def test_send_local_online_presence_to_workers(self): + def test_send_local_online_presence_to_workers(self) -> None: # Test sending local online presence to users from a worker process _test_sending_local_online_presence_to_local_user(self, test_with_workers=True) def _test_sending_local_online_presence_to_local_user( - test_case: HomeserverTestCase, test_with_workers: bool = False -): + test_case: BaseModuleApiTestCase, test_with_workers: bool = False +) -> None: """Tests that send_local_presence_to_users sends local online presence to local users. This simultaneously tests two different usecases: @@ -852,6 +870,7 @@ def _test_sending_local_online_presence_to_local_user( # Replicate the current sync presence token from the main process to the worker process. # We need to do this so that the worker process knows the current presence stream ID to # insert into the database when we call ModuleApi.send_local_online_presence_to. + assert isinstance(test_case, BaseMultiWorkerStreamTestCase) test_case.replicate() # Syncing again should result in no presence updates @@ -868,6 +887,7 @@ def _test_sending_local_online_presence_to_local_user( # Determine on which process (main or worker) to call ModuleApi.send_local_online_presence_to on if test_with_workers: + assert isinstance(test_case, BaseMultiWorkerStreamTestCase) module_api_to_use = worker_hs.get_module_api() else: module_api_to_use = test_case.module_api @@ -875,12 +895,11 @@ def _test_sending_local_online_presence_to_local_user( # Trigger sending local online presence. We expect this information # to be saved to the database where all processes can access it. # Note that we're syncing via the master. - d = module_api_to_use.send_local_online_presence_to( - [ - test_case.presence_receiver_id, - ] + d = defer.ensureDeferred( + module_api_to_use.send_local_online_presence_to( + [test_case.presence_receiver_id], + ) ) - d = defer.ensureDeferred(d) if test_with_workers: # In order for the required presence_set_state replication request to occur between the @@ -897,7 +916,7 @@ def _test_sending_local_online_presence_to_local_user( ) test_case.assertEqual(len(presence_updates), 1) - presence_update: UserPresenceState = presence_updates[0] + presence_update = presence_updates[0] test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id) test_case.assertEqual(presence_update.state, "online") @@ -908,7 +927,7 @@ def _test_sending_local_online_presence_to_local_user( ) test_case.assertEqual(len(presence_updates), 1) - presence_update: UserPresenceState = presence_updates[0] + presence_update = presence_updates[0] test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id) test_case.assertEqual(presence_update.state, "online") @@ -936,12 +955,13 @@ def _test_sending_local_online_presence_to_local_user( test_case.assertEqual(len(presence_updates), 1) # Now trigger sending local online presence. - d = module_api_to_use.send_local_online_presence_to( - [ - test_case.presence_receiver_id, - ] + d = defer.ensureDeferred( + module_api_to_use.send_local_online_presence_to( + [ + test_case.presence_receiver_id, + ] + ) ) - d = defer.ensureDeferred(d) if test_with_workers: # In order for the required presence_set_state replication request to occur between the diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py
index 7567756135..199e3d7b70 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py
@@ -227,7 +227,14 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): ) return len(result) > 0 - @override_config({"experimental_features": {"msc3952_intentional_mentions": True}}) + @override_config( + { + "experimental_features": { + "msc3758_exact_event_match": True, + "msc3952_intentional_mentions": True, + } + } + ) def test_user_mentions(self) -> None: """Test the behavior of an event which includes invalid user mentions.""" bulk_evaluator = BulkPushRuleEvaluator(self.hs) @@ -323,7 +330,14 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): ) ) - @override_config({"experimental_features": {"msc3952_intentional_mentions": True}}) + @override_config( + { + "experimental_features": { + "msc3758_exact_event_match": True, + "msc3952_intentional_mentions": True, + } + } + ) def test_room_mentions(self) -> None: """Test the behavior of an event which includes invalid room mentions.""" bulk_evaluator = BulkPushRuleEvaluator(self.hs) diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index ab8bb417e7..7563f33fdc 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py
@@ -23,6 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes, SynapseError +from synapse.push.emailpusher import EmailPusher from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.util import Clock @@ -105,6 +106,7 @@ class EmailPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastores().main.get_user_by_access_token(self.access_token) ) + assert user_tuple is not None self.token_id = user_tuple.token_id # We need to add email to account before we can create a pusher. @@ -114,7 +116,7 @@ class EmailPusherTests(HomeserverTestCase): ) ) - self.pusher = self.get_success( + pusher = self.get_success( self.hs.get_pusherpool().add_or_update_pusher( user_id=self.user_id, access_token=self.token_id, @@ -127,6 +129,8 @@ class EmailPusherTests(HomeserverTestCase): data={}, ) ) + assert isinstance(pusher, EmailPusher) + self.pusher = pusher self.auth_handler = hs.get_auth_handler() self.store = hs.get_datastores().main @@ -375,10 +379,13 @@ class EmailPusherTests(HomeserverTestCase): ) # check that the pusher for that email address has been deleted - pushers = self.get_success( - self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) + pushers = list( + self.get_success( + self.hs.get_datastores().main.get_pushers_by( + {"user_name": self.user_id} + ) + ) ) - pushers = list(pushers) self.assertEqual(len(pushers), 0) def test_remove_unlinked_pushers_background_job(self) -> None: @@ -413,10 +420,13 @@ class EmailPusherTests(HomeserverTestCase): self.wait_for_background_updates() # Check that all pushers with unlinked addresses were deleted - pushers = self.get_success( - self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) + pushers = list( + self.get_success( + self.hs.get_datastores().main.get_pushers_by( + {"user_name": self.user_id} + ) + ) ) - pushers = list(pushers) self.assertEqual(len(pushers), 0) def _check_for_mail(self) -> Tuple[Sequence, Dict]: @@ -428,10 +438,13 @@ class EmailPusherTests(HomeserverTestCase): that notification. """ # Get the stream ordering before it gets sent - pushers = self.get_success( - self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) + pushers = list( + self.get_success( + self.hs.get_datastores().main.get_pushers_by( + {"user_name": self.user_id} + ) + ) ) - pushers = list(pushers) self.assertEqual(len(pushers), 1) last_stream_ordering = pushers[0].last_stream_ordering @@ -439,10 +452,13 @@ class EmailPusherTests(HomeserverTestCase): self.pump(10) # It hasn't succeeded yet, so the stream ordering shouldn't have moved - pushers = self.get_success( - self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) + pushers = list( + self.get_success( + self.hs.get_datastores().main.get_pushers_by( + {"user_name": self.user_id} + ) + ) ) - pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering) @@ -458,10 +474,13 @@ class EmailPusherTests(HomeserverTestCase): self.assertEqual(len(self.email_attempts), 1) # The stream ordering has increased - pushers = self.get_success( - self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) + pushers = list( + self.get_success( + self.hs.get_datastores().main.get_pushers_by( + {"user_name": self.user_id} + ) + ) ) - pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering) diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index 23447cc310..c280ddcdf6 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py
@@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple +from typing import Any, List, Tuple from unittest.mock import Mock from twisted.internet.defer import Deferred @@ -22,7 +22,6 @@ from synapse.logging.context import make_deferred_yieldable from synapse.push import PusherConfig, PusherConfigException from synapse.rest.client import login, push_rule, pusher, receipts, room from synapse.server import HomeServer -from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import JsonDict from synapse.util import Clock @@ -67,9 +66,10 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastores().main.get_user_by_access_token(access_token) ) + assert user_tuple is not None token_id = user_tuple.token_id - def test_data(data: Optional[JsonDict]) -> None: + def test_data(data: Any) -> None: self.get_failure( self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, @@ -113,6 +113,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastores().main.get_user_by_access_token(access_token) ) + assert user_tuple is not None token_id = user_tuple.token_id self.get_success( @@ -140,10 +141,11 @@ class HTTPPusherTests(HomeserverTestCase): self.helper.send(room, body="There!", tok=other_access_token) # Get the stream ordering before it gets sent - pushers = self.get_success( - self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) + pushers = list( + self.get_success( + self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) + ) ) - pushers = list(pushers) self.assertEqual(len(pushers), 1) last_stream_ordering = pushers[0].last_stream_ordering @@ -151,10 +153,11 @@ class HTTPPusherTests(HomeserverTestCase): self.pump() # It hasn't succeeded yet, so the stream ordering shouldn't have moved - pushers = self.get_success( - self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) + pushers = list( + self.get_success( + self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) + ) ) - pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering) @@ -172,10 +175,11 @@ class HTTPPusherTests(HomeserverTestCase): self.pump() # The stream ordering has increased - pushers = self.get_success( - self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) + pushers = list( + self.get_success( + self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) + ) ) - pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering) last_stream_ordering = pushers[0].last_stream_ordering @@ -194,10 +198,11 @@ class HTTPPusherTests(HomeserverTestCase): self.pump() # The stream ordering has increased, again - pushers = self.get_success( - self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) + pushers = list( + self.get_success( + self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) + ) ) - pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering) @@ -229,6 +234,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastores().main.get_user_by_access_token(access_token) ) + assert user_tuple is not None token_id = user_tuple.token_id self.get_success( @@ -349,6 +355,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastores().main.get_user_by_access_token(access_token) ) + assert user_tuple is not None token_id = user_tuple.token_id self.get_success( @@ -435,6 +442,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastores().main.get_user_by_access_token(access_token) ) + assert user_tuple is not None token_id = user_tuple.token_id self.get_success( @@ -512,6 +520,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastores().main.get_user_by_access_token(access_token) ) + assert user_tuple is not None token_id = user_tuple.token_id self.get_success( @@ -618,6 +627,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastores().main.get_user_by_access_token(access_token) ) + assert user_tuple is not None token_id = user_tuple.token_id self.get_success( @@ -753,6 +763,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastores().main.get_user_by_access_token(access_token) ) + assert user_tuple is not None token_id = user_tuple.token_id self.get_success( @@ -895,6 +906,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastores().main.get_user_by_access_token(access_token) ) + assert user_tuple is not None token_id = user_tuple.token_id device_id = user_tuple.device_id @@ -941,9 +953,10 @@ class HTTPPusherTests(HomeserverTestCase): ) # Look up the user info for the access token so we can compare the device ID. - lookup_result: TokenLookupResult = self.get_success( + lookup_result = self.get_success( self.hs.get_datastores().main.get_user_by_access_token(access_token) ) + assert lookup_result is not None # Get the user's devices and check it has the correct device ID. channel = self.make_request("GET", "/pushers", access_token=access_token) diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index da33423871..d320a12f96 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py
@@ -32,6 +32,7 @@ from synapse.storage.databases.main.appservice import _make_exclusive_regex from synapse.synapse_rust.push import PushRuleEvaluator from synapse.types import JsonDict, JsonMapping, UserID from synapse.util import Clock +from synapse.util.frozenutils import freeze from tests import unittest from tests.test_utils.event_injection import create_event, inject_member_event @@ -48,17 +49,34 @@ class FlattenDictTestCase(unittest.TestCase): input = {"foo": {"bar": "abc"}} self.assertEqual({"foo.bar": "abc"}, _flatten_dict(input)) + # If a field has a dot in it, escape it. + input = {"m.foo": {"b\\ar": "abc"}} + self.assertEqual({"m.foo.b\\ar": "abc"}, _flatten_dict(input)) + self.assertEqual( + {"m\\.foo.b\\\\ar": "abc"}, + _flatten_dict(input, msc3783_escape_event_match_key=True), + ) + def test_non_string(self) -> None: - """Non-string items are dropped.""" + """String, booleans, ints, nulls and list of those should be kept while other items are dropped.""" input: Dict[str, Any] = { "woo": "woo", "foo": True, "bar": 1, "baz": None, - "fuzz": [], + "fuzz": ["woo", True, 1, None, [], {}], "boo": {}, } - self.assertEqual({"woo": "woo"}, _flatten_dict(input)) + self.assertEqual( + { + "woo": "woo", + "foo": True, + "bar": 1, + "baz": None, + "fuzz": ["woo", True, 1, None], + }, + _flatten_dict(input), + ) def test_event(self) -> None: """Events can also be flattened.""" @@ -78,9 +96,9 @@ class FlattenDictTestCase(unittest.TestCase): ) expected = { "content.msgtype": "m.text", - "content.body": "hello world!", + "content.body": "Hello world!", "content.format": "org.matrix.custom.html", - "content.formatted_body": "<h1>hello world!</h1>", + "content.formatted_body": "<h1>Hello world!</h1>", "room_id": "!test:test", "sender": "@alice:test", "type": "m.room.message", @@ -107,6 +125,7 @@ class FlattenDictTestCase(unittest.TestCase): "room_id": "!test:test", "sender": "@alice:test", "type": "m.room.message", + "content.org.matrix.msc1767.markup": [], } self.assertEqual(expected, _flatten_dict(event)) @@ -118,6 +137,7 @@ class FlattenDictTestCase(unittest.TestCase): "room_id": "!test:test", "sender": "@alice:test", "type": "m.room.message", + "content.org.matrix.msc1767.markup": [], } self.assertEqual(expected, _flatten_dict(event)) @@ -129,7 +149,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): *, has_mentions: bool = False, user_mentions: Optional[Set[str]] = None, - room_mention: bool = False, related_events: Optional[JsonDict] = None, ) -> PushRuleEvaluator: event = FrozenEvent( @@ -150,7 +169,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): _flatten_dict(event), has_mentions, user_mentions or set(), - room_mention, room_member_count, sender_power_level, cast(Dict[str, int], power_levels.get("notifications", {})), @@ -158,6 +176,8 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): related_event_match_enabled=True, room_version_feature_flags=event.room_version.msc3931_push_features, msc3931_enabled=True, + msc3758_exact_event_match=True, + msc3966_exact_event_property_contains=True, ) def test_display_name(self) -> None: @@ -210,27 +230,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): # Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions # since the BulkPushRuleEvaluator is what handles data sanitisation. - def test_room_mentions(self) -> None: - """Check for room mentions.""" - condition = {"kind": "org.matrix.msc3952.is_room_mention"} - - # No room mention shouldn't match. - evaluator = self._get_evaluator({}, has_mentions=True) - self.assertFalse(evaluator.matches(condition, None, None)) - - # Room mention should match. - evaluator = self._get_evaluator({}, has_mentions=True, room_mention=True) - self.assertTrue(evaluator.matches(condition, None, None)) - - # A room mention and user mention is valid. - evaluator = self._get_evaluator( - {}, has_mentions=True, user_mentions={"@another:test"}, room_mention=True - ) - self.assertTrue(evaluator.matches(condition, None, None)) - - # Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions - # since the BulkPushRuleEvaluator is what handles data sanitisation. - def _assert_matches( self, condition: JsonDict, content: JsonMapping, msg: Optional[str] = None ) -> None: @@ -402,6 +401,178 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): "pattern should not match before a newline", ) + def test_exact_event_match_string(self) -> None: + """Check that exact_event_match conditions work as expected for strings.""" + + # Test against a string value. + condition = { + "kind": "com.beeper.msc3758.exact_event_match", + "key": "content.value", + "value": "foobaz", + } + self._assert_matches( + condition, + {"value": "foobaz"}, + "exact value should match", + ) + self._assert_not_matches( + condition, + {"value": "FoobaZ"}, + "values should match and be case-sensitive", + ) + self._assert_not_matches( + condition, + {"value": "test foobaz test"}, + "values must exactly match", + ) + value: Any + for value in (True, False, 1, 1.1, None, [], {}): + self._assert_not_matches( + condition, + {"value": value}, + "incorrect types should not match", + ) + + # it should work on frozendicts too + self._assert_matches( + condition, + frozendict.frozendict({"value": "foobaz"}), + "values should match on frozendicts", + ) + + def test_exact_event_match_boolean(self) -> None: + """Check that exact_event_match conditions work as expected for booleans.""" + + # Test against a True boolean value. + condition = { + "kind": "com.beeper.msc3758.exact_event_match", + "key": "content.value", + "value": True, + } + self._assert_matches( + condition, + {"value": True}, + "exact value should match", + ) + self._assert_not_matches( + condition, + {"value": False}, + "incorrect values should not match", + ) + for value in ("foobaz", 1, 1.1, None, [], {}): + self._assert_not_matches( + condition, + {"value": value}, + "incorrect types should not match", + ) + + # Test against a False boolean value. + condition = { + "kind": "com.beeper.msc3758.exact_event_match", + "key": "content.value", + "value": False, + } + self._assert_matches( + condition, + {"value": False}, + "exact value should match", + ) + self._assert_not_matches( + condition, + {"value": True}, + "incorrect values should not match", + ) + # Choose false-y values to ensure there's no type coercion. + for value in ("", 0, 1.1, None, [], {}): + self._assert_not_matches( + condition, + {"value": value}, + "incorrect types should not match", + ) + + def test_exact_event_match_null(self) -> None: + """Check that exact_event_match conditions work as expected for null.""" + + condition = { + "kind": "com.beeper.msc3758.exact_event_match", + "key": "content.value", + "value": None, + } + self._assert_matches( + condition, + {"value": None}, + "exact value should match", + ) + for value in ("foobaz", True, False, 1, 1.1, [], {}): + self._assert_not_matches( + condition, + {"value": value}, + "incorrect types should not match", + ) + + def test_exact_event_match_integer(self) -> None: + """Check that exact_event_match conditions work as expected for integers.""" + + condition = { + "kind": "com.beeper.msc3758.exact_event_match", + "key": "content.value", + "value": 1, + } + self._assert_matches( + condition, + {"value": 1}, + "exact value should match", + ) + value: Any + for value in (1.1, -1, 0): + self._assert_not_matches( + condition, + {"value": value}, + "incorrect values should not match", + ) + for value in ("1", True, False, None, [], {}): + self._assert_not_matches( + condition, + {"value": value}, + "incorrect types should not match", + ) + + def test_exact_event_property_contains(self) -> None: + """Check that exact_event_property_contains conditions work as expected.""" + + condition = { + "kind": "org.matrix.msc3966.exact_event_property_contains", + "key": "content.value", + "value": "foobaz", + } + self._assert_matches( + condition, + {"value": ["foobaz"]}, + "exact value should match", + ) + self._assert_matches( + condition, + {"value": ["foobaz", "bugz"]}, + "extra values should match", + ) + self._assert_not_matches( + condition, + {"value": ["FoobaZ"]}, + "values should match and be case-sensitive", + ) + self._assert_not_matches( + condition, + {"value": "foobaz"}, + "does not search in a string", + ) + + # it should work on frozendicts too + self._assert_matches( + condition, + freeze({"value": ["foobaz"]}), + "values should match on frozendicts", + ) + def test_no_body(self) -> None: """Not having a body shouldn't break the evaluator.""" evaluator = self._get_evaluator({}) diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 043dbe76af..65ef4bb160 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py
@@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional +from typing import Any, List, Optional, Sequence from twisted.test.proto_helpers import MemoryReactor @@ -139,7 +139,7 @@ class EventsStreamTestCase(BaseStreamTestCase): ) # this is the point in the DAG where we make a fork - fork_point: List[str] = self.get_success( + fork_point: Sequence[str] = self.get_success( self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) ) @@ -168,7 +168,7 @@ class EventsStreamTestCase(BaseStreamTestCase): pl_event = self.get_success( inject_event( self.hs, - prev_event_ids=prev_events, + prev_event_ids=list(prev_events), type=EventTypes.PowerLevels, state_key="", sender=self.user_id, @@ -294,7 +294,7 @@ class EventsStreamTestCase(BaseStreamTestCase): ) # this is the point in the DAG where we make a fork - fork_point: List[str] = self.get_success( + fork_point: Sequence[str] = self.get_success( self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) ) @@ -323,7 +323,7 @@ class EventsStreamTestCase(BaseStreamTestCase): e = self.get_success( inject_event( self.hs, - prev_event_ids=prev_events, + prev_event_ids=list(prev_events), type=EventTypes.PowerLevels, state_key="", sender=self.user_id, diff --git a/tests/replication/tcp/streams/test_partial_state.py b/tests/replication/tcp/streams/test_partial_state.py
index 38b5020ce0..452ac85069 100644 --- a/tests/replication/tcp/streams/test_partial_state.py +++ b/tests/replication/tcp/streams/test_partial_state.py
@@ -37,7 +37,7 @@ class PartialStateStreamsTestCase(BaseMultiWorkerStreamTestCase): room_id = self.helper.create_room_as("@bob:test") # Mark the room as partial-stated. self.get_success( - self.store.store_partial_state_room(room_id, ["serv1", "serv2"], 0, "serv1") + self.store.store_partial_state_room(room_id, {"serv1", "serv2"}, 0, "serv1") ) worker = self.make_worker_hs("synapse.app.generic_worker") diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index 68de5d1cc2..5a38ac831f 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py
@@ -13,7 +13,7 @@ # limitations under the License. from unittest.mock import Mock -from synapse.handlers.typing import RoomMember +from synapse.handlers.typing import RoomMember, TypingWriterHandler from synapse.replication.tcp.streams import TypingStream from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -33,6 +33,7 @@ class TypingStreamTestCase(BaseStreamTestCase): def test_typing(self) -> None: typing = self.hs.get_typing_handler() + assert isinstance(typing, TypingWriterHandler) self.reconnect() @@ -88,6 +89,7 @@ class TypingStreamTestCase(BaseStreamTestCase): sends the proper position and RDATA). """ typing = self.hs.get_typing_handler() + assert isinstance(typing, TypingWriterHandler) self.reconnect() diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py
index 6e4055cc21..bf927beb6a 100644 --- a/tests/replication/tcp/test_handler.py +++ b/tests/replication/tcp/test_handler.py
@@ -127,6 +127,7 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase): # ... updating the cache ID gen on the master still shouldn't cause the # deferred to wake up. + assert store._cache_id_gen is not None ctx = store._cache_id_gen.get_next() self.get_success(ctx.__aenter__()) self.get_success(ctx.__aexit__(None, None, None)) diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 89380e25b5..08703206a9 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py
@@ -16,6 +16,7 @@ from unittest.mock import Mock from synapse.api.constants import EventTypes, Membership from synapse.events.builder import EventBuilderFactory +from synapse.handlers.typing import TypingWriterHandler from synapse.rest.admin import register_servlets_for_client_rest_resource from synapse.rest.client import login, room from synapse.types import UserID, create_requester @@ -174,6 +175,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): token = self.login("user3", "pass") typing_handler = self.hs.get_typing_handler() + assert isinstance(typing_handler, TypingWriterHandler) sent_on_1 = False sent_on_2 = False diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 9345cfbeb2..0798b021c3 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py
@@ -50,6 +50,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): user_dict = self.get_success( self.hs.get_datastores().main.get_user_by_access_token(access_token) ) + assert user_dict is not None token_id = user_dict.token_id self.get_success( diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index aadb31ca83..db77a45ae3 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py
@@ -213,7 +213,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.admin_user_tok = self.login("admin", "pass") self.filepaths = MediaFilePaths(hs.config.media.media_store_path) - self.url = "/_synapse/admin/v1/media/%s/delete" % self.server_name + self.url = "/_synapse/admin/v1/media/delete" + self.legacy_url = "/_synapse/admin/v1/media/%s/delete" % self.server_name # Move clock up to somewhat realistic time self.reactor.advance(1000000000) @@ -332,11 +333,13 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): channel.json_body["error"], ) - def test_delete_media_never_accessed(self) -> None: + @parameterized.expand([(True,), (False,)]) + def test_delete_media_never_accessed(self, use_legacy_url: bool) -> None: """ Tests that media deleted if it is older than `before_ts` and never accessed `last_access_ts` is `NULL` and `created_ts` < `before_ts` """ + url = self.legacy_url if use_legacy_url else self.url # upload and do not access server_and_media_id = self._create_media() @@ -351,7 +354,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): now_ms = self.clock.time_msec() channel = self.make_request( "POST", - self.url + "?before_ts=" + str(now_ms), + url + "?before_ts=" + str(now_ms), access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
index a2f347f666..f71ff46d87 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py
@@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Sequence from twisted.test.proto_helpers import MemoryReactor @@ -558,7 +558,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): def _check_invite_and_join_status( self, user_id: str, expected_invites: int, expected_memberships: int - ) -> List[RoomsForUser]: + ) -> Sequence[RoomsForUser]: """Check invite and room membership status of a user. Args diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 5c1ced355f..f5b213219f 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py
@@ -2913,7 +2913,8 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): other_user_tok = self.login("user", "pass") event_builder_factory = self.hs.get_event_builder_factory() event_creation_handler = self.hs.get_event_creation_handler() - storage_controllers = self.hs.get_storage_controllers() + persistence = self.hs.get_storage_controllers().persistence + assert persistence is not None # Create two rooms, one with a local user only and one with both a local # and remote user. @@ -2934,11 +2935,13 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): }, ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( event_creation_handler.create_new_client_event(builder) ) - self.get_success(storage_controllers.persistence.persist_event(event, context)) + context = self.get_success(unpersisted_context.persist(event)) + + self.get_success(persistence.persist_event(event, context)) # Now get rooms url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms" diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py
index 30f12f1bff..6c04e6c56c 100644 --- a/tests/rest/admin/test_username_available.py +++ b/tests/rest/admin/test_username_available.py
@@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin @@ -33,9 +35,14 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase): self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") - async def check_username(username: str) -> bool: - if username == "allowed": - return True + async def check_username( + localpart: str, + guest_access_token: Optional[str] = None, + assigned_user_id: Optional[str] = None, + inhibit_user_in_use_error: bool = False, + ) -> None: + if localpart == "allowed": + return raise SynapseError( 400, "User ID already taken.", @@ -43,7 +50,7 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase): ) handler = self.hs.get_registration_handler() - handler.check_username = check_username + handler.check_username = check_username # type: ignore[assignment] def test_username_available(self) -> None: """ diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index 88f255c9ee..e2ee1a1766 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py
@@ -1193,7 +1193,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): return {} # Register a mock that will return the expected result depending on the remote. - self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) + self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) # type: ignore[assignment] # Check that we've got the correct response from the client-side endpoint. self._test_status( diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index 208ec44829..a144610078 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py
@@ -34,7 +34,7 @@ from synapse.util import Clock from tests import unittest from tests.handlers.test_oidc import HAS_OIDC from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER -from tests.server import FakeChannel, make_request +from tests.server import FakeChannel from tests.unittest import override_config, skip_unless @@ -43,6 +43,9 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker): super().__init__(hs) self.recaptcha_attempts: List[Tuple[dict, str]] = [] + def is_enabled(self) -> bool: + return True + def check_auth(self, authdict: dict, clientip: str) -> Any: self.recaptcha_attempts.append((authdict, clientip)) return succeed(True) @@ -1319,16 +1322,8 @@ class OidcBackchannelLogoutTests(unittest.HomeserverTestCase): channel = self.submit_logout_token(logout_token) self.assertEqual(channel.code, 200) - # Now try to exchange the login token - channel = make_request( - self.hs.get_reactor(), - self.site, - "POST", - "/login", - content={"type": "m.login.token", "token": login_token}, - ) - # It should have failed - self.assertEqual(channel.code, 403) + # Now try to exchange the login token, it should fail. + self.helper.login_via_token(login_token, 403) @override_config( { diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py
index afc8d641be..830762fd53 100644 --- a/tests/rest/client/test_filter.py +++ b/tests/rest/client/test_filter.py
@@ -63,14 +63,14 @@ class FilterTestCase(unittest.HomeserverTestCase): def test_add_filter_non_local_user(self) -> None: _is_mine = self.hs.is_mine - self.hs.is_mine = lambda target_user: False + self.hs.is_mine = lambda target_user: False # type: ignore[assignment] channel = self.make_request( "POST", "/_matrix/client/r0/user/%s/filter" % (self.user_id), self.EXAMPLE_FILTER_JSON, ) - self.hs.is_mine = _is_mine + self.hs.is_mine = _is_mine # type: ignore[assignment] self.assertEqual(channel.code, 403) self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py
index b3738a0304..67e16880e6 100644 --- a/tests/rest/client/test_presence.py +++ b/tests/rest/client/test_presence.py
@@ -36,14 +36,14 @@ class PresenceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - presence_handler = Mock(spec=PresenceHandler) - presence_handler.set_state.return_value = make_awaitable(None) + self.presence_handler = Mock(spec=PresenceHandler) + self.presence_handler.set_state.return_value = make_awaitable(None) hs = self.setup_test_homeserver( "red", federation_http_client=None, federation_client=Mock(), - presence_handler=presence_handler, + presence_handler=self.presence_handler, ) return hs @@ -61,7 +61,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, HTTPStatus.OK) - self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1) + self.assertEqual(self.presence_handler.set_state.call_count, 1) @unittest.override_config({"use_presence": False}) def test_put_presence_disabled(self) -> None: @@ -76,4 +76,4 @@ class PresenceTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, HTTPStatus.OK) - self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0) + self.assertEqual(self.presence_handler.set_state.call_count, 0) diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 11cf3939d8..4c561f9525 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py
@@ -151,7 +151,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") def test_POST_guest_registration(self) -> None: - self.hs.config.key.macaroon_secret_key = "test" + self.hs.config.key.macaroon_secret_key = b"test" self.hs.config.registration.allow_guest_access = True channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") @@ -1166,12 +1166,15 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): """ user_id = self.register_user("kermit_delta", "user") - self.hs.config.account_validity.startup_job_max_delta = self.max_delta + self.hs.config.account_validity.account_validity_startup_job_max_delta = ( + self.max_delta + ) now_ms = self.hs.get_clock().time_msec() self.get_success(self.store._set_expiration_date_when_missing()) res = self.get_success(self.store.get_expiration_ts_for_user(user_id)) + assert res is not None self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta) self.assertLessEqual(res, now_ms + self.validity_period) diff --git a/tests/rest/client/test_report_event.py b/tests/rest/client/test_report_event.py
index 7cb1017a4a..1250685d39 100644 --- a/tests/rest/client/test_report_event.py +++ b/tests/rest/client/test_report_event.py
@@ -73,6 +73,18 @@ class ReportEventTestCase(unittest.HomeserverTestCase): data = {"reason": None, "score": None} self._assert_status(400, data) + def test_cannot_report_nonexistent_event(self) -> None: + """ + Tests that we don't accept event reports for events which do not exist. + """ + channel = self.make_request( + "POST", + f"rooms/{self.room_id}/report/$nonsenseeventid:test", + {"reason": "i am very sad"}, + access_token=self.other_user_tok, + ) + self.assertEqual(404, channel.code, msg=channel.result["body"]) + def _assert_status(self, response_status: int, data: JsonDict) -> None: channel = self.make_request( "POST", self.report_path, data, access_token=self.other_user_tok diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 9c8c1889d3..d3e06bf6b3 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py
@@ -136,6 +136,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): # Send a first event, which should be filtered out at the end of the test. resp = self.helper.send(room_id=room_id, body="1", tok=self.token) first_event_id = resp.get("event_id") + assert isinstance(first_event_id, str) # Advance the time by 2 days. We're using the default retention policy, therefore # after this the first event will still be valid. @@ -144,6 +145,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): # Send another event, which shouldn't get filtered out. resp = self.helper.send(room_id=room_id, body="2", tok=self.token) valid_event_id = resp.get("event_id") + assert isinstance(valid_event_id, str) # Advance the time by another 2 days. After this, the first event should be # outdated but not the second one. @@ -229,7 +231,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): # Check that we can still access state events that were sent before the event that # has been purged. - self.get_event(room_id, create_event.event_id) + self.get_event(room_id, bool(create_event)) def get_event(self, event_id: str, expect_none: bool = False) -> JsonDict: event = self.get_success(self.store.get_event(event_id, allow_none=True)) @@ -238,7 +240,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): self.assertIsNone(event) return {} - self.assertIsNotNone(event) + assert event is not None time_now = self.clock.time_msec() serialized = self.serializer.serialize_event(event, time_now) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 9222cab198..cfad182b2f 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py
@@ -3382,8 +3382,8 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): # a remote IS. We keep the mock for make_and_store_3pid_invite around so we # can check its call_count later on during the test. make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) - self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock - self.hs.get_identity_handler().lookup_3pid = Mock( + self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment] + self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment] return_value=make_awaitable(None), ) @@ -3443,8 +3443,8 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): # a remote IS. We keep the mock for make_and_store_3pid_invite around so we # can check its call_count later on during the test. make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) - self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock - self.hs.get_identity_handler().lookup_3pid = Mock( + self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment] + self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment] return_value=make_awaitable(None), ) @@ -3563,8 +3563,10 @@ class TimestampLookupTestCase(unittest.HomeserverTestCase): ) event.internal_metadata.outlier = True + persistence = self._storage_controllers.persistence + assert persistence is not None self.get_success( - self._storage_controllers.persistence.persist_event( + persistence.persist_event( event, EventContext.for_outlier(self._storage_controllers) ) ) diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index c807a37bc2..8d2cdf8751 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py
@@ -84,7 +84,7 @@ class RoomTestCase(_ShadowBannedBase): def test_invite_3pid(self) -> None: """Ensure that a 3PID invite does not attempt to contact the identity server.""" identity_handler = self.hs.get_identity_handler() - identity_handler.lookup_3pid = Mock( + identity_handler.lookup_3pid = Mock( # type: ignore[assignment] side_effect=AssertionError("This should not get called") ) @@ -222,7 +222,7 @@ class RoomTestCase(_ShadowBannedBase): event_source.get_new_events( user=UserID.from_string(self.other_user_id), from_key=0, - limit=None, + limit=10, room_ids=[room_id], is_guest=False, ) @@ -286,6 +286,7 @@ class ProfileTestCase(_ShadowBannedBase): self.banned_user_id, ) ) + assert event is not None self.assertEqual( event.content, {"membership": "join", "displayname": original_display_name} ) @@ -321,6 +322,7 @@ class ProfileTestCase(_ShadowBannedBase): self.banned_user_id, ) ) + assert event is not None self.assertEqual( event.content, {"membership": "join", "displayname": original_display_name} ) diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 3325d43a2f..5fa3440691 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py
@@ -425,7 +425,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): async def test_fn( event: EventBase, state_events: StateMap[EventBase] ) -> Tuple[bool, Optional[JsonDict]]: - if event.is_state and event.type == EventTypes.PowerLevels: + if event.is_state() and event.type == EventTypes.PowerLevels: await api.create_and_send_event_into_room( { "room_id": event.room_id, diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py
index 5ec343dd7f..0b4c691318 100644 --- a/tests/rest/client/test_upgrade_room.py +++ b/tests/rest/client/test_upgrade_room.py
@@ -84,7 +84,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): self.room_id, EventTypes.Tombstone, "" ) ) - self.assertIsNotNone(tombstone_event) + assert tombstone_event is not None self.assertEqual(new_room_id, tombstone_event.content["replacement_room"]) # Check that the new room exists. diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 8d6f2b6ff9..9532e5ddc1 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py
@@ -36,6 +36,7 @@ from urllib.parse import urlencode import attr from typing_extensions import Literal +from twisted.test.proto_helpers import MemoryReactorClock from twisted.web.resource import Resource from twisted.web.server import Site @@ -67,6 +68,7 @@ class RestHelper: """ hs: HomeServer + reactor: MemoryReactorClock site: Site auth_user_id: Optional[str] @@ -142,7 +144,7 @@ class RestHelper: path = path + "?access_token=%s" % tok channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "POST", path, @@ -216,7 +218,7 @@ class RestHelper: data["reason"] = reason channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "POST", path, @@ -313,7 +315,7 @@ class RestHelper: data.update(extra_data or {}) channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "PUT", path, @@ -394,7 +396,7 @@ class RestHelper: path = path + "?access_token=%s" % tok channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "PUT", path, @@ -433,7 +435,7 @@ class RestHelper: path = path + f"?access_token={tok}" channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "GET", path, @@ -488,7 +490,7 @@ class RestHelper: if body is not None: content = json.dumps(body).encode("utf8") - channel = make_request(self.hs.get_reactor(), self.site, method, path, content) + channel = make_request(self.reactor, self.site, method, path, content) assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( expect_code, @@ -573,8 +575,8 @@ class RestHelper: image_length = len(image_data) path = "/_matrix/media/r0/upload?filename=%s" % (filename,) channel = make_request( - self.hs.get_reactor(), - FakeSite(resource, self.hs.get_reactor()), + self.reactor, + FakeSite(resource, self.reactor), "POST", path, content=image_data, @@ -603,7 +605,7 @@ class RestHelper: expect_code: The return code to expect from attempting the whoami request """ channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "GET", "account/whoami", @@ -642,7 +644,7 @@ class RestHelper: ) -> Tuple[JsonDict, FakeAuthorizationGrant]: """Log in (as a new user) via OIDC - Returns the result of the final token login. + Returns the result of the final token login and the fake authorization grant. Requires that "oidc_config" in the homeserver config be set appropriately (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a @@ -672,10 +674,28 @@ class RestHelper: assert m, channel.text_body login_token = m.group(1) - # finally, submit the matrix login token to the login API, which gives us our - # matrix access token and device id. + return self.login_via_token(login_token, expected_status), grant + + def login_via_token( + self, + login_token: str, + expected_status: int = 200, + ) -> JsonDict: + """Submit the matrix login token to the login API, which gives us our + matrix access token and device id.Log in (as a new user) via OIDC + + Returns the result of the token login. + + Requires that "oidc_config" in the homeserver config be set appropriately + (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a + "public_base_url". + + Also requires the login servlet and the OIDC callback resource to be mounted at + the normal places. + """ + channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "POST", "/login", @@ -684,7 +704,7 @@ class RestHelper: assert ( channel.code == expected_status ), f"unexpected status in response: {channel.code}" - return channel.json_body, grant + return channel.json_body def auth_via_oidc( self, @@ -805,7 +825,7 @@ class RestHelper: with fake_serer.patch_homeserver(hs=self.hs): # now hit the callback URI with the right params and a made-up code channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "GET", callback_uri, @@ -849,7 +869,7 @@ class RestHelper: # is the easiest way of figuring out what the Host header ought to be set to # to keep Synapse happy. channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "GET", uri, @@ -867,7 +887,7 @@ class RestHelper: location = get_location(channel) parts = urllib.parse.urlsplit(location) channel = make_request( - self.hs.get_reactor(), + self.reactor, self.site, "GET", urllib.parse.urlunsplit(("", "") + parts[2:]), @@ -900,9 +920,7 @@ class RestHelper: + urllib.parse.urlencode({"session": ui_auth_session_id}) ) # hit the redirect url (which will issue a cookie and state) - channel = make_request( - self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint - ) + channel = make_request(self.reactor, self.site, "GET", sso_redirect_endpoint) # that should serve a confirmation page assert channel.code == HTTPStatus.OK, channel.text_body channel.extract_cookies(cookies) diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index d18fc13c21..17a3b06a8e 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py
@@ -16,7 +16,7 @@ import shutil import tempfile from binascii import unhexlify from io import BytesIO -from typing import Any, BinaryIO, Dict, List, Optional, Union +from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Tuple, Union from unittest.mock import Mock from urllib import parse @@ -32,6 +32,7 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import Codes from synapse.events import EventBase from synapse.events.spamcheck import load_legacy_spam_checkers +from synapse.http.types import QueryParams from synapse.logging.context import make_deferred_yieldable from synapse.module_api import ModuleApi from synapse.rest import admin @@ -41,7 +42,7 @@ from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.rest.media.v1.media_storage import MediaStorage, ReadableFileWrapper from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend from synapse.server import HomeServer -from synapse.types import RoomAlias +from synapse.types import JsonDict, RoomAlias from synapse.util import Clock from tests import unittest @@ -201,36 +202,46 @@ class _TestImage: ], ) class MediaRepoTests(unittest.HomeserverTestCase): - + test_image: ClassVar[_TestImage] hijack_auth = True user_id = "@test:user" def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.fetches = [] + self.fetches: List[ + Tuple[ + "Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]]", + str, + str, + Optional[QueryParams], + ] + ] = [] def get_file( destination: str, path: str, output_stream: BinaryIO, - args: Optional[Dict[str, Union[str, List[str]]]] = None, + args: Optional[QueryParams] = None, + retry_on_dns_fail: bool = True, max_size: Optional[int] = None, - ) -> Deferred: - """ - Returns tuple[int,dict,str,int] of file length, response headers, - absolute URI, and response code. - """ + ignore_backoff: bool = False, + ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]": + """A mock for MatrixFederationHttpClient.get_file.""" - def write_to(r): + def write_to( + r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]] + ) -> Tuple[int, Dict[bytes, List[bytes]]]: data, response = r output_stream.write(data) return response - d = Deferred() - d.addCallback(write_to) + d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred() self.fetches.append((d, destination, path, args)) - return make_deferred_yieldable(d) + # Note that this callback changes the value held by d. + d_after_callback = d.addCallback(write_to) + return make_deferred_yieldable(d_after_callback) + # Mock out the homeserver's MatrixFederationHttpClient client = Mock() client.get_file = get_file @@ -461,6 +472,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): # Synapse should regenerate missing thumbnails. origin, media_id = self.media_id.split("/") info = self.get_success(self.store.get_cached_remote_media(origin, media_id)) + assert info is not None file_id = info["filesystem_id"] thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir( @@ -581,7 +593,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): "thumbnail_method": method, "thumbnail_type": self.test_image.content_type, "thumbnail_length": 256, - "filesystem_id": f"thumbnail1{self.test_image.extension}", + "filesystem_id": f"thumbnail1{self.test_image.extension.decode()}", }, { "thumbnail_width": 32, @@ -589,10 +601,10 @@ class MediaRepoTests(unittest.HomeserverTestCase): "thumbnail_method": method, "thumbnail_type": self.test_image.content_type, "thumbnail_length": 256, - "filesystem_id": f"thumbnail2{self.test_image.extension}", + "filesystem_id": f"thumbnail2{self.test_image.extension.decode()}", }, ], - file_id=f"image{self.test_image.extension}", + file_id=f"image{self.test_image.extension.decode()}", url_cache=None, server_name=None, ) @@ -637,6 +649,7 @@ class TestSpamCheckerLegacy: self.config = config self.api = api + @staticmethod def parse_config(config: Dict[str, Any]) -> Dict[str, Any]: return config @@ -748,7 +761,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase): async def check_media_file_for_spam( self, file_wrapper: ReadableFileWrapper, file_info: FileInfo - ) -> Union[Codes, Literal["NOT_SPAM"]]: + ) -> Union[Codes, Literal["NOT_SPAM"], Tuple[Codes, JsonDict]]: buf = BytesIO() await file_wrapper.write_chunks_to(buf.write) diff --git a/tests/scripts/test_new_matrix_user.py b/tests/scripts/test_new_matrix_user.py
index 22f99c6ab1..3285f2433c 100644 --- a/tests/scripts/test_new_matrix_user.py +++ b/tests/scripts/test_new_matrix_user.py
@@ -12,29 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Optional from unittest.mock import Mock, patch from synapse._scripts.register_new_matrix_user import request_registration +from synapse.types import JsonDict from tests.unittest import TestCase class RegisterTestCase(TestCase): - def test_success(self): + def test_success(self) -> None: """ The script will fetch a nonce, and then generate a MAC with it, and then post that MAC. """ - def get(url, verify=None): + def get(url: str, verify: Optional[bool] = None) -> Mock: r = Mock() r.status_code = 200 r.json = lambda: {"nonce": "a"} return r - def post(url, json=None, verify=None): + def post( + url: str, json: Optional[JsonDict] = None, verify: Optional[bool] = None + ) -> Mock: # Make sure we are sent the correct info + assert json is not None self.assertEqual(json["username"], "user") self.assertEqual(json["password"], "pass") self.assertEqual(json["nonce"], "a") @@ -70,12 +74,12 @@ class RegisterTestCase(TestCase): # sys.exit shouldn't have been called. self.assertEqual(err_code, []) - def test_failure_nonce(self): + def test_failure_nonce(self) -> None: """ If the script fails to fetch a nonce, it throws an error and quits. """ - def get(url, verify=None): + def get(url: str, verify: Optional[bool] = None) -> Mock: r = Mock() r.status_code = 404 r.reason = "Not Found" @@ -107,20 +111,23 @@ class RegisterTestCase(TestCase): self.assertIn("ERROR! Received 404 Not Found", out) self.assertNotIn("Success!", out) - def test_failure_post(self): + def test_failure_post(self) -> None: """ The script will fetch a nonce, and then if the final POST fails, will report an error and quit. """ - def get(url, verify=None): + def get(url: str, verify: Optional[bool] = None) -> Mock: r = Mock() r.status_code = 200 r.json = lambda: {"nonce": "a"} return r - def post(url, json=None, verify=None): + def post( + url: str, json: Optional[JsonDict] = None, verify: Optional[bool] = None + ) -> Mock: # Make sure we are sent the correct info + assert json is not None self.assertEqual(json["username"], "user") self.assertEqual(json["password"], "pass") self.assertEqual(json["nonce"], "a") diff --git a/tests/server.py b/tests/server.py
index 237bcad8ba..5de9722766 100644 --- a/tests/server.py +++ b/tests/server.py
@@ -22,20 +22,25 @@ import warnings from collections import deque from io import SEEK_END, BytesIO from typing import ( + Any, + Awaitable, Callable, Dict, Iterable, List, MutableMapping, Optional, + Sequence, Tuple, Type, + TypeVar, Union, + cast, ) from unittest.mock import Mock import attr -from typing_extensions import Deque +from typing_extensions import Deque, ParamSpec from zope.interface import implementer from twisted.internet import address, threads, udp @@ -44,8 +49,10 @@ from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed from twisted.internet.error import DNSLookupError from twisted.internet.interfaces import ( IAddress, + IConnector, IConsumer, IHostnameResolver, + IProducer, IProtocol, IPullProducer, IPushProducer, @@ -54,6 +61,8 @@ from twisted.internet.interfaces import ( IResolverSimple, ITransport, ) +from twisted.internet.protocol import ClientFactory, DatagramProtocol +from twisted.python import threadpool from twisted.python.failure import Failure from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock from twisted.web.http_headers import Headers @@ -61,6 +70,7 @@ from twisted.web.resource import IResource from twisted.web.server import Request, Site from synapse.config.database import DatabaseConnectionConfig +from synapse.config.homeserver import HomeServerConfig from synapse.events.presence_router import load_legacy_presence_router from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.events.third_party_rules import load_legacy_third_party_event_rules @@ -88,6 +98,9 @@ from tests.utils import ( logger = logging.getLogger(__name__) +R = TypeVar("R") +P = ParamSpec("P") + # the type of thing that can be passed into `make_request` in the headers list CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]] @@ -98,12 +111,14 @@ class TimedOutException(Exception): """ -@implementer(IConsumer) +@implementer(ITransport, IPushProducer, IConsumer) @attr.s(auto_attribs=True) class FakeChannel: """ A fake Twisted Web Channel (the part that interfaces with the wire). + + See twisted.web.http.HTTPChannel. """ site: Union[Site, "FakeSite"] @@ -142,7 +157,7 @@ class FakeChannel: Raises an exception if the request has not yet completed. """ - if not self.is_finished: + if not self.is_finished(): raise Exception("Request not yet completed") return self.result["body"].decode("utf8") @@ -165,27 +180,36 @@ class FakeChannel: h.addRawHeader(*i) return h - def writeHeaders(self, version, code, reason, headers): + def writeHeaders( + self, version: bytes, code: bytes, reason: bytes, headers: Headers + ) -> None: self.result["version"] = version self.result["code"] = code self.result["reason"] = reason self.result["headers"] = headers - def write(self, content: bytes) -> None: - assert isinstance(content, bytes), "Should be bytes! " + repr(content) + def write(self, data: bytes) -> None: + assert isinstance(data, bytes), "Should be bytes! " + repr(data) if "body" not in self.result: self.result["body"] = b"" - self.result["body"] += content + self.result["body"] += data + + def writeSequence(self, data: Iterable[bytes]) -> None: + for x in data: + self.write(x) + + def loseConnection(self) -> None: + self.unregisterProducer() + self.transport.loseConnection() # Type ignore: mypy doesn't like the fact that producer isn't an IProducer. - def registerProducer( # type: ignore[override] - self, - producer: Union[IPullProducer, IPushProducer], - streaming: bool, - ) -> None: - self._producer = producer + def registerProducer(self, producer: IProducer, streaming: bool) -> None: + # TODO This should ensure that the IProducer is an IPushProducer or + # IPullProducer, unfortunately twisted.protocols.basic.FileSender does + # implement those, but doesn't declare it. + self._producer = cast(Union[IPushProducer, IPullProducer], producer) self.producerStreaming = streaming def _produce() -> None: @@ -202,6 +226,16 @@ class FakeChannel: self._producer = None + def stopProducing(self) -> None: + if self._producer is not None: + self._producer.stopProducing() + + def pauseProducing(self) -> None: + raise NotImplementedError() + + def resumeProducing(self) -> None: + raise NotImplementedError() + def requestDone(self, _self: Request) -> None: self.result["done"] = True if isinstance(_self, SynapseRequest): @@ -281,12 +315,12 @@ class FakeSite: self.reactor = reactor self.experimental_cors_msc3886 = experimental_cors_msc3886 - def getResourceFor(self, request): + def getResourceFor(self, request: Request) -> IResource: return self._resource def make_request( - reactor, + reactor: MemoryReactorClock, site: Union[Site, FakeSite], method: Union[bytes, str], path: Union[bytes, str], @@ -409,19 +443,21 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): A MemoryReactorClock that supports callFromThread. """ - def __init__(self): + def __init__(self) -> None: self.threadpool = ThreadPool(self) self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {} - self._udp = [] + self._udp: List[udp.Port] = [] self.lookups: Dict[str, str] = {} - self._thread_callbacks: Deque[Callable[[], None]] = deque() + self._thread_callbacks: Deque[Callable[..., R]] = deque() lookups = self.lookups @implementer(IResolverSimple) class FakeResolver: - def getHostByName(self, name, timeout=None): + def getHostByName( + self, name: str, timeout: Optional[Sequence[int]] = None + ) -> "Deferred[str]": if name not in lookups: return fail(DNSLookupError("OH NO: unknown %s" % (name,))) return succeed(lookups[name]) @@ -432,25 +468,44 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver: raise NotImplementedError() - def listenUDP(self, port, protocol, interface="", maxPacketSize=8196): + def listenUDP( + self, + port: int, + protocol: DatagramProtocol, + interface: str = "", + maxPacketSize: int = 8196, + ) -> udp.Port: p = udp.Port(port, protocol, interface, maxPacketSize, self) p.startListening() self._udp.append(p) return p - def callFromThread(self, callback, *args, **kwargs): + def callFromThread( + self, callable: Callable[..., Any], *args: object, **kwargs: object + ) -> None: """ Make the callback fire in the next reactor iteration. """ - cb = lambda: callback(*args, **kwargs) + cb = lambda: callable(*args, **kwargs) # it's not safe to call callLater() here, so we append the callback to a # separate queue. self._thread_callbacks.append(cb) - def getThreadPool(self): - return self.threadpool + def callInThread( + self, callable: Callable[..., Any], *args: object, **kwargs: object + ) -> None: + raise NotImplementedError() + + def suggestThreadPoolSize(self, size: int) -> None: + raise NotImplementedError() + + def getThreadPool(self) -> "threadpool.ThreadPool": + # Cast to match super-class. + return cast(threadpool.ThreadPool, self.threadpool) - def add_tcp_client_callback(self, host: str, port: int, callback: Callable): + def add_tcp_client_callback( + self, host: str, port: int, callback: Callable[[], None] + ) -> None: """Add a callback that will be invoked when we receive a connection attempt to the given IP/port using `connectTCP`. @@ -459,7 +514,14 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): """ self._tcp_callbacks[(host, port)] = callback - def connectTCP(self, host: str, port: int, factory, timeout=30, bindAddress=None): + def connectTCP( + self, + host: str, + port: int, + factory: ClientFactory, + timeout: float = 30, + bindAddress: Optional[Tuple[str, int]] = None, + ) -> IConnector: """Fake L{IReactorTCP.connectTCP}.""" conn = super().connectTCP( @@ -472,7 +534,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): return conn - def advance(self, amount): + def advance(self, amount: float) -> None: # first advance our reactor's time, and run any "callLater" callbacks that # makes ready super().advance(amount) @@ -500,25 +562,33 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): class ThreadPool: """ Threadless thread pool. + + See twisted.python.threadpool.ThreadPool """ - def __init__(self, reactor): + def __init__(self, reactor: IReactorTime): self._reactor = reactor - def start(self): + def start(self) -> None: pass - def stop(self): + def stop(self) -> None: pass - def callInThreadWithCallback(self, onResult, function, *args, **kwargs): - def _(res): + def callInThreadWithCallback( + self, + onResult: Callable[[bool, Union[Failure, R]], None], + function: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, + ) -> "Deferred[None]": + def _(res: Any) -> None: if isinstance(res, Failure): onResult(False, res) else: onResult(True, res) - d = Deferred() + d: "Deferred[None]" = Deferred() d.addCallback(lambda x: function(*args, **kwargs)) d.addBoth(_) self._reactor.callLater(0, d.callback, True) @@ -535,7 +605,9 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None: for database in server.get_datastores().databases: pool = database._db_pool - def runWithConnection(func, *args, **kwargs): + def runWithConnection( + func: Callable[..., R], *args: Any, **kwargs: Any + ) -> Awaitable[R]: return threads.deferToThreadPool( pool._reactor, pool.threadpool, @@ -545,20 +617,23 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None: **kwargs, ) - def runInteraction(interaction, *args, **kwargs): + def runInteraction( + desc: str, func: Callable[..., R], *args: Any, **kwargs: Any + ) -> Awaitable[R]: return threads.deferToThreadPool( pool._reactor, pool.threadpool, pool._runInteraction, - interaction, + desc, + func, *args, **kwargs, ) - pool.runWithConnection = runWithConnection - pool.runInteraction = runInteraction + pool.runWithConnection = runWithConnection # type: ignore[assignment] + pool.runInteraction = runInteraction # type: ignore[assignment] # Replace the thread pool with a threadless 'thread' pool - pool.threadpool = ThreadPool(clock._reactor) + pool.threadpool = ThreadPool(clock._reactor) # type: ignore[assignment] pool.running = True # We've just changed the Databases to run DB transactions on the same @@ -573,7 +648,7 @@ def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]: @implementer(ITransport) -@attr.s(cmp=False) +@attr.s(cmp=False, auto_attribs=True) class FakeTransport: """ A twisted.internet.interfaces.ITransport implementation which sends all its data @@ -588,48 +663,50 @@ class FakeTransport: If you want bidirectional communication, you'll need two instances. """ - other = attr.ib() + other: IProtocol """The Protocol object which will receive any data written to this transport. - - :type: twisted.internet.interfaces.IProtocol """ - _reactor = attr.ib() + _reactor: IReactorTime """Test reactor - - :type: twisted.internet.interfaces.IReactorTime """ - _protocol = attr.ib(default=None) + _protocol: Optional[IProtocol] = None """The Protocol which is producing data for this transport. Optional, but if set will get called back for connectionLost() notifications etc. """ - _peer_address: Optional[IAddress] = attr.ib(default=None) + _peer_address: IAddress = attr.Factory( + lambda: address.IPv4Address("TCP", "127.0.0.1", 5678) + ) """The value to be returned by getPeer""" - _host_address: Optional[IAddress] = attr.ib(default=None) + _host_address: IAddress = attr.Factory( + lambda: address.IPv4Address("TCP", "127.0.0.1", 1234) + ) """The value to be returned by getHost""" disconnecting = False disconnected = False connected = True - buffer = attr.ib(default=b"") - producer = attr.ib(default=None) - autoflush = attr.ib(default=True) + buffer: bytes = b"" + producer: Optional[IPushProducer] = None + autoflush: bool = True - def getPeer(self) -> Optional[IAddress]: + def getPeer(self) -> IAddress: return self._peer_address - def getHost(self) -> Optional[IAddress]: + def getHost(self) -> IAddress: return self._host_address - def loseConnection(self, reason=None): + def loseConnection(self) -> None: if not self.disconnecting: - logger.info("FakeTransport: loseConnection(%s)", reason) + logger.info("FakeTransport: loseConnection()") self.disconnecting = True if self._protocol: - self._protocol.connectionLost(reason) + self._protocol.connectionLost( + Failure(RuntimeError("FakeTransport.loseConnection()")) + ) # if we still have data to write, delay until that is done if self.buffer: @@ -640,38 +717,38 @@ class FakeTransport: self.connected = False self.disconnected = True - def abortConnection(self): + def abortConnection(self) -> None: logger.info("FakeTransport: abortConnection()") if not self.disconnecting: self.disconnecting = True if self._protocol: - self._protocol.connectionLost(None) + self._protocol.connectionLost(None) # type: ignore[arg-type] self.disconnected = True - def pauseProducing(self): + def pauseProducing(self) -> None: if not self.producer: return self.producer.pauseProducing() - def resumeProducing(self): + def resumeProducing(self) -> None: if not self.producer: return self.producer.resumeProducing() - def unregisterProducer(self): + def unregisterProducer(self) -> None: if not self.producer: return self.producer = None - def registerProducer(self, producer, streaming): + def registerProducer(self, producer: IPushProducer, streaming: bool) -> None: self.producer = producer self.producerStreaming = streaming - def _produce(): + def _produce() -> None: if not self.producer: # we've been unregistered return @@ -683,7 +760,7 @@ class FakeTransport: if not streaming: self._reactor.callLater(0.0, _produce) - def write(self, byt): + def write(self, byt: bytes) -> None: if self.disconnecting: raise Exception("Writing to disconnecting FakeTransport") @@ -695,11 +772,11 @@ class FakeTransport: if self.autoflush: self._reactor.callLater(0.0, self.flush) - def writeSequence(self, seq): + def writeSequence(self, seq: Iterable[bytes]) -> None: for x in seq: self.write(x) - def flush(self, maxbytes=None): + def flush(self, maxbytes: Optional[int] = None) -> None: if not self.buffer: # nothing to do. Don't write empty buffers: it upsets the # TLSMemoryBIOProtocol @@ -750,17 +827,17 @@ def connect_client( class TestHomeServer(HomeServer): - DATASTORE_CLASS = DataStore + DATASTORE_CLASS = DataStore # type: ignore[assignment] def setup_test_homeserver( - cleanup_func, - name="test", - config=None, - reactor=None, + cleanup_func: Callable[[Callable[[], None]], None], + name: str = "test", + config: Optional[HomeServerConfig] = None, + reactor: Optional[ISynapseReactor] = None, homeserver_to_use: Type[HomeServer] = TestHomeServer, - **kwargs, -): + **kwargs: Any, +) -> HomeServer: """ Setup a homeserver suitable for running tests against. Keyword arguments are passed to the Homeserver constructor. @@ -775,13 +852,14 @@ def setup_test_homeserver( HomeserverTestCase. """ if reactor is None: - from twisted.internet import reactor + from twisted.internet import reactor as _reactor + + reactor = cast(ISynapseReactor, _reactor) if config is None: config = default_config(name, parse=True) config.caches.resize_all_caches() - config.ldap_enabled = False if "clock" not in kwargs: kwargs["clock"] = MockClock() @@ -832,6 +910,8 @@ def setup_test_homeserver( # Create the database before we actually try and connect to it, based off # the template database we generate in setupdb() if isinstance(db_engine, PostgresEngine): + import psycopg2.extensions + db_conn = db_engine.module.connect( database=POSTGRES_BASE_DB, user=POSTGRES_USER, @@ -839,6 +919,7 @@ def setup_test_homeserver( port=POSTGRES_PORT, password=POSTGRES_PASSWORD, ) + assert isinstance(db_conn, psycopg2.extensions.connection) db_conn.autocommit = True cur = db_conn.cursor() cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) @@ -867,14 +948,15 @@ def setup_test_homeserver( hs.setup_background_tasks() if isinstance(db_engine, PostgresEngine): - database = hs.get_datastores().databases[0] + database_pool = hs.get_datastores().databases[0] # We need to do cleanup on PostgreSQL - def cleanup(): + def cleanup() -> None: import psycopg2 + import psycopg2.extensions # Close all the db pools - database._db_pool.close() + database_pool._db_pool.close() dropped = False @@ -886,6 +968,7 @@ def setup_test_homeserver( port=POSTGRES_PORT, password=POSTGRES_PASSWORD, ) + assert isinstance(db_conn, psycopg2.extensions.connection) db_conn.autocommit = True cur = db_conn.cursor() @@ -918,23 +1001,23 @@ def setup_test_homeserver( # Need to let the HS build an auth handler and then mess with it # because AuthHandler's constructor requires the HS, so we can't make one # beforehand and pass it in to the HS's constructor (chicken / egg) - async def hash(p): + async def hash(p: str) -> str: return hashlib.md5(p.encode("utf8")).hexdigest() - hs.get_auth_handler().hash = hash + hs.get_auth_handler().hash = hash # type: ignore[assignment] - async def validate_hash(p, h): + async def validate_hash(p: str, h: str) -> bool: return hashlib.md5(p.encode("utf8")).hexdigest() == h - hs.get_auth_handler().validate_hash = validate_hash + hs.get_auth_handler().validate_hash = validate_hash # type: ignore[assignment] # Make the threadpool and database transactions synchronous for testing. _make_test_homeserver_synchronous(hs) # Load any configured modules into the homeserver module_api = hs.get_module_api() - for module, config in hs.config.modules.loaded_modules: - module(config=config, api=module_api) + for module, module_config in hs.config.modules.loaded_modules: + module(config=module_config, api=module_api) load_legacy_spam_checkers(hs) load_legacy_third_party_event_rules(hs) diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py
index 58b399a043..6540ed53f1 100644 --- a/tests/server_notices/test_consent.py +++ b/tests/server_notices/test_consent.py
@@ -14,8 +14,12 @@ import os +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.rest.client import login, room, sync +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -29,7 +33,7 @@ class ConsentNoticesTests(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: tmpdir = self.mktemp() os.mkdir(tmpdir) @@ -53,15 +57,13 @@ class ConsentNoticesTests(unittest.HomeserverTestCase): "room_name": "Server Notices", } - hs = self.setup_test_homeserver(config=config) - - return hs + return self.setup_test_homeserver(config=config) - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = self.register_user("bob", "abc123") self.access_token = self.login("bob", "abc123") - def test_get_sync_message(self): + def test_get_sync_message(self) -> None: """ When user consent server notices are enabled, a sync will cause a notice to fire (in a room which the user is invited to). The notice contains diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index dadc6efcbf..d2bfa53eda 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -24,6 +24,8 @@ from synapse.server import HomeServer from synapse.server_notices.resource_limits_server_notices import ( ResourceLimitsServerNotices, ) +from synapse.server_notices.server_notices_sender import ServerNoticesSender +from synapse.types import JsonDict from synapse.util import Clock from tests import unittest @@ -33,7 +35,7 @@ from tests.utils import default_config class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): - def default_config(self): + def default_config(self) -> JsonDict: config = default_config("test") config.update( @@ -57,14 +59,15 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): return config def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.server_notices_sender = self.hs.get_server_notices_sender() + server_notices_sender = self.hs.get_server_notices_sender() + assert isinstance(server_notices_sender, ServerNoticesSender) # relying on [1] is far from ideal, but the only case where # ResourceLimitsServerNotices class needs to be isolated is this test, # general code should never have a reason to do so ... - self._rlsn = self.server_notices_sender._server_notices[1] - if not isinstance(self._rlsn, ResourceLimitsServerNotices): - raise Exception("Failed to find reference to ResourceLimitsServerNotices") + rlsn = list(server_notices_sender._server_notices)[1] + assert isinstance(rlsn, ResourceLimitsServerNotices) + self._rlsn = rlsn self._rlsn._store.user_last_seen_monthly_active = Mock( return_value=make_awaitable(1000) @@ -86,39 +89,43 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({})) # type: ignore[assignment] @override_config({"hs_disabled": True}) - def test_maybe_send_server_notice_disabled_hs(self): + def test_maybe_send_server_notice_disabled_hs(self) -> None: """If the HS is disabled, we should not send notices""" self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() @override_config({"limit_usage_by_mau": False}) - def test_maybe_send_server_notice_to_user_flag_off(self): + def test_maybe_send_server_notice_to_user_flag_off(self) -> None: """If mau limiting is disabled, we should not send notices""" self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() - def test_maybe_send_server_notice_to_user_remove_blocked_notice(self): + def test_maybe_send_server_notice_to_user_remove_blocked_notice(self) -> None: """Test when user has blocked notice, but should have it removed""" - self._rlsn._auth_blocking.check_auth_blocking = Mock( + self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] return_value=make_awaitable(None) ) mock_event = Mock( type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) - self._rlsn._store.get_events = Mock( + self._rlsn._store.get_events = Mock( # type: ignore[assignment] return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check the content, but once == remove blocking event - self._rlsn._server_notices_manager.maybe_get_notice_room_for_user.assert_called_once() + maybe_get_notice_room_for_user = ( + self._rlsn._server_notices_manager.maybe_get_notice_room_for_user + ) + assert isinstance(maybe_get_notice_room_for_user, Mock) + maybe_get_notice_room_for_user.assert_called_once() self._send_notice.assert_called_once() - def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self): + def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self) -> None: """ Test when user has blocked notice, but notice ought to be there (NOOP) """ - self._rlsn._auth_blocking.check_auth_blocking = Mock( + self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] return_value=make_awaitable(None), side_effect=ResourceLimitError(403, "foo"), ) @@ -126,7 +133,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): mock_event = Mock( type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) - self._rlsn._store.get_events = Mock( + self._rlsn._store.get_events = Mock( # type: ignore[assignment] return_value=make_awaitable({"123": mock_event}) ) @@ -134,11 +141,11 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self._send_notice.assert_not_called() - def test_maybe_send_server_notice_to_user_add_blocked_notice(self): + def test_maybe_send_server_notice_to_user_add_blocked_notice(self) -> None: """ Test when user does not have blocked notice, but should have one """ - self._rlsn._auth_blocking.check_auth_blocking = Mock( + self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] return_value=make_awaitable(None), side_effect=ResourceLimitError(403, "foo"), ) @@ -147,11 +154,11 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): # Would be better to check contents, but 2 calls == set blocking event self.assertEqual(self._send_notice.call_count, 2) - def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self): + def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self) -> None: """ Test when user does not have blocked notice, nor should they (NOOP) """ - self._rlsn._auth_blocking.check_auth_blocking = Mock( + self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] return_value=make_awaitable(None) ) @@ -159,12 +166,12 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self._send_notice.assert_not_called() - def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self): + def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self) -> None: """ Test when user is not part of the MAU cohort - this should not ever happen - but ... """ - self._rlsn._auth_blocking.check_auth_blocking = Mock( + self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] return_value=make_awaitable(None) ) self._rlsn._store.user_last_seen_monthly_active = Mock( @@ -175,12 +182,14 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self._send_notice.assert_not_called() @override_config({"mau_limit_alerting": False}) - def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(self): + def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked( + self, + ) -> None: """ Test that when server is over MAU limit and alerting is suppressed, then an alert message is not sent into the room """ - self._rlsn._auth_blocking.check_auth_blocking = Mock( + self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] return_value=make_awaitable(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER @@ -191,11 +200,11 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self.assertEqual(self._send_notice.call_count, 0) @override_config({"mau_limit_alerting": False}) - def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self): + def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self) -> None: """ Test that when a server is disabled, that MAU limit alerting is ignored. """ - self._rlsn._auth_blocking.check_auth_blocking = Mock( + self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] return_value=make_awaitable(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED @@ -207,26 +216,28 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self.assertEqual(self._send_notice.call_count, 2) @override_config({"mau_limit_alerting": False}) - def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(self): + def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked( + self, + ) -> None: """ When the room is already in a blocked state, test that when alerting is suppressed that the room is returned to an unblocked state. """ - self._rlsn._auth_blocking.check_auth_blocking = Mock( + self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] return_value=make_awaitable(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER ), ) - self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock( + self._rlsn._is_room_currently_blocked = Mock( # type: ignore[assignment] return_value=make_awaitable((True, [])) ) mock_event = Mock( type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) - self._rlsn._store.get_events = Mock( + self._rlsn._store.get_events = Mock( # type: ignore[assignment] return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -242,7 +253,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): sync.register_servlets, ] - def default_config(self): + def default_config(self) -> JsonDict: c = super().default_config() c["server_notices"] = { "system_mxid_localpart": "server", @@ -257,20 +268,22 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = self.hs.get_datastores().main - self.server_notices_sender = self.hs.get_server_notices_sender() self.server_notices_manager = self.hs.get_server_notices_manager() self.event_source = self.hs.get_event_sources() + server_notices_sender = self.hs.get_server_notices_sender() + assert isinstance(server_notices_sender, ServerNoticesSender) + # relying on [1] is far from ideal, but the only case where # ResourceLimitsServerNotices class needs to be isolated is this test, # general code should never have a reason to do so ... - self._rlsn = self.server_notices_sender._server_notices[1] - if not isinstance(self._rlsn, ResourceLimitsServerNotices): - raise Exception("Failed to find reference to ResourceLimitsServerNotices") + rlsn = list(server_notices_sender._server_notices)[1] + assert isinstance(rlsn, ResourceLimitsServerNotices) + self._rlsn = rlsn self.user_id = "@user_id:test" - def test_server_notice_only_sent_once(self): + def test_server_notice_only_sent_once(self) -> None: self.store.get_monthly_active_count = Mock(return_value=make_awaitable(1000)) self.store.user_last_seen_monthly_active = Mock( @@ -306,7 +319,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): self.assertEqual(count, 1) - def test_no_invite_without_notice(self): + def test_no_invite_without_notice(self) -> None: """Tests that a user doesn't get invited to a server notices room without a server notice being sent. @@ -328,7 +341,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): m.assert_called_once_with(user_id) - def test_invite_with_notice(self): + def test_invite_with_notice(self) -> None: """Tests that, if the MAU limit is hit, the server notices user invites each user to a room in which it has sent a notice. """ diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 9f33afcca0..9606ecc43b 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py
@@ -120,6 +120,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): # Persist the event which should invalidate or prefill the # `have_seen_event` cache so we don't return stale values. persistence = self.hs.get_storage_controllers().persistence + assert persistence is not None self.get_success( persistence.persist_event( event, diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index c070278db8..a10e5fa8b1 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py
@@ -389,6 +389,7 @@ class EventChainStoreTestCase(HomeserverTestCase): """ persist_events_store = self.hs.get_datastores().persist_events + assert persist_events_store is not None for e in events: e.internal_metadata.stream_ordering = self._next_stream_ordering @@ -397,6 +398,7 @@ class EventChainStoreTestCase(HomeserverTestCase): def _persist(txn: LoggingTransaction) -> None: # We need to persist the events to the events and state_events # tables. + assert persist_events_store is not None persist_events_store._store_event_txn( txn, [(e, EventContext(self.hs.get_storage_controllers())) for e in events], @@ -540,7 +542,9 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): self.requester, events_and_context=[(event, context)] ) ) - state1 = set(self.get_success(context.get_current_state_ids()).values()) + state_ids1 = self.get_success(context.get_current_state_ids()) + assert state_ids1 is not None + state1 = set(state_ids1.values()) event, context = self.get_success( event_handler.create_event( @@ -560,7 +564,9 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): self.requester, events_and_context=[(event, context)] ) ) - state2 = set(self.get_success(context.get_current_state_ids()).values()) + state_ids2 = self.get_success(context.get_current_state_ids()) + assert state_ids2 is not None + state2 = set(state_ids2.values()) # Delete the chain cover info. diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 7fd3e01364..8fc7936ab0 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py
@@ -54,6 +54,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main + persist_events = hs.get_datastores().persist_events + assert persist_events is not None + self.persist_events = persist_events def test_get_prev_events_for_room(self) -> None: room_id = "@ROOM:local" @@ -226,7 +229,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): }, ) - self.hs.datastores.persist_events._persist_event_auth_chain_txn( + self.persist_events._persist_event_auth_chain_txn( txn, [ cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id])) @@ -445,7 +448,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): ) # Insert all events apart from 'B' - self.hs.datastores.persist_events._persist_event_auth_chain_txn( + self.persist_events._persist_event_auth_chain_txn( txn, [ cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id])) @@ -464,7 +467,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): updatevalues={"has_auth_chain_index": False}, ) - self.hs.datastores.persist_events._persist_event_auth_chain_txn( + self.persist_events._persist_event_auth_chain_txn( txn, [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))], ) diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index 05661a537d..e67dd0589d 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py
@@ -40,7 +40,9 @@ class ExtremPruneTestCase(HomeserverTestCase): self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer ) -> None: self.state = self.hs.get_state_handler() - self._persistence = self.hs.get_storage_controllers().persistence + persistence = self.hs.get_storage_controllers().persistence + assert persistence is not None + self._persistence = persistence self._state_storage_controller = self.hs.get_storage_controllers().state self.store = self.hs.get_datastores().main @@ -374,7 +376,9 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer ) -> None: self.state = self.hs.get_state_handler() - self._persistence = self.hs.get_storage_controllers().persistence + persistence = self.hs.get_storage_controllers().persistence + assert persistence is not None + self._persistence = persistence self.store = self.hs.get_datastores().main def test_remote_user_rooms_cache_invalidated(self) -> None: diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index aa4b5bd3b1..ba68171ad7 100644 --- a/tests/storage/test_keys.py +++ b/tests/storage/test_keys.py
@@ -16,8 +16,6 @@ import signedjson.key import signedjson.types import unpaddedbase64 -from twisted.internet.defer import Deferred - from synapse.storage.keys import FetchKeyResult import tests.unittest @@ -44,20 +42,26 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): key_id_1 = "ed25519:key1" key_id_2 = "ed25519:KEY_ID_2" - d = store.store_server_verify_keys( - "from_server", - 10, - [ - ("server1", key_id_1, FetchKeyResult(KEY_1, 100)), - ("server1", key_id_2, FetchKeyResult(KEY_2, 200)), - ], + self.get_success( + store.store_server_verify_keys( + "from_server", + 10, + [ + ("server1", key_id_1, FetchKeyResult(KEY_1, 100)), + ("server1", key_id_2, FetchKeyResult(KEY_2, 200)), + ], + ) ) - self.get_success(d) - d = store.get_server_verify_keys( - [("server1", key_id_1), ("server1", key_id_2), ("server1", "ed25519:key3")] + res = self.get_success( + store.get_server_verify_keys( + [ + ("server1", key_id_1), + ("server1", key_id_2), + ("server1", "ed25519:key3"), + ] + ) ) - res = self.get_success(d) self.assertEqual(len(res.keys()), 3) res1 = res[("server1", key_id_1)] @@ -82,18 +86,20 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): key_id_1 = "ed25519:key1" key_id_2 = "ed25519:key2" - d = store.store_server_verify_keys( - "from_server", - 0, - [ - ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)), - ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)), - ], + self.get_success( + store.store_server_verify_keys( + "from_server", + 0, + [ + ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)), + ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)), + ], + ) ) - self.get_success(d) - d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)]) - res = self.get_success(d) + res = self.get_success( + store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)]) + ) self.assertEqual(len(res.keys()), 2) res1 = res[("srv1", key_id_1)] @@ -105,9 +111,7 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): self.assertEqual(res2.valid_until_ts, 200) # we should be able to look up the same thing again without a db hit - res = store.get_server_verify_keys([("srv1", key_id_1)]) - if isinstance(res, Deferred): - res = self.successResultOf(res) + res = self.get_success(store.get_server_verify_keys([("srv1", key_id_1)])) self.assertEqual(len(res.keys()), 1) self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1) @@ -119,8 +123,9 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): ) self.get_success(d) - d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)]) - res = self.get_success(d) + res = self.get_success( + store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)]) + ) self.assertEqual(len(res.keys()), 2) res1 = res[("srv1", key_id_1)] diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 010cc74c31..d8f42c5d05 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py
@@ -112,7 +112,7 @@ class PurgeTests(HomeserverTestCase): self.room_id, "m.room.create", "" ) ) - self.assertIsNotNone(create_event) + assert create_event is not None # Purge everything before this topological token self.get_success( diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py
index d8d84152dc..12c17f1073 100644 --- a/tests/storage/test_receipts.py +++ b/tests/storage/test_receipts.py
@@ -37,9 +37,9 @@ class ReceiptTestCase(HomeserverTestCase): self.store = homeserver.get_datastores().main self.room_creator = homeserver.get_room_creation_handler() - self.persist_event_storage_controller = ( - self.hs.get_storage_controllers().persistence - ) + persist_event_storage_controller = self.hs.get_storage_controllers().persistence + assert persist_event_storage_controller is not None + self.persist_event_storage_controller = persist_event_storage_controller # Create a test user self.ourUser = UserID.from_string(OUR_USER_ID) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index df4740f9d9..0100f7da14 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py
@@ -74,10 +74,12 @@ class RedactionTestCase(unittest.HomeserverTestCase): }, ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) + context = self.get_success(unpersisted_context.persist(event)) + self.get_success(self._persistence.persist_event(event, context)) return event @@ -96,10 +98,12 @@ class RedactionTestCase(unittest.HomeserverTestCase): }, ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) + context = self.get_success(unpersisted_context.persist(event)) + self.get_success(self._persistence.persist_event(event, context)) return event @@ -119,10 +123,12 @@ class RedactionTestCase(unittest.HomeserverTestCase): }, ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) + context = self.get_success(unpersisted_context.persist(event)) + self.get_success(self._persistence.persist_event(event, context)) return event @@ -259,7 +265,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): def internal_metadata(self) -> _EventInternalMetadata: return self._base_builder.internal_metadata - event_1, context_1 = self.get_success( + event_1, unpersisted_context_1 = self.get_success( self.event_creation_handler.create_new_client_event( cast( EventBuilder, @@ -280,9 +286,11 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) ) + context_1 = self.get_success(unpersisted_context_1.persist(event_1)) + self.get_success(self._persistence.persist_event(event_1, context_1)) - event_2, context_2 = self.get_success( + event_2, unpersisted_context_2 = self.get_success( self.event_creation_handler.create_new_client_event( cast( EventBuilder, @@ -302,6 +310,8 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) ) ) + + context_2 = self.get_success(unpersisted_context_2.persist(event_2)) self.get_success(self._persistence.persist_event(event_2, context_2)) # fetch one of the redactions @@ -421,10 +431,12 @@ class RedactionTestCase(unittest.HomeserverTestCase): }, ) - redaction_event, context = self.get_success( + redaction_event, unpersisted_context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) + context = self.get_success(unpersisted_context.persist(redaction_event)) + self.get_success(self._persistence.persist_event(redaction_event, context)) # Now lets jump to the future where we have censored the redaction event diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py
index 14d872514d..f183c38477 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py
@@ -119,7 +119,6 @@ class EventSearchInsertionTest(HomeserverTestCase): "content": {"msgtype": "m.text", "body": 2}, "room_id": room_id, "sender": user_id, - "depth": prev_event.depth + 1, "prev_events": prev_event_ids, "origin_server_ts": self.clock.time_msec(), } @@ -134,7 +133,7 @@ class EventSearchInsertionTest(HomeserverTestCase): prev_state_map, for_verification=False, ), - depth=event_dict["depth"], + depth=prev_event.depth + 1, ) ) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index bad7f0bc60..f730b888f7 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py
@@ -67,10 +67,12 @@ class StateStoreTestCase(HomeserverTestCase): }, ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) + context = self.get_success(unpersisted_context.persist(event)) + assert self.storage.persistence is not None self.get_success(self.storage.persistence.persist_event(event, context)) diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
index bc090ebce0..05dc4f64b8 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py
@@ -16,7 +16,7 @@ from typing import List from twisted.test.proto_helpers import MemoryReactor -from synapse.api.constants import EventTypes, RelationTypes +from synapse.api.constants import Direction, EventTypes, RelationTypes from synapse.api.filtering import Filter from synapse.rest import admin from synapse.rest.client import login, room @@ -128,7 +128,7 @@ class PaginationTestCase(HomeserverTestCase): room_id=self.room_id, from_key=self.from_token.room_key, to_key=None, - direction="f", + direction=Direction.FORWARDS, limit=10, event_filter=Filter(self.hs, filter), ) diff --git a/tests/storage/test_unsafe_locale.py b/tests/storage/test_unsafe_locale.py
index ba53c22818..19da8a9b09 100644 --- a/tests/storage/test_unsafe_locale.py +++ b/tests/storage/test_unsafe_locale.py
@@ -14,6 +14,7 @@ from unittest.mock import MagicMock, patch from synapse.storage.database import make_conn +from synapse.storage.engines import PostgresEngine from synapse.storage.engines._base import IncorrectDatabaseSetup from tests.unittest import HomeserverTestCase @@ -38,6 +39,7 @@ class UnsafeLocaleTest(HomeserverTestCase): def test_safe_locale(self) -> None: database = self.hs.get_datastores().databases[0] + assert isinstance(database.engine, PostgresEngine) db_conn = make_conn(database._database_config, database.engine, "test_unsafe") with db_conn.cursor() as txn: diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index f1ca523d23..2d169684cf 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py
@@ -25,6 +25,11 @@ from synapse.rest.client import login, register, room from synapse.server import HomeServer from synapse.storage import DataStore from synapse.storage.background_updates import _BackgroundUpdateHandler +from synapse.storage.databases.main import user_directory +from synapse.storage.databases.main.user_directory import ( + _parse_words_with_icu, + _parse_words_with_regex, +) from synapse.storage.roommember import ProfileInfo from synapse.util import Clock @@ -42,7 +47,7 @@ ALICE = "@alice:a" BOB = "@bob:b" BOBBY = "@bobby:a" # The localpart isn't 'Bela' on purpose so we can test looking up display names. -BELA = "@somenickname:a" +BELA = "@somenickname:example.org" class GetUserDirectoryTables: @@ -423,6 +428,8 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): class UserDirectoryStoreTestCase(HomeserverTestCase): + use_icu = False + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main @@ -434,6 +441,12 @@ class UserDirectoryStoreTestCase(HomeserverTestCase): self.get_success(self.store.update_profile_in_user_dir(BELA, "Bela", None)) self.get_success(self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))) + self._restore_use_icu = user_directory.USE_ICU + user_directory.USE_ICU = self.use_icu + + def tearDown(self) -> None: + user_directory.USE_ICU = self._restore_use_icu + def test_search_user_dir(self) -> None: # normally when alice searches the directory she should just find # bob because bobby doesn't share a room with her. @@ -478,6 +491,26 @@ class UserDirectoryStoreTestCase(HomeserverTestCase): {"user_id": BELA, "display_name": "Bela", "avatar_url": None}, ) + @override_config({"user_directory": {"search_all_users": True}}) + def test_search_user_dir_start_of_user_id(self) -> None: + """Tests that a user can look up another user by searching for the start + of their user ID. + """ + r = self.get_success(self.store.search_user_dir(ALICE, "somenickname:exa", 10)) + self.assertFalse(r["limited"]) + self.assertEqual(1, len(r["results"])) + self.assertDictEqual( + r["results"][0], + {"user_id": BELA, "display_name": "Bela", "avatar_url": None}, + ) + + +class UserDirectoryStoreTestCaseWithIcu(UserDirectoryStoreTestCase): + use_icu = True + + if not icu: + skip = "Requires PyICU" + class UserDirectoryICUTestCase(HomeserverTestCase): if not icu: @@ -513,3 +546,31 @@ class UserDirectoryICUTestCase(HomeserverTestCase): r["results"][0], {"user_id": ALICE, "display_name": display_name, "avatar_url": None}, ) + + def test_icu_word_boundary_punctuation(self) -> None: + """ + Tests the behaviour of punctuation with the ICU tokeniser. + + Seems to depend on underlying version of ICU. + """ + + # Note: either tokenisation is fine, because Postgres actually splits + # words itself afterwards. + self.assertIn( + _parse_words_with_icu("lazy'fox jumped:over the.dog"), + ( + # ICU 66 on Ubuntu 20.04 + ["lazy'fox", "jumped", "over", "the", "dog"], + # ICU 70 on Ubuntu 22.04 + ["lazy'fox", "jumped:over", "the.dog"], + ), + ) + + def test_regex_word_boundary_punctuation(self) -> None: + """ + Tests the behaviour of punctuation with the non-ICU tokeniser + """ + self.assertEqual( + _parse_words_with_regex("lazy'fox jumped:over the.dog"), + ["lazy", "fox", "jumped", "over", "the", "dog"], + ) diff --git a/tests/test_distributor.py b/tests/test_distributor.py
index 31546ea52b..a248f1d277 100644 --- a/tests/test_distributor.py +++ b/tests/test_distributor.py
@@ -21,10 +21,10 @@ from . import unittest class DistributorTestCase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.dist = Distributor() - def test_signal_dispatch(self): + def test_signal_dispatch(self) -> None: self.dist.declare("alert") observer = Mock() @@ -33,7 +33,7 @@ class DistributorTestCase(unittest.TestCase): self.dist.fire("alert", 1, 2, 3) observer.assert_called_with(1, 2, 3) - def test_signal_catch(self): + def test_signal_catch(self) -> None: self.dist.declare("alarm") observers = [Mock() for i in (1, 2)] @@ -51,7 +51,7 @@ class DistributorTestCase(unittest.TestCase): self.assertEqual(mock_logger.warning.call_count, 1) self.assertIsInstance(mock_logger.warning.call_args[0][0], str) - def test_signal_prereg(self): + def test_signal_prereg(self) -> None: observer = Mock() self.dist.observe("flare", observer) @@ -60,8 +60,8 @@ class DistributorTestCase(unittest.TestCase): observer.assert_called_with(4, 5) - def test_signal_undeclared(self): - def code(): + def test_signal_undeclared(self) -> None: + def code() -> None: self.dist.fire("notification") self.assertRaises(KeyError, code) diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 0a7937f1cc..2860564afc 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py
@@ -31,13 +31,13 @@ from tests.test_utils import get_awaitable_result class _StubEventSourceStore: """A stub implementation of the EventSourceStore""" - def __init__(self): + def __init__(self) -> None: self._store: Dict[str, EventBase] = {} - def add_event(self, event: EventBase): + def add_event(self, event: EventBase) -> None: self._store[event.event_id] = event - def add_events(self, events: Iterable[EventBase]): + def add_events(self, events: Iterable[EventBase]) -> None: for event in events: self._store[event.event_id] = event @@ -59,7 +59,7 @@ class _StubEventSourceStore: class EventAuthTestCase(unittest.TestCase): - def test_rejected_auth_events(self): + def test_rejected_auth_events(self) -> None: """ Events that refer to rejected events in their auth events are rejected """ @@ -109,7 +109,7 @@ class EventAuthTestCase(unittest.TestCase): ) ) - def test_create_event_with_prev_events(self): + def test_create_event_with_prev_events(self) -> None: """A create event with prev_events should be rejected https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules @@ -150,7 +150,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_state_independent_auth_rules(event_store, bad_event) ) - def test_duplicate_auth_events(self): + def test_duplicate_auth_events(self) -> None: """Events with duplicate auth_events should be rejected https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules @@ -196,7 +196,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_state_independent_auth_rules(event_store, bad_event2) ) - def test_unexpected_auth_events(self): + def test_unexpected_auth_events(self) -> None: """Events with excess auth_events should be rejected https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules @@ -236,7 +236,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_state_independent_auth_rules(event_store, bad_event) ) - def test_random_users_cannot_send_state_before_first_pl(self): + def test_random_users_cannot_send_state_before_first_pl(self) -> None: """ Check that, before the first PL lands, the creator is the only user that can send a state event. @@ -263,7 +263,7 @@ class EventAuthTestCase(unittest.TestCase): auth_events, ) - def test_state_default_level(self): + def test_state_default_level(self) -> None: """ Check that users above the state_default level can send state and those below cannot @@ -298,7 +298,7 @@ class EventAuthTestCase(unittest.TestCase): auth_events, ) - def test_alias_event(self): + def test_alias_event(self) -> None: """Alias events have special behavior up through room version 6.""" creator = "@creator:example.com" other = "@other:example.com" @@ -333,7 +333,7 @@ class EventAuthTestCase(unittest.TestCase): auth_events, ) - def test_msc2432_alias_event(self): + def test_msc2432_alias_event(self) -> None: """After MSC2432, alias events have no special behavior.""" creator = "@creator:example.com" other = "@other:example.com" @@ -366,7 +366,9 @@ class EventAuthTestCase(unittest.TestCase): ) @parameterized.expand([(RoomVersions.V1, True), (RoomVersions.V6, False)]) - def test_notifications(self, room_version: RoomVersion, allow_modification: bool): + def test_notifications( + self, room_version: RoomVersion, allow_modification: bool + ) -> None: """ Notifications power levels get checked due to MSC2209. """ @@ -395,7 +397,7 @@ class EventAuthTestCase(unittest.TestCase): with self.assertRaises(AuthError): event_auth.check_state_dependent_auth_rules(pl_event, auth_events) - def test_join_rules_public(self): + def test_join_rules_public(self) -> None: """ Test joining a public room. """ @@ -460,7 +462,7 @@ class EventAuthTestCase(unittest.TestCase): auth_events.values(), ) - def test_join_rules_invite(self): + def test_join_rules_invite(self) -> None: """ Test joining an invite only room. """ @@ -835,7 +837,7 @@ def _power_levels_event( ) -def _alias_event(room_version: RoomVersion, sender: str, **kwargs) -> EventBase: +def _alias_event(room_version: RoomVersion, sender: str, **kwargs: Any) -> EventBase: data = { "room_id": TEST_ROOM_ID, **_maybe_get_event_id_dict_for_room_version(room_version), diff --git a/tests/test_federation.py b/tests/test_federation.py
index 80e5c590d8..82dfd88b99 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py
@@ -12,53 +12,48 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Collection, List, Optional, Union from unittest.mock import Mock -from twisted.internet.defer import succeed +from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import FederationError -from synapse.api.room_versions import RoomVersions -from synapse.events import make_event_from_dict +from synapse.api.room_versions import RoomVersion, RoomVersions +from synapse.events import EventBase, make_event_from_dict +from synapse.events.snapshot import EventContext from synapse.federation.federation_base import event_from_pdu_json +from synapse.handlers.device import DeviceListUpdater +from synapse.http.types import QueryParams from synapse.logging.context import LoggingContext -from synapse.types import UserID, create_requester +from synapse.server import HomeServer +from synapse.types import JsonDict, UserID, create_requester from synapse.util import Clock from synapse.util.retryutils import NotRetryingDestination from tests import unittest -from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver from tests.test_utils import make_awaitable class MessageAcceptTests(unittest.HomeserverTestCase): - def setUp(self): - + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.http_client = Mock() - self.reactor = ThreadedMemoryReactorClock() - self.hs_clock = Clock(self.reactor) - self.homeserver = setup_test_homeserver( - self.addCleanup, - federation_http_client=self.http_client, - clock=self.hs_clock, - reactor=self.reactor, - ) + return self.setup_test_homeserver(federation_http_client=self.http_client) + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: user_id = UserID("us", "test") our_user = create_requester(user_id) - room_creator = self.homeserver.get_room_creation_handler() + room_creator = self.hs.get_room_creation_handler() self.room_id = self.get_success( room_creator.create_room( our_user, room_creator._presets_dict["public_chat"], ratelimit=False ) )[0]["room_id"] - self.store = self.homeserver.get_datastores().main + self.store = self.hs.get_datastores().main # Figure out what the most recent event is most_recent = self.get_success( - self.homeserver.get_datastores().main.get_latest_event_ids_in_room( - self.room_id - ) + self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) )[0] join_event = make_event_from_dict( @@ -78,17 +73,23 @@ class MessageAcceptTests(unittest.HomeserverTestCase): } ) - self.handler = self.homeserver.get_federation_handler() - federation_event_handler = self.homeserver.get_federation_event_handler() + self.handler = self.hs.get_federation_handler() + federation_event_handler = self.hs.get_federation_event_handler() - async def _check_event_auth(origin, event, context): + async def _check_event_auth( + origin: Optional[str], event: EventBase, context: EventContext + ) -> None: pass - federation_event_handler._check_event_auth = _check_event_auth - self.client = self.homeserver.get_federation_client() - self.client._check_sigs_and_hash_for_pulled_events_and_fetch = ( - lambda dest, pdus, **k: succeed(pdus) - ) + federation_event_handler._check_event_auth = _check_event_auth # type: ignore[assignment] + self.client = self.hs.get_federation_client() + + async def _check_sigs_and_hash_for_pulled_events_and_fetch( + dest: str, pdus: Collection[EventBase], room_version: RoomVersion + ) -> List[EventBase]: + return list(pdus) + + self.client._check_sigs_and_hash_for_pulled_events_and_fetch = _check_sigs_and_hash_for_pulled_events_and_fetch # type: ignore[assignment] # Send the join, it should return None (which is not an error) self.assertEqual( @@ -104,16 +105,25 @@ class MessageAcceptTests(unittest.HomeserverTestCase): "$join:test.serv", ) - def test_cant_hide_direct_ancestors(self): + def test_cant_hide_direct_ancestors(self) -> None: """ If you send a message, you must be able to provide the direct prev_events that said event references. """ - async def post_json(destination, path, data, headers=None, timeout=0): + async def post_json( + destination: str, + path: str, + data: Optional[JsonDict] = None, + long_retries: bool = False, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + args: Optional[QueryParams] = None, + ) -> Union[JsonDict, list]: # If it asks us for new missing events, give them NOTHING if path.startswith("/_matrix/federation/v1/get_missing_events/"): return {"events": []} + return {} self.http_client.post_json = post_json @@ -138,7 +148,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): } ) - federation_event_handler = self.homeserver.get_federation_event_handler() + federation_event_handler = self.hs.get_federation_event_handler() with LoggingContext("test-context"): failure = self.get_failure( federation_event_handler.on_receive_pdu("test.serv", lying_event), @@ -158,7 +168,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)) self.assertEqual(extrem[0], "$join:test.serv") - def test_retry_device_list_resync(self): + def test_retry_device_list_resync(self) -> None: """Tests that device lists are marked as stale if they couldn't be synced, and that stale device lists are retried periodically. """ @@ -171,24 +181,27 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # When this function is called, increment the number of resync attempts (only if # we're querying devices for the right user ID), then raise a # NotRetryingDestination error to fail the resync gracefully. - def query_user_devices(destination, user_id): + def query_user_devices( + destination: str, user_id: str, timeout: int = 30000 + ) -> JsonDict: if user_id == remote_user_id: self.resync_attempts += 1 raise NotRetryingDestination(0, 0, destination) # Register the mock on the federation client. - federation_client = self.homeserver.get_federation_client() - federation_client.query_user_devices = Mock(side_effect=query_user_devices) + federation_client = self.hs.get_federation_client() + federation_client.query_user_devices = Mock(side_effect=query_user_devices) # type: ignore[assignment] # Register a mock on the store so that the incoming update doesn't fail because # we don't share a room with the user. - store = self.homeserver.get_datastores().main + store = self.hs.get_datastores().main store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"])) # Manually inject a fake device list update. We need this update to include at # least one prev_id so that the user's device list will need to be retried. - device_list_updater = self.homeserver.get_device_handler().device_list_updater + device_list_updater = self.hs.get_device_handler().device_list_updater + assert isinstance(device_list_updater, DeviceListUpdater) self.get_success( device_list_updater.incoming_device_list_update( origin=remote_origin, @@ -218,7 +231,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): self.reactor.advance(30) self.assertEqual(self.resync_attempts, 2) - def test_cross_signing_keys_retry(self): + def test_cross_signing_keys_retry(self) -> None: """Tests that resyncing a device list correctly processes cross-signing keys from the remote server. """ @@ -227,8 +240,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase): remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" # Register mock device list retrieval on the federation client. - federation_client = self.homeserver.get_federation_client() - federation_client.query_user_devices = Mock( + federation_client = self.hs.get_federation_client() + federation_client.query_user_devices = Mock( # type: ignore[assignment] return_value=make_awaitable( { "user_id": remote_user_id, @@ -252,7 +265,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) # Resync the device list. - device_handler = self.homeserver.get_device_handler() + device_handler = self.hs.get_device_handler() self.get_success( device_handler.device_list_updater.user_device_resync(remote_user_id), ) @@ -261,16 +274,18 @@ class MessageAcceptTests(unittest.HomeserverTestCase): keys = self.get_success( self.store.get_e2e_cross_signing_keys_bulk(user_ids=[remote_user_id]), ) - self.assertTrue(remote_user_id in keys) + self.assertIn(remote_user_id, keys) + key = keys[remote_user_id] + assert key is not None # Check that the master key is the one returned by the mock. - master_key = keys[remote_user_id]["master"] + master_key = key["master"] self.assertEqual(len(master_key["keys"]), 1) self.assertTrue("ed25519:" + remote_master_key in master_key["keys"].keys()) self.assertTrue(remote_master_key in master_key["keys"].values()) # Check that the self-signing key is the one returned by the mock. - self_signing_key = keys[remote_user_id]["self_signing"] + self_signing_key = key["self_signing"] self.assertEqual(len(self_signing_key["keys"]), 1) self.assertTrue( "ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(), @@ -279,7 +294,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): class StripUnsignedFromEventsTestCase(unittest.TestCase): - def test_strip_unauthorized_unsigned_values(self): + def test_strip_unauthorized_unsigned_values(self) -> None: event1 = { "sender": "@baduser:test.serv", "state_key": "@baduser:test.serv", @@ -296,7 +311,7 @@ class StripUnsignedFromEventsTestCase(unittest.TestCase): # Make sure unauthorized fields are stripped from unsigned self.assertNotIn("more warez", filtered_event.unsigned) - def test_strip_event_maintains_allowed_fields(self): + def test_strip_event_maintains_allowed_fields(self) -> None: event2 = { "sender": "@baduser:test.serv", "state_key": "@baduser:test.serv", @@ -323,7 +338,7 @@ class StripUnsignedFromEventsTestCase(unittest.TestCase): self.assertIn("invite_room_state", filtered_event2.unsigned) self.assertEqual([], filtered_event2.unsigned["invite_room_state"]) - def test_strip_event_removes_fields_based_on_event_type(self): + def test_strip_event_removes_fields_based_on_event_type(self) -> None: event3 = { "sender": "@baduser:test.serv", "state_key": "@baduser:test.serv", diff --git a/tests/test_mau.py b/tests/test_mau.py
index f14fcb7db9..4e7665a22b 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py
@@ -14,12 +14,17 @@ """Tests REST events for /rooms paths.""" -from typing import List +from typing import List, Optional + +from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.appservice import ApplicationService from synapse.rest.client import register, sync +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest from tests.unittest import override_config @@ -30,7 +35,7 @@ class TestMauLimit(unittest.HomeserverTestCase): servlets = [register.register_servlets, sync.register_servlets] - def default_config(self): + def default_config(self) -> JsonDict: config = default_config("test") config.update( @@ -53,10 +58,12 @@ class TestMauLimit(unittest.HomeserverTestCase): return config - def prepare(self, reactor, clock, homeserver): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: self.store = homeserver.get_datastores().main - def test_simple_deny_mau(self): + def test_simple_deny_mau(self) -> None: # Create and sync so that the MAU counts get updated token1 = self.create_user("kermit1") self.do_sync_for_user(token1) @@ -75,7 +82,7 @@ class TestMauLimit(unittest.HomeserverTestCase): self.assertEqual(e.code, 403) self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - def test_as_ignores_mau(self): + def test_as_ignores_mau(self) -> None: """Test that application services can still create users when the MAU limit has been reached. This only works when application service user ip tracking is disabled. @@ -113,7 +120,7 @@ class TestMauLimit(unittest.HomeserverTestCase): self.create_user("as_kermit4", token=as_token, appservice=True) - def test_allowed_after_a_month_mau(self): + def test_allowed_after_a_month_mau(self) -> None: # Create and sync so that the MAU counts get updated token1 = self.create_user("kermit1") self.do_sync_for_user(token1) @@ -132,7 +139,7 @@ class TestMauLimit(unittest.HomeserverTestCase): self.do_sync_for_user(token3) @override_config({"mau_trial_days": 1}) - def test_trial_delay(self): + def test_trial_delay(self) -> None: # We should be able to register more than the limit initially token1 = self.create_user("kermit1") self.do_sync_for_user(token1) @@ -165,7 +172,7 @@ class TestMauLimit(unittest.HomeserverTestCase): self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) @override_config({"mau_trial_days": 1}) - def test_trial_users_cant_come_back(self): + def test_trial_users_cant_come_back(self) -> None: self.hs.config.server.mau_trial_days = 1 # We should be able to register more than the limit initially @@ -216,7 +223,7 @@ class TestMauLimit(unittest.HomeserverTestCase): # max_mau_value should not matter {"max_mau_value": 1, "limit_usage_by_mau": False, "mau_stats_only": True} ) - def test_tracked_but_not_limited(self): + def test_tracked_but_not_limited(self) -> None: # Simply being able to create 2 users indicates that the # limit was not reached. token1 = self.create_user("kermit1") @@ -236,10 +243,10 @@ class TestMauLimit(unittest.HomeserverTestCase): "mau_appservice_trial_days": {"SomeASID": 1, "AnotherASID": 2}, } ) - def test_as_trial_days(self): + def test_as_trial_days(self) -> None: user_tokens: List[str] = [] - def advance_time_and_sync(): + def advance_time_and_sync() -> None: self.reactor.advance(24 * 60 * 61) for token in user_tokens: self.do_sync_for_user(token) @@ -300,7 +307,9 @@ class TestMauLimit(unittest.HomeserverTestCase): }, ) - def create_user(self, localpart, token=None, appservice=False): + def create_user( + self, localpart: str, token: Optional[str] = None, appservice: bool = False + ) -> str: request_data = { "username": localpart, "password": "monkey", @@ -326,7 +335,7 @@ class TestMauLimit(unittest.HomeserverTestCase): return access_token - def do_sync_for_user(self, token): + def do_sync_for_user(self, token: str) -> None: channel = self.make_request("GET", "/sync", access_token=token) if channel.code != 200: diff --git a/tests/test_phone_home.py b/tests/test_phone_home.py
index cc1a98f1c4..3f899b0d91 100644 --- a/tests/test_phone_home.py +++ b/tests/test_phone_home.py
@@ -33,7 +33,7 @@ class PhoneHomeStatsTestCase(HomeserverTestCase): If time doesn't move, don't error out. """ past_stats = [ - (self.hs.get_clock().time(), resource.getrusage(resource.RUSAGE_SELF)) + (int(self.hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF)) ] stats: JsonDict = {} self.get_success(phone_stats_home(self.hs, stats, past_stats)) diff --git a/tests/test_rust.py b/tests/test_rust.py
index 55d8b6b28c..67443b6280 100644 --- a/tests/test_rust.py +++ b/tests/test_rust.py
@@ -6,6 +6,6 @@ from tests import unittest class RustTestCase(unittest.TestCase): """Basic tests to ensure that we can call into Rust code.""" - def test_basic(self): + def test_basic(self) -> None: result = sum_as_string(1, 2) self.assertEqual("3", result) diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py
index d04bcae0fa..5cd698147e 100644 --- a/tests/test_test_utils.py +++ b/tests/test_test_utils.py
@@ -17,25 +17,25 @@ from tests.utils import MockClock class MockClockTestCase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.clock = MockClock() - def test_advance_time(self): + def test_advance_time(self) -> None: start_time = self.clock.time() self.clock.advance_time(20) self.assertEqual(20, self.clock.time() - start_time) - def test_later(self): + def test_later(self) -> None: invoked = [0, 0] - def _cb0(): + def _cb0() -> None: invoked[0] = 1 self.clock.call_later(10, _cb0) - def _cb1(): + def _cb1() -> None: invoked[1] = 1 self.clock.call_later(20, _cb1) @@ -51,15 +51,15 @@ class MockClockTestCase(unittest.TestCase): self.assertTrue(invoked[1]) - def test_cancel_later(self): + def test_cancel_later(self) -> None: invoked = [0, 0] - def _cb0(): + def _cb0() -> None: invoked[0] = 1 t0 = self.clock.call_later(10, _cb0) - def _cb1(): + def _cb1() -> None: invoked[1] = 1 self.clock.call_later(20, _cb1) diff --git a/tests/test_types.py b/tests/test_types.py
index 1111169384..c491cc9a96 100644 --- a/tests/test_types.py +++ b/tests/test_types.py
@@ -43,34 +43,34 @@ class IsMineIDTests(unittest.HomeserverTestCase): class UserIDTestCase(unittest.HomeserverTestCase): - def test_parse(self): + def test_parse(self) -> None: user = UserID.from_string("@1234abcd:test") self.assertEqual("1234abcd", user.localpart) self.assertEqual("test", user.domain) self.assertEqual(True, self.hs.is_mine(user)) - def test_parse_rejects_empty_id(self): + def test_parse_rejects_empty_id(self) -> None: with self.assertRaises(SynapseError): UserID.from_string("") - def test_parse_rejects_missing_sigil(self): + def test_parse_rejects_missing_sigil(self) -> None: with self.assertRaises(SynapseError): UserID.from_string("alice:example.com") - def test_parse_rejects_missing_separator(self): + def test_parse_rejects_missing_separator(self) -> None: with self.assertRaises(SynapseError): UserID.from_string("@alice.example.com") - def test_validation_rejects_missing_domain(self): + def test_validation_rejects_missing_domain(self) -> None: self.assertFalse(UserID.is_valid("@alice:")) - def test_build(self): + def test_build(self) -> None: user = UserID("5678efgh", "my.domain") self.assertEqual(user.to_string(), "@5678efgh:my.domain") - def test_compare(self): + def test_compare(self) -> None: userA = UserID.from_string("@userA:my.domain") userAagain = UserID.from_string("@userA:my.domain") userB = UserID.from_string("@userB:my.domain") @@ -80,43 +80,43 @@ class UserIDTestCase(unittest.HomeserverTestCase): class RoomAliasTestCase(unittest.HomeserverTestCase): - def test_parse(self): + def test_parse(self) -> None: room = RoomAlias.from_string("#channel:test") self.assertEqual("channel", room.localpart) self.assertEqual("test", room.domain) self.assertEqual(True, self.hs.is_mine(room)) - def test_build(self): + def test_build(self) -> None: room = RoomAlias("channel", "my.domain") self.assertEqual(room.to_string(), "#channel:my.domain") - def test_validate(self): + def test_validate(self) -> None: id_string = "#test:domain,test" self.assertFalse(RoomAlias.is_valid(id_string)) class MapUsernameTestCase(unittest.TestCase): - def testPassThrough(self): + def test_pass_througuh(self) -> None: self.assertEqual(map_username_to_mxid_localpart("test1234"), "test1234") - def testUpperCase(self): + def test_upper_case(self) -> None: self.assertEqual(map_username_to_mxid_localpart("tEST_1234"), "test_1234") self.assertEqual( map_username_to_mxid_localpart("tEST_1234", case_sensitive=True), "t_e_s_t__1234", ) - def testSymbols(self): + def test_symbols(self) -> None: self.assertEqual( map_username_to_mxid_localpart("test=$?_1234"), "test=3d=24=3f_1234" ) - def testLeadingUnderscore(self): + def test_leading_underscore(self) -> None: self.assertEqual(map_username_to_mxid_localpart("_test_1234"), "=5ftest_1234") - def testNonAscii(self): + def test_non_ascii(self) -> None: # this should work with either a unicode or a bytes self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast") self.assertEqual(map_username_to_mxid_localpart("têst".encode()), "t=c3=aast") diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index e62ebcc6a5..e5dae670a7 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py
@@ -20,12 +20,13 @@ import sys import warnings from asyncio import Future from binascii import unhexlify -from typing import Awaitable, Callable, Tuple, TypeVar +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, TypeVar from unittest.mock import Mock import attr import zope.interface +from twisted.internet.interfaces import IProtocol from twisted.python.failure import Failure from twisted.web.client import ResponseDone from twisted.web.http import RESPONSES @@ -34,6 +35,9 @@ from twisted.web.iweb import IResponse from synapse.types import JsonDict +if TYPE_CHECKING: + from sys import UnraisableHookArgs + TV = TypeVar("TV") @@ -78,25 +82,29 @@ def setup_awaitable_errors() -> Callable[[], None]: unraisable_exceptions = [] orig_unraisablehook = sys.unraisablehook - def unraisablehook(unraisable): + def unraisablehook(unraisable: "UnraisableHookArgs") -> None: unraisable_exceptions.append(unraisable.exc_value) - def cleanup(): + def cleanup() -> None: """ A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions. """ sys.unraisablehook = orig_unraisablehook if unraisable_exceptions: - raise unraisable_exceptions.pop() + exc = unraisable_exceptions.pop() + assert exc is not None + raise exc sys.unraisablehook = unraisablehook return cleanup -def simple_async_mock(return_value=None, raises=None) -> Mock: +def simple_async_mock( + return_value: Optional[TV] = None, raises: Optional[Exception] = None +) -> Mock: # AsyncMock is not available in python3.5, this mimics part of its behaviour - async def cb(*args, **kwargs): + async def cb(*args: Any, **kwargs: Any) -> Optional[TV]: if raises: raise raises return return_value @@ -125,14 +133,14 @@ class FakeResponse: # type: ignore[misc] headers: Headers = attr.Factory(Headers) @property - def phrase(self): + def phrase(self) -> bytes: return RESPONSES.get(self.code, b"Unknown Status") @property - def length(self): + def length(self) -> int: return len(self.body) - def deliverBody(self, protocol): + def deliverBody(self, protocol: IProtocol) -> None: protocol.dataReceived(self.body) protocol.connectionLost(Failure(ResponseDone())) diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index 8027c7a856..a6330ed840 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py
@@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple +from typing import Any, List, Optional, Tuple import synapse.server from synapse.api.constants import EventTypes @@ -32,7 +32,7 @@ async def inject_member_event( membership: str, target: Optional[str] = None, extra_content: Optional[dict] = None, - **kwargs, + **kwargs: Any, ) -> EventBase: """Inject a membership event into a room.""" if target is None: @@ -57,7 +57,7 @@ async def inject_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, prev_event_ids: Optional[List[str]] = None, - **kwargs, + **kwargs: Any, ) -> EventBase: """Inject a generic event into a room @@ -82,7 +82,7 @@ async def create_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, prev_event_ids: Optional[List[str]] = None, - **kwargs, + **kwargs: Any, ) -> Tuple[EventBase, EventContext]: if room_version is None: room_version = await hs.get_datastores().main.get_room_version_id( @@ -92,8 +92,13 @@ async def create_event( builder = hs.get_event_builder_factory().for_room_version( KNOWN_ROOM_VERSIONS[room_version], kwargs ) - event, context = await hs.get_event_creation_handler().create_new_client_event( + ( + event, + unpersisted_context, + ) = await hs.get_event_creation_handler().create_new_client_event( builder, prev_event_ids=prev_event_ids ) + context = await unpersisted_context.persist(event) + return event, context diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py
index e878af5f12..189c697efb 100644 --- a/tests/test_utils/html_parsers.py +++ b/tests/test_utils/html_parsers.py
@@ -13,13 +13,13 @@ # limitations under the License. from html.parser import HTMLParser -from typing import Dict, Iterable, List, Optional, Tuple +from typing import Dict, Iterable, List, NoReturn, Optional, Tuple class TestHtmlParser(HTMLParser): """A generic HTML page parser which extracts useful things from the HTML""" - def __init__(self): + def __init__(self) -> None: super().__init__() # a list of links found in the doc @@ -48,5 +48,5 @@ class TestHtmlParser(HTMLParser): assert input_name self.hiddens[input_name] = attr_dict["value"] - def error(_, message): + def error(self, message: str) -> NoReturn: raise AssertionError(message) diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index 304c7b98c5..b522163a34 100644 --- a/tests/test_utils/logging_setup.py +++ b/tests/test_utils/logging_setup.py
@@ -25,7 +25,7 @@ class ToTwistedHandler(logging.Handler): tx_log = twisted.logger.Logger() - def emit(self, record): + def emit(self, record: logging.LogRecord) -> None: log_entry = self.format(record) log_level = record.levelname.lower().replace("warning", "warn") self.tx_log.emit( @@ -33,7 +33,7 @@ class ToTwistedHandler(logging.Handler): ) -def setup_logging(): +def setup_logging() -> None: """Configure the python logging appropriately for the tests. (Logs will end up in _trial_temp.) diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py
index 1461d23ee8..d555b24255 100644 --- a/tests/test_utils/oidc.py +++ b/tests/test_utils/oidc.py
@@ -14,7 +14,7 @@ import json -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, ContextManager, Dict, List, Optional, Tuple from unittest.mock import Mock, patch from urllib.parse import parse_qs @@ -77,14 +77,14 @@ class FakeOidcServer: self._id_token_overrides: Dict[str, Any] = {} - def reset_mocks(self): + def reset_mocks(self) -> None: self.request.reset_mock() self.get_jwks_handler.reset_mock() self.get_metadata_handler.reset_mock() self.get_userinfo_handler.reset_mock() self.post_token_handler.reset_mock() - def patch_homeserver(self, hs: HomeServer): + def patch_homeserver(self, hs: HomeServer) -> ContextManager[Mock]: """Patch the ``HomeServer`` HTTP client to handle requests through the ``FakeOidcServer``. This patch should be used whenever the HS is expected to perform request to the @@ -188,7 +188,7 @@ class FakeOidcServer: return self._sign(logout_token) - def id_token_override(self, overrides: dict): + def id_token_override(self, overrides: dict) -> ContextManager[dict]: """Temporarily patch the ID token generated by the token endpoint.""" return patch.object(self, "_id_token_overrides", overrides) @@ -247,7 +247,7 @@ class FakeOidcServer: metadata: bool = False, token: bool = False, userinfo: bool = False, - ): + ) -> ContextManager[Dict[str, Mock]]: """A context which makes a set of endpoints return a 500 error. Args: diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index d0b9ad5454..2801a950a8 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py
@@ -35,6 +35,8 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): self.event_creation_handler = self.hs.get_event_creation_handler() self.event_builder_factory = self.hs.get_event_builder_factory() self._storage_controllers = self.hs.get_storage_controllers() + assert self._storage_controllers.persistence is not None + self._persistence = self._storage_controllers.persistence self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")) @@ -175,12 +177,11 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): }, ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) - self.get_success( - self._storage_controllers.persistence.persist_event(event, context) - ) + context = self.get_success(unpersisted_context.persist(event)) + self.get_success(self._persistence.persist_event(event, context)) return event def _inject_room_member( @@ -202,13 +203,12 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): }, ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) + context = self.get_success(unpersisted_context.persist(event)) - self.get_success( - self._storage_controllers.persistence.persist_event(event, context) - ) + self.get_success(self._persistence.persist_event(event, context)) return event def _inject_message( @@ -226,13 +226,12 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): }, ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) + context = self.get_success(unpersisted_context.persist(event)) - self.get_success( - self._storage_controllers.persistence.persist_event(event, context) - ) + self.get_success(self._persistence.persist_event(event, context)) return event def _inject_outlier(self) -> EventBase: @@ -250,7 +249,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[])) event.internal_metadata.outlier = True self.get_success( - self._storage_controllers.persistence.persist_event( + self._persistence.persist_event( event, EventContext.for_outlier(self._storage_controllers) ) ) @@ -258,7 +257,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase): - def test_out_of_band_invite_rejection(self): + def test_out_of_band_invite_rejection(self) -> None: # this is where we have received an invite event over federation, and then # rejected it. invite_pdu = { diff --git a/tests/unittest.py b/tests/unittest.py
index fa92dd94eb..b21e7f1221 100644 --- a/tests/unittest.py +++ b/tests/unittest.py
@@ -45,7 +45,7 @@ from typing_extensions import Concatenate, ParamSpec, Protocol from twisted.internet.defer import Deferred, ensureDeferred from twisted.python.failure import Failure from twisted.python.threadpool import ThreadPool -from twisted.test.proto_helpers import MemoryReactor +from twisted.test.proto_helpers import MemoryReactor, MemoryReactorClock from twisted.trial import unittest from twisted.web.resource import Resource from twisted.web.server import Request @@ -82,7 +82,7 @@ from tests.server import ( ) from tests.test_utils import event_injection, setup_awaitable_errors from tests.test_utils.logging_setup import setup_logging -from tests.utils import default_config, setupdb +from tests.utils import checked_cast, default_config, setupdb setupdb() setup_logging() @@ -296,7 +296,12 @@ class HomeserverTestCase(TestCase): from tests.rest.client.utils import RestHelper - self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None)) + self.helper = RestHelper( + self.hs, + checked_cast(MemoryReactorClock, self.hs.get_reactor()), + self.site, + getattr(self, "user_id", None), + ) if hasattr(self, "user_id"): if self.hijack_auth: @@ -315,7 +320,7 @@ class HomeserverTestCase(TestCase): # This has to be a function and not just a Mock, because # `self.helper.auth_user_id` is temporarily reassigned in some tests - async def get_requester(*args, **kwargs) -> Requester: + async def get_requester(*args: Any, **kwargs: Any) -> Requester: assert self.helper.auth_user_id is not None return create_requester( user_id=UserID.from_string(self.helper.auth_user_id), @@ -361,7 +366,9 @@ class HomeserverTestCase(TestCase): store.db_pool.updates.do_next_background_update(False), by=0.1 ) - def make_homeserver(self, reactor: ThreadedMemoryReactorClock, clock: Clock): + def make_homeserver( + self, reactor: ThreadedMemoryReactorClock, clock: Clock + ) -> HomeServer: """ Make and return a homeserver. diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py
index 9529ee53c8..5f8f4e76b5 100644 --- a/tests/util/test_retryutils.py +++ b/tests/util/test_retryutils.py
@@ -54,6 +54,7 @@ class RetryLimiterTestCase(HomeserverTestCase): self.pump() new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) + assert new_timings is not None self.assertEqual(new_timings.failure_ts, failure_ts) self.assertEqual(new_timings.retry_last_ts, failure_ts) self.assertEqual(new_timings.retry_interval, MIN_RETRY_INTERVAL) @@ -82,6 +83,7 @@ class RetryLimiterTestCase(HomeserverTestCase): self.pump() new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) + assert new_timings is not None self.assertEqual(new_timings.failure_ts, failure_ts) self.assertEqual(new_timings.retry_last_ts, retry_ts) self.assertGreaterEqual( diff --git a/tests/utils.py b/tests/utils.py
index d76bf9716a..a0ac11bc5c 100644 --- a/tests/utils.py +++ b/tests/utils.py
@@ -15,7 +15,7 @@ import atexit import os -from typing import Any, Callable, Dict, List, Tuple, Union, overload +from typing import Any, Callable, Dict, List, Tuple, Type, TypeVar, Union, overload import attr from typing_extensions import Literal, ParamSpec @@ -335,6 +335,33 @@ async def create_room(hs: HomeServer, room_id: str, creator_id: str) -> None: }, ) - event, context = await event_creation_handler.create_new_client_event(builder) + event, unpersisted_context = await event_creation_handler.create_new_client_event( + builder + ) + context = await unpersisted_context.persist(event) await persistence_store.persist_event(event, context) + + +T = TypeVar("T") + + +def checked_cast(type: Type[T], x: object) -> T: + """A version of typing.cast that is checked at runtime. + + We have our own function for this for two reasons: + + 1. typing.cast itself is deliberately a no-op at runtime, see + https://docs.python.org/3/library/typing.html#typing.cast + 2. To help workaround a mypy-zope bug https://github.com/Shoobx/mypy-zope/issues/91 + where mypy would erroneously consider `isinstance(x, type)` to be false in all + circumstances. + + For this to make sense, `T` needs to be something that `isinstance` can check; see + https://docs.python.org/3/library/functions.html?highlight=isinstance#isinstance + https://docs.python.org/3/glossary.html#term-abstract-base-class + https://docs.python.org/3/library/typing.html#typing.runtime_checkable + for more details. + """ + assert isinstance(x, type) + return x