diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index d580e729c5..1e1f30d790 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -1,4 +1,6 @@
from synapse.api.ratelimiting import LimitExceededError, Ratelimiter
+from synapse.appservice import ApplicationService
+from synapse.types import create_requester
from tests import unittest
@@ -18,6 +20,77 @@ class TestRatelimiter(unittest.TestCase):
self.assertTrue(allowed)
self.assertEquals(20.0, time_allowed)
+ def test_allowed_user_via_can_requester_do_action(self):
+ user_requester = create_requester("@user:example.com")
+ limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+ allowed, time_allowed = limiter.can_requester_do_action(
+ user_requester, _time_now_s=0
+ )
+ self.assertTrue(allowed)
+ self.assertEquals(10.0, time_allowed)
+
+ allowed, time_allowed = limiter.can_requester_do_action(
+ user_requester, _time_now_s=5
+ )
+ self.assertFalse(allowed)
+ self.assertEquals(10.0, time_allowed)
+
+ allowed, time_allowed = limiter.can_requester_do_action(
+ user_requester, _time_now_s=10
+ )
+ self.assertTrue(allowed)
+ self.assertEquals(20.0, time_allowed)
+
+ def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
+ appservice = ApplicationService(
+ None, "example.com", id="foo", rate_limited=True,
+ )
+ as_requester = create_requester("@user:example.com", app_service=appservice)
+
+ limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+ allowed, time_allowed = limiter.can_requester_do_action(
+ as_requester, _time_now_s=0
+ )
+ self.assertTrue(allowed)
+ self.assertEquals(10.0, time_allowed)
+
+ allowed, time_allowed = limiter.can_requester_do_action(
+ as_requester, _time_now_s=5
+ )
+ self.assertFalse(allowed)
+ self.assertEquals(10.0, time_allowed)
+
+ allowed, time_allowed = limiter.can_requester_do_action(
+ as_requester, _time_now_s=10
+ )
+ self.assertTrue(allowed)
+ self.assertEquals(20.0, time_allowed)
+
+ def test_allowed_appservice_via_can_requester_do_action(self):
+ appservice = ApplicationService(
+ None, "example.com", id="foo", rate_limited=False,
+ )
+ as_requester = create_requester("@user:example.com", app_service=appservice)
+
+ limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+ allowed, time_allowed = limiter.can_requester_do_action(
+ as_requester, _time_now_s=0
+ )
+ self.assertTrue(allowed)
+ self.assertEquals(-1, time_allowed)
+
+ allowed, time_allowed = limiter.can_requester_do_action(
+ as_requester, _time_now_s=5
+ )
+ self.assertTrue(allowed)
+ self.assertEquals(-1, time_allowed)
+
+ allowed, time_allowed = limiter.can_requester_do_action(
+ as_requester, _time_now_s=10
+ )
+ self.assertTrue(allowed)
+ self.assertEquals(-1, time_allowed)
+
def test_allowed_via_ratelimit(self):
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 1bb25ab684..f92f3b8c15 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -374,12 +374,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
self.handler._auth_handler.complete_sso_login = simple_async_mock()
- request = Mock(spec=["args", "getCookie", "addCookie"])
+ request = Mock(
+ spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
+ )
code = "code"
state = "state"
nonce = "nonce"
client_redirect_url = "http://client/redirect"
+ user_agent = "Browser"
+ ip_address = "10.0.0.1"
session = self.handler._generate_oidc_session_token(
state=state,
nonce=nonce,
@@ -392,6 +396,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
request.args[b"code"] = [code.encode("utf-8")]
request.args[b"state"] = [state.encode("utf-8")]
+ request.requestHeaders = Mock(spec=["getRawHeaders"])
+ request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")]
+ request.getClientIP.return_value = ip_address
+
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
@@ -399,7 +407,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
- self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+ self.handler._map_userinfo_to_user.assert_called_once_with(
+ userinfo, token, user_agent, ip_address
+ )
self.handler._fetch_userinfo.assert_not_called()
self.handler._render_error.assert_not_called()
@@ -431,7 +441,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_not_called()
- self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+ self.handler._map_userinfo_to_user.assert_called_once_with(
+ userinfo, token, user_agent, ip_address
+ )
self.handler._fetch_userinfo.assert_called_once_with(token)
self.handler._render_error.assert_not_called()
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index b609b30d4a..60ebc95f3e 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -71,7 +71,9 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_my_name(self):
- yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
+ yield defer.ensureDeferred(
+ self.store.set_profile_displayname(self.frank.localpart, "Frank")
+ )
displayname = yield defer.ensureDeferred(
self.handler.get_displayname(self.frank)
@@ -104,7 +106,12 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
- (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_displayname(self.frank.localpart)
+ )
+ ),
+ "Frank",
)
@defer.inlineCallbacks
@@ -112,10 +119,17 @@ class ProfileTestCase(unittest.TestCase):
self.hs.config.enable_set_displayname = False
# Setting displayname for the first time is allowed
- yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
+ yield defer.ensureDeferred(
+ self.store.set_profile_displayname(self.frank.localpart, "Frank")
+ )
self.assertEquals(
- (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_displayname(self.frank.localpart)
+ )
+ ),
+ "Frank",
)
# Setting displayname a second time is forbidden
@@ -158,7 +172,9 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_incoming_fed_query(self):
yield defer.ensureDeferred(self.store.create_profile("caroline"))
- yield self.store.set_profile_displayname("caroline", "Caroline")
+ yield defer.ensureDeferred(
+ self.store.set_profile_displayname("caroline", "Caroline")
+ )
response = yield defer.ensureDeferred(
self.query_handlers["profile"](
@@ -170,8 +186,10 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_my_avatar(self):
- yield self.store.set_profile_avatar_url(
- self.frank.localpart, "http://my.server/me.png"
+ yield defer.ensureDeferred(
+ self.store.set_profile_avatar_url(
+ self.frank.localpart, "http://my.server/me.png"
+ )
)
avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
@@ -188,7 +206,11 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
- (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_avatar_url(self.frank.localpart)
+ )
+ ),
"http://my.server/pic.gif",
)
@@ -202,7 +224,11 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
- (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_avatar_url(self.frank.localpart)
+ )
+ ),
"http://my.server/me.png",
)
@@ -211,12 +237,18 @@ class ProfileTestCase(unittest.TestCase):
self.hs.config.enable_set_avatar_url = False
# Setting displayname for the first time is allowed
- yield self.store.set_profile_avatar_url(
- self.frank.localpart, "http://my.server/me.png"
+ yield defer.ensureDeferred(
+ self.store.set_profile_avatar_url(
+ self.frank.localpart, "http://my.server/me.png"
+ )
)
self.assertEquals(
- (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_avatar_url(self.frank.localpart)
+ )
+ ),
"http://my.server/me.png",
)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index e364b1bd62..5c92d0e8c9 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -17,18 +17,21 @@ from mock import Mock
from twisted.internet import defer
+from synapse.api.auth import Auth
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, ResourceLimitError, SynapseError
from synapse.handlers.register import RegistrationHandler
+from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias, UserID, create_requester
from tests.test_utils import make_awaitable
from tests.unittest import override_config
+from tests.utils import mock_getRawHeaders
from .. import unittest
-class RegistrationHandlers(object):
+class RegistrationHandlers:
def __init__(self, hs):
self.registration_handler = RegistrationHandler(hs)
@@ -475,6 +478,53 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart=invalid_user_id), SynapseError
)
+ def test_spam_checker_deny(self):
+ """A spam checker can deny registration, which results in an error."""
+
+ class DenyAll:
+ def check_registration_for_spam(
+ self, email_threepid, username, request_info
+ ):
+ return RegistrationBehaviour.DENY
+
+ # Configure a spam checker that denies all users.
+ spam_checker = self.hs.get_spam_checker()
+ spam_checker.spam_checkers = [DenyAll()]
+
+ self.get_failure(self.handler.register_user(localpart="user"), SynapseError)
+
+ def test_spam_checker_shadow_ban(self):
+ """A spam checker can choose to shadow-ban a user, which allows registration to succeed."""
+
+ class BanAll:
+ def check_registration_for_spam(
+ self, email_threepid, username, request_info
+ ):
+ return RegistrationBehaviour.SHADOW_BAN
+
+ # Configure a spam checker that denies all users.
+ spam_checker = self.hs.get_spam_checker()
+ spam_checker.spam_checkers = [BanAll()]
+
+ user_id = self.get_success(self.handler.register_user(localpart="user"))
+
+ # Get an access token.
+ token = self.macaroon_generator.generate_access_token(user_id)
+ self.get_success(
+ self.store.add_access_token_to_user(
+ user_id=user_id, token=token, device_id=None, valid_until_ms=None
+ )
+ )
+
+ # Ensure the user was marked as shadow-banned.
+ request = Mock(args={})
+ request.args[b"access_token"] = [token.encode("ascii")]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ auth = Auth(self.hs)
+ requester = self.get_success(auth.get_user_by_req(request))
+
+ self.assertTrue(requester.shadow_banned)
+
async def get_or_create_user(
self, requester, localpart, displayname, password_hash=None
):
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index 0e666492f6..a609f148c0 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -81,8 +81,8 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)
- def get_all_room_state(self):
- return self.store.db_pool.simple_select_list(
+ async def get_all_room_state(self):
+ return await self.store.db_pool.simple_select_list(
"room_stats_state", None, retcols=("name", "topic", "canonical_alias")
)
@@ -256,7 +256,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# self.handler.notify_new_event()
# We need to let the delta processor advance…
- self.pump(10 * 60)
+ self.reactor.advance(10 * 60)
# Get the slices! There should be two -- day 1, and day 2.
r = self.get_success(self.store.get_statistics_for_subject("room", room_1, 0))
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index e01de158e5..81c1839637 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -21,7 +21,7 @@ from mock import ANY, Mock, call
from twisted.internet import defer
from synapse.api.errors import AuthError
-from synapse.types import UserID
+from synapse.types import UserID, create_requester
from tests import unittest
from tests.test_utils import make_awaitable
@@ -144,9 +144,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_users_in_room = get_users_in_room
- self.datastore.get_user_directory_stream_pos.return_value = (
+ self.datastore.get_user_directory_stream_pos.side_effect = (
# we deliberately return a non-None stream pos to avoid doing an initial_spam
- defer.succeed(1)
+ lambda: make_awaitable(1)
)
self.datastore.get_current_state_deltas.return_value = (0, None)
@@ -167,7 +167,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.get_success(
self.handler.started_typing(
- target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000
+ target_user=U_APPLE,
+ requester=create_requester(U_APPLE),
+ room_id=ROOM_ID,
+ timeout=20000,
)
)
@@ -194,7 +197,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.get_success(
self.handler.started_typing(
- target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000
+ target_user=U_APPLE,
+ requester=create_requester(U_APPLE),
+ room_id=ROOM_ID,
+ timeout=20000,
)
)
@@ -269,7 +275,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.get_success(
self.handler.stopped_typing(
- target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID
+ target_user=U_APPLE,
+ requester=create_requester(U_APPLE),
+ room_id=ROOM_ID,
)
)
@@ -309,7 +317,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.get_success(
self.handler.started_typing(
- target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000
+ target_user=U_APPLE,
+ requester=create_requester(U_APPLE),
+ room_id=ROOM_ID,
+ timeout=10000,
)
)
@@ -348,7 +359,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.get_success(
self.handler.started_typing(
- target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000
+ target_user=U_APPLE,
+ requester=create_requester(U_APPLE),
+ room_id=ROOM_ID,
+ timeout=10000,
)
)
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 31ed89a5cd..87be94111f 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -238,7 +238,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def test_spam_checker(self):
"""
- A user which fails to the spam checks will not appear in search results.
+ A user which fails the spam checks will not appear in search results.
"""
u1 = self.register_user("user1", "pass")
u1_token = self.login(u1, "pass")
@@ -269,7 +269,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Configure a spam checker that does not filter any users.
spam_checker = self.hs.get_spam_checker()
- class AllowAll(object):
+ class AllowAll:
def check_username_for_spam(self, user_profile):
# Allow all users.
return False
@@ -282,7 +282,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(s["results"]), 1)
# Configure a spam checker that filters all users.
- class BlockAll(object):
+ class BlockAll:
def check_username_for_spam(self, user_profile):
# All users are spammy.
return True
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index ac598249e4..7f1dfb2128 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -508,6 +508,6 @@ class FederationClientTests(HomeserverTestCase):
self.assertFalse(conn.disconnecting)
# wait for a while
- self.pump(120)
+ self.reactor.advance(120)
self.assertTrue(conn.disconnecting)
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 807cd65dd6..04de0b9dbe 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -35,7 +35,7 @@ class ModuleApiTestCase(HomeserverTestCase):
# Check that the new user exists with all provided attributes
self.assertEqual(user_id, "@bob:test")
self.assertTrue(access_token)
- self.assertTrue(self.store.get_user_by_id(user_id))
+ self.assertTrue(self.get_success(self.store.get_user_by_id(user_id)))
# Check that the email was assigned
emails = self.get_success(self.store.user_get_threepids(user_id))
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 83032cc9ea..227b0d32d0 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -170,7 +170,7 @@ class EmailPusherTests(HomeserverTestCase):
last_stream_ordering = pushers[0]["last_stream_ordering"]
# Advance time a bit, so the pusher will register something has happened
- self.pump(100)
+ self.pump(10)
# It hasn't succeeded yet, so the stream ordering shouldn't have moved
pushers = self.get_success(
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 83f9aa291c..8b4982ecb1 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -20,7 +20,7 @@ from synapse.api.constants import EventTypes, Membership
from synapse.events.builder import EventBuilderFactory
from synapse.rest.admin import register_servlets_for_client_rest_resource
from synapse.rest.client.v1 import login, room
-from synapse.types import UserID
+from synapse.types import UserID, create_requester
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.test_utils import make_awaitable
@@ -175,7 +175,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
self.get_success(
typing_handler.started_typing(
target_user=UserID.from_string(user),
- auth_user=UserID.from_string(user),
+ requester=create_requester(user),
room_id=room,
timeout=20000,
)
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 0b191d13c6..7d3773ff78 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -45,50 +45,63 @@ class RetentionTestCase(unittest.HomeserverTestCase):
}
self.hs = self.setup_test_homeserver(config=config)
+
return self.hs
def prepare(self, reactor, clock, homeserver):
self.user_id = self.register_user("user", "password")
self.token = self.login("user", "password")
- def test_retention_state_event(self):
- """Tests that the server configuration can limit the values a user can set to the
- room's retention policy.
+ self.store = self.hs.get_datastore()
+ self.serializer = self.hs.get_event_client_serializer()
+ self.clock = self.hs.get_clock()
+
+ def test_retention_event_purged_with_state_event(self):
+ """Tests that expired events are correctly purged when the room's retention policy
+ is defined by a state event.
"""
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+ # Set the room's retention period to 2 days.
+ lifetime = one_day_ms * 2
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
- body={"max_lifetime": one_day_ms * 4},
+ body={"max_lifetime": lifetime},
tok=self.token,
- expect_code=400,
)
+ self._test_retention_event_purged(room_id, one_day_ms * 1.5)
+
+ def test_retention_event_purged_with_state_event_outside_allowed(self):
+ """Tests that the server configuration can override the policy for a room when
+ running the purge jobs.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ # Set a max_lifetime higher than the maximum allowed value.
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
- body={"max_lifetime": one_hour_ms},
+ body={"max_lifetime": one_day_ms * 4},
tok=self.token,
- expect_code=400,
)
- def test_retention_event_purged_with_state_event(self):
- """Tests that expired events are correctly purged when the room's retention policy
- is defined by a state event.
- """
- room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+ # Check that the event is purged after waiting for the maximum allowed duration
+ # instead of the one specified in the room's policy.
+ self._test_retention_event_purged(room_id, one_day_ms * 1.5)
- # Set the room's retention period to 2 days.
- lifetime = one_day_ms * 2
+ # Set a max_lifetime lower than the minimum allowed value.
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
- body={"max_lifetime": lifetime},
+ body={"max_lifetime": one_hour_ms},
tok=self.token,
)
- self._test_retention_event_purged(room_id, one_day_ms * 1.5)
+ # Check that the event is purged after waiting for the minimum allowed duration
+ # instead of the one specified in the room's policy.
+ self._test_retention_event_purged(room_id, one_day_ms * 0.5)
def test_retention_event_purged_without_state_event(self):
"""Tests that expired events are correctly purged when the room's retention policy
@@ -140,12 +153,32 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# That event should be the second, not outdated event.
self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events)
- def _test_retention_event_purged(self, room_id, increment):
+ def _test_retention_event_purged(self, room_id: str, increment: float):
+ """Run the following test scenario to test the message retention policy support:
+
+ 1. Send event 1
+ 2. Increment time by `increment`
+ 3. Send event 2
+ 4. Increment time by `increment`
+ 5. Check that event 1 has been purged
+ 6. Check that event 2 has not been purged
+ 7. Check that state events that were sent before event 1 aren't purged.
+ The main reason for sending a second event is because currently Synapse won't
+ purge the latest message in a room because it would otherwise result in a lack of
+ forward extremities for this room. It's also a good thing to ensure the purge jobs
+ aren't too greedy and purge messages they shouldn't.
+
+ Args:
+ room_id: The ID of the room to test retention in.
+ increment: The number of milliseconds to advance the clock each time. Must be
+ defined so that events in the room aren't purged if they are `increment`
+ old but are purged if they are `increment * 2` old.
+ """
# Get the create event to, later, check that we can still access it.
message_handler = self.hs.get_message_handler()
create_event = self.get_success(
message_handler.get_room_data(
- self.user_id, room_id, EventTypes.Create, state_key="", is_guest=False
+ self.user_id, room_id, EventTypes.Create, state_key=""
)
)
@@ -156,7 +189,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
expired_event_id = resp.get("event_id")
# Check that we can retrieve the event.
- expired_event = self.get_event(room_id, expired_event_id)
+ expired_event = self.get_event(expired_event_id)
self.assertEqual(
expired_event.get("content", {}).get("body"), "1", expired_event
)
@@ -174,26 +207,31 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# one should still be kept.
self.reactor.advance(increment / 1000)
- # Check that the event has been purged from the database.
- self.get_event(room_id, expired_event_id, expected_code=404)
+ # Check that the first event has been purged from the database, i.e. that we
+ # can't retrieve it anymore, because it has expired.
+ self.get_event(expired_event_id, expect_none=True)
- # Check that the event that hasn't been purged can still be retrieved.
- valid_event = self.get_event(room_id, valid_event_id)
+ # Check that the event that hasn't expired can still be retrieved.
+ valid_event = self.get_event(valid_event_id)
self.assertEqual(valid_event.get("content", {}).get("body"), "2", valid_event)
# Check that we can still access state events that were sent before the event that
# has been purged.
self.get_event(room_id, create_event.event_id)
- def get_event(self, room_id, event_id, expected_code=200):
- url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
+ def get_event(self, event_id, expect_none=False):
+ event = self.get_success(self.store.get_event(event_id, allow_none=True))
- request, channel = self.make_request("GET", url, access_token=self.token)
- self.render(request)
+ if expect_none:
+ self.assertIsNone(event)
+ return {}
- self.assertEqual(channel.code, expected_code, channel.result)
+ self.assertIsNotNone(event)
- return channel.json_body
+ time_now = self.clock.time_msec()
+ serialized = self.get_success(self.serializer.serialize_event(event, time_now))
+
+ return serialized
class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
new file mode 100644
index 0000000000..dfe4bf7762
--- /dev/null
+++ b/tests/rest/client/test_shadow_banned.py
@@ -0,0 +1,312 @@
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from mock import Mock, patch
+
+import synapse.rest.admin
+from synapse.api.constants import EventTypes
+from synapse.rest.client.v1 import directory, login, profile, room
+from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet
+
+from tests import unittest
+
+
+class _ShadowBannedBase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, homeserver):
+ # Create two users, one of which is shadow-banned.
+ self.banned_user_id = self.register_user("banned", "test")
+ self.banned_access_token = self.login("banned", "test")
+
+ self.store = self.hs.get_datastore()
+
+ self.get_success(
+ self.store.db_pool.simple_update(
+ table="users",
+ keyvalues={"name": self.banned_user_id},
+ updatevalues={"shadow_banned": True},
+ desc="shadow_ban",
+ )
+ )
+
+ self.other_user_id = self.register_user("otheruser", "pass")
+ self.other_access_token = self.login("otheruser", "pass")
+
+
+# To avoid the tests timing out don't add a delay to "annoy the requester".
+@patch("random.randint", new=lambda a, b: 0)
+class RoomTestCase(_ShadowBannedBase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ directory.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ room_upgrade_rest_servlet.register_servlets,
+ ]
+
+ def test_invite(self):
+ """Invites from shadow-banned users don't actually get sent."""
+
+ # The create works fine.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ # Inviting the user completes successfully.
+ self.helper.invite(
+ room=room_id,
+ src=self.banned_user_id,
+ tok=self.banned_access_token,
+ targ=self.other_user_id,
+ )
+
+ # But the user wasn't actually invited.
+ invited_rooms = self.get_success(
+ self.store.get_invited_rooms_for_local_user(self.other_user_id)
+ )
+ self.assertEqual(invited_rooms, [])
+
+ def test_invite_3pid(self):
+ """Ensure that a 3PID invite does not attempt to contact the identity server."""
+ identity_handler = self.hs.get_handlers().identity_handler
+ identity_handler.lookup_3pid = Mock(
+ side_effect=AssertionError("This should not get called")
+ )
+
+ # The create works fine.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ # Inviting the user completes successfully.
+ request, channel = self.make_request(
+ "POST",
+ "/rooms/%s/invite" % (room_id,),
+ {"id_server": "test", "medium": "email", "address": "test@test.test"},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+
+ # This should have raised an error earlier, but double check this wasn't called.
+ identity_handler.lookup_3pid.assert_not_called()
+
+ def test_create_room(self):
+ """Invitations during a room creation should be discarded, but the room still gets created."""
+ # The room creation is successful.
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/createRoom",
+ {"visibility": "public", "invite": [self.other_user_id]},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+ room_id = channel.json_body["room_id"]
+
+ # But the user wasn't actually invited.
+ invited_rooms = self.get_success(
+ self.store.get_invited_rooms_for_local_user(self.other_user_id)
+ )
+ self.assertEqual(invited_rooms, [])
+
+ # Since a real room was created, the other user should be able to join it.
+ self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
+
+ # Both users should be in the room.
+ users = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertCountEqual(users, ["@banned:test", "@otheruser:test"])
+
+ def test_message(self):
+ """Messages from shadow-banned users don't actually get sent."""
+
+ room_id = self.helper.create_room_as(
+ self.other_user_id, tok=self.other_access_token
+ )
+
+ # The user should be in the room.
+ self.helper.join(room_id, self.banned_user_id, tok=self.banned_access_token)
+
+ # Sending a message should complete successfully.
+ result = self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "with right label"},
+ tok=self.banned_access_token,
+ )
+ self.assertIn("event_id", result)
+ event_id = result["event_id"]
+
+ latest_events = self.get_success(
+ self.store.get_latest_event_ids_in_room(room_id)
+ )
+ self.assertNotIn(event_id, latest_events)
+
+ def test_upgrade(self):
+ """A room upgrade should fail, but look like it succeeded."""
+
+ # The create works fine.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/%s/upgrade" % (room_id,),
+ {"new_version": "6"},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+ # A new room_id should be returned.
+ self.assertIn("replacement_room", channel.json_body)
+
+ new_room_id = channel.json_body["replacement_room"]
+
+ # It doesn't really matter what API we use here, we just want to assert
+ # that the room doesn't exist.
+ summary = self.get_success(self.store.get_room_summary(new_room_id))
+ # The summary should be empty since the room doesn't exist.
+ self.assertEqual(summary, {})
+
+ def test_typing(self):
+ """Typing notifications should not be propagated into the room."""
+ # The create works fine.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ request, channel = self.make_request(
+ "PUT",
+ "/rooms/%s/typing/%s" % (room_id, self.banned_user_id),
+ {"typing": True, "timeout": 30000},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+
+ # There should be no typing events.
+ event_source = self.hs.get_event_sources().sources["typing"]
+ self.assertEquals(event_source.get_current_key(), 0)
+
+ # The other user can join and send typing events.
+ self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
+
+ request, channel = self.make_request(
+ "PUT",
+ "/rooms/%s/typing/%s" % (room_id, self.other_user_id),
+ {"typing": True, "timeout": 30000},
+ access_token=self.other_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+
+ # These appear in the room.
+ self.assertEquals(event_source.get_current_key(), 1)
+ events = self.get_success(
+ event_source.get_new_events(from_key=0, room_ids=[room_id])
+ )
+ self.assertEquals(
+ events[0],
+ [
+ {
+ "type": "m.typing",
+ "room_id": room_id,
+ "content": {"user_ids": [self.other_user_id]},
+ }
+ ],
+ )
+
+
+# To avoid the tests timing out don't add a delay to "annoy the requester".
+@patch("random.randint", new=lambda a, b: 0)
+class ProfileTestCase(_ShadowBannedBase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ profile.register_servlets,
+ room.register_servlets,
+ ]
+
+ def test_displayname(self):
+ """Profile changes should succeed, but don't end up in a room."""
+ original_display_name = "banned"
+ new_display_name = "new name"
+
+ # Join a room.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ # The update should succeed.
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/profile/%s/displayname" % (self.banned_user_id,),
+ {"displayname": new_display_name},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEqual(channel.json_body, {})
+
+ # The user's display name should be updated.
+ request, channel = self.make_request(
+ "GET", "/profile/%s/displayname" % (self.banned_user_id,)
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(channel.json_body["displayname"], new_display_name)
+
+ # But the display name in the room should not be.
+ message_handler = self.hs.get_message_handler()
+ event = self.get_success(
+ message_handler.get_room_data(
+ self.banned_user_id, room_id, "m.room.member", self.banned_user_id,
+ )
+ )
+ self.assertEqual(
+ event.content, {"membership": "join", "displayname": original_display_name}
+ )
+
+ def test_room_displayname(self):
+ """Changes to state events for a room should be processed, but not end up in the room."""
+ original_display_name = "banned"
+ new_display_name = "new name"
+
+ # Join a room.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ # The update should succeed.
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/state/m.room.member/%s"
+ % (room_id, self.banned_user_id),
+ {"membership": "join", "displayname": new_display_name},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertIn("event_id", channel.json_body)
+
+ # The display name in the room should not be changed.
+ message_handler = self.hs.get_message_handler()
+ event = self.get_success(
+ message_handler.get_room_data(
+ self.banned_user_id, room_id, "m.room.member", self.banned_user_id,
+ )
+ )
+ self.assertEqual(
+ event.content, {"membership": "join", "displayname": original_display_name}
+ )
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index db52725cfe..2668662c9e 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -62,8 +62,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
"password": "monkey",
}
- request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
if i == 5:
@@ -76,14 +75,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# than 1min.
self.assertTrue(retry_after_ms < 6000)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
params = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
"password": "monkey",
}
- request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
@@ -111,8 +109,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
- request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
if i == 5:
@@ -132,7 +129,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
- request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
@@ -160,8 +156,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey",
}
- request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
if i == 5:
@@ -174,14 +169,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# than 1min.
self.assertTrue(retry_after_ms < 6000)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
params = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey",
}
- request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index ef6b775ed2..0a567b032f 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -28,7 +28,7 @@ from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.handlers.pagination import PurgeStatus
from synapse.rest.client.v1 import directory, login, profile, room
from synapse.rest.client.v2_alpha import account
-from synapse.types import JsonDict, RoomAlias
+from synapse.types import JsonDict, RoomAlias, UserID
from synapse.util.stringutils import random_string
from tests import unittest
@@ -675,6 +675,92 @@ class RoomMemberStateTestCase(RoomBase):
self.assertEquals(json.loads(content), channel.json_body)
+class RoomJoinRatelimitTestCase(RoomBase):
+ user_id = "@sid1:red"
+
+ servlets = [
+ profile.register_servlets,
+ room.register_servlets,
+ ]
+
+ @unittest.override_config(
+ {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_join_local_ratelimit(self):
+ """Tests that local joins are actually rate-limited."""
+ for i in range(3):
+ self.helper.create_room_as(self.user_id)
+
+ self.helper.create_room_as(self.user_id, expect_code=429)
+
+ @unittest.override_config(
+ {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_join_local_ratelimit_profile_change(self):
+ """Tests that sending a profile update into all of the user's joined rooms isn't
+ rate-limited by the rate-limiter on joins."""
+
+ # Create and join as many rooms as the rate-limiting config allows in a second.
+ room_ids = [
+ self.helper.create_room_as(self.user_id),
+ self.helper.create_room_as(self.user_id),
+ self.helper.create_room_as(self.user_id),
+ ]
+ # Let some time for the rate-limiter to forget about our multi-join.
+ self.reactor.advance(2)
+ # Add one to make sure we're joined to more rooms than the config allows us to
+ # join in a second.
+ room_ids.append(self.helper.create_room_as(self.user_id))
+
+ # Create a profile for the user, since it hasn't been done on registration.
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.create_profile(UserID.from_string(self.user_id).localpart)
+ )
+
+ # Update the display name for the user.
+ path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
+ request, channel = self.make_request("PUT", path, {"displayname": "John Doe"})
+ self.render(request)
+ self.assertEquals(channel.code, 200, channel.json_body)
+
+ # Check that all the rooms have been sent a profile update into.
+ for room_id in room_ids:
+ path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (
+ room_id,
+ self.user_id,
+ )
+
+ request, channel = self.make_request("GET", path)
+ self.render(request)
+ self.assertEquals(channel.code, 200)
+
+ self.assertIn("displayname", channel.json_body)
+ self.assertEquals(channel.json_body["displayname"], "John Doe")
+
+ @unittest.override_config(
+ {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_join_local_ratelimit_idempotent(self):
+ """Tests that the room join endpoints remain idempotent despite rate-limiting
+ on room joins."""
+ room_id = self.helper.create_room_as(self.user_id)
+
+ # Let's test both paths to be sure.
+ paths_to_test = [
+ "/_matrix/client/r0/rooms/%s/join",
+ "/_matrix/client/r0/join/%s",
+ ]
+
+ for path in paths_to_test:
+ # Make sure we send more requests than the rate-limiting config would allow
+ # if all of these requests ended up joining the user to a room.
+ for i in range(4):
+ request, channel = self.make_request("POST", path % room_id, {})
+ self.render(request)
+ self.assertEquals(channel.code, 200)
+
+
class RoomMessagesTestCase(RoomBase):
""" Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 8933b560d2..e66c9a4c4c 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -39,7 +39,9 @@ class RestHelper(object):
resource = attr.ib()
auth_user_id = attr.ib()
- def create_room_as(self, room_creator=None, is_public=True, tok=None):
+ def create_room_as(
+ self, room_creator=None, is_public=True, tok=None, expect_code=200,
+ ):
temp_id = self.auth_user_id
self.auth_user_id = room_creator
path = "/_matrix/client/r0/createRoom"
@@ -54,9 +56,11 @@ class RestHelper(object):
)
render(request, self.resource, self.hs.get_reactor())
- assert channel.result["code"] == b"200", channel.result
+ assert channel.result["code"] == b"%d" % expect_code, channel.result
self.auth_user_id = temp_id
- return channel.json_body["room_id"]
+
+ if expect_code == 200:
+ return channel.json_body["room_id"]
def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
self.change_membership(
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 53a43038f0..2fc3a60fc5 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -160,7 +160,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
else:
self.assertEquals(channel.result["code"], b"200", channel.result)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.render(request)
@@ -186,7 +186,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
else:
self.assertEquals(channel.result["code"], b"200", channel.result)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.render(request)
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 2858d13558..23db821fb7 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -104,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock(
- return_value=defer.succeed({"123": mock_event})
+ return_value=make_awaitable({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
# Would be better to check the content, but once == remove blocking event
@@ -122,7 +122,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock(
- return_value=defer.succeed({"123": mock_event})
+ return_value=make_awaitable({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -217,7 +217,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock(
- return_value=defer.succeed({"123": mock_event})
+ return_value=make_awaitable({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index a425e66f37..17fbde284a 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -31,6 +31,7 @@ from synapse.storage.databases.main.appservice import (
)
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.utils import setup_test_homeserver
@@ -357,7 +358,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
other_events = [Mock(event_id="e5"), Mock(event_id="e6")]
# we aren't testing store._base stuff here, so mock this out
- self.store.get_events_as_list = Mock(return_value=defer.succeed(events))
+ self.store.get_events_as_list = Mock(return_value=make_awaitable(events))
yield self._insert_txn(self.as_list[1]["id"], 9, other_events)
yield self._insert_txn(service.id, 10, events)
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 2efbc97c2e..1a1c59256c 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -67,13 +67,12 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
# second step: complete the update
# we should now get run with a much bigger number of items to update
- @defer.inlineCallbacks
- def update(progress, count):
+ async def update(progress, count):
self.assertEqual(progress, {"my_key": 2})
self.assertAlmostEqual(
count, target_background_update_duration_ms / duration_ms, places=0,
)
- yield self.updates._end_background_update("test_update")
+ await self.updates._end_background_update("test_update")
return count
self.update_handler.side_effect = update
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 13bcac743a..40ba652248 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -97,8 +97,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
- value = yield self.datastore.db_pool.simple_select_one_onecol(
- table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
+ value = yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_select_one_onecol(
+ table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
+ )
)
self.assertEquals("Value", value)
@@ -111,10 +113,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
self.mock_txn.fetchone.return_value = (1, 2, 3)
- ret = yield self.datastore.db_pool.simple_select_one(
- table="tablename",
- keyvalues={"keycol": "TheKey"},
- retcols=["colA", "colB", "colC"],
+ ret = yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_select_one(
+ table="tablename",
+ keyvalues={"keycol": "TheKey"},
+ retcols=["colA", "colB", "colC"],
+ )
)
self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret)
@@ -127,11 +131,13 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 0
self.mock_txn.fetchone.return_value = None
- ret = yield self.datastore.db_pool.simple_select_one(
- table="tablename",
- keyvalues={"keycol": "Not here"},
- retcols=["colA"],
- allow_none=True,
+ ret = yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_select_one(
+ table="tablename",
+ keyvalues={"keycol": "Not here"},
+ retcols=["colA"],
+ allow_none=True,
+ )
)
self.assertFalse(ret)
@@ -142,8 +148,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
self.mock_txn.description = (("colA", None, None, None, None, None, None),)
- ret = yield self.datastore.db_pool.simple_select_list(
- table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"]
+ ret = yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_select_list(
+ table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"]
+ )
)
self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret)
@@ -155,10 +163,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_update_one_1col(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db_pool.simple_update_one(
- table="tablename",
- keyvalues={"keycol": "TheKey"},
- updatevalues={"columnname": "New Value"},
+ yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_update_one(
+ table="tablename",
+ keyvalues={"keycol": "TheKey"},
+ updatevalues={"columnname": "New Value"},
+ )
)
self.mock_txn.execute.assert_called_with(
@@ -170,10 +180,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_update_one_4cols(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db_pool.simple_update_one(
- table="tablename",
- keyvalues=OrderedDict([("colA", 1), ("colB", 2)]),
- updatevalues=OrderedDict([("colC", 3), ("colD", 4)]),
+ yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_update_one(
+ table="tablename",
+ keyvalues=OrderedDict([("colA", 1), ("colB", 2)]),
+ updatevalues=OrderedDict([("colC", 3), ("colD", 4)]),
+ )
)
self.mock_txn.execute.assert_called_with(
@@ -185,8 +197,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_delete_one(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db_pool.simple_delete_one(
- table="tablename", keyvalues={"keycol": "Go away"}
+ yield defer.ensureDeferred(
+ self.datastore.db_pool.simple_delete_one(
+ table="tablename", keyvalues={"keycol": "Go away"}
+ )
)
self.mock_txn.execute.assert_called_with(
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 8e9a650f9f..080761d1d2 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -271,7 +271,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
# Pump the reactor repeatedly so that the background updates have a
# chance to run.
- self.pump(10 * 60)
+ self.pump(20)
latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
@@ -353,6 +353,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[
"3"
] = 300000
+
self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
# All entries within time frame
self.assertEqual(
@@ -362,7 +363,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
3,
)
# Oldest room to expire
- self.pump(1)
+ self.pump(1.01)
self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
self.assertEqual(
len(
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 87ed8f8cd1..34ae8c9da7 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -38,7 +38,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
self.store.store_device("user_id", "device_id", "display_name")
)
- res = yield self.store.get_device("user_id", "device_id")
+ res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertDictContainsSubset(
{
"user_id": "user_id",
@@ -111,12 +111,12 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
self.store.store_device("user_id", "device_id", "display_name 1")
)
- res = yield self.store.get_device("user_id", "device_id")
+ res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"])
# do a no-op first
yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
- res = yield self.store.get_device("user_id", "device_id")
+ res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"])
# do the update
@@ -127,7 +127,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
)
# check it worked
- res = yield self.store.get_device("user_id", "device_id")
+ res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 2", res["display_name"])
@defer.inlineCallbacks
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index daac947cb2..da93ca3980 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -42,7 +42,11 @@ class DirectoryStoreTestCase(unittest.TestCase):
self.assertEquals(
["#my-room:test"],
- (yield self.store.get_aliases_for_room(self.room.to_string())),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_aliases_for_room(self.room.to_string())
+ )
+ ),
)
@defer.inlineCallbacks
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index e845410dae..14ce21c786 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -58,6 +58,10 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
return self.get_success(self.db_pool.runWithConnection(_create))
def _insert_rows(self, instance_name: str, number: int):
+ """Insert N rows as the given instance, inserting with stream IDs pulled
+ from the postgres sequence.
+ """
+
def _insert(txn):
for _ in range(number):
txn.execute(
@@ -65,7 +69,20 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
(instance_name,),
)
- self.get_success(self.db_pool.runInteraction("test_single_instance", _insert))
+ self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
+
+ def _insert_row_with_id(self, instance_name: str, stream_id: int):
+ """Insert one row as the given instance with given stream_id, updating
+ the postgres sequence position to match.
+ """
+
+ def _insert(txn):
+ txn.execute(
+ "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
+ )
+ txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
+
+ self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
def test_empty(self):
"""Test an ID generator against an empty database gives sensible
@@ -88,7 +105,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen = self._create_id_generator()
self.assertEqual(id_gen.get_positions(), {"master": 7})
- self.assertEqual(id_gen.get_current_token("master"), 7)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
@@ -98,12 +115,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(stream_id, 8)
self.assertEqual(id_gen.get_positions(), {"master": 7})
- self.assertEqual(id_gen.get_current_token("master"), 7)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
self.get_success(_get_next_async())
self.assertEqual(id_gen.get_positions(), {"master": 8})
- self.assertEqual(id_gen.get_current_token("master"), 8)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
def test_multi_instance(self):
"""Test that reads and writes from multiple processes are handled
@@ -116,8 +133,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
second_id_gen = self._create_id_generator("second")
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
- self.assertEqual(first_id_gen.get_current_token("first"), 3)
- self.assertEqual(first_id_gen.get_current_token("second"), 7)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
+ self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
@@ -166,7 +183,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen = self._create_id_generator()
self.assertEqual(id_gen.get_positions(), {"master": 7})
- self.assertEqual(id_gen.get_current_token("master"), 7)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
@@ -176,9 +193,74 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(stream_id, 8)
self.assertEqual(id_gen.get_positions(), {"master": 7})
- self.assertEqual(id_gen.get_current_token("master"), 7)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
self.get_success(self.db_pool.runInteraction("test", _get_next_txn))
self.assertEqual(id_gen.get_positions(), {"master": 8})
- self.assertEqual(id_gen.get_current_token("master"), 8)
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
+
+ def test_get_persisted_upto_position(self):
+ """Test that `get_persisted_upto_position` correctly tracks updates to
+ positions.
+ """
+
+ # The following tests are a bit cheeky in that we notify about new
+ # positions via `advance` without *actually* advancing the postgres
+ # sequence.
+
+ self._insert_row_with_id("first", 3)
+ self._insert_row_with_id("second", 5)
+
+ id_gen = self._create_id_generator("first")
+
+ self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
+
+ # Min is 3 and there is a gap between 5, so we expect it to be 3.
+ self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+
+ # We advance "first" straight to 6. Min is now 5 but there is no gap so
+ # we expect it to be 6
+ id_gen.advance("first", 6)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 6)
+
+ # No gap, so we expect 7.
+ id_gen.advance("second", 7)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 7)
+
+ # We haven't seen 8 yet, so we expect 7 still.
+ id_gen.advance("second", 9)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 7)
+
+ # Now that we've seen 7, 8 and 9 we can got straight to 9.
+ id_gen.advance("first", 8)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 9)
+
+ # Jump forward with gaps. The minimum is 11, even though we haven't seen
+ # 10 we know that everything before 11 must be persisted.
+ id_gen.advance("first", 11)
+ id_gen.advance("second", 15)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 11)
+
+ def test_get_persisted_upto_position_get_next(self):
+ """Test that `get_persisted_upto_position` correctly tracks updates to
+ positions when `get_next` is called.
+ """
+
+ self._insert_row_with_id("first", 3)
+ self._insert_row_with_id("second", 5)
+
+ id_gen = self._create_id_generator("first")
+
+ self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
+
+ self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+ with self.get_success(id_gen.get_next()) as stream_id:
+ self.assertEqual(stream_id, 6)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+
+ self.assertEqual(id_gen.get_persisted_upto_position(), 6)
+
+ # We assume that so long as `get_next` does correctly advance the
+ # `persisted_upto_position` in this case, then it will be correct in the
+ # other cases that are tested above (since they'll hit the same code).
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index fbf8af940a..954338a592 100644
--- a/tests/storage/test_main.py
+++ b/tests/storage/test_main.py
@@ -36,7 +36,9 @@ class DataStoreTestCase(unittest.TestCase):
def test_get_users_paginate(self):
yield self.store.register_user(self.user.to_string(), "pass")
yield defer.ensureDeferred(self.store.create_profile(self.user.localpart))
- yield self.store.set_profile_displayname(self.user.localpart, self.displayname)
+ yield defer.ensureDeferred(
+ self.store.set_profile_displayname(self.user.localpart, self.displayname)
+ )
users, total = yield self.store.get_users_paginate(
0, 10, name="bc", guests=False
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 9d5b8aa47d..3fd0a38cf5 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -35,21 +35,34 @@ class ProfileStoreTestCase(unittest.TestCase):
def test_displayname(self):
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
- yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
+ yield defer.ensureDeferred(
+ self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
+ )
self.assertEquals(
- "Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart))
+ "Frank",
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_displayname(self.u_frank.localpart)
+ )
+ ),
)
@defer.inlineCallbacks
def test_avatar_url(self):
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
- yield self.store.set_profile_avatar_url(
- self.u_frank.localpart, "http://my.site/here"
+ yield defer.ensureDeferred(
+ self.store.set_profile_avatar_url(
+ self.u_frank.localpart, "http://my.site/here"
+ )
)
self.assertEquals(
"http://my.site/here",
- (yield self.store.get_profile_avatar_url(self.u_frank.localpart)),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_avatar_url(self.u_frank.localpart)
+ )
+ ),
)
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 840db66072..70c55cd650 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -17,6 +17,7 @@
from twisted.internet import defer
from synapse.api.constants import UserTypes
+from synapse.api.errors import ThreepidValidationError
from tests import unittest
from tests.utils import setup_test_homeserver
@@ -52,7 +53,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
"user_type": None,
"deactivated": 0,
},
- (yield self.store.get_user_by_id(self.user_id)),
+ (yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))),
)
@defer.inlineCallbacks
@@ -122,3 +123,33 @@ class RegistrationStoreTestCase(unittest.TestCase):
)
res = yield self.store.is_support_user(SUPPORT_USER)
self.assertTrue(res)
+
+ @defer.inlineCallbacks
+ def test_3pid_inhibit_invalid_validation_session_error(self):
+ """Tests that enabling the configuration option to inhibit 3PID errors on
+ /requestToken also inhibits validation errors caused by an unknown session ID.
+ """
+
+ # Check that, with the config setting set to false (the default value), a
+ # validation error is caused by the unknown session ID.
+ try:
+ yield defer.ensureDeferred(
+ self.store.validate_threepid_session(
+ "fake_sid", "fake_client_secret", "fake_token", 0,
+ )
+ )
+ except ThreepidValidationError as e:
+ self.assertEquals(e.msg, "Unknown session_id", e)
+
+ # Set the config setting to true.
+ self.store._ignore_unknown_session_error = True
+
+ # Check that now the validation error is caused by the token not matching.
+ try:
+ yield defer.ensureDeferred(
+ self.store.validate_threepid_session(
+ "fake_sid", "fake_client_secret", "fake_token", 0,
+ )
+ )
+ except ThreepidValidationError as e:
+ self.assertEquals(e.msg, "Validation token not found or has expired", e)
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index d07b985a8e..bc8400f240 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -54,12 +54,14 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(),
"is_public": True,
},
- (yield self.store.get_room(self.room.to_string())),
+ (yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))),
)
@defer.inlineCallbacks
def test_get_room_unknown_room(self):
- self.assertIsNone((yield self.store.get_room("!uknown:test")),)
+ self.assertIsNone(
+ (yield defer.ensureDeferred(self.store.get_room("!uknown:test")))
+ )
@defer.inlineCallbacks
def test_get_room_with_stats(self):
@@ -69,12 +71,22 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(),
"public": True,
},
- (yield self.store.get_room_with_stats(self.room.to_string())),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_room_with_stats(self.room.to_string())
+ )
+ ),
)
@defer.inlineCallbacks
def test_get_room_with_stats_unknown_room(self):
- self.assertIsNone((yield self.store.get_room_with_stats("!uknown:test")),)
+ self.assertIsNone(
+ (
+ yield defer.ensureDeferred(
+ self.store.get_room_with_stats("!uknown:test")
+ )
+ ),
+ )
class RoomEventsStoreTestCase(unittest.TestCase):
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index d98fe8754d..12ccc1f53e 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -87,7 +87,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN)
- self.pump(20)
+ self.pump()
self.assertTrue("_known_servers_count" not in self.store.__dict__.keys())
@@ -101,7 +101,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
# Initialises to 1 -- itself
self.assertEqual(self.store._known_servers_count, 1)
- self.pump(20)
+ self.pump()
# No rooms have been joined, so technically the SQL returns 0, but it
# will still say it knows about itself.
@@ -111,7 +111,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN)
- self.pump(20)
+ self.pump(1)
# It now knows about Charlie's server.
self.assertEqual(self.store._known_servers_count, 2)
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 4a4548433f..27a7fc9ed7 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -15,8 +15,9 @@
from mock import Mock
-from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed
+from twisted.internet.defer import succeed
+from synapse.api.errors import FederationError
from synapse.events import make_event_from_dict
from synapse.logging.context import LoggingContext
from synapse.types import Requester, UserID
@@ -44,22 +45,17 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
user_id = UserID("us", "test")
our_user = Requester(user_id, None, False, False, None, None)
room_creator = self.homeserver.get_room_creation_handler()
- room_deferred = ensureDeferred(
+ self.room_id = self.get_success(
room_creator.create_room(
our_user, room_creator._presets_dict["public_chat"], ratelimit=False
)
- )
- self.reactor.advance(0.1)
- self.room_id = self.successResultOf(room_deferred)[0]["room_id"]
+ )[0]["room_id"]
self.store = self.homeserver.get_datastore()
# Figure out what the most recent event is
- most_recent = self.successResultOf(
- maybeDeferred(
- self.homeserver.get_datastore().get_latest_event_ids_in_room,
- self.room_id,
- )
+ most_recent = self.get_success(
+ self.homeserver.get_datastore().get_latest_event_ids_in_room(self.room_id)
)[0]
join_event = make_event_from_dict(
@@ -89,19 +85,18 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
)
# Send the join, it should return None (which is not an error)
- d = ensureDeferred(
- self.handler.on_receive_pdu(
- "test.serv", join_event, sent_to_us_directly=True
- )
+ self.assertEqual(
+ self.get_success(
+ self.handler.on_receive_pdu(
+ "test.serv", join_event, sent_to_us_directly=True
+ )
+ ),
+ None,
)
- self.reactor.advance(1)
- self.assertEqual(self.successResultOf(d), None)
# Make sure we actually joined the room
self.assertEqual(
- self.successResultOf(
- maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
- )[0],
+ self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))[0],
"$join:test.serv",
)
@@ -119,8 +114,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
self.http_client.post_json = post_json
# Figure out what the most recent event is
- most_recent = self.successResultOf(
- maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
+ most_recent = self.get_success(
+ self.store.get_latest_event_ids_in_room(self.room_id)
)[0]
# Now lie about an event
@@ -140,17 +135,14 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
)
with LoggingContext(request="lying_event"):
- d = ensureDeferred(
+ failure = self.get_failure(
self.handler.on_receive_pdu(
"test.serv", lying_event, sent_to_us_directly=True
- )
+ ),
+ FederationError,
)
- # Step the reactor, so the database fetches come back
- self.reactor.advance(1)
-
# on_receive_pdu should throw an error
- failure = self.failureResultOf(d)
self.assertEqual(
failure.value.args[0],
(
@@ -160,8 +152,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
)
# Make sure the invalid event isn't there
- extrem = maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
- self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
+ extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
+ self.assertEqual(extrem[0], "$join:test.serv")
def test_retry_device_list_resync(self):
"""Tests that device lists are marked as stale if they couldn't be synced, and
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 4d2b9e0d64..0363735d4f 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -366,11 +366,11 @@ class CachedListDescriptorTestCase(unittest.TestCase):
def fn(self, arg1, arg2):
pass
- @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
- def list_fn(self, args1, arg2):
+ @descriptors.cachedList("fn", "args1")
+ async def list_fn(self, args1, arg2):
assert current_context().request == "c1"
# we want this to behave like an asynchronous function
- yield run_on_reactor()
+ await run_on_reactor()
assert current_context().request == "c1"
return self.mock(args1, arg2)
@@ -416,10 +416,10 @@ class CachedListDescriptorTestCase(unittest.TestCase):
def fn(self, arg1, arg2):
pass
- @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
- def list_fn(self, args1, arg2):
+ @descriptors.cachedList("fn", "args1")
+ async def list_fn(self, args1, arg2):
# we want this to behave like an asynchronous function
- yield run_on_reactor()
+ await run_on_reactor()
return self.mock(args1, arg2)
obj = Cls()
diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py
index bc42ffce88..5f46ed0cef 100644
--- a/tests/util/test_retryutils.py
+++ b/tests/util/test_retryutils.py
@@ -91,7 +91,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
#
# one more go, with success
#
- self.pump(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0)
+ self.reactor.advance(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0)
limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
self.pump(1)
|