summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/api/test_ratelimiting.py73
-rw-r--r--tests/handlers/test_oidc.py18
-rw-r--r--tests/handlers/test_profile.py56
-rw-r--r--tests/handlers/test_register.py52
-rw-r--r--tests/handlers/test_stats.py6
-rw-r--r--tests/handlers/test_typing.py30
-rw-r--r--tests/handlers/test_user_directory.py6
-rw-r--r--tests/http/test_fedclient.py2
-rw-r--r--tests/module_api/test_api.py2
-rw-r--r--tests/push/test_email.py2
-rw-r--r--tests/replication/test_federation_sender_shard.py4
-rw-r--r--tests/rest/client/test_retention.py96
-rw-r--r--tests/rest/client/test_shadow_banned.py312
-rw-r--r--tests/rest/client/v1/test_login.py16
-rw-r--r--tests/rest/client/v1/test_rooms.py88
-rw-r--r--tests/rest/client/v1/utils.py10
-rw-r--r--tests/rest/client/v2_alpha/test_register.py4
-rw-r--r--tests/server_notices/test_resource_limits_server_notices.py6
-rw-r--r--tests/storage/test_appservice.py3
-rw-r--r--tests/storage/test_background_update.py5
-rw-r--r--tests/storage/test_base.py60
-rw-r--r--tests/storage/test_cleanup_extrems.py5
-rw-r--r--tests/storage/test_devices.py8
-rw-r--r--tests/storage/test_directory.py6
-rw-r--r--tests/storage/test_id_generators.py100
-rw-r--r--tests/storage/test_main.py4
-rw-r--r--tests/storage/test_profile.py23
-rw-r--r--tests/storage/test_registration.py33
-rw-r--r--tests/storage/test_room.py20
-rw-r--r--tests/storage/test_roommember.py6
-rw-r--r--tests/test_federation.py50
-rw-r--r--tests/util/caches/test_descriptors.py12
-rw-r--r--tests/util/test_retryutils.py2
33 files changed, 943 insertions, 177 deletions
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py

