diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 4ff04fc66b..013b9ee550 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -960,7 +960,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
appservice = ApplicationService(
token="i_am_an_app_service",
id="1234",
- namespaces={"users": [{"regex": r"@boris:*", "exclusive": True}]},
+ namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]},
# Note: this user does not have to match the regex above
sender="@as_main:test",
)
@@ -1015,3 +1015,122 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
},
},
)
+
+ @override_config({"experimental_features": {"msc3984_appservice_key_query": True}})
+ def test_query_local_devices_appservice(self) -> None:
+ """Test that querying of appservices for keys overrides responses from the database."""
+ local_user = "@boris:" + self.hs.hostname
+ device_1 = "abc"
+ device_2 = "def"
+ device_3 = "ghi"
+
+ # There are 3 devices:
+ #
+ # 1. One which is uploaded to the homeserver.
+ # 2. One which is uploaded to the homeserver, but a newer copy is returned
+ # by the appservice.
+ # 3. One which is only returned by the appservice.
+ device_key_1: JsonDict = {
+ "user_id": local_user,
+ "device_id": device_1,
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+ ],
+ "keys": {
+ "ed25519:abc": "base64+ed25519+key",
+ "curve25519:abc": "base64+curve25519+key",
+ },
+ "signatures": {local_user: {"ed25519:abc": "base64+signature"}},
+ }
+ device_key_2a: JsonDict = {
+ "user_id": local_user,
+ "device_id": device_2,
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+ ],
+ "keys": {
+ "ed25519:def": "base64+ed25519+key",
+ "curve25519:def": "base64+curve25519+key",
+ },
+ "signatures": {local_user: {"ed25519:def": "base64+signature"}},
+ }
+
+ device_key_2b: JsonDict = {
+ "user_id": local_user,
+ "device_id": device_2,
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+ ],
+ # The device ID is the same (above), but the keys are different.
+ "keys": {
+ "ed25519:xyz": "base64+ed25519+key",
+ "curve25519:xyz": "base64+curve25519+key",
+ },
+ "signatures": {local_user: {"ed25519:xyz": "base64+signature"}},
+ }
+ device_key_3: JsonDict = {
+ "user_id": local_user,
+ "device_id": device_3,
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+ ],
+ "keys": {
+ "ed25519:jkl": "base64+ed25519+key",
+ "curve25519:jkl": "base64+curve25519+key",
+ },
+ "signatures": {local_user: {"ed25519:jkl": "base64+signature"}},
+ }
+
+ # Upload keys for devices 1 & 2a.
+ self.get_success(
+ self.handler.upload_keys_for_user(
+ local_user, device_1, {"device_keys": device_key_1}
+ )
+ )
+ self.get_success(
+ self.handler.upload_keys_for_user(
+ local_user, device_2, {"device_keys": device_key_2a}
+ )
+ )
+
+ # Inject an appservice interested in this user.
+ appservice = ApplicationService(
+ token="i_am_an_app_service",
+ id="1234",
+ namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]},
+ # Note: this user does not have to match the regex above
+ sender="@as_main:test",
+ )
+ self.hs.get_datastores().main.services_cache = [appservice]
+ self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex(
+ [appservice]
+ )
+
+ # Setup a response.
+ self.appservice_api.query_keys.return_value = make_awaitable(
+ {
+ "device_keys": {
+ local_user: {device_2: device_key_2b, device_3: device_key_3}
+ }
+ }
+ )
+
+ # Request all devices.
+ res = self.get_success(self.handler.query_local_devices({local_user: None}))
+ self.assertIn(local_user, res)
+ for res_key in res[local_user].values():
+ res_key.pop("unsigned", None)
+ self.assertDictEqual(
+ res,
+ {
+ local_user: {
+ device_1: device_key_1,
+ device_2: device_key_2b,
+ device_3: device_key_3,
+ }
+ },
+ )
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 951caaa6b3..0a8bae54fb 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -922,7 +922,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
auth_provider_session_id=None,
)
- @override_config({"oidc_config": DEFAULT_CONFIG})
+ @override_config({"oidc_config": {**DEFAULT_CONFIG, "enable_registration": True}})
def test_map_userinfo_to_user(self) -> None:
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
userinfo: dict = {
@@ -975,6 +975,21 @@ class OidcHandlerTestCase(HomeserverTestCase):
"Mapping provider does not support de-duplicating Matrix IDs",
)
+ @override_config({"oidc_config": {**DEFAULT_CONFIG, "enable_registration": False}})
+ def test_map_userinfo_to_user_does_not_register_new_user(self) -> None:
+ """Ensures new users are not registered if the enabled registration flag is disabled."""
+ userinfo: dict = {
+ "sub": "test_user",
+ "username": "test_user",
+ }
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
+ self.assertRenderedError(
+ "mapping_error",
+ "User does not exist and registrations are disabled",
+ )
+
@override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
def test_map_userinfo_to_existing_user(self) -> None:
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 46a8e2013e..0f1a8a145f 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -54,6 +54,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
if not hiredis:
skip = "Requires hiredis"
+ if not USE_POSTGRES_FOR_TESTS:
+ # Redis replication only takes place on Postgres
+ skip = "Requires Postgres"
+
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py
index 01df1be047..b9075e3f20 100644
--- a/tests/replication/tcp/streams/test_account_data.py
+++ b/tests/replication/tcp/streams/test_account_data.py
@@ -37,11 +37,6 @@ class AccountDataStreamTestCase(BaseStreamTestCase):
# also one global update
self.get_success(store.add_account_data_for_user("test_user", "m.global", {}))
- # tell the notifier to catch up to avoid duplicate rows.
- # workaround for https://github.com/matrix-org/synapse/issues/7360
- # FIXME remove this when the above is fixed
- self.replicate()
-
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
diff --git a/tests/replication/tcp/streams/test_to_device.py b/tests/replication/tcp/streams/test_to_device.py
new file mode 100644
index 0000000000..fb9eac668f
--- /dev/null
+++ b/tests/replication/tcp/streams/test_to_device.py
@@ -0,0 +1,89 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+import synapse
+from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT
+from synapse.types import JsonDict
+
+from tests.replication._base import BaseStreamTestCase
+
+logger = logging.getLogger(__name__)
+
+
+class ToDeviceStreamTestCase(BaseStreamTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ synapse.rest.client.login.register_servlets,
+ ]
+
+ def test_to_device_stream(self) -> None:
+ store = self.hs.get_datastores().main
+
+ user1 = self.register_user("user1", "pass")
+ self.login("user1", "pass", "device")
+ user2 = self.register_user("user2", "pass")
+ self.login("user2", "pass", "device")
+
+ # connect to pull the updates related to users creation/login
+ self.reconnect()
+ self.replicate()
+ self.test_handler.received_rdata_rows.clear()
+ # disconnect so we can accumulate the updates without pulling them
+ self.disconnect()
+
+ msg: JsonDict = {}
+ msg["sender"] = "@sender:example.org"
+ msg["type"] = "m.new_device"
+
+ # add messages to the device inbox for user1 up until the
+ # limit defined for a stream update batch
+ for i in range(0, _STREAM_UPDATE_TARGET_ROW_COUNT):
+ msg["content"] = {"device": {}}
+ messages = {user1: {"device": msg}}
+
+ self.get_success(
+ store.add_messages_from_remote_to_device_inbox(
+ "example.org",
+ f"{i}",
+ messages,
+ )
+ )
+
+ # add one more message, for user2 this time
+ # this message would be dropped before fixing #15335
+ msg["content"] = {"device": {}}
+ messages = {user2: {"device": msg}}
+
+ self.get_success(
+ store.add_messages_from_remote_to_device_inbox(
+ "example.org",
+ f"{_STREAM_UPDATE_TARGET_ROW_COUNT}",
+ messages,
+ )
+ )
+
+ # replication is disconnected so we shouldn't get any updates yet
+ self.assertEqual([], self.test_handler.received_rdata_rows)
+
+ # now reconnect to pull the updates
+ self.reconnect()
+ self.replicate()
+
+ # we should receive the fact that we have to_device updates
+ # for user1 and user2
+ received_rows = self.test_handler.received_rdata_rows
+ self.assertEqual(len(received_rows), 2)
+ self.assertEqual(received_rows[0][2].entity, user1)
+ self.assertEqual(received_rows[1][2].entity, user2)
diff --git a/tests/server.py b/tests/server.py
index bb059630fa..b52ff1c463 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -16,6 +16,7 @@ import json
import logging
import os
import os.path
+import sqlite3
import time
import uuid
import warnings
@@ -79,7 +80,9 @@ from synapse.http.site import SynapseRequest
from synapse.logging.context import ContextResourceUsage
from synapse.server import HomeServer
from synapse.storage import DataStore
+from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import PostgresEngine, create_engine
+from synapse.storage.prepare_database import prepare_database
from synapse.types import ISynapseReactor, JsonDict
from synapse.util import Clock
@@ -104,6 +107,10 @@ P = ParamSpec("P")
# the type of thing that can be passed into `make_request` in the headers list
CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]
+# A pre-prepared SQLite DB that is used as a template when creating new SQLite
+# DB each test run. This dramatically speeds up test set up when using SQLite.
+PREPPED_SQLITE_DB_CONN: Optional[LoggingDatabaseConnection] = None
+
class TimedOutException(Exception):
"""
@@ -899,6 +906,22 @@ def setup_test_homeserver(
"args": {"database": test_db_location, "cp_min": 1, "cp_max": 1},
}
+ # Check if we have set up a DB that we can use as a template.
+ global PREPPED_SQLITE_DB_CONN
+ if PREPPED_SQLITE_DB_CONN is None:
+ temp_engine = create_engine(database_config)
+ PREPPED_SQLITE_DB_CONN = LoggingDatabaseConnection(
+ sqlite3.connect(":memory:"), temp_engine, "PREPPED_CONN"
+ )
+
+ database = DatabaseConnectionConfig("master", database_config)
+ config.database.databases = [database]
+ prepare_database(
+ PREPPED_SQLITE_DB_CONN, create_engine(database_config), config
+ )
+
+ database_config["_TEST_PREPPED_CONN"] = PREPPED_SQLITE_DB_CONN
+
if "db_txn_limit" in kwargs:
database_config["txn_limit"] = kwargs["db_txn_limit"]
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 3e1984c15c..81e50bdd55 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -1143,19 +1143,24 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
tok = self.login("alice", "test")
room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+ failure_time = self.clock.time_msec()
self.get_success(
self.store.record_event_failed_pull_attempt(
room_id, "$failed_event_id", "fake cause"
)
)
- event_ids_to_backoff = self.get_success(
+ event_ids_with_backoff = self.get_success(
self.store.get_event_ids_to_not_pull_from_backoff(
room_id=room_id, event_ids=["$failed_event_id", "$normal_event_id"]
)
)
- self.assertEqual(event_ids_to_backoff, ["$failed_event_id"])
+ self.assertEqual(
+ event_ids_with_backoff,
+ # We expect a 2^1 hour backoff after a single failed attempt.
+ {"$failed_event_id": failure_time + 2 * 60 * 60 * 1000},
+ )
def test_get_event_ids_to_not_pull_from_backoff_retry_after_backoff_duration(
self,
@@ -1179,14 +1184,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# attempt (2^1 hours).
self.reactor.advance(datetime.timedelta(hours=2).total_seconds())
- event_ids_to_backoff = self.get_success(
+ event_ids_with_backoff = self.get_success(
self.store.get_event_ids_to_not_pull_from_backoff(
room_id=room_id, event_ids=["$failed_event_id", "$normal_event_id"]
)
)
# Since this function only returns events we should backoff from, time has
# elapsed past the backoff range so there is no events to backoff from.
- self.assertEqual(event_ids_to_backoff, [])
+ self.assertEqual(event_ids_with_backoff, {})
@attr.s(auto_attribs=True)
diff --git a/tests/unittest.py b/tests/unittest.py
index f9160faa1d..8a16fd3665 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -146,6 +146,9 @@ class TestCase(unittest.TestCase):
% (current_context(),)
)
+ # Disable GC for duration of test. See below for why.
+ gc.disable()
+
old_level = logging.getLogger().level
if level is not None and old_level != level:
@@ -163,12 +166,19 @@ class TestCase(unittest.TestCase):
return orig()
+ # We want to force a GC to workaround problems with deferreds leaking
+ # logcontexts when they are GCed (see the logcontext docs).
+ #
+ # The easiest way to do this would be to do a full GC after each test
+ # run, but that is very expensive. Instead, we disable GC (above) for
+ # the duration of the test so that we only need to run a gen-0 GC, which
+ # is a lot quicker.
+
@around(self)
def tearDown(orig: Callable[[], R]) -> R:
ret = orig()
- # force a GC to workaround problems with deferreds leaking logcontexts when
- # they are GCed (see the logcontext docs)
- gc.collect()
+ gc.collect(0)
+ gc.enable()
set_current_context(SENTINEL_CONTEXT)
return ret
|