diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 34f72ae795..28d77f0ca2 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
import pymacaroons
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index 483418192c..fa96ba07a5 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -5,38 +5,25 @@ from synapse.types import create_requester
from tests import unittest
-class TestRatelimiter(unittest.TestCase):
+class TestRatelimiter(unittest.HomeserverTestCase):
def test_allowed_via_can_do_action(self):
- limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
- allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=0)
- self.assertTrue(allowed)
- self.assertEquals(10.0, time_allowed)
-
- allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=5)
- self.assertFalse(allowed)
- self.assertEquals(10.0, time_allowed)
-
- allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=10)
- self.assertTrue(allowed)
- self.assertEquals(20.0, time_allowed)
-
- def test_allowed_user_via_can_requester_do_action(self):
- user_requester = create_requester("@user:example.com")
- limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
- allowed, time_allowed = limiter.can_requester_do_action(
- user_requester, _time_now_s=0
+ limiter = Ratelimiter(
+ store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ )
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(None, key="test_id", _time_now_s=0)
)
self.assertTrue(allowed)
self.assertEquals(10.0, time_allowed)
- allowed, time_allowed = limiter.can_requester_do_action(
- user_requester, _time_now_s=5
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(None, key="test_id", _time_now_s=5)
)
self.assertFalse(allowed)
self.assertEquals(10.0, time_allowed)
- allowed, time_allowed = limiter.can_requester_do_action(
- user_requester, _time_now_s=10
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(None, key="test_id", _time_now_s=10)
)
self.assertTrue(allowed)
self.assertEquals(20.0, time_allowed)
@@ -51,21 +38,23 @@ class TestRatelimiter(unittest.TestCase):
)
as_requester = create_requester("@user:example.com", app_service=appservice)
- limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
- allowed, time_allowed = limiter.can_requester_do_action(
- as_requester, _time_now_s=0
+ limiter = Ratelimiter(
+ store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ )
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(as_requester, _time_now_s=0)
)
self.assertTrue(allowed)
self.assertEquals(10.0, time_allowed)
- allowed, time_allowed = limiter.can_requester_do_action(
- as_requester, _time_now_s=5
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(as_requester, _time_now_s=5)
)
self.assertFalse(allowed)
self.assertEquals(10.0, time_allowed)
- allowed, time_allowed = limiter.can_requester_do_action(
- as_requester, _time_now_s=10
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(as_requester, _time_now_s=10)
)
self.assertTrue(allowed)
self.assertEquals(20.0, time_allowed)
@@ -80,73 +69,89 @@ class TestRatelimiter(unittest.TestCase):
)
as_requester = create_requester("@user:example.com", app_service=appservice)
- limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
- allowed, time_allowed = limiter.can_requester_do_action(
- as_requester, _time_now_s=0
+ limiter = Ratelimiter(
+ store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ )
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(as_requester, _time_now_s=0)
)
self.assertTrue(allowed)
self.assertEquals(-1, time_allowed)
- allowed, time_allowed = limiter.can_requester_do_action(
- as_requester, _time_now_s=5
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(as_requester, _time_now_s=5)
)
self.assertTrue(allowed)
self.assertEquals(-1, time_allowed)
- allowed, time_allowed = limiter.can_requester_do_action(
- as_requester, _time_now_s=10
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(as_requester, _time_now_s=10)
)
self.assertTrue(allowed)
self.assertEquals(-1, time_allowed)
def test_allowed_via_ratelimit(self):
- limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+ limiter = Ratelimiter(
+ store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ )
# Shouldn't raise
- limiter.ratelimit(key="test_id", _time_now_s=0)
+ self.get_success_or_raise(limiter.ratelimit(None, key="test_id", _time_now_s=0))
# Should raise
with self.assertRaises(LimitExceededError) as context:
- limiter.ratelimit(key="test_id", _time_now_s=5)
+ self.get_success_or_raise(
+ limiter.ratelimit(None, key="test_id", _time_now_s=5)
+ )
self.assertEqual(context.exception.retry_after_ms, 5000)
# Shouldn't raise
- limiter.ratelimit(key="test_id", _time_now_s=10)
+ self.get_success_or_raise(
+ limiter.ratelimit(None, key="test_id", _time_now_s=10)
+ )
def test_allowed_via_can_do_action_and_overriding_parameters(self):
"""Test that we can override options of can_do_action that would otherwise fail
an action
"""
# Create a Ratelimiter with a very low allowed rate_hz and burst_count
- limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+ limiter = Ratelimiter(
+ store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ )
# First attempt should be allowed
- allowed, time_allowed = limiter.can_do_action(
- ("test_id",),
- _time_now_s=0,
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(
+ None,
+ ("test_id",),
+ _time_now_s=0,
+ )
)
self.assertTrue(allowed)
self.assertEqual(10.0, time_allowed)
# Second attempt, 1s later, will fail
- allowed, time_allowed = limiter.can_do_action(
- ("test_id",),
- _time_now_s=1,
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(
+ None,
+ ("test_id",),
+ _time_now_s=1,
+ )
)
self.assertFalse(allowed)
self.assertEqual(10.0, time_allowed)
# But, if we allow 10 actions/sec for this request, we should be allowed
# to continue.
- allowed, time_allowed = limiter.can_do_action(
- ("test_id",), _time_now_s=1, rate_hz=10.0
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(None, ("test_id",), _time_now_s=1, rate_hz=10.0)
)
self.assertTrue(allowed)
self.assertEqual(1.1, time_allowed)
# Similarly if we allow a burst of 10 actions
- allowed, time_allowed = limiter.can_do_action(
- ("test_id",), _time_now_s=1, burst_count=10
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(None, ("test_id",), _time_now_s=1, burst_count=10)
)
self.assertTrue(allowed)
self.assertEqual(1.0, time_allowed)
@@ -156,29 +161,72 @@ class TestRatelimiter(unittest.TestCase):
fail an action
"""
# Create a Ratelimiter with a very low allowed rate_hz and burst_count
- limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+ limiter = Ratelimiter(
+ store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ )
# First attempt should be allowed
- limiter.ratelimit(key=("test_id",), _time_now_s=0)
+ self.get_success_or_raise(
+ limiter.ratelimit(None, key=("test_id",), _time_now_s=0)
+ )
# Second attempt, 1s later, will fail
with self.assertRaises(LimitExceededError) as context:
- limiter.ratelimit(key=("test_id",), _time_now_s=1)
+ self.get_success_or_raise(
+ limiter.ratelimit(None, key=("test_id",), _time_now_s=1)
+ )
self.assertEqual(context.exception.retry_after_ms, 9000)
# But, if we allow 10 actions/sec for this request, we should be allowed
# to continue.
- limiter.ratelimit(key=("test_id",), _time_now_s=1, rate_hz=10.0)
+ self.get_success_or_raise(
+ limiter.ratelimit(None, key=("test_id",), _time_now_s=1, rate_hz=10.0)
+ )
# Similarly if we allow a burst of 10 actions
- limiter.ratelimit(key=("test_id",), _time_now_s=1, burst_count=10)
+ self.get_success_or_raise(
+ limiter.ratelimit(None, key=("test_id",), _time_now_s=1, burst_count=10)
+ )
def test_pruning(self):
- limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
- limiter.can_do_action(key="test_id_1", _time_now_s=0)
+ limiter = Ratelimiter(
+ store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ )
+ self.get_success_or_raise(
+ limiter.can_do_action(None, key="test_id_1", _time_now_s=0)
+ )
self.assertIn("test_id_1", limiter.actions)
- limiter.can_do_action(key="test_id_2", _time_now_s=10)
+ self.get_success_or_raise(
+ limiter.can_do_action(None, key="test_id_2", _time_now_s=10)
+ )
self.assertNotIn("test_id_1", limiter.actions)
+
+ def test_db_user_override(self):
+ """Test that users that have ratelimiting disabled in the DB aren't
+ ratelimited.
+ """
+ store = self.hs.get_datastore()
+
+ user_id = "@user:test"
+ requester = create_requester(user_id)
+
+ self.get_success(
+ store.db_pool.simple_insert(
+ table="ratelimit_override",
+ values={
+ "user_id": user_id,
+ "messages_per_second": None,
+ "burst_count": None,
+ },
+ desc="test_db_user_override",
+ )
+ )
+
+ limiter = Ratelimiter(store=store, clock=None, rate_hz=0.1, burst_count=1)
+
+ # Shouldn't raise
+ for _ in range(20):
+ self.get_success_or_raise(limiter.ratelimit(requester, _time_now_s=0))
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index 467033e201..33a37fe35e 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock, patch
+from unittest.mock import Mock, patch
from parameterized import parameterized
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index 0bffeb1150..03a7440eec 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
-
-from mock import Mock
+from unittest.mock import Mock
from twisted.internet import defer
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index 97f8cad0dd..3c27d797fb 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from twisted.internet import defer
diff --git a/tests/config/test_load.py b/tests/config/test_load.py
index 734a9983e8..c109425671 100644
--- a/tests/config/test_load.py
+++ b/tests/config/test_load.py
@@ -20,6 +20,7 @@ from io import StringIO
import yaml
+from synapse.config import ConfigError
from synapse.config.homeserver import HomeServerConfig
from tests import unittest
@@ -35,9 +36,9 @@ class ConfigLoadingTestCase(unittest.TestCase):
def test_load_fails_if_server_name_missing(self):
self.generate_config_and_remove_lines_containing("server_name")
- with self.assertRaises(Exception):
+ with self.assertRaises(ConfigError):
HomeServerConfig.load_config("", ["-c", self.file])
- with self.assertRaises(Exception):
+ with self.assertRaises(ConfigError):
HomeServerConfig.load_or_generate_config("", ["-c", self.file])
def test_generates_and_loads_macaroon_secret_key(self):
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 30fcc4c1bf..a56063315b 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -13,9 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
+from unittest.mock import Mock
-from mock import Mock
-
+import attr
import canonicaljson
import signedjson.key
import signedjson.sign
@@ -68,6 +68,11 @@ class MockPerspectiveServer:
signedjson.sign.sign_json(res, self.server_name, self.key)
+@attr.s(slots=True)
+class FakeRequest:
+ id = attr.ib()
+
+
@logcontext_clean
class KeyringTestCase(unittest.HomeserverTestCase):
def check_context(self, val, expected):
@@ -89,7 +94,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
first_lookup_deferred = Deferred()
async def first_lookup_fetch(keys_to_fetch):
- self.assertEquals(current_context().request, "context_11")
+ self.assertEquals(current_context().request.id, "context_11")
self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}})
await make_deferred_yieldable(first_lookup_deferred)
@@ -102,9 +107,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
mock_fetcher.get_keys.side_effect = first_lookup_fetch
async def first_lookup():
- with LoggingContext("context_11") as context_11:
- context_11.request = "context_11"
-
+ with LoggingContext("context_11", request=FakeRequest("context_11")):
res_deferreds = kr.verify_json_objects_for_server(
[("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")]
)
@@ -130,7 +133,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# should block rather than start a second call
async def second_lookup_fetch(keys_to_fetch):
- self.assertEquals(current_context().request, "context_12")
+ self.assertEquals(current_context().request.id, "context_12")
return {
"server10": {
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
@@ -142,9 +145,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
second_lookup_state = [0]
async def second_lookup():
- with LoggingContext("context_12") as context_12:
- context_12.request = "context_12"
-
+ with LoggingContext("context_12", request=FakeRequest("context_12")):
res_deferreds_2 = kr.verify_json_objects_for_server(
[("server10", json1, 0, "test")]
)
@@ -589,10 +590,7 @@ def get_key_id(key):
@defer.inlineCallbacks
def run_in_context(f, *args, **kwargs):
- with LoggingContext("testctx") as ctx:
- # we set the "request" prop to make it easier to follow what's going on in the
- # logs.
- ctx.request = "testctx"
+ with LoggingContext("testctx"):
rv = yield f(*args, **kwargs)
return rv
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
new file mode 100644
index 0000000000..c996ecc221
--- /dev/null
+++ b/tests/events/test_presence_router.py
@@ -0,0 +1,386 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
+from unittest.mock import Mock
+
+import attr
+
+from synapse.api.constants import EduTypes
+from synapse.events.presence_router import PresenceRouter
+from synapse.federation.units import Transaction
+from synapse.handlers.presence import UserPresenceState
+from synapse.module_api import ModuleApi
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, presence, room
+from synapse.types import JsonDict, StreamToken, create_requester
+
+from tests.handlers.test_sync import generate_sync_config
+from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config
+
+
+@attr.s
+class PresenceRouterTestConfig:
+ users_who_should_receive_all_presence = attr.ib(type=List[str], default=[])
+
+
+class PresenceRouterTestModule:
+ def __init__(self, config: PresenceRouterTestConfig, module_api: ModuleApi):
+ self._config = config
+ self._module_api = module_api
+
+ async def get_users_for_states(
+ self, state_updates: Iterable[UserPresenceState]
+ ) -> Dict[str, Set[UserPresenceState]]:
+ users_to_state = {
+ user_id: set(state_updates)
+ for user_id in self._config.users_who_should_receive_all_presence
+ }
+ return users_to_state
+
+ async def get_interested_users(
+ self, user_id: str
+ ) -> Union[Set[str], PresenceRouter.ALL_USERS]:
+ if user_id in self._config.users_who_should_receive_all_presence:
+ return PresenceRouter.ALL_USERS
+
+ return set()
+
+ @staticmethod
+ def parse_config(config_dict: dict) -> PresenceRouterTestConfig:
+ """Parse a configuration dictionary from the homeserver config, do
+ some validation and return a typed PresenceRouterConfig.
+
+ Args:
+ config_dict: The configuration dictionary.
+
+ Returns:
+ A validated config object.
+ """
+ # Initialise a typed config object
+ config = PresenceRouterTestConfig()
+
+ config.users_who_should_receive_all_presence = config_dict.get(
+ "users_who_should_receive_all_presence"
+ )
+
+ return config
+
+
+class PresenceRouterTestCase(FederatingHomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ presence.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ return self.setup_test_homeserver(
+ federation_transport_client=Mock(spec=["send_transaction"]),
+ )
+
+ def prepare(self, reactor, clock, homeserver):
+ self.sync_handler = self.hs.get_sync_handler()
+ self.module_api = homeserver.get_module_api()
+
+ @override_config(
+ {
+ "presence": {
+ "presence_router": {
+ "module": __name__ + ".PresenceRouterTestModule",
+ "config": {
+ "users_who_should_receive_all_presence": [
+ "@presence_gobbler:test",
+ ]
+ },
+ }
+ },
+ "send_federation": True,
+ }
+ )
+ def test_receiving_all_presence(self):
+ """Test that a user that does not share a room with another other can receive
+ presence for them, due to presence routing.
+ """
+ # Create a user who should receive all presence of others
+ self.presence_receiving_user_id = self.register_user(
+ "presence_gobbler", "monkey"
+ )
+ self.presence_receiving_user_tok = self.login("presence_gobbler", "monkey")
+
+ # And two users who should not have any special routing
+ self.other_user_one_id = self.register_user("other_user_one", "monkey")
+ self.other_user_one_tok = self.login("other_user_one", "monkey")
+ self.other_user_two_id = self.register_user("other_user_two", "monkey")
+ self.other_user_two_tok = self.login("other_user_two", "monkey")
+
+ # Put the other two users in a room with each other
+ room_id = self.helper.create_room_as(
+ self.other_user_one_id, tok=self.other_user_one_tok
+ )
+
+ self.helper.invite(
+ room_id,
+ self.other_user_one_id,
+ self.other_user_two_id,
+ tok=self.other_user_one_tok,
+ )
+ self.helper.join(room_id, self.other_user_two_id, tok=self.other_user_two_tok)
+ # User one sends some presence
+ send_presence_update(
+ self,
+ self.other_user_one_id,
+ self.other_user_one_tok,
+ "online",
+ "boop",
+ )
+
+ # Check that the presence receiving user gets user one's presence when syncing
+ presence_updates, sync_token = sync_presence(
+ self, self.presence_receiving_user_id
+ )
+ self.assertEqual(len(presence_updates), 1)
+
+ presence_update = presence_updates[0] # type: UserPresenceState
+ self.assertEqual(presence_update.user_id, self.other_user_one_id)
+ self.assertEqual(presence_update.state, "online")
+ self.assertEqual(presence_update.status_msg, "boop")
+
+ # Have all three users send presence
+ send_presence_update(
+ self,
+ self.other_user_one_id,
+ self.other_user_one_tok,
+ "online",
+ "user_one",
+ )
+ send_presence_update(
+ self,
+ self.other_user_two_id,
+ self.other_user_two_tok,
+ "online",
+ "user_two",
+ )
+ send_presence_update(
+ self,
+ self.presence_receiving_user_id,
+ self.presence_receiving_user_tok,
+ "online",
+ "presence_gobbler",
+ )
+
+ # Check that the presence receiving user gets everyone's presence
+ presence_updates, _ = sync_presence(
+ self, self.presence_receiving_user_id, sync_token
+ )
+ self.assertEqual(len(presence_updates), 3)
+
+ # But that User One only get itself and User Two's presence
+ presence_updates, _ = sync_presence(self, self.other_user_one_id)
+ self.assertEqual(len(presence_updates), 2)
+
+ found = False
+ for update in presence_updates:
+ if update.user_id == self.other_user_two_id:
+ self.assertEqual(update.state, "online")
+ self.assertEqual(update.status_msg, "user_two")
+ found = True
+
+ self.assertTrue(found)
+
+ @override_config(
+ {
+ "presence": {
+ "presence_router": {
+ "module": __name__ + ".PresenceRouterTestModule",
+ "config": {
+ "users_who_should_receive_all_presence": [
+ "@presence_gobbler1:test",
+ "@presence_gobbler2:test",
+ "@far_away_person:island",
+ ]
+ },
+ }
+ },
+ "send_federation": True,
+ }
+ )
+ def test_send_local_online_presence_to_with_module(self):
+ """Tests that send_local_presence_to_users sends local online presence to a set
+ of specified local and remote users, with a custom PresenceRouter module enabled.
+ """
+ # Create a user who will send presence updates
+ self.other_user_id = self.register_user("other_user", "monkey")
+ self.other_user_tok = self.login("other_user", "monkey")
+
+ # And another two users that will also send out presence updates, as well as receive
+ # theirs and everyone else's
+ self.presence_receiving_user_one_id = self.register_user(
+ "presence_gobbler1", "monkey"
+ )
+ self.presence_receiving_user_one_tok = self.login("presence_gobbler1", "monkey")
+ self.presence_receiving_user_two_id = self.register_user(
+ "presence_gobbler2", "monkey"
+ )
+ self.presence_receiving_user_two_tok = self.login("presence_gobbler2", "monkey")
+
+ # Have all three users send some presence updates
+ send_presence_update(
+ self,
+ self.other_user_id,
+ self.other_user_tok,
+ "online",
+ "I'm online!",
+ )
+ send_presence_update(
+ self,
+ self.presence_receiving_user_one_id,
+ self.presence_receiving_user_one_tok,
+ "online",
+ "I'm also online!",
+ )
+ send_presence_update(
+ self,
+ self.presence_receiving_user_two_id,
+ self.presence_receiving_user_two_tok,
+ "unavailable",
+ "I'm in a meeting!",
+ )
+
+ # Mark each presence-receiving user for receiving all user presence
+ self.get_success(
+ self.module_api.send_local_online_presence_to(
+ [
+ self.presence_receiving_user_one_id,
+ self.presence_receiving_user_two_id,
+ ]
+ )
+ )
+
+ # Perform a sync for each user
+
+ # The other user should only receive their own presence
+ presence_updates, _ = sync_presence(self, self.other_user_id)
+ self.assertEqual(len(presence_updates), 1)
+
+ presence_update = presence_updates[0] # type: UserPresenceState
+ self.assertEqual(presence_update.user_id, self.other_user_id)
+ self.assertEqual(presence_update.state, "online")
+ self.assertEqual(presence_update.status_msg, "I'm online!")
+
+ # Whereas both presence receiving users should receive everyone's presence updates
+ presence_updates, _ = sync_presence(self, self.presence_receiving_user_one_id)
+ self.assertEqual(len(presence_updates), 3)
+ presence_updates, _ = sync_presence(self, self.presence_receiving_user_two_id)
+ self.assertEqual(len(presence_updates), 3)
+
+ # Test that sending to a remote user works
+ remote_user_id = "@far_away_person:island"
+
+ # Note that due to the remote user being in our module's
+ # users_who_should_receive_all_presence config, they would have
+ # received user presence updates already.
+ #
+ # Thus we reset the mock, and try sending all online local user
+ # presence again
+ self.hs.get_federation_transport_client().send_transaction.reset_mock()
+
+ # Broadcast local user online presence
+ self.get_success(
+ self.module_api.send_local_online_presence_to([remote_user_id])
+ )
+
+ # Check that the expected presence updates were sent
+ expected_users = [
+ self.other_user_id,
+ self.presence_receiving_user_one_id,
+ self.presence_receiving_user_two_id,
+ ]
+
+ calls = (
+ self.hs.get_federation_transport_client().send_transaction.call_args_list
+ )
+ for call in calls:
+ call_args = call[0]
+ federation_transaction = call_args[0] # type: Transaction
+
+ # Get the sent EDUs in this transaction
+ edus = federation_transaction.get_dict()["edus"]
+
+ for edu in edus:
+ # Make sure we're only checking presence-type EDUs
+ if edu["edu_type"] != EduTypes.Presence:
+ continue
+
+ # EDUs can contain multiple presence updates
+ for presence_update in edu["content"]["push"]:
+ # Check for presence updates that contain the user IDs we're after
+ expected_users.remove(presence_update["user_id"])
+
+ # Ensure that no offline states are being sent out
+ self.assertNotEqual(presence_update["presence"], "offline")
+
+ self.assertEqual(len(expected_users), 0)
+
+
+def send_presence_update(
+ testcase: TestCase,
+ user_id: str,
+ access_token: str,
+ presence_state: str,
+ status_message: Optional[str] = None,
+) -> JsonDict:
+ # Build the presence body
+ body = {"presence": presence_state}
+ if status_message:
+ body["status_msg"] = status_message
+
+ # Update the user's presence state
+ channel = testcase.make_request(
+ "PUT", "/presence/%s/status" % (user_id,), body, access_token=access_token
+ )
+ testcase.assertEqual(channel.code, 200)
+
+ return channel.json_body
+
+
+def sync_presence(
+ testcase: TestCase,
+ user_id: str,
+ since_token: Optional[StreamToken] = None,
+) -> Tuple[List[UserPresenceState], StreamToken]:
+ """Perform a sync request for the given user and return the user presence updates
+ they've received, as well as the next_batch token.
+
+ This method assumes testcase.sync_handler points to the homeserver's sync handler.
+
+ Args:
+ testcase: The testcase that is currently being run.
+ user_id: The ID of the user to generate a sync response for.
+ since_token: An optional token to indicate from at what point to sync from.
+
+ Returns:
+ A tuple containing a list of presence updates, and the sync response's
+ next_batch token.
+ """
+ requester = create_requester(user_id)
+ sync_config = generate_sync_config(requester.user.to_string())
+ sync_result = testcase.get_success(
+ testcase.sync_handler.wait_for_sync_for_user(
+ requester, sync_config, since_token
+ )
+ )
+
+ return sync_result.presence, sync_result.next_batch
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 8186b8ca01..701fa8379f 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from synapse.api.errors import Codes, SynapseError
from synapse.rest import admin
diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index 95eac6a5a3..802c5ad299 100644
--- a/tests/federation/test_federation_catch_up.py
+++ b/tests/federation/test_federation_catch_up.py
@@ -1,6 +1,5 @@
from typing import List, Tuple
-
-from mock import Mock
+from unittest.mock import Mock
from synapse.api.constants import EventTypes
from synapse.events import EventBase
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index ecc3faa572..deb12433cf 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
-
-from mock import Mock
+from unittest.mock import Mock
from signedjson import key, sign
from signedjson.types import BaseKey, SigningKey
diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
index a01fdd0839..32669ae9ce 100644
--- a/tests/handlers/test_admin.py
+++ b/tests/handlers/test_admin.py
@@ -14,8 +14,7 @@
# limitations under the License.
from collections import Counter
-
-from mock import Mock
+from unittest.mock import Mock
import synapse.api.errors
import synapse.handlers.admin
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index d5d3fdd99a..6e325b24ce 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from twisted.internet import defer
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index c9f889b511..321c5ba045 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
import pymacaroons
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index 7975af243c..0444b26798 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from synapse.handlers.cas_handler import CasResponse
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index fadec16e13..a8d0cf6603 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -14,7 +14,7 @@
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
import synapse
import synapse.api.errors
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 5e86c5e56b..6915ac0205 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -14,7 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import mock
+from unittest import mock
from signedjson import key as key, sign as sign
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index d7498aa51a..07893302ec 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -16,8 +16,7 @@
# limitations under the License.
import copy
-
-import mock
+from unittest import mock
from synapse.api.errors import SynapseError
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index c7796fb837..8702ee70e0 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -14,10 +14,9 @@
# limitations under the License.
import json
import os
+from unittest.mock import ANY, Mock, patch
from urllib.parse import parse_qs, urlparse
-from mock import ANY, Mock, patch
-
import pymacaroons
from synapse.handlers.sso import MappingException
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index a98a65ae67..e28e4159eb 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -16,8 +16,7 @@
"""Tests for the password_auth_provider interface"""
from typing import Any, Type, Union
-
-from mock import Mock
+from unittest.mock import Mock
from twisted.internet import defer
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 77330f59a9..9f16cc65fc 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -14,7 +14,7 @@
# limitations under the License.
-from mock import Mock, call
+from unittest.mock import Mock, call
from signedjson.key import generate_signing_key
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index cbbe7280c7..60f2458c98 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
import synapse.types
from synapse.api.errors import AuthError, SynapseError
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 00a0bc5274..c30b414d99 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from synapse.api.auth import Auth
from synapse.api.constants import UserTypes
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 30efd43b40..8cfc184fef 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -13,8 +13,7 @@
# limitations under the License.
from typing import Optional
-
-from mock import Mock
+from unittest.mock import Mock
import attr
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index e62586142e..8e950f25c5 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -37,7 +37,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
def test_wait_for_sync_for_user_auth_blocking(self):
user_id1 = "@user1:test"
user_id2 = "@user2:test"
- sync_config = self._generate_sync_config(user_id1)
+ sync_config = generate_sync_config(user_id1)
requester = create_requester(user_id1)
self.reactor.advance(100) # So we get not 0 time
@@ -60,7 +60,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.auth_blocking._hs_disabled = False
- sync_config = self._generate_sync_config(user_id2)
+ sync_config = generate_sync_config(user_id2)
requester = create_requester(user_id2)
e = self.get_failure(
@@ -69,11 +69,12 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
- def _generate_sync_config(self, user_id):
- return SyncConfig(
- user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]),
- filter_collection=DEFAULT_FILTER_COLLECTION,
- is_guest=False,
- request_key="request_key",
- device_id="device_id",
- )
+
+def generate_sync_config(user_id: str) -> SyncConfig:
+ return SyncConfig(
+ user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]),
+ filter_collection=DEFAULT_FILTER_COLLECTION,
+ is_guest=False,
+ request_key="request_key",
+ device_id="device_id",
+ )
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 24e7138196..9fa231a37a 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -16,8 +16,7 @@
import json
from typing import Dict
-
-from mock import ANY, Mock, call
+from unittest.mock import ANY, Mock, call
from twisted.internet import defer
from twisted.web.resource import Resource
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index dbe68bb058..67a8e49945 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from twisted.internet import defer
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 3972abb038..e6b20799e5 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-
-from mock import Mock
+from typing import Optional
+from unittest.mock import Mock
import treq
from netaddr import IPSet
@@ -180,7 +180,11 @@ class MatrixFederationAgentTests(unittest.TestCase):
_check_logcontext(context)
def _handle_well_known_connection(
- self, client_factory, expected_sni, content, response_headers={}
+ self,
+ client_factory,
+ expected_sni,
+ content,
+ response_headers: Optional[dict] = None,
):
"""Handle an outgoing HTTPs connection: wire it up to a server, check that the
request is for a .well-known, and send the response.
@@ -202,10 +206,12 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertEqual(
request.requestHeaders.getRawHeaders(b"user-agent"), [b"test-agent"]
)
- self._send_well_known_response(request, content, headers=response_headers)
+ self._send_well_known_response(request, content, headers=response_headers or {})
return well_known_server
- def _send_well_known_response(self, request, content, headers={}):
+ def _send_well_known_response(
+ self, request, content, headers: Optional[dict] = None
+ ):
"""Check that an incoming request looks like a valid .well-known request, and
send back the response.
"""
@@ -213,7 +219,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertEqual(request.path, b"/.well-known/matrix/server")
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"])
# send back a response
- for k, v in headers.items():
+ for k, v in (headers or {}).items():
request.setHeader(k, v)
request.write(content)
request.finish()
diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index fee2985d35..466ce722d9 100644
--- a/tests/http/federation/test_srv_resolver.py
+++ b/tests/http/federation/test_srv_resolver.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from twisted.internet import defer
from twisted.internet.defer import Deferred
diff --git a/tests/http/test_client.py b/tests/http/test_client.py
index 0ce181a51e..7e2f2a01cc 100644
--- a/tests/http/test_client.py
+++ b/tests/http/test_client.py
@@ -13,8 +13,7 @@
# limitations under the License.
from io import BytesIO
-
-from mock import Mock
+from unittest.mock import Mock
from netaddr import IPSet
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index 9c52c8fdca..21c1297171 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from netaddr import IPSet
from parameterized import parameterized
diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py
index 45089158ce..f979c96f7c 100644
--- a/tests/http/test_servlet.py
+++ b/tests/http/test_servlet.py
@@ -14,8 +14,7 @@
# limitations under the License.
import json
from io import BytesIO
-
-from mock import Mock
+from unittest.mock import Mock
from synapse.api.errors import SynapseError
from synapse.http.servlet import (
diff --git a/tests/http/test_simple_client.py b/tests/http/test_simple_client.py
index a1cf0862d4..cc4cae320d 100644
--- a/tests/http/test_simple_client.py
+++ b/tests/http/test_simple_client.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from netaddr import IPSet
diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index 48a74e2eee..ecf873e2ab 100644
--- a/tests/logging/test_terse_json.py
+++ b/tests/logging/test_terse_json.py
@@ -12,15 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import json
import logging
-from io import StringIO
+from io import BytesIO, StringIO
+from unittest.mock import Mock, patch
+
+from twisted.web.server import Request
+from synapse.http.site import SynapseRequest
from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter
from synapse.logging.context import LoggingContext, LoggingContextFilter
from tests.logging import LoggerCleanupMixin
+from tests.server import FakeChannel
from tests.unittest import TestCase
@@ -120,7 +124,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
handler.addFilter(LoggingContextFilter())
logger = self.get_logger(handler)
- with LoggingContext(request="test"):
+ with LoggingContext("name"):
logger.info("Hello there, %s!", "wally")
log = self.get_log_line()
@@ -134,4 +138,63 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
]
self.assertCountEqual(log.keys(), expected_log_keys)
self.assertEqual(log["log"], "Hello there, wally!")
- self.assertEqual(log["request"], "test")
+ self.assertEqual(log["request"], "name")
+
+ def test_with_request_context(self):
+ """
+ Information from the logging context request should be added to the JSON response.
+ """
+ handler = logging.StreamHandler(self.output)
+ handler.setFormatter(JsonFormatter())
+ handler.addFilter(LoggingContextFilter())
+ logger = self.get_logger(handler)
+
+ # A full request isn't needed here.
+ site = Mock(spec=["site_tag", "server_version_string", "getResourceFor"])
+ site.site_tag = "test-site"
+ site.server_version_string = "Server v1"
+ request = SynapseRequest(FakeChannel(site, None))
+ # Call requestReceived to finish instantiating the object.
+ request.content = BytesIO()
+ # Partially skip some of the internal processing of SynapseRequest.
+ request._started_processing = Mock()
+ request.request_metrics = Mock(spec=["name"])
+ with patch.object(Request, "render"):
+ request.requestReceived(b"POST", b"/_matrix/client/versions", b"1.1")
+
+ # Also set the requester to ensure the processing works.
+ request.requester = "@foo:test"
+
+ with LoggingContext(
+ request.get_request_id(), parent_context=request.logcontext
+ ):
+ logger.info("Hello there, %s!", "wally")
+
+ log = self.get_log_line()
+
+ # The terse logger includes additional request information, if possible.
+ expected_log_keys = [
+ "log",
+ "level",
+ "namespace",
+ "request",
+ "ip_address",
+ "site_tag",
+ "requester",
+ "authenticated_entity",
+ "method",
+ "url",
+ "protocol",
+ "user_agent",
+ ]
+ self.assertCountEqual(log.keys(), expected_log_keys)
+ self.assertEqual(log["log"], "Hello there, wally!")
+ self.assertTrue(log["request"].startswith("POST-"))
+ self.assertEqual(log["ip_address"], "127.0.0.1")
+ self.assertEqual(log["site_tag"], "test-site")
+ self.assertEqual(log["requester"], "@foo:test")
+ self.assertEqual(log["authenticated_entity"], "@foo:test")
+ self.assertEqual(log["method"], "POST")
+ self.assertEqual(log["url"], "/_matrix/client/versions")
+ self.assertEqual(log["protocol"], "1.1")
+ self.assertEqual(log["user_agent"], "")
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index edacd1b566..349f93560e 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -12,27 +12,39 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
+from synapse.api.constants import EduTypes
from synapse.events import EventBase
+from synapse.federation.units import Transaction
+from synapse.handlers.presence import UserPresenceState
from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v1 import login, presence, room
from synapse.types import create_requester
-from tests.unittest import HomeserverTestCase
+from tests.events.test_presence_router import send_presence_update, sync_presence
+from tests.test_utils.event_injection import inject_member_event
+from tests.unittest import FederatingHomeserverTestCase, override_config
-class ModuleApiTestCase(HomeserverTestCase):
+class ModuleApiTestCase(FederatingHomeserverTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
room.register_servlets,
+ presence.register_servlets,
]
def prepare(self, reactor, clock, homeserver):
self.store = homeserver.get_datastore()
self.module_api = homeserver.get_module_api()
self.event_creation_handler = homeserver.get_event_creation_handler()
+ self.sync_handler = homeserver.get_sync_handler()
+
+ def make_homeserver(self, reactor, clock):
+ return self.setup_test_homeserver(
+ federation_transport_client=Mock(spec=["send_transaction"]),
+ )
def test_can_register_user(self):
"""Tests that an external module can register a user"""
@@ -205,3 +217,161 @@ class ModuleApiTestCase(HomeserverTestCase):
)
)
self.assertFalse(is_in_public_rooms)
+
+ # The ability to send federation is required by send_local_online_presence_to.
+ @override_config({"send_federation": True})
+ def test_send_local_online_presence_to(self):
+ """Tests that send_local_presence_to_users sends local online presence to local users."""
+ # Create a user who will send presence updates
+ self.presence_receiver_id = self.register_user("presence_receiver", "monkey")
+ self.presence_receiver_tok = self.login("presence_receiver", "monkey")
+
+ # And another user that will send presence updates out
+ self.presence_sender_id = self.register_user("presence_sender", "monkey")
+ self.presence_sender_tok = self.login("presence_sender", "monkey")
+
+ # Put them in a room together so they will receive each other's presence updates
+ room_id = self.helper.create_room_as(
+ self.presence_receiver_id,
+ tok=self.presence_receiver_tok,
+ )
+ self.helper.join(room_id, self.presence_sender_id, tok=self.presence_sender_tok)
+
+ # Presence sender comes online
+ send_presence_update(
+ self,
+ self.presence_sender_id,
+ self.presence_sender_tok,
+ "online",
+ "I'm online!",
+ )
+
+ # Presence receiver should have received it
+ presence_updates, sync_token = sync_presence(self, self.presence_receiver_id)
+ self.assertEqual(len(presence_updates), 1)
+
+ presence_update = presence_updates[0] # type: UserPresenceState
+ self.assertEqual(presence_update.user_id, self.presence_sender_id)
+ self.assertEqual(presence_update.state, "online")
+
+ # Syncing again should result in no presence updates
+ presence_updates, sync_token = sync_presence(
+ self, self.presence_receiver_id, sync_token
+ )
+ self.assertEqual(len(presence_updates), 0)
+
+ # Trigger sending local online presence
+ self.get_success(
+ self.module_api.send_local_online_presence_to(
+ [
+ self.presence_receiver_id,
+ ]
+ )
+ )
+
+ # Presence receiver should have received online presence again
+ presence_updates, sync_token = sync_presence(
+ self, self.presence_receiver_id, sync_token
+ )
+ self.assertEqual(len(presence_updates), 1)
+
+ presence_update = presence_updates[0] # type: UserPresenceState
+ self.assertEqual(presence_update.user_id, self.presence_sender_id)
+ self.assertEqual(presence_update.state, "online")
+
+ # Presence sender goes offline
+ send_presence_update(
+ self,
+ self.presence_sender_id,
+ self.presence_sender_tok,
+ "offline",
+ "I slink back into the darkness.",
+ )
+
+ # Trigger sending local online presence
+ self.get_success(
+ self.module_api.send_local_online_presence_to(
+ [
+ self.presence_receiver_id,
+ ]
+ )
+ )
+
+ # Presence receiver should *not* have received offline state
+ presence_updates, sync_token = sync_presence(
+ self, self.presence_receiver_id, sync_token
+ )
+ self.assertEqual(len(presence_updates), 0)
+
+ @override_config({"send_federation": True})
+ def test_send_local_online_presence_to_federation(self):
+ """Tests that send_local_presence_to_users sends local online presence to remote users."""
+ # Create a user who will send presence updates
+ self.presence_sender_id = self.register_user("presence_sender", "monkey")
+ self.presence_sender_tok = self.login("presence_sender", "monkey")
+
+ # And a room they're a part of
+ room_id = self.helper.create_room_as(
+ self.presence_sender_id,
+ tok=self.presence_sender_tok,
+ )
+
+ # Mark them as online
+ send_presence_update(
+ self,
+ self.presence_sender_id,
+ self.presence_sender_tok,
+ "online",
+ "I'm online!",
+ )
+
+ # Make up a remote user to send presence to
+ remote_user_id = "@far_away_person:island"
+
+ # Create a join membership event for the remote user into the room.
+ # This allows presence information to flow from one user to the other.
+ self.get_success(
+ inject_member_event(
+ self.hs,
+ room_id,
+ sender=remote_user_id,
+ target=remote_user_id,
+ membership="join",
+ )
+ )
+
+ # The remote user would have received the existing room members' presence
+ # when they joined the room.
+ #
+ # Thus we reset the mock, and try sending online local user
+ # presence again
+ self.hs.get_federation_transport_client().send_transaction.reset_mock()
+
+ # Broadcast local user online presence
+ self.get_success(
+ self.module_api.send_local_online_presence_to([remote_user_id])
+ )
+
+ # Check that a presence update was sent as part of a federation transaction
+ found_update = False
+ calls = (
+ self.hs.get_federation_transport_client().send_transaction.call_args_list
+ )
+ for call in calls:
+ call_args = call[0]
+ federation_transaction = call_args[0] # type: Transaction
+
+ # Get the sent EDUs in this transaction
+ edus = federation_transaction.get_dict()["edus"]
+
+ for edu in edus:
+ # Make sure we're only checking presence-type EDUs
+ if edu["edu_type"] != EduTypes.Presence:
+ continue
+
+ # EDUs can contain multiple presence updates
+ for presence_update in edu["content"]["push"]:
+ if presence_update["user_id"] == self.presence_sender_id:
+ found_update = True
+
+ self.assertTrue(found_update)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index a3b304d316..f590e8d21c 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from twisted.internet.defer import Deferred
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 1d4a592862..aff19d9fb3 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -266,7 +266,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
return resource
def make_worker_hs(
- self, worker_app: str, extra_config: dict = {}, **kwargs
+ self, worker_app: str, extra_config: Optional[dict] = None, **kwargs
) -> HomeServer:
"""Make a new worker HS instance, correctly connecting replcation
stream to the master HS.
@@ -283,7 +283,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
config = self._get_worker_hs_config()
config["worker_app"] = worker_app
- config.update(extra_config)
+ config.update(extra_config or {})
worker_hs = self.setup_test_homeserver(
homeserver_to_use=GenericWorkerServer,
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 56497b8476..83e89383f6 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from tests.replication._base import BaseStreamTestCase
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 0ceb0f935c..db80a0bdbd 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import Iterable, Optional
from canonicaljson import encode_canonical_json
@@ -332,15 +333,18 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
room_id=ROOM_ID,
type="m.room.message",
key=None,
- internal={},
+ internal: Optional[dict] = None,
depth=None,
- prev_events=[],
- auth_events=[],
- prev_state=[],
+ prev_events: Optional[list] = None,
+ auth_events: Optional[list] = None,
+ prev_state: Optional[list] = None,
redacts=None,
- push_actions=[],
- **content
+ push_actions: Iterable = frozenset(),
+ **content,
):
+ prev_events = prev_events or []
+ auth_events = auth_events or []
+ prev_state = prev_state or []
if depth is None:
depth = self.event_id
@@ -369,7 +373,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
if redacts is not None:
event_dict["redacts"] = redacts
- event = make_event_from_dict(event_dict, internal_metadata_dict=internal)
+ event = make_event_from_dict(event_dict, internal_metadata_dict=internal or {})
self.event_id += 1
state_handler = self.hs.get_state_handler()
diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py
index 56b062ecc1..7d848e41ff 100644
--- a/tests/replication/tcp/streams/test_receipts.py
+++ b/tests/replication/tcp/streams/test_receipts.py
@@ -15,7 +15,7 @@
# type: ignore
-from mock import Mock
+from unittest.mock import Mock
from synapse.replication.tcp.streams._base import ReceiptsStream
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index ca49d4dd3a..4a0b342264 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from synapse.handlers.typing import RoomMember
from synapse.replication.tcp.streams import TypingStream
diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index 0d9e3bb11d..44ad5eec57 100644
--- a/tests/replication/test_federation_ack.py
+++ b/tests/replication/test_federation_ack.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import mock
+from unittest import mock
from synapse.app.generic_worker import GenericWorkerServer
from synapse.replication.tcp.commands import FederationAckCommand
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 2f2d117858..8ca595c3ee 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-
-from mock import Mock
+from unittest.mock import Mock
from synapse.api.constants import EventTypes, Membership
from synapse.events.builder import EventBuilderFactory
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index ab2988a6ba..1f12bde1aa 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-
-from mock import Mock
+from unittest.mock import Mock
from twisted.internet import defer
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index c9b773fbd2..6c2e1674cb 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-
-from mock import patch
+from unittest.mock import patch
from synapse.api.room_versions import RoomVersion
from synapse.rest import admin
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 057e27372e..4abcbe3f55 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -17,8 +17,7 @@ import json
import os
import urllib.parse
from binascii import unhexlify
-
-from mock import Mock
+from unittest.mock import Mock
from twisted.internet.defer import Deferred
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index b55160b70a..85f77c0a65 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -16,8 +16,7 @@
import json
import urllib.parse
from typing import List, Optional
-
-from mock import Mock
+from unittest.mock import Mock
import synapse.rest.admin
from synapse.api.constants import EventTypes, Membership
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 79a05b519b..a7b600a1d4 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -19,8 +19,7 @@ import json
import urllib.parse
from binascii import unhexlify
from typing import List, Optional
-
-from mock import Mock
+from unittest.mock import Mock
import synapse.rest.admin
from synapse.api.constants import UserTypes
@@ -28,7 +27,7 @@ from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
from synapse.api.room_versions import RoomVersions
from synapse.rest.client.v1 import login, logout, profile, room
from synapse.rest.client.v2_alpha import devices, sync
-from synapse.types import JsonDict
+from synapse.types import JsonDict, UserID
from tests import unittest
from tests.server import FakeSite, make_request
@@ -467,6 +466,8 @@ class UsersListTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v2/users"
def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
@@ -634,6 +635,26 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+ # unkown order_by
+ channel = self.make_request(
+ "GET",
+ self.url + "?order_by=bar",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+ # invalid search order
+ channel = self.make_request(
+ "GET",
+ self.url + "?dir=bar",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
def test_limit(self):
"""
Testing list of users with limit
@@ -759,6 +780,103 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(channel.json_body["users"]), 1)
self.assertNotIn("next_token", channel.json_body)
+ def test_order_by(self):
+ """
+ Testing order list with parameter `order_by`
+ """
+
+ user1 = self.register_user("user1", "pass1", admin=False, displayname="Name Z")
+ user2 = self.register_user("user2", "pass2", admin=False, displayname="Name Y")
+
+ # Modify user
+ self.get_success(self.store.set_user_deactivated_status(user1, True))
+ self.get_success(self.store.set_shadow_banned(UserID.from_string(user1), True))
+
+ # Set avatar URL to all users, that no user has a NULL value to avoid
+ # different sort order between SQlite and PostreSQL
+ self.get_success(self.store.set_profile_avatar_url("user1", "mxc://url3"))
+ self.get_success(self.store.set_profile_avatar_url("user2", "mxc://url2"))
+ self.get_success(self.store.set_profile_avatar_url("admin", "mxc://url1"))
+
+ # order by default (name)
+ self._order_test([self.admin_user, user1, user2], None)
+ self._order_test([self.admin_user, user1, user2], None, "f")
+ self._order_test([user2, user1, self.admin_user], None, "b")
+
+ # order by name
+ self._order_test([self.admin_user, user1, user2], "name")
+ self._order_test([self.admin_user, user1, user2], "name", "f")
+ self._order_test([user2, user1, self.admin_user], "name", "b")
+
+ # order by displayname
+ self._order_test([user2, user1, self.admin_user], "displayname")
+ self._order_test([user2, user1, self.admin_user], "displayname", "f")
+ self._order_test([self.admin_user, user1, user2], "displayname", "b")
+
+ # order by is_guest
+ # like sort by ascending name, as no guest user here
+ self._order_test([self.admin_user, user1, user2], "is_guest")
+ self._order_test([self.admin_user, user1, user2], "is_guest", "f")
+ self._order_test([self.admin_user, user1, user2], "is_guest", "b")
+
+ # order by admin
+ self._order_test([user1, user2, self.admin_user], "admin")
+ self._order_test([user1, user2, self.admin_user], "admin", "f")
+ self._order_test([self.admin_user, user1, user2], "admin", "b")
+
+ # order by deactivated
+ self._order_test([self.admin_user, user2, user1], "deactivated")
+ self._order_test([self.admin_user, user2, user1], "deactivated", "f")
+ self._order_test([user1, self.admin_user, user2], "deactivated", "b")
+
+ # order by user_type
+ # like sort by ascending name, as no special user type here
+ self._order_test([self.admin_user, user1, user2], "user_type")
+ self._order_test([self.admin_user, user1, user2], "user_type", "f")
+ self._order_test([self.admin_user, user1, user2], "is_guest", "b")
+
+ # order by shadow_banned
+ self._order_test([self.admin_user, user2, user1], "shadow_banned")
+ self._order_test([self.admin_user, user2, user1], "shadow_banned", "f")
+ self._order_test([user1, self.admin_user, user2], "shadow_banned", "b")
+
+ # order by avatar_url
+ self._order_test([self.admin_user, user2, user1], "avatar_url")
+ self._order_test([self.admin_user, user2, user1], "avatar_url", "f")
+ self._order_test([user1, user2, self.admin_user], "avatar_url", "b")
+
+ def _order_test(
+ self,
+ expected_user_list: List[str],
+ order_by: Optional[str],
+ dir: Optional[str] = None,
+ ):
+ """Request the list of users in a certain order. Assert that order is what
+ we expect
+ Args:
+ expected_user_list: The list of user_id in the order we expect to get
+ back from the server
+ order_by: The type of ordering to give the server
+ dir: The direction of ordering to give the server
+ """
+
+ url = self.url + "?deactivated=true&"
+ if order_by is not None:
+ url += "order_by=%s&" % (order_by,)
+ if dir is not None and dir in ("b", "f"):
+ url += "dir=%s" % (dir,)
+ channel = self.make_request(
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(channel.json_body["total"], len(expected_user_list))
+
+ returned_order = [row["name"] for row in channel.json_body["users"]]
+ self.assertEqual(expected_user_list, returned_order)
+ self._check_fields(channel.json_body["users"])
+
def _check_fields(self, content: JsonDict):
"""Checks that the expected user attributes are present in content
Args:
@@ -2908,3 +3026,287 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
# Ensure the user is shadow-banned (and the cache was cleared).
result = self.get_success(self.store.get_user_by_access_token(other_user_token))
self.assertTrue(result.shadow_banned)
+
+
+class RateLimitTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.url = (
+ "/_synapse/admin/v1/users/%s/override_ratelimit"
+ % urllib.parse.quote(self.other_user)
+ )
+
+ def test_no_auth(self):
+ """
+ Try to get information of a user without authentication.
+ """
+ channel = self.make_request("GET", self.url, b"{}")
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ channel = self.make_request("POST", self.url, b"{}")
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ channel = self.make_request("DELETE", self.url, b"{}")
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """
+ If the user is not a server admin, an error is returned.
+ """
+ other_user_token = self.login("user", "pass")
+
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=other_user_token,
+ )
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=other_user_token,
+ )
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ access_token=other_user_token,
+ )
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_user_does_not_exist(self):
+ """
+ Tests that a lookup for a user that does not exist returns a 404
+ """
+ url = "/_synapse/admin/v1/users/@unknown_person:test/override_ratelimit"
+
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ channel = self.make_request(
+ "POST",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ channel = self.make_request(
+ "DELETE",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ def test_user_is_not_local(self):
+ """
+ Tests that a lookup for a user that is not a local returns a 400
+ """
+ url = (
+ "/_synapse/admin/v1/users/@unknown_person:unknown_domain/override_ratelimit"
+ )
+
+ channel = self.make_request(
+ "GET",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Can only lookup local users", channel.json_body["error"])
+
+ channel = self.make_request(
+ "POST",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ "Only local users can be ratelimited", channel.json_body["error"]
+ )
+
+ channel = self.make_request(
+ "DELETE",
+ url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ "Only local users can be ratelimited", channel.json_body["error"]
+ )
+
+ def test_invalid_parameter(self):
+ """
+ If parameters are invalid, an error is returned.
+ """
+ # messages_per_second is a string
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={"messages_per_second": "string"},
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ # messages_per_second is negative
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={"messages_per_second": -1},
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ # burst_count is a string
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={"burst_count": "string"},
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ # burst_count is negative
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={"burst_count": -1},
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+ def test_return_zero_when_null(self):
+ """
+ If values in database are `null` API should return an int `0`
+ """
+
+ self.get_success(
+ self.store.db_pool.simple_upsert(
+ table="ratelimit_override",
+ keyvalues={"user_id": self.other_user},
+ values={
+ "messages_per_second": None,
+ "burst_count": None,
+ },
+ )
+ )
+
+ # request status
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(0, channel.json_body["messages_per_second"])
+ self.assertEqual(0, channel.json_body["burst_count"])
+
+ def test_success(self):
+ """
+ Rate-limiting (set/update/delete) should succeed for an admin.
+ """
+ # request status
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertNotIn("messages_per_second", channel.json_body)
+ self.assertNotIn("burst_count", channel.json_body)
+
+ # set ratelimit
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={"messages_per_second": 10, "burst_count": 11},
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(10, channel.json_body["messages_per_second"])
+ self.assertEqual(11, channel.json_body["burst_count"])
+
+ # update ratelimit
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={"messages_per_second": 20, "burst_count": 21},
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(20, channel.json_body["messages_per_second"])
+ self.assertEqual(21, channel.json_body["burst_count"])
+
+ # request status
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(20, channel.json_body["messages_per_second"])
+ self.assertEqual(21, channel.json_body["burst_count"])
+
+ # delete ratelimit
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertNotIn("messages_per_second", channel.json_body)
+ self.assertNotIn("burst_count", channel.json_body)
+
+ # request status
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertNotIn("messages_per_second", channel.json_body)
+ self.assertNotIn("burst_count", channel.json_body)
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index b8285f3240..be1211dbce 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from synapse.api.constants import EventTypes
from synapse.rest import admin
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index d2cce44032..288ee12888 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock, patch
+from unittest.mock import Mock, patch
import synapse.rest.admin
from synapse.api.constants import EventTypes
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index bf39014277..a7ebe0c3e9 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -14,8 +14,7 @@
# limitations under the License.
import threading
from typing import Dict
-
-from mock import Mock
+from unittest.mock import Mock
from synapse.events import EventBase
from synapse.module_api import ModuleApi
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index 171632e195..3b5747cb12 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -1,4 +1,4 @@
-from mock import Mock, call
+from unittest.mock import Mock, call
from twisted.internet import defer, reactor
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index 2ae896db1e..87a18d2cb9 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -15,7 +15,7 @@
""" Tests REST events for /events paths."""
-from mock import Mock
+from unittest.mock import Mock
import synapse.rest.admin
from synapse.rest.client.v1 import events, login, room
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 988821b16f..c7b79ab8a7 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -16,10 +16,9 @@
import time
import urllib.parse
from typing import Any, Dict, List, Optional, Union
+from unittest.mock import Mock
from urllib.parse import urlencode
-from mock import Mock
-
import pymacaroons
from twisted.web.resource import Resource
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 94a5154834..c136827f79 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from twisted.internet import defer
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index ed65f645fc..4df20c90fd 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -19,10 +19,10 @@
"""Tests REST events for /rooms paths."""
import json
+from typing import Iterable
+from unittest.mock import Mock
from urllib import parse as urlparse
-from mock import Mock
-
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.handlers.pagination import PurgeStatus
@@ -207,7 +207,9 @@ class RoomPermissionsTestCase(RoomBase):
)
self.assertEquals(403, channel.code, msg=channel.result["body"])
- def _test_get_membership(self, room=None, members=[], expect_code=None):
+ def _test_get_membership(
+ self, room=None, members: Iterable = frozenset(), expect_code=None
+ ):
for member in members:
path = "/rooms/%s/state/m.room.member/%s" % (room, member)
channel = self.make_request("GET", path)
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 329dbd06de..0b8f565121 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -16,7 +16,7 @@
"""Tests REST events for /rooms paths."""
-from mock import Mock
+from unittest.mock import Mock
from synapse.rest.client.v1 import room
from synapse.types import UserID
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 946740aa5d..a6a292b20c 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -21,8 +21,7 @@ import re
import time
import urllib.parse
from typing import Any, Dict, Mapping, MutableMapping, Optional
-
-from mock import patch
+from unittest.mock import patch
import attr
@@ -132,7 +131,7 @@ class RestHelper:
src: str,
targ: str,
membership: str,
- extra_data: dict = {},
+ extra_data: Optional[dict] = None,
tok: Optional[str] = None,
expect_code: int = 200,
) -> None:
@@ -156,7 +155,7 @@ class RestHelper:
path = path + "?access_token=%s" % tok
data = {"membership": membership}
- data.update(extra_data)
+ data.update(extra_data or {})
channel = make_request(
self.hs.get_reactor(),
@@ -187,7 +186,13 @@ class RestHelper:
)
def send_event(
- self, room_id, type, content={}, txn_id=None, tok=None, expect_code=200
+ self,
+ room_id,
+ type,
+ content: Optional[dict] = None,
+ txn_id=None,
+ tok=None,
+ expect_code=200,
):
if txn_id is None:
txn_id = "m%s" % (str(time.time()))
@@ -201,7 +206,7 @@ class RestHelper:
self.site,
"PUT",
path,
- json.dumps(content).encode("utf8"),
+ json.dumps(content or {}).encode("utf8"),
)
assert (
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index 9734a2159a..ed433d9333 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Union
+from typing import Optional, Union
from twisted.internet.defer import succeed
@@ -74,7 +74,10 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
return channel
def recaptcha(
- self, session: str, expected_post_response: int, post_session: str = None
+ self,
+ session: str,
+ expected_post_response: int,
+ post_session: Optional[str] = None,
) -> None:
"""Get and respond to a fallback recaptcha. Returns the second request."""
if post_session is None:
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 2d4ce871eb..41e52c701f 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import datetime
import json
import os
@@ -28,7 +27,7 @@ import pkg_resources
from twisted.internet import defer
import synapse.rest.admin
-from synapse.api.constants import LoginType
+from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
from synapse.rest.client.v1 import login, logout
@@ -65,7 +64,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
)
self.hs.get_datastore().services_cache.append(appservice)
- request_data = json.dumps({"username": "as_user_kermit"})
+ request_data = json.dumps(
+ {"username": "as_user_kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
+ )
channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
@@ -75,9 +76,31 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
det_data = {"user_id": user_id, "home_server": self.hs.hostname}
self.assertDictContainsSubset(det_data, channel.json_body)
+ def test_POST_appservice_registration_no_type(self):
+ as_token = "i_am_an_app_service"
+
+ appservice = ApplicationService(
+ as_token,
+ self.hs.config.server_name,
+ id="1234",
+ namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
+ sender="@as:test",
+ )
+
+ self.hs.get_datastore().services_cache.append(appservice)
+ request_data = json.dumps({"username": "as_user_kermit"})
+
+ channel = self.make_request(
+ b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
+ )
+
+ self.assertEquals(channel.result["code"], b"400", channel.result)
+
def test_POST_appservice_registration_invalid(self):
self.appservice = None # no application service exists
- request_data = json.dumps({"username": "kermit"})
+ request_data = json.dumps(
+ {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
+ )
channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index e7bb5583fc..21ee436b91 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -16,6 +16,7 @@
import itertools
import json
import urllib
+from typing import Optional
from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin
@@ -681,7 +682,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
relation_type,
event_type,
key=None,
- content={},
+ content: Optional[dict] = None,
access_token=None,
parent_id=None,
):
@@ -713,7 +714,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"POST",
"/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
% (self.room, original_id, relation_type, event_type, query),
- json.dumps(content).encode("utf-8"),
+ json.dumps(content or {}).encode("utf-8"),
access_token=access_token,
)
return channel
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 9d0d0ef414..eb8687ce68 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -14,8 +14,7 @@
# limitations under the License.
import urllib.parse
from io import BytesIO, StringIO
-
-from mock import Mock
+from unittest.mock import Mock
import signedjson.key
from canonicaljson import encode_canonical_json
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 9f77125fd4..375f0b7977 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -18,10 +18,9 @@ import tempfile
from binascii import unhexlify
from io import BytesIO
from typing import Optional
+from unittest.mock import Mock
from urllib import parse
-from mock import Mock
-
import attr
from parameterized import parameterized_class
from PIL import Image as Image
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 6968502433..9067463e54 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -15,8 +15,7 @@
import json
import os
import re
-
-from mock import patch
+from unittest.mock import patch
from twisted.internet._resolver import HostResolution
from twisted.internet.address import IPv4Address, IPv6Address
diff --git a/tests/scripts/test_new_matrix_user.py b/tests/scripts/test_new_matrix_user.py
index 6f56893f5e..885b95a51f 100644
--- a/tests/scripts/test_new_matrix_user.py
+++ b/tests/scripts/test_new_matrix_user.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from synapse._scripts.register_new_matrix_user import request_registration
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index d40d65b06a..450b4ec710 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from twisted.internet import defer
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 1ce29af5fd..e755a4db62 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -15,8 +15,7 @@
import json
import os
import tempfile
-
-from mock import Mock
+from unittest.mock import Mock
import yaml
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 1b4fae0bb5..069db0edc4 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -1,4 +1,4 @@
-from mock import Mock
+from unittest.mock import Mock
from synapse.storage.background_updates import BackgroundUpdater
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index eac7e4dcd2..54e9e7f6fe 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -15,8 +15,7 @@
from collections import OrderedDict
-
-from mock import Mock
+from unittest.mock import Mock
from twisted.internet import defer
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 7791138688..b02fb32ced 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -14,9 +14,7 @@
# limitations under the License.
import os.path
-from unittest.mock import patch
-
-from mock import Mock
+from unittest.mock import Mock, patch
import synapse.rest.admin
from synapse.api.constants import EventTypes
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 34e6526097..f7f75320ba 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
import synapse.rest.admin
from synapse.http.site import XForwardedForRequest
@@ -390,7 +390,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
class ClientIpAuthTestCase(unittest.HomeserverTestCase):
servlets = [
- synapse.rest.admin.register_servlets_for_client_rest_resource,
+ synapse.rest.admin.register_servlets,
login.register_servlets,
]
@@ -434,7 +434,7 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
self.reactor,
self.site,
"GET",
- "/_synapse/admin/v1/users/" + self.user_id,
+ "/_synapse/admin/v2/users/" + self.user_id,
access_token=access_token,
custom_headers=headers1.items(),
**make_request_args,
diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py
index 5a77c84962..a906d30e73 100644
--- a/tests/storage/test_database.py
+++ b/tests/storage/test_database.py
@@ -36,17 +36,6 @@ def _stub_db_engine(**kwargs) -> BaseDatabaseEngine:
class TupleComparisonClauseTestCase(unittest.TestCase):
def test_native_tuple_comparison(self):
- db_engine = _stub_db_engine(supports_tuple_comparison=True)
- clause, args = make_tuple_comparison_clause(db_engine, [("a", 1), ("b", 2)])
+ clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)])
self.assertEqual(clause, "(a,b) > (?,?)")
self.assertEqual(args, [1, 2])
-
- def test_emulated_tuple_comparison(self):
- db_engine = _stub_db_engine(supports_tuple_comparison=False)
- clause, args = make_tuple_comparison_clause(
- db_engine, [("a", 1), ("b", 2), ("c", 3)]
- )
- self.assertEqual(
- clause, "(a >= ? AND (a > ? OR (b >= ? AND (b > ? OR c > ?))))"
- )
- self.assertEqual(args, [1, 1, 2, 2, 3])
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index dabc1c5f09..ef4cf8d0f1 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2016 OpenMarket Ltd
+# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,32 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
import synapse.api.errors
-import tests.unittest
-import tests.utils
-
-
-class DeviceStoreTestCase(tests.unittest.TestCase):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.store = None # type: synapse.storage.DataStore
+from tests.unittest import HomeserverTestCase
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
+class DeviceStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
- @defer.inlineCallbacks
def test_store_new_device(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id", "device_id", "display_name")
)
- res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
+ res = self.get_success(self.store.get_device("user_id", "device_id"))
self.assertDictContainsSubset(
{
"user_id": "user_id",
@@ -48,19 +37,18 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
res,
)
- @defer.inlineCallbacks
def test_get_devices_by_user(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id", "device1", "display_name 1")
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id", "device2", "display_name 2")
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id2", "device3", "display_name 3")
)
- res = yield defer.ensureDeferred(self.store.get_devices_by_user("user_id"))
+ res = self.get_success(self.store.get_devices_by_user("user_id"))
self.assertEqual(2, len(res.keys()))
self.assertDictContainsSubset(
{
@@ -79,43 +67,41 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
res["device2"],
)
- @defer.inlineCallbacks
def test_count_devices_by_users(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id", "device1", "display_name 1")
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id", "device2", "display_name 2")
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id2", "device3", "display_name 3")
)
- res = yield defer.ensureDeferred(self.store.count_devices_by_users())
+ res = self.get_success(self.store.count_devices_by_users())
self.assertEqual(0, res)
- res = yield defer.ensureDeferred(self.store.count_devices_by_users(["unknown"]))
+ res = self.get_success(self.store.count_devices_by_users(["unknown"]))
self.assertEqual(0, res)
- res = yield defer.ensureDeferred(self.store.count_devices_by_users(["user_id"]))
+ res = self.get_success(self.store.count_devices_by_users(["user_id"]))
self.assertEqual(2, res)
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.store.count_devices_by_users(["user_id", "user_id2"])
)
self.assertEqual(3, res)
- @defer.inlineCallbacks
def test_get_device_updates_by_remote(self):
device_ids = ["device_id1", "device_id2"]
# Add two device updates with a single stream_id
- yield defer.ensureDeferred(
+ self.get_success(
self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
)
# Get all device updates ever meant for this remote
- now_stream_id, device_updates = yield defer.ensureDeferred(
+ now_stream_id, device_updates = self.get_success(
self.store.get_device_updates_by_remote("somehost", -1, limit=100)
)
@@ -131,37 +117,35 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
}
self.assertEqual(received_device_ids, set(expected_device_ids))
- @defer.inlineCallbacks
def test_update_device(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id", "device_id", "display_name 1")
)
- res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
+ res = self.get_success(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"])
# do a no-op first
- yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
- res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
+ self.get_success(self.store.update_device("user_id", "device_id"))
+ res = self.get_success(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"])
# do the update
- yield defer.ensureDeferred(
+ self.get_success(
self.store.update_device(
"user_id", "device_id", new_display_name="display_name 2"
)
)
# check it worked
- res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
+ res = self.get_success(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 2", res["display_name"])
- @defer.inlineCallbacks
def test_update_unknown_device(self):
- with self.assertRaises(synapse.api.errors.StoreError) as cm:
- yield defer.ensureDeferred(
- self.store.update_device(
- "user_id", "unknown_device_id", new_display_name="display_name 2"
- )
- )
- self.assertEqual(404, cm.exception.code)
+ exc = self.get_failure(
+ self.store.update_device(
+ "user_id", "unknown_device_id", new_display_name="display_name 2"
+ ),
+ synapse.api.errors.StoreError,
+ )
+ self.assertEqual(404, exc.value.code)
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index da93ca3980..0db233fd68 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,28 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from twisted.internet import defer
-
from synapse.types import RoomAlias, RoomID
-from tests import unittest
-from tests.utils import setup_test_homeserver
+from tests.unittest import HomeserverTestCase
-class DirectoryStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield setup_test_homeserver(self.addCleanup)
-
+class DirectoryStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.room = RoomID.from_string("!abcde:test")
self.alias = RoomAlias.from_string("#my-room:test")
- @defer.inlineCallbacks
def test_room_to_alias(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
)
@@ -42,16 +34,11 @@ class DirectoryStoreTestCase(unittest.TestCase):
self.assertEquals(
["#my-room:test"],
- (
- yield defer.ensureDeferred(
- self.store.get_aliases_for_room(self.room.to_string())
- )
- ),
+ (self.get_success(self.store.get_aliases_for_room(self.room.to_string()))),
)
- @defer.inlineCallbacks
def test_alias_to_room(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
)
@@ -59,28 +46,19 @@ class DirectoryStoreTestCase(unittest.TestCase):
self.assertObjectHasAttributes(
{"room_id": self.room.to_string(), "servers": ["test"]},
- (
- yield defer.ensureDeferred(
- self.store.get_association_from_room_alias(self.alias)
- )
- ),
+ (self.get_success(self.store.get_association_from_room_alias(self.alias))),
)
- @defer.inlineCallbacks
def test_delete_alias(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
)
)
- room_id = yield defer.ensureDeferred(self.store.delete_room_alias(self.alias))
+ room_id = self.get_success(self.store.delete_room_alias(self.alias))
self.assertEqual(self.room.to_string(), room_id)
self.assertIsNone(
- (
- yield defer.ensureDeferred(
- self.store.get_association_from_room_alias(self.alias)
- )
- )
+ (self.get_success(self.store.get_association_from_room_alias(self.alias)))
)
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 3fc4bb13b6..1e54b940fd 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2016 OpenMarket Ltd
+# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,30 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
+from tests.unittest import HomeserverTestCase
-import tests.unittest
-import tests.utils
-
-class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
+class EndToEndKeyStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
- @defer.inlineCallbacks
def test_key_without_device_name(self):
now = 1470174257070
json = {"key": "value"}
- yield defer.ensureDeferred(self.store.store_device("user", "device", None))
+ self.get_success(self.store.store_device("user", "device", None))
- yield defer.ensureDeferred(
- self.store.set_e2e_device_keys("user", "device", now, json)
- )
+ self.get_success(self.store.set_e2e_device_keys("user", "device", now, json))
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
)
self.assertIn("user", res)
@@ -44,38 +36,32 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
dev = res["user"]["device"]
self.assertDictContainsSubset(json, dev)
- @defer.inlineCallbacks
def test_reupload_key(self):
now = 1470174257070
json = {"key": "value"}
- yield defer.ensureDeferred(self.store.store_device("user", "device", None))
+ self.get_success(self.store.store_device("user", "device", None))
- changed = yield defer.ensureDeferred(
+ changed = self.get_success(
self.store.set_e2e_device_keys("user", "device", now, json)
)
self.assertTrue(changed)
# If we try to upload the same key then we should be told nothing
# changed
- changed = yield defer.ensureDeferred(
+ changed = self.get_success(
self.store.set_e2e_device_keys("user", "device", now, json)
)
self.assertFalse(changed)
- @defer.inlineCallbacks
def test_get_key_with_device_name(self):
now = 1470174257070
json = {"key": "value"}
- yield defer.ensureDeferred(
- self.store.set_e2e_device_keys("user", "device", now, json)
- )
- yield defer.ensureDeferred(
- self.store.store_device("user", "device", "display_name")
- )
+ self.get_success(self.store.set_e2e_device_keys("user", "device", now, json))
+ self.get_success(self.store.store_device("user", "device", "display_name"))
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
)
self.assertIn("user", res)
@@ -85,29 +71,28 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
{"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev
)
- @defer.inlineCallbacks
def test_multiple_devices(self):
now = 1470174257070
- yield defer.ensureDeferred(self.store.store_device("user1", "device1", None))
- yield defer.ensureDeferred(self.store.store_device("user1", "device2", None))
- yield defer.ensureDeferred(self.store.store_device("user2", "device1", None))
- yield defer.ensureDeferred(self.store.store_device("user2", "device2", None))
+ self.get_success(self.store.store_device("user1", "device1", None))
+ self.get_success(self.store.store_device("user1", "device2", None))
+ self.get_success(self.store.store_device("user2", "device1", None))
+ self.get_success(self.store.store_device("user2", "device2", None))
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
)
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.store.get_e2e_device_keys_for_cs_api(
(("user1", "device1"), ("user2", "device2"))
)
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 485f1ee033..0289942f88 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2016 OpenMarket Ltd
+# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,12 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
-from twisted.internet import defer
-
-import tests.unittest
-import tests.utils
+from tests.unittest import HomeserverTestCase
USER_ID = "@user:example.com"
@@ -30,37 +27,31 @@ HIGHLIGHT = [
]
-class EventPushActionsStoreTestCase(tests.unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
+class EventPushActionsStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.persist_events_store = hs.get_datastores().persist_events
- @defer.inlineCallbacks
def test_get_unread_push_actions_for_user_in_range_for_http(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_http(
USER_ID, 0, 1000, 20
)
)
- @defer.inlineCallbacks
def test_get_unread_push_actions_for_user_in_range_for_email(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_email(
USER_ID, 0, 1000, 20
)
)
- @defer.inlineCallbacks
def test_count_aggregation(self):
room_id = "!foo:example.com"
user_id = "@user1235:example.com"
- @defer.inlineCallbacks
def _assert_counts(noitf_count, highlight_count):
- counts = yield defer.ensureDeferred(
+ counts = self.get_success(
self.store.db_pool.runInteraction(
"", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
)
@@ -74,7 +65,6 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
},
)
- @defer.inlineCallbacks
def _inject_actions(stream, action):
event = Mock()
event.room_id = room_id
@@ -82,14 +72,14 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
event.internal_metadata.stream_ordering = stream
event.depth = stream
- yield defer.ensureDeferred(
+ self.get_success(
self.store.add_push_actions_to_staging(
event.event_id,
{user_id: action},
False,
)
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.db_pool.runInteraction(
"",
self.persist_events_store._set_push_actions_for_event_and_users_txn,
@@ -99,14 +89,14 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
)
def _rotate(stream):
- return defer.ensureDeferred(
+ self.get_success(
self.store.db_pool.runInteraction(
"", self.store._rotate_notifs_before_txn, stream
)
)
def _mark_read(stream, depth):
- return defer.ensureDeferred(
+ self.get_success(
self.store.db_pool.runInteraction(
"",
self.store._remove_old_push_actions_before_txn,
@@ -116,49 +106,48 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
)
)
- yield _assert_counts(0, 0)
- yield _inject_actions(1, PlAIN_NOTIF)
- yield _assert_counts(1, 0)
- yield _rotate(2)
- yield _assert_counts(1, 0)
+ _assert_counts(0, 0)
+ _inject_actions(1, PlAIN_NOTIF)
+ _assert_counts(1, 0)
+ _rotate(2)
+ _assert_counts(1, 0)
- yield _inject_actions(3, PlAIN_NOTIF)
- yield _assert_counts(2, 0)
- yield _rotate(4)
- yield _assert_counts(2, 0)
+ _inject_actions(3, PlAIN_NOTIF)
+ _assert_counts(2, 0)
+ _rotate(4)
+ _assert_counts(2, 0)
- yield _inject_actions(5, PlAIN_NOTIF)
- yield _mark_read(3, 3)
- yield _assert_counts(1, 0)
+ _inject_actions(5, PlAIN_NOTIF)
+ _mark_read(3, 3)
+ _assert_counts(1, 0)
- yield _mark_read(5, 5)
- yield _assert_counts(0, 0)
+ _mark_read(5, 5)
+ _assert_counts(0, 0)
- yield _inject_actions(6, PlAIN_NOTIF)
- yield _rotate(7)
+ _inject_actions(6, PlAIN_NOTIF)
+ _rotate(7)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.db_pool.simple_delete(
table="event_push_actions", keyvalues={"1": 1}, desc=""
)
)
- yield _assert_counts(1, 0)
+ _assert_counts(1, 0)
- yield _mark_read(7, 7)
- yield _assert_counts(0, 0)
+ _mark_read(7, 7)
+ _assert_counts(0, 0)
- yield _inject_actions(8, HIGHLIGHT)
- yield _assert_counts(1, 1)
- yield _rotate(9)
- yield _assert_counts(1, 1)
- yield _rotate(10)
- yield _assert_counts(1, 1)
+ _inject_actions(8, HIGHLIGHT)
+ _assert_counts(1, 1)
+ _rotate(9)
+ _assert_counts(1, 1)
+ _rotate(10)
+ _assert_counts(1, 1)
- @defer.inlineCallbacks
def test_find_first_stream_ordering_after_ts(self):
def add_event(so, ts):
- return defer.ensureDeferred(
+ self.get_success(
self.store.db_pool.simple_insert(
"events",
{
@@ -177,24 +166,16 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
)
# start with the base case where there are no events in the table
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(11)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(11))
self.assertEqual(r, 0)
# now with one event
- yield add_event(2, 10)
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(9)
- )
+ add_event(2, 10)
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(9))
self.assertEqual(r, 2)
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(10)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(10))
self.assertEqual(r, 2)
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(11)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(11))
self.assertEqual(r, 3)
# add a bunch of dummy events to the events table
@@ -205,39 +186,27 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
(10, 130),
(20, 140),
):
- yield add_event(stream_ordering, ts)
+ add_event(stream_ordering, ts)
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(110)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(110))
self.assertEqual(r, 3, "First event after 110ms should be 3, was %i" % r)
# 4 and 5 are both after 120: we want 4 rather than 5
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(120)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(120))
self.assertEqual(r, 4, "First event after 120ms should be 4, was %i" % r)
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(129)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(129))
self.assertEqual(r, 10, "First event after 129ms should be 10, was %i" % r)
# check we can get the last event
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(140)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(140))
self.assertEqual(r, 20, "First event after 14ms should be 20, was %i" % r)
# off the end
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(160)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(160))
self.assertEqual(r, 21)
# check we can find an event at ordering zero
- yield add_event(0, 5)
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(1)
- )
+ add_event(0, 5)
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(1))
self.assertEqual(r, 0)
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index aad6bc907e..6c389fe9ac 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List, Optional
+
from synapse.storage.database import DatabasePool
from synapse.storage.engines import IncorrectDatabaseSetup
from synapse.storage.util.id_generators import MultiWriterIdGenerator
@@ -43,7 +45,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
)
def _create_id_generator(
- self, instance_name="master", writers=["master"]
+ self, instance_name="master", writers: Optional[List[str]] = None
) -> MultiWriterIdGenerator:
def _create(conn):
return MultiWriterIdGenerator(
@@ -53,7 +55,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
instance_name=instance_name,
tables=[("foobar", "instance_name", "stream_id")],
sequence_name="foobar_seq",
- writers=writers,
+ writers=writers or ["master"],
)
return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
@@ -476,7 +478,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
)
def _create_id_generator(
- self, instance_name="master", writers=["master"]
+ self, instance_name="master", writers: Optional[List[str]] = None
) -> MultiWriterIdGenerator:
def _create(conn):
return MultiWriterIdGenerator(
@@ -486,7 +488,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
instance_name=instance_name,
tables=[("foobar", "instance_name", "stream_id")],
sequence_name="foobar_seq",
- writers=writers,
+ writers=writers or ["master"],
positive=False,
)
@@ -612,7 +614,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
)
def _create_id_generator(
- self, instance_name="master", writers=["master"]
+ self, instance_name="master", writers: Optional[List[str]] = None
) -> MultiWriterIdGenerator:
def _create(conn):
return MultiWriterIdGenerator(
@@ -625,7 +627,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
("foobar2", "instance_name", "stream_id"),
],
sequence_name="foobar_seq",
- writers=writers,
+ writers=writers or ["master"],
)
return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 5858c7fcc4..47556791f4 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from twisted.internet import defer
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index b7dde51224..c6256fce86 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,59 +13,50 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from twisted.internet import defer
-
from synapse.types import UserID
from tests import unittest
-from tests.utils import setup_test_homeserver
-
-class ProfileStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield setup_test_homeserver(self.addCleanup)
+class ProfileStoreTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.u_frank = UserID.from_string("@frank:test")
- @defer.inlineCallbacks
def test_displayname(self):
- yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
+ self.get_success(self.store.create_profile(self.u_frank.localpart))
- yield defer.ensureDeferred(
- self.store.set_profile_displayname(self.u_frank.localpart, "Frank", 1)
+ self.get_success(
+ self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
)
self.assertEquals(
"Frank",
(
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_profile_displayname(self.u_frank.localpart)
)
),
)
# test set to None
- yield defer.ensureDeferred(
- self.store.set_profile_displayname(self.u_frank.localpart, None, 2)
+ self.get_success(
+ self.store.set_profile_displayname(self.u_frank.localpart, None)
)
self.assertIsNone(
(
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_profile_displayname(self.u_frank.localpart)
)
)
)
- @defer.inlineCallbacks
def test_avatar_url(self):
- yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
+ self.get_success(self.store.create_profile(self.u_frank.localpart))
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_profile_avatar_url(
self.u_frank.localpart, "http://my.site/here", 1
)
@@ -74,20 +65,20 @@ class ProfileStoreTestCase(unittest.TestCase):
self.assertEquals(
"http://my.site/here",
(
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_profile_avatar_url(self.u_frank.localpart)
)
),
)
# test set to None
- yield defer.ensureDeferred(
- self.store.set_profile_avatar_url(self.u_frank.localpart, None, 2)
+ self.get_success(
+ self.store.set_profile_avatar_url(self.u_frank.localpart, None)
)
self.assertIsNone(
(
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_profile_avatar_url(self.u_frank.localpart)
)
)
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index b2a0e60856..2d2f58903c 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,11 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional
from canonicaljson import json
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.types import RoomID, UserID
@@ -50,10 +48,15 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.depth = 1
def inject_room_member(
- self, room, user, membership, replaces_state=None, extra_content={}
+ self,
+ room,
+ user,
+ membership,
+ replaces_state=None,
+ extra_content: Optional[dict] = None,
):
content = {"membership": membership}
- content.update(extra_content)
+ content.update(extra_content or {})
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
@@ -230,10 +233,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self._base_builder = base_builder
self._event_id = event_id
- @defer.inlineCallbacks
- def build(self, prev_event_ids, auth_event_ids):
- built_event = yield defer.ensureDeferred(
- self._base_builder.build(prev_event_ids, auth_event_ids)
+ async def build(self, prev_event_ids, auth_event_ids):
+ built_event = await self._base_builder.build(
+ prev_event_ids, auth_event_ids
)
built_event._event_id = self._event_id
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 4eb41c46e8..c82cf15bc2 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,21 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from twisted.internet import defer
-
from synapse.api.constants import UserTypes
from synapse.api.errors import ThreepidValidationError
-from tests import unittest
-from tests.utils import setup_test_homeserver
-
+from tests.unittest import HomeserverTestCase
-class RegistrationStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield setup_test_homeserver(self.addCleanup)
+class RegistrationStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.user_id = "@my-user:test"
@@ -35,9 +28,8 @@ class RegistrationStoreTestCase(unittest.TestCase):
self.pwhash = "{xx1}123456789"
self.device_id = "akgjhdjklgshg"
- @defer.inlineCallbacks
def test_register(self):
- yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
+ self.get_success(self.store.register_user(self.user_id, self.pwhash))
self.assertEquals(
{
@@ -49,93 +41,81 @@ class RegistrationStoreTestCase(unittest.TestCase):
"consent_version": None,
"consent_server_notice_sent": None,
"appservice_id": None,
- "creation_ts": 1000,
+ "creation_ts": 0,
"user_type": None,
"deactivated": 0,
"shadow_banned": 0,
},
- (yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))),
+ (self.get_success(self.store.get_user_by_id(self.user_id))),
)
- @defer.inlineCallbacks
def test_add_tokens(self):
- yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
- yield defer.ensureDeferred(
+ self.get_success(self.store.register_user(self.user_id, self.pwhash))
+ self.get_success(
self.store.add_access_token_to_user(
self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
)
)
- result = yield defer.ensureDeferred(
- self.store.get_user_by_access_token(self.tokens[1])
- )
+ result = self.get_success(self.store.get_user_by_access_token(self.tokens[1]))
self.assertEqual(result.user_id, self.user_id)
self.assertEqual(result.device_id, self.device_id)
self.assertIsNotNone(result.token_id)
- @defer.inlineCallbacks
def test_user_delete_access_tokens(self):
# add some tokens
- yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
- yield defer.ensureDeferred(
+ self.get_success(self.store.register_user(self.user_id, self.pwhash))
+ self.get_success(
self.store.add_access_token_to_user(
self.user_id, self.tokens[0], device_id=None, valid_until_ms=None
)
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.add_access_token_to_user(
self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
)
)
# now delete some
- yield defer.ensureDeferred(
+ self.get_success(
self.store.user_delete_access_tokens(self.user_id, device_id=self.device_id)
)
# check they were deleted
- user = yield defer.ensureDeferred(
- self.store.get_user_by_access_token(self.tokens[1])
- )
+ user = self.get_success(self.store.get_user_by_access_token(self.tokens[1]))
self.assertIsNone(user, "access token was not deleted by device_id")
# check the one not associated with the device was not deleted
- user = yield defer.ensureDeferred(
- self.store.get_user_by_access_token(self.tokens[0])
- )
+ user = self.get_success(self.store.get_user_by_access_token(self.tokens[0]))
self.assertEqual(self.user_id, user.user_id)
# now delete the rest
- yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id))
+ self.get_success(self.store.user_delete_access_tokens(self.user_id))
- user = yield defer.ensureDeferred(
- self.store.get_user_by_access_token(self.tokens[0])
- )
+ user = self.get_success(self.store.get_user_by_access_token(self.tokens[0]))
self.assertIsNone(user, "access token was not deleted without device_id")
- @defer.inlineCallbacks
def test_is_support_user(self):
TEST_USER = "@test:test"
SUPPORT_USER = "@support:test"
- res = yield defer.ensureDeferred(self.store.is_support_user(None))
+ res = self.get_success(self.store.is_support_user(None))
self.assertFalse(res)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.register_user(user_id=TEST_USER, password_hash=None)
)
- res = yield defer.ensureDeferred(self.store.is_support_user(TEST_USER))
+ res = self.get_success(self.store.is_support_user(TEST_USER))
self.assertFalse(res)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.register_user(
user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT
)
)
- res = yield defer.ensureDeferred(self.store.is_support_user(SUPPORT_USER))
+ res = self.get_success(self.store.is_support_user(SUPPORT_USER))
self.assertTrue(res)
- @defer.inlineCallbacks
def test_3pid_inhibit_invalid_validation_session_error(self):
"""Tests that enabling the configuration option to inhibit 3PID errors on
/requestToken also inhibits validation errors caused by an unknown session ID.
@@ -143,30 +123,28 @@ class RegistrationStoreTestCase(unittest.TestCase):
# Check that, with the config setting set to false (the default value), a
# validation error is caused by the unknown session ID.
- try:
- yield defer.ensureDeferred(
- self.store.validate_threepid_session(
- "fake_sid",
- "fake_client_secret",
- "fake_token",
- 0,
- )
- )
- except ThreepidValidationError as e:
- self.assertEquals(e.msg, "Unknown session_id", e)
+ e = self.get_failure(
+ self.store.validate_threepid_session(
+ "fake_sid",
+ "fake_client_secret",
+ "fake_token",
+ 0,
+ ),
+ ThreepidValidationError,
+ )
+ self.assertEquals(e.value.msg, "Unknown session_id", e)
# Set the config setting to true.
self.store._ignore_unknown_session_error = True
# Check that now the validation error is caused by the token not matching.
- try:
- yield defer.ensureDeferred(
- self.store.validate_threepid_session(
- "fake_sid",
- "fake_client_secret",
- "fake_token",
- 0,
- )
- )
- except ThreepidValidationError as e:
- self.assertEquals(e.msg, "Validation token not found or has expired", e)
+ e = self.get_failure(
+ self.store.validate_threepid_session(
+ "fake_sid",
+ "fake_client_secret",
+ "fake_token",
+ 0,
+ ),
+ ThreepidValidationError,
+ )
+ self.assertEquals(e.value.msg, "Validation token not found or has expired", e)
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index bc8400f240..0089d33c93 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,22 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes
from synapse.api.room_versions import RoomVersions
from synapse.types import RoomAlias, RoomID, UserID
-from tests import unittest
-from tests.utils import setup_test_homeserver
-
+from tests.unittest import HomeserverTestCase
-class RoomStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield setup_test_homeserver(self.addCleanup)
+class RoomStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
# We can't test RoomStore on its own without the DirectoryStore, for
# management of the 'room_aliases' table
self.store = hs.get_datastore()
@@ -37,7 +30,7 @@ class RoomStoreTestCase(unittest.TestCase):
self.alias = RoomAlias.from_string("#a-room-name:test")
self.u_creator = UserID.from_string("@creator:test")
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_room(
self.room.to_string(),
room_creator_user_id=self.u_creator.to_string(),
@@ -46,7 +39,6 @@ class RoomStoreTestCase(unittest.TestCase):
)
)
- @defer.inlineCallbacks
def test_get_room(self):
self.assertDictContainsSubset(
{
@@ -54,16 +46,12 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(),
"is_public": True,
},
- (yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))),
+ (self.get_success(self.store.get_room(self.room.to_string()))),
)
- @defer.inlineCallbacks
def test_get_room_unknown_room(self):
- self.assertIsNone(
- (yield defer.ensureDeferred(self.store.get_room("!uknown:test")))
- )
+ self.assertIsNone((self.get_success(self.store.get_room("!uknown:test"))))
- @defer.inlineCallbacks
def test_get_room_with_stats(self):
self.assertDictContainsSubset(
{
@@ -71,29 +59,17 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(),
"public": True,
},
- (
- yield defer.ensureDeferred(
- self.store.get_room_with_stats(self.room.to_string())
- )
- ),
+ (self.get_success(self.store.get_room_with_stats(self.room.to_string()))),
)
- @defer.inlineCallbacks
def test_get_room_with_stats_unknown_room(self):
self.assertIsNone(
- (
- yield defer.ensureDeferred(
- self.store.get_room_with_stats("!uknown:test")
- )
- ),
+ (self.get_success(self.store.get_room_with_stats("!uknown:test"))),
)
-class RoomEventsStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = setup_test_homeserver(self.addCleanup)
-
+class RoomEventsStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
# Room events need the full datastore, for persist_event() and
# get_room_state()
self.store = hs.get_datastore()
@@ -102,7 +78,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
self.room = RoomID.from_string("!abcde:test")
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_room(
self.room.to_string(),
room_creator_user_id="@creator:text",
@@ -111,23 +87,21 @@ class RoomEventsStoreTestCase(unittest.TestCase):
)
)
- @defer.inlineCallbacks
def inject_room_event(self, **kwargs):
- yield defer.ensureDeferred(
+ self.get_success(
self.storage.persistence.persist_event(
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
)
)
- @defer.inlineCallbacks
def STALE_test_room_name(self):
name = "A-Room-Name"
- yield self.inject_room_event(
+ self.inject_room_event(
etype=EventTypes.Name, name=name, content={"name": name}, depth=1
)
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.store.get_current_state(room_id=self.room.to_string())
)
@@ -137,15 +111,14 @@ class RoomEventsStoreTestCase(unittest.TestCase):
state[0],
)
- @defer.inlineCallbacks
def STALE_test_room_topic(self):
topic = "A place for things"
- yield self.inject_room_event(
+ self.inject_room_event(
etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1
)
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.store.get_current_state(room_id=self.room.to_string())
)
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 2471f1267d..f06b452fa9 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,24 +15,18 @@
import logging
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.storage.state import StateFilter
from synapse.types import RoomID, UserID
-import tests.unittest
-import tests.utils
+from tests.unittest import HomeserverTestCase
logger = logging.getLogger(__name__)
-class StateStoreTestCase(tests.unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
-
+class StateStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.state_datastore = self.storage.state.stores.state
@@ -44,7 +38,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room = RoomID.from_string("!abc123:test")
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_room(
self.room.to_string(),
room_creator_user_id="@creator:text",
@@ -53,7 +47,6 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
)
- @defer.inlineCallbacks
def inject_state_event(self, room, sender, typ, state_key, content):
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
@@ -66,13 +59,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
},
)
- event, context = yield defer.ensureDeferred(
+ event, context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
- yield defer.ensureDeferred(
- self.storage.persistence.persist_event(event, context)
- )
+ self.get_success(self.storage.persistence.persist_event(event, context))
return event
@@ -82,16 +73,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(s1[t].event_id, s2[t].event_id)
self.assertEqual(len(s1), len(s2))
- @defer.inlineCallbacks
def test_get_state_groups_ids(self):
- e1 = yield self.inject_state_event(
- self.room, self.u_alice, EventTypes.Create, "", {}
- )
- e2 = yield self.inject_state_event(
+ e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
+ e2 = self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
- state_group_map = yield defer.ensureDeferred(
+ state_group_map = self.get_success(
self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
)
self.assertEqual(len(state_group_map), 1)
@@ -101,16 +89,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
{(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id},
)
- @defer.inlineCallbacks
def test_get_state_groups(self):
- e1 = yield self.inject_state_event(
- self.room, self.u_alice, EventTypes.Create, "", {}
- )
- e2 = yield self.inject_state_event(
+ e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
+ e2 = self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
- state_group_map = yield defer.ensureDeferred(
+ state_group_map = self.get_success(
self.storage.state.get_state_groups(self.room, [e2.event_id])
)
self.assertEqual(len(state_group_map), 1)
@@ -118,32 +103,29 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
- @defer.inlineCallbacks
def test_get_state_for_event(self):
# this defaults to a linear DAG as each new injection defaults to whatever
# forward extremities are currently in the DB for this room.
- e1 = yield self.inject_state_event(
- self.room, self.u_alice, EventTypes.Create, "", {}
- )
- e2 = yield self.inject_state_event(
+ e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
+ e2 = self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
- e3 = yield self.inject_state_event(
+ e3 = self.inject_state_event(
self.room,
self.u_alice,
EventTypes.Member,
self.u_alice.to_string(),
{"membership": Membership.JOIN},
)
- e4 = yield self.inject_state_event(
+ e4 = self.inject_state_event(
self.room,
self.u_bob,
EventTypes.Member,
self.u_bob.to_string(),
{"membership": Membership.JOIN},
)
- e5 = yield self.inject_state_event(
+ e5 = self.inject_state_event(
self.room,
self.u_bob,
EventTypes.Member,
@@ -152,9 +134,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we get the full state as of the final event
- state = yield defer.ensureDeferred(
- self.storage.state.get_state_for_event(e5.event_id)
- )
+ state = self.get_success(self.storage.state.get_state_for_event(e5.event_id))
self.assertIsNotNone(e4)
@@ -170,7 +150,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we can filter to the m.room.name event (with a '' state key)
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
)
@@ -179,7 +159,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can filter to the m.room.name event (with a wildcard None state key)
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
)
@@ -188,7 +168,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can grab the m.room.member events (with a wildcard None state key)
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
)
@@ -200,7 +180,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# check we can grab a specific room member without filtering out the
# other event types
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
@@ -220,7 +200,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check that we can grab everything except members
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
@@ -238,17 +218,14 @@ class StateStoreTestCase(tests.unittest.TestCase):
#######################################################
room_id = self.room.to_string()
- group_ids = yield defer.ensureDeferred(
+ group_ids = self.get_success(
self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
)
group = list(group_ids.keys())[0]
# test _get_state_for_group_using_cache correctly filters out members
# with types=[]
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -265,10 +242,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -281,10 +255,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with wildcard types
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -301,10 +272,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -324,10 +292,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -344,10 +309,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -360,10 +322,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -413,10 +372,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters out members
# with types=[]
room_id = self.room.to_string()
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -428,10 +384,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
room_id = self.room.to_string()
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -444,10 +397,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# wildcard types
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -458,10 +408,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -480,10 +427,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -494,10 +438,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -510,10 +451,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -524,10 +462,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({}, state_dict)
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index a6f63f4aaf..019c5b7b14 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,10 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
-from tests import unittest
-from tests.utils import setup_test_homeserver
+from tests.unittest import HomeserverTestCase, override_config
ALICE = "@alice:a"
BOB = "@bob:b"
@@ -25,73 +22,52 @@ BOBBY = "@bobby:a"
BELA = "@somenickname:a"
-class UserDirectoryStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- self.hs = yield setup_test_homeserver(self.addCleanup)
- self.store = self.hs.get_datastore()
+class UserDirectoryStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
# alice and bob are both in !room_id. bobby is not but shares
# a homeserver with alice.
- yield defer.ensureDeferred(
- self.store.update_profile_in_user_dir(ALICE, "alice", None)
- )
- yield defer.ensureDeferred(
- self.store.update_profile_in_user_dir(BOB, "bob", None)
- )
- yield defer.ensureDeferred(
- self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
- )
- yield defer.ensureDeferred(
- self.store.update_profile_in_user_dir(BELA, "Bela", None)
- )
- yield defer.ensureDeferred(
- self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
- )
+ self.get_success(self.store.update_profile_in_user_dir(ALICE, "alice", None))
+ self.get_success(self.store.update_profile_in_user_dir(BOB, "bob", None))
+ self.get_success(self.store.update_profile_in_user_dir(BOBBY, "bobby", None))
+ self.get_success(self.store.update_profile_in_user_dir(BELA, "Bela", None))
+ self.get_success(self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)))
- @defer.inlineCallbacks
def test_search_user_dir(self):
# normally when alice searches the directory she should just find
# bob because bobby doesn't share a room with her.
- r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
+ r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10))
self.assertFalse(r["limited"])
self.assertEqual(1, len(r["results"]))
self.assertDictEqual(
r["results"][0], {"user_id": BOB, "display_name": "bob", "avatar_url": None}
)
- @defer.inlineCallbacks
+ @override_config({"user_directory": {"search_all_users": True}})
def test_search_user_dir_all_users(self):
- self.hs.config.user_directory_search_all_users = True
- try:
- r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
- self.assertFalse(r["limited"])
- self.assertEqual(2, len(r["results"]))
- self.assertDictEqual(
- r["results"][0],
- {"user_id": BOB, "display_name": "bob", "avatar_url": None},
- )
- self.assertDictEqual(
- r["results"][1],
- {"user_id": BOBBY, "display_name": "bobby", "avatar_url": None},
- )
- finally:
- self.hs.config.user_directory_search_all_users = False
+ r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10))
+ self.assertFalse(r["limited"])
+ self.assertEqual(2, len(r["results"]))
+ self.assertDictEqual(
+ r["results"][0],
+ {"user_id": BOB, "display_name": "bob", "avatar_url": None},
+ )
+ self.assertDictEqual(
+ r["results"][1],
+ {"user_id": BOBBY, "display_name": "bobby", "avatar_url": None},
+ )
- @defer.inlineCallbacks
+ @override_config({"user_directory": {"search_all_users": True}})
def test_search_user_dir_stop_words(self):
"""Tests that a user can look up another user by searching for the start if its
display name even if that name happens to be a common English word that would
usually be ignored in full text searches.
"""
- self.hs.config.user_directory_search_all_users = True
- try:
- r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "be", 10))
- self.assertFalse(r["limited"])
- self.assertEqual(1, len(r["results"]))
- self.assertDictEqual(
- r["results"][0],
- {"user_id": BELA, "display_name": "Bela", "avatar_url": None},
- )
- finally:
- self.hs.config.user_directory_search_all_users = False
+ r = self.get_success(self.store.search_user_dir(ALICE, "be", 10))
+ self.assertFalse(r["limited"])
+ self.assertEqual(1, len(r["results"]))
+ self.assertDictEqual(
+ r["results"][0],
+ {"user_id": BELA, "display_name": "Bela", "avatar_url": None},
+ )
diff --git a/tests/test_distributor.py b/tests/test_distributor.py
index b57f36e6ac..6a6cf709f6 100644
--- a/tests/test_distributor.py
+++ b/tests/test_distributor.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock, patch
+from unittest.mock import Mock, patch
from synapse.util.distributor import Distributor
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 3f2691ee6b..b5f18344dc 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -207,6 +207,226 @@ class EventAuthTestCase(unittest.TestCase):
do_sig_check=False,
)
+ def test_join_rules_public(self):
+ """
+ Test joining a public room.
+ """
+ creator = "@creator:example.com"
+ pleb = "@joiner:example.com"
+
+ auth_events = {
+ ("m.room.create", ""): _create_event(creator),
+ ("m.room.member", creator): _join_event(creator),
+ ("m.room.join_rules", ""): _join_rules_event(creator, "public"),
+ }
+
+ # Check join.
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user cannot be force-joined to a room.
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.V6,
+ _member_event(pleb, "join", sender=creator),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # Banned should be rejected.
+ auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban")
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user who left can re-join.
+ auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave")
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user can send a join if they're in the room.
+ auth_events[("m.room.member", pleb)] = _member_event(pleb, "join")
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user can accept an invite.
+ auth_events[("m.room.member", pleb)] = _member_event(
+ pleb, "invite", sender=creator
+ )
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ def test_join_rules_invite(self):
+ """
+ Test joining an invite only room.
+ """
+ creator = "@creator:example.com"
+ pleb = "@joiner:example.com"
+
+ auth_events = {
+ ("m.room.create", ""): _create_event(creator),
+ ("m.room.member", creator): _join_event(creator),
+ ("m.room.join_rules", ""): _join_rules_event(creator, "invite"),
+ }
+
+ # A join without an invite is rejected.
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user cannot be force-joined to a room.
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.V6,
+ _member_event(pleb, "join", sender=creator),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # Banned should be rejected.
+ auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban")
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user who left cannot re-join.
+ auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave")
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user can send a join if they're in the room.
+ auth_events[("m.room.member", pleb)] = _member_event(pleb, "join")
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user can accept an invite.
+ auth_events[("m.room.member", pleb)] = _member_event(
+ pleb, "invite", sender=creator
+ )
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ def test_join_rules_msc3083_restricted(self):
+ """
+ Test joining a restricted room from MSC3083.
+
+ This is pretty much the same test as public.
+ """
+ creator = "@creator:example.com"
+ pleb = "@joiner:example.com"
+
+ auth_events = {
+ ("m.room.create", ""): _create_event(creator),
+ ("m.room.member", creator): _join_event(creator),
+ ("m.room.join_rules", ""): _join_rules_event(creator, "restricted"),
+ }
+
+ # Older room versions don't understand this join rule
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # Check join.
+ event_auth.check(
+ RoomVersions.MSC3083,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user cannot be force-joined to a room.
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.MSC3083,
+ _member_event(pleb, "join", sender=creator),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # Banned should be rejected.
+ auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban")
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.MSC3083,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user who left can re-join.
+ auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave")
+ event_auth.check(
+ RoomVersions.MSC3083,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user can send a join if they're in the room.
+ auth_events[("m.room.member", pleb)] = _member_event(pleb, "join")
+ event_auth.check(
+ RoomVersions.MSC3083,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user can accept an invite.
+ auth_events[("m.room.member", pleb)] = _member_event(
+ pleb, "invite", sender=creator
+ )
+ event_auth.check(
+ RoomVersions.MSC3083,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
# helpers for making events
@@ -225,19 +445,24 @@ def _create_event(user_id):
)
-def _join_event(user_id):
+def _member_event(user_id, membership, sender=None):
return make_event_from_dict(
{
"room_id": TEST_ROOM_ID,
"event_id": _get_event_id(),
"type": "m.room.member",
- "sender": user_id,
+ "sender": sender or user_id,
"state_key": user_id,
- "content": {"membership": "join"},
+ "content": {"membership": membership},
+ "prev_events": [],
}
)
+def _join_event(user_id):
+ return _member_event(user_id, "join")
+
+
def _power_levels_event(sender, content):
return make_event_from_dict(
{
@@ -277,6 +502,21 @@ def _random_state_event(sender):
)
+def _join_rules_event(sender, join_rule):
+ return make_event_from_dict(
+ {
+ "room_id": TEST_ROOM_ID,
+ "event_id": _get_event_id(),
+ "type": "m.room.join_rules",
+ "sender": sender,
+ "state_key": "",
+ "content": {
+ "join_rule": join_rule,
+ },
+ }
+ )
+
+
event_count = 0
diff --git a/tests/test_federation.py b/tests/test_federation.py
index fc9aab32d0..382cedbd5d 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from twisted.internet.defer import succeed
@@ -134,7 +134,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
}
)
- with LoggingContext():
+ with LoggingContext("test-context"):
failure = self.get_failure(
self.handler.on_receive_pdu(
"test.serv", lying_event, sent_to_us_directly=True
diff --git a/tests/test_mau.py b/tests/test_mau.py
index 75d28a42df..7d92a16a8d 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -15,9 +15,7 @@
"""Tests REST events for /rooms paths."""
-import json
-
-from synapse.api.constants import LoginType
+from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.appservice import ApplicationService
from synapse.rest.client.v2_alpha import register, sync
@@ -113,7 +111,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
)
)
- self.create_user("as_kermit4", token=as_token)
+ self.create_user("as_kermit4", token=as_token, appservice=True)
def test_allowed_after_a_month_mau(self):
# Create and sync so that the MAU counts get updated
@@ -232,14 +230,15 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.reactor.advance(100)
self.assertEqual(2, self.successResultOf(count))
- def create_user(self, localpart, token=None):
- request_data = json.dumps(
- {
- "username": localpart,
- "password": "monkey",
- "auth": {"type": LoginType.DUMMY},
- }
- )
+ def create_user(self, localpart, token=None, appservice=False):
+ request_data = {
+ "username": localpart,
+ "password": "monkey",
+ "auth": {"type": LoginType.DUMMY},
+ }
+
+ if appservice:
+ request_data["type"] = APP_SERVICE_REGISTRATION_TYPE
channel = self.make_request(
"POST",
diff --git a/tests/test_phone_home.py b/tests/test_phone_home.py
index e7aed092c2..0f800a075b 100644
--- a/tests/test_phone_home.py
+++ b/tests/test_phone_home.py
@@ -14,8 +14,7 @@
# limitations under the License.
import resource
-
-import mock
+from unittest import mock
from synapse.app.phone_stats_home import phone_stats_home
diff --git a/tests/test_state.py b/tests/test_state.py
index 6227a3ba95..0d626f49f6 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -12,8 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from mock import Mock
+from typing import List, Optional
+from unittest.mock import Mock
from twisted.internet import defer
@@ -37,8 +37,8 @@ def create_event(
state_key=None,
depth=2,
event_id=None,
- prev_events=[],
- **kwargs
+ prev_events: Optional[List[str]] = None,
+ **kwargs,
):
global _next_event_id
@@ -58,7 +58,7 @@ def create_event(
"sender": "@user_id:example.com",
"room_id": "!room_id:example.com",
"depth": depth,
- "prev_events": prev_events,
+ "prev_events": prev_events or [],
}
if state_key is not None:
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index a743cdc3a9..0df480db9f 100644
--- a/tests/test_terms_auth.py
+++ b/tests/test_terms_auth.py
@@ -13,8 +13,7 @@
# limitations under the License.
import json
-
-from mock import Mock
+from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactorClock
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 43898d8142..b557ffd692 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -21,8 +21,7 @@ import sys
import warnings
from asyncio import Future
from typing import Any, Awaitable, Callable, TypeVar
-
-from mock import Mock
+from unittest.mock import Mock
import attr
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index c3c4a93e1f..3dfbf8f8a9 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -33,7 +33,7 @@ async def inject_member_event(
membership: str,
target: Optional[str] = None,
extra_content: Optional[dict] = None,
- **kwargs
+ **kwargs,
) -> EventBase:
"""Inject a membership event into a room."""
if target is None:
@@ -58,7 +58,7 @@ async def inject_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
prev_event_ids: Optional[List[str]] = None,
- **kwargs
+ **kwargs,
) -> EventBase:
"""Inject a generic event into a room
@@ -83,7 +83,7 @@ async def create_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
prev_event_ids: Optional[List[str]] = None,
- **kwargs
+ **kwargs,
) -> Tuple[EventBase, EventContext]:
if room_version is None:
room_version = await hs.get_datastore().get_room_version_id(kwargs["room_id"])
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 510b630114..e502ac197e 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-
-from mock import Mock
+from typing import Optional
+from unittest.mock import Mock
from twisted.internet import defer
from twisted.internet.defer import succeed
@@ -147,9 +147,11 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
return event
@defer.inlineCallbacks
- def inject_room_member(self, user_id, membership="join", extra_content={}):
+ def inject_room_member(
+ self, user_id, membership="join", extra_content: Optional[dict] = None
+ ):
content = {"membership": membership}
- content.update(extra_content)
+ content.update(extra_content or {})
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
diff --git a/tests/unittest.py b/tests/unittest.py
index 58a4daa1ec..92764434bd 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -21,8 +21,7 @@ import inspect
import logging
import time
from typing import Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
-
-from mock import Mock, patch
+from unittest.mock import Mock, patch
from canonicaljson import json
@@ -471,7 +470,7 @@ class HomeserverTestCase(TestCase):
kwargs["config"] = config_obj
async def run_bg_updates():
- with LoggingContext("run_bg_updates", request="run_bg_updates-1"):
+ with LoggingContext("run_bg_updates"):
while not await stor.db_pool.updates.has_completed_background_updates():
await stor.db_pool.updates.do_next_background_update(1)
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index afb11b9caf..8c082e7432 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -15,8 +15,7 @@
# limitations under the License.
import logging
from typing import Set
-
-import mock
+from unittest import mock
from twisted.internet import defer, reactor
@@ -232,8 +231,7 @@ class DescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def do_lookup():
- with LoggingContext() as c1:
- c1.name = "c1"
+ with LoggingContext("c1") as c1:
r = yield obj.fn(1)
self.assertEqual(current_context(), c1)
return r
@@ -275,8 +273,7 @@ class DescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def do_lookup():
- with LoggingContext() as c1:
- c1.name = "c1"
+ with LoggingContext("c1") as c1:
try:
d = obj.fn(1)
self.assertEqual(
@@ -661,14 +658,13 @@ class CachedListDescriptorTestCase(unittest.TestCase):
@descriptors.cachedList("fn", "args1")
async def list_fn(self, args1, arg2):
- assert current_context().request == "c1"
+ assert current_context().name == "c1"
# we want this to behave like an asynchronous function
await run_on_reactor()
- assert current_context().request == "c1"
+ assert current_context().name == "c1"
return self.mock(args1, arg2)
- with LoggingContext() as c1:
- c1.request = "c1"
+ with LoggingContext("c1") as c1:
obj = Cls()
obj.mock.return_value = {10: "fish", 20: "chips"}
d1 = obj.list_fn([10, 20], 2)
diff --git a/tests/util/caches/test_ttlcache.py b/tests/util/caches/test_ttlcache.py
index 816795c136..23018081e5 100644
--- a/tests/util/caches/test_ttlcache.py
+++ b/tests/util/caches/test_ttlcache.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from synapse.util.caches.ttlcache import TTLCache
diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py
index 2012263184..d1372f6bc2 100644
--- a/tests/util/test_file_consumer.py
+++ b/tests/util/test_file_consumer.py
@@ -16,8 +16,7 @@
import threading
from io import StringIO
-
-from mock import NonCallableMock
+from unittest.mock import NonCallableMock
from twisted.internet import defer, reactor
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 58ee918f65..5d9c4665aa 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -17,11 +17,10 @@ from .. import unittest
class LoggingContextTestCase(unittest.TestCase):
def _check_test_key(self, value):
- self.assertEquals(current_context().request, value)
+ self.assertEquals(current_context().name, value)
def test_with_context(self):
- with LoggingContext() as context_one:
- context_one.request = "test"
+ with LoggingContext("test"):
self._check_test_key("test")
@defer.inlineCallbacks
@@ -30,15 +29,13 @@ class LoggingContextTestCase(unittest.TestCase):
@defer.inlineCallbacks
def competing_callback():
- with LoggingContext() as competing_context:
- competing_context.request = "competing"
+ with LoggingContext("competing"):
yield clock.sleep(0)
self._check_test_key("competing")
reactor.callLater(0, competing_callback)
- with LoggingContext() as context_one:
- context_one.request = "one"
+ with LoggingContext("one"):
yield clock.sleep(0)
self._check_test_key("one")
@@ -47,9 +44,7 @@ class LoggingContextTestCase(unittest.TestCase):
callback_completed = [False]
- with LoggingContext() as context_one:
- context_one.request = "one"
-
+ with LoggingContext("one"):
# fire off function, but don't wait on it.
d2 = run_in_background(function)
@@ -133,9 +128,7 @@ class LoggingContextTestCase(unittest.TestCase):
sentinel_context = current_context()
- with LoggingContext() as context_one:
- context_one.request = "one"
-
+ with LoggingContext("one"):
d1 = make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable
self.assertIs(current_context(), sentinel_context)
@@ -149,9 +142,7 @@ class LoggingContextTestCase(unittest.TestCase):
def test_make_deferred_yieldable_with_chained_deferreds(self):
sentinel_context = current_context()
- with LoggingContext() as context_one:
- context_one.request = "one"
-
+ with LoggingContext("one"):
d1 = make_deferred_yieldable(_chained_deferred_function())
# make sure that the context was reset by make_deferred_yieldable
self.assertIs(current_context(), sentinel_context)
@@ -166,9 +157,7 @@ class LoggingContextTestCase(unittest.TestCase):
"""Check that make_deferred_yieldable does the right thing when its
argument isn't actually a deferred"""
- with LoggingContext() as context_one:
- context_one.request = "one"
-
+ with LoggingContext("one"):
d1 = make_deferred_yieldable("bum")
self._check_test_key("one")
@@ -177,9 +166,9 @@ class LoggingContextTestCase(unittest.TestCase):
self._check_test_key("one")
def test_nested_logging_context(self):
- with LoggingContext(request="foo"):
+ with LoggingContext("foo"):
nested_context = nested_logging_context(suffix="bar")
- self.assertEqual(nested_context.request, "foo-bar")
+ self.assertEqual(nested_context.name, "foo-bar")
@defer.inlineCallbacks
def test_make_deferred_yieldable_with_await(self):
@@ -193,9 +182,7 @@ class LoggingContextTestCase(unittest.TestCase):
sentinel_context = current_context()
- with LoggingContext() as context_one:
- context_one.request = "one"
-
+ with LoggingContext("one"):
d1 = make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable
self.assertIs(current_context(), sentinel_context)
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index a739a6aaaf..ce4f1cc30a 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -14,7 +14,7 @@
# limitations under the License.
-from mock import Mock
+from unittest.mock import Mock
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache
diff --git a/tests/util/test_ratelimitutils.py b/tests/util/test_ratelimitutils.py
index 4d1aee91d5..3fed55090a 100644
--- a/tests/util/test_ratelimitutils.py
+++ b/tests/util/test_ratelimitutils.py
@@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional
+
from synapse.config.homeserver import HomeServerConfig
from synapse.util.ratelimitutils import FederationRateLimiter
@@ -89,9 +91,9 @@ def _await_resolution(reactor, d):
return (reactor.seconds() - start_time) * 1000
-def build_rc_config(settings={}):
+def build_rc_config(settings: Optional[dict] = None):
config_dict = default_config("test")
- config_dict.update(settings)
+ config_dict.update(settings or {})
config = HomeServerConfig()
config.parse_config_dict(config_dict, "", "")
return config.rc_federation
diff --git a/tests/utils.py b/tests/utils.py
index 5d299f766f..65d7ad58d9 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -21,10 +21,9 @@ import time
import uuid
import warnings
from typing import Type
+from unittest.mock import Mock, patch
from urllib import parse as urlparse
-from mock import Mock, patch
-
from twisted.internet import defer
from synapse.api.constants import EventTypes
@@ -122,7 +121,6 @@ def default_config(name, parse=False):
"enable_registration_captcha": False,
"macaroon_secret_key": "not even a little secret",
"trusted_third_party_id_servers": [],
- "room_invite_state_types": [],
"password_providers": [],
"worker_replication_url": "",
"worker_app": None,
@@ -198,7 +196,7 @@ def setup_test_homeserver(
config=None,
reactor=None,
homeserver_to_use: Type[HomeServer] = TestHomeServer,
- **kwargs
+ **kwargs,
):
"""
Setup a homeserver suitable for running tests against. Keyword arguments
|