diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index ce7525e29c..ee48f9e546 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -15,15 +15,22 @@
# limitations under the License.
from typing import Optional
+from unittest import mock
from twisted.test.proto_helpers import MemoryReactor
+from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import NotFoundError, SynapseError
+from synapse.appservice import ApplicationService
from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
from synapse.server import HomeServer
+from synapse.storage.databases.main.appservice import _make_exclusive_regex
+from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
+from tests.test_utils import make_awaitable
+from tests.unittest import override_config
user1 = "@boris:aaa"
user2 = "@theresa:bbb"
@@ -31,7 +38,12 @@ user2 = "@theresa:bbb"
class DeviceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- hs = self.setup_test_homeserver("server", federation_http_client=None)
+ self.appservice_api = mock.Mock()
+ hs = self.setup_test_homeserver(
+ "server",
+ federation_http_client=None,
+ application_service_api=self.appservice_api,
+ )
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.handler = handler
@@ -265,6 +277,127 @@ class DeviceTestCase(unittest.HomeserverTestCase):
)
self.reactor.advance(1000)
+ @override_config({"experimental_features": {"msc3984_appservice_key_query": True}})
+ def test_on_federation_query_user_devices_appservice(self) -> None:
+ """Test that querying of appservices for keys overrides responses from the database."""
+ local_user = "@boris:" + self.hs.hostname
+ device_1 = "abc"
+ device_2 = "def"
+ device_3 = "ghi"
+
+ # There are 3 devices:
+ #
+ # 1. One which is uploaded to the homeserver.
+ # 2. One which is uploaded to the homeserver, but a newer copy is returned
+ # by the appservice.
+ # 3. One which is only returned by the appservice.
+ device_key_1: JsonDict = {
+ "user_id": local_user,
+ "device_id": device_1,
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+ ],
+ "keys": {
+ "ed25519:abc": "base64+ed25519+key",
+ "curve25519:abc": "base64+curve25519+key",
+ },
+ "signatures": {local_user: {"ed25519:abc": "base64+signature"}},
+ }
+ device_key_2a: JsonDict = {
+ "user_id": local_user,
+ "device_id": device_2,
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+ ],
+ "keys": {
+ "ed25519:def": "base64+ed25519+key",
+ "curve25519:def": "base64+curve25519+key",
+ },
+ "signatures": {local_user: {"ed25519:def": "base64+signature"}},
+ }
+
+ device_key_2b: JsonDict = {
+ "user_id": local_user,
+ "device_id": device_2,
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+ ],
+ # The device ID is the same (above), but the keys are different.
+ "keys": {
+ "ed25519:xyz": "base64+ed25519+key",
+ "curve25519:xyz": "base64+curve25519+key",
+ },
+ "signatures": {local_user: {"ed25519:xyz": "base64+signature"}},
+ }
+ device_key_3: JsonDict = {
+ "user_id": local_user,
+ "device_id": device_3,
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+ ],
+ "keys": {
+ "ed25519:jkl": "base64+ed25519+key",
+ "curve25519:jkl": "base64+curve25519+key",
+ },
+ "signatures": {local_user: {"ed25519:jkl": "base64+signature"}},
+ }
+
+ # Upload keys for devices 1 & 2a.
+ e2e_keys_handler = self.hs.get_e2e_keys_handler()
+ self.get_success(
+ e2e_keys_handler.upload_keys_for_user(
+ local_user, device_1, {"device_keys": device_key_1}
+ )
+ )
+ self.get_success(
+ e2e_keys_handler.upload_keys_for_user(
+ local_user, device_2, {"device_keys": device_key_2a}
+ )
+ )
+
+ # Inject an appservice interested in this user.
+ appservice = ApplicationService(
+ token="i_am_an_app_service",
+ id="1234",
+ namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]},
+ # Note: this user does not have to match the regex above
+ sender="@as_main:test",
+ )
+ self.hs.get_datastores().main.services_cache = [appservice]
+ self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex(
+ [appservice]
+ )
+
+ # Setup a response.
+ self.appservice_api.query_keys.return_value = make_awaitable(
+ {
+ "device_keys": {
+ local_user: {device_2: device_key_2b, device_3: device_key_3}
+ }
+ }
+ )
+
+ # Request all devices.
+ res = self.get_success(
+ self.handler.on_federation_query_user_devices(local_user)
+ )
+ self.assertIn("devices", res)
+ res_devices = res["devices"]
+ for device in res_devices:
+ device["keys"].pop("unsigned", None)
+ self.assertEqual(
+ res_devices,
+ [
+ {"device_id": device_1, "keys": device_key_1},
+ {"device_id": device_2, "keys": device_key_2b},
+ {"device_id": device_3, "keys": device_key_3},
+ ],
+ )
+
class DehydrationTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index aff1ec4758..73822b07a5 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -586,6 +586,19 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
d = self.store.is_support_user(user_id)
self.assertFalse(self.get_success(d))
+ def test_invalid_user_id(self) -> None:
+ invalid_user_id = "+abcd"
+ self.get_failure(
+ self.handler.register_user(localpart=invalid_user_id), SynapseError
+ )
+
+ @override_config({"experimental_features": {"msc4009_e164_mxids": True}})
+ def text_extended_user_ids(self) -> None:
+ """+ should be allowed according to MSC4009."""
+ valid_user_id = "+1234"
+ user_id = self.get_success(self.handler.register_user(localpart=valid_user_id))
+ self.assertEqual(user_id, valid_user_id)
+
def test_invalid_user_id_length(self) -> None:
invalid_user_id = "x" * 256
self.get_failure(
diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py
index 6a38893b68..a444d822cd 100644
--- a/tests/handlers/test_room_member.py
+++ b/tests/handlers/test_room_member.py
@@ -333,6 +333,17 @@ class RoomMemberMasterHandlerTestCase(HomeserverTestCase):
self.get_success(self.store.is_locally_forgotten_room(self.room_id))
)
+ @override_config({"forget_rooms_on_leave": True})
+ def test_leave_and_auto_forget(self) -> None:
+ """Tests the `forget_rooms_on_leave` config option."""
+ self.helper.join(self.room_id, user=self.bob, tok=self.bob_token)
+
+ # alice is not the last room member that leaves and forgets the room
+ self.helper.leave(self.room_id, user=self.alice, tok=self.alice_token)
+ self.assertTrue(
+ self.get_success(self.store.did_forget(self.alice, self.room_id))
+ )
+
def test_leave_and_forget_last_user(self) -> None:
"""Tests that forget a room is successfully when the last user has left the room."""
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index 99cec0836b..54f558742d 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -962,3 +962,40 @@ class HTTPPusherTests(HomeserverTestCase):
channel.json_body["pushers"][0]["org.matrix.msc3881.device_id"],
lookup_result.device_id,
)
+
+ @override_config({"push": {"jitter_delay": "10s"}})
+ def test_jitter(self) -> None:
+ """Tests that enabling jitter actually delays sending push."""
+ user_id, access_token = self._make_user_with_pusher("user")
+ other_user_id, other_access_token = self._make_user_with_pusher("otheruser")
+
+ room = self.helper.create_room_as(user_id, tok=access_token)
+ self.helper.join(room=room, user=other_user_id, tok=other_access_token)
+
+ # Send a message and check that it did not generate a push, as it should
+ # be delayed.
+ self.helper.send(room, body="Hi!", tok=other_access_token)
+ self.assertEqual(len(self.push_attempts), 0)
+
+ # Now advance time past the max jitter, and assert the message was sent.
+ self.reactor.advance(15)
+ self.assertEqual(len(self.push_attempts), 1)
+
+ self.push_attempts[0][0].callback({})
+
+ # Now we send a bunch of messages and assert that they were all sent
+ # within the 10s max delay.
+ for _ in range(10):
+ self.helper.send(room, body="Hi!", tok=other_access_token)
+
+ index = 1
+ for _ in range(11):
+ while len(self.push_attempts) > index:
+ self.push_attempts[index][0].callback({})
+ self.pump()
+ index += 1
+
+ self.reactor.advance(1)
+ self.pump()
+
+ self.assertEqual(len(self.push_attempts), 11)
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 645a00b4b1..695e84357a 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -399,7 +399,7 @@ class ExperimentalFeaturesTestCase(unittest.HomeserverTestCase):
"PUT",
url,
content={
- "features": {"msc3026": True, "msc2654": True},
+ "features": {"msc3026": True, "msc3881": True},
},
access_token=self.admin_user_tok,
)
@@ -420,7 +420,7 @@ class ExperimentalFeaturesTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
True,
- channel.json_body["features"]["msc2654"],
+ channel.json_body["features"]["msc3881"],
)
# test disabling a feature works
@@ -448,10 +448,6 @@ class ExperimentalFeaturesTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
True,
- channel.json_body["features"]["msc2654"],
- )
- self.assertEqual(
- False,
channel.json_body["features"]["msc3881"],
)
self.assertEqual(
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 753ecc8d16..e5ba5a9706 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -22,7 +22,9 @@ from synapse.api.errors import SynapseError
from synapse.api.room_versions import RoomVersion
from synapse.config.homeserver import HomeServerConfig
from synapse.events import EventBase
-from synapse.events.third_party_rules import load_legacy_third_party_event_rules
+from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
+ load_legacy_third_party_event_rules,
+)
from synapse.rest import admin
from synapse.rest.client import account, login, profile, room
from synapse.server import HomeServer
@@ -146,7 +148,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
return ev.type != "foo.bar.forbidden", None
callback = Mock(spec=[], side_effect=check)
- self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [
+ self.hs.get_module_api_callbacks().third_party_event_rules._check_event_allowed_callbacks = [
callback
]
@@ -202,7 +204,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
) -> Tuple[bool, Optional[JsonDict]]:
raise NastyHackException(429, "message")
- self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
+ self.hs.get_module_api_callbacks().third_party_event_rules._check_event_allowed_callbacks = [
+ check
+ ]
# Make a request
channel = self.make_request(
@@ -229,7 +233,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
ev.content = {"x": "y"}
return True, None
- self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
+ self.hs.get_module_api_callbacks().third_party_event_rules._check_event_allowed_callbacks = [
+ check
+ ]
# now send the event
channel = self.make_request(
@@ -253,7 +259,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
d["content"] = {"x": "y"}
return True, d
- self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
+ self.hs.get_module_api_callbacks().third_party_event_rules._check_event_allowed_callbacks = [
+ check
+ ]
# now send the event
channel = self.make_request(
@@ -289,7 +297,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
}
return True, d
- self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
+ self.hs.get_module_api_callbacks().third_party_event_rules._check_event_allowed_callbacks = [
+ check
+ ]
# Send an event, then edit it.
channel = self.make_request(
@@ -440,7 +450,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
)
return True, None
- self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [test_fn]
+ self.hs.get_module_api_callbacks().third_party_event_rules._check_event_allowed_callbacks = [
+ test_fn
+ ]
# Sometimes the bug might not happen the first time the event type is added
# to the state but might happen when an event updates the state of the room for
@@ -466,7 +478,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
def test_on_new_event(self) -> None:
"""Test that the on_new_event callback is called on new events"""
on_new_event = Mock(make_awaitable(None))
- self.hs.get_third_party_event_rules()._on_new_event_callbacks.append(
+ self.hs.get_module_api_callbacks().third_party_event_rules._on_new_event_callbacks.append(
on_new_event
)
@@ -569,7 +581,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
# Register a mock callback.
m = Mock(return_value=make_awaitable(None))
- self.hs.get_third_party_event_rules()._on_profile_update_callbacks.append(m)
+ self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append(
+ m
+ )
# Change the display name.
channel = self.make_request(
@@ -628,7 +642,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
# Register a mock callback.
m = Mock(return_value=make_awaitable(None))
- self.hs.get_third_party_event_rules()._on_profile_update_callbacks.append(m)
+ self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append(
+ m
+ )
# Register an admin user.
self.register_user("admin", "password", admin=True)
@@ -667,7 +683,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"""
# Register a mocked callback.
deactivation_mock = Mock(return_value=make_awaitable(None))
- third_party_rules = self.hs.get_third_party_event_rules()
+ third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
third_party_rules._on_user_deactivation_status_changed_callbacks.append(
deactivation_mock,
)
@@ -675,7 +691,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
# deactivation code calls it in a way that let modules know the user is being
# deactivated.
profile_mock = Mock(return_value=make_awaitable(None))
- self.hs.get_third_party_event_rules()._on_profile_update_callbacks.append(
+ self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append(
profile_mock,
)
@@ -725,7 +741,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"""
# Register a mock callback.
m = Mock(return_value=make_awaitable(None))
- third_party_rules = self.hs.get_third_party_event_rules()
+ third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
third_party_rules._on_user_deactivation_status_changed_callbacks.append(m)
# Register an admin user.
@@ -779,7 +795,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"""
# Register a mocked callback.
deactivation_mock = Mock(return_value=make_awaitable(False))
- third_party_rules = self.hs.get_third_party_event_rules()
+ third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
third_party_rules._check_can_deactivate_user_callbacks.append(
deactivation_mock,
)
@@ -825,7 +841,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"""
# Register a mocked callback.
deactivation_mock = Mock(return_value=make_awaitable(False))
- third_party_rules = self.hs.get_third_party_event_rules()
+ third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
third_party_rules._check_can_deactivate_user_callbacks.append(
deactivation_mock,
)
@@ -864,7 +880,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"""
# Register a mocked callback.
shutdown_mock = Mock(return_value=make_awaitable(False))
- third_party_rules = self.hs.get_third_party_event_rules()
+ third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
third_party_rules._check_can_shutdown_room_callbacks.append(
shutdown_mock,
)
@@ -900,7 +916,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"""
# Register a mocked callback.
threepid_bind_mock = Mock(return_value=make_awaitable(None))
- third_party_rules = self.hs.get_third_party_event_rules()
+ third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
third_party_rules._on_threepid_bind_callbacks.append(threepid_bind_mock)
# Register an admin user.
@@ -947,8 +963,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
on_remove_user_third_party_identifier_callback_mock = Mock(
return_value=make_awaitable(None)
)
- third_party_rules = self.hs.get_third_party_event_rules()
- third_party_rules.register_third_party_rules_callbacks(
+ self.hs.get_module_api().register_third_party_rules_callbacks(
on_add_user_third_party_identifier=on_add_user_third_party_identifier_callback_mock,
on_remove_user_third_party_identifier=on_remove_user_third_party_identifier_callback_mock,
)
@@ -1009,8 +1024,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
on_remove_user_third_party_identifier_callback_mock = Mock(
return_value=make_awaitable(None)
)
- third_party_rules = self.hs.get_third_party_event_rules()
- third_party_rules.register_third_party_rules_callbacks(
+ self.hs.get_module_api().register_third_party_rules_callbacks(
on_remove_user_third_party_identifier=on_remove_user_third_party_identifier_callback_mock,
)
diff --git a/tests/server.py b/tests/server.py
index a49dc90e32..7296f0a552 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -73,11 +73,13 @@ from twisted.web.server import Request, Site
from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
from synapse.events.presence_router import load_legacy_presence_router
-from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.http.site import SynapseRequest
from synapse.logging.context import ContextResourceUsage
from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_checkers
+from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
+ load_legacy_third_party_event_rules,
+)
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.database import LoggingDatabaseConnection
diff --git a/tests/unittest.py b/tests/unittest.py
index ee2f78ab01..b6fdf69635 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -566,7 +566,9 @@ class HomeserverTestCase(TestCase):
client_ip,
)
- def setup_test_homeserver(self, *args: Any, **kwargs: Any) -> HomeServer:
+ def setup_test_homeserver(
+ self, name: Optional[str] = None, **kwargs: Any
+ ) -> HomeServer:
"""
Set up the test homeserver, meant to be called by the overridable
make_homeserver. It automatically passes through the test class's
@@ -585,15 +587,25 @@ class HomeserverTestCase(TestCase):
else:
config = kwargs["config"]
+ # The server name can be specified using either the `name` argument or a config
+ # override. The `name` argument takes precedence over any config overrides.
+ if name is not None:
+ config["server_name"] = name
+
# Parse the config from a config dict into a HomeServerConfig
config_obj = make_homeserver_config_obj(config)
kwargs["config"] = config_obj
+ # The server name in the config is now `name`, if provided, or the `server_name`
+ # from a config override, or the default of "test". Whichever it is, we
+ # construct a homeserver with a matching name.
+ kwargs["name"] = config_obj.server.server_name
+
async def run_bg_updates() -> None:
with LoggingContext("run_bg_updates"):
self.get_success(stor.db_pool.updates.run_background_updates(False))
- hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
+ hs = setup_test_homeserver(self.addCleanup, **kwargs)
stor = hs.get_datastores().main
# Run the database background updates, when running against "master".
|