diff --git a/tests/config/test_load.py b/tests/config/test_load.py
index 734a9983e8..c109425671 100644
--- a/tests/config/test_load.py
+++ b/tests/config/test_load.py
@@ -20,6 +20,7 @@ from io import StringIO
import yaml
+from synapse.config import ConfigError
from synapse.config.homeserver import HomeServerConfig
from tests import unittest
@@ -35,9 +36,9 @@ class ConfigLoadingTestCase(unittest.TestCase):
def test_load_fails_if_server_name_missing(self):
self.generate_config_and_remove_lines_containing("server_name")
- with self.assertRaises(Exception):
+ with self.assertRaises(ConfigError):
HomeServerConfig.load_config("", ["-c", self.file])
- with self.assertRaises(Exception):
+ with self.assertRaises(ConfigError):
HomeServerConfig.load_or_generate_config("", ["-c", self.file])
def test_generates_and_loads_macaroon_secret_key(self):
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 30fcc4c1bf..946482b7e7 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -16,6 +16,7 @@ import time
from mock import Mock
+import attr
import canonicaljson
import signedjson.key
import signedjson.sign
@@ -68,6 +69,11 @@ class MockPerspectiveServer:
signedjson.sign.sign_json(res, self.server_name, self.key)
+@attr.s(slots=True)
+class FakeRequest:
+ id = attr.ib()
+
+
@logcontext_clean
class KeyringTestCase(unittest.HomeserverTestCase):
def check_context(self, val, expected):
@@ -89,7 +95,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
first_lookup_deferred = Deferred()
async def first_lookup_fetch(keys_to_fetch):
- self.assertEquals(current_context().request, "context_11")
+ self.assertEquals(current_context().request.id, "context_11")
self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}})
await make_deferred_yieldable(first_lookup_deferred)
@@ -102,9 +108,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
mock_fetcher.get_keys.side_effect = first_lookup_fetch
async def first_lookup():
- with LoggingContext("context_11") as context_11:
- context_11.request = "context_11"
-
+ with LoggingContext("context_11", request=FakeRequest("context_11")):
res_deferreds = kr.verify_json_objects_for_server(
[("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")]
)
@@ -130,7 +134,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# should block rather than start a second call
async def second_lookup_fetch(keys_to_fetch):
- self.assertEquals(current_context().request, "context_12")
+ self.assertEquals(current_context().request.id, "context_12")
return {
"server10": {
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
@@ -142,9 +146,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
second_lookup_state = [0]
async def second_lookup():
- with LoggingContext("context_12") as context_12:
- context_12.request = "context_12"
-
+ with LoggingContext("context_12", request=FakeRequest("context_12")):
res_deferreds_2 = kr.verify_json_objects_for_server(
[("server10", json1, 0, "test")]
)
@@ -589,10 +591,7 @@ def get_key_id(key):
@defer.inlineCallbacks
def run_in_context(f, *args, **kwargs):
- with LoggingContext("testctx") as ctx:
- # we set the "request" prop to make it easier to follow what's going on in the
- # logs.
- ctx.request = "testctx"
+ with LoggingContext("testctx"):
rv = yield f(*args, **kwargs)
return rv
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
new file mode 100644
index 0000000000..c6e547f11c
--- /dev/null
+++ b/tests/events/test_presence_router.py
@@ -0,0 +1,386 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
+
+from mock import Mock
+
+import attr
+
+from synapse.api.constants import EduTypes
+from synapse.events.presence_router import PresenceRouter
+from synapse.federation.units import Transaction
+from synapse.handlers.presence import UserPresenceState
+from synapse.module_api import ModuleApi
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, presence, room
+from synapse.types import JsonDict, StreamToken, create_requester
+
+from tests.handlers.test_sync import generate_sync_config
+from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config
+
+
+@attr.s
+class PresenceRouterTestConfig:
+ users_who_should_receive_all_presence = attr.ib(type=List[str], default=[])
+
+
+class PresenceRouterTestModule:
+ def __init__(self, config: PresenceRouterTestConfig, module_api: ModuleApi):
+ self._config = config
+ self._module_api = module_api
+
+ async def get_users_for_states(
+ self, state_updates: Iterable[UserPresenceState]
+ ) -> Dict[str, Set[UserPresenceState]]:
+ users_to_state = {
+ user_id: set(state_updates)
+ for user_id in self._config.users_who_should_receive_all_presence
+ }
+ return users_to_state
+
+ async def get_interested_users(
+ self, user_id: str
+ ) -> Union[Set[str], PresenceRouter.ALL_USERS]:
+ if user_id in self._config.users_who_should_receive_all_presence:
+ return PresenceRouter.ALL_USERS
+
+ return set()
+
+ @staticmethod
+ def parse_config(config_dict: dict) -> PresenceRouterTestConfig:
+ """Parse a configuration dictionary from the homeserver config, do
+ some validation and return a typed PresenceRouterConfig.
+
+ Args:
+ config_dict: The configuration dictionary.
+
+ Returns:
+ A validated config object.
+ """
+ # Initialise a typed config object
+ config = PresenceRouterTestConfig()
+
+ config.users_who_should_receive_all_presence = config_dict.get(
+ "users_who_should_receive_all_presence"
+ )
+
+ return config
+
+
+class PresenceRouterTestCase(FederatingHomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ presence.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ return self.setup_test_homeserver(
+ federation_transport_client=Mock(spec=["send_transaction"]),
+ )
+
+ def prepare(self, reactor, clock, homeserver):
+ self.sync_handler = self.hs.get_sync_handler()
+ self.module_api = homeserver.get_module_api()
+
+ @override_config(
+ {
+ "presence": {
+ "presence_router": {
+ "module": __name__ + ".PresenceRouterTestModule",
+ "config": {
+ "users_who_should_receive_all_presence": [
+ "@presence_gobbler:test",
+ ]
+ },
+ }
+ },
+ "send_federation": True,
+ }
+ )
+ def test_receiving_all_presence(self):
+ """Test that a user that does not share a room with another other can receive
+ presence for them, due to presence routing.
+ """
+ # Create a user who should receive all presence of others
+ self.presence_receiving_user_id = self.register_user(
+ "presence_gobbler", "monkey"
+ )
+ self.presence_receiving_user_tok = self.login("presence_gobbler", "monkey")
+
+ # And two users who should not have any special routing
+ self.other_user_one_id = self.register_user("other_user_one", "monkey")
+ self.other_user_one_tok = self.login("other_user_one", "monkey")
+ self.other_user_two_id = self.register_user("other_user_two", "monkey")
+ self.other_user_two_tok = self.login("other_user_two", "monkey")
+
+ # Put the other two users in a room with each other
+ room_id = self.helper.create_room_as(
+ self.other_user_one_id, tok=self.other_user_one_tok
+ )
+
+ self.helper.invite(
+ room_id,
+ self.other_user_one_id,
+ self.other_user_two_id,
+ tok=self.other_user_one_tok,
+ )
+ self.helper.join(room_id, self.other_user_two_id, tok=self.other_user_two_tok)
+ # User one sends some presence
+ send_presence_update(
+ self,
+ self.other_user_one_id,
+ self.other_user_one_tok,
+ "online",
+ "boop",
+ )
+
+ # Check that the presence receiving user gets user one's presence when syncing
+ presence_updates, sync_token = sync_presence(
+ self, self.presence_receiving_user_id
+ )
+ self.assertEqual(len(presence_updates), 1)
+
+ presence_update = presence_updates[0] # type: UserPresenceState
+ self.assertEqual(presence_update.user_id, self.other_user_one_id)
+ self.assertEqual(presence_update.state, "online")
+ self.assertEqual(presence_update.status_msg, "boop")
+
+ # Have all three users send presence
+ send_presence_update(
+ self,
+ self.other_user_one_id,
+ self.other_user_one_tok,
+ "online",
+ "user_one",
+ )
+ send_presence_update(
+ self,
+ self.other_user_two_id,
+ self.other_user_two_tok,
+ "online",
+ "user_two",
+ )
+ send_presence_update(
+ self,
+ self.presence_receiving_user_id,
+ self.presence_receiving_user_tok,
+ "online",
+ "presence_gobbler",
+ )
+
+ # Check that the presence receiving user gets everyone's presence
+ presence_updates, _ = sync_presence(
+ self, self.presence_receiving_user_id, sync_token
+ )
+ self.assertEqual(len(presence_updates), 3)
+
+ # But that User One only get itself and User Two's presence
+ presence_updates, _ = sync_presence(self, self.other_user_one_id)
+ self.assertEqual(len(presence_updates), 2)
+
+ found = False
+ for update in presence_updates:
+ if update.user_id == self.other_user_two_id:
+ self.assertEqual(update.state, "online")
+ self.assertEqual(update.status_msg, "user_two")
+ found = True
+
+ self.assertTrue(found)
+
+ @override_config(
+ {
+ "presence": {
+ "presence_router": {
+ "module": __name__ + ".PresenceRouterTestModule",
+ "config": {
+ "users_who_should_receive_all_presence": [
+ "@presence_gobbler1:test",
+ "@presence_gobbler2:test",
+ "@far_away_person:island",
+ ]
+ },
+ }
+ },
+ "send_federation": True,
+ }
+ )
+ def test_send_local_online_presence_to_with_module(self):
+ """Tests that send_local_presence_to_users sends local online presence to a set
+ of specified local and remote users, with a custom PresenceRouter module enabled.
+ """
+ # Create a user who will send presence updates
+ self.other_user_id = self.register_user("other_user", "monkey")
+ self.other_user_tok = self.login("other_user", "monkey")
+
+ # And another two users that will also send out presence updates, as well as receive
+ # theirs and everyone else's
+ self.presence_receiving_user_one_id = self.register_user(
+ "presence_gobbler1", "monkey"
+ )
+ self.presence_receiving_user_one_tok = self.login("presence_gobbler1", "monkey")
+ self.presence_receiving_user_two_id = self.register_user(
+ "presence_gobbler2", "monkey"
+ )
+ self.presence_receiving_user_two_tok = self.login("presence_gobbler2", "monkey")
+
+ # Have all three users send some presence updates
+ send_presence_update(
+ self,
+ self.other_user_id,
+ self.other_user_tok,
+ "online",
+ "I'm online!",
+ )
+ send_presence_update(
+ self,
+ self.presence_receiving_user_one_id,
+ self.presence_receiving_user_one_tok,
+ "online",
+ "I'm also online!",
+ )
+ send_presence_update(
+ self,
+ self.presence_receiving_user_two_id,
+ self.presence_receiving_user_two_tok,
+ "unavailable",
+ "I'm in a meeting!",
+ )
+
+ # Mark each presence-receiving user for receiving all user presence
+ self.get_success(
+ self.module_api.send_local_online_presence_to(
+ [
+ self.presence_receiving_user_one_id,
+ self.presence_receiving_user_two_id,
+ ]
+ )
+ )
+
+ # Perform a sync for each user
+
+ # The other user should only receive their own presence
+ presence_updates, _ = sync_presence(self, self.other_user_id)
+ self.assertEqual(len(presence_updates), 1)
+
+ presence_update = presence_updates[0] # type: UserPresenceState
+ self.assertEqual(presence_update.user_id, self.other_user_id)
+ self.assertEqual(presence_update.state, "online")
+ self.assertEqual(presence_update.status_msg, "I'm online!")
+
+ # Whereas both presence receiving users should receive everyone's presence updates
+ presence_updates, _ = sync_presence(self, self.presence_receiving_user_one_id)
+ self.assertEqual(len(presence_updates), 3)
+ presence_updates, _ = sync_presence(self, self.presence_receiving_user_two_id)
+ self.assertEqual(len(presence_updates), 3)
+
+ # Test that sending to a remote user works
+ remote_user_id = "@far_away_person:island"
+
+ # Note that due to the remote user being in our module's
+ # users_who_should_receive_all_presence config, they would have
+ # received user presence updates already.
+ #
+ # Thus we reset the mock, and try sending all online local user
+ # presence again
+ self.hs.get_federation_transport_client().send_transaction.reset_mock()
+
+ # Broadcast local user online presence
+ self.get_success(
+ self.module_api.send_local_online_presence_to([remote_user_id])
+ )
+
+ # Check that the expected presence updates were sent
+ expected_users = [
+ self.other_user_id,
+ self.presence_receiving_user_one_id,
+ self.presence_receiving_user_two_id,
+ ]
+
+ calls = (
+ self.hs.get_federation_transport_client().send_transaction.call_args_list
+ )
+ for call in calls:
+ federation_transaction = call.args[0] # type: Transaction
+
+ # Get the sent EDUs in this transaction
+ edus = federation_transaction.get_dict()["edus"]
+
+ for edu in edus:
+ # Make sure we're only checking presence-type EDUs
+ if edu["edu_type"] != EduTypes.Presence:
+ continue
+
+ # EDUs can contain multiple presence updates
+ for presence_update in edu["content"]["push"]:
+ # Check for presence updates that contain the user IDs we're after
+ expected_users.remove(presence_update["user_id"])
+
+ # Ensure that no offline states are being sent out
+ self.assertNotEqual(presence_update["presence"], "offline")
+
+ self.assertEqual(len(expected_users), 0)
+
+
+def send_presence_update(
+ testcase: TestCase,
+ user_id: str,
+ access_token: str,
+ presence_state: str,
+ status_message: Optional[str] = None,
+) -> JsonDict:
+ # Build the presence body
+ body = {"presence": presence_state}
+ if status_message:
+ body["status_msg"] = status_message
+
+ # Update the user's presence state
+ channel = testcase.make_request(
+ "PUT", "/presence/%s/status" % (user_id,), body, access_token=access_token
+ )
+ testcase.assertEqual(channel.code, 200)
+
+ return channel.json_body
+
+
+def sync_presence(
+ testcase: TestCase,
+ user_id: str,
+ since_token: Optional[StreamToken] = None,
+) -> Tuple[List[UserPresenceState], StreamToken]:
+ """Perform a sync request for the given user and return the user presence updates
+ they've received, as well as the next_batch token.
+
+ This method assumes testcase.sync_handler points to the homeserver's sync handler.
+
+ Args:
+ testcase: The testcase that is currently being run.
+ user_id: The ID of the user to generate a sync response for.
+ since_token: An optional token to indicate from at what point to sync from.
+
+ Returns:
+ A tuple containing a list of presence updates, and the sync response's
+ next_batch token.
+ """
+ requester = create_requester(user_id)
+ sync_config = generate_sync_config(requester.user.to_string())
+ sync_result = testcase.get_success(
+ testcase.sync_handler.wait_for_sync_for_user(
+ requester, sync_config, since_token
+ )
+ )
+
+ return sync_result.presence, sync_result.next_batch
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index e62586142e..8e950f25c5 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -37,7 +37,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
def test_wait_for_sync_for_user_auth_blocking(self):
user_id1 = "@user1:test"
user_id2 = "@user2:test"
- sync_config = self._generate_sync_config(user_id1)
+ sync_config = generate_sync_config(user_id1)
requester = create_requester(user_id1)
self.reactor.advance(100) # So we get not 0 time
@@ -60,7 +60,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.auth_blocking._hs_disabled = False
- sync_config = self._generate_sync_config(user_id2)
+ sync_config = generate_sync_config(user_id2)
requester = create_requester(user_id2)
e = self.get_failure(
@@ -69,11 +69,12 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
- def _generate_sync_config(self, user_id):
- return SyncConfig(
- user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]),
- filter_collection=DEFAULT_FILTER_COLLECTION,
- is_guest=False,
- request_key="request_key",
- device_id="device_id",
- )
+
+def generate_sync_config(user_id: str) -> SyncConfig:
+ return SyncConfig(
+ user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]),
+ filter_collection=DEFAULT_FILTER_COLLECTION,
+ is_guest=False,
+ request_key="request_key",
+ device_id="device_id",
+ )
diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index 48a74e2eee..bfe0d11c93 100644
--- a/tests/logging/test_terse_json.py
+++ b/tests/logging/test_terse_json.py
@@ -12,15 +12,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import json
import logging
-from io import StringIO
+from io import BytesIO, StringIO
+
+from mock import Mock, patch
+
+from twisted.web.server import Request
+from synapse.http.site import SynapseRequest
from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter
from synapse.logging.context import LoggingContext, LoggingContextFilter
from tests.logging import LoggerCleanupMixin
+from tests.server import FakeChannel
from tests.unittest import TestCase
@@ -120,7 +125,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
handler.addFilter(LoggingContextFilter())
logger = self.get_logger(handler)
- with LoggingContext(request="test"):
+ with LoggingContext("name"):
logger.info("Hello there, %s!", "wally")
log = self.get_log_line()
@@ -134,4 +139,61 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
]
self.assertCountEqual(log.keys(), expected_log_keys)
self.assertEqual(log["log"], "Hello there, wally!")
- self.assertEqual(log["request"], "test")
+ self.assertTrue(log["request"].startswith("name@"))
+
+ def test_with_request_context(self):
+ """
+ Information from the logging context request should be added to the JSON response.
+ """
+ handler = logging.StreamHandler(self.output)
+ handler.setFormatter(JsonFormatter())
+ handler.addFilter(LoggingContextFilter())
+ logger = self.get_logger(handler)
+
+ # A full request isn't needed here.
+ site = Mock(spec=["site_tag", "server_version_string", "getResourceFor"])
+ site.site_tag = "test-site"
+ site.server_version_string = "Server v1"
+ request = SynapseRequest(FakeChannel(site, None))
+ # Call requestReceived to finish instantiating the object.
+ request.content = BytesIO()
+ # Partially skip some of the internal processing of SynapseRequest.
+ request._started_processing = Mock()
+ request.request_metrics = Mock(spec=["name"])
+ with patch.object(Request, "render"):
+ request.requestReceived(b"POST", b"/_matrix/client/versions", b"1.1")
+
+ # Also set the requester to ensure the processing works.
+ request.requester = "@foo:test"
+
+ with LoggingContext(parent_context=request.logcontext):
+ logger.info("Hello there, %s!", "wally")
+
+ log = self.get_log_line()
+
+ # The terse logger includes additional request information, if possible.
+ expected_log_keys = [
+ "log",
+ "level",
+ "namespace",
+ "request",
+ "ip_address",
+ "site_tag",
+ "requester",
+ "authenticated_entity",
+ "method",
+ "url",
+ "protocol",
+ "user_agent",
+ ]
+ self.assertCountEqual(log.keys(), expected_log_keys)
+ self.assertEqual(log["log"], "Hello there, wally!")
+ self.assertTrue(log["request"].startswith("POST-"))
+ self.assertEqual(log["ip_address"], "127.0.0.1")
+ self.assertEqual(log["site_tag"], "test-site")
+ self.assertEqual(log["requester"], "@foo:test")
+ self.assertEqual(log["authenticated_entity"], "@foo:test")
+ self.assertEqual(log["method"], "POST")
+ self.assertEqual(log["url"], "/_matrix/client/versions")
+ self.assertEqual(log["protocol"], "1.1")
+ self.assertEqual(log["user_agent"], "")
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index edacd1b566..1d1fceeecf 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -14,25 +14,37 @@
# limitations under the License.
from mock import Mock
+from synapse.api.constants import EduTypes
from synapse.events import EventBase
+from synapse.federation.units import Transaction
+from synapse.handlers.presence import UserPresenceState
from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v1 import login, presence, room
from synapse.types import create_requester
-from tests.unittest import HomeserverTestCase
+from tests.events.test_presence_router import send_presence_update, sync_presence
+from tests.test_utils.event_injection import inject_member_event
+from tests.unittest import FederatingHomeserverTestCase, override_config
-class ModuleApiTestCase(HomeserverTestCase):
+class ModuleApiTestCase(FederatingHomeserverTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
room.register_servlets,
+ presence.register_servlets,
]
def prepare(self, reactor, clock, homeserver):
self.store = homeserver.get_datastore()
self.module_api = homeserver.get_module_api()
self.event_creation_handler = homeserver.get_event_creation_handler()
+ self.sync_handler = homeserver.get_sync_handler()
+
+ def make_homeserver(self, reactor, clock):
+ return self.setup_test_homeserver(
+ federation_transport_client=Mock(spec=["send_transaction"]),
+ )
def test_can_register_user(self):
"""Tests that an external module can register a user"""
@@ -205,3 +217,160 @@ class ModuleApiTestCase(HomeserverTestCase):
)
)
self.assertFalse(is_in_public_rooms)
+
+ # The ability to send federation is required by send_local_online_presence_to.
+ @override_config({"send_federation": True})
+ def test_send_local_online_presence_to(self):
+ """Tests that send_local_presence_to_users sends local online presence to local users."""
+ # Create a user who will send presence updates
+ self.presence_receiver_id = self.register_user("presence_receiver", "monkey")
+ self.presence_receiver_tok = self.login("presence_receiver", "monkey")
+
+ # And another user that will send presence updates out
+ self.presence_sender_id = self.register_user("presence_sender", "monkey")
+ self.presence_sender_tok = self.login("presence_sender", "monkey")
+
+ # Put them in a room together so they will receive each other's presence updates
+ room_id = self.helper.create_room_as(
+ self.presence_receiver_id,
+ tok=self.presence_receiver_tok,
+ )
+ self.helper.join(room_id, self.presence_sender_id, tok=self.presence_sender_tok)
+
+ # Presence sender comes online
+ send_presence_update(
+ self,
+ self.presence_sender_id,
+ self.presence_sender_tok,
+ "online",
+ "I'm online!",
+ )
+
+ # Presence receiver should have received it
+ presence_updates, sync_token = sync_presence(self, self.presence_receiver_id)
+ self.assertEqual(len(presence_updates), 1)
+
+ presence_update = presence_updates[0] # type: UserPresenceState
+ self.assertEqual(presence_update.user_id, self.presence_sender_id)
+ self.assertEqual(presence_update.state, "online")
+
+ # Syncing again should result in no presence updates
+ presence_updates, sync_token = sync_presence(
+ self, self.presence_receiver_id, sync_token
+ )
+ self.assertEqual(len(presence_updates), 0)
+
+ # Trigger sending local online presence
+ self.get_success(
+ self.module_api.send_local_online_presence_to(
+ [
+ self.presence_receiver_id,
+ ]
+ )
+ )
+
+ # Presence receiver should have received online presence again
+ presence_updates, sync_token = sync_presence(
+ self, self.presence_receiver_id, sync_token
+ )
+ self.assertEqual(len(presence_updates), 1)
+
+ presence_update = presence_updates[0] # type: UserPresenceState
+ self.assertEqual(presence_update.user_id, self.presence_sender_id)
+ self.assertEqual(presence_update.state, "online")
+
+ # Presence sender goes offline
+ send_presence_update(
+ self,
+ self.presence_sender_id,
+ self.presence_sender_tok,
+ "offline",
+ "I slink back into the darkness.",
+ )
+
+ # Trigger sending local online presence
+ self.get_success(
+ self.module_api.send_local_online_presence_to(
+ [
+ self.presence_receiver_id,
+ ]
+ )
+ )
+
+ # Presence receiver should *not* have received offline state
+ presence_updates, sync_token = sync_presence(
+ self, self.presence_receiver_id, sync_token
+ )
+ self.assertEqual(len(presence_updates), 0)
+
+ @override_config({"send_federation": True})
+ def test_send_local_online_presence_to_federation(self):
+ """Tests that send_local_presence_to_users sends local online presence to remote users."""
+ # Create a user who will send presence updates
+ self.presence_sender_id = self.register_user("presence_sender", "monkey")
+ self.presence_sender_tok = self.login("presence_sender", "monkey")
+
+ # And a room they're a part of
+ room_id = self.helper.create_room_as(
+ self.presence_sender_id,
+ tok=self.presence_sender_tok,
+ )
+
+ # Mark them as online
+ send_presence_update(
+ self,
+ self.presence_sender_id,
+ self.presence_sender_tok,
+ "online",
+ "I'm online!",
+ )
+
+ # Make up a remote user to send presence to
+ remote_user_id = "@far_away_person:island"
+
+ # Create a join membership event for the remote user into the room.
+ # This allows presence information to flow from one user to the other.
+ self.get_success(
+ inject_member_event(
+ self.hs,
+ room_id,
+ sender=remote_user_id,
+ target=remote_user_id,
+ membership="join",
+ )
+ )
+
+ # The remote user would have received the existing room members' presence
+ # when they joined the room.
+ #
+ # Thus we reset the mock, and try sending online local user
+ # presence again
+ self.hs.get_federation_transport_client().send_transaction.reset_mock()
+
+ # Broadcast local user online presence
+ self.get_success(
+ self.module_api.send_local_online_presence_to([remote_user_id])
+ )
+
+ # Check that a presence update was sent as part of a federation transaction
+ found_update = False
+ calls = (
+ self.hs.get_federation_transport_client().send_transaction.call_args_list
+ )
+ for call in calls:
+ federation_transaction = call.args[0] # type: Transaction
+
+ # Get the sent EDUs in this transaction
+ edus = federation_transaction.get_dict()["edus"]
+
+ for edu in edus:
+ # Make sure we're only checking presence-type EDUs
+ if edu["edu_type"] != EduTypes.Presence:
+ continue
+
+ # EDUs can contain multiple presence updates
+ for presence_update in edu["content"]["push"]:
+ if presence_update["user_id"] == self.presence_sender_id:
+ found_update = True
+
+ self.assertTrue(found_update)
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index dabc1c5f09..ef4cf8d0f1 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2016 OpenMarket Ltd
+# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,32 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
import synapse.api.errors
-import tests.unittest
-import tests.utils
-
-
-class DeviceStoreTestCase(tests.unittest.TestCase):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.store = None # type: synapse.storage.DataStore
+from tests.unittest import HomeserverTestCase
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
+class DeviceStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
- @defer.inlineCallbacks
def test_store_new_device(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id", "device_id", "display_name")
)
- res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
+ res = self.get_success(self.store.get_device("user_id", "device_id"))
self.assertDictContainsSubset(
{
"user_id": "user_id",
@@ -48,19 +37,18 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
res,
)
- @defer.inlineCallbacks
def test_get_devices_by_user(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id", "device1", "display_name 1")
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id", "device2", "display_name 2")
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id2", "device3", "display_name 3")
)
- res = yield defer.ensureDeferred(self.store.get_devices_by_user("user_id"))
+ res = self.get_success(self.store.get_devices_by_user("user_id"))
self.assertEqual(2, len(res.keys()))
self.assertDictContainsSubset(
{
@@ -79,43 +67,41 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
res["device2"],
)
- @defer.inlineCallbacks
def test_count_devices_by_users(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id", "device1", "display_name 1")
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id", "device2", "display_name 2")
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id2", "device3", "display_name 3")
)
- res = yield defer.ensureDeferred(self.store.count_devices_by_users())
+ res = self.get_success(self.store.count_devices_by_users())
self.assertEqual(0, res)
- res = yield defer.ensureDeferred(self.store.count_devices_by_users(["unknown"]))
+ res = self.get_success(self.store.count_devices_by_users(["unknown"]))
self.assertEqual(0, res)
- res = yield defer.ensureDeferred(self.store.count_devices_by_users(["user_id"]))
+ res = self.get_success(self.store.count_devices_by_users(["user_id"]))
self.assertEqual(2, res)
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.store.count_devices_by_users(["user_id", "user_id2"])
)
self.assertEqual(3, res)
- @defer.inlineCallbacks
def test_get_device_updates_by_remote(self):
device_ids = ["device_id1", "device_id2"]
# Add two device updates with a single stream_id
- yield defer.ensureDeferred(
+ self.get_success(
self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
)
# Get all device updates ever meant for this remote
- now_stream_id, device_updates = yield defer.ensureDeferred(
+ now_stream_id, device_updates = self.get_success(
self.store.get_device_updates_by_remote("somehost", -1, limit=100)
)
@@ -131,37 +117,35 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
}
self.assertEqual(received_device_ids, set(expected_device_ids))
- @defer.inlineCallbacks
def test_update_device(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_device("user_id", "device_id", "display_name 1")
)
- res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
+ res = self.get_success(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"])
# do a no-op first
- yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
- res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
+ self.get_success(self.store.update_device("user_id", "device_id"))
+ res = self.get_success(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"])
# do the update
- yield defer.ensureDeferred(
+ self.get_success(
self.store.update_device(
"user_id", "device_id", new_display_name="display_name 2"
)
)
# check it worked
- res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
+ res = self.get_success(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 2", res["display_name"])
- @defer.inlineCallbacks
def test_update_unknown_device(self):
- with self.assertRaises(synapse.api.errors.StoreError) as cm:
- yield defer.ensureDeferred(
- self.store.update_device(
- "user_id", "unknown_device_id", new_display_name="display_name 2"
- )
- )
- self.assertEqual(404, cm.exception.code)
+ exc = self.get_failure(
+ self.store.update_device(
+ "user_id", "unknown_device_id", new_display_name="display_name 2"
+ ),
+ synapse.api.errors.StoreError,
+ )
+ self.assertEqual(404, exc.value.code)
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index da93ca3980..0db233fd68 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,28 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from twisted.internet import defer
-
from synapse.types import RoomAlias, RoomID
-from tests import unittest
-from tests.utils import setup_test_homeserver
+from tests.unittest import HomeserverTestCase
-class DirectoryStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield setup_test_homeserver(self.addCleanup)
-
+class DirectoryStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.room = RoomID.from_string("!abcde:test")
self.alias = RoomAlias.from_string("#my-room:test")
- @defer.inlineCallbacks
def test_room_to_alias(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
)
@@ -42,16 +34,11 @@ class DirectoryStoreTestCase(unittest.TestCase):
self.assertEquals(
["#my-room:test"],
- (
- yield defer.ensureDeferred(
- self.store.get_aliases_for_room(self.room.to_string())
- )
- ),
+ (self.get_success(self.store.get_aliases_for_room(self.room.to_string()))),
)
- @defer.inlineCallbacks
def test_alias_to_room(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
)
@@ -59,28 +46,19 @@ class DirectoryStoreTestCase(unittest.TestCase):
self.assertObjectHasAttributes(
{"room_id": self.room.to_string(), "servers": ["test"]},
- (
- yield defer.ensureDeferred(
- self.store.get_association_from_room_alias(self.alias)
- )
- ),
+ (self.get_success(self.store.get_association_from_room_alias(self.alias))),
)
- @defer.inlineCallbacks
def test_delete_alias(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
)
)
- room_id = yield defer.ensureDeferred(self.store.delete_room_alias(self.alias))
+ room_id = self.get_success(self.store.delete_room_alias(self.alias))
self.assertEqual(self.room.to_string(), room_id)
self.assertIsNone(
- (
- yield defer.ensureDeferred(
- self.store.get_association_from_room_alias(self.alias)
- )
- )
+ (self.get_success(self.store.get_association_from_room_alias(self.alias)))
)
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 3fc4bb13b6..1e54b940fd 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2016 OpenMarket Ltd
+# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,30 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
+from tests.unittest import HomeserverTestCase
-import tests.unittest
-import tests.utils
-
-class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
+class EndToEndKeyStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
- @defer.inlineCallbacks
def test_key_without_device_name(self):
now = 1470174257070
json = {"key": "value"}
- yield defer.ensureDeferred(self.store.store_device("user", "device", None))
+ self.get_success(self.store.store_device("user", "device", None))
- yield defer.ensureDeferred(
- self.store.set_e2e_device_keys("user", "device", now, json)
- )
+ self.get_success(self.store.set_e2e_device_keys("user", "device", now, json))
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
)
self.assertIn("user", res)
@@ -44,38 +36,32 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
dev = res["user"]["device"]
self.assertDictContainsSubset(json, dev)
- @defer.inlineCallbacks
def test_reupload_key(self):
now = 1470174257070
json = {"key": "value"}
- yield defer.ensureDeferred(self.store.store_device("user", "device", None))
+ self.get_success(self.store.store_device("user", "device", None))
- changed = yield defer.ensureDeferred(
+ changed = self.get_success(
self.store.set_e2e_device_keys("user", "device", now, json)
)
self.assertTrue(changed)
# If we try to upload the same key then we should be told nothing
# changed
- changed = yield defer.ensureDeferred(
+ changed = self.get_success(
self.store.set_e2e_device_keys("user", "device", now, json)
)
self.assertFalse(changed)
- @defer.inlineCallbacks
def test_get_key_with_device_name(self):
now = 1470174257070
json = {"key": "value"}
- yield defer.ensureDeferred(
- self.store.set_e2e_device_keys("user", "device", now, json)
- )
- yield defer.ensureDeferred(
- self.store.store_device("user", "device", "display_name")
- )
+ self.get_success(self.store.set_e2e_device_keys("user", "device", now, json))
+ self.get_success(self.store.store_device("user", "device", "display_name"))
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
)
self.assertIn("user", res)
@@ -85,29 +71,28 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
{"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev
)
- @defer.inlineCallbacks
def test_multiple_devices(self):
now = 1470174257070
- yield defer.ensureDeferred(self.store.store_device("user1", "device1", None))
- yield defer.ensureDeferred(self.store.store_device("user1", "device2", None))
- yield defer.ensureDeferred(self.store.store_device("user2", "device1", None))
- yield defer.ensureDeferred(self.store.store_device("user2", "device2", None))
+ self.get_success(self.store.store_device("user1", "device1", None))
+ self.get_success(self.store.store_device("user1", "device2", None))
+ self.get_success(self.store.store_device("user2", "device1", None))
+ self.get_success(self.store.store_device("user2", "device2", None))
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
)
- res = yield defer.ensureDeferred(
+ res = self.get_success(
self.store.get_e2e_device_keys_for_cs_api(
(("user1", "device1"), ("user2", "device2"))
)
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 485f1ee033..239f7c9faf 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2016 OpenMarket Ltd
+# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,10 +15,7 @@
from mock import Mock
-from twisted.internet import defer
-
-import tests.unittest
-import tests.utils
+from tests.unittest import HomeserverTestCase
USER_ID = "@user:example.com"
@@ -30,37 +27,31 @@ HIGHLIGHT = [
]
-class EventPushActionsStoreTestCase(tests.unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
+class EventPushActionsStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.persist_events_store = hs.get_datastores().persist_events
- @defer.inlineCallbacks
def test_get_unread_push_actions_for_user_in_range_for_http(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_http(
USER_ID, 0, 1000, 20
)
)
- @defer.inlineCallbacks
def test_get_unread_push_actions_for_user_in_range_for_email(self):
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_email(
USER_ID, 0, 1000, 20
)
)
- @defer.inlineCallbacks
def test_count_aggregation(self):
room_id = "!foo:example.com"
user_id = "@user1235:example.com"
- @defer.inlineCallbacks
def _assert_counts(noitf_count, highlight_count):
- counts = yield defer.ensureDeferred(
+ counts = self.get_success(
self.store.db_pool.runInteraction(
"", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
)
@@ -74,7 +65,6 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
},
)
- @defer.inlineCallbacks
def _inject_actions(stream, action):
event = Mock()
event.room_id = room_id
@@ -82,14 +72,14 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
event.internal_metadata.stream_ordering = stream
event.depth = stream
- yield defer.ensureDeferred(
+ self.get_success(
self.store.add_push_actions_to_staging(
event.event_id,
{user_id: action},
False,
)
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.db_pool.runInteraction(
"",
self.persist_events_store._set_push_actions_for_event_and_users_txn,
@@ -99,14 +89,14 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
)
def _rotate(stream):
- return defer.ensureDeferred(
+ self.get_success(
self.store.db_pool.runInteraction(
"", self.store._rotate_notifs_before_txn, stream
)
)
def _mark_read(stream, depth):
- return defer.ensureDeferred(
+ self.get_success(
self.store.db_pool.runInteraction(
"",
self.store._remove_old_push_actions_before_txn,
@@ -116,49 +106,48 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
)
)
- yield _assert_counts(0, 0)
- yield _inject_actions(1, PlAIN_NOTIF)
- yield _assert_counts(1, 0)
- yield _rotate(2)
- yield _assert_counts(1, 0)
+ _assert_counts(0, 0)
+ _inject_actions(1, PlAIN_NOTIF)
+ _assert_counts(1, 0)
+ _rotate(2)
+ _assert_counts(1, 0)
- yield _inject_actions(3, PlAIN_NOTIF)
- yield _assert_counts(2, 0)
- yield _rotate(4)
- yield _assert_counts(2, 0)
+ _inject_actions(3, PlAIN_NOTIF)
+ _assert_counts(2, 0)
+ _rotate(4)
+ _assert_counts(2, 0)
- yield _inject_actions(5, PlAIN_NOTIF)
- yield _mark_read(3, 3)
- yield _assert_counts(1, 0)
+ _inject_actions(5, PlAIN_NOTIF)
+ _mark_read(3, 3)
+ _assert_counts(1, 0)
- yield _mark_read(5, 5)
- yield _assert_counts(0, 0)
+ _mark_read(5, 5)
+ _assert_counts(0, 0)
- yield _inject_actions(6, PlAIN_NOTIF)
- yield _rotate(7)
+ _inject_actions(6, PlAIN_NOTIF)
+ _rotate(7)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.db_pool.simple_delete(
table="event_push_actions", keyvalues={"1": 1}, desc=""
)
)
- yield _assert_counts(1, 0)
+ _assert_counts(1, 0)
- yield _mark_read(7, 7)
- yield _assert_counts(0, 0)
+ _mark_read(7, 7)
+ _assert_counts(0, 0)
- yield _inject_actions(8, HIGHLIGHT)
- yield _assert_counts(1, 1)
- yield _rotate(9)
- yield _assert_counts(1, 1)
- yield _rotate(10)
- yield _assert_counts(1, 1)
+ _inject_actions(8, HIGHLIGHT)
+ _assert_counts(1, 1)
+ _rotate(9)
+ _assert_counts(1, 1)
+ _rotate(10)
+ _assert_counts(1, 1)
- @defer.inlineCallbacks
def test_find_first_stream_ordering_after_ts(self):
def add_event(so, ts):
- return defer.ensureDeferred(
+ self.get_success(
self.store.db_pool.simple_insert(
"events",
{
@@ -177,24 +166,16 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
)
# start with the base case where there are no events in the table
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(11)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(11))
self.assertEqual(r, 0)
# now with one event
- yield add_event(2, 10)
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(9)
- )
+ add_event(2, 10)
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(9))
self.assertEqual(r, 2)
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(10)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(10))
self.assertEqual(r, 2)
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(11)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(11))
self.assertEqual(r, 3)
# add a bunch of dummy events to the events table
@@ -205,39 +186,27 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
(10, 130),
(20, 140),
):
- yield add_event(stream_ordering, ts)
+ add_event(stream_ordering, ts)
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(110)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(110))
self.assertEqual(r, 3, "First event after 110ms should be 3, was %i" % r)
# 4 and 5 are both after 120: we want 4 rather than 5
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(120)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(120))
self.assertEqual(r, 4, "First event after 120ms should be 4, was %i" % r)
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(129)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(129))
self.assertEqual(r, 10, "First event after 129ms should be 10, was %i" % r)
# check we can get the last event
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(140)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(140))
self.assertEqual(r, 20, "First event after 14ms should be 20, was %i" % r)
# off the end
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(160)
- )
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(160))
self.assertEqual(r, 21)
# check we can find an event at ordering zero
- yield add_event(0, 5)
- r = yield defer.ensureDeferred(
- self.store.find_first_stream_ordering_after_ts(1)
- )
+ add_event(0, 5)
+ r = self.get_success(self.store.find_first_stream_ordering_after_ts(1))
self.assertEqual(r, 0)
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index ea63bd56b4..d18ceb41a9 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,59 +13,50 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from twisted.internet import defer
-
from synapse.types import UserID
from tests import unittest
-from tests.utils import setup_test_homeserver
-
-class ProfileStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield setup_test_homeserver(self.addCleanup)
+class ProfileStoreTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.u_frank = UserID.from_string("@frank:test")
- @defer.inlineCallbacks
def test_displayname(self):
- yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
+ self.get_success(self.store.create_profile(self.u_frank.localpart))
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
)
self.assertEquals(
"Frank",
(
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_profile_displayname(self.u_frank.localpart)
)
),
)
# test set to None
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_profile_displayname(self.u_frank.localpart, None)
)
self.assertIsNone(
(
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_profile_displayname(self.u_frank.localpart)
)
)
)
- @defer.inlineCallbacks
def test_avatar_url(self):
- yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
+ self.get_success(self.store.create_profile(self.u_frank.localpart))
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_profile_avatar_url(
self.u_frank.localpart, "http://my.site/here"
)
@@ -74,20 +65,20 @@ class ProfileStoreTestCase(unittest.TestCase):
self.assertEquals(
"http://my.site/here",
(
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_profile_avatar_url(self.u_frank.localpart)
)
),
)
# test set to None
- yield defer.ensureDeferred(
+ self.get_success(
self.store.set_profile_avatar_url(self.u_frank.localpart, None)
)
self.assertIsNone(
(
- yield defer.ensureDeferred(
+ self.get_success(
self.store.get_profile_avatar_url(self.u_frank.localpart)
)
)
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index b2a0e60856..2622207639 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,8 +15,6 @@
from canonicaljson import json
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.types import RoomID, UserID
@@ -230,10 +227,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self._base_builder = base_builder
self._event_id = event_id
- @defer.inlineCallbacks
- def build(self, prev_event_ids, auth_event_ids):
- built_event = yield defer.ensureDeferred(
- self._base_builder.build(prev_event_ids, auth_event_ids)
+ async def build(self, prev_event_ids, auth_event_ids):
+ built_event = await self._base_builder.build(
+ prev_event_ids, auth_event_ids
)
built_event._event_id = self._event_id
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 4eb41c46e8..c82cf15bc2 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,21 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from twisted.internet import defer
-
from synapse.api.constants import UserTypes
from synapse.api.errors import ThreepidValidationError
-from tests import unittest
-from tests.utils import setup_test_homeserver
-
+from tests.unittest import HomeserverTestCase
-class RegistrationStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield setup_test_homeserver(self.addCleanup)
+class RegistrationStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.user_id = "@my-user:test"
@@ -35,9 +28,8 @@ class RegistrationStoreTestCase(unittest.TestCase):
self.pwhash = "{xx1}123456789"
self.device_id = "akgjhdjklgshg"
- @defer.inlineCallbacks
def test_register(self):
- yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
+ self.get_success(self.store.register_user(self.user_id, self.pwhash))
self.assertEquals(
{
@@ -49,93 +41,81 @@ class RegistrationStoreTestCase(unittest.TestCase):
"consent_version": None,
"consent_server_notice_sent": None,
"appservice_id": None,
- "creation_ts": 1000,
+ "creation_ts": 0,
"user_type": None,
"deactivated": 0,
"shadow_banned": 0,
},
- (yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))),
+ (self.get_success(self.store.get_user_by_id(self.user_id))),
)
- @defer.inlineCallbacks
def test_add_tokens(self):
- yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
- yield defer.ensureDeferred(
+ self.get_success(self.store.register_user(self.user_id, self.pwhash))
+ self.get_success(
self.store.add_access_token_to_user(
self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
)
)
- result = yield defer.ensureDeferred(
- self.store.get_user_by_access_token(self.tokens[1])
- )
+ result = self.get_success(self.store.get_user_by_access_token(self.tokens[1]))
self.assertEqual(result.user_id, self.user_id)
self.assertEqual(result.device_id, self.device_id)
self.assertIsNotNone(result.token_id)
- @defer.inlineCallbacks
def test_user_delete_access_tokens(self):
# add some tokens
- yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
- yield defer.ensureDeferred(
+ self.get_success(self.store.register_user(self.user_id, self.pwhash))
+ self.get_success(
self.store.add_access_token_to_user(
self.user_id, self.tokens[0], device_id=None, valid_until_ms=None
)
)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.add_access_token_to_user(
self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
)
)
# now delete some
- yield defer.ensureDeferred(
+ self.get_success(
self.store.user_delete_access_tokens(self.user_id, device_id=self.device_id)
)
# check they were deleted
- user = yield defer.ensureDeferred(
- self.store.get_user_by_access_token(self.tokens[1])
- )
+ user = self.get_success(self.store.get_user_by_access_token(self.tokens[1]))
self.assertIsNone(user, "access token was not deleted by device_id")
# check the one not associated with the device was not deleted
- user = yield defer.ensureDeferred(
- self.store.get_user_by_access_token(self.tokens[0])
- )
+ user = self.get_success(self.store.get_user_by_access_token(self.tokens[0]))
self.assertEqual(self.user_id, user.user_id)
# now delete the rest
- yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id))
+ self.get_success(self.store.user_delete_access_tokens(self.user_id))
- user = yield defer.ensureDeferred(
- self.store.get_user_by_access_token(self.tokens[0])
- )
+ user = self.get_success(self.store.get_user_by_access_token(self.tokens[0]))
self.assertIsNone(user, "access token was not deleted without device_id")
- @defer.inlineCallbacks
def test_is_support_user(self):
TEST_USER = "@test:test"
SUPPORT_USER = "@support:test"
- res = yield defer.ensureDeferred(self.store.is_support_user(None))
+ res = self.get_success(self.store.is_support_user(None))
self.assertFalse(res)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.register_user(user_id=TEST_USER, password_hash=None)
)
- res = yield defer.ensureDeferred(self.store.is_support_user(TEST_USER))
+ res = self.get_success(self.store.is_support_user(TEST_USER))
self.assertFalse(res)
- yield defer.ensureDeferred(
+ self.get_success(
self.store.register_user(
user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT
)
)
- res = yield defer.ensureDeferred(self.store.is_support_user(SUPPORT_USER))
+ res = self.get_success(self.store.is_support_user(SUPPORT_USER))
self.assertTrue(res)
- @defer.inlineCallbacks
def test_3pid_inhibit_invalid_validation_session_error(self):
"""Tests that enabling the configuration option to inhibit 3PID errors on
/requestToken also inhibits validation errors caused by an unknown session ID.
@@ -143,30 +123,28 @@ class RegistrationStoreTestCase(unittest.TestCase):
# Check that, with the config setting set to false (the default value), a
# validation error is caused by the unknown session ID.
- try:
- yield defer.ensureDeferred(
- self.store.validate_threepid_session(
- "fake_sid",
- "fake_client_secret",
- "fake_token",
- 0,
- )
- )
- except ThreepidValidationError as e:
- self.assertEquals(e.msg, "Unknown session_id", e)
+ e = self.get_failure(
+ self.store.validate_threepid_session(
+ "fake_sid",
+ "fake_client_secret",
+ "fake_token",
+ 0,
+ ),
+ ThreepidValidationError,
+ )
+ self.assertEquals(e.value.msg, "Unknown session_id", e)
# Set the config setting to true.
self.store._ignore_unknown_session_error = True
# Check that now the validation error is caused by the token not matching.
- try:
- yield defer.ensureDeferred(
- self.store.validate_threepid_session(
- "fake_sid",
- "fake_client_secret",
- "fake_token",
- 0,
- )
- )
- except ThreepidValidationError as e:
- self.assertEquals(e.msg, "Validation token not found or has expired", e)
+ e = self.get_failure(
+ self.store.validate_threepid_session(
+ "fake_sid",
+ "fake_client_secret",
+ "fake_token",
+ 0,
+ ),
+ ThreepidValidationError,
+ )
+ self.assertEquals(e.value.msg, "Validation token not found or has expired", e)
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index bc8400f240..0089d33c93 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,22 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes
from synapse.api.room_versions import RoomVersions
from synapse.types import RoomAlias, RoomID, UserID
-from tests import unittest
-from tests.utils import setup_test_homeserver
-
+from tests.unittest import HomeserverTestCase
-class RoomStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield setup_test_homeserver(self.addCleanup)
+class RoomStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
# We can't test RoomStore on its own without the DirectoryStore, for
# management of the 'room_aliases' table
self.store = hs.get_datastore()
@@ -37,7 +30,7 @@ class RoomStoreTestCase(unittest.TestCase):
self.alias = RoomAlias.from_string("#a-room-name:test")
self.u_creator = UserID.from_string("@creator:test")
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_room(
self.room.to_string(),
room_creator_user_id=self.u_creator.to_string(),
@@ -46,7 +39,6 @@ class RoomStoreTestCase(unittest.TestCase):
)
)
- @defer.inlineCallbacks
def test_get_room(self):
self.assertDictContainsSubset(
{
@@ -54,16 +46,12 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(),
"is_public": True,
},
- (yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))),
+ (self.get_success(self.store.get_room(self.room.to_string()))),
)
- @defer.inlineCallbacks
def test_get_room_unknown_room(self):
- self.assertIsNone(
- (yield defer.ensureDeferred(self.store.get_room("!uknown:test")))
- )
+ self.assertIsNone((self.get_success(self.store.get_room("!uknown:test"))))
- @defer.inlineCallbacks
def test_get_room_with_stats(self):
self.assertDictContainsSubset(
{
@@ -71,29 +59,17 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(),
"public": True,
},
- (
- yield defer.ensureDeferred(
- self.store.get_room_with_stats(self.room.to_string())
- )
- ),
+ (self.get_success(self.store.get_room_with_stats(self.room.to_string()))),
)
- @defer.inlineCallbacks
def test_get_room_with_stats_unknown_room(self):
self.assertIsNone(
- (
- yield defer.ensureDeferred(
- self.store.get_room_with_stats("!uknown:test")
- )
- ),
+ (self.get_success(self.store.get_room_with_stats("!uknown:test"))),
)
-class RoomEventsStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = setup_test_homeserver(self.addCleanup)
-
+class RoomEventsStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
# Room events need the full datastore, for persist_event() and
# get_room_state()
self.store = hs.get_datastore()
@@ -102,7 +78,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
self.room = RoomID.from_string("!abcde:test")
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_room(
self.room.to_string(),
room_creator_user_id="@creator:text",
@@ -111,23 +87,21 @@ class RoomEventsStoreTestCase(unittest.TestCase):
)
)
- @defer.inlineCallbacks
def inject_room_event(self, **kwargs):
- yield defer.ensureDeferred(
+ self.get_success(
self.storage.persistence.persist_event(
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
)
)
- @defer.inlineCallbacks
def STALE_test_room_name(self):
name = "A-Room-Name"
- yield self.inject_room_event(
+ self.inject_room_event(
etype=EventTypes.Name, name=name, content={"name": name}, depth=1
)
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.store.get_current_state(room_id=self.room.to_string())
)
@@ -137,15 +111,14 @@ class RoomEventsStoreTestCase(unittest.TestCase):
state[0],
)
- @defer.inlineCallbacks
def STALE_test_room_topic(self):
topic = "A place for things"
- yield self.inject_room_event(
+ self.inject_room_event(
etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1
)
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.store.get_current_state(room_id=self.room.to_string())
)
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 2471f1267d..f06b452fa9 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,24 +15,18 @@
import logging
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.storage.state import StateFilter
from synapse.types import RoomID, UserID
-import tests.unittest
-import tests.utils
+from tests.unittest import HomeserverTestCase
logger = logging.getLogger(__name__)
-class StateStoreTestCase(tests.unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
-
+class StateStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.state_datastore = self.storage.state.stores.state
@@ -44,7 +38,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room = RoomID.from_string("!abc123:test")
- yield defer.ensureDeferred(
+ self.get_success(
self.store.store_room(
self.room.to_string(),
room_creator_user_id="@creator:text",
@@ -53,7 +47,6 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
)
- @defer.inlineCallbacks
def inject_state_event(self, room, sender, typ, state_key, content):
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
@@ -66,13 +59,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
},
)
- event, context = yield defer.ensureDeferred(
+ event, context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
- yield defer.ensureDeferred(
- self.storage.persistence.persist_event(event, context)
- )
+ self.get_success(self.storage.persistence.persist_event(event, context))
return event
@@ -82,16 +73,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(s1[t].event_id, s2[t].event_id)
self.assertEqual(len(s1), len(s2))
- @defer.inlineCallbacks
def test_get_state_groups_ids(self):
- e1 = yield self.inject_state_event(
- self.room, self.u_alice, EventTypes.Create, "", {}
- )
- e2 = yield self.inject_state_event(
+ e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
+ e2 = self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
- state_group_map = yield defer.ensureDeferred(
+ state_group_map = self.get_success(
self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
)
self.assertEqual(len(state_group_map), 1)
@@ -101,16 +89,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
{(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id},
)
- @defer.inlineCallbacks
def test_get_state_groups(self):
- e1 = yield self.inject_state_event(
- self.room, self.u_alice, EventTypes.Create, "", {}
- )
- e2 = yield self.inject_state_event(
+ e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
+ e2 = self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
- state_group_map = yield defer.ensureDeferred(
+ state_group_map = self.get_success(
self.storage.state.get_state_groups(self.room, [e2.event_id])
)
self.assertEqual(len(state_group_map), 1)
@@ -118,32 +103,29 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
- @defer.inlineCallbacks
def test_get_state_for_event(self):
# this defaults to a linear DAG as each new injection defaults to whatever
# forward extremities are currently in the DB for this room.
- e1 = yield self.inject_state_event(
- self.room, self.u_alice, EventTypes.Create, "", {}
- )
- e2 = yield self.inject_state_event(
+ e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
+ e2 = self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
- e3 = yield self.inject_state_event(
+ e3 = self.inject_state_event(
self.room,
self.u_alice,
EventTypes.Member,
self.u_alice.to_string(),
{"membership": Membership.JOIN},
)
- e4 = yield self.inject_state_event(
+ e4 = self.inject_state_event(
self.room,
self.u_bob,
EventTypes.Member,
self.u_bob.to_string(),
{"membership": Membership.JOIN},
)
- e5 = yield self.inject_state_event(
+ e5 = self.inject_state_event(
self.room,
self.u_bob,
EventTypes.Member,
@@ -152,9 +134,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we get the full state as of the final event
- state = yield defer.ensureDeferred(
- self.storage.state.get_state_for_event(e5.event_id)
- )
+ state = self.get_success(self.storage.state.get_state_for_event(e5.event_id))
self.assertIsNotNone(e4)
@@ -170,7 +150,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we can filter to the m.room.name event (with a '' state key)
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
)
@@ -179,7 +159,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can filter to the m.room.name event (with a wildcard None state key)
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
)
@@ -188,7 +168,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can grab the m.room.member events (with a wildcard None state key)
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
)
@@ -200,7 +180,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# check we can grab a specific room member without filtering out the
# other event types
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
@@ -220,7 +200,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check that we can grab everything except members
- state = yield defer.ensureDeferred(
+ state = self.get_success(
self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
@@ -238,17 +218,14 @@ class StateStoreTestCase(tests.unittest.TestCase):
#######################################################
room_id = self.room.to_string()
- group_ids = yield defer.ensureDeferred(
+ group_ids = self.get_success(
self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
)
group = list(group_ids.keys())[0]
# test _get_state_for_group_using_cache correctly filters out members
# with types=[]
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -265,10 +242,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -281,10 +255,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with wildcard types
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -301,10 +272,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -324,10 +292,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -344,10 +309,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
state_dict,
)
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -360,10 +322,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -413,10 +372,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters out members
# with types=[]
room_id = self.room.to_string()
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -428,10 +384,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
room_id = self.room.to_string()
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -444,10 +397,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# wildcard types
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -458,10 +408,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -480,10 +427,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -494,10 +438,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
@@ -510,10 +451,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_state_for_group_using_cache correctly filters in members
# with specific types
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
@@ -524,10 +462,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertEqual(is_all, False)
self.assertDictEqual({}, state_dict)
- (
- state_dict,
- is_all,
- ) = yield self.state_datastore._get_state_for_group_using_cache(
+ (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index a6f63f4aaf..019c5b7b14 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,10 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
-from tests import unittest
-from tests.utils import setup_test_homeserver
+from tests.unittest import HomeserverTestCase, override_config
ALICE = "@alice:a"
BOB = "@bob:b"
@@ -25,73 +22,52 @@ BOBBY = "@bobby:a"
BELA = "@somenickname:a"
-class UserDirectoryStoreTestCase(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- self.hs = yield setup_test_homeserver(self.addCleanup)
- self.store = self.hs.get_datastore()
+class UserDirectoryStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
# alice and bob are both in !room_id. bobby is not but shares
# a homeserver with alice.
- yield defer.ensureDeferred(
- self.store.update_profile_in_user_dir(ALICE, "alice", None)
- )
- yield defer.ensureDeferred(
- self.store.update_profile_in_user_dir(BOB, "bob", None)
- )
- yield defer.ensureDeferred(
- self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
- )
- yield defer.ensureDeferred(
- self.store.update_profile_in_user_dir(BELA, "Bela", None)
- )
- yield defer.ensureDeferred(
- self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
- )
+ self.get_success(self.store.update_profile_in_user_dir(ALICE, "alice", None))
+ self.get_success(self.store.update_profile_in_user_dir(BOB, "bob", None))
+ self.get_success(self.store.update_profile_in_user_dir(BOBBY, "bobby", None))
+ self.get_success(self.store.update_profile_in_user_dir(BELA, "Bela", None))
+ self.get_success(self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)))
- @defer.inlineCallbacks
def test_search_user_dir(self):
# normally when alice searches the directory she should just find
# bob because bobby doesn't share a room with her.
- r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
+ r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10))
self.assertFalse(r["limited"])
self.assertEqual(1, len(r["results"]))
self.assertDictEqual(
r["results"][0], {"user_id": BOB, "display_name": "bob", "avatar_url": None}
)
- @defer.inlineCallbacks
+ @override_config({"user_directory": {"search_all_users": True}})
def test_search_user_dir_all_users(self):
- self.hs.config.user_directory_search_all_users = True
- try:
- r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
- self.assertFalse(r["limited"])
- self.assertEqual(2, len(r["results"]))
- self.assertDictEqual(
- r["results"][0],
- {"user_id": BOB, "display_name": "bob", "avatar_url": None},
- )
- self.assertDictEqual(
- r["results"][1],
- {"user_id": BOBBY, "display_name": "bobby", "avatar_url": None},
- )
- finally:
- self.hs.config.user_directory_search_all_users = False
+ r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10))
+ self.assertFalse(r["limited"])
+ self.assertEqual(2, len(r["results"]))
+ self.assertDictEqual(
+ r["results"][0],
+ {"user_id": BOB, "display_name": "bob", "avatar_url": None},
+ )
+ self.assertDictEqual(
+ r["results"][1],
+ {"user_id": BOBBY, "display_name": "bobby", "avatar_url": None},
+ )
- @defer.inlineCallbacks
+ @override_config({"user_directory": {"search_all_users": True}})
def test_search_user_dir_stop_words(self):
"""Tests that a user can look up another user by searching for the start if its
display name even if that name happens to be a common English word that would
usually be ignored in full text searches.
"""
- self.hs.config.user_directory_search_all_users = True
- try:
- r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "be", 10))
- self.assertFalse(r["limited"])
- self.assertEqual(1, len(r["results"]))
- self.assertDictEqual(
- r["results"][0],
- {"user_id": BELA, "display_name": "Bela", "avatar_url": None},
- )
- finally:
- self.hs.config.user_directory_search_all_users = False
+ r = self.get_success(self.store.search_user_dir(ALICE, "be", 10))
+ self.assertFalse(r["limited"])
+ self.assertEqual(1, len(r["results"]))
+ self.assertDictEqual(
+ r["results"][0],
+ {"user_id": BELA, "display_name": "Bela", "avatar_url": None},
+ )
diff --git a/tests/unittest.py b/tests/unittest.py
index 58a4daa1ec..57b6a395c7 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -471,7 +471,7 @@ class HomeserverTestCase(TestCase):
kwargs["config"] = config_obj
async def run_bg_updates():
- with LoggingContext("run_bg_updates", request="run_bg_updates-1"):
+ with LoggingContext("run_bg_updates"):
while not await stor.db_pool.updates.has_completed_background_updates():
await stor.db_pool.updates.do_next_background_update(1)
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index afb11b9caf..e434e21aee 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -661,14 +661,13 @@ class CachedListDescriptorTestCase(unittest.TestCase):
@descriptors.cachedList("fn", "args1")
async def list_fn(self, args1, arg2):
- assert current_context().request == "c1"
+ assert current_context().name == "c1"
# we want this to behave like an asynchronous function
await run_on_reactor()
- assert current_context().request == "c1"
+ assert current_context().name == "c1"
return self.mock(args1, arg2)
- with LoggingContext() as c1:
- c1.request = "c1"
+ with LoggingContext("c1") as c1:
obj = Cls()
obj.mock.return_value = {10: "fish", 20: "chips"}
d1 = obj.list_fn([10, 20], 2)
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 58ee918f65..5d9c4665aa 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -17,11 +17,10 @@ from .. import unittest
class LoggingContextTestCase(unittest.TestCase):
def _check_test_key(self, value):
- self.assertEquals(current_context().request, value)
+ self.assertEquals(current_context().name, value)
def test_with_context(self):
- with LoggingContext() as context_one:
- context_one.request = "test"
+ with LoggingContext("test"):
self._check_test_key("test")
@defer.inlineCallbacks
@@ -30,15 +29,13 @@ class LoggingContextTestCase(unittest.TestCase):
@defer.inlineCallbacks
def competing_callback():
- with LoggingContext() as competing_context:
- competing_context.request = "competing"
+ with LoggingContext("competing"):
yield clock.sleep(0)
self._check_test_key("competing")
reactor.callLater(0, competing_callback)
- with LoggingContext() as context_one:
- context_one.request = "one"
+ with LoggingContext("one"):
yield clock.sleep(0)
self._check_test_key("one")
@@ -47,9 +44,7 @@ class LoggingContextTestCase(unittest.TestCase):
callback_completed = [False]
- with LoggingContext() as context_one:
- context_one.request = "one"
-
+ with LoggingContext("one"):
# fire off function, but don't wait on it.
d2 = run_in_background(function)
@@ -133,9 +128,7 @@ class LoggingContextTestCase(unittest.TestCase):
sentinel_context = current_context()
- with LoggingContext() as context_one:
- context_one.request = "one"
-
+ with LoggingContext("one"):
d1 = make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable
self.assertIs(current_context(), sentinel_context)
@@ -149,9 +142,7 @@ class LoggingContextTestCase(unittest.TestCase):
def test_make_deferred_yieldable_with_chained_deferreds(self):
sentinel_context = current_context()
- with LoggingContext() as context_one:
- context_one.request = "one"
-
+ with LoggingContext("one"):
d1 = make_deferred_yieldable(_chained_deferred_function())
# make sure that the context was reset by make_deferred_yieldable
self.assertIs(current_context(), sentinel_context)
@@ -166,9 +157,7 @@ class LoggingContextTestCase(unittest.TestCase):
"""Check that make_deferred_yieldable does the right thing when its
argument isn't actually a deferred"""
- with LoggingContext() as context_one:
- context_one.request = "one"
-
+ with LoggingContext("one"):
d1 = make_deferred_yieldable("bum")
self._check_test_key("one")
@@ -177,9 +166,9 @@ class LoggingContextTestCase(unittest.TestCase):
self._check_test_key("one")
def test_nested_logging_context(self):
- with LoggingContext(request="foo"):
+ with LoggingContext("foo"):
nested_context = nested_logging_context(suffix="bar")
- self.assertEqual(nested_context.request, "foo-bar")
+ self.assertEqual(nested_context.name, "foo-bar")
@defer.inlineCallbacks
def test_make_deferred_yieldable_with_await(self):
@@ -193,9 +182,7 @@ class LoggingContextTestCase(unittest.TestCase):
sentinel_context = current_context()
- with LoggingContext() as context_one:
- context_one.request = "one"
-
+ with LoggingContext("one"):
d1 = make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable
self.assertIs(current_context(), sentinel_context)
|