diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index 483418192c..fa96ba07a5 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -5,38 +5,25 @@ from synapse.types import create_requester
from tests import unittest
-class TestRatelimiter(unittest.TestCase):
+class TestRatelimiter(unittest.HomeserverTestCase):
def test_allowed_via_can_do_action(self):
- limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
- allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=0)
- self.assertTrue(allowed)
- self.assertEquals(10.0, time_allowed)
-
- allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=5)
- self.assertFalse(allowed)
- self.assertEquals(10.0, time_allowed)
-
- allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=10)
- self.assertTrue(allowed)
- self.assertEquals(20.0, time_allowed)
-
- def test_allowed_user_via_can_requester_do_action(self):
- user_requester = create_requester("@user:example.com")
- limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
- allowed, time_allowed = limiter.can_requester_do_action(
- user_requester, _time_now_s=0
+ limiter = Ratelimiter(
+ store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ )
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(None, key="test_id", _time_now_s=0)
)
self.assertTrue(allowed)
self.assertEquals(10.0, time_allowed)
- allowed, time_allowed = limiter.can_requester_do_action(
- user_requester, _time_now_s=5
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(None, key="test_id", _time_now_s=5)
)
self.assertFalse(allowed)
self.assertEquals(10.0, time_allowed)
- allowed, time_allowed = limiter.can_requester_do_action(
- user_requester, _time_now_s=10
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(None, key="test_id", _time_now_s=10)
)
self.assertTrue(allowed)
self.assertEquals(20.0, time_allowed)
@@ -51,21 +38,23 @@ class TestRatelimiter(unittest.TestCase):
)
as_requester = create_requester("@user:example.com", app_service=appservice)
- limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
- allowed, time_allowed = limiter.can_requester_do_action(
- as_requester, _time_now_s=0
+ limiter = Ratelimiter(
+ store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ )
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(as_requester, _time_now_s=0)
)
self.assertTrue(allowed)
self.assertEquals(10.0, time_allowed)
- allowed, time_allowed = limiter.can_requester_do_action(
- as_requester, _time_now_s=5
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(as_requester, _time_now_s=5)
)
self.assertFalse(allowed)
self.assertEquals(10.0, time_allowed)
- allowed, time_allowed = limiter.can_requester_do_action(
- as_requester, _time_now_s=10
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(as_requester, _time_now_s=10)
)
self.assertTrue(allowed)
self.assertEquals(20.0, time_allowed)
@@ -80,73 +69,89 @@ class TestRatelimiter(unittest.TestCase):
)
as_requester = create_requester("@user:example.com", app_service=appservice)
- limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
- allowed, time_allowed = limiter.can_requester_do_action(
- as_requester, _time_now_s=0
+ limiter = Ratelimiter(
+ store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ )
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(as_requester, _time_now_s=0)
)
self.assertTrue(allowed)
self.assertEquals(-1, time_allowed)
- allowed, time_allowed = limiter.can_requester_do_action(
- as_requester, _time_now_s=5
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(as_requester, _time_now_s=5)
)
self.assertTrue(allowed)
self.assertEquals(-1, time_allowed)
- allowed, time_allowed = limiter.can_requester_do_action(
- as_requester, _time_now_s=10
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(as_requester, _time_now_s=10)
)
self.assertTrue(allowed)
self.assertEquals(-1, time_allowed)
def test_allowed_via_ratelimit(self):
- limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+ limiter = Ratelimiter(
+ store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ )
# Shouldn't raise
- limiter.ratelimit(key="test_id", _time_now_s=0)
+ self.get_success_or_raise(limiter.ratelimit(None, key="test_id", _time_now_s=0))
# Should raise
with self.assertRaises(LimitExceededError) as context:
- limiter.ratelimit(key="test_id", _time_now_s=5)
+ self.get_success_or_raise(
+ limiter.ratelimit(None, key="test_id", _time_now_s=5)
+ )
self.assertEqual(context.exception.retry_after_ms, 5000)
# Shouldn't raise
- limiter.ratelimit(key="test_id", _time_now_s=10)
+ self.get_success_or_raise(
+ limiter.ratelimit(None, key="test_id", _time_now_s=10)
+ )
def test_allowed_via_can_do_action_and_overriding_parameters(self):
"""Test that we can override options of can_do_action that would otherwise fail
an action
"""
# Create a Ratelimiter with a very low allowed rate_hz and burst_count
- limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+ limiter = Ratelimiter(
+ store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ )
# First attempt should be allowed
- allowed, time_allowed = limiter.can_do_action(
- ("test_id",),
- _time_now_s=0,
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(
+ None,
+ ("test_id",),
+ _time_now_s=0,
+ )
)
self.assertTrue(allowed)
self.assertEqual(10.0, time_allowed)
# Second attempt, 1s later, will fail
- allowed, time_allowed = limiter.can_do_action(
- ("test_id",),
- _time_now_s=1,
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(
+ None,
+ ("test_id",),
+ _time_now_s=1,
+ )
)
self.assertFalse(allowed)
self.assertEqual(10.0, time_allowed)
# But, if we allow 10 actions/sec for this request, we should be allowed
# to continue.
- allowed, time_allowed = limiter.can_do_action(
- ("test_id",), _time_now_s=1, rate_hz=10.0
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(None, ("test_id",), _time_now_s=1, rate_hz=10.0)
)
self.assertTrue(allowed)
self.assertEqual(1.1, time_allowed)
# Similarly if we allow a burst of 10 actions
- allowed, time_allowed = limiter.can_do_action(
- ("test_id",), _time_now_s=1, burst_count=10
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(None, ("test_id",), _time_now_s=1, burst_count=10)
)
self.assertTrue(allowed)
self.assertEqual(1.0, time_allowed)
@@ -156,29 +161,72 @@ class TestRatelimiter(unittest.TestCase):
fail an action
"""
# Create a Ratelimiter with a very low allowed rate_hz and burst_count
- limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+ limiter = Ratelimiter(
+ store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ )
# First attempt should be allowed
- limiter.ratelimit(key=("test_id",), _time_now_s=0)
+ self.get_success_or_raise(
+ limiter.ratelimit(None, key=("test_id",), _time_now_s=0)
+ )
# Second attempt, 1s later, will fail
with self.assertRaises(LimitExceededError) as context:
- limiter.ratelimit(key=("test_id",), _time_now_s=1)
+ self.get_success_or_raise(
+ limiter.ratelimit(None, key=("test_id",), _time_now_s=1)
+ )
self.assertEqual(context.exception.retry_after_ms, 9000)
# But, if we allow 10 actions/sec for this request, we should be allowed
# to continue.
- limiter.ratelimit(key=("test_id",), _time_now_s=1, rate_hz=10.0)
+ self.get_success_or_raise(
+ limiter.ratelimit(None, key=("test_id",), _time_now_s=1, rate_hz=10.0)
+ )
# Similarly if we allow a burst of 10 actions
- limiter.ratelimit(key=("test_id",), _time_now_s=1, burst_count=10)
+ self.get_success_or_raise(
+ limiter.ratelimit(None, key=("test_id",), _time_now_s=1, burst_count=10)
+ )
def test_pruning(self):
- limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
- limiter.can_do_action(key="test_id_1", _time_now_s=0)
+ limiter = Ratelimiter(
+ store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ )
+ self.get_success_or_raise(
+ limiter.can_do_action(None, key="test_id_1", _time_now_s=0)
+ )
self.assertIn("test_id_1", limiter.actions)
- limiter.can_do_action(key="test_id_2", _time_now_s=10)
+ self.get_success_or_raise(
+ limiter.can_do_action(None, key="test_id_2", _time_now_s=10)
+ )
self.assertNotIn("test_id_1", limiter.actions)
+
+ def test_db_user_override(self):
+ """Test that users that have ratelimiting disabled in the DB aren't
+ ratelimited.
+ """
+ store = self.hs.get_datastore()
+
+ user_id = "@user:test"
+ requester = create_requester(user_id)
+
+ self.get_success(
+ store.db_pool.simple_insert(
+ table="ratelimit_override",
+ values={
+ "user_id": user_id,
+ "messages_per_second": None,
+ "burst_count": None,
+ },
+ desc="test_db_user_override",
+ )
+ )
+
+ limiter = Ratelimiter(store=store, clock=None, rate_hz=0.1, burst_count=1)
+
+ # Shouldn't raise
+ for _ in range(20):
+ self.get_success_or_raise(limiter.ratelimit(requester, _time_now_s=0))
diff --git a/tests/config/test_load.py b/tests/config/test_load.py
index 734a9983e8..c109425671 100644
--- a/tests/config/test_load.py
+++ b/tests/config/test_load.py
@@ -20,6 +20,7 @@ from io import StringIO
import yaml
+from synapse.config import ConfigError
from synapse.config.homeserver import HomeServerConfig
from tests import unittest
@@ -35,9 +36,9 @@ class ConfigLoadingTestCase(unittest.TestCase):
def test_load_fails_if_server_name_missing(self):
self.generate_config_and_remove_lines_containing("server_name")
- with self.assertRaises(Exception):
+ with self.assertRaises(ConfigError):
HomeServerConfig.load_config("", ["-c", self.file])
- with self.assertRaises(Exception):
+ with self.assertRaises(ConfigError):
HomeServerConfig.load_or_generate_config("", ["-c", self.file])
def test_generates_and_loads_macaroon_secret_key(self):
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 30fcc4c1bf..946482b7e7 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -16,6 +16,7 @@ import time
from mock import Mock
+import attr
import canonicaljson
import signedjson.key
import signedjson.sign
@@ -68,6 +69,11 @@ class MockPerspectiveServer:
signedjson.sign.sign_json(res, self.server_name, self.key)
+@attr.s(slots=True)
+class FakeRequest:
+ id = attr.ib()
+
+
@logcontext_clean
class KeyringTestCase(unittest.HomeserverTestCase):
def check_context(self, val, expected):
@@ -89,7 +95,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
first_lookup_deferred = Deferred()
async def first_lookup_fetch(keys_to_fetch):
- self.assertEquals(current_context().request, "context_11")
+ self.assertEquals(current_context().request.id, "context_11")
self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}})
await make_deferred_yieldable(first_lookup_deferred)
@@ -102,9 +108,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
mock_fetcher.get_keys.side_effect = first_lookup_fetch
async def first_lookup():
- with LoggingContext("context_11") as context_11:
- context_11.request = "context_11"
-
+ with LoggingContext("context_11", request=FakeRequest("context_11")):
res_deferreds = kr.verify_json_objects_for_server(
[("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")]
)
@@ -130,7 +134,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# should block rather than start a second call
async def second_lookup_fetch(keys_to_fetch):
- self.assertEquals(current_context().request, "context_12")
+ self.assertEquals(current_context().request.id, "context_12")
return {
"server10": {
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
@@ -142,9 +146,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
second_lookup_state = [0]
async def second_lookup():
- with LoggingContext("context_12") as context_12:
- context_12.request = "context_12"
-
+ with LoggingContext("context_12", request=FakeRequest("context_12")):
res_deferreds_2 = kr.verify_json_objects_for_server(
[("server10", json1, 0, "test")]
)
@@ -589,10 +591,7 @@ def get_key_id(key):
@defer.inlineCallbacks
def run_in_context(f, *args, **kwargs):
- with LoggingContext("testctx") as ctx:
- # we set the "request" prop to make it easier to follow what's going on in the
- # logs.
- ctx.request = "testctx"
+ with LoggingContext("testctx"):
rv = yield f(*args, **kwargs)
return rv
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
new file mode 100644
index 0000000000..c6e547f11c
--- /dev/null
+++ b/tests/events/test_presence_router.py
@@ -0,0 +1,386 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
+
+from mock import Mock
+
+import attr
+
+from synapse.api.constants import EduTypes
+from synapse.events.presence_router import PresenceRouter
+from synapse.federation.units import Transaction
+from synapse.handlers.presence import UserPresenceState
+from synapse.module_api import ModuleApi
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, presence, room
+from synapse.types import JsonDict, StreamToken, create_requester
+
+from tests.handlers.test_sync import generate_sync_config
+from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config
+
+
+@attr.s
+class PresenceRouterTestConfig:
+ users_who_should_receive_all_presence = attr.ib(type=List[str], default=[])
+
+
+class PresenceRouterTestModule:
+ def __init__(self, config: PresenceRouterTestConfig, module_api: ModuleApi):
+ self._config = config
+ self._module_api = module_api
+
+ async def get_users_for_states(
+ self, state_updates: Iterable[UserPresenceState]
+ ) -> Dict[str, Set[UserPresenceState]]:
+ users_to_state = {
+ user_id: set(state_updates)
+ for user_id in self._config.users_who_should_receive_all_presence
+ }
+ return users_to_state
+
+ async def get_interested_users(
+ self, user_id: str
+ ) -> Union[Set[str], PresenceRouter.ALL_USERS]:
+ if user_id in self._config.users_who_should_receive_all_presence:
+ return PresenceRouter.ALL_USERS
+
+ return set()
+
+ @staticmethod
+ def parse_config(config_dict: dict) -> PresenceRouterTestConfig:
+ """Parse a configuration dictionary from the homeserver config, do
+ some validation and return a typed PresenceRouterConfig.
+
+ Args:
+ config_dict: The configuration dictionary.
+
+ Returns:
+ A validated config object.
+ """
+ # Initialise a typed config object
+ config = PresenceRouterTestConfig()
+
+ config.users_who_should_receive_all_presence = config_dict.get(
+ "users_who_should_receive_all_presence"
+ )
+
+ return config
+
+
+class PresenceRouterTestCase(FederatingHomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ presence.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ return self.setup_test_homeserver(
+ federation_transport_client=Mock(spec=["send_transaction"]),
+ )
+
+ def prepare(self, reactor, clock, homeserver):
+ self.sync_handler = self.hs.get_sync_handler()
+ self.module_api = homeserver.get_module_api()
+
+ @override_config(
+ {
+ "presence": {
+ "presence_router": {
+ "module": __name__ + ".PresenceRouterTestModule",
+ "config": {
+ "users_who_should_receive_all_presence": [
+ "@presence_gobbler:test",
+ ]
+ },
+ }
+ },
+ "send_federation": True,
+ }
+ )
+ def test_receiving_all_presence(self):
+ """Test that a user that does not share a room with another other can receive
+ presence for them, due to presence routing.
+ """
+ # Create a user who should receive all presence of others
+ self.presence_receiving_user_id = self.register_user(
+ "presence_gobbler", "monkey"
+ )
+ self.presence_receiving_user_tok = self.login("presence_gobbler", "monkey")
+
+ # And two users who should not have any special routing
+ self.other_user_one_id = self.register_user("other_user_one", "monkey")
+ self.other_user_one_tok = self.login("other_user_one", "monkey")
+ self.other_user_two_id = self.register_user("other_user_two", "monkey")
+ self.other_user_two_tok = self.login("other_user_two", "monkey")
+
+ # Put the other two users in a room with each other
+ room_id = self.helper.create_room_as(
+ self.other_user_one_id, tok=self.other_user_one_tok
+ )
+
+ self.helper.invite(
+ room_id,
+ self.other_user_one_id,
+ self.other_user_two_id,
+ tok=self.other_user_one_tok,
+ )
+ self.helper.join(room_id, self.other_user_two_id, tok=self.other_user_two_tok)
+ # User one sends some presence
+ send_presence_update(
+ self,
+ self.other_user_one_id,
+ self.other_user_one_tok,
+ "online",
+ "boop",
+ )
+
+ # Check that the presence receiving user gets user one's presence when syncing
+ presence_updates, sync_token = sync_presence(
+ self, self.presence_receiving_user_id
+ )
+ self.assertEqual(len(presence_updates), 1)
+
+ presence_update = presence_updates[0] # type: UserPresenceState
+ self.assertEqual(presence_update.user_id, self.other_user_one_id)
+ self.assertEqual(presence_update.state, "online")
+ self.assertEqual(presence_update.status_msg, "boop")
+
+ # Have all three users send presence
+ send_presence_update(
+ self,
+ self.other_user_one_id,
+ self.other_user_one_tok,
+ "online",
+ "user_one",
+ )
+ send_presence_update(
+ self,
+ self.other_user_two_id,
+ self.other_user_two_tok,
+ "online",
+ "user_two",
+ )
+ send_presence_update(
+ self,
+ self.presence_receiving_user_id,
+ self.presence_receiving_user_tok,
+ "online",
+ "presence_gobbler",
+ )
+
+ # Check that the presence receiving user gets everyone's presence
+ presence_updates, _ = sync_presence(
+ self, self.presence_receiving_user_id, sync_token
+ )
+ self.assertEqual(len(presence_updates), 3)
+
+ # But that User One only get itself and User Two's presence
+ presence_updates, _ = sync_presence(self, self.other_user_one_id)
+ self.assertEqual(len(presence_updates), 2)
+
+ found = False
+ for update in presence_updates:
+ if update.user_id == self.other_user_two_id:
+ self.assertEqual(update.state, "online")
+ self.assertEqual(update.status_msg, "user_two")
+ found = True
+
+ self.assertTrue(found)
+
+ @override_config(
+ {
+ "presence": {
+ "presence_router": {
+ "module": __name__ + ".PresenceRouterTestModule",
+ "config": {
+ "users_who_should_receive_all_presence": [
+ "@presence_gobbler1:test",
+ "@presence_gobbler2:test",
+ "@far_away_person:island",
+ ]
+ },
+ }
+ },
+ "send_federation": True,
+ }
+ )
+ def test_send_local_online_presence_to_with_module(self):
+ """Tests that send_local_presence_to_users sends local online presence to a set
+ of specified local and remote users, with a custom PresenceRouter module enabled.
+ """
+ # Create a user who will send presence updates
+ self.other_user_id = self.register_user("other_user", "monkey")
+ self.other_user_tok = self.login("other_user", "monkey")
+
+ # And another two users that will also send out presence updates, as well as receive
+ # theirs and everyone else's
+ self.presence_receiving_user_one_id = self.register_user(
+ "presence_gobbler1", "monkey"
+ )
+ self.presence_receiving_user_one_tok = self.login("presence_gobbler1", "monkey")
+ self.presence_receiving_user_two_id = self.register_user(
+ "presence_gobbler2", "monkey"
+ )
+ self.presence_receiving_user_two_tok = self.login("presence_gobbler2", "monkey")
+
+ # Have all three users send some presence updates
+ send_presence_update(
+ self,
+ self.other_user_id,
+ self.other_user_tok,
+ "online",
+ "I'm online!",
+ )
+ send_presence_update(
+ self,
+ self.presence_receiving_user_one_id,
+ self.presence_receiving_user_one_tok,
+ "online",
+ "I'm also online!",
+ )
+ send_presence_update(
+ self,
+ self.presence_receiving_user_two_id,
+ self.presence_receiving_user_two_tok,
+ "unavailable",
+ "I'm in a meeting!",
+ )
+
+ # Mark each presence-receiving user for receiving all user presence
+ self.get_success(
+ self.module_api.send_local_online_presence_to(
+ [
+ self.presence_receiving_user_one_id,
+ self.presence_receiving_user_two_id,
+ ]
+ )
+ )
+
+ # Perform a sync for each user
+
+ # The other user should only receive their own presence
+ presence_updates, _ = sync_presence(self, self.other_user_id)
+ self.assertEqual(len(presence_updates), 1)
+
+ presence_update = presence_updates[0] # type: UserPresenceState
+ self.assertEqual(presence_update.user_id, self.other_user_id)
+ self.assertEqual(presence_update.state, "online")
+ self.assertEqual(presence_update.status_msg, "I'm online!")
+
+ # Whereas both presence receiving users should receive everyone's presence updates
+ presence_updates, _ = sync_presence(self, self.presence_receiving_user_one_id)
+ self.assertEqual(len(presence_updates), 3)
+ presence_updates, _ = sync_presence(self, self.presence_receiving_user_two_id)
+ self.assertEqual(len(presence_updates), 3)
+
+ # Test that sending to a remote user works
+ remote_user_id = "@far_away_person:island"
+
+ # Note that due to the remote user being in our module's
+ # users_who_should_receive_all_presence config, they would have
+ # received user presence updates already.
+ #
+ # Thus we reset the mock, and try sending all online local user
+ # presence again
+ self.hs.get_federation_transport_client().send_transaction.reset_mock()
+
+ # Broadcast local user online presence
+ self.get_success(
+ self.module_api.send_local_online_presence_to([remote_user_id])
+ )
+
+ # Check that the expected presence updates were sent
+ expected_users = [
+ self.other_user_id,
+ self.presence_receiving_user_one_id,
+ self.presence_receiving_user_two_id,
+ ]
+
+ calls = (
+ self.hs.get_federation_transport_client().send_transaction.call_args_list
+ )
+ for call in calls:
+ federation_transaction = call.args[0] # type: Transaction
+
+ # Get the sent EDUs in this transaction
+ edus = federation_transaction.get_dict()["edus"]
+
+ for edu in edus:
+ # Make sure we're only checking presence-type EDUs
+ if edu["edu_type"] != EduTypes.Presence:
+ continue
+
+ # EDUs can contain multiple presence updates
+ for presence_update in edu["content"]["push"]:
+ # Check for presence updates that contain the user IDs we're after
+ expected_users.remove(presence_update["user_id"])
+
+ # Ensure that no offline states are being sent out
+ self.assertNotEqual(presence_update["presence"], "offline")
+
+ self.assertEqual(len(expected_users), 0)
+
+
+def send_presence_update(
+ testcase: TestCase,
+ user_id: str,
+ access_token: str,
+ presence_state: str,
+ status_message: Optional[str] = None,
+) -> JsonDict:
+ # Build the presence body
+ body = {"presence": presence_state}
+ if status_message:
+ body["status_msg"] = status_message
+
+ # Update the user's presence state
+ channel = testcase.make_request(
+ "PUT", "/presence/%s/status" % (user_id,), body, access_token=access_token
+ )
+ testcase.assertEqual(channel.code, 200)
+
+ return channel.json_body
+
+
+def sync_presence(
+ testcase: TestCase,
+ user_id: str,
+ since_token: Optional[StreamToken] = None,
+) -> Tuple[List[UserPresenceState], StreamToken]:
+ """Perform a sync request for the given user and return the user presence updates
+ they've received, as well as the next_batch token.
+
+ This method assumes testcase.sync_handler points to the homeserver's sync handler.
+
+ Args:
+ testcase: The testcase that is currently being run.
+ user_id: The ID of the user to generate a sync response for.
+ since_token: An optional token to indicate from at what point to sync from.
+
+ Returns:
+ A tuple containing a list of presence updates, and the sync response's
+ next_batch token.
+ """
+ requester = create_requester(user_id)
+ sync_config = generate_sync_config(requester.user.to_string())
+ sync_result = testcase.get_success(
+ testcase.sync_handler.wait_for_sync_for_user(
+ requester, sync_config, since_token
+ )
+ )
+
+ return sync_result.presence, sync_result.next_batch
diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index 6f96cd7940..95eac6a5a3 100644
--- a/tests/federation/test_federation_catch_up.py
+++ b/tests/federation/test_federation_catch_up.py
@@ -2,6 +2,7 @@ from typing import List, Tuple
from mock import Mock
+from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.federation.sender import PerDestinationQueue, TransactionManager
from synapse.federation.units import Edu
@@ -421,3 +422,51 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
self.assertNotIn("zzzerver", woken)
# - all destinations are woken exactly once; they appear once in woken.
self.assertCountEqual(woken, server_names[:-1])
+
+ @override_config({"send_federation": True})
+ def test_not_latest_event(self):
+ """Test that we send the latest event in the room even if its not ours."""
+
+ per_dest_queue, sent_pdus = self.make_fake_destination_queue()
+
+ # Make a room with a local user, and two servers. One will go offline
+ # and one will send some events.
+ self.register_user("u1", "you the one")
+ u1_token = self.login("u1", "you the one")
+ room_1 = self.helper.create_room_as("u1", tok=u1_token)
+
+ self.get_success(
+ event_injection.inject_member_event(self.hs, room_1, "@user:host2", "join")
+ )
+ event_1 = self.get_success(
+ event_injection.inject_member_event(self.hs, room_1, "@user:host3", "join")
+ )
+
+ # First we send something from the local server, so that we notice the
+ # remote is down and go into catchup mode.
+ self.helper.send(room_1, "you hear me!!", tok=u1_token)
+
+ # Now simulate us receiving an event from the still online remote.
+ event_2 = self.get_success(
+ event_injection.inject_event(
+ self.hs,
+ type=EventTypes.Message,
+ sender="@user:host3",
+ room_id=room_1,
+ content={"msgtype": "m.text", "body": "Hello"},
+ )
+ )
+
+ self.get_success(
+ self.hs.get_datastore().set_destination_last_successful_stream_ordering(
+ "host2", event_1.internal_metadata.stream_ordering
+ )
+ )
+
+ self.get_success(per_dest_queue._catch_up_transmission_loop())
+
+ # We expect only the last message from the remote, event_2, to have been
+ # sent, rather than the last *local* event that was sent.
+ self.assertEqual(len(sent_pdus), 1)
+ self.assertEqual(sent_pdus[0].event_id, event_2.event_id)
+ self.assertFalse(per_dest_queue._catching_up)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 5e9c9c2e88..c7796fb837 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -989,6 +989,138 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
self.assertRenderedError("mapping_error", "localpart is invalid: ")
+ @override_config(
+ {
+ "oidc_config": {
+ **DEFAULT_CONFIG,
+ "attribute_requirements": [{"attribute": "test", "value": "foobar"}],
+ }
+ }
+ )
+ def test_attribute_requirements(self):
+ """The required attributes must be met from the OIDC userinfo response."""
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ # userinfo lacking "test": "foobar" attribute should fail.
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+
+ # userinfo with "test": "foobar" attribute should succeed.
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": "foobar",
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@tester:test", "oidc", ANY, ANY, None, new_user=True
+ )
+
+ @override_config(
+ {
+ "oidc_config": {
+ **DEFAULT_CONFIG,
+ "attribute_requirements": [{"attribute": "test", "value": "foobar"}],
+ }
+ }
+ )
+ def test_attribute_requirements_contains(self):
+ """Test that auth succeeds if userinfo attribute CONTAINS required value"""
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+ # userinfo with "test": ["foobar", "foo", "bar"] attribute should succeed.
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": ["foobar", "foo", "bar"],
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@tester:test", "oidc", ANY, ANY, None, new_user=True
+ )
+
+ @override_config(
+ {
+ "oidc_config": {
+ **DEFAULT_CONFIG,
+ "attribute_requirements": [{"attribute": "test", "value": "foobar"}],
+ }
+ }
+ )
+ def test_attribute_requirements_mismatch(self):
+ """
+ Test that auth fails if attributes exist but don't match,
+ or are non-string values.
+ """
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+ # userinfo with "test": "not_foobar" attribute should fail
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": "not_foobar",
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+
+ # userinfo with "test": ["foo", "bar"] attribute should fail
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": ["foo", "bar"],
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+
+ # userinfo with "test": False attribute should fail
+ # this is largely just to ensure we don't crash here
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": False,
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+
+ # userinfo with "test": None attribute should fail
+ # a value of None breaks the OIDC spec, but it's important to not crash here
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": None,
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+
+ # userinfo with "test": 1 attribute should fail
+ # this is largely just to ensure we don't crash here
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": 1,
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+
+ # userinfo with "test": 3.14 attribute should fail
+ # this is largely just to ensure we don't crash here
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": 3.14,
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+
def _generate_oidc_session_token(
self,
state: str,
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 996c614198..77330f59a9 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -310,6 +310,26 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertIsNotNone(new_state)
self.assertEquals(new_state.state, PresenceState.UNAVAILABLE)
+ def test_busy_no_idle(self):
+ """
+ Tests that a user setting their presence to busy but idling doesn't turn their
+ presence state into unavailable.
+ """
+ user_id = "@foo:bar"
+ now = 5000000
+
+ state = UserPresenceState.default(user_id)
+ state = state.copy_and_replace(
+ state=PresenceState.BUSY,
+ last_active_ts=now - IDLE_TIMER - 1,
+ last_user_sync_ts=now,
+ )
+
+ new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
+
+ self.assertIsNotNone(new_state)
+ self.assertEquals(new_state.state, PresenceState.BUSY)
+
def test_sync_timeout(self):
user_id = "@foo:bar"
now = 5000000
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index e62586142e..8e950f25c5 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -37,7 +37,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
def test_wait_for_sync_for_user_auth_blocking(self):
user_id1 = "@user1:test"
user_id2 = "@user2:test"
- sync_config = self._generate_sync_config(user_id1)
+ sync_config = generate_sync_config(user_id1)
requester = create_requester(user_id1)
self.reactor.advance(100) # So we get not 0 time
@@ -60,7 +60,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.auth_blocking._hs_disabled = False
- sync_config = self._generate_sync_config(user_id2)
+ sync_config = generate_sync_config(user_id2)
requester = create_requester(user_id2)
e = self.get_failure(
@@ -69,11 +69,12 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
- def _generate_sync_config(self, user_id):
- return SyncConfig(
- user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]),
- filter_collection=DEFAULT_FILTER_COLLECTION,
- is_guest=False,
- request_key="request_key",
- device_id="device_id",
- )
+
+def generate_sync_config(user_id: str) -> SyncConfig:
+ return SyncConfig(
+ user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]),
+ filter_collection=DEFAULT_FILTER_COLLECTION,
+ is_guest=False,
+ request_key="request_key",
+ device_id="device_id",
+ )
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index 505ffcd300..3ea8b5bec7 100644
--- a/tests/http/test_proxyagent.py
+++ b/tests/http/test_proxyagent.py
@@ -12,8 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import base64
import logging
import os
+from typing import Optional
from unittest.mock import patch
import treq
@@ -242,6 +244,21 @@ class MatrixFederationAgentTests(TestCase):
@patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"})
def test_https_request_via_proxy(self):
+ """Tests that TLS-encrypted requests can be made through a proxy"""
+ self._do_https_request_via_proxy(auth_credentials=None)
+
+ @patch.dict(
+ os.environ,
+ {"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"},
+ )
+ def test_https_request_via_proxy_with_auth(self):
+ """Tests that authenticated, TLS-encrypted requests can be made through a proxy"""
+ self._do_https_request_via_proxy(auth_credentials="bob:pinkponies")
+
+ def _do_https_request_via_proxy(
+ self,
+ auth_credentials: Optional[str] = None,
+ ):
agent = ProxyAgent(
self.reactor,
contextFactory=get_test_https_policy(),
@@ -278,6 +295,22 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(request.method, b"CONNECT")
self.assertEqual(request.path, b"test.com:443")
+ # Check whether auth credentials have been supplied to the proxy
+ proxy_auth_header_values = request.requestHeaders.getRawHeaders(
+ b"Proxy-Authorization"
+ )
+
+ if auth_credentials is not None:
+ # Compute the correct header value for Proxy-Authorization
+ encoded_credentials = base64.b64encode(b"bob:pinkponies")
+ expected_header_value = b"Basic " + encoded_credentials
+
+ # Validate the header's value
+ self.assertIn(expected_header_value, proxy_auth_header_values)
+ else:
+ # Check that the Proxy-Authorization header has not been supplied to the proxy
+ self.assertIsNone(proxy_auth_header_values)
+
# tell the proxy server not to close the connection
proxy_server.persistent = True
@@ -312,6 +345,13 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(request.method, b"GET")
self.assertEqual(request.path, b"/abc")
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+
+ # Check that the destination server DID NOT receive proxy credentials
+ proxy_auth_header_values = request.requestHeaders.getRawHeaders(
+ b"Proxy-Authorization"
+ )
+ self.assertIsNone(proxy_auth_header_values)
+
request.write(b"result")
request.finish()
diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index 48a74e2eee..bfe0d11c93 100644
--- a/tests/logging/test_terse_json.py
+++ b/tests/logging/test_terse_json.py
@@ -12,15 +12,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import json
import logging
-from io import StringIO
+from io import BytesIO, StringIO
+
+from mock import Mock, patch
+
+from twisted.web.server import Request
+from synapse.http.site import SynapseRequest
from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter
from synapse.logging.context import LoggingContext, LoggingContextFilter
from tests.logging import LoggerCleanupMixin
+from tests.server import FakeChannel
from tests.unittest import TestCase
@@ -120,7 +125,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
handler.addFilter(LoggingContextFilter())
logger = self.get_logger(handler)
- with LoggingContext(request="test"):
+ with LoggingContext("name"):
logger.info("Hello there, %s!", "wally")
log = self.get_log_line()
@@ -134,4 +139,61 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
]
self.assertCountEqual(log.keys(), expected_log_keys)
self.assertEqual(log["log"], "Hello there, wally!")
- self.assertEqual(log["request"], "test")
+ self.assertTrue(log["request"].startswith("name@"))
+
+ def test_with_request_context(self):
+ """
+ Information from the logging context request should be added to the JSON response.
+ """
+ handler = logging.StreamHandler(self.output)
+ handler.setFormatter(JsonFormatter())
+ handler.addFilter(LoggingContextFilter())
+ logger = self.get_logger(handler)
+
+ # A full request isn't needed here.
+ site = Mock(spec=["site_tag", "server_version_string", "getResourceFor"])
+ site.site_tag = "test-site"
+ site.server_version_string = "Server v1"
+ request = SynapseRequest(FakeChannel(site, None))
+ # Call requestReceived to finish instantiating the object.
+ request.content = BytesIO()
+ # Partially skip some of the internal processing of SynapseRequest.
+ request._started_processing = Mock()
+ request.request_metrics = Mock(spec=["name"])
+ with patch.object(Request, "render"):
+ request.requestReceived(b"POST", b"/_matrix/client/versions", b"1.1")
+
+ # Also set the requester to ensure the processing works.
+ request.requester = "@foo:test"
+
+ with LoggingContext(parent_context=request.logcontext):
+ logger.info("Hello there, %s!", "wally")
+
+ log = self.get_log_line()
+
+ # The terse logger includes additional request information, if possible.
+ expected_log_keys = [
+ "log",
+ "level",
+ "namespace",
+ "request",
+ "ip_address",
+ "site_tag",
+ "requester",
+ "authenticated_entity",
+ "method",
+ "url",
+ "protocol",
+ "user_agent",
+ ]
+ self.assertCountEqual(log.keys(), expected_log_keys)
+ self.assertEqual(log["log"], "Hello there, wally!")
+ self.assertTrue(log["request"].startswith("POST-"))
+ self.assertEqual(log["ip_address"], "127.0.0.1")
+ self.assertEqual(log["site_tag"], "test-site")
+ self.assertEqual(log["requester"], "@foo:test")
+ self.assertEqual(log["authenticated_entity"], "@foo:test")
+ self.assertEqual(log["method"], "POST")
+ self.assertEqual(log["url"], "/_matrix/client/versions")
+ self.assertEqual(log["protocol"], "1.1")
+ self.assertEqual(log["user_agent"], "")
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index edacd1b566..1d1fceeecf 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -14,25 +14,37 @@
# limitations under the License.
from mock import Mock
+from synapse.api.constants import EduTypes
from synapse.events import EventBase
+from synapse.federation.units import Transaction
+from synapse.handlers.presence import UserPresenceState
from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v1 import login, presence, room
from synapse.types import create_requester
-from tests.unittest import HomeserverTestCase
+from tests.events.test_presence_router import send_presence_update, sync_presence
+from tests.test_utils.event_injection import inject_member_event
+from tests.unittest import FederatingHomeserverTestCase, override_config
-class ModuleApiTestCase(HomeserverTestCase):
+class ModuleApiTestCase(FederatingHomeserverTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
room.register_servlets,
+ presence.register_servlets,
]
def prepare(self, reactor, clock, homeserver):
self.store = homeserver.get_datastore()
self.module_api = homeserver.get_module_api()
self.event_creation_handler = homeserver.get_event_creation_handler()
+ self.sync_handler = homeserver.get_sync_handler()
+
+ def make_homeserver(self, reactor, clock):
+ return self.setup_test_homeserver(
+ federation_transport_client=Mock(spec=["send_transaction"]),
+ )
def test_can_register_user(self):
"""Tests that an external module can register a user"""
@@ -205,3 +217,160 @@ class ModuleApiTestCase(HomeserverTestCase):
)
)
self.assertFalse(is_in_public_rooms)
+
+ # The ability to send federation is required by send_local_online_presence_to.
+ @override_config({"send_federation": True})
+ def test_send_local_online_presence_to(self):
+ """Tests that send_local_presence_to_users sends local online presence to local users."""
+ # Create a user who will send presence updates
+ self.presence_receiver_id = self.register_user("presence_receiver", "monkey")
+ self.presence_receiver_tok = self.login("presence_receiver", "monkey")
+
+ # And another user that will send presence updates out
+ self.presence_sender_id = self.register_user("presence_sender", "monkey")
+ self.presence_sender_tok = self.login("presence_sender", "monkey")
+
+ # Put them in a room together so they will receive each other's presence updates
+ room_id = self.helper.create_room_as(
+ self.presence_receiver_id,
+ tok=self.presence_receiver_tok,
+ )
+ self.helper.join(room_id, self.presence_sender_id, tok=self.presence_sender_tok)
+
+ # Presence sender comes online
+ send_presence_update(
+ self,
+ self.presence_sender_id,
+ self.presence_sender_tok,
+ "online",
+ "I'm online!",
+ )
+
+ # Presence receiver should have received it
+ presence_updates, sync_token = sync_presence(self, self.presence_receiver_id)
+ self.assertEqual(len(presence_updates), 1)
+
+ presence_update = presence_updates[0] # type: UserPresenceState
+ self.assertEqual(presence_update.user_id, self.presence_sender_id)
+ self.assertEqual(presence_update.state, "online")
+
+ # Syncing again should result in no presence updates
+ presence_updates, sync_token = sync_presence(
+ self, self.presence_receiver_id, sync_token
+ )
+ self.assertEqual(len(presence_updates), 0)
+
+ # Trigger sending local online presence
+ self.get_success(
+ self.module_api.send_local_online_presence_to(
+ [
+ self.presence_receiver_id,
+ ]
+ )
+ )
+
+ # Presence receiver should have received online presence again
+ presence_updates, sync_token = sync_presence(
+ self, self.presence_receiver_id, sync_token
+ )
+ self.assertEqual(len(presence_updates), 1)
+
+ presence_update = presence_updates[0] # type: UserPresenceState
+ self.assertEqual(presence_update.user_id, self.presence_sender_id)
+ self.assertEqual(presence_update.state, "online")
+
+ # Presence sender goes offline
+ send_presence_update(
+ self,
+ self.presence_sender_id,
+ self.presence_sender_tok,
+ "offline",
+ "I slink back into the darkness.",
+ )
+
+ # Trigger sending local online presence
+ self.get_success(
+ self.module_api.send_local_online_presence_to(
+ [
+ self.presence_receiver_id,
+ ]
+ )
+ )
+
+ # Presence receiver should *not* have received offline state
+ presence_updates, sync_token = sync_presence(
+ self, self.presence_receiver_id, sync_token
+ )
+ self.assertEqual(len(presence_updates), 0)
+
+ @override_config({"send_federation": True})
+ def test_send_local_online_presence_to_federation(self):
+ """Tests that send_local_presence_to_users sends local online presence to remote users."""
+ # Create a user who will send presence updates
+ self.presence_sender_id = self.register_user("presence_sender", "monkey")
+ self.presence_sender_tok = self.login("presence_sender", "monkey")
+
+ # And a room they're a part of
+ room_id = self.helper.create_room_as(
+ self.presence_sender_id,
+ tok=self.presence_sender_tok,
+ )
+
+ # Mark them as online
+ send_presence_update(
+ self,
+ self.presence_sender_id,
+ self.presence_sender_tok,
+ "online",
+ "I'm online!",
+ )
+
+ # Make up a remote user to send presence to
+ remote_user_id = "@far_away_person:island"
+
+ # Create a join membership event for the remote user into the room.
+ # This allows presence information to flow from one user to the other.
+ self.get_success(
+ inject_member_event(
+ self.hs,
+ room_id,
+ sender=remote_user_id,
+ target=remote_user_id,
+ membership="join",
+ )
+ )
+
+ # The remote user would have received the existing room members' presence
+ # when they joined the room.
+ #
+ # Thus we reset the mock, and try sending online local user
+ # presence again
+ self.hs.get_federation_transport_client().send_transaction.reset_mock()
+
+ # Broadcast local user online presence
+ self.get_success(
+ self.module_api.send_local_online_presence_to([remote_user_id])
+ )
+
+ # Check that a presence update was sent as part of a federation transaction
+ found_update = False
+ calls = (
+ self.hs.get_federation_transport_client().send_transaction.call_args_list
+ )
+ for call in calls:
+ federation_transaction = call.args[0] # type: Transaction
+
+ # Get the sent EDUs in this transaction
+ edus = federation_transaction.get_dict()["edus"]
+
+ for edu in edus:
+ # Make sure we're only checking presence-type EDUs
+ if edu["edu_type"] != EduTypes.Presence:
+ continue
+
+ # EDUs can contain multiple presence updates
+ for presence_update in edu["content"]["push"]:
+ if presence_update["user_id"] == self.presence_sender_id:
+ found_update = True
+
+ self.assertTrue(found_update)
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 67b7913666..1d4a592862 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -44,7 +44,7 @@ from tests.server import FakeTransport
try:
import hiredis
except ImportError:
- hiredis = None
+ hiredis = None # type: ignore
logger = logging.getLogger(__name__)
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index 5acfb3e53e..ca49d4dd3a 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -69,6 +69,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.assert_request_is_get_repl_stream_updates(request, "typing")
# The from token should be the token from the last RDATA we got.
+ assert request.args is not None
self.assertEqual(int(request.args[b"from_token"][0]), token)
self.test_handler.on_rdata.assert_called_once()
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 7ff11cde10..b0800f9840 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -15,7 +15,7 @@
import logging
import os
from binascii import unhexlify
-from typing import Tuple
+from typing import Optional, Tuple
from twisted.internet.protocol import Factory
from twisted.protocols.tls import TLSMemoryBIOFactory
@@ -32,7 +32,7 @@ from tests.server import FakeChannel, FakeSite, FakeTransport, make_request
logger = logging.getLogger(__name__)
-test_server_connection_factory = None
+test_server_connection_factory = None # type: Optional[TestServerTLSConnectionFactory]
class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index e58d5cf0db..0c9ec133c2 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -28,7 +28,7 @@ from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
from synapse.api.room_versions import RoomVersions
from synapse.rest.client.v1 import login, logout, profile, room
from synapse.rest.client.v2_alpha import devices, sync
-from synapse.types import JsonDict
+from synapse.types import JsonDict, UserID
from tests import unittest
from tests.server import FakeSite, make_request
@@ -467,6 +467,8 @@ class UsersListTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v2/users"
def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
@@ -634,6 +636,26 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+ # unkown order_by
+ channel = self.make_request(
+ "GET",
+ self.url + "?order_by=bar",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+ # invalid search order
+ channel = self.make_request(
+ "GET",
+ self.url + "?dir=bar",
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
def test_limit(self):
"""
Testing list of users with limit
@@ -759,6 +781,103 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(channel.json_body["users"]), 1)
self.assertNotIn("next_token", channel.json_body)
+ def test_order_by(self):
+ """
+ Testing order list with parameter `order_by`
+ """
+
+ user1 = self.register_user("user1", "pass1", admin=False, displayname="Name Z")
+ user2 = self.register_user("user2", "pass2", admin=False, displayname="Name Y")
+
+ # Modify user
+ self.get_success(self.store.set_user_deactivated_status(user1, True))
+ self.get_success(self.store.set_shadow_banned(UserID.from_string(user1), True))
+
+ # Set avatar URL to all users, that no user has a NULL value to avoid
+ # different sort order between SQlite and PostreSQL
+ self.get_success(self.store.set_profile_avatar_url("user1", "mxc://url3"))
+ self.get_success(self.store.set_profile_avatar_url("user2", "mxc://url2"))
+ self.get_success(self.store.set_profile_avatar_url("admin", "mxc://url1"))
+
+ # order by default (name)
+ self._order_test([self.admin_user, user1, user2], None)
+ self._order_test([self.admin_user, user1, user2], None, "f")
+ self._order_test([user2, user1, self.admin_user], None, "b")
+
+ # order by name
+ self._order_test([self.admin_user, user1, user2], "name")
+ self._order_test([self.admin_user, user1, user2], "name", "f")
+ self._order_test([user2, user1, self.admin_user], "name", "b")
+
+ # order by displayname
+ self._order_test([user2, user1, self.admin_user], "displayname")
+ self._order_test([user2, user1, self.admin_user], "displayname", "f")
+ self._order_test([self.admin_user, user1, user2], "displayname", "b")
+
+ # order by is_guest
+ # like sort by ascending name, as no guest user here
+ self._order_test([self.admin_user, user1, user2], "is_guest")
+ self._order_test([self.admin_user, user1, user2], "is_guest", "f")
+ self._order_test([self.admin_user, user1, user2], "is_guest", "b")
+
+ # order by admin
+ self._order_test([user1, user2, self.admin_user], "admin")
+ self._order_test([user1, user2, self.admin_user], "admin", "f")
+ self._order_test([self.admin_user, user1, user2], "admin", "b")
+
+ # order by deactivated
+ self._order_test([self.admin_user, user2, user1], "deactivated")
+ self._order_test([self.admin_user, user2, user1], "deactivated", "f")
+ self._order_test([user1, self.admin_user, user2], "deactivated", "b")
+
+ # order by user_type
+ # like sort by ascending name, as no special user type here
+ self._order_test([self.admin_user, user1, user2], "user_type")
+ self._order_test([self.admin_user, user1, user2], "user_type", "f")
+ self._order_test([self.admin_user, user1, user2], "is_guest", "b")
+
+ # order by shadow_banned
+ self._order_test([self.admin_user, user2, user1], "shadow_banned")
+ self._order_test([self.admin_user, user2, user1], "shadow_banned", "f")
+ self._order_test([user1, self.admin_user, user2], "shadow_banned", "b")
+
+ # order by avatar_url
+ self._order_test([self.admin_user, user2, user1], "avatar_url")
+ self._order_test([self.admin_user, user2, user1], "avatar_url", "f")
+ self._order_test([user1, user2, self.admin_user], "avatar_url", "b")
+
+ def _order_test(
+ self,
+ expected_user_list: List[str],
+ order_by: Optional[str],
+ dir: Optional[str] = None,
+ ):
+ """Request the list of users in a certain order. Assert that order is what
+ we expect
+ Args:
+ expected_user_list: The list of user_id in the order we expect to get
+ back from the server
+ order_by: The type of ordering to give the server
+ dir: The direction of ordering to give the server
+ """
+
+ url = self.url + "?deactivated=true&"
+ if order_by is not None:
+ url += "order_by=%s&" % (order_by,)
+ if dir is not None and dir in ("b", "f"):
+ url += "dir=%s" % (dir,)
+ channel = self.make_request(
+ "GET",
+ url.encode("ascii"),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(channel.json_body["total"], len(expected_user_list))
+
+ returned_order = [row["name"] for row in channel.json_body["users"]]
+ self.assertEqual(expected_user_list, returned_order)
+ self._check_fields(channel.json_body["users"])
+
def _check_fields(self, content: JsonDict):
"""Checks that the expected user attributes are present in content
Args:
@@ -1003,12 +1122,23 @@ class UserRestTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
+ self.auth_handler = hs.get_auth_handler()
+ # create users and get access tokens
+ # regardless of whether password login or SSO is allowed
self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
+ self.admin_user_tok = self.get_success(
+ self.auth_handler.get_access_token_for_user_id(
+ self.admin_user, device_id=None, valid_until_ms=None
+ )
+ )
self.other_user = self.register_user("user", "pass", displayname="User")
- self.other_user_token = self.login("user", "pass")
+ self.other_user_token = self.get_success(
+ self.auth_handler.get_access_token_for_user_id(
+ self.other_user, device_id=None, valid_until_ms=None
+ )
+ )
self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote(
self.other_user
)
@@ -1081,7 +1211,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
- self.assertEqual(True, channel.json_body["admin"])
+ self.assertTrue(channel.json_body["admin"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
# Get user
@@ -1096,9 +1226,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
- self.assertEqual(True, channel.json_body["admin"])
- self.assertEqual(False, channel.json_body["is_guest"])
- self.assertEqual(False, channel.json_body["deactivated"])
+ self.assertTrue(channel.json_body["admin"])
+ self.assertFalse(channel.json_body["is_guest"])
+ self.assertFalse(channel.json_body["deactivated"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
def test_create_user(self):
@@ -1130,7 +1260,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
- self.assertEqual(False, channel.json_body["admin"])
+ self.assertFalse(channel.json_body["admin"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
# Get user
@@ -1145,10 +1275,10 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
- self.assertEqual(False, channel.json_body["admin"])
- self.assertEqual(False, channel.json_body["is_guest"])
- self.assertEqual(False, channel.json_body["deactivated"])
- self.assertEqual(False, channel.json_body["shadow_banned"])
+ self.assertFalse(channel.json_body["admin"])
+ self.assertFalse(channel.json_body["is_guest"])
+ self.assertFalse(channel.json_body["deactivated"])
+ self.assertFalse(channel.json_body["shadow_banned"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
@override_config(
@@ -1197,7 +1327,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
- self.assertEqual(False, channel.json_body["admin"])
+ self.assertFalse(channel.json_body["admin"])
@override_config(
{"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
@@ -1237,7 +1367,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Admin user is not blocked by mau anymore
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
- self.assertEqual(False, channel.json_body["admin"])
+ self.assertFalse(channel.json_body["admin"])
@override_config(
{
@@ -1429,24 +1559,23 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(False, channel.json_body["deactivated"])
+ self.assertFalse(channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User", channel.json_body["displayname"])
# Deactivate user
- body = json.dumps({"deactivated": True})
-
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"deactivated": True},
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(True, channel.json_body["deactivated"])
+ self.assertTrue(channel.json_body["deactivated"])
+ self.assertIsNone(channel.json_body["password_hash"])
self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User", channel.json_body["displayname"])
@@ -1461,7 +1590,8 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(True, channel.json_body["deactivated"])
+ self.assertTrue(channel.json_body["deactivated"])
+ self.assertIsNone(channel.json_body["password_hash"])
self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User", channel.json_body["displayname"])
@@ -1478,41 +1608,37 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertTrue(profile["display_name"] == "User")
# Deactivate user
- body = json.dumps({"deactivated": True})
-
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"deactivated": True},
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(True, channel.json_body["deactivated"])
+ self.assertTrue(channel.json_body["deactivated"])
# is not in user directory
profile = self.get_success(self.store.get_user_in_directory(self.other_user))
- self.assertTrue(profile is None)
+ self.assertIsNone(profile)
# Set new displayname user
- body = json.dumps({"displayname": "Foobar"})
-
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"displayname": "Foobar"},
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(True, channel.json_body["deactivated"])
+ self.assertTrue(channel.json_body["deactivated"])
self.assertEqual("Foobar", channel.json_body["displayname"])
# is not in user directory
profile = self.get_success(self.store.get_user_in_directory(self.other_user))
- self.assertTrue(profile is None)
+ self.assertIsNone(profile)
def test_reactivate_user(self):
"""
@@ -1520,48 +1646,92 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
# Deactivate the user.
+ self._deactivate_user("@user:test")
+
+ # Attempt to reactivate the user (without a password).
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content=json.dumps({"deactivated": True}).encode(encoding="utf_8"),
+ content={"deactivated": False},
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Reactivate the user.
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={"deactivated": False, "password": "foo"},
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertFalse(channel.json_body["deactivated"])
+ self.assertIsNotNone(channel.json_body["password_hash"])
self._is_erased("@user:test", False)
- d = self.store.mark_user_erased("@user:test")
- self.assertIsNone(self.get_success(d))
- self._is_erased("@user:test", True)
- # Attempt to reactivate the user (without a password).
+ @override_config({"password_config": {"localdb_enabled": False}})
+ def test_reactivate_user_localdb_disabled(self):
+ """
+ Test reactivating another user when using SSO.
+ """
+
+ # Deactivate the user.
+ self._deactivate_user("@user:test")
+
+ # Reactivate the user with a password
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content=json.dumps({"deactivated": False}).encode(encoding="utf_8"),
+ content={"deactivated": False, "password": "foo"},
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- # Reactivate the user.
+ # Reactivate the user without a password.
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content=json.dumps({"deactivated": False, "password": "foo"}).encode(
- encoding="utf_8"
- ),
+ content={"deactivated": False},
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertFalse(channel.json_body["deactivated"])
+ self.assertIsNone(channel.json_body["password_hash"])
+ self._is_erased("@user:test", False)
- # Get user
+ @override_config({"password_config": {"enabled": False}})
+ def test_reactivate_user_password_disabled(self):
+ """
+ Test reactivating another user when using SSO.
+ """
+
+ # Deactivate the user.
+ self._deactivate_user("@user:test")
+
+ # Reactivate the user with a password
channel = self.make_request(
- "GET",
+ "PUT",
self.url_other_user,
access_token=self.admin_user_tok,
+ content={"deactivated": False, "password": "foo"},
)
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+ # Reactivate the user without a password.
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={"deactivated": False},
+ )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(False, channel.json_body["deactivated"])
+ self.assertFalse(channel.json_body["deactivated"])
+ self.assertIsNone(channel.json_body["password_hash"])
self._is_erased("@user:test", False)
def test_set_user_as_admin(self):
@@ -1570,18 +1740,16 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
# Set a user as an admin
- body = json.dumps({"admin": True})
-
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"admin": True},
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(True, channel.json_body["admin"])
+ self.assertTrue(channel.json_body["admin"])
# Get user
channel = self.make_request(
@@ -1592,7 +1760,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(True, channel.json_body["admin"])
+ self.assertTrue(channel.json_body["admin"])
def test_accidental_deactivation_prevention(self):
"""
@@ -1602,13 +1770,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v2/users/@bob:test"
# Create user
- body = json.dumps({"password": "abc123"})
-
channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"password": "abc123"},
)
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
@@ -1628,13 +1794,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(0, channel.json_body["deactivated"])
# Change password (and use a str for deactivate instead of a bool)
- body = json.dumps({"password": "abc123", "deactivated": "false"}) # oops!
-
channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"password": "abc123", "deactivated": "false"},
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -1653,7 +1817,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Ensure they're still alive
self.assertEqual(0, channel.json_body["deactivated"])
- def _is_erased(self, user_id, expect):
+ def _is_erased(self, user_id: str, expect: bool) -> None:
"""Assert that the user is erased or not"""
d = self.store.is_user_erased(user_id)
if expect:
@@ -1661,6 +1825,24 @@ class UserRestTestCase(unittest.HomeserverTestCase):
else:
self.assertFalse(self.get_success(d))
+ def _deactivate_user(self, user_id: str) -> None:
+ """Deactivate user and set as erased"""
+
+ # Deactivate the user.
+ channel = self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/%s" % urllib.parse.quote(user_id),
+ access_token=self.admin_user_tok,
+ content={"deactivated": True},
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertTrue(channel.json_body["deactivated"])
+ self.assertIsNone(channel.json_body["password_hash"])
+ self._is_erased(user_id, False)
+ d = self.store.mark_user_erased(user_id)
+ self.assertIsNone(self.get_success(d))
+ self._is_erased(user_id, True)
+
class UserMembershipRestTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 227fffab58..bf39014277 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -161,6 +161,68 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
ev = channel.json_body
self.assertEqual(ev["content"]["x"], "y")
+ def test_message_edit(self):
+ """Ensure that the module doesn't cause issues with edited messages."""
+ # first patch the event checker so that it will modify the event
+ async def check(ev: EventBase, state):
+ d = ev.get_dict()
+ d["content"] = {
+ "msgtype": "m.text",
+ "body": d["content"]["body"].upper(),
+ }
+ return d
+
+ current_rules_module().check_event_allowed = check
+
+ # Send an event, then edit it.
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id,
+ {
+ "msgtype": "m.text",
+ "body": "Original body",
+ },
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ orig_event_id = channel.json_body["event_id"]
+
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/m.room.message/2" % self.room_id,
+ {
+ "m.new_content": {"msgtype": "m.text", "body": "Edited body"},
+ "m.relates_to": {
+ "rel_type": "m.replace",
+ "event_id": orig_event_id,
+ },
+ "msgtype": "m.text",
+ "body": "Edited body",
+ },
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ edited_event_id = channel.json_body["event_id"]
+
+ # ... and check that they both got modified
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, orig_event_id),
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ ev = channel.json_body
+ self.assertEqual(ev["content"]["body"], "ORIGINAL BODY")
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, edited_event_id),
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ ev = channel.json_body
+ self.assertEqual(ev["content"]["body"], "EDITED BODY")
+
def test_send_event(self):
"""Tests that the module can send an event into a room via the module api"""
content = {
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index 9734a2159a..ed433d9333 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Union
+from typing import Optional, Union
from twisted.internet.defer import succeed
@@ -74,7 +74,10 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
return channel
def recaptcha(
- self, session: str, expected_post_response: int, post_session: str = None
+ self,
+ session: str,
+ expected_post_response: int,
+ post_session: Optional[str] = None,
) -> None:
"""Get and respond to a fallback recaptcha. Returns the second request."""
if post_session is None:
diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py
index e808339fb3..287a1a485c 100644
--- a/tests/rest/client/v2_alpha/test_capabilities.py
+++ b/tests/rest/client/v2_alpha/test_capabilities.py
@@ -18,6 +18,7 @@ from synapse.rest.client.v1 import login
from synapse.rest.client.v2_alpha import capabilities
from tests import unittest
+from tests.unittest import override_config
class CapabilitiesTestCase(unittest.HomeserverTestCase):
@@ -33,6 +34,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver()
self.store = hs.get_datastore()
self.config = hs.config
+ self.auth_handler = hs.get_auth_handler()
return hs
def test_check_auth_required(self):
@@ -56,7 +58,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
capabilities["m.room_versions"]["default"],
)
- def test_get_change_password_capabilities(self):
+ def test_get_change_password_capabilities_password_login(self):
localpart = "user"
password = "pass"
user = self.register_user(localpart, password)
@@ -66,10 +68,36 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, 200)
-
- # Test case where password is handled outside of Synapse
self.assertTrue(capabilities["m.change_password"]["enabled"])
- self.get_success(self.store.user_set_password_hash(user, None))
+
+ @override_config({"password_config": {"localdb_enabled": False}})
+ def test_get_change_password_capabilities_localdb_disabled(self):
+ localpart = "user"
+ password = "pass"
+ user = self.register_user(localpart, password)
+ access_token = self.get_success(
+ self.auth_handler.get_access_token_for_user_id(
+ user, device_id=None, valid_until_ms=None
+ )
+ )
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, 200)
+ self.assertFalse(capabilities["m.change_password"]["enabled"])
+
+ @override_config({"password_config": {"enabled": False}})
+ def test_get_change_password_capabilities_password_disabled(self):
+ localpart = "user"
+ password = "pass"
+ user = self.register_user(localpart, password)
+ access_token = self.get_success(
+ self.auth_handler.get_access_token_for_user_id(
+ user, device_id=None, valid_until_ms=None
+ )
+ )
+
channel = self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index 7c457754f1..e7bb5583fc 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -39,6 +39,11 @@ class RelationsTestCase(unittest.HomeserverTestCase):
# We need to enable msc1849 support for aggregations
config = self.default_config()
config["experimental_msc1849_support_enabled"] = True
+
+ # We enable frozen dicts as relations/edits change event contents, so we
+ # want to test that we don't modify the events in the caches.
+ config["use_frozen_dicts"] = True
+
return self.setup_test_homeserver(config=config)
def prepare(self, reactor, clock, hs):
@@ -518,6 +523,63 @@ class RelationsTestCase(unittest.HomeserverTestCase):
{"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
)
+ def test_edit_reply(self):
+ """Test that editing a reply works."""
+
+ # Create a reply to edit.
+ channel = self._send_relation(
+ RelationTypes.REFERENCE,
+ "m.room.message",
+ content={"msgtype": "m.text", "body": "A reply!"},
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ reply = channel.json_body["event_id"]
+
+ new_body = {"msgtype": "m.text", "body": "I've been edited!"}
+ channel = self._send_relation(
+ RelationTypes.REPLACE,
+ "m.room.message",
+ content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
+ parent_id=reply,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ edit_event_id = channel.json_body["event_id"]
+
+ channel = self.make_request(
+ "GET",
+ "/rooms/%s/event/%s" % (self.room, reply),
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ # We expect to see the new body in the dict, as well as the reference
+ # metadata sill intact.
+ self.assertDictContainsSubset(new_body, channel.json_body["content"])
+ self.assertDictContainsSubset(
+ {
+ "m.relates_to": {
+ "event_id": self.parent_id,
+ "key": None,
+ "rel_type": "m.reference",
+ }
+ },
+ channel.json_body["content"],
+ )
+
+ # We expect that the edit relation appears in the unsigned relations
+ # section.
+ relations_dict = channel.json_body["unsigned"].get("m.relations")
+ self.assertIn(RelationTypes.REPLACE, relations_dict)
+
+ m_replace_dict = relations_dict[RelationTypes.REPLACE]
+ for key in ["event_id", "sender", "origin_server_ts"]:
+ self.assertIn(key, m_replace_dict)
+
+ self.assert_dict(
+ {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
+ )
+
def test_relations_redaction_redacts_edits(self):
"""Test that edits of an event are redacted when the original event
is redacted.
diff --git a/tests/server.py b/tests/server.py
index 2287d20076..b535a5d886 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -2,7 +2,7 @@ import json
import logging
from collections import deque
from io import SEEK_END, BytesIO
-from typing import Callable, Iterable, MutableMapping, Optional, Tuple, Union
+from typing import Callable, Dict, Iterable, MutableMapping, Optional, Tuple, Union
import attr
from typing_extensions import Deque
@@ -13,8 +13,11 @@ from twisted.internet._resolver import SimpleResolverComplexifier
from twisted.internet.defer import Deferred, fail, succeed
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import (
+ IHostnameResolver,
+ IProtocol,
+ IPullProducer,
+ IPushProducer,
IReactorPluggableNameResolver,
- IReactorTCP,
IResolverSimple,
ITransport,
)
@@ -45,11 +48,11 @@ class FakeChannel:
wire).
"""
- site = attr.ib(type=Site)
+ site = attr.ib(type=Union[Site, "FakeSite"])
_reactor = attr.ib()
result = attr.ib(type=dict, default=attr.Factory(dict))
_ip = attr.ib(type=str, default="127.0.0.1")
- _producer = None
+ _producer = None # type: Optional[Union[IPullProducer, IPushProducer]]
@property
def json_body(self):
@@ -159,7 +162,11 @@ class FakeChannel:
Any cookines found are added to the given dict
"""
- for h in self.headers.getRawHeaders("Set-Cookie"):
+ headers = self.headers.getRawHeaders("Set-Cookie")
+ if not headers:
+ return
+
+ for h in headers:
parts = h.split(";")
k, v = parts[0].split("=", maxsplit=1)
cookies[k] = v
@@ -311,8 +318,8 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
self._tcp_callbacks = {}
self._udp = []
- lookups = self.lookups = {}
- self._thread_callbacks = deque() # type: Deque[Callable[[], None]]()
+ lookups = self.lookups = {} # type: Dict[str, str]
+ self._thread_callbacks = deque() # type: Deque[Callable[[], None]]
@implementer(IResolverSimple)
class FakeResolver:
@@ -324,6 +331,9 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
self.nameResolver = SimpleResolverComplexifier(FakeResolver())
super().__init__()
+ def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
+ raise NotImplementedError()
+
def listenUDP(self, port, protocol, interface="", maxPacketSize=8196):
p = udp.Port(port, protocol, interface, maxPacketSize, self)
p.startListening()
@@ -593,7 +603,7 @@ class FakeTransport:
if self.disconnected:
return
- if getattr(self.other, "transport") is None:
+ if not hasattr(self.other, "transport"):
# the other has no transport yet; reschedule
if self.autoflush:
self._reactor.callLater(0.0, self.flush)
@@ -621,7 +631,9 @@ class FakeTransport:
self.disconnected = True
-def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol:
+def connect_client(
+ reactor: ThreadedMemoryReactorClock, client_id: int
+) -> Tuple[IProtocol, AccumulatingProtocol]:
"""
Connect a client to a fake TCP transport.
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index dabc1c5f09..ef4cf8d0f1 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2016 OpenMarket Ltd
+# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,32 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
import synapse.api.errors
-import tests.unittest
-import tests.utils
-
-
-class DeviceStoreTestCase(tests.unittest.TestCase):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.store = None # type: synapse.storage.DataStore
+from tests.unittest import HomeserverTestCase
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
+class DeviceStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
- @defer.inlineCallbacks
def test_store_new_device(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id", "device_id", "display_name")
)
- res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
+ res = self.get_success(self.store.get_device("user_id", "device_id"))
self.assertDictContainsSubset(
{
"user_id": "user_id",
@@ -48,19 +37,18 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
res,
)
- @defer.inlineCallbacks
def test_get_devices_by_user(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id", "device1", "display_name 1")
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id", "device2", "display_name 2")
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id2", "device3", "display_name 3")
)
- res = yield defer.ensureDeferred(self.store.get_devices_by_user("user_id"))
+ res = self.get_success(self.store.get_devices_by_user("user_id"))
self.assertEqual(2, len(res.keys()))
self.assertDictContainsSubset(
{
@@ -79,43 +67,41 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
res["device2"],
)
- @defer.inlineCallbacks
def test_count_devices_by_users(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id", "device1", "display_name 1")
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id", "device2", "display_name 2")
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id2", "device3", "display_name 3")
)
- res = yield defer.ensureDeferred(self.store.count_devices_by_users())
+ res = self.get_success(self.store.count_devices_by_users())
self.assertEqual(0, res)
- res = yield defer.ensureDeferred(self.store.count_devices_by_users(["unknown"]))
+ res = self.get_success(self.store.count_devices_by_users(["unknown"]))
self.assertEqual(0, res)
- res = yield defer.ensureDeferred(self.store.count_devices_by_users(["user_id"]))
+ res = self.get_success(self.store.count_devices_by_users(["user_id"]))
self.assertEqual(2, res)
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.store.count_devices_by_users(["user_id", "user_id2"])
)
self.assertEqual(3, res)
- @defer.inlineCallbacks
def test_get_device_updates_by_remote(self):
device_ids = ["device_id1", "device_id2"]
# Add two device updates with a single stream_id
- yield defer.ensureDeferred(
+ self.get_success(
self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
)
# Get all device updates ever meant for this remote
- now_stream_id, device_updates = yield defer.ensureDeferred(
+ now_stream_id, device_updates = self.get_success(
self.store.get_device_updates_by_remote("somehost", -1, limit=100)
)
@@ -131,37 +117,35 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
}
self.assertEqual(received_device_ids, set(expected_device_ids))
- @defer.inlineCallbacks
def test_update_device(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id", "device_id", "display_name 1")
)
- res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
+ res = self.get_success(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"])
# do a no-op first
- yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
- res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
+ self.get_success(self.store.update_device("user_id", "device_id"))
+ res = self.get_success(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"])
# do the update
- yield defer.ensureDeferred(
+ self.get_success(
self.store.update_device(
"user_id", "device_id", new_display_name="display_name 2"
)
)
# check it worked
- res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
+ res = self.get_success(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 2", res["display_name"])
- @defer.inlineCallbacks
def test_update_unknown_device(self):
- with self.assertRaises(synapse.api.errors.StoreError) as cm:
- yield defer.ensureDeferred(
- self.store.update_device(
- "user_id", "unknown_device_id", new_display_name="display_name 2"
- )
- )
- self.assertEqual(404, cm.exception.code)
+ exc = self.get_failure(
+ self.store.update_device(
+ "user_id", "unknown_device_id", new_display_name="display_name 2"
+ ),
+ synapse.api.errors.StoreError,
+ )
+ self.assertEqual(404, exc.value.code)
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index da93ca3980..0db233fd68 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,28 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from twisted.internet import defer
-
from synapse.types import RoomAlias, RoomID
-from tests import unittest
-from tests.utils import setup_test_homeserver
+from tests.unittest import HomeserverTestCase
-class DirectoryStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield setup_test_homeserver(self.addCleanup)
-
+class DirectoryStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.room = RoomID.from_string("!abcde:test")
self.alias = RoomAlias.from_string("#my-room:test")
- @defer.inlineCallbacks
def test_room_to_alias(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
)
@@ -42,16 +34,11 @@ class DirectoryStoreTestCase(unittest.TestCase):
self.assertEquals(
["#my-room:test"],
- (
- yield defer.ensureDeferred(
- self.store.get_aliases_for_room(self.room.to_string())
- )
- ),
+ (self.get_success(self.store.get_aliases_for_room(self.room.to_string()))),
)
- @defer.inlineCallbacks
def test_alias_to_room(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
)
@@ -59,28 +46,19 @@ class DirectoryStoreTestCase(unittest.TestCase):
self.assertObjectHasAttributes(
{"room_id": self.room.to_string(), "servers": ["test"]},
- (
- yield defer.ensureDeferred(
- self.store.get_association_from_room_alias(self.alias)
- )
- ),
+ (self.get_success(self.store.get_association_from_room_alias(self.alias))),
)
- @defer.inlineCallbacks
def test_delete_alias(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
)
)
- room_id = yield defer.ensureDeferred(self.store.delete_room_alias(self.alias))
+ room_id = self.get_success(self.store.delete_room_alias(self.alias))
self.assertEqual(self.room.to_string(), room_id)
self.assertIsNone(
- (
- yield defer.ensureDeferred(
- self.store.get_association_from_room_alias(self.alias)
- )
- )
+ (self.get_success(self.store.get_association_from_room_alias(self.alias)))
)
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 3fc4bb13b6..1e54b940fd 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2016 OpenMarket Ltd
+# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,30 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
+from tests.unittest import HomeserverTestCase
-import tests.unittest
-import tests.utils
-
-class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
+class EndToEndKeyStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
- @defer.inlineCallbacks
def test_key_without_device_name(self):
now = 1470174257070
json = {"key": "value"}
- yield defer.ensureDeferred(self.store.store_device("user", "device", None))
+ self.get_success(self.store.store_device("user", "device", None))
- yield defer.ensureDeferred(
- self.store.set_e2e_device_keys("user", "device", now, json)
- )
+ self.get_success(self.store.set_e2e_device_keys("user", "device", now, json))
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
)
self.assertIn("user", res)
@@ -44,38 +36,32 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
dev = res["user"]["device"]
self.assertDictContainsSubset(json, dev)
- @defer.inlineCallbacks
def test_reupload_key(self):
now = 1470174257070
json = {"key": "value"}
- yield defer.ensureDeferred(self.store.store_device("user", "device", None))
+ self.get_success(self.store.store_device("user", "device", None))
- changed = yield defer.ensureDeferred(
+ changed = self.get_success(
self.store.set_e2e_device_keys("user", "device", now, json)
)
self.assertTrue(changed)
# If we try to upload the same key then we should be told nothing
# changed
- changed = yield defer.ensureDeferred(
+ changed = self.get_success(
self.store.set_e2e_device_keys("user", "device", now, json)
)
self.assertFalse(changed)
- @defer.inlineCallbacks
def test_get_key_with_device_name(self):
now = 1470174257070
json = {"key": "value"}
- yield defer.ensureDeferred(
- self.store.set_e2e_device_keys("user", "device", now, json)
- )
- yield defer.ensureDeferred(
- self.store.store_device("user", "device", "display_name")
- )
+ self.get_success(self.store.set_e2e_device_keys("user", "device", now, json))
+ self.get_success(self.store.store_device("user", "device", "display_name"))
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
)
self.assertIn("user", res)
@@ -85,29 +71,28 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
{"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev
)
- @defer.inlineCallbacks
def test_multiple_devices(self):
now = 1470174257070
- yield defer.ensureDeferred(self.store.store_device("user1", "device1", None))
- yield defer.ensureDeferred(self.store.store_device("user1", "device2", None))
- yield defer.ensureDeferred(self.store.store_device("user2", "device1", None))
- yield defer.ensureDeferred(self.store.store_device("user2", "device2", None))
+ self.get_success(self.store.store_device("user1", "device1", None))
+ self.get_success(self.store.store_device("user1", "device2", None))
+ self.get_success(self.store.store_device("user2", "device1", None))
+ self.get_success(self.store.store_device("user2", "device2", None))
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
)
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.store.get_e2e_device_keys_for_cs_api(
(("user1", "device1"), ("user2", "device2"))
)
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 485f1ee033..239f7c9faf 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2016 OpenMarket Ltd
+# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,10 +15,7 @@
from mock import Mock
-from twisted.internet import defer
-
-import tests.unittest
-import tests.utils
+from tests.unittest import HomeserverTestCase
USER_ID = "@user:example.com"
@@ -30,37 +27,31 @@ HIGHLIGHT = [
]
-class EventPushActionsStoreTestCase(tests.unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
+class EventPushActionsStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.persist_events_store = hs.get_datastores().persist_events
- @defer.inlineCallbacks
def test_get_unread_push_actions_for_user_in_range_for_http(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_http(
USER_ID, 0, 1000, 20
)
)
- @defer.inlineCallbacks
def test_get_unread_push_actions_for_user_in_range_for_email(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_email(
USER_ID, 0, 1000, 20
)
)
- @defer.inlineCallbacks
def test_count_aggregation(self):
room_id = "!foo:example.com"
user_id = "@user1235:example.com"
- @defer.inlineCallbacks
def _assert_counts(noitf_count, highlight_count):
- counts = yield defer.ensureDeferred(
+ counts = self.get_success(
self.store.db_pool.runInteraction(
"", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
)
@@ -74,7 +65,6 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
},
)
- @defer.inlineCallbacks
def _inject_actions(stream, action):
event = Mock()
event.room_id = room_id
@@ -82,14 +72,14 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
event.internal_metadata.stream_ordering = stream
event.depth = stream
- yield defer.ensureDeferred(
+ self.get_success(
self.store.add_push_actions_to_staging(
event.event_id,
{user_id: action},
False,
)
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.db_pool.runInteraction(
"",
self.persist_events_store._set_push_actions_for_event_and_users_txn,
@@ -99,14 +89,14 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
)
def _rotate(stream):
- return defer.ensureDeferred(
+ self.get_success(
self.store.db_pool.runInteraction(
"", self.store._rotate_notifs_before_txn, stream
)
)
def _mark_read(stream, depth):
- return defer.ensureDeferred(
+ self.get_success(
self.store.db_pool.runInteraction(
"",
self.store._remove_old_push_actions_before_txn,
@@ -116,49 +106,48 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
)
)
- yield _assert_counts(0, 0)
- yield _inject_actions(1, PlAIN_NOTIF)
- yield _assert_counts(1, 0)
- yield _rotate(2)
- yield _assert_counts(1, 0)
+ _assert_counts(0, 0)
+ _inject_actions(1, PlAIN_NOTIF)
+ _assert_counts(1, 0)
+ _rotate(2)
+ _assert_counts(1, 0)
- yield _inject_actions(3, PlAIN_NOTIF)
- yield _assert_counts(2, 0)
- yield _rotate(4)
- yield _assert_counts(2, 0)
+ _inject_actions(3, PlAIN_NOTIF)
+ _assert_counts(2, 0)
+ _rotate(4)
+ _assert_counts(2, 0)
- yield _inject_actions(5, PlAIN_NOTIF)
- yield _mark_read(3, 3)
- yield _assert_counts(1, 0)
+ _inject_actions(5, PlAIN_NOTIF)
+ _mark_read(3, 3)
+ _assert_counts(1, 0)
- yield _mark_read(5, 5)
- yield _assert_counts(0, 0)
+ _mark_read(5, 5)
+ _assert_counts(0, 0)
- yield _inject_actions(6, PlAIN_NOTIF)
- yield _rotate(7)
+ _inject_actions(6, PlAIN_NOTIF)
+ _rotate(7)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.db_pool.simple_delete(
table="event_push_actions", keyvalues={"1": 1}, desc=""
)
)
- yield _assert_counts(1, 0)
+ _assert_counts(1, 0)
- yield _mark_read(7, 7)
- yield _assert_counts(0, 0)
+ _mark_read(7, 7)
+ _assert_counts(0, 0)
- yield _inject_actions(8, HIGHLIGHT)
- yield _assert_counts(1, 1)
- yield _rotate(9)
- yield _assert_counts(1, 1)
- yield _rotate(10)
- yield _assert_counts(1, 1)
+ _inject_actions(8, HIGHLIGHT)
+ _assert_counts(1, 1)
+ _rotate(9)
+ _assert_counts(1, 1)
+ _rotate(10)
+ _assert_counts(1, 1)
- @defer.inlineCallbacks
def test_find_first_stream_ordering_after_ts(self):
def add_event(so, ts):
- return defer.ensureDeferred(
+ self.get_success(
self.store.db_pool.simple_insert(
"events",
{
@@ -177,24 +166,16 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
)
# start with the base case where there are no events in the table
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(11)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(11))
self.assertEqual(r, 0)
# now with one event
- yield add_event(2, 10)
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(9)
- )
+ add_event(2, 10)
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(9))
self.assertEqual(r, 2)
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(10)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(10))
self.assertEqual(r, 2)
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(11)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(11))
self.assertEqual(r, 3)
# add a bunch of dummy events to the events table
@@ -205,39 +186,27 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
(10, 130),
(20, 140),
):
- yield add_event(stream_ordering, ts)
+ add_event(stream_ordering, ts)
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(110)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(110))
self.assertEqual(r, 3, "First event after 110ms should be 3, was %i" % r)
# 4 and 5 are both after 120: we want 4 rather than 5
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(120)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(120))
self.assertEqual(r, 4, "First event after 120ms should be 4, was %i" % r)
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(129)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(129))
self.assertEqual(r, 10, "First event after 129ms should be 10, was %i" % r)
# check we can get the last event
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(140)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(140))
self.assertEqual(r, 20, "First event after 14ms should be 20, was %i" % r)
# off the end
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(160)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(160))
self.assertEqual(r, 21)
# check we can find an event at ordering zero
- yield add_event(0, 5)
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(1)
- )
+ add_event(0, 5)
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(1))
self.assertEqual(r, 0)
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index ea63bd56b4..d18ceb41a9 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,59 +13,50 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from twisted.internet import defer
-
from synapse.types import UserID
from tests import unittest
-from tests.utils import setup_test_homeserver
-
-class ProfileStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield setup_test_homeserver(self.addCleanup)
+class ProfileStoreTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.u_frank = UserID.from_string("@frank:test")
- @defer.inlineCallbacks
def test_displayname(self):
- yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
+ self.get_success(self.store.create_profile(self.u_frank.localpart))
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
)
self.assertEquals(
"Frank",
(
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_profile_displayname(self.u_frank.localpart)
)
),
)
# test set to None
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_profile_displayname(self.u_frank.localpart, None)
)
self.assertIsNone(
(
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_profile_displayname(self.u_frank.localpart)
)
)
)
- @defer.inlineCallbacks
def test_avatar_url(self):
- yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
+ self.get_success(self.store.create_profile(self.u_frank.localpart))
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_profile_avatar_url(
self.u_frank.localpart, "http://my.site/here"
)
@@ -74,20 +65,20 @@ class ProfileStoreTestCase(unittest.TestCase):
self.assertEquals(
"http://my.site/here",
(
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_profile_avatar_url(self.u_frank.localpart)
)
),
)
# test set to None
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_profile_avatar_url(self.u_frank.localpart, None)
)
self.assertIsNone(
(
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_profile_avatar_url(self.u_frank.localpart)
)
)
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index b2a0e60856..2622207639 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,8 +15,6 @@
from canonicaljson import json
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.types import RoomID, UserID
@@ -230,10 +227,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self._base_builder = base_builder
self._event_id = event_id
- @defer.inlineCallbacks
- def build(self, prev_event_ids, auth_event_ids):
- built_event = yield defer.ensureDeferred(
- self._base_builder.build(prev_event_ids, auth_event_ids)
+ async def build(self, prev_event_ids, auth_event_ids):
+ built_event = await self._base_builder.build(
+ prev_event_ids, auth_event_ids
)
built_event._event_id = self._event_id
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 4eb41c46e8..c82cf15bc2 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,21 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from twisted.internet import defer
-
from synapse.api.constants import UserTypes
from synapse.api.errors import ThreepidValidationError
-from tests import unittest
-from tests.utils import setup_test_homeserver
-
+from tests.unittest import HomeserverTestCase
-class RegistrationStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield setup_test_homeserver(self.addCleanup)
+class RegistrationStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.user_id = "@my-user:test"
@@ -35,9 +28,8 @@ class RegistrationStoreTestCase(unittest.TestCase):
self.pwhash = "{xx1}123456789"
self.device_id = "akgjhdjklgshg"
- @defer.inlineCallbacks
def test_register(self):
- yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
+ self.get_success(self.store.register_user(self.user_id, self.pwhash))
self.assertEquals(
{
@@ -49,93 +41,81 @@ class RegistrationStoreTestCase(unittest.TestCase):
"consent_version": None,
"consent_server_notice_sent": None,
"appservice_id": None,
- "creation_ts": 1000,
+ "creation_ts": 0,
"user_type": None,
"deactivated": 0,
"shadow_banned": 0,
},
- (yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))),
+ (self.get_success(self.store.get_user_by_id(self.user_id))),
)
- @defer.inlineCallbacks
def test_add_tokens(self):
- yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
- yield defer.ensureDeferred(
+ self.get_success(self.store.register_user(self.user_id, self.pwhash))
+ self.get_success(
self.store.add_access_token_to_user(
self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
)
)
- result = yield defer.ensureDeferred(
- self.store.get_user_by_access_token(self.tokens[1])
- )
+ result = self.get_success(self.store.get_user_by_access_token(self.tokens[1]))
self.assertEqual(result.user_id, self.user_id)
self.assertEqual(result.device_id, self.device_id)
self.assertIsNotNone(result.token_id)
- @defer.inlineCallbacks
def test_user_delete_access_tokens(self):
# add some tokens
- yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
- yield defer.ensureDeferred(
+ self.get_success(self.store.register_user(self.user_id, self.pwhash))
+ self.get_success(
self.store.add_access_token_to_user(
self.user_id, self.tokens[0], device_id=None, valid_until_ms=None
)
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.add_access_token_to_user(
self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
)
)
# now delete some
- yield defer.ensureDeferred(
+ self.get_success(
self.store.user_delete_access_tokens(self.user_id, device_id=self.device_id)
)
# check they were deleted
- user = yield defer.ensureDeferred(
- self.store.get_user_by_access_token(self.tokens[1])
- )
+ user = self.get_success(self.store.get_user_by_access_token(self.tokens[1]))
self.assertIsNone(user, "access token was not deleted by device_id")
# check the one not associated with the device was not deleted
- user = yield defer.ensureDeferred(
- self.store.get_user_by_access_token(self.tokens[0])
- )
+ user = self.get_success(self.store.get_user_by_access_token(self.tokens[0]))
self.assertEqual(self.user_id, user.user_id)
# now delete the rest
- yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id))
+ self.get_success(self.store.user_delete_access_tokens(self.user_id))
- user = yield defer.ensureDeferred(
- self.store.get_user_by_access_token(self.tokens[0])
- )
+ user = self.get_success(self.store.get_user_by_access_token(self.tokens[0]))
self.assertIsNone(user, "access token was not deleted without device_id")
- @defer.inlineCallbacks
def test_is_support_user(self):
TEST_USER = "@test:test"
SUPPORT_USER = "@support:test"
- res = yield defer.ensureDeferred(self.store.is_support_user(None))
+ res = self.get_success(self.store.is_support_user(None))
self.assertFalse(res)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.register_user(user_id=TEST_USER, password_hash=None)
)
- res = yield defer.ensureDeferred(self.store.is_support_user(TEST_USER))
+ res = self.get_success(self.store.is_support_user(TEST_USER))
self.assertFalse(res)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.register_user(
user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT
)
)
- res = yield defer.ensureDeferred(self.store.is_support_user(SUPPORT_USER))
+ res = self.get_success(self.store.is_support_user(SUPPORT_USER))
self.assertTrue(res)
- @defer.inlineCallbacks
def test_3pid_inhibit_invalid_validation_session_error(self):
"""Tests that enabling the configuration option to inhibit 3PID errors on
/requestToken also inhibits validation errors caused by an unknown session ID.
@@ -143,30 +123,28 @@ class RegistrationStoreTestCase(unittest.TestCase):
# Check that, with the config setting set to false (the default value), a
# validation error is caused by the unknown session ID.
- try:
- yield defer.ensureDeferred(
- self.store.validate_threepid_session(
- "fake_sid",
- "fake_client_secret",
- "fake_token",
- 0,
- )
- )
- except ThreepidValidationError as e:
- self.assertEquals(e.msg, "Unknown session_id", e)
+ e = self.get_failure(
+ self.store.validate_threepid_session(
+ "fake_sid",
+ "fake_client_secret",
+ "fake_token",
+ 0,
+ ),
+ ThreepidValidationError,
+ )
+ self.assertEquals(e.value.msg, "Unknown session_id", e)
# Set the config setting to true.
self.store._ignore_unknown_session_error = True
# Check that now the validation error is caused by the token not matching.
- try:
- yield defer.ensureDeferred(
- self.store.validate_threepid_session(
- "fake_sid",
- "fake_client_secret",
- "fake_token",
- 0,
- )
- )
- except ThreepidValidationError as e:
- self.assertEquals(e.msg, "Validation token not found or has expired", e)
+ e = self.get_failure(
+ self.store.validate_threepid_session(
+ "fake_sid",
+ "fake_client_secret",
+ "fake_token",
+ 0,
+ ),
+ ThreepidValidationError,
+ )
+ self.assertEquals(e.value.msg, "Validation token not found or has expired", e)
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index bc8400f240..0089d33c93 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,22 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes
from synapse.api.room_versions import RoomVersions
from synapse.types import RoomAlias, RoomID, UserID
-from tests import unittest
-from tests.utils import setup_test_homeserver
-
+from tests.unittest import HomeserverTestCase
-class RoomStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield setup_test_homeserver(self.addCleanup)
+class RoomStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
# We can't test RoomStore on its own without the DirectoryStore, for
# management of the 'room_aliases' table
self.store = hs.get_datastore()
@@ -37,7 +30,7 @@ class RoomStoreTestCase(unittest.TestCase):
self.alias = RoomAlias.from_string("#a-room-name:test")
self.u_creator = UserID.from_string("@creator:test")
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_room(
self.room.to_string(),
room_creator_user_id=self.u_creator.to_string(),
@@ -46,7 +39,6 @@ class RoomStoreTestCase(unittest.TestCase):
)
)
- @defer.inlineCallbacks
def test_get_room(self):
self.assertDictContainsSubset(
{
@@ -54,16 +46,12 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(),
"is_public": True,
},
- (yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))),
+ (self.get_success(self.store.get_room(self.room.to_string()))),
)
- @defer.inlineCallbacks
def test_get_room_unknown_room(self):
- self.assertIsNone(
- (yield defer.ensureDeferred(self.store.get_room("!uknown:test")))
- )
+ self.assertIsNone((self.get_success(self.store.get_room("!uknown:test"))))
- @defer.inlineCallbacks
def test_get_room_with_stats(self):
self.assertDictContainsSubset(
{
@@ -71,29 +59,17 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(),
"public": True,
},
- (
- yield defer.ensureDeferred(
- self.store.get_room_with_stats(self.room.to_string())
- )
- ),
+ (self.get_success(self.store.get_room_with_stats(self.room.to_string()))),
)
- @defer.inlineCallbacks
def test_get_room_with_stats_unknown_room(self):
self.assertIsNone(
- (
- yield defer.ensureDeferred(
- self.store.get_room_with_stats("!uknown:test")
- )
- ),
+ (self.get_success(self.store.get_room_with_stats("!uknown:test"))),
)
-class RoomEventsStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = setup_test_homeserver(self.addCleanup)
-
+class RoomEventsStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
# Room events need the full datastore, for persist_event() and
# get_room_state()
self.store = hs.get_datastore()
@@ -102,7 +78,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
self.room = RoomID.from_string("!abcde:test")
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_room(
self.room.to_string(),
room_creator_user_id="@creator:text",
@@ -111,23 +87,21 @@ class RoomEventsStoreTestCase(unittest.TestCase):
)
)
- @defer.inlineCallbacks
def inject_room_event(self, **kwargs):
- yield defer.ensureDeferred(
+ self.get_success(
self.storage.persistence.persist_event(
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
)
)
- @defer.inlineCallbacks
def STALE_test_room_name(self):
name = "A-Room-Name"
- yield self.inject_room_event(
+ self.inject_room_event(
etype=EventTypes.Name, name=name, content={"name": name}, depth=1
)
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.store.get_current_state(room_id=self.room.to_string())
)
@@ -137,15 +111,14 @@ class RoomEventsStoreTestCase(unittest.TestCase):
state[0],
)
- @defer.inlineCallbacks
def STALE_test_room_topic(self):
topic = "A place for things"
- yield self.inject_room_event(
+ self.inject_room_event(
etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1
)
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.store.get_current_state(room_id=self.room.to_string())
)
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 8bd12fa847..f06b452fa9 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,24 +15,18 @@
import logging
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.storage.state import StateFilter
from synapse.types import RoomID, UserID
-import tests.unittest
-import tests.utils
+from tests.unittest import HomeserverTestCase
logger = logging.getLogger(__name__)
-class StateStoreTestCase(tests.unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
-
+class StateStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.state_datastore = self.storage.state.stores.state
@@ -44,7 +38,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room = RoomID.from_string("!abc123:test")
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_room(
self.room.to_string(),
room_creator_user_id="@creator:text",
@@ -53,7 +47,6 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
)
- @defer.inlineCallbacks
def inject_state_event(self, room, sender, typ, state_key, content):
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
@@ -66,13 +59,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
},
)
- event, context = yield defer.ensureDeferred(
+ event, context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
- yield defer.ensureDeferred(
- self.storage.persistence.persist_event(event, context)
- )
+ self.get_success(self.storage.persistence.persist_event(event, context))
return event
@@ -82,16 +73,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(s1[t].event_id, s2[t].event_id)
self.assertEqual(len(s1), len(s2))
- @defer.inlineCallbacks
def test_get_state_groups_ids(self):
- e1 = yield self.inject_state_event(
- self.room, self.u_alice, EventTypes.Create, "", {}
- )
- e2 = yield self.inject_state_event(
+ e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
+ e2 = self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
- state_group_map = yield defer.ensureDeferred(
+ state_group_map = self.get_success(
self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
)
self.assertEqual(len(state_group_map), 1)
@@ -101,16 +89,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
{(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id},
)
- @defer.inlineCallbacks
def test_get_state_groups(self):
- e1 = yield self.inject_state_event(
- self.room, self.u_alice, EventTypes.Create, "", {}
- )
- e2 = yield self.inject_state_event(
+ e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
+ e2 = self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
- state_group_map = yield defer.ensureDeferred(
+ state_group_map = self.get_success(
self.storage.state.get_state_groups(self.room, [e2.event_id])
)
self.assertEqual(len(state_group_map), 1)
@@ -118,32 +103,29 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
- @defer.inlineCallbacks
def test_get_state_for_event(self):
# this defaults to a linear DAG as each new injection defaults to whatever
# forward extremities are currently in the DB for this room.
- e1 = yield self.inject_state_event(
- self.room, self.u_alice, EventTypes.Create, "", {}
- )
- e2 = yield self.inject_state_event(
+ e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
+ e2 = self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
- e3 = yield self.inject_state_event(
+ e3 = self.inject_state_event(
self.room,
self.u_alice,
EventTypes.Member,
self.u_alice.to_string(),
{"membership": Membership.JOIN},
)
- e4 = yield self.inject_state_event(
+ e4 = self.inject_state_event(
self.room,
self.u_bob,
EventTypes.Member,
self.u_bob.to_string(),
{"membership": Membership.JOIN},
)
- e5 = yield self.inject_state_event(
+ e5 = self.inject_state_event(
self.room,
self.u_bob,
EventTypes.Member,
@@ -152,9 +134,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we get the full state as of the final event
- state = yield defer.ensureDeferred(
- self.storage.state.get_state_for_event(e5.event_id)
- )
+ state = self.get_success(self.storage.state.get_state_for_event(e5.event_id))
self.assertIsNotNone(e4)
@@ -170,7 +150,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we can filter to the m.room.name event (with a '' state key)
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
)
@@ -179,7 +159,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can filter to the m.room.name event (with a wildcard None state key)
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
)
@@ -188,7 +168,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can grab the m.room.member events (with a wildcard None state key)
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
)
@@ -200,7 +180,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# check we can grab a specific room member without filtering out the
# other event types
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
@@ -220,7 +200,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check that we can grab everything except members
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
@@ -238,17 +218,14 @@ class StateStoreTestCase(tests.unittest.TestCase):
#######################################################
room_id = self.room.to_string()
- group_ids = yield defer.ensureDeferred(
+ group_ids = self.get_success(
self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
)
group = list(group_ids.keys())[0]
# test _get_state_for_group_using_cache correctly filters out members
# with types=[]
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -265,10 +242,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -281,10 +255,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with wildcard types
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -301,10 +272,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -324,10 +292,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -344,10 +309,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -360,10 +322,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -377,14 +336,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
#######################################################
# deliberately remove e2 (room name) from the _state_group_cache
- (
- is_all,
- known_absent,
- state_dict_ids,
- ) = self.state_datastore._state_group_cache.get(group)
+ cache_entry = self.state_datastore._state_group_cache.get(group)
+ state_dict_ids = cache_entry.value
- self.assertEqual(is_all, True)
- self.assertEqual(known_absent, set())
+ self.assertEqual(cache_entry.full, True)
+ self.assertEqual(cache_entry.known_absent, set())
self.assertDictEqual(
state_dict_ids,
{
@@ -403,14 +359,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
fetched_keys=((e1.type, e1.state_key),),
)
- (
- is_all,
- known_absent,
- state_dict_ids,
- ) = self.state_datastore._state_group_cache.get(group)
+ cache_entry = self.state_datastore._state_group_cache.get(group)
+ state_dict_ids = cache_entry.value
- self.assertEqual(is_all, False)
- self.assertEqual(known_absent, {(e1.type, e1.state_key)})
+ self.assertEqual(cache_entry.full, False)
+ self.assertEqual(cache_entry.known_absent, {(e1.type, e1.state_key)})
self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id})
############################################
@@ -419,10 +372,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters out members
# with types=[]
room_id = self.room.to_string()
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -434,10 +384,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
room_id = self.room.to_string()
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -450,10 +397,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# wildcard types
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -464,10 +408,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -486,10 +427,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -500,10 +438,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -516,10 +451,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -530,10 +462,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({}, state_dict)
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index a6f63f4aaf..019c5b7b14 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,10 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
-from tests import unittest
-from tests.utils import setup_test_homeserver
+from tests.unittest import HomeserverTestCase, override_config
ALICE = "@alice:a"
BOB = "@bob:b"
@@ -25,73 +22,52 @@ BOBBY = "@bobby:a"
BELA = "@somenickname:a"
-class UserDirectoryStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- self.hs = yield setup_test_homeserver(self.addCleanup)
- self.store = self.hs.get_datastore()
+class UserDirectoryStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
# alice and bob are both in !room_id. bobby is not but shares
# a homeserver with alice.
- yield defer.ensureDeferred(
- self.store.update_profile_in_user_dir(ALICE, "alice", None)
- )
- yield defer.ensureDeferred(
- self.store.update_profile_in_user_dir(BOB, "bob", None)
- )
- yield defer.ensureDeferred(
- self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
- )
- yield defer.ensureDeferred(
- self.store.update_profile_in_user_dir(BELA, "Bela", None)
- )
- yield defer.ensureDeferred(
- self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
- )
+ self.get_success(self.store.update_profile_in_user_dir(ALICE, "alice", None))
+ self.get_success(self.store.update_profile_in_user_dir(BOB, "bob", None))
+ self.get_success(self.store.update_profile_in_user_dir(BOBBY, "bobby", None))
+ self.get_success(self.store.update_profile_in_user_dir(BELA, "Bela", None))
+ self.get_success(self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)))
- @defer.inlineCallbacks
def test_search_user_dir(self):
# normally when alice searches the directory she should just find
# bob because bobby doesn't share a room with her.
- r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
+ r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10))
self.assertFalse(r["limited"])
self.assertEqual(1, len(r["results"]))
self.assertDictEqual(
r["results"][0], {"user_id": BOB, "display_name": "bob", "avatar_url": None}
)
- @defer.inlineCallbacks
+ @override_config({"user_directory": {"search_all_users": True}})
def test_search_user_dir_all_users(self):
- self.hs.config.user_directory_search_all_users = True
- try:
- r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
- self.assertFalse(r["limited"])
- self.assertEqual(2, len(r["results"]))
- self.assertDictEqual(
- r["results"][0],
- {"user_id": BOB, "display_name": "bob", "avatar_url": None},
- )
- self.assertDictEqual(
- r["results"][1],
- {"user_id": BOBBY, "display_name": "bobby", "avatar_url": None},
- )
- finally:
- self.hs.config.user_directory_search_all_users = False
+ r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10))
+ self.assertFalse(r["limited"])
+ self.assertEqual(2, len(r["results"]))
+ self.assertDictEqual(
+ r["results"][0],
+ {"user_id": BOB, "display_name": "bob", "avatar_url": None},
+ )
+ self.assertDictEqual(
+ r["results"][1],
+ {"user_id": BOBBY, "display_name": "bobby", "avatar_url": None},
+ )
- @defer.inlineCallbacks
+ @override_config({"user_directory": {"search_all_users": True}})
def test_search_user_dir_stop_words(self):
"""Tests that a user can look up another user by searching for the start if its
display name even if that name happens to be a common English word that would
usually be ignored in full text searches.
"""
- self.hs.config.user_directory_search_all_users = True
- try:
- r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "be", 10))
- self.assertFalse(r["limited"])
- self.assertEqual(1, len(r["results"]))
- self.assertDictEqual(
- r["results"][0],
- {"user_id": BELA, "display_name": "Bela", "avatar_url": None},
- )
- finally:
- self.hs.config.user_directory_search_all_users = False
+ r = self.get_success(self.store.search_user_dir(ALICE, "be", 10))
+ self.assertFalse(r["limited"])
+ self.assertEqual(1, len(r["results"]))
+ self.assertDictEqual(
+ r["results"][0],
+ {"user_id": BELA, "display_name": "Bela", "avatar_url": None},
+ )
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 3f2691ee6b..b5f18344dc 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -207,6 +207,226 @@ class EventAuthTestCase(unittest.TestCase):
do_sig_check=False,
)
+ def test_join_rules_public(self):
+ """
+ Test joining a public room.
+ """
+ creator = "@creator:example.com"
+ pleb = "@joiner:example.com"
+
+ auth_events = {
+ ("m.room.create", ""): _create_event(creator),
+ ("m.room.member", creator): _join_event(creator),
+ ("m.room.join_rules", ""): _join_rules_event(creator, "public"),
+ }
+
+ # Check join.
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user cannot be force-joined to a room.
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.V6,
+ _member_event(pleb, "join", sender=creator),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # Banned should be rejected.
+ auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban")
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user who left can re-join.
+ auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave")
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user can send a join if they're in the room.
+ auth_events[("m.room.member", pleb)] = _member_event(pleb, "join")
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user can accept an invite.
+ auth_events[("m.room.member", pleb)] = _member_event(
+ pleb, "invite", sender=creator
+ )
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ def test_join_rules_invite(self):
+ """
+ Test joining an invite only room.
+ """
+ creator = "@creator:example.com"
+ pleb = "@joiner:example.com"
+
+ auth_events = {
+ ("m.room.create", ""): _create_event(creator),
+ ("m.room.member", creator): _join_event(creator),
+ ("m.room.join_rules", ""): _join_rules_event(creator, "invite"),
+ }
+
+ # A join without an invite is rejected.
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user cannot be force-joined to a room.
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.V6,
+ _member_event(pleb, "join", sender=creator),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # Banned should be rejected.
+ auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban")
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user who left cannot re-join.
+ auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave")
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user can send a join if they're in the room.
+ auth_events[("m.room.member", pleb)] = _member_event(pleb, "join")
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user can accept an invite.
+ auth_events[("m.room.member", pleb)] = _member_event(
+ pleb, "invite", sender=creator
+ )
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ def test_join_rules_msc3083_restricted(self):
+ """
+ Test joining a restricted room from MSC3083.
+
+ This is pretty much the same test as public.
+ """
+ creator = "@creator:example.com"
+ pleb = "@joiner:example.com"
+
+ auth_events = {
+ ("m.room.create", ""): _create_event(creator),
+ ("m.room.member", creator): _join_event(creator),
+ ("m.room.join_rules", ""): _join_rules_event(creator, "restricted"),
+ }
+
+ # Older room versions don't understand this join rule
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.V6,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # Check join.
+ event_auth.check(
+ RoomVersions.MSC3083,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user cannot be force-joined to a room.
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.MSC3083,
+ _member_event(pleb, "join", sender=creator),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # Banned should be rejected.
+ auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban")
+ with self.assertRaises(AuthError):
+ event_auth.check(
+ RoomVersions.MSC3083,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user who left can re-join.
+ auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave")
+ event_auth.check(
+ RoomVersions.MSC3083,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user can send a join if they're in the room.
+ auth_events[("m.room.member", pleb)] = _member_event(pleb, "join")
+ event_auth.check(
+ RoomVersions.MSC3083,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
+ # A user can accept an invite.
+ auth_events[("m.room.member", pleb)] = _member_event(
+ pleb, "invite", sender=creator
+ )
+ event_auth.check(
+ RoomVersions.MSC3083,
+ _join_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ )
+
# helpers for making events
@@ -225,19 +445,24 @@ def _create_event(user_id):
)
-def _join_event(user_id):
+def _member_event(user_id, membership, sender=None):
return make_event_from_dict(
{
"room_id": TEST_ROOM_ID,
"event_id": _get_event_id(),
"type": "m.room.member",
- "sender": user_id,
+ "sender": sender or user_id,
"state_key": user_id,
- "content": {"membership": "join"},
+ "content": {"membership": membership},
+ "prev_events": [],
}
)
+def _join_event(user_id):
+ return _member_event(user_id, "join")
+
+
def _power_levels_event(sender, content):
return make_event_from_dict(
{
@@ -277,6 +502,21 @@ def _random_state_event(sender):
)
+def _join_rules_event(sender, join_rule):
+ return make_event_from_dict(
+ {
+ "room_id": TEST_ROOM_ID,
+ "event_id": _get_event_id(),
+ "type": "m.room.join_rules",
+ "sender": sender,
+ "state_key": "",
+ "content": {
+ "join_rule": join_rule,
+ },
+ }
+ )
+
+
event_count = 0
diff --git a/tests/unittest.py b/tests/unittest.py
index ca7031c724..57b6a395c7 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -32,6 +32,7 @@ from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest
from twisted.web.resource import Resource
+from synapse import events
from synapse.api.constants import EventTypes, Membership
from synapse.config.homeserver import HomeServerConfig
from synapse.config.ratelimiting import FederationRateLimitConfig
@@ -140,7 +141,7 @@ class TestCase(unittest.TestCase):
try:
self.assertEquals(attrs[key], getattr(obj, key))
except AssertionError as e:
- raise (type(e))(e.message + " for '.%s'" % key)
+ raise (type(e))("Assert error for '.{}':".format(key)) from e
def assert_dict(self, required, actual):
"""Does a partial assert of a dict.
@@ -229,6 +230,11 @@ class HomeserverTestCase(TestCase):
self._hs_args = {"clock": self.clock, "reactor": self.reactor}
self.hs = self.make_homeserver(self.reactor, self.clock)
+ # Honour the `use_frozen_dicts` config option. We have to do this
+ # manually because this is taken care of in the app `start` code, which
+ # we don't run. Plus we want to reset it on tearDown.
+ events.USE_FROZEN_DICTS = self.hs.config.use_frozen_dicts
+
if self.hs is None:
raise Exception("No homeserver returned from make_homeserver.")
@@ -292,6 +298,10 @@ class HomeserverTestCase(TestCase):
if hasattr(self, "prepare"):
self.prepare(self.reactor, self.clock, self.hs)
+ def tearDown(self):
+ # Reset to not use frozen dicts.
+ events.USE_FROZEN_DICTS = False
+
def wait_on_thread(self, deferred, timeout=10):
"""
Wait until a Deferred is done, where it's waiting on a real thread.
@@ -461,7 +471,7 @@ class HomeserverTestCase(TestCase):
kwargs["config"] = config_obj
async def run_bg_updates():
- with LoggingContext("run_bg_updates", request="run_bg_updates-1"):
+ with LoggingContext("run_bg_updates"):
while not await stor.db_pool.updates.has_completed_background_updates():
await stor.db_pool.updates.do_next_background_update(1)
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index afb11b9caf..e434e21aee 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -661,14 +661,13 @@ class CachedListDescriptorTestCase(unittest.TestCase):
@descriptors.cachedList("fn", "args1")
async def list_fn(self, args1, arg2):
- assert current_context().request == "c1"
+ assert current_context().name == "c1"
# we want this to behave like an asynchronous function
await run_on_reactor()
- assert current_context().request == "c1"
+ assert current_context().name == "c1"
return self.mock(args1, arg2)
- with LoggingContext() as c1:
- c1.request = "c1"
+ with LoggingContext("c1") as c1:
obj = Cls()
obj.mock.return_value = {10: "fish", 20: "chips"}
d1 = obj.list_fn([10, 20], 2)
diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py
index 34fdc9a43a..2f41333f4c 100644
--- a/tests/util/test_dict_cache.py
+++ b/tests/util/test_dict_cache.py
@@ -27,7 +27,9 @@ class DictCacheTestCase(unittest.TestCase):
key = "test_simple_cache_hit_full"
v = self.cache.get(key)
- self.assertEqual((False, set(), {}), v)
+ self.assertIs(v.full, False)
+ self.assertEqual(v.known_absent, set())
+ self.assertEqual({}, v.value)
seq = self.cache.sequence
test_value = {"test": "test_simple_cache_hit_full"}
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 58ee918f65..5d9c4665aa 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -17,11 +17,10 @@ from .. import unittest
class LoggingContextTestCase(unittest.TestCase):
def _check_test_key(self, value):
- self.assertEquals(current_context().request, value)
+ self.assertEquals(current_context().name, value)
def test_with_context(self):
- with LoggingContext() as context_one:
- context_one.request = "test"
+ with LoggingContext("test"):
self._check_test_key("test")
@defer.inlineCallbacks
@@ -30,15 +29,13 @@ class LoggingContextTestCase(unittest.TestCase):
@defer.inlineCallbacks
def competing_callback():
- with LoggingContext() as competing_context:
- competing_context.request = "competing"
+ with LoggingContext("competing"):
yield clock.sleep(0)
self._check_test_key("competing")
reactor.callLater(0, competing_callback)
- with LoggingContext() as context_one:
- context_one.request = "one"
+ with LoggingContext("one"):
yield clock.sleep(0)
self._check_test_key("one")
@@ -47,9 +44,7 @@ class LoggingContextTestCase(unittest.TestCase):
callback_completed = [False]
- with LoggingContext() as context_one:
- context_one.request = "one"
-
+ with LoggingContext("one"):
# fire off function, but don't wait on it.
d2 = run_in_background(function)
@@ -133,9 +128,7 @@ class LoggingContextTestCase(unittest.TestCase):
sentinel_context = current_context()
- with LoggingContext() as context_one:
- context_one.request = "one"
-
+ with LoggingContext("one"):
d1 = make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable
self.assertIs(current_context(), sentinel_context)
@@ -149,9 +142,7 @@ class LoggingContextTestCase(unittest.TestCase):
def test_make_deferred_yieldable_with_chained_deferreds(self):
sentinel_context = current_context()
- with LoggingContext() as context_one:
- context_one.request = "one"
-
+ with LoggingContext("one"):
d1 = make_deferred_yieldable(_chained_deferred_function())
# make sure that the context was reset by make_deferred_yieldable
self.assertIs(current_context(), sentinel_context)
@@ -166,9 +157,7 @@ class LoggingContextTestCase(unittest.TestCase):
"""Check that make_deferred_yieldable does the right thing when its
argument isn't actually a deferred"""
- with LoggingContext() as context_one:
- context_one.request = "one"
-
+ with LoggingContext("one"):
d1 = make_deferred_yieldable("bum")
self._check_test_key("one")
@@ -177,9 +166,9 @@ class LoggingContextTestCase(unittest.TestCase):
self._check_test_key("one")
def test_nested_logging_context(self):
- with LoggingContext(request="foo"):
+ with LoggingContext("foo"):
nested_context = nested_logging_context(suffix="bar")
- self.assertEqual(nested_context.request, "foo-bar")
+ self.assertEqual(nested_context.name, "foo-bar")
@defer.inlineCallbacks
def test_make_deferred_yieldable_with_await(self):
@@ -193,9 +182,7 @@ class LoggingContextTestCase(unittest.TestCase):
sentinel_context = current_context()
- with LoggingContext() as context_one:
- context_one.request = "one"
-
+ with LoggingContext("one"):
d1 = make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable
self.assertIs(current_context(), sentinel_context)
diff --git a/tests/utils.py b/tests/utils.py
index be80b13760..a141ee6496 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -122,7 +122,6 @@ def default_config(name, parse=False):
"enable_registration_captcha": False,
"macaroon_secret_key": "not even a little secret",
"trusted_third_party_id_servers": [],
- "room_invite_state_types": [],
"password_providers": [],
"worker_replication_url": "",
"worker_app": None,
|