diff options
Diffstat (limited to 'tests')
43 files changed, 1174 insertions, 649 deletions
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index 1cbb059357..0b22afdc75 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -24,6 +24,7 @@ from synapse.appservice.scheduler import ( ) from synapse.logging.context import make_deferred_yieldable from synapse.server import HomeServer +from synapse.types import DeviceListUpdates from synapse.util import Clock from tests import unittest @@ -70,6 +71,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): to_device_messages=[], # txn made and saved one_time_key_counts={}, unused_fallback_keys={}, + device_list_summary=DeviceListUpdates(), ) self.assertEqual(0, len(self.txnctrl.recoverers)) # no recoverer made txn.complete.assert_called_once_with(self.store) # txn completed @@ -96,6 +98,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): to_device_messages=[], # txn made and saved one_time_key_counts={}, unused_fallback_keys={}, + device_list_summary=DeviceListUpdates(), ) self.assertEqual(0, txn.send.call_count) # txn not sent though self.assertEqual(0, txn.complete.call_count) # or completed @@ -124,6 +127,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): to_device_messages=[], one_time_key_counts={}, unused_fallback_keys={}, + device_list_summary=DeviceListUpdates(), ) self.assertEqual(1, self.recoverer_fn.call_count) # recoverer made self.assertEqual(1, self.recoverer.recover.call_count) # and invoked @@ -225,7 +229,9 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): service = Mock(id=4) event = Mock() self.scheduler.enqueue_for_appservice(service, events=[event]) - self.txn_ctrl.send.assert_called_once_with(service, [event], [], [], None, None) + self.txn_ctrl.send.assert_called_once_with( + service, [event], [], [], None, None, DeviceListUpdates() + ) def test_send_single_event_with_queue(self): d = defer.Deferred() @@ -240,12 +246,14 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # (call enqueue_for_appservice multiple times deliberately) self.scheduler.enqueue_for_appservice(service, events=[event2]) self.scheduler.enqueue_for_appservice(service, events=[event3]) - self.txn_ctrl.send.assert_called_with(service, [event], [], [], None, None) + self.txn_ctrl.send.assert_called_with( + service, [event], [], [], None, None, DeviceListUpdates() + ) self.assertEqual(1, self.txn_ctrl.send.call_count) # Resolve the send event: expect the queued events to be sent d.callback(service) self.txn_ctrl.send.assert_called_with( - service, [event2, event3], [], [], None, None + service, [event2, event3], [], [], None, None, DeviceListUpdates() ) self.assertEqual(2, self.txn_ctrl.send.call_count) @@ -272,15 +280,21 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # send events for different ASes and make sure they are sent self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event]) self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event2]) - self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], [], None, None) + self.txn_ctrl.send.assert_called_with( + srv1, [srv_1_event], [], [], None, None, DeviceListUpdates() + ) self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event]) self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event2]) - self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], [], None, None) + self.txn_ctrl.send.assert_called_with( + srv2, [srv_2_event], [], [], None, None, DeviceListUpdates() + ) # make sure callbacks for a service only send queued events for THAT # service srv_2_defer.callback(srv2) - self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], [], None, None) + self.txn_ctrl.send.assert_called_with( + srv2, [srv_2_event2], [], [], None, None, DeviceListUpdates() + ) self.assertEqual(3, self.txn_ctrl.send.call_count) def test_send_large_txns(self): @@ -300,17 +314,17 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # Expect the first event to be sent immediately. self.txn_ctrl.send.assert_called_with( - service, [event_list[0]], [], [], None, None + service, [event_list[0]], [], [], None, None, DeviceListUpdates() ) srv_1_defer.callback(service) # Then send the next 100 events self.txn_ctrl.send.assert_called_with( - service, event_list[1:101], [], [], None, None + service, event_list[1:101], [], [], None, None, DeviceListUpdates() ) srv_2_defer.callback(service) # Then the final 99 events self.txn_ctrl.send.assert_called_with( - service, event_list[101:], [], [], None, None + service, event_list[101:], [], [], None, None, DeviceListUpdates() ) self.assertEqual(3, self.txn_ctrl.send.call_count) @@ -320,7 +334,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): event_list = [Mock(name="event")] self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) self.txn_ctrl.send.assert_called_once_with( - service, [], event_list, [], None, None + service, [], event_list, [], None, None, DeviceListUpdates() ) def test_send_multiple_ephemeral_no_queue(self): @@ -329,7 +343,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")] self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) self.txn_ctrl.send.assert_called_once_with( - service, [], event_list, [], None, None + service, [], event_list, [], None, None, DeviceListUpdates() ) def test_send_single_ephemeral_with_queue(self): @@ -345,13 +359,21 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): # Send more events: expect send() to NOT be called multiple times. self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2) self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3) - self.txn_ctrl.send.assert_called_with(service, [], event_list_1, [], None, None) + self.txn_ctrl.send.assert_called_with( + service, [], event_list_1, [], None, None, DeviceListUpdates() + ) self.assertEqual(1, self.txn_ctrl.send.call_count) # Resolve txn_ctrl.send d.callback(service) # Expect the queued events to be sent self.txn_ctrl.send.assert_called_with( - service, [], event_list_2 + event_list_3, [], None, None + service, + [], + event_list_2 + event_list_3, + [], + None, + None, + DeviceListUpdates(), ) self.assertEqual(2, self.txn_ctrl.send.call_count) @@ -365,8 +387,10 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): event_list = first_chunk + second_chunk self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) self.txn_ctrl.send.assert_called_once_with( - service, [], first_chunk, [], None, None + service, [], first_chunk, [], None, None, DeviceListUpdates() ) d.callback(service) - self.txn_ctrl.send.assert_called_with(service, [], second_chunk, [], None, None) + self.txn_ctrl.send.assert_called_with( + service, [], second_chunk, [], None, None, DeviceListUpdates() + ) self.assertEqual(2, self.txn_ctrl.send.call_count) diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py index 694020fbef..06e0545a4f 100644 --- a/tests/crypto/test_event_signing.py +++ b/tests/crypto/test_event_signing.py @@ -28,8 +28,8 @@ from tests import unittest SIGNING_KEY_SEED = decode_base64("YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1") KEY_ALG = "ed25519" -KEY_VER = 1 -KEY_NAME = "%s:%d" % (KEY_ALG, KEY_VER) +KEY_VER = "1" +KEY_NAME = "%s:%s" % (KEY_ALG, KEY_VER) HOSTNAME = "domain" @@ -39,7 +39,7 @@ class EventSigningTestCase(unittest.TestCase): # NB: `signedjson` expects `nacl.signing.SigningKey` instances which have been # monkeypatched to include new `alg` and `version` attributes. This is captured # by the `signedjson.types.SigningKey` protocol. - self.signing_key: signedjson.types.SigningKey = nacl.signing.SigningKey( + self.signing_key: signedjson.types.SigningKey = nacl.signing.SigningKey( # type: ignore[assignment] SIGNING_KEY_SEED ) self.signing_key.alg = KEY_ALG diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index e90592855a..a6e91956af 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -14,6 +14,7 @@ from typing import Optional from unittest.mock import Mock +from parameterized import parameterized_class from signedjson import key, sign from signedjson.types import BaseKey, SigningKey @@ -154,6 +155,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): ) +@parameterized_class( + [ + {"enable_room_poke_code_path": False}, + {"enable_room_poke_code_path": True}, + ] +) class FederationSenderDevicesTestCases(HomeserverTestCase): servlets = [ admin.register_servlets, @@ -168,17 +175,21 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): def default_config(self): c = super().default_config() c["send_federation"] = True + c["use_new_device_lists_changes_in_room"] = self.enable_room_poke_code_path return c def prepare(self, reactor, clock, hs): - # stub out get_users_who_share_room_with_user so that it claims that - # `@user2:host2` is in the room - def get_users_who_share_room_with_user(user_id): + # stub out `get_rooms_for_user` and `get_users_in_room` so that the + # server thinks the user shares a room with `@user2:host2` + def get_rooms_for_user(user_id): + return defer.succeed({"!room:host1"}) + + hs.get_datastores().main.get_rooms_for_user = get_rooms_for_user + + def get_users_in_room(room_id): return defer.succeed({"@user2:host2"}) - hs.get_datastores().main.get_users_who_share_room_with_user = ( - get_users_who_share_room_with_user - ) + hs.get_datastores().main.get_users_in_room = get_users_in_room # whenever send_transaction is called, record the edu data self.edus = [] diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py index 648a01618e..d21c11b716 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py @@ -23,7 +23,7 @@ from synapse.server import HomeServer from synapse.types import RoomAlias from tests.test_utils import event_injection -from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config +from tests.unittest import FederatingHomeserverTestCase, TestCase class KnockingStrippedStateEventHelperMixin(TestCase): @@ -221,7 +221,6 @@ class FederationKnockingTestCase( return super().prepare(reactor, clock, homeserver) - @override_config({"experimental_features": {"msc2403_enabled": True}}) def test_room_state_returned_when_knocking(self): """ Tests that specific, stripped state events from a room are returned after diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index cead9f90df..8c72cf6b30 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -15,6 +15,8 @@ from typing import Dict, Iterable, List, Optional from unittest.mock import Mock +from parameterized import parameterized + from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor @@ -471,6 +473,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): to_device_messages, _otks, _fbks, + _device_list_summary, ) = self.send_mock.call_args[0] # Assert that this was the same to-device message that local_user sent @@ -583,7 +586,15 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): service_id_to_message_count: Dict[str, int] = {} for call in self.send_mock.call_args_list: - service, _events, _ephemeral, to_device_messages, _otks, _fbks = call[0] + ( + service, + _events, + _ephemeral, + to_device_messages, + _otks, + _fbks, + _device_list_summary, + ) = call[0] # Check that this was made to an interested service self.assertIn(service, interested_appservices) @@ -627,6 +638,114 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): return appservice +class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase): + """ + Tests that the ApplicationServicesHandler sends device list updates to application + services correctly. + """ + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + # Allow us to modify cached feature flags mid-test + self.as_handler = hs.get_application_service_handler() + + # Mock ApplicationServiceApi's put_json, so we can verify the raw JSON that + # will be sent over the wire + self.put_json = simple_async_mock() + hs.get_application_service_api().put_json = self.put_json # type: ignore[assignment] + + # Mock out application services, and allow defining our own in tests + self._services: List[ApplicationService] = [] + self.hs.get_datastores().main.get_app_services = Mock( + return_value=self._services + ) + + # Test across a variety of configuration values + @parameterized.expand( + [ + (True, True, True), + (True, False, False), + (False, True, False), + (False, False, False), + ] + ) + def test_application_service_receives_device_list_updates( + self, + experimental_feature_enabled: bool, + as_supports_txn_extensions: bool, + as_should_receive_device_list_updates: bool, + ): + """ + Tests that an application service receives notice of changed device + lists for a user, when a user changes their device lists. + + Arguments above are populated by parameterized. + + Args: + as_should_receive_device_list_updates: Whether we expect the AS to receive the + device list changes. + experimental_feature_enabled: Whether the "msc3202_transaction_extensions" experimental + feature is enabled. This feature must be enabled for device lists to ASs to work. + as_supports_txn_extensions: Whether the application service has explicitly registered + to receive information defined by MSC3202 - which includes device list changes. + """ + # Change whether the experimental feature is enabled or disabled before making + # device list changes + self.as_handler._msc3202_transaction_extensions_enabled = ( + experimental_feature_enabled + ) + + # Create an appservice that is interested in "local_user" + appservice = ApplicationService( + token=random_string(10), + hostname="example.com", + id=random_string(10), + sender="@as:example.com", + rate_limited=False, + namespaces={ + ApplicationService.NS_USERS: [ + { + "regex": "@local_user:.+", + "exclusive": False, + } + ], + }, + supports_ephemeral=True, + msc3202_transaction_extensions=as_supports_txn_extensions, + # Must be set for Synapse to try pushing data to the AS + hs_token="abcde", + url="some_url", + ) + + # Register the application service + self._services.append(appservice) + + # Register a user on the homeserver + self.local_user = self.register_user("local_user", "password") + self.local_user_token = self.login("local_user", "password") + + if as_should_receive_device_list_updates: + # Ensure that the resulting JSON uses the unstable prefix and contains the + # expected users + self.put_json.assert_called_once() + json_body = self.put_json.call_args[1]["json_body"] + + # Our application service should have received a device list update with + # "local_user" in the "changed" list + device_list_dict = json_body.get("org.matrix.msc3202.device_lists", {}) + self.assertEqual([], device_list_dict["left"]) + self.assertEqual([self.local_user], device_list_dict["changed"]) + + else: + # No device list changes should have been sent out + self.put_json.assert_not_called() + + class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase): # Argument indices for pulling out arguments from a `send_mock`. ARG_OTK_COUNTS = 4 diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py index 3a10791226..7586e472b5 100644 --- a/tests/handlers/test_deactivate_account.py +++ b/tests/handlers/test_deactivate_account.py @@ -44,21 +44,20 @@ class DeactivateAccountTestCase(HomeserverTestCase): Deactivates the account `self.user` using `self.token` and asserts that it returns a 200 success code. """ - req = self.get_success( - self.make_request( - "POST", - "account/deactivate", - { - "auth": { - "type": "m.login.password", - "user": self.user, - "password": "pass", - }, - "erase": True, + req = self.make_request( + "POST", + "account/deactivate", + { + "auth": { + "type": "m.login.password", + "user": self.user, + "password": "pass", }, - access_token=self.token, - ) + "erase": True, + }, + access_token=self.token, ) + self.assertEqual(req.code, HTTPStatus.OK, req) def test_global_account_data_deleted_upon_deactivation(self) -> None: diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index ac21a28c43..8c74ed1fcf 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -463,8 +463,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): res = e.value.code self.assertEqual(res, 400) - res = self.get_success(self.handler.query_local_devices({local_user: None})) - self.assertDictEqual(res, {local_user: {}}) + query_res = self.get_success( + self.handler.query_local_devices({local_user: None}) + ) + self.assertDictEqual(query_res, {local_user: {}}) def test_upload_signatures(self) -> None: """should check signatures that are uploaded""" diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 4d65639a1e..060ba5f517 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -20,17 +20,17 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError from synapse.api.room_versions import RoomVersions -from synapse.events import EventBase +from synapse.events import EventBase, make_event_from_dict from synapse.federation.federation_base import event_from_pdu_json from synapse.logging.context import LoggingContext, run_in_background from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer -from synapse.types import create_requester from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest +from tests.test_utils import event_injection logger = logging.getLogger(__name__) @@ -39,7 +39,7 @@ def generate_fake_event_id() -> str: return "$fake_" + random_string(43) -class FederationTestCase(unittest.HomeserverTestCase): +class FederationTestCase(unittest.FederatingHomeserverTestCase): servlets = [ admin.register_servlets, login.register_servlets, @@ -219,41 +219,77 @@ class FederationTestCase(unittest.HomeserverTestCase): # create the room user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test") - requester = create_requester(user_id) room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + room_version = self.get_success(self.store.get_room_version(room_id)) + + # we need a user on the remote server to be a member, so that we can send + # extremity-causing events. + self.get_success( + event_injection.inject_member_event( + self.hs, room_id, f"@user:{self.OTHER_SERVER_NAME}", "join" + ) + ) - ev1 = self.helper.send(room_id, "first message", tok=tok) + send_result = self.helper.send(room_id, "first message", tok=tok) + ev1 = self.get_success( + self.store.get_event(send_result["event_id"], allow_none=False) + ) + current_state = self.get_success( + self.store.get_events_as_list( + (self.get_success(self.store.get_current_state_ids(room_id))).values() + ) + ) # Create "many" backward extremities. The magic number we're trying to # create more than is 5 which corresponds to the number of backward # extremities we slice off in `_maybe_backfill_inner` + federation_event_handler = self.hs.get_federation_event_handler() for _ in range(0, 8): - event_handler = self.hs.get_event_creation_handler() - event, context = self.get_success( - event_handler.create_event( - requester, + event = make_event_from_dict( + self.add_hashes_and_signatures( { + "origin_server_ts": 1, "type": "m.room.message", "content": { "msgtype": "m.text", "body": "message connected to fake event", }, "room_id": room_id, - "sender": user_id, + "sender": f"@user:{self.OTHER_SERVER_NAME}", + "prev_events": [ + ev1.event_id, + # We're creating an backward extremity each time thanks + # to this fake event + generate_fake_event_id(), + ], + # lazy: *everything* is an auth event + "auth_events": [ev.event_id for ev in current_state], + "depth": ev1.depth + 1, }, - prev_event_ids=[ - ev1["event_id"], - # We're creating an backward extremity each time thanks - # to this fake event - generate_fake_event_id(), - ], - ) + room_version, + ), + room_version, ) + + # we poke this directly into _process_received_pdu, to avoid the + # federation handler wanting to backfill the fake event. self.get_success( - event_handler.handle_new_client_event(requester, event, context) + federation_event_handler._process_received_pdu( + self.OTHER_SERVER_NAME, event, state=current_state + ) ) + # we should now have 8 backwards extremities. + backwards_extremities = self.get_success( + self.store.db_pool.simple_select_list( + "event_backward_extremities", + keyvalues={"room_id": room_id}, + retcols=["event_id"], + ) + ) + self.assertEqual(len(backwards_extremities), 8) + current_depth = 1 limit = 100 with LoggingContext("receive_pdu"): @@ -339,7 +375,8 @@ class FederationTestCase(unittest.HomeserverTestCase): member_event.signatures = member_event_dict["signatures"] # Add the new member_event to the StateMap - prev_state_map[ + updated_state_map = dict(prev_state_map) + updated_state_map[ (member_event.type, member_event.state_key) ] = member_event.event_id auth_events.append(member_event) @@ -363,7 +400,7 @@ class FederationTestCase(unittest.HomeserverTestCase): prev_event_ids=message_event_dict["prev_events"], auth_event_ids=self._event_auth_handler.compute_auth_events( builder, - prev_state_map, + updated_state_map, for_verification=False, ), depth=message_event_dict["depth"], diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py new file mode 100644 index 0000000000..489ba57736 --- /dev/null +++ b/tests/handlers/test_federation_event.py @@ -0,0 +1,225 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest import mock + +from synapse.events import make_event_from_dict +from synapse.events.snapshot import EventContext +from synapse.federation.transport.client import StateRequestResponse +from synapse.logging.context import LoggingContext +from synapse.rest import admin +from synapse.rest.client import login, room + +from tests import unittest +from tests.test_utils import event_injection, make_awaitable + + +class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + # mock out the federation transport client + self.mock_federation_transport_client = mock.Mock( + spec=["get_room_state_ids", "get_room_state", "get_event"] + ) + return super().setup_test_homeserver( + federation_transport_client=self.mock_federation_transport_client + ) + + def test_process_pulled_event_with_missing_state(self) -> None: + """Ensure that we correctly handle pulled events with lots of missing state + + In this test, we pretend we are processing a "pulled" event (eg, via backfill + or get_missing_events). The pulled event has a prev_event we haven't previously + seen, so the server requests the state at that prev_event. There is a lot + of state we don't have, so we expect the server to make a /state request. + + We check that the pulled event is correctly persisted, and that the state is + as we expect. + """ + return self._test_process_pulled_event_with_missing_state(False) + + def test_process_pulled_event_with_missing_state_where_prev_is_outlier( + self, + ) -> None: + """Ensure that we correctly handle pulled events with lots of missing state + + A slight modification to test_process_pulled_event_with_missing_state. Again + we have a "pulled" event which refers to a prev_event with lots of state, + but in this case we already have the prev_event (as an outlier, obviously - + if it were a regular event, we wouldn't need to request the state). + """ + return self._test_process_pulled_event_with_missing_state(True) + + def _test_process_pulled_event_with_missing_state( + self, prev_exists_as_outlier: bool + ) -> None: + OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}" + main_store = self.hs.get_datastores().main + state_storage = self.hs.get_storage().state + + # create the room + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + room_version = self.get_success(main_store.get_room_version(room_id)) + + # allow the remote user to send state events + self.helper.send_state( + room_id, + "m.room.power_levels", + {"events_default": 0, "state_default": 0}, + tok=tok, + ) + + # add the remote user to the room + member_event = self.get_success( + event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join") + ) + + initial_state_map = self.get_success(main_store.get_current_state_ids(room_id)) + + auth_event_ids = [ + initial_state_map[("m.room.create", "")], + initial_state_map[("m.room.power_levels", "")], + initial_state_map[("m.room.join_rules", "")], + member_event.event_id, + ] + + # mock up a load of state events which we are missing + state_events = [ + make_event_from_dict( + self.add_hashes_and_signatures( + { + "type": "test_state_type", + "state_key": f"state_{i}", + "room_id": room_id, + "sender": OTHER_USER, + "prev_events": [member_event.event_id], + "auth_events": auth_event_ids, + "origin_server_ts": 1, + "depth": 10, + "content": {"body": f"state_{i}"}, + } + ), + room_version, + ) + for i in range(1, 10) + ] + + # this is the state that we are going to claim is active at the prev_event. + state_at_prev_event = state_events + self.get_success( + main_store.get_events_as_list(initial_state_map.values()) + ) + + # mock up a prev event. + # Depending on the test, we either persist this upfront (as an outlier), + # or let the server request it. + prev_event = make_event_from_dict( + self.add_hashes_and_signatures( + { + "type": "test_regular_type", + "room_id": room_id, + "sender": OTHER_USER, + "prev_events": [], + "auth_events": auth_event_ids, + "origin_server_ts": 1, + "depth": 11, + "content": {"body": "missing_prev"}, + } + ), + room_version, + ) + if prev_exists_as_outlier: + prev_event.internal_metadata.outlier = True + persistence = self.hs.get_storage().persistence + self.get_success( + persistence.persist_event(prev_event, EventContext.for_outlier()) + ) + else: + + async def get_event(destination: str, event_id: str, timeout=None): + self.assertEqual(destination, self.OTHER_SERVER_NAME) + self.assertEqual(event_id, prev_event.event_id) + return {"pdus": [prev_event.get_pdu_json()]} + + self.mock_federation_transport_client.get_event.side_effect = get_event + + # mock up a regular event to pass into _process_pulled_event + pulled_event = make_event_from_dict( + self.add_hashes_and_signatures( + { + "type": "test_regular_type", + "room_id": room_id, + "sender": OTHER_USER, + "prev_events": [prev_event.event_id], + "auth_events": auth_event_ids, + "origin_server_ts": 1, + "depth": 12, + "content": {"body": "pulled"}, + } + ), + room_version, + ) + + # we expect an outbound request to /state_ids, so stub that out + self.mock_federation_transport_client.get_room_state_ids.return_value = ( + make_awaitable( + { + "pdu_ids": [e.event_id for e in state_at_prev_event], + "auth_chain_ids": [], + } + ) + ) + + # we also expect an outbound request to /state + self.mock_federation_transport_client.get_room_state.return_value = ( + make_awaitable( + StateRequestResponse(auth_events=[], state=state_at_prev_event) + ) + ) + + # we have to bump the clock a bit, to keep the retry logic in + # FederationClient.get_pdu happy + self.reactor.advance(60000) + + # Finally, the call under test: send the pulled event into _process_pulled_event + with LoggingContext("test"): + self.get_success( + self.hs.get_federation_event_handler()._process_pulled_event( + self.OTHER_SERVER_NAME, pulled_event, backfilled=False + ) + ) + + # check that the event is correctly persisted + persisted = self.get_success(main_store.get_event(pulled_event.event_id)) + self.assertIsNotNone(persisted, "pulled event was not persisted at all") + self.assertFalse( + persisted.internal_metadata.is_outlier(), "pulled event was an outlier" + ) + + # check that the state at that event is as expected + state = self.get_success( + state_storage.get_state_ids_for_event(pulled_event.event_id) + ) + expected_state = { + (e.type, e.state_key): e.event_id for e in state_at_prev_event + } + self.assertEqual(state, expected_state) + + if prev_exists_as_outlier: + self.mock_federation_transport_client.get_event.assert_not_called() diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 014815db6e..9684120c70 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -354,10 +354,11 @@ class OidcHandlerTestCase(HomeserverTestCase): req = Mock(spec=["cookies"]) req.cookies = [] - url = self.get_success( - self.provider.handle_redirect_request(req, b"http://client/redirect") + url = urlparse( + self.get_success( + self.provider.handle_redirect_request(req, b"http://client/redirect") + ) ) - url = urlparse(url) auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT) self.assertEqual(url.scheme, auth_endpoint.scheme) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 1ec105c373..f88c725a42 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -59,7 +59,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.bob = UserID.from_string("@4567:test") self.alice = UserID.from_string("@alice:remote") - self.get_success(self.register_user(self.frank.localpart, "frankpassword")) + self.register_user(self.frank.localpart, "frankpassword") self.handler = hs.get_profile_handler() diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 3aedc0767b..865b8b7e47 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -158,9 +158,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): ) # Blow away caches (supported room versions can only change due to a restart). - self.get_success( - self.store.get_rooms_for_user_with_stream_ordering.invalidate_all() - ) + self.store.get_rooms_for_user_with_stream_ordering.invalidate_all() self.store._get_event_cache.clear() # The rooms should be excluded from the sync response. diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 92012cd6f7..c6e501c7be 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -351,6 +351,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.handler.handle_local_profile_change(regular_user_id, profile_info) ) profile = self.get_success(self.store.get_user_in_directory(regular_user_id)) + assert profile is not None self.assertTrue(profile["display_name"] == display_name) def test_handle_local_profile_change_with_deactivated_user(self) -> None: @@ -369,6 +370,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # profile is in directory profile = self.get_success(self.store.get_user_in_directory(r_user_id)) + assert profile is not None self.assertTrue(profile["display_name"] == display_name) # deactivate user diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 10dd94b549..9fd5d59c55 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -87,11 +87,23 @@ class ModuleApiTestCase(HomeserverTestCase): self.assertEqual(displayname, "Bobberino") def test_can_register_admin_user(self): - user_id = self.get_success( - self.register_user( - "bob_module_admin", "1234", displayname="Bobberino Admin", admin=True - ) + user_id = self.register_user( + "bob_module_admin", "1234", displayname="Bobberino Admin", admin=True ) + + found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id)) + self.assertEqual(found_user.user_id.to_string(), user_id) + self.assertIdentical(found_user.is_admin, True) + + def test_can_set_admin(self): + user_id = self.register_user( + "alice_wants_admin", + "1234", + displayname="Alice Powerhungry", + admin=False, + ) + + self.get_success(self.module_api.set_user_admin(user_id, True)) found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id)) self.assertEqual(found_user.user_id.to_string(), user_id) self.assertIdentical(found_user.is_admin, True) @@ -278,7 +290,7 @@ class ModuleApiTestCase(HomeserverTestCase): # Create a user and room to play with user_id = self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") - room_id = self.helper.create_room_as(user_id, tok=tok) + room_id = self.helper.create_room_as(user_id, tok=tok, is_public=False) # The room should not currently be in the public rooms directory is_in_public_rooms = self.get_success( diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 17dc42fd37..297a9e77f8 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -268,7 +268,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): event_source = RoomEventSource(self.hs) event_source.store = self.slaved_store - current_token = self.get_success(event_source.get_current_key()) + current_token = event_source.get_current_key() # gradually stream out the replication while repl_transport.buffer: @@ -277,7 +277,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.pump(0) prev_token = current_token - current_token = self.get_success(event_source.get_current_key()) + current_token = event_source.get_current_key() # attempt to replicate the behaviour of the sync handler. # diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 0d47dd0aff..e909e444ac 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -702,6 +702,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): """ media_info = self.get_success(self.store.get_local_media(self.media_id)) + assert media_info is not None self.assertFalse(media_info["quarantined_by"]) # quarantining @@ -715,6 +716,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) + assert media_info is not None self.assertTrue(media_info["quarantined_by"]) # remove from quarantine @@ -728,6 +730,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) + assert media_info is not None self.assertFalse(media_info["quarantined_by"]) def test_quarantine_protected_media(self) -> None: @@ -740,6 +743,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): # verify protection media_info = self.get_success(self.store.get_local_media(self.media_id)) + assert media_info is not None self.assertTrue(media_info["safe_from_quarantine"]) # quarantining @@ -754,6 +758,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): # verify that is not in quarantine media_info = self.get_success(self.store.get_local_media(self.media_id)) + assert media_info is not None self.assertFalse(media_info["quarantined_by"]) @@ -830,6 +835,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): """ media_info = self.get_success(self.store.get_local_media(self.media_id)) + assert media_info is not None self.assertFalse(media_info["safe_from_quarantine"]) # protect @@ -843,6 +849,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) + assert media_info is not None self.assertTrue(media_info["safe_from_quarantine"]) # unprotect @@ -856,6 +863,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) + assert media_info is not None self.assertFalse(media_info["safe_from_quarantine"]) diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py index 2c855bff99..a53463c9ba 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py @@ -214,9 +214,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): self.assertEqual(messages[0]["sender"], "@notices:test") # invalidate cache of server notices room_ids - self.get_success( - self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all() - ) + self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all() # send second message channel = self.make_request( @@ -291,9 +289,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): # invalidate cache of server notices room_ids # if server tries to send to a cached room_id the user gets the message # in old room - self.get_success( - self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all() - ) + self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all() # send second message channel = self.make_request( @@ -380,9 +376,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): # invalidate cache of server notices room_ids # if server tries to send to a cached room_id it gives an error - self.get_success( - self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all() - ) + self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all() # send second message channel = self.make_request( diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index bef911d5df..0cdf1dec40 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1590,10 +1590,9 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) - pushers = self.get_success( - self.store.get_pushers_by({"user_name": "@bob:test"}) + pushers = list( + self.get_success(self.store.get_pushers_by({"user_name": "@bob:test"})) ) - pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertEqual("@bob:test", pushers[0].user_name) @@ -1632,10 +1631,9 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) - pushers = self.get_success( - self.store.get_pushers_by({"user_name": "@bob:test"}) + pushers = list( + self.get_success(self.store.get_pushers_by({"user_name": "@bob:test"})) ) - pushers = list(pushers) self.assertEqual(len(pushers), 0) def test_set_password(self) -> None: @@ -2144,6 +2142,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # is in user directory profile = self.get_success(self.store.get_user_in_directory(self.other_user)) + assert profile is not None self.assertTrue(profile["display_name"] == "User") # Deactivate user @@ -2711,6 +2710,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): user_tuple = self.get_success( self.store.get_user_by_access_token(other_user_token) ) + assert user_tuple is not None token_id = user_tuple.token_id self.get_success( @@ -3676,6 +3676,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): # The user starts off as not shadow-banned. other_user_token = self.login("user", "pass") result = self.get_success(self.store.get_user_by_access_token(other_user_token)) + assert result is not None self.assertFalse(result.shadow_banned) channel = self.make_request("POST", self.url, access_token=self.admin_user_tok) @@ -3684,6 +3685,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): # Ensure the user is shadow-banned (and the cache was cleared). result = self.get_success(self.store.get_user_by_access_token(other_user_token)) + assert result is not None self.assertTrue(result.shadow_banned) # Un-shadow-ban the user. @@ -3695,6 +3697,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): # Ensure the user is no longer shadow-banned (and the cache was cleared). result = self.get_success(self.store.get_user_by_access_token(other_user_token)) + assert result is not None self.assertFalse(result.shadow_banned) diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index 27946febff..e00b5c171c 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -89,6 +89,17 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): self.store = hs.get_datastores().main self.submit_token_resource = PasswordResetSubmitTokenResource(hs) + def attempt_wrong_password_login(self, username: str, password: str) -> None: + """Attempts to login as the user with the given password, asserting + that the attempt *fails*. + """ + body = {"type": "m.login.password", "user": username, "password": password} + + channel = self.make_request( + "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") + ) + self.assertEqual(channel.code, 403, channel.result) + def test_basic_password_reset(self) -> None: """Test basic password reset flow""" old_password = "monkey" diff --git a/tests/rest/client/test_account_data.py b/tests/rest/client/test_account_data.py new file mode 100644 index 0000000000..d5b0640e7a --- /dev/null +++ b/tests/rest/client/test_account_data.py @@ -0,0 +1,75 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import Mock + +from synapse.rest import admin +from synapse.rest.client import account_data, login, room + +from tests import unittest +from tests.test_utils import make_awaitable + + +class AccountDataTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + account_data.register_servlets, + ] + + def test_on_account_data_updated_callback(self) -> None: + """Tests that the on_account_data_updated module callback is called correctly when + a user's account data changes. + """ + mocked_callback = Mock(return_value=make_awaitable(None)) + self.hs.get_account_data_handler()._on_account_data_updated_callbacks.append( + mocked_callback + ) + + user_id = self.register_user("user", "password") + tok = self.login("user", "password") + account_data_type = "org.matrix.foo" + account_data_content = {"bar": "baz"} + + # Change the user's global account data. + channel = self.make_request( + "PUT", + f"/user/{user_id}/account_data/{account_data_type}", + account_data_content, + access_token=tok, + ) + + # Test that the callback is called with the user ID, the new account data, and + # None as the room ID. + self.assertEqual(channel.code, 200, channel.result) + mocked_callback.assert_called_once_with( + user_id, None, account_data_type, account_data_content + ) + + # Change the user's room-specific account data. + room_id = self.helper.create_room_as(user_id, tok=tok) + channel = self.make_request( + "PUT", + f"/user/{user_id}/rooms/{room_id}/account_data/{account_data_type}", + account_data_content, + access_token=tok, + ) + + # Test that the callback is called with the user ID, the room ID and the new + # account data. + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(mocked_callback.call_count, 2) + mocked_callback.assert_called_with( + user_id, room_id, account_data_type, account_data_content + ) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index fe97a0b3dd..419eef166a 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import itertools import urllib.parse from typing import Any, Callable, Dict, List, Optional, Tuple from unittest.mock import patch @@ -145,16 +144,6 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) return channel.json_body["unsigned"].get("m.relations", {}) - def _get_aggregations(self) -> List[JsonDict]: - """Request /aggregations on the parent ID and includes the returned chunk.""" - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - return channel.json_body["chunk"] - def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: """ Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. @@ -264,43 +253,6 @@ class RelationsTestCase(BaseRelationsTestCase): expected_response_code=400, ) - def test_aggregation(self) -> None: - """Test that annotations get correctly aggregated.""" - - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token - ) - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertEqual( - channel.json_body, - { - "chunk": [ - {"type": "m.reaction", "key": "a", "count": 2}, - {"type": "m.reaction", "key": "b", "count": 1}, - ] - }, - ) - - def test_aggregation_must_be_annotation(self) -> None: - """Test that aggregations must be annotations.""" - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations" - f"/{self.parent_id}/{RelationTypes.REPLACE}?limit=1", - access_token=self.user_token, - ) - self.assertEqual(400, channel.code, channel.json_body) - def test_ignore_invalid_room(self) -> None: """Test that we ignore invalid relations over federation.""" # Create another room and send a message in it. @@ -394,15 +346,6 @@ class RelationsTestCase(BaseRelationsTestCase): self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(channel.json_body["chunk"], []) - # And when fetching aggregations. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{room2}/aggregations/{parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual(channel.json_body["chunk"], []) - # And for bundled aggregations. channel = self.make_request( "GET", @@ -717,15 +660,6 @@ class RelationsTestCase(BaseRelationsTestCase): self.assertEqual(200, channel.code, channel.json_body) self.assertNotIn("m.relations", channel.json_body["unsigned"]) - # But unknown relations can be directly queried. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual(channel.json_body["chunk"], []) - def test_background_update(self) -> None: """Test the event_arbitrary_relations background update.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") @@ -941,131 +875,6 @@ class RelationPaginationTestCase(BaseRelationsTestCase): annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]] ) - def test_aggregation_pagination_groups(self) -> None: - """Test that we can paginate annotation groups correctly.""" - - # We need to create ten separate users to send each reaction. - access_tokens = [self.user_token, self.user2_token] - idx = 0 - while len(access_tokens) < 10: - user_id, token = self._create_user("test" + str(idx)) - idx += 1 - - self.helper.join(self.room, user=user_id, tok=token) - access_tokens.append(token) - - idx = 0 - sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1} - for key in itertools.chain.from_iterable( - itertools.repeat(key, num) for key, num in sent_groups.items() - ): - self._send_relation( - RelationTypes.ANNOTATION, - "m.reaction", - key=key, - access_token=access_tokens[idx], - ) - - idx += 1 - idx %= len(access_tokens) - - prev_token: Optional[str] = None - found_groups: Dict[str, int] = {} - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + prev_token - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - - for groups in channel.json_body["chunk"]: - # We only expect reactions - self.assertEqual(groups["type"], "m.reaction", channel.json_body) - - # We should only see each key once - self.assertNotIn(groups["key"], found_groups, channel.json_body) - - found_groups[groups["key"]] = groups["count"] - - next_batch = channel.json_body.get("next_batch") - - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - self.assertEqual(sent_groups, found_groups) - - def test_aggregation_pagination_within_group(self) -> None: - """Test that we can paginate within an annotation group.""" - - # We need to create ten separate users to send each reaction. - access_tokens = [self.user_token, self.user2_token] - idx = 0 - while len(access_tokens) < 10: - user_id, token = self._create_user("test" + str(idx)) - idx += 1 - - self.helper.join(self.room, user=user_id, tok=token) - access_tokens.append(token) - - idx = 0 - expected_event_ids = [] - for _ in range(10): - channel = self._send_relation( - RelationTypes.ANNOTATION, - "m.reaction", - key="👍", - access_token=access_tokens[idx], - ) - expected_event_ids.append(channel.json_body["event_id"]) - - idx += 1 - - # Also send a different type of reaction so that we test we don't see it - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - - prev_token = "" - found_event_ids: List[str] = [] - encoded_key = urllib.parse.quote_plus("👍".encode()) - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + prev_token - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}" - f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}" - f"/m.reaction/{encoded_key}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - - found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) - - next_batch = channel.json_body.get("next_batch") - - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - # We paginated backwards, so reverse - found_event_ids.reverse() - self.assertEqual(found_event_ids, expected_event_ids) - class BundledAggregationsTestCase(BaseRelationsTestCase): """ @@ -1453,10 +1262,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase): {"chunk": [{"type": "m.reaction", "key": "a", "count": 2}]}, ) - # Both relations appear in the aggregation. - chunk = self._get_aggregations() - self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 2}]) - # Redact one of the reactions. self._redact(to_redact_event_id) @@ -1469,10 +1274,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase): {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, ) - # The unredacted aggregation should still exist. - chunk = self._get_aggregations() - self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 1}]) - def test_redact_relation_thread(self) -> None: """ Test that thread replies are properly handled after the thread reply redacted. @@ -1578,10 +1379,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase): self.assertEqual(len(event_ids), 1) self.assertIn(RelationTypes.ANNOTATION, relations) - # The aggregation should exist. - chunk = self._get_aggregations() - self.assertEqual(chunk, [{"type": "m.reaction", "key": "👍", "count": 1}]) - # Redact the original event. self._redact(self.parent_id) @@ -1594,10 +1391,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase): {"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]}, ) - # There's nothing to aggregate. - chunk = self._get_aggregations() - self.assertEqual(chunk, [{"count": 1, "key": "👍", "type": "m.reaction"}]) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_redact_parent_thread(self) -> None: """ diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 3a9617d6da..6ff79b9e2e 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -982,7 +982,7 @@ class RoomJoinRatelimitTestCase(RoomBase): super().prepare(reactor, clock, hs) # profile changes expect that the user is actually registered user = UserID.from_string(self.user_id) - self.get_success(self.register_user(user.localpart, "supersecretpassword")) + self.register_user(user.localpart, "supersecretpassword") @unittest.override_config( {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index 4351013952..773c16a54c 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -341,7 +341,6 @@ class SyncKnockTestCase( hs, self.room_id, self.user_id ) - @override_config({"experimental_features": {"msc2403_enabled": True}}) def test_knock_room_state(self) -> None: """Tests that /sync returns state from a room after knocking on it.""" # Knock on a room @@ -497,6 +496,11 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): receipts.register_servlets, ] + def default_config(self) -> JsonDict: + config = super().default_config() + config["experimental_features"] = {"msc2654_enabled": True} + return config + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.url = "/sync?since=%s" self.next_batch = "s0" @@ -772,3 +776,65 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase): self.assertIn( self.user_id, device_list_changes, incremental_sync_channel.json_body ) + + +class ExcludeRoomTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + sync.register_servlets, + room.register_servlets, + ] + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + self.user_id = self.register_user("user", "password") + self.tok = self.login("user", "password") + + self.excluded_room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + self.included_room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + # We need to manually append the room ID, because we can't know the ID before + # creating the room, and we can't set the config after starting the homeserver. + self.hs.get_sync_handler().rooms_to_exclude.append(self.excluded_room_id) + + def test_join_leave(self) -> None: + """Tests that rooms are correctly excluded from the 'join' and 'leave' sections of + sync responses. + """ + channel = self.make_request("GET", "/sync", access_token=self.tok) + self.assertEqual(channel.code, 200, channel.result) + + self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["join"]) + self.assertIn(self.included_room_id, channel.json_body["rooms"]["join"]) + + self.helper.leave(self.excluded_room_id, self.user_id, tok=self.tok) + self.helper.leave(self.included_room_id, self.user_id, tok=self.tok) + + channel = self.make_request( + "GET", + "/sync?since=" + channel.json_body["next_batch"], + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["leave"]) + self.assertIn(self.included_room_id, channel.json_body["rooms"]["leave"]) + + def test_invite(self) -> None: + """Tests that rooms are correctly excluded from the 'invite' section of sync + responses. + """ + invitee = self.register_user("invitee", "password") + invitee_tok = self.login("invitee", "password") + + self.helper.invite(self.excluded_room_id, self.user_id, invitee, tok=self.tok) + self.helper.invite(self.included_room_id, self.user_id, invitee, tok=self.tok) + + channel = self.make_request("GET", "/sync", access_token=invitee_tok) + self.assertEqual(channel.code, 200, channel.result) + + self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["invite"]) + self.assertIn(self.included_room_id, channel.json_body["rooms"]["invite"]) diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index e7de67e3a3..5eb0f243f7 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -896,3 +896,44 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): # Check that the mock was called with the right room ID self.assertEqual(args[1], self.room_id) + + def test_on_threepid_bind(self) -> None: + """Tests that the on_threepid_bind module callback is called correctly after + associating a 3PID to an account. + """ + # Register a mocked callback. + threepid_bind_mock = Mock(return_value=make_awaitable(None)) + third_party_rules = self.hs.get_third_party_event_rules() + third_party_rules._on_threepid_bind_callbacks.append(threepid_bind_mock) + + # Register an admin user. + self.register_user("admin", "password", admin=True) + admin_tok = self.login("admin", "password") + + # Also register a normal user we can modify. + user_id = self.register_user("user", "password") + + # Add a 3PID to the user. + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % user_id, + { + "threepids": [ + { + "medium": "email", + "address": "foo@example.com", + }, + ], + }, + access_token=admin_tok, + ) + + # Check that the shutdown was blocked + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the mock was called once. + threepid_bind_mock.assert_called_once() + args = threepid_bind_mock.call_args[0] + + # Check that the mock was called with the right parameters + self.assertEqual(args, (user_id, "email", "foo@example.com")) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 28663826fc..a0788b1bb0 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -88,7 +88,7 @@ class RestHelper: def create_room_as( self, room_creator: Optional[str] = None, - is_public: Optional[bool] = None, + is_public: Optional[bool] = True, room_version: Optional[str] = None, tok: Optional[str] = None, expect_code: int = HTTPStatus.OK, @@ -101,9 +101,12 @@ class RestHelper: Args: room_creator: The user ID to create the room with. is_public: If True, the `visibility` parameter will be set to - "public". If False, it will be set to "private". If left - unspecified, the server will set it to an appropriate default - (which should be "private" as per the CS spec). + "public". If False, it will be set to "private". + If None, doesn't specify the `visibility` parameter in which + case the server is supposed to make the room private according to + the CS API. + Defaults to public, since that is commonly needed in tests + for convenience where room privacy is not a problem. room_version: The room version to create the room as. Defaults to Synapse's default room version. tok: The access token to use in the request. diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index 978c252f84..ac0ac06b7e 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -76,7 +76,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase): "verify_keys": { key_id: { "key": signedjson.key.encode_verify_key_base64( - signing_key.verify_key + signedjson.key.get_verify_key(signing_key) ) } }, @@ -175,7 +175,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): % ( self.hs_signing_key.version, ): signedjson.key.encode_verify_key_base64( - self.hs_signing_key.verify_key + signedjson.key.get_verify_key(self.hs_signing_key) ) }, } @@ -229,7 +229,9 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): assert isinstance(keyres, FetchKeyResult) self.assertEqual( signedjson.key.encode_verify_key_base64(keyres.verify_key), - signedjson.key.encode_verify_key_base64(testkey.verify_key), + signedjson.key.encode_verify_key_base64( + signedjson.key.get_verify_key(testkey) + ), ) def test_get_notary_key(self) -> None: @@ -251,7 +253,9 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): assert isinstance(keyres, FetchKeyResult) self.assertEqual( signedjson.key.encode_verify_key_base64(keyres.verify_key), - signedjson.key.encode_verify_key_base64(testkey.verify_key), + signedjson.key.encode_verify_key_base64( + signedjson.key.get_verify_key(testkey) + ), ) def test_get_notary_keyserver_key(self) -> None: @@ -268,5 +272,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): assert isinstance(keyres, FetchKeyResult) self.assertEqual( signedjson.key.encode_verify_key_base64(keyres.verify_key), - signedjson.key.encode_verify_key_base64(self.hs_signing_key.verify_key), + signedjson.key.encode_verify_key_base64( + signedjson.key.get_verify_key(self.hs_signing_key) + ), ) diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 5148c39874..3b24d0ace6 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -17,7 +17,7 @@ import json import os import re from typing import Any, Dict, Optional, Sequence, Tuple, Type -from urllib.parse import urlencode +from urllib.parse import quote, urlencode from twisted.internet._resolver import HostResolution from twisted.internet.address import IPv4Address, IPv6Address @@ -69,7 +69,6 @@ class URLPreviewTests(unittest.HomeserverTestCase): "2001:800::/21", ) config["url_preview_ip_range_whitelist"] = ("1.1.1.1",) - config["url_preview_url_blacklist"] = [] config["url_preview_accept_language"] = [ "en-UK", "en-US;q=0.9", @@ -1123,3 +1122,43 @@ class URLPreviewTests(unittest.HomeserverTestCase): os.path.exists(path), f"{os.path.relpath(path, self.media_store_path)} was not deleted", ) + + @unittest.override_config({"url_preview_url_blacklist": [{"port": "*"}]}) + def test_blacklist_port(self) -> None: + """Tests that blacklisting URLs with a port makes previewing such URLs + fail with a 403 error and doesn't impact other previews. + """ + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] + + bad_url = quote("http://matrix.org:8888/foo") + good_url = quote("http://matrix.org/foo") + + channel = self.make_request( + "GET", + "preview_url?url=" + bad_url, + shorthand=False, + await_result=False, + ) + self.pump() + self.assertEqual(channel.code, 403, channel.result) + + channel = self.make_request( + "GET", + "preview_url?url=" + good_url, + shorthand=False, + await_result=False, + ) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n" + % (len(self.end_content),) + + self.end_content + ) + + self.pump() + self.assertEqual(channel.code, 200) diff --git a/tests/server.py b/tests/server.py index 6ce2a17bf4..16559d2588 100644 --- a/tests/server.py +++ b/tests/server.py @@ -22,7 +22,6 @@ import warnings from collections import deque from io import SEEK_END, BytesIO from typing import ( - AnyStr, Callable, Dict, Iterable, @@ -77,6 +76,7 @@ from tests.utils import ( POSTGRES_BASE_DB, POSTGRES_HOST, POSTGRES_PASSWORD, + POSTGRES_PORT, POSTGRES_USER, SQLITE_PERSIST_DB, USE_POSTGRES_FOR_TESTS, @@ -86,6 +86,9 @@ from tests.utils import ( logger = logging.getLogger(__name__) +# the type of thing that can be passed into `make_request` in the headers list +CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]] + class TimedOutException(Exception): """ @@ -260,7 +263,7 @@ def make_request( federation_auth_origin: Optional[bytes] = None, content_is_form: bool = False, await_result: bool = True, - custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, + custom_headers: Optional[Iterable[CustomHeaderType]] = None, client_ip: str = "127.0.0.1", ) -> FakeChannel: """ @@ -745,6 +748,7 @@ def setup_test_homeserver( "host": POSTGRES_HOST, "password": POSTGRES_PASSWORD, "user": POSTGRES_USER, + "port": POSTGRES_PORT, "cp_min": 1, "cp_max": 5, }, @@ -784,6 +788,7 @@ def setup_test_homeserver( database=POSTGRES_BASE_DB, user=POSTGRES_USER, host=POSTGRES_HOST, + port=POSTGRES_PORT, password=POSTGRES_PASSWORD, ) db_conn.autocommit = True @@ -831,6 +836,7 @@ def setup_test_homeserver( database=POSTGRES_BASE_DB, user=POSTGRES_USER, host=POSTGRES_HOST, + port=POSTGRES_PORT, password=POSTGRES_PASSWORD, ) db_conn.autocommit = True diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py index 3ac4646969..74c6224eb6 100644 --- a/tests/storage/databases/main/test_lock.py +++ b/tests/storage/databases/main/test_lock.py @@ -28,7 +28,7 @@ class LockTestCase(unittest.HomeserverTestCase): """ # First to acquire this lock, so it should complete lock = self.get_success(self.store.try_acquire_lock("name", "key")) - self.assertIsNotNone(lock) + assert lock is not None # Enter the context manager self.get_success(lock.__aenter__()) @@ -45,7 +45,7 @@ class LockTestCase(unittest.HomeserverTestCase): # We can now acquire the lock again. lock3 = self.get_success(self.store.try_acquire_lock("name", "key")) - self.assertIsNotNone(lock3) + assert lock3 is not None self.get_success(lock3.__aenter__()) self.get_success(lock3.__aexit__(None, None, None)) @@ -53,7 +53,7 @@ class LockTestCase(unittest.HomeserverTestCase): """Test that we don't time out locks while they're still active""" lock = self.get_success(self.store.try_acquire_lock("name", "key")) - self.assertIsNotNone(lock) + assert lock is not None self.get_success(lock.__aenter__()) @@ -69,7 +69,7 @@ class LockTestCase(unittest.HomeserverTestCase): """Test that we time out locks if they're not updated for ages""" lock = self.get_success(self.store.try_acquire_lock("name", "key")) - self.assertIsNotNone(lock) + assert lock is not None self.get_success(lock.__aenter__()) diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 4899cd5c36..366398e39d 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -14,12 +14,18 @@ # limitations under the License. import secrets +from typing import Any, Dict, Generator, List, Tuple + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest class UpsertManyTests(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.storage = hs.get_datastores().main self.table_name = "table_" + secrets.token_hex(6) @@ -40,11 +46,13 @@ class UpsertManyTests(unittest.HomeserverTestCase): ) ) - def _dump_to_tuple(self, res): + def _dump_to_tuple( + self, res: List[Dict[str, Any]] + ) -> Generator[Tuple[int, str, str], None, None]: for i in res: yield (i["id"], i["username"], i["value"]) - def test_upsert_many(self): + def test_upsert_many(self) -> None: """ Upsert_many will perform the upsert operation across a batch of data. """ diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index ee599f4336..1bf93e79a7 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -31,6 +31,7 @@ from synapse.storage.databases.main.appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore, ) +from synapse.types import DeviceListUpdates from synapse.util import Clock from tests import unittest @@ -168,15 +169,6 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): (as_id, txn_id, json.dumps([e.event_id for e in events])), ) - def _set_last_txn(self, as_id, txn_id): - return self.db_pool.runOperation( - self.engine.convert_param_style( - "INSERT INTO application_services_state(as_id, last_txn, state) " - "VALUES(?,?,?)" - ), - (as_id, txn_id, ApplicationServiceState.UP.value), - ) - def test_get_appservice_state_none( self, ) -> None: @@ -267,65 +259,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) txn = self.get_success( defer.ensureDeferred( - self.store.create_appservice_txn(service, events, [], [], {}, {}) + self.store.create_appservice_txn( + service, events, [], [], {}, {}, DeviceListUpdates() + ) ) ) self.assertEqual(txn.id, 1) self.assertEqual(txn.events, events) self.assertEqual(txn.service, service) - def test_create_appservice_txn_older_last_txn( - self, - ) -> None: - service = Mock(id=self.as_list[0]["id"]) - events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) - self.get_success(self._set_last_txn(service.id, 9643)) # AS is falling behind - self.get_success(self._insert_txn(service.id, 9644, events)) - self.get_success(self._insert_txn(service.id, 9645, events)) - txn = self.get_success( - self.store.create_appservice_txn(service, events, [], [], {}, {}) - ) - self.assertEqual(txn.id, 9646) - self.assertEqual(txn.events, events) - self.assertEqual(txn.service, service) - - def test_create_appservice_txn_up_to_date_last_txn( - self, - ) -> None: - service = Mock(id=self.as_list[0]["id"]) - events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) - self.get_success(self._set_last_txn(service.id, 9643)) - txn = self.get_success( - self.store.create_appservice_txn(service, events, [], [], {}, {}) - ) - self.assertEqual(txn.id, 9644) - self.assertEqual(txn.events, events) - self.assertEqual(txn.service, service) - - def test_create_appservice_txn_up_fuzzing( - self, - ) -> None: - service = Mock(id=self.as_list[0]["id"]) - events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) - self.get_success(self._set_last_txn(service.id, 9643)) - - # dump in rows with higher IDs to make sure the queries aren't wrong. - self.get_success(self._set_last_txn(self.as_list[1]["id"], 119643)) - self.get_success(self._set_last_txn(self.as_list[2]["id"], 9)) - self.get_success(self._set_last_txn(self.as_list[3]["id"], 9643)) - self.get_success(self._insert_txn(self.as_list[1]["id"], 119644, events)) - self.get_success(self._insert_txn(self.as_list[1]["id"], 119645, events)) - self.get_success(self._insert_txn(self.as_list[1]["id"], 119646, events)) - self.get_success(self._insert_txn(self.as_list[2]["id"], 10, events)) - self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events)) - - txn = self.get_success( - self.store.create_appservice_txn(service, events, [], [], {}, {}) - ) - self.assertEqual(txn.id, 9644) - self.assertEqual(txn.events, events) - self.assertEqual(txn.service, service) - def test_complete_appservice_txn_first_txn( self, ) -> None: @@ -359,13 +301,13 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): ) self.assertEqual(0, len(res)) - def test_complete_appservice_txn_existing_in_state_table( + def test_complete_appservice_txn_updates_last_txn_state( self, ) -> None: service = Mock(id=self.as_list[0]["id"]) events = [Mock(event_id="e1"), Mock(event_id="e2")] txn_id = 5 - self.get_success(self._set_last_txn(service.id, 4)) + self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP) self.get_success(self._insert_txn(service.id, txn_id, events)) self.get_success( self.store.complete_appservice_txn(txn_id=txn_id, service=service) @@ -416,6 +358,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): self.get_success(self._insert_txn(service.id, 12, other_events)) txn = self.get_success(self.store.get_oldest_unsent_txn(service)) + assert txn is not None self.assertEqual(service, txn.service) self.assertEqual(10, txn.id) self.assertEqual(events, txn.events) @@ -476,12 +419,12 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): value = self.get_success( self.store.get_type_stream_id_for_appservice(self.service, "read_receipt") ) - self.assertEqual(value, 0) + self.assertEqual(value, 1) value = self.get_success( self.store.get_type_stream_id_for_appservice(self.service, "presence") ) - self.assertEqual(value, 0) + self.assertEqual(value, 1) def test_get_type_stream_id_for_appservice_invalid_type(self) -> None: self.get_failure( diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index ce89c96912..b998ad42d9 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -68,6 +68,22 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): self.wait_for_background_updates() + def add_extremity(self, room_id: str, event_id: str) -> None: + """ + Add the given event as an extremity to the room. + """ + self.get_success( + self.hs.get_datastores().main.db_pool.simple_insert( + table="event_forward_extremities", + values={"room_id": room_id, "event_id": event_id}, + desc="test_add_extremity", + ) + ) + + self.hs.get_datastores().main.get_latest_event_ids_in_room.invalidate( + (room_id,) + ) + def test_soft_failed_extremities_handled_correctly(self): """Test that extremities are correctly calculated in the presence of soft failed events. @@ -250,7 +266,9 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): self.user = UserID.from_string(self.register_user("user1", "password")) self.token1 = self.login("user1", "password") self.requester = create_requester(self.user) - info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) + info, _ = self.get_success( + self.room_creator.create_room(self.requester, {"visibility": "public"}) + ) self.room_id = info["room_id"] self.event_creator = homeserver.get_event_creation_handler() homeserver.config.consent.user_consent_version = self.CONSENT_VERSION diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 21ffc5a909..d1227dd4ac 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -96,7 +96,9 @@ class DeviceStoreTestCase(HomeserverTestCase): # Add two device updates with sequential `stream_id`s self.get_success( - self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"]) + self.store.add_device_change_to_streams( + "user_id", device_ids, ["somehost"], ["!some:room"] + ) ) # Get all device updates ever meant for this remote @@ -122,7 +124,9 @@ class DeviceStoreTestCase(HomeserverTestCase): "device_id5", ] self.get_success( - self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"]) + self.store.add_device_change_to_streams( + "user_id", device_ids, ["somehost"], ["!some:room"] + ) ) # Get device updates meant for this remote @@ -144,7 +148,9 @@ class DeviceStoreTestCase(HomeserverTestCase): # Add some more device updates to ensure it still resumes properly device_ids = ["device_id6", "device_id7"] self.get_success( - self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"]) + self.store.add_device_change_to_streams( + "user_id", device_ids, ["somehost"], ["!some:room"] + ) ) # Get the next batch of device updates @@ -220,7 +226,7 @@ class DeviceStoreTestCase(HomeserverTestCase): self.get_success( self.store.add_device_change_to_streams( - "@user_id:test", device_ids, ["somehost"] + "@user_id:test", device_ids, ["somehost"], ["!some:room"] ) ) diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 395396340b..2d8d1f860f 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -157,10 +157,10 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_positions(), {"master": 7}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) - ctx1 = self.get_success(id_gen.get_next()) - ctx2 = self.get_success(id_gen.get_next()) - ctx3 = self.get_success(id_gen.get_next()) - ctx4 = self.get_success(id_gen.get_next()) + ctx1 = id_gen.get_next() + ctx2 = id_gen.get_next() + ctx3 = id_gen.get_next() + ctx4 = id_gen.get_next() s1 = self.get_success(ctx1.__aenter__()) s2 = self.get_success(ctx2.__aenter__()) @@ -362,8 +362,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) # Persist two rows at once - ctx1 = self.get_success(id_gen.get_next()) - ctx2 = self.get_success(id_gen.get_next()) + ctx1 = id_gen.get_next() + ctx2 = id_gen.get_next() s1 = self.get_success(ctx1.__aenter__()) s2 = self.get_success(ctx2.__aenter__()) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 79648d45db..60c8d37594 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -11,11 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict, List from unittest.mock import Mock from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import UserTypes +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable @@ -24,7 +28,7 @@ from tests.unittest import default_config, override_config FORTY_DAYS = 40 * 24 * 60 * 60 -def gen_3pids(count): +def gen_3pids(count: int) -> List[Dict[str, Any]]: """Generate `count` threepids as a list.""" return [ {"medium": "email", "address": "user%i@matrix.org" % i} for i in range(count) @@ -32,7 +36,7 @@ def gen_3pids(count): class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = default_config("test") config.update({"limit_usage_by_mau": True, "max_mau_value": 50}) @@ -44,10 +48,10 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): return config - def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastores().main + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main # Advance the clock a bit - reactor.advance(FORTY_DAYS) + self.reactor.advance(FORTY_DAYS) @override_config({"max_mau_value": 3, "mau_limit_reserved_threepids": gen_3pids(3)}) def test_initialise_reserved_users(self): @@ -245,7 +249,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): ) self.get_success(d) - self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) + self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] d = self.store.populate_monthly_active_users(user_id) self.get_success(d) @@ -253,9 +257,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.upsert_monthly_active_user.assert_not_called() def test_populate_monthly_users_should_update(self): - self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) + self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] - self.store.is_trial_user = Mock(return_value=defer.succeed(False)) + self.store.is_trial_user = Mock(return_value=defer.succeed(False)) # type: ignore[assignment] self.store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(None) @@ -266,9 +270,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.upsert_monthly_active_user.assert_called_once() def test_populate_monthly_users_should_not_update(self): - self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) + self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] - self.store.is_trial_user = Mock(return_value=defer.succeed(False)) + self.store.is_trial_user = Mock(return_value=defer.succeed(False)) # type: ignore[assignment] self.store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(self.hs.get_clock().time_msec()) ) @@ -324,16 +328,15 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): count = self.get_success(self.store.get_monthly_active_count()) self.assertEqual(count, 0) - d = self.store.register_user( - user_id=support_user_id, password_hash=None, user_type=UserTypes.SUPPORT + self.get_success( + self.store.register_user( + user_id=support_user_id, password_hash=None, user_type=UserTypes.SUPPORT + ) ) - self.get_success(d) - d = self.store.upsert_monthly_active_user(support_user_id) - self.get_success(d) + self.get_success(self.store.upsert_monthly_active_user(support_user_id)) - d = self.store.get_monthly_active_count() - count = self.get_success(d) + count = self.get_success(self.store.get_monthly_active_count()) self.assertEqual(count, 0) # Note that the max_mau_value setting should not matter. @@ -352,7 +355,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): @override_config({"limit_usage_by_mau": False, "mau_stats_only": False}) def test_no_users_when_not_tracking(self): - self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) + self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] self.get_success(self.store.populate_monthly_active_users("@user:sever")) @@ -388,16 +391,16 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.register_user(user_id=native_user1, password_hash=None) ) - count = self.get_success(self.store.get_monthly_active_count_by_service()) - self.assertEqual(count, {}) + count1 = self.get_success(self.store.get_monthly_active_count_by_service()) + self.assertEqual(count1, {}) self.get_success(self.store.upsert_monthly_active_user(native_user1)) self.get_success(self.store.upsert_monthly_active_user(appservice1_user1)) self.get_success(self.store.upsert_monthly_active_user(appservice1_user2)) self.get_success(self.store.upsert_monthly_active_user(appservice2_user1)) - count = self.get_success(self.store.get_monthly_active_count()) - self.assertEqual(count, 1) + count2 = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count2, 1) d = self.store.get_monthly_active_count_by_service() result = self.get_success(d) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 03e9cc7d4a..d8d17ef379 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -119,11 +119,9 @@ class RedactionTestCase(unittest.HomeserverTestCase): return event def test_redact(self): - self.get_success( - self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) - ) + self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) - msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t")) + msg_event = self.inject_message(self.room1, self.u_alice, "t") # Check event has not been redacted: event = self.get_success(self.store.get_event(msg_event.event_id)) @@ -141,9 +139,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): # Redact event reason = "Because I said so" - self.get_success( - self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason) - ) + self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason) event = self.get_success(self.store.get_event(msg_event.event_id)) @@ -170,14 +166,10 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) def test_redact_join(self): - self.get_success( - self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) - ) + self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) - msg_event = self.get_success( - self.inject_room_member( - self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"} - ) + msg_event = self.inject_room_member( + self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"} ) event = self.get_success(self.store.get_event(msg_event.event_id)) @@ -195,9 +187,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): # Redact event reason = "Because I said so" - self.get_success( - self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason) - ) + self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason) # Check redaction @@ -311,11 +301,9 @@ class RedactionTestCase(unittest.HomeserverTestCase): def test_redact_censor(self): """Test that a redacted event gets censored in the DB after a month""" - self.get_success( - self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) - ) + self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) - msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t")) + msg_event = self.inject_message(self.room1, self.u_alice, "t") # Check event has not been redacted: event = self.get_success(self.store.get_event(msg_event.event_id)) @@ -333,9 +321,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): # Redact event reason = "Because I said so" - self.get_success( - self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason) - ) + self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason) event = self.get_success(self.store.get_event(msg_event.event_id)) @@ -381,25 +367,19 @@ class RedactionTestCase(unittest.HomeserverTestCase): def test_redact_redaction(self): """Tests that we can redact a redaction and can fetch it again.""" - self.get_success( - self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) - ) + self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) - msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t")) + msg_event = self.inject_message(self.room1, self.u_alice, "t") - first_redact_event = self.get_success( - self.inject_redaction( - self.room1, msg_event.event_id, self.u_alice, "Redacting message" - ) + first_redact_event = self.inject_redaction( + self.room1, msg_event.event_id, self.u_alice, "Redacting message" ) - self.get_success( - self.inject_redaction( - self.room1, - first_redact_event.event_id, - self.u_alice, - "Redacting redaction", - ) + self.inject_redaction( + self.room1, + first_redact_event.event_id, + self.u_alice, + "Redacting redaction", ) # Now lets jump to the future where we have censored the redaction event @@ -414,9 +394,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): def test_store_redacted_redaction(self): """Tests that we can store a redacted redaction.""" - self.get_success( - self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) - ) + self.inject_room_member(self.room1, self.u_alice, Membership.JOIN) builder = self.event_builder_factory.for_room_version( RoomVersions.V1, diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index b8f09a8ee0..a2a9c05f24 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -12,11 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import Membership from synapse.rest.admin import register_servlets_for_client_rest_resource from synapse.rest.client import login, room +from synapse.server import HomeServer from synapse.types import UserID, create_requester +from synapse.util import Clock from tests import unittest from tests.server import TestHomeServer @@ -31,7 +34,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor, clock, hs: TestHomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: TestHomeServer) -> None: # We can't test the RoomMemberStore on its own without the other event # storage logic @@ -44,7 +47,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # User elsewhere on another host self.u_charlie = UserID.from_string("@charlie:elsewhere") - def test_one_member(self): + def test_one_member(self) -> None: # Alice creates the room, and is automatically joined self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) @@ -57,7 +60,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): self.assertEqual([self.room], [m.room_id for m in rooms_for_user]) - def test_count_known_servers(self): + def test_count_known_servers(self) -> None: """ _count_known_servers will calculate how many servers are in a room. """ @@ -68,7 +71,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): servers = self.get_success(self.store._count_known_servers()) self.assertEqual(servers, 2) - def test_count_known_servers_stat_counter_disabled(self): + def test_count_known_servers_stat_counter_disabled(self) -> None: """ If enabled, the metrics for how many servers are known will be counted. """ @@ -85,7 +88,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): @unittest.override_config( {"enable_metrics": True, "metrics_flags": {"known_servers": True}} ) - def test_count_known_servers_stat_counter_enabled(self): + def test_count_known_servers_stat_counter_enabled(self) -> None: """ If enabled, the metrics for how many servers are known will be counted. """ @@ -107,7 +110,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # It now knows about Charlie's server. self.assertEqual(self.store._known_servers_count, 2) - def test_get_joined_users_from_context(self): + def test_get_joined_users_from_context(self) -> None: room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) bob_event = self.get_success( event_injection.inject_member_event( @@ -161,7 +164,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): ) self.assertEqual(users.keys(), {self.u_alice, self.u_bob}) - def test__null_byte_in_display_name_properly_handled(self): + def test__null_byte_in_display_name_properly_handled(self) -> None: room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) res = self.get_success( @@ -211,11 +214,11 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, homeserver): - self.store = homeserver.get_datastores().main - self.room_creator = homeserver.get_room_creation_handler() + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.room_creator = hs.get_room_creation_handler() - def test_can_rerun_update(self): + def test_can_rerun_update(self) -> None: # First make sure we have completed all updates. self.wait_for_background_updates() diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index eaa0d7d749..52e41cdab4 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -110,9 +110,7 @@ class PaginationTestCase(HomeserverTestCase): def _filter_messages(self, filter: JsonDict) -> List[EventBase]: """Make a request to /messages with a filter, returns the chunk of events.""" - from_token = self.get_success( - self.hs.get_event_sources().get_current_token_for_pagination() - ) + from_token = self.hs.get_event_sources().get_current_token_for_pagination() events, next_key = self.get_success( self.hs.get_datastores().main.paginate_room_events( diff --git a/tests/test_phone_home.py b/tests/test_phone_home.py index 09707a74d7..b01cae6e5d 100644 --- a/tests/test_phone_home.py +++ b/tests/test_phone_home.py @@ -16,23 +16,24 @@ import resource from unittest import mock from synapse.app.phone_stats_home import phone_stats_home +from synapse.types import JsonDict from tests.unittest import HomeserverTestCase class PhoneHomeStatsTestCase(HomeserverTestCase): - def test_performance_frozen_clock(self): + def test_performance_frozen_clock(self) -> None: """ If time doesn't move, don't error out. """ past_stats = [ (self.hs.get_clock().time(), resource.getrusage(resource.RUSAGE_SELF)) ] - stats = {} + stats: JsonDict = {} self.get_success(phone_stats_home(self.hs, stats, past_stats)) self.assertEqual(stats["cpu_average"], 0) - def test_performance_100(self): + def test_performance_100(self) -> None: """ 1 second of usage over 1 second is 100% CPU usage. """ @@ -43,7 +44,8 @@ class PhoneHomeStatsTestCase(HomeserverTestCase): old_resource.ru_maxrss = real_res.ru_maxrss past_stats = [(self.hs.get_clock().time(), old_resource)] - stats = {} + stats: JsonDict = {} self.reactor.advance(1) - self.get_success(phone_stats_home(self.hs, stats, past_stats)) + # `old_resource` has type `Mock` instead of `struct_rusage` + self.get_success(phone_stats_home(self.hs, stats, past_stats)) # type: ignore[arg-type] self.assertApproximates(stats["cpu_average"], 100, tolerance=2.5) diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 532e3fe9cd..d0230f9ebb 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -17,6 +17,7 @@ from unittest.mock import patch from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, make_event_from_dict +from synapse.events.snapshot import EventContext from synapse.types import JsonDict, create_requester from synapse.visibility import filter_events_for_client, filter_events_for_server @@ -47,17 +48,15 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): # # before we do that, we persist some other events to act as state. - self.get_success(self._inject_visibility("@admin:hs", "joined")) + self._inject_visibility("@admin:hs", "joined") for i in range(0, 10): - self.get_success(self._inject_room_member("@resident%i:hs" % i)) + self._inject_room_member("@resident%i:hs" % i) events_to_filter = [] for i in range(0, 10): user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server") - evt = self.get_success( - self._inject_room_member(user, extra_content={"a": "b"}) - ) + evt = self._inject_room_member(user, extra_content={"a": "b"}) events_to_filter.append(evt) filtered = self.get_success( @@ -73,24 +72,57 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id) self.assertEqual(filtered[i].content["a"], "b") + def test_filter_outlier(self) -> None: + # outlier events must be returned, for the good of the collective federation + self._inject_room_member("@resident:remote_hs") + self._inject_visibility("@resident:remote_hs", "joined") + + outlier = self._inject_outlier() + self.assertEqual( + self.get_success( + filter_events_for_server(self.storage, "remote_hs", [outlier]) + ), + [outlier], + ) + + # it should also work when there are other events in the list + evt = self._inject_message("@unerased:local_hs") + + filtered = self.get_success( + filter_events_for_server(self.storage, "remote_hs", [outlier, evt]) + ) + self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}") + self.assertEqual(filtered[0], outlier) + self.assertEqual(filtered[1].event_id, evt.event_id) + self.assertEqual(filtered[1].content, evt.content) + + # ... but other servers should only be able to see the outlier (the other should + # be redacted) + filtered = self.get_success( + filter_events_for_server(self.storage, "other_server", [outlier, evt]) + ) + self.assertEqual(filtered[0], outlier) + self.assertEqual(filtered[1].event_id, evt.event_id) + self.assertNotIn("body", filtered[1].content) + def test_erased_user(self) -> None: # 4 message events, from erased and unerased users, with a membership # change in the middle of them. events_to_filter = [] - evt = self.get_success(self._inject_message("@unerased:local_hs")) + evt = self._inject_message("@unerased:local_hs") events_to_filter.append(evt) - evt = self.get_success(self._inject_message("@erased:local_hs")) + evt = self._inject_message("@erased:local_hs") events_to_filter.append(evt) - evt = self.get_success(self._inject_room_member("@joiner:remote_hs")) + evt = self._inject_room_member("@joiner:remote_hs") events_to_filter.append(evt) - evt = self.get_success(self._inject_message("@unerased:local_hs")) + evt = self._inject_message("@unerased:local_hs") events_to_filter.append(evt) - evt = self.get_success(self._inject_message("@erased:local_hs")) + evt = self._inject_message("@erased:local_hs") events_to_filter.append(evt) # the erasey user gets erased @@ -187,6 +219,25 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): self.get_success(self.storage.persistence.persist_event(event, context)) return event + def _inject_outlier(self) -> EventBase: + builder = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": "m.room.member", + "sender": "@test:user", + "state_key": "@test:user", + "room_id": TEST_ROOM_ID, + "content": {"membership": "join"}, + }, + ) + + event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[])) + event.internal_metadata.outlier = True + self.get_success( + self.storage.persistence.persist_event(event, EventContext.for_outlier()) + ) + return event + class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase): def test_out_of_band_invite_rejection(self): diff --git a/tests/unittest.py b/tests/unittest.py index 326895f4c9..9afa68c164 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -16,17 +16,17 @@ import gc import hashlib import hmac -import inspect import json import logging import secrets import time from typing import ( Any, - AnyStr, + Awaitable, Callable, ClassVar, Dict, + Generic, Iterable, List, Optional, @@ -40,6 +40,7 @@ from unittest.mock import Mock, patch import canonicaljson import signedjson.key import unpaddedbase64 +from typing_extensions import Protocol from twisted.internet.defer import Deferred, ensureDeferred from twisted.python.failure import Failure @@ -50,7 +51,7 @@ from twisted.web.resource import Resource from twisted.web.server import Request from synapse import events -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION @@ -71,7 +72,13 @@ from synapse.types import JsonDict, UserID, create_requester from synapse.util import Clock from synapse.util.httpresourcetree import create_resource_tree -from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver +from tests.server import ( + CustomHeaderType, + FakeChannel, + get_clock, + make_request, + setup_test_homeserver, +) from tests.test_utils import event_injection, setup_awaitable_errors from tests.test_utils.logging_setup import setup_logging from tests.utils import default_config, setupdb @@ -79,6 +86,17 @@ from tests.utils import default_config, setupdb setupdb() setup_logging() +TV = TypeVar("TV") +_ExcType = TypeVar("_ExcType", bound=BaseException, covariant=True) + + +class _TypedFailure(Generic[_ExcType], Protocol): + """Extension to twisted.Failure, where the 'value' has a certain type.""" + + @property + def value(self) -> _ExcType: + ... + def around(target): """A CLOS-style 'around' modifier, which wraps the original method of the @@ -277,6 +295,7 @@ class HomeserverTestCase(TestCase): if hasattr(self, "user_id"): if self.hijack_auth: + assert self.helper.auth_user_id is not None # We need a valid token ID to satisfy foreign key constraints. token_id = self.get_success( @@ -289,6 +308,7 @@ class HomeserverTestCase(TestCase): ) async def get_user_by_access_token(token=None, allow_guest=False): + assert self.helper.auth_user_id is not None return { "user": UserID.from_string(self.helper.auth_user_id), "token_id": token_id, @@ -296,6 +316,7 @@ class HomeserverTestCase(TestCase): } async def get_user_by_req(request, allow_guest=False, rights="access"): + assert self.helper.auth_user_id is not None return create_requester( UserID.from_string(self.helper.auth_user_id), token_id, @@ -312,7 +333,7 @@ class HomeserverTestCase(TestCase): ) if self.needs_threadpool: - self.reactor.threadpool = ThreadPool() + self.reactor.threadpool = ThreadPool() # type: ignore[assignment] self.addCleanup(self.reactor.threadpool.stop) self.reactor.threadpool.start() @@ -427,7 +448,7 @@ class HomeserverTestCase(TestCase): federation_auth_origin: Optional[bytes] = None, content_is_form: bool = False, await_result: bool = True, - custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, + custom_headers: Optional[Iterable[CustomHeaderType]] = None, client_ip: str = "127.0.0.1", ) -> FakeChannel: """ @@ -512,40 +533,36 @@ class HomeserverTestCase(TestCase): return hs - def pump(self, by=0.0): + def pump(self, by: float = 0.0) -> None: """ Pump the reactor enough that Deferreds will fire. """ self.reactor.pump([by] * 100) - def get_success(self, d, by=0.0): - if inspect.isawaitable(d): - d = ensureDeferred(d) - if not isinstance(d, Deferred): - return d + def get_success( + self, + d: Awaitable[TV], + by: float = 0.0, + ) -> TV: + deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type] self.pump(by=by) - return self.successResultOf(d) + return self.successResultOf(deferred) - def get_failure(self, d, exc): + def get_failure( + self, d: Awaitable[Any], exc: Type[_ExcType] + ) -> _TypedFailure[_ExcType]: """ Run a Deferred and get a Failure from it. The failure must be of the type `exc`. """ - if inspect.isawaitable(d): - d = ensureDeferred(d) - if not isinstance(d, Deferred): - return d + deferred: Deferred[Any] = ensureDeferred(d) # type: ignore[arg-type] self.pump() - return self.failureResultOf(d, exc) + return self.failureResultOf(deferred, exc) - def get_success_or_raise(self, d, by=0.0): + def get_success_or_raise(self, d: Awaitable[TV], by: float = 0.0) -> TV: """Drive deferred to completion and return result or raise exception on failure. """ - - if inspect.isawaitable(d): - deferred = ensureDeferred(d) - if not isinstance(deferred, Deferred): - return d + deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type] results: list = [] deferred.addBoth(results.append) @@ -653,11 +670,11 @@ class HomeserverTestCase(TestCase): def login( self, - username, - password, - device_id=None, - custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, - ): + username: str, + password: str, + device_id: Optional[str] = None, + custom_headers: Optional[Iterable[CustomHeaderType]] = None, + ) -> str: """ Log in a user, and get an access token. Requires the Login API be registered. @@ -679,18 +696,22 @@ class HomeserverTestCase(TestCase): return access_token def create_and_send_event( - self, room_id, user, soft_failed=False, prev_event_ids=None - ): + self, + room_id: str, + user: UserID, + soft_failed: bool = False, + prev_event_ids: Optional[List[str]] = None, + ) -> str: """ Create and send an event. Args: - soft_failed (bool): Whether to create a soft failed event or not - prev_event_ids (list[str]|None): Explicitly set the prev events, + soft_failed: Whether to create a soft failed event or not + prev_event_ids: Explicitly set the prev events, or if None just use the default Returns: - str: The new event's ID. + The new event's ID. """ event_creator = self.hs.get_event_creation_handler() requester = create_requester(user) @@ -717,34 +738,7 @@ class HomeserverTestCase(TestCase): return event.event_id - def add_extremity(self, room_id, event_id): - """ - Add the given event as an extremity to the room. - """ - self.get_success( - self.hs.get_datastores().main.db_pool.simple_insert( - table="event_forward_extremities", - values={"room_id": room_id, "event_id": event_id}, - desc="test_add_extremity", - ) - ) - - self.hs.get_datastores().main.get_latest_event_ids_in_room.invalidate( - (room_id,) - ) - - def attempt_wrong_password_login(self, username, password): - """Attempts to login as the user with the given password, asserting - that the attempt *fails*. - """ - body = {"type": "m.login.password", "user": username, "password": password} - - channel = self.make_request( - "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") - ) - self.assertEqual(channel.code, 403, channel.result) - - def inject_room_member(self, room: str, user: str, membership: Membership) -> None: + def inject_room_member(self, room: str, user: str, membership: str) -> None: """ Inject a membership event into a room. @@ -804,7 +798,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase): path: str, content: Optional[JsonDict] = None, await_result: bool = True, - custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, + custom_headers: Optional[Iterable[CustomHeaderType]] = None, client_ip: str = "127.0.0.1", ) -> FakeChannel: """Make an inbound signed federation request to this server @@ -837,7 +831,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase): self.site, method=method, path=path, - content=content, + content=content or "", shorthand=False, await_result=await_result, custom_headers=custom_headers, @@ -916,9 +910,6 @@ def override_config(extra_config): return decorator -TV = TypeVar("TV") - - def skip_unless(condition: bool, reason: str) -> Callable[[TV], TV]: """A test decorator which will skip the decorated test unless a condition is set diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py index c4a3917b23..fa132391a1 100644 --- a/tests/util/test_linearizer.py +++ b/tests/util/test_linearizer.py @@ -13,160 +13,202 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Callable, Hashable, Tuple + from twisted.internet import defer, reactor -from twisted.internet.defer import CancelledError +from twisted.internet.base import ReactorBase +from twisted.internet.defer import CancelledError, Deferred from synapse.logging.context import LoggingContext, current_context -from synapse.util import Clock from synapse.util.async_helpers import Linearizer from tests import unittest class LinearizerTestCase(unittest.TestCase): - @defer.inlineCallbacks - def test_linearizer(self): + def _start_task( + self, linearizer: Linearizer, key: Hashable + ) -> Tuple["Deferred[None]", "Deferred[None]", Callable[[], None]]: + """Starts a task which acquires the linearizer lock, blocks, then completes. + + Args: + linearizer: The `Linearizer`. + key: The `Linearizer` key. + + Returns: + A tuple containing: + * A cancellable `Deferred` for the entire task. + * A `Deferred` that resolves once the task acquires the lock. + * A function that unblocks the task. Must be called by the caller + to allow the task to release the lock and complete. + """ + acquired_d: "Deferred[None]" = Deferred() + unblock_d: "Deferred[None]" = Deferred() + + async def task() -> None: + with await linearizer.queue(key): + acquired_d.callback(None) + await unblock_d + + d = defer.ensureDeferred(task()) + + def unblock() -> None: + unblock_d.callback(None) + # The next task, if it exists, will acquire the lock and require a kick of + # the reactor to advance. + self._pump() + + return d, acquired_d, unblock + + def _pump(self) -> None: + """Pump the reactor to advance `Linearizer`s.""" + assert isinstance(reactor, ReactorBase) + while reactor.getDelayedCalls(): + reactor.runUntilCurrent() + + def test_linearizer(self) -> None: + """Tests that a task is queued up behind an earlier task.""" linearizer = Linearizer() key = object() - d1 = linearizer.queue(key) - cm1 = yield d1 + _, acquired_d1, unblock1 = self._start_task(linearizer, key) + self.assertTrue(acquired_d1.called) + + _, acquired_d2, unblock2 = self._start_task(linearizer, key) + self.assertFalse(acquired_d2.called) - d2 = linearizer.queue(key) - self.assertFalse(d2.called) + # Once the first task is done, the second task can continue. + unblock1() + self.assertTrue(acquired_d2.called) - with cm1: - self.assertFalse(d2.called) + unblock2() - with (yield d2): - pass + def test_linearizer_is_queued(self) -> None: + """Tests `Linearizer.is_queued`. - @defer.inlineCallbacks - def test_linearizer_is_queued(self): + Runs through the same scenario as `test_linearizer`. + """ linearizer = Linearizer() key = object() - d1 = linearizer.queue(key) - cm1 = yield d1 + _, acquired_d1, unblock1 = self._start_task(linearizer, key) + self.assertTrue(acquired_d1.called) - # Since d1 gets called immediately, "is_queued" should return false. + # Since the first task acquires the lock immediately, "is_queued" should return + # false. self.assertFalse(linearizer.is_queued(key)) - d2 = linearizer.queue(key) - self.assertFalse(d2.called) + _, acquired_d2, unblock2 = self._start_task(linearizer, key) + self.assertFalse(acquired_d2.called) - # Now d2 is queued up behind successful completion of cm1 + # Now the second task is queued up behind the first. self.assertTrue(linearizer.is_queued(key)) - with cm1: - self.assertFalse(d2.called) - - # cm1 still not done, so d2 still queued. - self.assertTrue(linearizer.is_queued(key)) + unblock1() - # And now d2 is called and nothing is in the queue again + # And now the second task acquires the lock and nothing is in the queue again. + self.assertTrue(acquired_d2.called) self.assertFalse(linearizer.is_queued(key)) - with (yield d2): - self.assertFalse(linearizer.is_queued(key)) - + unblock2() self.assertFalse(linearizer.is_queued(key)) - def test_lots_of_queued_things(self): - # we have one slow thing, and lots of fast things queued up behind it. - # it should *not* explode the stack. + def test_lots_of_queued_things(self) -> None: + """Tests lots of fast things queued up behind a slow thing. + + The stack should *not* explode when the slow thing completes. + """ linearizer = Linearizer() + key = "" - @defer.inlineCallbacks - def func(i, sleep=False): + async def func(i: int) -> None: with LoggingContext("func(%s)" % i) as lc: - with (yield linearizer.queue("")): + with (await linearizer.queue(key)): self.assertEqual(current_context(), lc) - if sleep: - yield Clock(reactor).sleep(0) self.assertEqual(current_context(), lc) - func(0, sleep=True) + _, _, unblock = self._start_task(linearizer, key) for i in range(1, 100): - func(i) + defer.ensureDeferred(func(i)) - return func(1000) + d = defer.ensureDeferred(func(1000)) + unblock() + self.successResultOf(d) - @defer.inlineCallbacks - def test_multiple_entries(self): + def test_multiple_entries(self) -> None: + """Tests a `Linearizer` with a concurrency above 1.""" limiter = Linearizer(max_count=3) key = object() - d1 = limiter.queue(key) - cm1 = yield d1 - - d2 = limiter.queue(key) - cm2 = yield d2 - - d3 = limiter.queue(key) - cm3 = yield d3 - - d4 = limiter.queue(key) - self.assertFalse(d4.called) - - d5 = limiter.queue(key) - self.assertFalse(d5.called) + _, acquired_d1, unblock1 = self._start_task(limiter, key) + self.assertTrue(acquired_d1.called) - with cm1: - self.assertFalse(d4.called) - self.assertFalse(d5.called) + _, acquired_d2, unblock2 = self._start_task(limiter, key) + self.assertTrue(acquired_d2.called) - cm4 = yield d4 - self.assertFalse(d5.called) + _, acquired_d3, unblock3 = self._start_task(limiter, key) + self.assertTrue(acquired_d3.called) - with cm3: - self.assertFalse(d5.called) + # These next two tasks have to wait. + _, acquired_d4, unblock4 = self._start_task(limiter, key) + self.assertFalse(acquired_d4.called) - cm5 = yield d5 + _, acquired_d5, unblock5 = self._start_task(limiter, key) + self.assertFalse(acquired_d5.called) - with cm2: - pass + # Once the first task completes, the fourth task can continue. + unblock1() + self.assertTrue(acquired_d4.called) + self.assertFalse(acquired_d5.called) - with cm4: - pass + # Once the third task completes, the fifth task can continue. + unblock3() + self.assertTrue(acquired_d5.called) - with cm5: - pass + # Make all tasks finish. + unblock2() + unblock4() + unblock5() - d6 = limiter.queue(key) - with (yield d6): - pass + # The next task shouldn't have to wait. + _, acquired_d6, unblock6 = self._start_task(limiter, key) + self.assertTrue(acquired_d6) + unblock6() - @defer.inlineCallbacks - def test_cancellation(self): + def test_cancellation(self) -> None: + """Tests cancellation while waiting for a `Linearizer`.""" linearizer = Linearizer() key = object() - d1 = linearizer.queue(key) - cm1 = yield d1 + d1, acquired_d1, unblock1 = self._start_task(linearizer, key) + self.assertTrue(acquired_d1.called) - d2 = linearizer.queue(key) - self.assertFalse(d2.called) + # Create a second task, waiting for the first task. + d2, acquired_d2, _ = self._start_task(linearizer, key) + self.assertFalse(acquired_d2.called) - d3 = linearizer.queue(key) - self.assertFalse(d3.called) + # Create a third task, waiting for the second task. + d3, acquired_d3, unblock3 = self._start_task(linearizer, key) + self.assertFalse(acquired_d3.called) + # Cancel the waiting second task. d2.cancel() - with cm1: - pass + unblock1() + self.successResultOf(d1) self.assertTrue(d2.called) - try: - yield d2 - self.fail("Expected d2 to raise CancelledError") - except CancelledError: - pass - - with (yield d3): - pass + self.failureResultOf(d2, CancelledError) + + # The third task should continue running. + self.assertTrue( + acquired_d3.called, + "Third task did not get the lock after the second task was cancelled", + ) + unblock3() + self.successResultOf(d3) diff --git a/tests/utils.py b/tests/utils.py index f6b1d60371..d4ba3a9b99 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -35,6 +35,11 @@ LEAVE_DB = os.environ.get("SYNAPSE_LEAVE_DB", False) POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", None) POSTGRES_HOST = os.environ.get("SYNAPSE_POSTGRES_HOST", None) POSTGRES_PASSWORD = os.environ.get("SYNAPSE_POSTGRES_PASSWORD", None) +POSTGRES_PORT = ( + int(os.environ["SYNAPSE_POSTGRES_PORT"]) + if "SYNAPSE_POSTGRES_PORT" in os.environ + else None +) POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),) # When debugging a specific test, it's occasionally useful to write the @@ -55,6 +60,7 @@ def setupdb(): db_conn = db_engine.module.connect( user=POSTGRES_USER, host=POSTGRES_HOST, + port=POSTGRES_PORT, password=POSTGRES_PASSWORD, dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, ) @@ -73,6 +79,7 @@ def setupdb(): database=POSTGRES_BASE_DB, user=POSTGRES_USER, host=POSTGRES_HOST, + port=POSTGRES_PORT, password=POSTGRES_PASSWORD, ) db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests") @@ -83,6 +90,7 @@ def setupdb(): db_conn = db_engine.module.connect( user=POSTGRES_USER, host=POSTGRES_HOST, + port=POSTGRES_PORT, password=POSTGRES_PASSWORD, dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, ) |