index d580e729c5..1e1f30d790 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py
@@ -1,4 +1,6 @@ from synapse.api.ratelimiting import LimitExceededError, Ratelimiter +from synapse.appservice import ApplicationService +from synapse.types import create_requester from tests import unittest @@ -18,6 +20,77 @@ class TestRatelimiter(unittest.TestCase): self.assertTrue(allowed) self.assertEquals(20.0, time_allowed) + def test_allowed_user_via_can_requester_do_action(self): + user_requester = create_requester("@user:example.com") + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + allowed, time_allowed = limiter.can_requester_do_action( + user_requester, _time_now_s=0 + ) + self.assertTrue(allowed) + self.assertEquals(10.0, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + user_requester, _time_now_s=5 + ) + self.assertFalse(allowed) + self.assertEquals(10.0, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + user_requester, _time_now_s=10 + ) + self.assertTrue(allowed) + self.assertEquals(20.0, time_allowed) + + def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): + appservice = ApplicationService( + None, "example.com", id="foo", rate_limited=True, + ) + as_requester = create_requester("@user:example.com", app_service=appservice) + + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=0 + ) + self.assertTrue(allowed) + self.assertEquals(10.0, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=5 + ) + self.assertFalse(allowed) + self.assertEquals(10.0, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=10 + ) + self.assertTrue(allowed) + self.assertEquals(20.0, time_allowed) + + def test_allowed_appservice_via_can_requester_do_action(self): + appservice = ApplicationService( + None, "example.com", id="foo", rate_limited=False, + ) + as_requester = create_requester("@user:example.com", app_service=appservice) + + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=0 + ) + self.assertTrue(allowed) + self.assertEquals(-1, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=5 + ) + self.assertTrue(allowed) + self.assertEquals(-1, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=10 + ) + self.assertTrue(allowed) + self.assertEquals(-1, time_allowed) + def test_allowed_via_ratelimit(self): limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 1bb25ab684..f92f3b8c15 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py
@@ -374,12 +374,16 @@ class OidcHandlerTestCase(HomeserverTestCase): self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) self.handler._auth_handler.complete_sso_login = simple_async_mock() - request = Mock(spec=["args", "getCookie", "addCookie"]) + request = Mock( + spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"] + ) code = "code" state = "state" nonce = "nonce" client_redirect_url = "http://client/redirect" + user_agent = "Browser" + ip_address = "10.0.0.1" session = self.handler._generate_oidc_session_token( state=state, nonce=nonce, @@ -392,6 +396,10 @@ class OidcHandlerTestCase(HomeserverTestCase): request.args[b"code"] = [code.encode("utf-8")] request.args[b"state"] = [state.encode("utf-8")] + request.requestHeaders = Mock(spec=["getRawHeaders"]) + request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")] + request.getClientIP.return_value = ip_address + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) self.handler._auth_handler.complete_sso_login.assert_called_once_with( @@ -399,7 +407,9 @@ class OidcHandlerTestCase(HomeserverTestCase): ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token) + self.handler._map_userinfo_to_user.assert_called_once_with( + userinfo, token, user_agent, ip_address + ) self.handler._fetch_userinfo.assert_not_called() self.handler._render_error.assert_not_called() @@ -431,7 +441,9 @@ class OidcHandlerTestCase(HomeserverTestCase): ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_not_called() - self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token) + self.handler._map_userinfo_to_user.assert_called_once_with( + userinfo, token, user_agent, ip_address + ) self.handler._fetch_userinfo.assert_called_once_with(token) self.handler._render_error.assert_not_called() diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index b609b30d4a..60ebc95f3e 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py
@@ -71,7 +71,9 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_my_name(self): - yield self.store.set_profile_displayname(self.frank.localpart, "Frank") + yield defer.ensureDeferred( + self.store.set_profile_displayname(self.frank.localpart, "Frank") + ) displayname = yield defer.ensureDeferred( self.handler.get_displayname(self.frank) @@ -104,7 +106,12 @@ class ProfileTestCase(unittest.TestCase): ) self.assertEquals( - (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank", + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.frank.localpart) + ) + ), + "Frank", ) @defer.inlineCallbacks @@ -112,10 +119,17 @@ class ProfileTestCase(unittest.TestCase): self.hs.config.enable_set_displayname = False # Setting displayname for the first time is allowed - yield self.store.set_profile_displayname(self.frank.localpart, "Frank") + yield defer.ensureDeferred( + self.store.set_profile_displayname(self.frank.localpart, "Frank") + ) self.assertEquals( - (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank", + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.frank.localpart) + ) + ), + "Frank", ) # Setting displayname a second time is forbidden @@ -158,7 +172,9 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_incoming_fed_query(self): yield defer.ensureDeferred(self.store.create_profile("caroline")) - yield self.store.set_profile_displayname("caroline", "Caroline") + yield defer.ensureDeferred( + self.store.set_profile_displayname("caroline", "Caroline") + ) response = yield defer.ensureDeferred( self.query_handlers["profile"]( @@ -170,8 +186,10 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_my_avatar(self): - yield self.store.set_profile_avatar_url( - self.frank.localpart, "http://my.server/me.png" + yield defer.ensureDeferred( + self.store.set_profile_avatar_url( + self.frank.localpart, "http://my.server/me.png" + ) ) avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank)) @@ -188,7 +206,11 @@ class ProfileTestCase(unittest.TestCase): ) self.assertEquals( - (yield self.store.get_profile_avatar_url(self.frank.localpart)), + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.frank.localpart) + ) + ), "http://my.server/pic.gif", ) @@ -202,7 +224,11 @@ class ProfileTestCase(unittest.TestCase): ) self.assertEquals( - (yield self.store.get_profile_avatar_url(self.frank.localpart)), + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.frank.localpart) + ) + ), "http://my.server/me.png", ) @@ -211,12 +237,18 @@ class ProfileTestCase(unittest.TestCase): self.hs.config.enable_set_avatar_url = False # Setting displayname for the first time is allowed - yield self.store.set_profile_avatar_url( - self.frank.localpart, "http://my.server/me.png" + yield defer.ensureDeferred( + self.store.set_profile_avatar_url( + self.frank.localpart, "http://my.server/me.png" + ) ) self.assertEquals( - (yield self.store.get_profile_avatar_url(self.frank.localpart)), + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.frank.localpart) + ) + ), "http://my.server/me.png", ) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index e364b1bd62..5c92d0e8c9 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py
@@ -17,18 +17,21 @@ from mock import Mock from twisted.internet import defer +from synapse.api.auth import Auth from synapse.api.constants import UserTypes from synapse.api.errors import Codes, ResourceLimitError, SynapseError from synapse.handlers.register import RegistrationHandler +from synapse.spam_checker_api import RegistrationBehaviour from synapse.types import RoomAlias, UserID, create_requester from tests.test_utils import make_awaitable from tests.unittest import override_config +from tests.utils import mock_getRawHeaders from .. import unittest -class RegistrationHandlers(object): +class RegistrationHandlers: def __init__(self, hs): self.registration_handler = RegistrationHandler(hs) @@ -475,6 +478,53 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.handler.register_user(localpart=invalid_user_id), SynapseError ) + def test_spam_checker_deny(self): + """A spam checker can deny registration, which results in an error.""" + + class DenyAll: + def check_registration_for_spam( + self, email_threepid, username, request_info + ): + return RegistrationBehaviour.DENY + + # Configure a spam checker that denies all users. + spam_checker = self.hs.get_spam_checker() + spam_checker.spam_checkers = [DenyAll()] + + self.get_failure(self.handler.register_user(localpart="user"), SynapseError) + + def test_spam_checker_shadow_ban(self): + """A spam checker can choose to shadow-ban a user, which allows registration to succeed.""" + + class BanAll: + def check_registration_for_spam( + self, email_threepid, username, request_info + ): + return RegistrationBehaviour.SHADOW_BAN + + # Configure a spam checker that denies all users. + spam_checker = self.hs.get_spam_checker() + spam_checker.spam_checkers = [BanAll()] + + user_id = self.get_success(self.handler.register_user(localpart="user")) + + # Get an access token. + token = self.macaroon_generator.generate_access_token(user_id) + self.get_success( + self.store.add_access_token_to_user( + user_id=user_id, token=token, device_id=None, valid_until_ms=None + ) + ) + + # Ensure the user was marked as shadow-banned. + request = Mock(args={}) + request.args[b"access_token"] = [token.encode("ascii")] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + auth = Auth(self.hs) + requester = self.get_success(auth.get_user_by_req(request)) + + self.assertTrue(requester.shadow_banned) + async def get_or_create_user( self, requester, localpart, displayname, password_hash=None ): diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index 0e666492f6..a609f148c0 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py
@@ -81,8 +81,8 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) - def get_all_room_state(self): - return self.store.db_pool.simple_select_list( + async def get_all_room_state(self): + return await self.store.db_pool.simple_select_list( "room_stats_state", None, retcols=("name", "topic", "canonical_alias") ) @@ -256,7 +256,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): # self.handler.notify_new_event() # We need to let the delta processor advance… - self.pump(10 * 60) + self.reactor.advance(10 * 60) # Get the slices! There should be two -- day 1, and day 2. r = self.get_success(self.store.get_statistics_for_subject("room", room_1, 0)) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index e01de158e5..81c1839637 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py
@@ -21,7 +21,7 @@ from mock import ANY, Mock, call from twisted.internet import defer from synapse.api.errors import AuthError -from synapse.types import UserID +from synapse.types import UserID, create_requester from tests import unittest from tests.test_utils import make_awaitable @@ -144,9 +144,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.datastore.get_users_in_room = get_users_in_room - self.datastore.get_user_directory_stream_pos.return_value = ( + self.datastore.get_user_directory_stream_pos.side_effect = ( # we deliberately return a non-None stream pos to avoid doing an initial_spam - defer.succeed(1) + lambda: make_awaitable(1) ) self.datastore.get_current_state_deltas.return_value = (0, None) @@ -167,7 +167,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.get_success( self.handler.started_typing( - target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000 + target_user=U_APPLE, + requester=create_requester(U_APPLE), + room_id=ROOM_ID, + timeout=20000, ) ) @@ -194,7 +197,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.get_success( self.handler.started_typing( - target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000 + target_user=U_APPLE, + requester=create_requester(U_APPLE), + room_id=ROOM_ID, + timeout=20000, ) ) @@ -269,7 +275,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.get_success( self.handler.stopped_typing( - target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID + target_user=U_APPLE, + requester=create_requester(U_APPLE), + room_id=ROOM_ID, ) ) @@ -309,7 +317,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.get_success( self.handler.started_typing( - target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000 + target_user=U_APPLE, + requester=create_requester(U_APPLE), + room_id=ROOM_ID, + timeout=10000, ) ) @@ -348,7 +359,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.get_success( self.handler.started_typing( - target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000 + target_user=U_APPLE, + requester=create_requester(U_APPLE), + room_id=ROOM_ID, + timeout=10000, ) ) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 31ed89a5cd..87be94111f 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py
@@ -238,7 +238,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def test_spam_checker(self): """ - A user which fails to the spam checks will not appear in search results. + A user which fails the spam checks will not appear in search results. """ u1 = self.register_user("user1", "pass") u1_token = self.login(u1, "pass") @@ -269,7 +269,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # Configure a spam checker that does not filter any users. spam_checker = self.hs.get_spam_checker() - class AllowAll(object): + class AllowAll: def check_username_for_spam(self, user_profile): # Allow all users. return False @@ -282,7 +282,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.assertEqual(len(s["results"]), 1) # Configure a spam checker that filters all users. - class BlockAll(object): + class BlockAll: def check_username_for_spam(self, user_profile): # All users are spammy. return True diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index ac598249e4..7f1dfb2128 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py
@@ -508,6 +508,6 @@ class FederationClientTests(HomeserverTestCase): self.assertFalse(conn.disconnecting) # wait for a while - self.pump(120) + self.reactor.advance(120) self.assertTrue(conn.disconnecting) diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 807cd65dd6..04de0b9dbe 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py
@@ -35,7 +35,7 @@ class ModuleApiTestCase(HomeserverTestCase): # Check that the new user exists with all provided attributes self.assertEqual(user_id, "@bob:test") self.assertTrue(access_token) - self.assertTrue(self.store.get_user_by_id(user_id)) + self.assertTrue(self.get_success(self.store.get_user_by_id(user_id))) # Check that the email was assigned emails = self.get_success(self.store.user_get_threepids(user_id)) diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 83032cc9ea..227b0d32d0 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py
@@ -170,7 +170,7 @@ class EmailPusherTests(HomeserverTestCase): last_stream_ordering = pushers[0]["last_stream_ordering"] # Advance time a bit, so the pusher will register something has happened - self.pump(100) + self.pump(10) # It hasn't succeeded yet, so the stream ordering shouldn't have moved pushers = self.get_success( diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 83f9aa291c..8b4982ecb1 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py
@@ -20,7 +20,7 @@ from synapse.api.constants import EventTypes, Membership from synapse.events.builder import EventBuilderFactory from synapse.rest.admin import register_servlets_for_client_rest_resource from synapse.rest.client.v1 import login, room -from synapse.types import UserID +from synapse.types import UserID, create_requester from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.test_utils import make_awaitable @@ -175,7 +175,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): self.get_success( typing_handler.started_typing( target_user=UserID.from_string(user), - auth_user=UserID.from_string(user), + requester=create_requester(user), room_id=room, timeout=20000, ) diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 0b191d13c6..7d3773ff78 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py
@@ -45,50 +45,63 @@ class RetentionTestCase(unittest.HomeserverTestCase): } self.hs = self.setup_test_homeserver(config=config) + return self.hs def prepare(self, reactor, clock, homeserver): self.user_id = self.register_user("user", "password") self.token = self.login("user", "password") - def test_retention_state_event(self): - """Tests that the server configuration can limit the values a user can set to the - room's retention policy. + self.store = self.hs.get_datastore() + self.serializer = self.hs.get_event_client_serializer() + self.clock = self.hs.get_clock() + + def test_retention_event_purged_with_state_event(self): + """Tests that expired events are correctly purged when the room's retention policy + is defined by a state event. """ room_id = self.helper.create_room_as(self.user_id, tok=self.token) + # Set the room's retention period to 2 days. + lifetime = one_day_ms * 2 self.helper.send_state( room_id=room_id, event_type=EventTypes.Retention, - body={"max_lifetime": one_day_ms * 4}, + body={"max_lifetime": lifetime}, tok=self.token, - expect_code=400, ) + self._test_retention_event_purged(room_id, one_day_ms * 1.5) + + def test_retention_event_purged_with_state_event_outside_allowed(self): + """Tests that the server configuration can override the policy for a room when + running the purge jobs. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + # Set a max_lifetime higher than the maximum allowed value. self.helper.send_state( room_id=room_id, event_type=EventTypes.Retention, - body={"max_lifetime": one_hour_ms}, + body={"max_lifetime": one_day_ms * 4}, tok=self.token, - expect_code=400, ) - def test_retention_event_purged_with_state_event(self): - """Tests that expired events are correctly purged when the room's retention policy - is defined by a state event. - """ - room_id = self.helper.create_room_as(self.user_id, tok=self.token) + # Check that the event is purged after waiting for the maximum allowed duration + # instead of the one specified in the room's policy. + self._test_retention_event_purged(room_id, one_day_ms * 1.5) - # Set the room's retention period to 2 days. - lifetime = one_day_ms * 2 + # Set a max_lifetime lower than the minimum allowed value. self.helper.send_state( room_id=room_id, event_type=EventTypes.Retention, - body={"max_lifetime": lifetime}, + body={"max_lifetime": one_hour_ms}, tok=self.token, ) - self._test_retention_event_purged(room_id, one_day_ms * 1.5) + # Check that the event is purged after waiting for the minimum allowed duration + # instead of the one specified in the room's policy. + self._test_retention_event_purged(room_id, one_day_ms * 0.5) def test_retention_event_purged_without_state_event(self): """Tests that expired events are correctly purged when the room's retention policy @@ -140,12 +153,32 @@ class RetentionTestCase(unittest.HomeserverTestCase): # That event should be the second, not outdated event. self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events) - def _test_retention_event_purged(self, room_id, increment): + def _test_retention_event_purged(self, room_id: str, increment: float): + """Run the following test scenario to test the message retention policy support: + + 1. Send event 1 + 2. Increment time by `increment` + 3. Send event 2 + 4. Increment time by `increment` + 5. Check that event 1 has been purged + 6. Check that event 2 has not been purged + 7. Check that state events that were sent before event 1 aren't purged. + The main reason for sending a second event is because currently Synapse won't + purge the latest message in a room because it would otherwise result in a lack of + forward extremities for this room. It's also a good thing to ensure the purge jobs + aren't too greedy and purge messages they shouldn't. + + Args: + room_id: The ID of the room to test retention in. + increment: The number of milliseconds to advance the clock each time. Must be + defined so that events in the room aren't purged if they are `increment` + old but are purged if they are `increment * 2` old. + """ # Get the create event to, later, check that we can still access it. message_handler = self.hs.get_message_handler() create_event = self.get_success( message_handler.get_room_data( - self.user_id, room_id, EventTypes.Create, state_key="", is_guest=False + self.user_id, room_id, EventTypes.Create, state_key="" ) ) @@ -156,7 +189,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): expired_event_id = resp.get("event_id") # Check that we can retrieve the event. - expired_event = self.get_event(room_id, expired_event_id) + expired_event = self.get_event(expired_event_id) self.assertEqual( expired_event.get("content", {}).get("body"), "1", expired_event ) @@ -174,26 +207,31 @@ class RetentionTestCase(unittest.HomeserverTestCase): # one should still be kept. self.reactor.advance(increment / 1000) - # Check that the event has been purged from the database. - self.get_event(room_id, expired_event_id, expected_code=404) + # Check that the first event has been purged from the database, i.e. that we + # can't retrieve it anymore, because it has expired. + self.get_event(expired_event_id, expect_none=True) - # Check that the event that hasn't been purged can still be retrieved. - valid_event = self.get_event(room_id, valid_event_id) + # Check that the event that hasn't expired can still be retrieved. + valid_event = self.get_event(valid_event_id) self.assertEqual(valid_event.get("content", {}).get("body"), "2", valid_event) # Check that we can still access state events that were sent before the event that # has been purged. self.get_event(room_id, create_event.event_id) - def get_event(self, room_id, event_id, expected_code=200): - url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) + def get_event(self, event_id, expect_none=False): + event = self.get_success(self.store.get_event(event_id, allow_none=True)) - request, channel = self.make_request("GET", url, access_token=self.token) - self.render(request) + if expect_none: + self.assertIsNone(event) + return {} - self.assertEqual(channel.code, expected_code, channel.result) + self.assertIsNotNone(event) - return channel.json_body + time_now = self.clock.time_msec() + serialized = self.get_success(self.serializer.serialize_event(event, time_now)) + + return serialized class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py new file mode 100644
index 0000000000..dfe4bf7762 --- /dev/null +++ b/tests/rest/client/test_shadow_banned.py
@@ -0,0 +1,312 @@ +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mock import Mock, patch + +import synapse.rest.admin +from synapse.api.constants import EventTypes +from synapse.rest.client.v1 import directory, login, profile, room +from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet + +from tests import unittest + + +class _ShadowBannedBase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, homeserver): + # Create two users, one of which is shadow-banned. + self.banned_user_id = self.register_user("banned", "test") + self.banned_access_token = self.login("banned", "test") + + self.store = self.hs.get_datastore() + + self.get_success( + self.store.db_pool.simple_update( + table="users", + keyvalues={"name": self.banned_user_id}, + updatevalues={"shadow_banned": True}, + desc="shadow_ban", + ) + ) + + self.other_user_id = self.register_user("otheruser", "pass") + self.other_access_token = self.login("otheruser", "pass") + + +# To avoid the tests timing out don't add a delay to "annoy the requester". +@patch("random.randint", new=lambda a, b: 0) +class RoomTestCase(_ShadowBannedBase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + directory.register_servlets, + login.register_servlets, + room.register_servlets, + room_upgrade_rest_servlet.register_servlets, + ] + + def test_invite(self): + """Invites from shadow-banned users don't actually get sent.""" + + # The create works fine. + room_id = self.helper.create_room_as( + self.banned_user_id, tok=self.banned_access_token + ) + + # Inviting the user completes successfully. + self.helper.invite( + room=room_id, + src=self.banned_user_id, + tok=self.banned_access_token, + targ=self.other_user_id, + ) + + # But the user wasn't actually invited. + invited_rooms = self.get_success( + self.store.get_invited_rooms_for_local_user(self.other_user_id) + ) + self.assertEqual(invited_rooms, []) + + def test_invite_3pid(self): + """Ensure that a 3PID invite does not attempt to contact the identity server.""" + identity_handler = self.hs.get_handlers().identity_handler + identity_handler.lookup_3pid = Mock( + side_effect=AssertionError("This should not get called") + ) + + # The create works fine. + room_id = self.helper.create_room_as( + self.banned_user_id, tok=self.banned_access_token + ) + + # Inviting the user completes successfully. + request, channel = self.make_request( + "POST", + "/rooms/%s/invite" % (room_id,), + {"id_server": "test", "medium": "email", "address": "test@test.test"}, + access_token=self.banned_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + # This should have raised an error earlier, but double check this wasn't called. + identity_handler.lookup_3pid.assert_not_called() + + def test_create_room(self): + """Invitations during a room creation should be discarded, but the room still gets created.""" + # The room creation is successful. + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/createRoom", + {"visibility": "public", "invite": [self.other_user_id]}, + access_token=self.banned_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + room_id = channel.json_body["room_id"] + + # But the user wasn't actually invited. + invited_rooms = self.get_success( + self.store.get_invited_rooms_for_local_user(self.other_user_id) + ) + self.assertEqual(invited_rooms, []) + + # Since a real room was created, the other user should be able to join it. + self.helper.join(room_id, self.other_user_id, tok=self.other_access_token) + + # Both users should be in the room. + users = self.get_success(self.store.get_users_in_room(room_id)) + self.assertCountEqual(users, ["@banned:test", "@otheruser:test"]) + + def test_message(self): + """Messages from shadow-banned users don't actually get sent.""" + + room_id = self.helper.create_room_as( + self.other_user_id, tok=self.other_access_token + ) + + # The user should be in the room. + self.helper.join(room_id, self.banned_user_id, tok=self.banned_access_token) + + # Sending a message should complete successfully. + result = self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "with right label"}, + tok=self.banned_access_token, + ) + self.assertIn("event_id", result) + event_id = result["event_id"] + + latest_events = self.get_success( + self.store.get_latest_event_ids_in_room(room_id) + ) + self.assertNotIn(event_id, latest_events) + + def test_upgrade(self): + """A room upgrade should fail, but look like it succeeded.""" + + # The create works fine. + room_id = self.helper.create_room_as( + self.banned_user_id, tok=self.banned_access_token + ) + + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/%s/upgrade" % (room_id,), + {"new_version": "6"}, + access_token=self.banned_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + # A new room_id should be returned. + self.assertIn("replacement_room", channel.json_body) + + new_room_id = channel.json_body["replacement_room"] + + # It doesn't really matter what API we use here, we just want to assert + # that the room doesn't exist. + summary = self.get_success(self.store.get_room_summary(new_room_id)) + # The summary should be empty since the room doesn't exist. + self.assertEqual(summary, {}) + + def test_typing(self): + """Typing notifications should not be propagated into the room.""" + # The create works fine. + room_id = self.helper.create_room_as( + self.banned_user_id, tok=self.banned_access_token + ) + + request, channel = self.make_request( + "PUT", + "/rooms/%s/typing/%s" % (room_id, self.banned_user_id), + {"typing": True, "timeout": 30000}, + access_token=self.banned_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code) + + # There should be no typing events. + event_source = self.hs.get_event_sources().sources["typing"] + self.assertEquals(event_source.get_current_key(), 0) + + # The other user can join and send typing events. + self.helper.join(room_id, self.other_user_id, tok=self.other_access_token) + + request, channel = self.make_request( + "PUT", + "/rooms/%s/typing/%s" % (room_id, self.other_user_id), + {"typing": True, "timeout": 30000}, + access_token=self.other_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code) + + # These appear in the room. + self.assertEquals(event_source.get_current_key(), 1) + events = self.get_success( + event_source.get_new_events(from_key=0, room_ids=[room_id]) + ) + self.assertEquals( + events[0], + [ + { + "type": "m.typing", + "room_id": room_id, + "content": {"user_ids": [self.other_user_id]}, + } + ], + ) + + +# To avoid the tests timing out don't add a delay to "annoy the requester". +@patch("random.randint", new=lambda a, b: 0) +class ProfileTestCase(_ShadowBannedBase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + profile.register_servlets, + room.register_servlets, + ] + + def test_displayname(self): + """Profile changes should succeed, but don't end up in a room.""" + original_display_name = "banned" + new_display_name = "new name" + + # Join a room. + room_id = self.helper.create_room_as( + self.banned_user_id, tok=self.banned_access_token + ) + + # The update should succeed. + request, channel = self.make_request( + "PUT", + "/_matrix/client/r0/profile/%s/displayname" % (self.banned_user_id,), + {"displayname": new_display_name}, + access_token=self.banned_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + self.assertEqual(channel.json_body, {}) + + # The user's display name should be updated. + request, channel = self.make_request( + "GET", "/profile/%s/displayname" % (self.banned_user_id,) + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body["displayname"], new_display_name) + + # But the display name in the room should not be. + message_handler = self.hs.get_message_handler() + event = self.get_success( + message_handler.get_room_data( + self.banned_user_id, room_id, "m.room.member", self.banned_user_id, + ) + ) + self.assertEqual( + event.content, {"membership": "join", "displayname": original_display_name} + ) + + def test_room_displayname(self): + """Changes to state events for a room should be processed, but not end up in the room.""" + original_display_name = "banned" + new_display_name = "new name" + + # Join a room. + room_id = self.helper.create_room_as( + self.banned_user_id, tok=self.banned_access_token + ) + + # The update should succeed. + request, channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" + % (room_id, self.banned_user_id), + {"membership": "join", "displayname": new_display_name}, + access_token=self.banned_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + self.assertIn("event_id", channel.json_body) + + # The display name in the room should not be changed. + message_handler = self.hs.get_message_handler() + event = self.get_success( + message_handler.get_room_data( + self.banned_user_id, room_id, "m.room.member", self.banned_user_id, + ) + ) + self.assertEqual( + event.content, {"membership": "join", "displayname": original_display_name} + ) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index db52725cfe..2668662c9e 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py
@@ -62,8 +62,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, "password": "monkey", } - request_data = json.dumps(params) - request, channel = self.make_request(b"POST", LOGIN_URL, request_data) + request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) if i == 5: @@ -76,14 +75,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # than 1min. self.assertTrue(retry_after_ms < 6000) - self.reactor.advance(retry_after_ms / 1000.0) + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) params = { "type": "m.login.password", "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, "password": "monkey", } - request_data = json.dumps(params) request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) @@ -111,8 +109,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "monkey", } - request_data = json.dumps(params) - request, channel = self.make_request(b"POST", LOGIN_URL, request_data) + request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) if i == 5: @@ -132,7 +129,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "monkey", } - request_data = json.dumps(params) request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) @@ -160,8 +156,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "notamonkey", } - request_data = json.dumps(params) - request, channel = self.make_request(b"POST", LOGIN_URL, request_data) + request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) if i == 5: @@ -174,14 +169,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # than 1min. self.assertTrue(retry_after_ms < 6000) - self.reactor.advance(retry_after_ms / 1000.0) + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) params = { "type": "m.login.password", "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "notamonkey", } - request_data = json.dumps(params) request, channel = self.make_request(b"POST", LOGIN_URL, params) self.render(request) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index ef6b775ed2..0a567b032f 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py
@@ -28,7 +28,7 @@ from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.handlers.pagination import PurgeStatus from synapse.rest.client.v1 import directory, login, profile, room from synapse.rest.client.v2_alpha import account -from synapse.types import JsonDict, RoomAlias +from synapse.types import JsonDict, RoomAlias, UserID from synapse.util.stringutils import random_string from tests import unittest @@ -675,6 +675,92 @@ class RoomMemberStateTestCase(RoomBase): self.assertEquals(json.loads(content), channel.json_body) +class RoomJoinRatelimitTestCase(RoomBase): + user_id = "@sid1:red" + + servlets = [ + profile.register_servlets, + room.register_servlets, + ] + + @unittest.override_config( + {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_join_local_ratelimit(self): + """Tests that local joins are actually rate-limited.""" + for i in range(3): + self.helper.create_room_as(self.user_id) + + self.helper.create_room_as(self.user_id, expect_code=429) + + @unittest.override_config( + {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_join_local_ratelimit_profile_change(self): + """Tests that sending a profile update into all of the user's joined rooms isn't + rate-limited by the rate-limiter on joins.""" + + # Create and join as many rooms as the rate-limiting config allows in a second. + room_ids = [ + self.helper.create_room_as(self.user_id), + self.helper.create_room_as(self.user_id), + self.helper.create_room_as(self.user_id), + ] + # Let some time for the rate-limiter to forget about our multi-join. + self.reactor.advance(2) + # Add one to make sure we're joined to more rooms than the config allows us to + # join in a second. + room_ids.append(self.helper.create_room_as(self.user_id)) + + # Create a profile for the user, since it hasn't been done on registration. + store = self.hs.get_datastore() + self.get_success( + store.create_profile(UserID.from_string(self.user_id).localpart) + ) + + # Update the display name for the user. + path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id + request, channel = self.make_request("PUT", path, {"displayname": "John Doe"}) + self.render(request) + self.assertEquals(channel.code, 200, channel.json_body) + + # Check that all the rooms have been sent a profile update into. + for room_id in room_ids: + path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % ( + room_id, + self.user_id, + ) + + request, channel = self.make_request("GET", path) + self.render(request) + self.assertEquals(channel.code, 200) + + self.assertIn("displayname", channel.json_body) + self.assertEquals(channel.json_body["displayname"], "John Doe") + + @unittest.override_config( + {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_join_local_ratelimit_idempotent(self): + """Tests that the room join endpoints remain idempotent despite rate-limiting + on room joins.""" + room_id = self.helper.create_room_as(self.user_id) + + # Let's test both paths to be sure. + paths_to_test = [ + "/_matrix/client/r0/rooms/%s/join", + "/_matrix/client/r0/join/%s", + ] + + for path in paths_to_test: + # Make sure we send more requests than the rate-limiting config would allow + # if all of these requests ended up joining the user to a room. + for i in range(4): + request, channel = self.make_request("POST", path % room_id, {}) + self.render(request) + self.assertEquals(channel.code, 200) + + class RoomMessagesTestCase(RoomBase): """ Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """ diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 8933b560d2..e66c9a4c4c 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py
@@ -39,7 +39,9 @@ class RestHelper(object): resource = attr.ib() auth_user_id = attr.ib() - def create_room_as(self, room_creator=None, is_public=True, tok=None): + def create_room_as( + self, room_creator=None, is_public=True, tok=None, expect_code=200, + ): temp_id = self.auth_user_id self.auth_user_id = room_creator path = "/_matrix/client/r0/createRoom" @@ -54,9 +56,11 @@ class RestHelper(object): ) render(request, self.resource, self.hs.get_reactor()) - assert channel.result["code"] == b"200", channel.result + assert channel.result["code"] == b"%d" % expect_code, channel.result self.auth_user_id = temp_id - return channel.json_body["room_id"] + + if expect_code == 200: + return channel.json_body["room_id"] def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): self.change_membership( diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 53a43038f0..2fc3a60fc5 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py
@@ -160,7 +160,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): else: self.assertEquals(channel.result["code"], b"200", channel.result) - self.reactor.advance(retry_after_ms / 1000.0) + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") self.render(request) @@ -186,7 +186,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): else: self.assertEquals(channel.result["code"], b"200", channel.result) - self.reactor.advance(retry_after_ms / 1000.0) + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") self.render(request) diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 2858d13558..23db821fb7 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -104,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( - return_value=defer.succeed({"123": mock_event}) + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check the content, but once == remove blocking event @@ -122,7 +122,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( - return_value=defer.succeed({"123": mock_event}) + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -217,7 +217,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( - return_value=defer.succeed({"123": mock_event}) + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index a425e66f37..17fbde284a 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py
@@ -31,6 +31,7 @@ from synapse.storage.databases.main.appservice import ( ) from tests import unittest +from tests.test_utils import make_awaitable from tests.utils import setup_test_homeserver @@ -357,7 +358,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): other_events = [Mock(event_id="e5"), Mock(event_id="e6")] # we aren't testing store._base stuff here, so mock this out - self.store.get_events_as_list = Mock(return_value=defer.succeed(events)) + self.store.get_events_as_list = Mock(return_value=make_awaitable(events)) yield self._insert_txn(self.as_list[1]["id"], 9, other_events) yield self._insert_txn(service.id, 10, events) diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 2efbc97c2e..1a1c59256c 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py
@@ -67,13 +67,12 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): # second step: complete the update # we should now get run with a much bigger number of items to update - @defer.inlineCallbacks - def update(progress, count): + async def update(progress, count): self.assertEqual(progress, {"my_key": 2}) self.assertAlmostEqual( count, target_background_update_duration_ms / duration_ms, places=0, ) - yield self.updates._end_background_update("test_update") + await self.updates._end_background_update("test_update") return count self.update_handler.side_effect = update diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 13bcac743a..40ba652248 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py
@@ -97,8 +97,10 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)])) - value = yield self.datastore.db_pool.simple_select_one_onecol( - table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol" + value = yield defer.ensureDeferred( + self.datastore.db_pool.simple_select_one_onecol( + table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol" + ) ) self.assertEquals("Value", value) @@ -111,10 +113,12 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 self.mock_txn.fetchone.return_value = (1, 2, 3) - ret = yield self.datastore.db_pool.simple_select_one( - table="tablename", - keyvalues={"keycol": "TheKey"}, - retcols=["colA", "colB", "colC"], + ret = yield defer.ensureDeferred( + self.datastore.db_pool.simple_select_one( + table="tablename", + keyvalues={"keycol": "TheKey"}, + retcols=["colA", "colB", "colC"], + ) ) self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret) @@ -127,11 +131,13 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 0 self.mock_txn.fetchone.return_value = None - ret = yield self.datastore.db_pool.simple_select_one( - table="tablename", - keyvalues={"keycol": "Not here"}, - retcols=["colA"], - allow_none=True, + ret = yield defer.ensureDeferred( + self.datastore.db_pool.simple_select_one( + table="tablename", + keyvalues={"keycol": "Not here"}, + retcols=["colA"], + allow_none=True, + ) ) self.assertFalse(ret) @@ -142,8 +148,10 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) self.mock_txn.description = (("colA", None, None, None, None, None, None),) - ret = yield self.datastore.db_pool.simple_select_list( - table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"] + ret = yield defer.ensureDeferred( + self.datastore.db_pool.simple_select_list( + table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"] + ) ) self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret) @@ -155,10 +163,12 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_update_one_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore.db_pool.simple_update_one( - table="tablename", - keyvalues={"keycol": "TheKey"}, - updatevalues={"columnname": "New Value"}, + yield defer.ensureDeferred( + self.datastore.db_pool.simple_update_one( + table="tablename", + keyvalues={"keycol": "TheKey"}, + updatevalues={"columnname": "New Value"}, + ) ) self.mock_txn.execute.assert_called_with( @@ -170,10 +180,12 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_update_one_4cols(self): self.mock_txn.rowcount = 1 - yield self.datastore.db_pool.simple_update_one( - table="tablename", - keyvalues=OrderedDict([("colA", 1), ("colB", 2)]), - updatevalues=OrderedDict([("colC", 3), ("colD", 4)]), + yield defer.ensureDeferred( + self.datastore.db_pool.simple_update_one( + table="tablename", + keyvalues=OrderedDict([("colA", 1), ("colB", 2)]), + updatevalues=OrderedDict([("colC", 3), ("colD", 4)]), + ) ) self.mock_txn.execute.assert_called_with( @@ -185,8 +197,10 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_delete_one(self): self.mock_txn.rowcount = 1 - yield self.datastore.db_pool.simple_delete_one( - table="tablename", keyvalues={"keycol": "Go away"} + yield defer.ensureDeferred( + self.datastore.db_pool.simple_delete_one( + table="tablename", keyvalues={"keycol": "Go away"} + ) ) self.mock_txn.execute.assert_called_with( diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 8e9a650f9f..080761d1d2 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py
@@ -271,7 +271,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): # Pump the reactor repeatedly so that the background updates have a # chance to run. - self.pump(10 * 60) + self.pump(20) latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) @@ -353,6 +353,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[ "3" ] = 300000 + self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion() # All entries within time frame self.assertEqual( @@ -362,7 +363,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): 3, ) # Oldest room to expire - self.pump(1) + self.pump(1.01) self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion() self.assertEqual( len( diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 87ed8f8cd1..34ae8c9da7 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py
@@ -38,7 +38,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase): self.store.store_device("user_id", "device_id", "display_name") ) - res = yield self.store.get_device("user_id", "device_id") + res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) self.assertDictContainsSubset( { "user_id": "user_id", @@ -111,12 +111,12 @@ class DeviceStoreTestCase(tests.unittest.TestCase): self.store.store_device("user_id", "device_id", "display_name 1") ) - res = yield self.store.get_device("user_id", "device_id") + res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) self.assertEqual("display_name 1", res["display_name"]) # do a no-op first yield defer.ensureDeferred(self.store.update_device("user_id", "device_id")) - res = yield self.store.get_device("user_id", "device_id") + res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) self.assertEqual("display_name 1", res["display_name"]) # do the update @@ -127,7 +127,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase): ) # check it worked - res = yield self.store.get_device("user_id", "device_id") + res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) self.assertEqual("display_name 2", res["display_name"]) @defer.inlineCallbacks diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index daac947cb2..da93ca3980 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py
@@ -42,7 +42,11 @@ class DirectoryStoreTestCase(unittest.TestCase): self.assertEquals( ["#my-room:test"], - (yield self.store.get_aliases_for_room(self.room.to_string())), + ( + yield defer.ensureDeferred( + self.store.get_aliases_for_room(self.room.to_string()) + ) + ), ) @defer.inlineCallbacks diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index e845410dae..14ce21c786 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py
@@ -58,6 +58,10 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): return self.get_success(self.db_pool.runWithConnection(_create)) def _insert_rows(self, instance_name: str, number: int): + """Insert N rows as the given instance, inserting with stream IDs pulled + from the postgres sequence. + """ + def _insert(txn): for _ in range(number): txn.execute( @@ -65,7 +69,20 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): (instance_name,), ) - self.get_success(self.db_pool.runInteraction("test_single_instance", _insert)) + self.get_success(self.db_pool.runInteraction("_insert_rows", _insert)) + + def _insert_row_with_id(self, instance_name: str, stream_id: int): + """Insert one row as the given instance with given stream_id, updating + the postgres sequence position to match. + """ + + def _insert(txn): + txn.execute( + "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,), + ) + txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,)) + + self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert)) def test_empty(self): """Test an ID generator against an empty database gives sensible @@ -88,7 +105,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): id_gen = self._create_id_generator() self.assertEqual(id_gen.get_positions(), {"master": 7}) - self.assertEqual(id_gen.get_current_token("master"), 7) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) # Try allocating a new ID gen and check that we only see position # advanced after we leave the context manager. @@ -98,12 +115,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(stream_id, 8) self.assertEqual(id_gen.get_positions(), {"master": 7}) - self.assertEqual(id_gen.get_current_token("master"), 7) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) self.get_success(_get_next_async()) self.assertEqual(id_gen.get_positions(), {"master": 8}) - self.assertEqual(id_gen.get_current_token("master"), 8) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) def test_multi_instance(self): """Test that reads and writes from multiple processes are handled @@ -116,8 +133,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): second_id_gen = self._create_id_generator("second") self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) - self.assertEqual(first_id_gen.get_current_token("first"), 3) - self.assertEqual(first_id_gen.get_current_token("second"), 7) + self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3) + self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7) # Try allocating a new ID gen and check that we only see position # advanced after we leave the context manager. @@ -166,7 +183,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): id_gen = self._create_id_generator() self.assertEqual(id_gen.get_positions(), {"master": 7}) - self.assertEqual(id_gen.get_current_token("master"), 7) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) # Try allocating a new ID gen and check that we only see position # advanced after we leave the context manager. @@ -176,9 +193,74 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(stream_id, 8) self.assertEqual(id_gen.get_positions(), {"master": 7}) - self.assertEqual(id_gen.get_current_token("master"), 7) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) self.get_success(self.db_pool.runInteraction("test", _get_next_txn)) self.assertEqual(id_gen.get_positions(), {"master": 8}) - self.assertEqual(id_gen.get_current_token("master"), 8) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) + + def test_get_persisted_upto_position(self): + """Test that `get_persisted_upto_position` correctly tracks updates to + positions. + """ + + # The following tests are a bit cheeky in that we notify about new + # positions via `advance` without *actually* advancing the postgres + # sequence. + + self._insert_row_with_id("first", 3) + self._insert_row_with_id("second", 5) + + id_gen = self._create_id_generator("first") + + self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5}) + + # Min is 3 and there is a gap between 5, so we expect it to be 3. + self.assertEqual(id_gen.get_persisted_upto_position(), 3) + + # We advance "first" straight to 6. Min is now 5 but there is no gap so + # we expect it to be 6 + id_gen.advance("first", 6) + self.assertEqual(id_gen.get_persisted_upto_position(), 6) + + # No gap, so we expect 7. + id_gen.advance("second", 7) + self.assertEqual(id_gen.get_persisted_upto_position(), 7) + + # We haven't seen 8 yet, so we expect 7 still. + id_gen.advance("second", 9) + self.assertEqual(id_gen.get_persisted_upto_position(), 7) + + # Now that we've seen 7, 8 and 9 we can got straight to 9. + id_gen.advance("first", 8) + self.assertEqual(id_gen.get_persisted_upto_position(), 9) + + # Jump forward with gaps. The minimum is 11, even though we haven't seen + # 10 we know that everything before 11 must be persisted. + id_gen.advance("first", 11) + id_gen.advance("second", 15) + self.assertEqual(id_gen.get_persisted_upto_position(), 11) + + def test_get_persisted_upto_position_get_next(self): + """Test that `get_persisted_upto_position` correctly tracks updates to + positions when `get_next` is called. + """ + + self._insert_row_with_id("first", 3) + self._insert_row_with_id("second", 5) + + id_gen = self._create_id_generator("first") + + self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5}) + + self.assertEqual(id_gen.get_persisted_upto_position(), 3) + with self.get_success(id_gen.get_next()) as stream_id: + self.assertEqual(stream_id, 6) + self.assertEqual(id_gen.get_persisted_upto_position(), 3) + + self.assertEqual(id_gen.get_persisted_upto_position(), 6) + + # We assume that so long as `get_next` does correctly advance the + # `persisted_upto_position` in this case, then it will be correct in the + # other cases that are tested above (since they'll hit the same code). diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index fbf8af940a..954338a592 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py
@@ -36,7 +36,9 @@ class DataStoreTestCase(unittest.TestCase): def test_get_users_paginate(self): yield self.store.register_user(self.user.to_string(), "pass") yield defer.ensureDeferred(self.store.create_profile(self.user.localpart)) - yield self.store.set_profile_displayname(self.user.localpart, self.displayname) + yield defer.ensureDeferred( + self.store.set_profile_displayname(self.user.localpart, self.displayname) + ) users, total = yield self.store.get_users_paginate( 0, 10, name="bc", guests=False diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 9d5b8aa47d..3fd0a38cf5 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py
@@ -35,21 +35,34 @@ class ProfileStoreTestCase(unittest.TestCase): def test_displayname(self): yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) - yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank") + yield defer.ensureDeferred( + self.store.set_profile_displayname(self.u_frank.localpart, "Frank") + ) self.assertEquals( - "Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart)) + "Frank", + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.u_frank.localpart) + ) + ), ) @defer.inlineCallbacks def test_avatar_url(self): yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) - yield self.store.set_profile_avatar_url( - self.u_frank.localpart, "http://my.site/here" + yield defer.ensureDeferred( + self.store.set_profile_avatar_url( + self.u_frank.localpart, "http://my.site/here" + ) ) self.assertEquals( "http://my.site/here", - (yield self.store.get_profile_avatar_url(self.u_frank.localpart)), + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.u_frank.localpart) + ) + ), ) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 840db66072..70c55cd650 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py
@@ -17,6 +17,7 @@ from twisted.internet import defer from synapse.api.constants import UserTypes +from synapse.api.errors import ThreepidValidationError from tests import unittest from tests.utils import setup_test_homeserver @@ -52,7 +53,7 @@ class RegistrationStoreTestCase(unittest.TestCase): "user_type": None, "deactivated": 0, }, - (yield self.store.get_user_by_id(self.user_id)), + (yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))), ) @defer.inlineCallbacks @@ -122,3 +123,33 @@ class RegistrationStoreTestCase(unittest.TestCase): ) res = yield self.store.is_support_user(SUPPORT_USER) self.assertTrue(res) + + @defer.inlineCallbacks + def test_3pid_inhibit_invalid_validation_session_error(self): + """Tests that enabling the configuration option to inhibit 3PID errors on + /requestToken also inhibits validation errors caused by an unknown session ID. + """ + + # Check that, with the config setting set to false (the default value), a + # validation error is caused by the unknown session ID. + try: + yield defer.ensureDeferred( + self.store.validate_threepid_session( + "fake_sid", "fake_client_secret", "fake_token", 0, + ) + ) + except ThreepidValidationError as e: + self.assertEquals(e.msg, "Unknown session_id", e) + + # Set the config setting to true. + self.store._ignore_unknown_session_error = True + + # Check that now the validation error is caused by the token not matching. + try: + yield defer.ensureDeferred( + self.store.validate_threepid_session( + "fake_sid", "fake_client_secret", "fake_token", 0, + ) + ) + except ThreepidValidationError as e: + self.assertEquals(e.msg, "Validation token not found or has expired", e) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index d07b985a8e..bc8400f240 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py
@@ -54,12 +54,14 @@ class RoomStoreTestCase(unittest.TestCase): "creator": self.u_creator.to_string(), "is_public": True, }, - (yield self.store.get_room(self.room.to_string())), + (yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))), ) @defer.inlineCallbacks def test_get_room_unknown_room(self): - self.assertIsNone((yield self.store.get_room("!uknown:test")),) + self.assertIsNone( + (yield defer.ensureDeferred(self.store.get_room("!uknown:test"))) + ) @defer.inlineCallbacks def test_get_room_with_stats(self): @@ -69,12 +71,22 @@ class RoomStoreTestCase(unittest.TestCase): "creator": self.u_creator.to_string(), "public": True, }, - (yield self.store.get_room_with_stats(self.room.to_string())), + ( + yield defer.ensureDeferred( + self.store.get_room_with_stats(self.room.to_string()) + ) + ), ) @defer.inlineCallbacks def test_get_room_with_stats_unknown_room(self): - self.assertIsNone((yield self.store.get_room_with_stats("!uknown:test")),) + self.assertIsNone( + ( + yield defer.ensureDeferred( + self.store.get_room_with_stats("!uknown:test") + ) + ), + ) class RoomEventsStoreTestCase(unittest.TestCase): diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index d98fe8754d..12ccc1f53e 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py
@@ -87,7 +87,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): self.inject_room_member(self.room, self.u_bob, Membership.JOIN) self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN) - self.pump(20) + self.pump() self.assertTrue("_known_servers_count" not in self.store.__dict__.keys()) @@ -101,7 +101,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # Initialises to 1 -- itself self.assertEqual(self.store._known_servers_count, 1) - self.pump(20) + self.pump() # No rooms have been joined, so technically the SQL returns 0, but it # will still say it knows about itself. @@ -111,7 +111,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): self.inject_room_member(self.room, self.u_bob, Membership.JOIN) self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN) - self.pump(20) + self.pump(1) # It now knows about Charlie's server. self.assertEqual(self.store._known_servers_count, 2) diff --git a/tests/test_federation.py b/tests/test_federation.py
index 4a4548433f..27a7fc9ed7 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py
@@ -15,8 +15,9 @@ from mock import Mock -from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed +from twisted.internet.defer import succeed +from synapse.api.errors import FederationError from synapse.events import make_event_from_dict from synapse.logging.context import LoggingContext from synapse.types import Requester, UserID @@ -44,22 +45,17 @@ class MessageAcceptTests(unittest.HomeserverTestCase): user_id = UserID("us", "test") our_user = Requester(user_id, None, False, False, None, None) room_creator = self.homeserver.get_room_creation_handler() - room_deferred = ensureDeferred( + self.room_id = self.get_success( room_creator.create_room( our_user, room_creator._presets_dict["public_chat"], ratelimit=False ) - ) - self.reactor.advance(0.1) - self.room_id = self.successResultOf(room_deferred)[0]["room_id"] + )[0]["room_id"] self.store = self.homeserver.get_datastore() # Figure out what the most recent event is - most_recent = self.successResultOf( - maybeDeferred( - self.homeserver.get_datastore().get_latest_event_ids_in_room, - self.room_id, - ) + most_recent = self.get_success( + self.homeserver.get_datastore().get_latest_event_ids_in_room(self.room_id) )[0] join_event = make_event_from_dict( @@ -89,19 +85,18 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) # Send the join, it should return None (which is not an error) - d = ensureDeferred( - self.handler.on_receive_pdu( - "test.serv", join_event, sent_to_us_directly=True - ) + self.assertEqual( + self.get_success( + self.handler.on_receive_pdu( + "test.serv", join_event, sent_to_us_directly=True + ) + ), + None, ) - self.reactor.advance(1) - self.assertEqual(self.successResultOf(d), None) # Make sure we actually joined the room self.assertEqual( - self.successResultOf( - maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) - )[0], + self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))[0], "$join:test.serv", ) @@ -119,8 +114,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase): self.http_client.post_json = post_json # Figure out what the most recent event is - most_recent = self.successResultOf( - maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) + most_recent = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) )[0] # Now lie about an event @@ -140,17 +135,14 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) with LoggingContext(request="lying_event"): - d = ensureDeferred( + failure = self.get_failure( self.handler.on_receive_pdu( "test.serv", lying_event, sent_to_us_directly=True - ) + ), + FederationError, ) - # Step the reactor, so the database fetches come back - self.reactor.advance(1) - # on_receive_pdu should throw an error - failure = self.failureResultOf(d) self.assertEqual( failure.value.args[0], ( @@ -160,8 +152,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) # Make sure the invalid event isn't there - extrem = maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) - self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv") + extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)) + self.assertEqual(extrem[0], "$join:test.serv") def test_retry_device_list_resync(self): """Tests that device lists are marked as stale if they couldn't be synced, and diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 4d2b9e0d64..0363735d4f 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py
@@ -366,11 +366,11 @@ class CachedListDescriptorTestCase(unittest.TestCase): def fn(self, arg1, arg2): pass - @descriptors.cachedList("fn", "args1", inlineCallbacks=True) - def list_fn(self, args1, arg2): + @descriptors.cachedList("fn", "args1") + async def list_fn(self, args1, arg2): assert current_context().request == "c1" # we want this to behave like an asynchronous function - yield run_on_reactor() + await run_on_reactor() assert current_context().request == "c1" return self.mock(args1, arg2) @@ -416,10 +416,10 @@ class CachedListDescriptorTestCase(unittest.TestCase): def fn(self, arg1, arg2): pass - @descriptors.cachedList("fn", "args1", inlineCallbacks=True) - def list_fn(self, args1, arg2): + @descriptors.cachedList("fn", "args1") + async def list_fn(self, args1, arg2): # we want this to behave like an asynchronous function - yield run_on_reactor() + await run_on_reactor() return self.mock(args1, arg2) obj = Cls() diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py
index bc42ffce88..5f46ed0cef 100644 --- a/tests/util/test_retryutils.py +++ b/tests/util/test_retryutils.py
@@ -91,7 +91,7 @@ class RetryLimiterTestCase(HomeserverTestCase): # # one more go, with success # - self.pump(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0) + self.reactor.advance(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0) limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) self.pump(1)