diff options
Diffstat (limited to 'tests')
33 files changed, 1413 insertions, 116 deletions
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 3d880c499d..1471cc1a28 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -77,11 +77,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock( - side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999}) - ) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) handler.federation_handler.do_invite_join = Mock( - side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) + return_value=make_awaitable(("", 1)) ) d = handler._remote_join( @@ -110,11 +108,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock( - side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999}) - ) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) handler.federation_handler.do_invite_join = Mock( - side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) + return_value=make_awaitable(("", 1)) ) d = handler._remote_join( @@ -150,11 +146,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock( - side_effect=lambda *args, **kwargs: make_awaitable(None) - ) + fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) handler.federation_handler.do_invite_join = Mock( - side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) + return_value=make_awaitable(("", 1)) ) # Artificially raise the complexity @@ -208,11 +202,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock( - side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999}) - ) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) handler.federation_handler.do_invite_join = Mock( - side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) + return_value=make_awaitable(("", 1)) ) d = handler._remote_join( @@ -240,11 +232,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock( - side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999}) - ) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) handler.federation_handler.do_invite_join = Mock( - side_effect=lambda *args, **kwargs: make_awaitable(("", 1)) + return_value=make_awaitable(("", 1)) ) d = handler._remote_join( diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py new file mode 100644 index 0000000000..1a3ccb263d --- /dev/null +++ b/tests/federation/test_federation_catch_up.py @@ -0,0 +1,422 @@ +from typing import List, Tuple + +from mock import Mock + +from synapse.events import EventBase +from synapse.federation.sender import PerDestinationQueue, TransactionManager +from synapse.federation.units import Edu +from synapse.rest import admin +from synapse.rest.client.v1 import login, room + +from tests.test_utils import event_injection, make_awaitable +from tests.unittest import FederatingHomeserverTestCase, override_config + + +class FederationCatchUpTestCases(FederatingHomeserverTestCase): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + return self.setup_test_homeserver( + federation_transport_client=Mock(spec=["send_transaction"]), + ) + + def prepare(self, reactor, clock, hs): + # stub out get_current_hosts_in_room + state_handler = hs.get_state_handler() + + # This mock is crucial for destination_rooms to be populated. + state_handler.get_current_hosts_in_room = Mock( + return_value=make_awaitable(["test", "host2"]) + ) + + # whenever send_transaction is called, record the pdu data + self.pdus = [] + self.failed_pdus = [] + self.is_online = True + self.hs.get_federation_transport_client().send_transaction.side_effect = ( + self.record_transaction + ) + + async def record_transaction(self, txn, json_cb): + if self.is_online: + data = json_cb() + self.pdus.extend(data["pdus"]) + return {} + else: + data = json_cb() + self.failed_pdus.extend(data["pdus"]) + raise IOError("Failed to connect because this is a test!") + + def get_destination_room(self, room: str, destination: str = "host2") -> dict: + """ + Gets the destination_rooms entry for a (destination, room_id) pair. + + Args: + room: room ID + destination: what destination, default is "host2" + + Returns: + Dictionary of { event_id: str, stream_ordering: int } + """ + event_id, stream_ordering = self.get_success( + self.hs.get_datastore().db_pool.execute( + "test:get_destination_rooms", + None, + """ + SELECT event_id, stream_ordering + FROM destination_rooms dr + JOIN events USING (stream_ordering) + WHERE dr.destination = ? AND dr.room_id = ? + """, + destination, + room, + ) + )[0] + return {"event_id": event_id, "stream_ordering": stream_ordering} + + @override_config({"send_federation": True}) + def test_catch_up_destination_rooms_tracking(self): + """ + Tests that we populate the `destination_rooms` table as needed. + """ + self.register_user("u1", "you the one") + u1_token = self.login("u1", "you the one") + room = self.helper.create_room_as("u1", tok=u1_token) + + self.get_success( + event_injection.inject_member_event(self.hs, room, "@user:host2", "join") + ) + + event_id_1 = self.helper.send(room, "wombats!", tok=u1_token)["event_id"] + + row_1 = self.get_destination_room(room) + + event_id_2 = self.helper.send(room, "rabbits!", tok=u1_token)["event_id"] + + row_2 = self.get_destination_room(room) + + # check: events correctly registered in order + self.assertEqual(row_1["event_id"], event_id_1) + self.assertEqual(row_2["event_id"], event_id_2) + self.assertEqual(row_1["stream_ordering"], row_2["stream_ordering"] - 1) + + @override_config({"send_federation": True}) + def test_catch_up_last_successful_stream_ordering_tracking(self): + """ + Tests that we populate the `destination_rooms` table as needed. + """ + self.register_user("u1", "you the one") + u1_token = self.login("u1", "you the one") + room = self.helper.create_room_as("u1", tok=u1_token) + + # take the remote offline + self.is_online = False + + self.get_success( + event_injection.inject_member_event(self.hs, room, "@user:host2", "join") + ) + + self.helper.send(room, "wombats!", tok=u1_token) + self.pump() + + lsso_1 = self.get_success( + self.hs.get_datastore().get_destination_last_successful_stream_ordering( + "host2" + ) + ) + + self.assertIsNone( + lsso_1, + "There should be no last successful stream ordering for an always-offline destination", + ) + + # bring the remote online + self.is_online = True + + event_id_2 = self.helper.send(room, "rabbits!", tok=u1_token)["event_id"] + + lsso_2 = self.get_success( + self.hs.get_datastore().get_destination_last_successful_stream_ordering( + "host2" + ) + ) + row_2 = self.get_destination_room(room) + + self.assertEqual( + self.pdus[0]["content"]["body"], + "rabbits!", + "Test fault: didn't receive the right PDU", + ) + self.assertEqual( + row_2["event_id"], + event_id_2, + "Test fault: destination_rooms not updated correctly", + ) + self.assertEqual( + lsso_2, + row_2["stream_ordering"], + "Send succeeded but not marked as last_successful_stream_ordering", + ) + + @override_config({"send_federation": True}) # critical to federate + def test_catch_up_from_blank_state(self): + """ + Runs an overall test of federation catch-up from scratch. + Further tests will focus on more narrow aspects and edge-cases, but I + hope to provide an overall view with this test. + """ + # bring the other server online + self.is_online = True + + # let's make some events for the other server to receive + self.register_user("u1", "you the one") + u1_token = self.login("u1", "you the one") + room_1 = self.helper.create_room_as("u1", tok=u1_token) + room_2 = self.helper.create_room_as("u1", tok=u1_token) + + # also critical to federate + self.get_success( + event_injection.inject_member_event(self.hs, room_1, "@user:host2", "join") + ) + self.get_success( + event_injection.inject_member_event(self.hs, room_2, "@user:host2", "join") + ) + + self.helper.send_state( + room_1, event_type="m.room.topic", body={"topic": "wombat"}, tok=u1_token + ) + + # check: PDU received for topic event + self.assertEqual(len(self.pdus), 1) + self.assertEqual(self.pdus[0]["type"], "m.room.topic") + + # take the remote offline + self.is_online = False + + # send another event + self.helper.send(room_1, "hi user!", tok=u1_token) + + # check: things didn't go well since the remote is down + self.assertEqual(len(self.failed_pdus), 1) + self.assertEqual(self.failed_pdus[0]["content"]["body"], "hi user!") + + # let's delete the federation transmission queue + # (this pretends we are starting up fresh.) + self.assertFalse( + self.hs.get_federation_sender() + ._per_destination_queues["host2"] + .transmission_loop_running + ) + del self.hs.get_federation_sender()._per_destination_queues["host2"] + + # let's also clear any backoffs + self.get_success( + self.hs.get_datastore().set_destination_retry_timings("host2", None, 0, 0) + ) + + # bring the remote online and clear the received pdu list + self.is_online = True + self.pdus = [] + + # now we need to initiate a federation transaction somehow… + # to do that, let's send another event (because it's simple to do) + # (do it to another room otherwise the catch-up logic decides it doesn't + # need to catch up room_1 — something I overlooked when first writing + # this test) + self.helper.send(room_2, "wombats!", tok=u1_token) + + # we should now have received both PDUs + self.assertEqual(len(self.pdus), 2) + self.assertEqual(self.pdus[0]["content"]["body"], "hi user!") + self.assertEqual(self.pdus[1]["content"]["body"], "wombats!") + + def make_fake_destination_queue( + self, destination: str = "host2" + ) -> Tuple[PerDestinationQueue, List[EventBase]]: + """ + Makes a fake per-destination queue. + """ + transaction_manager = TransactionManager(self.hs) + per_dest_queue = PerDestinationQueue(self.hs, transaction_manager, destination) + results_list = [] + + async def fake_send( + destination_tm: str, + pending_pdus: List[EventBase], + _pending_edus: List[Edu], + ) -> bool: + assert destination == destination_tm + results_list.extend(pending_pdus) + return True # success! + + transaction_manager.send_new_transaction = fake_send + + return per_dest_queue, results_list + + @override_config({"send_federation": True}) + def test_catch_up_loop(self): + """ + Tests the behaviour of _catch_up_transmission_loop. + """ + + # ARRANGE: + # - a local user (u1) + # - 3 rooms which u1 is joined to (and remote user @user:host2 is + # joined to) + # - some events (1 to 5) in those rooms + # we have 'already sent' events 1 and 2 to host2 + per_dest_queue, sent_pdus = self.make_fake_destination_queue() + + self.register_user("u1", "you the one") + u1_token = self.login("u1", "you the one") + room_1 = self.helper.create_room_as("u1", tok=u1_token) + room_2 = self.helper.create_room_as("u1", tok=u1_token) + room_3 = self.helper.create_room_as("u1", tok=u1_token) + self.get_success( + event_injection.inject_member_event(self.hs, room_1, "@user:host2", "join") + ) + self.get_success( + event_injection.inject_member_event(self.hs, room_2, "@user:host2", "join") + ) + self.get_success( + event_injection.inject_member_event(self.hs, room_3, "@user:host2", "join") + ) + + # create some events + self.helper.send(room_1, "you hear me!!", tok=u1_token) + event_id_2 = self.helper.send(room_2, "wombats!", tok=u1_token)["event_id"] + self.helper.send(room_3, "Matrix!", tok=u1_token) + event_id_4 = self.helper.send(room_2, "rabbits!", tok=u1_token)["event_id"] + event_id_5 = self.helper.send(room_3, "Synapse!", tok=u1_token)["event_id"] + + # destination_rooms should already be populated, but let us pretend that we already + # sent (successfully) up to and including event id 2 + event_2 = self.get_success(self.hs.get_datastore().get_event(event_id_2)) + + # also fetch event 5 so we know its last_successful_stream_ordering later + event_5 = self.get_success(self.hs.get_datastore().get_event(event_id_5)) + + self.get_success( + self.hs.get_datastore().set_destination_last_successful_stream_ordering( + "host2", event_2.internal_metadata.stream_ordering + ) + ) + + # ACT + self.get_success(per_dest_queue._catch_up_transmission_loop()) + + # ASSERT, noticing in particular: + # - event 3 not sent out, because event 5 replaces it + # - order is least recent first, so event 5 comes after event 4 + # - catch-up is completed + self.assertEqual(len(sent_pdus), 2) + self.assertEqual(sent_pdus[0].event_id, event_id_4) + self.assertEqual(sent_pdus[1].event_id, event_id_5) + self.assertFalse(per_dest_queue._catching_up) + self.assertEqual( + per_dest_queue._last_successful_stream_ordering, + event_5.internal_metadata.stream_ordering, + ) + + @override_config({"send_federation": True}) + def test_catch_up_on_synapse_startup(self): + """ + Tests the behaviour of get_catch_up_outstanding_destinations and + _wake_destinations_needing_catchup. + """ + + # list of sorted server names (note that there are more servers than the batch + # size used in get_catch_up_outstanding_destinations). + server_names = ["server%02d" % number for number in range(42)] + ["zzzerver"] + + # ARRANGE: + # - a local user (u1) + # - a room which u1 is joined to (and remote users @user:serverXX are + # joined to) + + # mark the remotes as online + self.is_online = True + + self.register_user("u1", "you the one") + u1_token = self.login("u1", "you the one") + room_id = self.helper.create_room_as("u1", tok=u1_token) + + for server_name in server_names: + self.get_success( + event_injection.inject_member_event( + self.hs, room_id, "@user:%s" % server_name, "join" + ) + ) + + # create an event + self.helper.send(room_id, "deary me!", tok=u1_token) + + # ASSERT: + # - All servers are up to date so none should have outstanding catch-up + outstanding_when_successful = self.get_success( + self.hs.get_datastore().get_catch_up_outstanding_destinations(None) + ) + self.assertEqual(outstanding_when_successful, []) + + # ACT: + # - Make the remote servers unreachable + self.is_online = False + + # - Mark zzzerver as being backed-off from + now = self.clock.time_msec() + self.get_success( + self.hs.get_datastore().set_destination_retry_timings( + "zzzerver", now, now, 24 * 60 * 60 * 1000 # retry in 1 day + ) + ) + + # - Send an event + self.helper.send(room_id, "can anyone hear me?", tok=u1_token) + + # ASSERT (get_catch_up_outstanding_destinations): + # - all remotes are outstanding + # - they are returned in batches of 25, in order + outstanding_1 = self.get_success( + self.hs.get_datastore().get_catch_up_outstanding_destinations(None) + ) + + self.assertEqual(len(outstanding_1), 25) + self.assertEqual(outstanding_1, server_names[0:25]) + + outstanding_2 = self.get_success( + self.hs.get_datastore().get_catch_up_outstanding_destinations( + outstanding_1[-1] + ) + ) + self.assertNotIn("zzzerver", outstanding_2) + self.assertEqual(len(outstanding_2), 17) + self.assertEqual(outstanding_2, server_names[25:-1]) + + # ACT: call _wake_destinations_needing_catchup + + # patch wake_destination to just count the destinations instead + woken = [] + + def wake_destination_track(destination): + woken.append(destination) + + self.hs.get_federation_sender().wake_destination = wake_destination_track + + # cancel the pre-existing timer for _wake_destinations_needing_catchup + # this is because we are calling it manually rather than waiting for it + # to be called automatically + self.hs.get_federation_sender()._catchup_after_startup_timer.cancel() + + self.get_success( + self.hs.get_federation_sender()._wake_destinations_needing_catchup(), by=5.0 + ) + + # ASSERT (_wake_destinations_needing_catchup): + # - all remotes are woken up, save for zzzerver + self.assertNotIn("zzzerver", woken) + # - all destinations are woken exactly once; they appear once in woken. + self.assertCountEqual(woken, server_names[:-1]) diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 5f512ff8bf..917762e6b6 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -34,7 +34,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): def make_homeserver(self, reactor, clock): mock_state_handler = Mock(spec=["get_current_hosts_in_room"]) # Ensure a new Awaitable is created for each call. - mock_state_handler.get_current_hosts_in_room.side_effect = lambda room_Id: make_awaitable( + mock_state_handler.get_current_hosts_in_room.return_value = make_awaitable( ["test", "host2"] ) return self.setup_test_homeserver( diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index c7efd3822d..97877c2e42 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -143,7 +143,7 @@ class AuthTestCase(unittest.TestCase): def test_mau_limits_exceeded_large(self): self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastore().get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.large_number_of_users) + return_value=make_awaitable(self.large_number_of_users) ) with self.assertRaises(ResourceLimitError): @@ -154,7 +154,7 @@ class AuthTestCase(unittest.TestCase): ) self.hs.get_datastore().get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.large_number_of_users) + return_value=make_awaitable(self.large_number_of_users) ) with self.assertRaises(ResourceLimitError): yield defer.ensureDeferred( @@ -169,7 +169,7 @@ class AuthTestCase(unittest.TestCase): # If not in monthly active cohort self.hs.get_datastore().get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value) + return_value=make_awaitable(self.auth_blocking._max_mau_value) ) with self.assertRaises(ResourceLimitError): yield defer.ensureDeferred( @@ -179,7 +179,7 @@ class AuthTestCase(unittest.TestCase): ) self.hs.get_datastore().get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value) + return_value=make_awaitable(self.auth_blocking._max_mau_value) ) with self.assertRaises(ResourceLimitError): yield defer.ensureDeferred( @@ -189,10 +189,10 @@ class AuthTestCase(unittest.TestCase): ) # If in monthly active cohort self.hs.get_datastore().user_last_seen_monthly_active = Mock( - side_effect=lambda user_id: make_awaitable(self.hs.get_clock().time_msec()) + return_value=make_awaitable(self.hs.get_clock().time_msec()) ) self.hs.get_datastore().get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value) + return_value=make_awaitable(self.auth_blocking._max_mau_value) ) yield defer.ensureDeferred( self.auth_handler.get_access_token_for_user_id( @@ -200,10 +200,10 @@ class AuthTestCase(unittest.TestCase): ) ) self.hs.get_datastore().user_last_seen_monthly_active = Mock( - side_effect=lambda user_id: make_awaitable(self.hs.get_clock().time_msec()) + return_value=make_awaitable(self.hs.get_clock().time_msec()) ) self.hs.get_datastore().get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value) + return_value=make_awaitable(self.auth_blocking._max_mau_value) ) yield defer.ensureDeferred( self.auth_handler.validate_short_term_login_token_and_get_user_id( @@ -216,7 +216,7 @@ class AuthTestCase(unittest.TestCase): self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastore().get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.small_number_of_users) + return_value=make_awaitable(self.small_number_of_users) ) # Ensure does not raise exception yield defer.ensureDeferred( @@ -226,7 +226,7 @@ class AuthTestCase(unittest.TestCase): ) self.hs.get_datastore().get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.small_number_of_users) + return_value=make_awaitable(self.small_number_of_users) ) yield defer.ensureDeferred( self.auth_handler.validate_short_term_login_token_and_get_user_id( diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 6aa322bf3a..969d44c787 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -35,6 +35,17 @@ class DeviceTestCase(unittest.HomeserverTestCase): # These tests assume that it starts 1000 seconds in. self.reactor.advance(1000) + def test_device_is_created_with_invalid_name(self): + self.get_failure( + self.handler.check_device_registered( + user_id="@boris:foo", + device_id="foo", + initial_device_display_name="a" + * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1), + ), + synapse.api.errors.SynapseError, + ) + def test_device_is_created_if_doesnt_exist(self): res = self.get_success( self.handler.check_device_registered( diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 210ddcbb88..366dcfb670 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -30,7 +30,7 @@ from tests import unittest, utils class E2eKeysHandlerTestCase(unittest.TestCase): def __init__(self, *args, **kwargs): - super(E2eKeysHandlerTestCase, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.hs = None # type: synapse.server.HomeServer self.handler = None # type: synapse.handlers.e2e_keys.E2eKeysHandler diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 3362050ce0..7adde9b9de 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -47,7 +47,7 @@ room_keys = { class E2eRoomKeysHandlerTestCase(unittest.TestCase): def __init__(self, *args, **kwargs): - super(E2eRoomKeysHandlerTestCase, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.hs = None # type: synapse.server.HomeServer self.handler = None # type: synapse.handlers.e2e_keys.E2eRoomKeysHandler diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index eddf5e2498..cb7c0ed51a 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -100,7 +100,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_get_or_create_user_mau_not_blocked(self): self.hs.config.limit_usage_by_mau = True self.store.count_monthly_users = Mock( - side_effect=lambda: make_awaitable(self.hs.config.max_mau_value - 1) + return_value=make_awaitable(self.hs.config.max_mau_value - 1) ) # Ensure does not throw exception self.get_success(self.get_or_create_user(self.requester, "c", "User")) @@ -108,7 +108,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_get_or_create_user_mau_blocked(self): self.hs.config.limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.lots_of_users) + return_value=make_awaitable(self.lots_of_users) ) self.get_failure( self.get_or_create_user(self.requester, "b", "display_name"), @@ -116,7 +116,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ) self.store.get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.max_mau_value) ) self.get_failure( self.get_or_create_user(self.requester, "b", "display_name"), @@ -126,14 +126,14 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_register_mau_blocked(self): self.hs.config.limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.lots_of_users) + return_value=make_awaitable(self.lots_of_users) ) self.get_failure( self.handler.register_user(localpart="local_part"), ResourceLimitError ) self.store.get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.max_mau_value) ) self.get_failure( self.handler.register_user(localpart="local_part"), ResourceLimitError diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index a609f148c0..312c0a0d41 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -54,7 +54,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.store.db_pool.simple_insert( "background_updates", { - "update_name": "populate_stats_process_rooms_2", + "update_name": "populate_stats_process_rooms", "progress_json": "{}", "depends_on": "populate_stats_prepare", }, @@ -66,7 +66,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): { "update_name": "populate_stats_process_users", "progress_json": "{}", - "depends_on": "populate_stats_process_rooms_2", + "depends_on": "populate_stats_process_rooms", }, ) ) @@ -219,10 +219,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.get_success( self.store.db_pool.simple_insert( "background_updates", - { - "update_name": "populate_stats_process_rooms_2", - "progress_json": "{}", - }, + {"update_name": "populate_stats_process_rooms", "progress_json": "{}"}, ) ) self.get_success( @@ -231,7 +228,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): { "update_name": "populate_stats_cleanup", "progress_json": "{}", - "depends_on": "populate_stats_process_rooms_2", + "depends_on": "populate_stats_process_rooms", }, ) ) @@ -728,7 +725,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.store.db_pool.simple_insert( "background_updates", { - "update_name": "populate_stats_process_rooms_2", + "update_name": "populate_stats_process_rooms", "progress_json": "{}", "depends_on": "populate_stats_prepare", }, @@ -740,7 +737,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): { "update_name": "populate_stats_process_users", "progress_json": "{}", - "depends_on": "populate_stats_process_rooms_2", + "depends_on": "populate_stats_process_rooms", }, ) ) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 7bf15c4ba9..3fec09ea8a 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -73,6 +73,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): "delivered_txn", "get_received_txn_response", "set_received_txn_response", + "get_destination_last_successful_stream_ordering", "get_destination_retry_timings", "get_devices_by_remote", "maybe_store_room_on_invite", @@ -80,6 +81,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): "get_user_directory_stream_pos", "get_current_state_deltas", "get_device_updates_by_remote", + "get_room_max_stream_ordering", ] ) @@ -116,10 +118,14 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): retry_timings_res ) - self.datastore.get_device_updates_by_remote.side_effect = lambda destination, from_stream_id, limit: make_awaitable( + self.datastore.get_device_updates_by_remote.return_value = make_awaitable( (0, []) ) + self.datastore.get_destination_last_successful_stream_ordering.return_value = make_awaitable( + None + ) + def get_received_txn_response(*args): return defer.succeed(None) diff --git a/tests/http/__init__.py b/tests/http/__init__.py index 5d41443293..3e5a856584 100644 --- a/tests/http/__init__.py +++ b/tests/http/__init__.py @@ -145,7 +145,7 @@ class TestServerTLSConnectionFactory: self._cert_file = create_test_cert_file(sanlist) def serverConnectionForTLS(self, tlsProtocol): - ctx = SSL.Context(SSL.TLSv1_METHOD) + ctx = SSL.Context(SSL.SSLv23_METHOD) ctx.use_certificate_file(self._cert_file) ctx.use_privatekey_file(get_test_key_file()) return Connection(ctx, None) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 561258a356..bc578411d6 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -58,7 +58,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): # Patch up the equality operator for events so that we can check # whether lists of events match using assertEquals self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)] - return super(SlavedEventStoreTestCase, self).setUp() + return super().setUp() def prepare(self, *args, **kwargs): super().prepare(*args, **kwargs) diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 8b4982ecb1..1d7edee5ba 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -45,7 +45,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): new event. """ mock_client = Mock(spec=["put_json"]) - mock_client.put_json.side_effect = lambda *_, **__: make_awaitable({}) + mock_client.put_json.return_value = make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", @@ -73,7 +73,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): new events. """ mock_client1 = Mock(spec=["put_json"]) - mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({}) + mock_client1.put_json.return_value = make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", { @@ -85,7 +85,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): ) mock_client2 = Mock(spec=["put_json"]) - mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({}) + mock_client2.put_json.return_value = make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", { @@ -136,7 +136,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): new typing EDUs. """ mock_client1 = Mock(spec=["put_json"]) - mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({}) + mock_client1.put_json.return_value = make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", { @@ -148,7 +148,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): ) mock_client2 = Mock(spec=["put_json"]) - mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({}) + mock_client2.put_json.return_value = make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", { diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index faa7f381a9..92c9058887 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -221,7 +221,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"]) # Ensure the display name was not updated. request, channel = self.make_request( diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 408c568a27..6dfc709dc5 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1174,6 +1174,8 @@ class RoomTestCase(unittest.HomeserverTestCase): self.assertIn("room_id", channel.json_body) self.assertIn("name", channel.json_body) + self.assertIn("topic", channel.json_body) + self.assertIn("avatar", channel.json_body) self.assertIn("canonical_alias", channel.json_body) self.assertIn("joined_members", channel.json_body) self.assertIn("joined_local_members", channel.json_body) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 160c630235..f96011fc1c 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -22,8 +22,8 @@ from mock import Mock import synapse.rest.admin from synapse.api.constants import UserTypes -from synapse.api.errors import HttpResponseException, ResourceLimitError -from synapse.rest.client.v1 import login +from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError +from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import sync from tests import unittest @@ -337,7 +337,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit store.get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -591,7 +591,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit self.store.get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -631,7 +631,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit self.store.get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -995,3 +995,95 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Ensure they're still alive self.assertEqual(0, channel.json_body["deactivated"]) + + +class UserMembershipRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + sync.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.url = "/_synapse/admin/v1/users/%s/joined_rooms" % urllib.parse.quote( + self.other_user + ) + + def test_no_auth(self): + """ + Try to list rooms of an user without authentication. + """ + request, channel = self.make_request("GET", self.url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + other_user_token = self.login("user", "pass") + + request, channel = self.make_request( + "GET", self.url, access_token=other_user_token, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_user_does_not_exist(self): + """ + Tests that a lookup for a user that does not exist returns a 404 + """ + url = "/_synapse/admin/v1/users/@unknown_person:test/joined_rooms" + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_user_is_not_local(self): + """ + Tests that a lookup for a user that is not a local returns a 400 + """ + url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/joined_rooms" + + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only lookup local users", channel.json_body["error"]) + + def test_get_rooms(self): + """ + Tests that a normal lookup for rooms is successfully + """ + # Create rooms and join + other_user_tok = self.login("user", "pass") + number_rooms = 5 + for n in range(number_rooms): + self.helper.create_room_as(self.other_user, tok=other_user_tok) + + # Get rooms + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(number_rooms, channel.json_body["total"]) + self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"])) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 2668662c9e..5d987a30c7 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -7,8 +7,9 @@ from mock import Mock import jwt import synapse.rest.admin +from synapse.appservice import ApplicationService from synapse.rest.client.v1 import login, logout -from synapse.rest.client.v2_alpha import devices +from synapse.rest.client.v2_alpha import devices, register from synapse.rest.client.v2_alpha.account import WhoamiRestServlet from tests import unittest @@ -748,3 +749,134 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): channel.json_body["error"], "JWT validation failed: Signature verification failed", ) + + +AS_USER = "as_user_alice" + + +class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): + servlets = [ + login.register_servlets, + register.register_servlets, + ] + + def register_as_user(self, username): + request, channel = self.make_request( + b"POST", + "/_matrix/client/r0/register?access_token=%s" % (self.service.token,), + {"username": username}, + ) + self.render(request) + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + + self.service = ApplicationService( + id="unique_identifier", + token="some_token", + hostname="example.com", + sender="@asbot:example.com", + namespaces={ + ApplicationService.NS_USERS: [ + {"regex": r"@as_user.*", "exclusive": False} + ], + ApplicationService.NS_ROOMS: [], + ApplicationService.NS_ALIASES: [], + }, + ) + self.another_service = ApplicationService( + id="another__identifier", + token="another_token", + hostname="example.com", + sender="@as2bot:example.com", + namespaces={ + ApplicationService.NS_USERS: [ + {"regex": r"@as2_user.*", "exclusive": False} + ], + ApplicationService.NS_ROOMS: [], + ApplicationService.NS_ALIASES: [], + }, + ) + + self.hs.get_datastore().services_cache.append(self.service) + self.hs.get_datastore().services_cache.append(self.another_service) + return self.hs + + def test_login_appservice_user(self): + """Test that an appservice user can use /login + """ + self.register_as_user(AS_USER) + + params = { + "type": login.LoginRestServlet.APPSERVICE_TYPE, + "identifier": {"type": "m.id.user", "user": AS_USER}, + } + request, channel = self.make_request( + b"POST", LOGIN_URL, params, access_token=self.service.token + ) + + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + def test_login_appservice_user_bot(self): + """Test that the appservice bot can use /login + """ + self.register_as_user(AS_USER) + + params = { + "type": login.LoginRestServlet.APPSERVICE_TYPE, + "identifier": {"type": "m.id.user", "user": self.service.sender}, + } + request, channel = self.make_request( + b"POST", LOGIN_URL, params, access_token=self.service.token + ) + + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + def test_login_appservice_wrong_user(self): + """Test that non-as users cannot login with the as token + """ + self.register_as_user(AS_USER) + + params = { + "type": login.LoginRestServlet.APPSERVICE_TYPE, + "identifier": {"type": "m.id.user", "user": "fibble_wibble"}, + } + request, channel = self.make_request( + b"POST", LOGIN_URL, params, access_token=self.service.token + ) + + self.render(request) + self.assertEquals(channel.result["code"], b"403", channel.result) + + def test_login_appservice_wrong_as(self): + """Test that as users cannot login with wrong as token + """ + self.register_as_user(AS_USER) + + params = { + "type": login.LoginRestServlet.APPSERVICE_TYPE, + "identifier": {"type": "m.id.user", "user": AS_USER}, + } + request, channel = self.make_request( + b"POST", LOGIN_URL, params, access_token=self.another_service.token + ) + + self.render(request) + self.assertEquals(channel.result["code"], b"403", channel.result) + + def test_login_appservice_no_token(self): + """Test that users must provide a token when using the appservice + login method + """ + self.register_as_user(AS_USER) + + params = { + "type": login.LoginRestServlet.APPSERVICE_TYPE, + "identifier": {"type": "m.id.user", "user": AS_USER}, + } + request, channel = self.make_request(b"POST", LOGIN_URL, params) + + self.render(request) + self.assertEquals(channel.result["code"], b"401", channel.result) diff --git a/tests/rest/client/v1/test_push_rule_attrs.py b/tests/rest/client/v1/test_push_rule_attrs.py new file mode 100644 index 0000000000..081052f6a6 --- /dev/null +++ b/tests/rest/client/v1/test_push_rule_attrs.py @@ -0,0 +1,448 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import synapse +from synapse.api.errors import Codes +from synapse.rest.client.v1 import login, push_rule, room + +from tests.unittest import HomeserverTestCase + + +class PushRuleAttributesTestCase(HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + push_rule.register_servlets, + ] + hijack_auth = False + + def test_enabled_on_creation(self): + """ + Tests the GET and PUT of push rules' `enabled` endpoints. + Tests that a rule is enabled upon creation, even though a rule with that + ruleId existed previously and was disabled. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # PUT a new rule + request, channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # GET enabled for that new rule + request, channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["enabled"], True) + + def test_enabled_on_recreation(self): + """ + Tests the GET and PUT of push rules' `enabled` endpoints. + Tests that a rule is enabled upon creation, even if a rule with that + ruleId existed previously and was disabled. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # PUT a new rule + request, channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # disable the rule + request, channel = self.make_request( + "PUT", + "/pushrules/global/override/best.friend/enabled", + {"enabled": False}, + access_token=token, + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # check rule disabled + request, channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["enabled"], False) + + # DELETE the rule + request, channel = self.make_request( + "DELETE", "/pushrules/global/override/best.friend", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # PUT a new rule + request, channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # GET enabled for that new rule + request, channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["enabled"], True) + + def test_enabled_disable(self): + """ + Tests the GET and PUT of push rules' `enabled` endpoints. + Tests that a rule is disabled and enabled when we ask for it. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # PUT a new rule + request, channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # disable the rule + request, channel = self.make_request( + "PUT", + "/pushrules/global/override/best.friend/enabled", + {"enabled": False}, + access_token=token, + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # check rule disabled + request, channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["enabled"], False) + + # re-enable the rule + request, channel = self.make_request( + "PUT", + "/pushrules/global/override/best.friend/enabled", + {"enabled": True}, + access_token=token, + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # check rule enabled + request, channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["enabled"], True) + + def test_enabled_404_when_get_non_existent(self): + """ + Tests that `enabled` gives 404 when the rule doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # check 404 for never-heard-of rule + request, channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + # PUT a new rule + request, channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # GET enabled for that new rule + request, channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # DELETE the rule + request, channel = self.make_request( + "DELETE", "/pushrules/global/override/best.friend", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # check 404 for deleted rule + request, channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_enabled_404_when_get_non_existent_server_rule(self): + """ + Tests that `enabled` gives 404 when the server-default rule doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + # check 404 for never-heard-of rule + request, channel = self.make_request( + "GET", "/pushrules/global/override/.m.muahahaha/enabled", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_enabled_404_when_put_non_existent_rule(self): + """ + Tests that `enabled` gives 404 when we put to a rule that doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + # enable & check 404 for never-heard-of rule + request, channel = self.make_request( + "PUT", + "/pushrules/global/override/best.friend/enabled", + {"enabled": True}, + access_token=token, + ) + self.render(request) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_enabled_404_when_put_non_existent_server_rule(self): + """ + Tests that `enabled` gives 404 when we put to a server-default rule that doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + # enable & check 404 for never-heard-of rule + request, channel = self.make_request( + "PUT", + "/pushrules/global/override/.m.muahahah/enabled", + {"enabled": True}, + access_token=token, + ) + self.render(request) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_actions_get(self): + """ + Tests that `actions` gives you what you expect on a fresh rule. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # PUT a new rule + request, channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # GET actions for that new rule + request, channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/actions", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body["actions"], ["notify", {"set_tweak": "highlight"}] + ) + + def test_actions_put(self): + """ + Tests that PUT on actions updates the value you'd get from GET. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # PUT a new rule + request, channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # change the rule actions + request, channel = self.make_request( + "PUT", + "/pushrules/global/override/best.friend/actions", + {"actions": ["dont_notify"]}, + access_token=token, + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # GET actions for that new rule + request, channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/actions", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["actions"], ["dont_notify"]) + + def test_actions_404_when_get_non_existent(self): + """ + Tests that `actions` gives 404 when the rule doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # check 404 for never-heard-of rule + request, channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + # PUT a new rule + request, channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # DELETE the rule + request, channel = self.make_request( + "DELETE", "/pushrules/global/override/best.friend", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 200) + + # check 404 for deleted rule + request, channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_actions_404_when_get_non_existent_server_rule(self): + """ + Tests that `actions` gives 404 when the server-default rule doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + # check 404 for never-heard-of rule + request, channel = self.make_request( + "GET", "/pushrules/global/override/.m.muahahaha/actions", access_token=token + ) + self.render(request) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_actions_404_when_put_non_existent_rule(self): + """ + Tests that `actions` gives 404 when putting to a rule that doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + # enable & check 404 for never-heard-of rule + request, channel = self.make_request( + "PUT", + "/pushrules/global/override/best.friend/actions", + {"actions": ["dont_notify"]}, + access_token=token, + ) + self.render(request) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_actions_404_when_put_non_existent_server_rule(self): + """ + Tests that `actions` gives 404 when putting to a server-default rule that doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + # enable & check 404 for never-heard-of rule + request, channel = self.make_request( + "PUT", + "/pushrules/global/override/.m.muahahah/actions", + {"actions": ["dont_notify"]}, + access_token=token, + ) + self.render(request) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 152a5182fa..93f899d861 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -14,11 +14,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import json import os import re from email.parser import Parser +from typing import Optional +from urllib.parse import urlencode import pkg_resources @@ -27,8 +28,10 @@ from synapse.api.constants import LoginType, Membership from synapse.api.errors import Codes from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import account, register +from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource from tests import unittest +from tests.unittest import override_config class PasswordResetTestCase(unittest.HomeserverTestCase): @@ -69,6 +72,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() + self.submit_token_resource = PasswordResetSubmitTokenResource(hs) def test_basic_password_reset(self): """Test basic password reset flow @@ -250,8 +254,32 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # Remove the host path = link.replace("https://example.com", "") + # Load the password reset confirmation page request, channel = self.make_request("GET", path, shorthand=False) - self.render(request) + request.render(self.submit_token_resource) + self.pump() + self.assertEquals(200, channel.code, channel.result) + + # Now POST to the same endpoint, mimicking the same behaviour as clicking the + # password reset confirm button + + # Send arguments as url-encoded form data, matching the template's behaviour + form_args = [] + for key, value_list in request.args.items(): + for value in value_list: + arg = (key, value) + form_args.append(arg) + + # Confirm the password reset + request, channel = self.make_request( + "POST", + path, + content=urlencode(form_args).encode("utf8"), + shorthand=False, + content_is_form=True, + ) + request.render(self.submit_token_resource) + self.pump() self.assertEquals(200, channel.code, channel.result) def _get_link_from_email(self): @@ -668,16 +696,104 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) - def _request_token(self, email, client_secret): + @override_config({"next_link_domain_whitelist": None}) + def test_next_link(self): + """Tests a valid next_link parameter value with no whitelist (good case)""" + self._request_token( + "something@example.com", + "some_secret", + next_link="https://example.com/a/good/site", + expect_code=200, + ) + + @override_config({"next_link_domain_whitelist": None}) + def test_next_link_exotic_protocol(self): + """Tests using a esoteric protocol as a next_link parameter value. + Someone may be hosting a client on IPFS etc. + """ + self._request_token( + "something@example.com", + "some_secret", + next_link="some-protocol://abcdefghijklmopqrstuvwxyz", + expect_code=200, + ) + + @override_config({"next_link_domain_whitelist": None}) + def test_next_link_file_uri(self): + """Tests next_link parameters cannot be file URI""" + # Attempt to use a next_link value that points to the local disk + self._request_token( + "something@example.com", + "some_secret", + next_link="file:///host/path", + expect_code=400, + ) + + @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]}) + def test_next_link_domain_whitelist(self): + """Tests next_link parameters must fit the whitelist if provided""" + self._request_token( + "something@example.com", + "some_secret", + next_link="https://example.com/some/good/page", + expect_code=200, + ) + + self._request_token( + "something@example.com", + "some_secret", + next_link="https://example.org/some/also/good/page", + expect_code=200, + ) + + self._request_token( + "something@example.com", + "some_secret", + next_link="https://bad.example.org/some/bad/page", + expect_code=400, + ) + + @override_config({"next_link_domain_whitelist": []}) + def test_empty_next_link_domain_whitelist(self): + """Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially + disallowed + """ + self._request_token( + "something@example.com", + "some_secret", + next_link="https://example.com/a/page", + expect_code=400, + ) + + def _request_token( + self, + email: str, + client_secret: str, + next_link: Optional[str] = None, + expect_code: int = 200, + ) -> str: + """Request a validation token to add an email address to a user's account + + Args: + email: The email address to validate + client_secret: A secret string + next_link: A link to redirect the user to after validation + expect_code: Expected return code of the call + + Returns: + The ID of the new threepid validation session + """ + body = {"client_secret": client_secret, "email": email, "send_attempt": 1} + if next_link: + body["next_link"] = next_link + request, channel = self.make_request( - "POST", - b"account/3pid/email/requestToken", - {"client_secret": client_secret, "email": email, "send_attempt": 1}, + "POST", b"account/3pid/email/requestToken", body, ) self.render(request) - self.assertEquals(200, channel.code, channel.result) + self.assertEquals(expect_code, channel.code, channel.result) - return channel.json_body["sid"] + return channel.json_body.get("sid") def _request_token_invalid_email( self, email, expected_errcode, expected_error, client_secret="foobar", diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index f4f3e56777..5f897d49cf 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -120,12 +120,13 @@ class _TestImage: extension = attr.ib(type=bytes) expected_cropped = attr.ib(type=Optional[bytes]) expected_scaled = attr.ib(type=Optional[bytes]) + expected_found = attr.ib(default=True, type=bool) @parameterized_class( ("test_image",), [ - # smol png + # smoll png ( _TestImage( unhexlify( @@ -161,6 +162,8 @@ class _TestImage: None, ), ), + # an empty file + (_TestImage(b"", b"image/gif", b".gif", None, None, False,),), ], ) class MediaRepoTests(unittest.HomeserverTestCase): @@ -303,12 +306,16 @@ class MediaRepoTests(unittest.HomeserverTestCase): self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None) def test_thumbnail_crop(self): - self._test_thumbnail("crop", self.test_image.expected_cropped) + self._test_thumbnail( + "crop", self.test_image.expected_cropped, self.test_image.expected_found + ) def test_thumbnail_scale(self): - self._test_thumbnail("scale", self.test_image.expected_scaled) + self._test_thumbnail( + "scale", self.test_image.expected_scaled, self.test_image.expected_found + ) - def _test_thumbnail(self, method, expected_body): + def _test_thumbnail(self, method, expected_body, expected_found): params = "?width=32&height=32&method=" + method request, channel = self.make_request( "GET", self.media_id + params, shorthand=False @@ -325,11 +332,23 @@ class MediaRepoTests(unittest.HomeserverTestCase): ) self.pump() - self.assertEqual(channel.code, 200) - if expected_body is not None: + if expected_found: + self.assertEqual(channel.code, 200) + if expected_body is not None: + self.assertEqual( + channel.result["body"], expected_body, channel.result["body"] + ) + else: + # ensure that the result is at least some valid image + Image.open(BytesIO(channel.result["body"])) + else: + # A 404 with a JSON body. + self.assertEqual(channel.code, 404) self.assertEqual( - channel.result["body"], expected_body, channel.result["body"] + channel.json_body, + { + "errcode": "M_NOT_FOUND", + "error": "Not found [b'example.com', b'12345?width=32&height=32&method=%s']" + % method, + }, ) - else: - # ensure that the result is at least some valid image - Image.open(BytesIO(channel.result["body"])) diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py index b090bb974c..dcd65c2a50 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py @@ -21,7 +21,7 @@ from tests import unittest class WellKnownTests(unittest.HomeserverTestCase): def setUp(self): - super(WellKnownTests, self).setUp() + super().setUp() # replace the JsonResource with a WellKnownResource self.resource = WellKnownResource(self.hs) diff --git a/tests/server.py b/tests/server.py index 48e45c6c8b..b404ad4e2a 100644 --- a/tests/server.py +++ b/tests/server.py @@ -1,6 +1,6 @@ import json import logging -from io import BytesIO +from io import SEEK_END, BytesIO import attr from zope.interface import implementer @@ -135,6 +135,7 @@ def make_request( request=SynapseRequest, shorthand=True, federation_auth_origin=None, + content_is_form=False, ): """ Make a web request using the given method and path, feed it the @@ -150,6 +151,8 @@ def make_request( with the usual REST API path, if it doesn't contain it. federation_auth_origin (bytes|None): if set to not-None, we will add a fake Authorization header pretenting to be the given server name. + content_is_form: Whether the content is URL encoded form data. Adds the + 'Content-Type': 'application/x-www-form-urlencoded' header. Returns: Tuple[synapse.http.site.SynapseRequest, channel] @@ -181,6 +184,8 @@ def make_request( req = request(channel) req.process = lambda: b"" req.content = BytesIO(content) + # Twisted expects to be at the end of the content when parsing the request. + req.content.seek(SEEK_END) req.postpath = list(map(unquote, path[1:].split(b"/"))) if access_token: @@ -195,7 +200,13 @@ def make_request( ) if content: - req.requestHeaders.addRawHeader(b"Content-Type", b"application/json") + if content_is_form: + req.requestHeaders.addRawHeader( + b"Content-Type", b"application/x-www-form-urlencoded" + ) + else: + # Assume the body is JSON + req.requestHeaders.addRawHeader(b"Content-Type", b"application/json") req.requestReceived(method, path, b"1.1") @@ -249,7 +260,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): return succeed(lookups[name]) self.nameResolver = SimpleResolverComplexifier(FakeResolver()) - super(ThreadedMemoryReactorClock, self).__init__() + super().__init__() def listenUDP(self, port, protocol, interface="", maxPacketSize=8196): p = udp.Port(port, protocol, interface, maxPacketSize, self) diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 973338ea71..6382b19dc3 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -67,7 +67,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): raise Exception("Failed to find reference to ResourceLimitsServerNotices") self._rlsn._store.user_last_seen_monthly_active = Mock( - side_effect=lambda user_id: make_awaitable(1000) + return_value=make_awaitable(1000) ) self._rlsn._server_notices_manager.send_notice = Mock( return_value=defer.succeed(Mock()) @@ -80,9 +80,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): return_value=defer.succeed("!something:localhost") ) self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None)) - self._rlsn._store.get_tags_for_room = Mock( - side_effect=lambda user_id, room_id: make_awaitable({}) - ) + self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({})) @override_config({"hs_disabled": True}) def test_maybe_send_server_notice_disabled_hs(self): @@ -158,7 +156,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) self._rlsn._store.user_last_seen_monthly_active = Mock( - side_effect=lambda user_id: make_awaitable(None) + return_value=make_awaitable(None) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -261,12 +259,10 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): self.user_id = "@user_id:test" def test_server_notice_only_sent_once(self): - self.store.get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(1000) - ) + self.store.get_monthly_active_count = Mock(return_value=make_awaitable(1000)) self.store.user_last_seen_monthly_active = Mock( - side_effect=lambda user_id: make_awaitable(1000) + return_value=make_awaitable(1000) ) # Call the function multiple times to ensure we only send the notice once diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index cb808d4de4..46f94914ff 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -413,7 +413,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): # required for ApplicationServiceTransactionStoreTestCase tests class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore): def __init__(self, database: DatabasePool, db_conn, hs): - super(TestTransactionStore, self).__init__(database, db_conn, hs) + super().__init__(database, db_conn, hs) class ApplicationServiceStoreConfigTestCase(unittest.TestCase): diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 370c247e16..755c70db31 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -154,7 +154,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): user_id = "@user:server" self.store.get_monthly_active_count = Mock( - side_effect=lambda: make_awaitable(lots_of_users) + return_value=make_awaitable(lots_of_users) ) self.get_success( self.store.insert_client_ip( diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 34ae8c9da7..ecb00f4e02 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -23,7 +23,7 @@ import tests.utils class DeviceStoreTestCase(tests.unittest.TestCase): def __init__(self, *args, **kwargs): - super(DeviceStoreTestCase, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.store = None # type: synapse.storage.DataStore @defer.inlineCallbacks diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index f0a8e32f1e..20636fc400 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -122,6 +122,56 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_positions(), {"master": 8}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) + def test_out_of_order_finish(self): + """Test that IDs persisted out of order are correctly handled + """ + + # Prefill table with 7 rows written by 'master' + self._insert_rows("master", 7) + + id_gen = self._create_id_generator() + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) + + ctx1 = self.get_success(id_gen.get_next()) + ctx2 = self.get_success(id_gen.get_next()) + ctx3 = self.get_success(id_gen.get_next()) + ctx4 = self.get_success(id_gen.get_next()) + + s1 = ctx1.__enter__() + s2 = ctx2.__enter__() + s3 = ctx3.__enter__() + s4 = ctx4.__enter__() + + self.assertEqual(s1, 8) + self.assertEqual(s2, 9) + self.assertEqual(s3, 10) + self.assertEqual(s4, 11) + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) + + ctx2.__exit__(None, None, None) + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) + + ctx1.__exit__(None, None, None) + + self.assertEqual(id_gen.get_positions(), {"master": 9}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 9) + + ctx4.__exit__(None, None, None) + + self.assertEqual(id_gen.get_positions(), {"master": 9}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 9) + + ctx3.__exit__(None, None, None) + + self.assertEqual(id_gen.get_positions(), {"master": 11}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 11) + def test_multi_instance(self): """Test that reads and writes from multiple processes are handled correctly. diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 9870c74883..643072bbaf 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -231,9 +231,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): ) self.get_success(d) - self.store.upsert_monthly_active_user = Mock( - side_effect=lambda user_id: make_awaitable(None) - ) + self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) d = self.store.populate_monthly_active_users(user_id) self.get_success(d) @@ -241,9 +239,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.upsert_monthly_active_user.assert_not_called() def test_populate_monthly_users_should_update(self): - self.store.upsert_monthly_active_user = Mock( - side_effect=lambda user_id: make_awaitable(None) - ) + self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) self.store.is_trial_user = Mock(return_value=defer.succeed(False)) @@ -256,9 +252,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.upsert_monthly_active_user.assert_called_once() def test_populate_monthly_users_should_not_update(self): - self.store.upsert_monthly_active_user = Mock( - side_effect=lambda user_id: make_awaitable(None) - ) + self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) self.store.is_trial_user = Mock(return_value=defer.succeed(False)) self.store.user_last_seen_monthly_active = Mock( @@ -344,9 +338,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): @override_config({"limit_usage_by_mau": False, "mau_stats_only": False}) def test_no_users_when_not_tracking(self): - self.store.upsert_monthly_active_user = Mock( - side_effect=lambda user_id: make_awaitable(None) - ) + self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) self.get_success(self.store.populate_monthly_active_users("@user:sever")) diff --git a/tests/test_state.py b/tests/test_state.py index 2d58467932..80b0ccbc40 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -125,7 +125,7 @@ class StateGroupStore: class DictObj(dict): def __init__(self, **kwargs): - super(DictObj, self).__init__(kwargs) + super().__init__(kwargs) self.__dict__ = self diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index 508aeba078..a298cc0fd3 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -17,6 +17,7 @@ """ Utilities for running the unit tests """ +from asyncio import Future from typing import Any, Awaitable, TypeVar TV = TypeVar("TV") @@ -38,6 +39,12 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV: raise Exception("awaitable has not yet completed") -async def make_awaitable(result: Any): - """Create an awaitable that just returns a result.""" - return result +def make_awaitable(result: Any) -> Awaitable[Any]: + """ + Makes an awaitable, suitable for mocking an `async` function. + This uses Futures as they can be awaited multiple times so can be returned + to multiple callers. + """ + future = Future() # type: ignore + future.set_result(result) + return future diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index fb1ca90336..e93aa84405 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -71,7 +71,10 @@ async def inject_event( """ event, context = await create_event(hs, room_version, prev_event_ids, **kwargs) - await hs.get_storage().persistence.persist_event(event, context) + persistence = hs.get_storage().persistence + assert persistence is not None + + await persistence.persist_event(event, context) return event diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py index 2d96b0fa8d..fdfb840b62 100644 --- a/tests/test_utils/logging_setup.py +++ b/tests/test_utils/logging_setup.py @@ -29,8 +29,7 @@ class ToTwistedHandler(logging.Handler): log_entry = self.format(record) log_level = record.levelname.lower().replace("warning", "warn") self.tx_log.emit( - twisted.logger.LogLevel.levelWithName(log_level), - log_entry.replace("{", r"(").replace("}", r")"), + twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry ) diff --git a/tests/unittest.py b/tests/unittest.py index 3cb55a7e96..dabf69cff4 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -92,7 +92,7 @@ class TestCase(unittest.TestCase): root logger's logging level while that test (case|method) runs.""" def __init__(self, methodName, *args, **kwargs): - super(TestCase, self).__init__(methodName, *args, **kwargs) + super().__init__(methodName, *args, **kwargs) method = getattr(self, methodName) @@ -353,6 +353,7 @@ class HomeserverTestCase(TestCase): request: Type[T] = SynapseRequest, shorthand: bool = True, federation_auth_origin: str = None, + content_is_form: bool = False, ) -> Tuple[T, FakeChannel]: """ Create a SynapseRequest at the path using the method and containing the @@ -368,6 +369,8 @@ class HomeserverTestCase(TestCase): with the usual REST API path, if it doesn't contain it. federation_auth_origin (bytes|None): if set to not-None, we will add a fake Authorization header pretenting to be the given server name. + content_is_form: Whether the content is URL encoded form data. Adds the + 'Content-Type': 'application/x-www-form-urlencoded' header. Returns: Tuple[synapse.http.site.SynapseRequest, channel] @@ -384,6 +387,7 @@ class HomeserverTestCase(TestCase): request, shorthand, federation_auth_origin, + content_is_form, ) def render(self, request): |