diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index bd229cf7e9..751bf599ae 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -484,23 +484,6 @@ class AuthTestCase(unittest.HomeserverTestCase):
ResourceLimitError,
)
- def test_reserved_threepid(self) -> None:
- self.auth_blocking._limit_usage_by_mau = True
- self.auth_blocking._max_mau_value = 1
- self.store.get_monthly_active_count = AsyncMock(return_value=2)
- threepid = {"medium": "email", "address": "reserved@server.com"}
- unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
- self.auth_blocking._mau_limits_reserved_threepids = [threepid]
-
- self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError)
-
- self.get_failure(
- self.auth_blocking.check_auth_blocking(threepid=unknown_threepid),
- ResourceLimitError,
- )
-
- self.get_success(self.auth_blocking.check_auth_blocking(threepid=threepid))
-
def test_hs_disabled(self) -> None:
self.auth_blocking._hs_disabled = True
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index a59e168db1..93f4f98916 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -1,6 +1,10 @@
+from typing import Optional
+
from synapse.api.ratelimiting import LimitExceededError, Ratelimiter
from synapse.appservice import ApplicationService
from synapse.config.ratelimiting import RatelimitSettings
+from synapse.module_api import RatelimitOverride
+from synapse.module_api.callbacks.ratelimit_callbacks import RatelimitModuleApiCallbacks
from synapse.types import create_requester
from tests import unittest
@@ -220,9 +224,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
self.assertIn("test_id_1", limiter.actions)
- self.get_success_or_raise(
- limiter.can_do_action(None, key="test_id_2", _time_now_s=10)
- )
+ self.reactor.advance(60)
self.assertNotIn("test_id_1", limiter.actions)
@@ -442,3 +444,49 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter.can_do_action(requester=None, key="a", _time_now_s=20.0)
)
self.assertTrue(success)
+
+ def test_get_ratelimit_override_for_user_callback(self) -> None:
+ test_user_id = "@user:test"
+ test_limiter_name = "name"
+ callbacks = RatelimitModuleApiCallbacks(self.hs)
+ requester = create_requester(test_user_id)
+ limiter = Ratelimiter(
+ store=self.hs.get_datastores().main,
+ clock=self.clock,
+ cfg=RatelimitSettings(
+ test_limiter_name,
+ per_second=0.1,
+ burst_count=3,
+ ),
+ ratelimit_callbacks=callbacks,
+ )
+
+ # Observe four actions, exceeding the burst_count.
+ limiter.record_action(requester=requester, n_actions=4, _time_now_s=0.0)
+
+ # We should be prevented from taking a new action now.
+ success, _ = self.get_success_or_raise(
+ limiter.can_do_action(requester=requester, _time_now_s=0.0)
+ )
+ self.assertFalse(success)
+
+ # Now register a callback that overrides the ratelimit for this user
+ # and limiter name.
+ async def get_ratelimit_override_for_user(
+ user_id: str, limiter_name: str
+ ) -> Optional[RatelimitOverride]:
+ if user_id == test_user_id:
+ return RatelimitOverride(
+ per_second=0.1,
+ burst_count=10,
+ )
+ return None
+
+ callbacks.register_callbacks(
+ get_ratelimit_override_for_user=get_ratelimit_override_for_user
+ )
+
+ success, _ = self.get_success_or_raise(
+ limiter.can_do_action(requester=requester, _time_now_s=0.0)
+ )
+ self.assertTrue(success)
diff --git a/tests/api/test_urls.py b/tests/api/test_urls.py
new file mode 100644
index 0000000000..ce156a05dc
--- /dev/null
+++ b/tests/api/test_urls.py
@@ -0,0 +1,55 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.urls import LoginSSORedirectURIBuilder
+from synapse.server import HomeServer
+from synapse.util import Clock
+
+from tests.unittest import HomeserverTestCase
+
+# a (valid) url with some annoying characters in. %3D is =, %26 is &, %2B is +
+TRICKY_TEST_CLIENT_REDIRECT_URL = 'https://x?<ab c>&q"+%3D%2B"="fö%26=o"'
+
+
+class LoginSSORedirectURIBuilderTestCase(HomeserverTestCase):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.login_sso_redirect_url_builder = LoginSSORedirectURIBuilder(hs.config)
+
+ def test_no_idp_id(self) -> None:
+ self.assertEqual(
+ self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
+ idp_id=None, client_redirect_url="http://example.com/redirect"
+ ),
+ "https://test/_matrix/client/v3/login/sso/redirect?redirectUrl=http%3A%2F%2Fexample.com%2Fredirect",
+ )
+
+ def test_explicit_idp_id(self) -> None:
+ self.assertEqual(
+ self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
+ idp_id="oidc-github", client_redirect_url="http://example.com/redirect"
+ ),
+ "https://test/_matrix/client/v3/login/sso/redirect/oidc-github?redirectUrl=http%3A%2F%2Fexample.com%2Fredirect",
+ )
+
+ def test_tricky_redirect_uri(self) -> None:
+ self.assertEqual(
+ self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
+ idp_id="oidc-github",
+ client_redirect_url=TRICKY_TEST_CLIENT_REDIRECT_URL,
+ ),
+ "https://test/_matrix/client/v3/login/sso/redirect/oidc-github?redirectUrl=https%3A%2F%2Fx%3F%3Cab+c%3E%26q%22%2B%253D%252B%22%3D%22f%C3%B6%2526%3Do%22",
+ )
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index a1c7ccdd0b..a5bf7e0635 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -2,7 +2,7 @@
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2015, 2016 OpenMarket Ltd
-# Copyright (C) 2023 New Vector, Ltd
+# Copyright (C) 2023, 2025 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
@@ -150,7 +150,8 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.assertEqual(1, len(self.txnctrl.recoverers)) # and stored
self.assertEqual(0, txn.complete.call_count) # txn not completed
self.store.set_appservice_state.assert_called_once_with(
- service, ApplicationServiceState.DOWN # service marked as down
+ service,
+ ApplicationServiceState.DOWN, # service marked as down
)
@@ -233,6 +234,41 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.assertEqual(1, txn.complete.call_count)
self.callback.assert_called_once_with(self.recoverer)
+ def test_recover_force_retry(self) -> None:
+ txn = Mock()
+ txns = [txn, None]
+ pop_txn = False
+
+ def take_txn(
+ *args: object, **kwargs: object
+ ) -> "defer.Deferred[Optional[Mock]]":
+ if pop_txn:
+ return defer.succeed(txns.pop(0))
+ else:
+ return defer.succeed(txn)
+
+ self.store.get_oldest_unsent_txn = Mock(side_effect=take_txn)
+
+ # Start the recovery, and then fail the first attempt.
+ self.recoverer.recover()
+ self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count)
+ txn.send = AsyncMock(return_value=False)
+ txn.complete = AsyncMock(return_value=None)
+ self.clock.advance_time(2)
+ self.assertEqual(1, txn.send.call_count)
+ self.assertEqual(0, txn.complete.call_count)
+ self.assertEqual(0, self.callback.call_count)
+
+ # Now allow the send to succeed, and force a retry.
+ pop_txn = True # returns the txn the first time, then no more.
+ txn.send = AsyncMock(return_value=True) # successfully send the txn
+ self.recoverer.force_retry()
+ self.assertEqual(1, txn.send.call_count) # new mock reset call count
+ self.assertEqual(1, txn.complete.call_count)
+
+ # Ensure we call the callback to say we're done!
+ self.callback.assert_called_once_with(self.recoverer)
+
# Corresponds to synapse.appservice.scheduler._TransactionController.send
TxnCtrlArgs: TypeAlias = """
diff --git a/tests/config/test_api.py b/tests/config/test_api.py
index 6773c9a277..e6cc3e21ed 100644
--- a/tests/config/test_api.py
+++ b/tests/config/test_api.py
@@ -3,6 +3,7 @@ from unittest import TestCase as StdlibTestCase
import yaml
from synapse.config import ConfigError
+from synapse.config._base import RootConfig
from synapse.config.api import ApiConfig
from synapse.types.state import StateFilter
@@ -19,7 +20,7 @@ DEFAULT_PREJOIN_STATE_PAIRS = {
class TestRoomPrejoinState(StdlibTestCase):
def read_config(self, source: str) -> ApiConfig:
- config = ApiConfig()
+ config = ApiConfig(RootConfig())
config.read_config(yaml.safe_load(source))
return config
diff --git a/tests/config/test_appservice.py b/tests/config/test_appservice.py
index e3021b59d8..2572681224 100644
--- a/tests/config/test_appservice.py
+++ b/tests/config/test_appservice.py
@@ -19,6 +19,7 @@
#
#
+from synapse.config._base import RootConfig
from synapse.config.appservice import AppServiceConfig, ConfigError
from tests.unittest import TestCase
@@ -36,12 +37,12 @@ class AppServiceConfigTest(TestCase):
["foo", "bar", False],
]:
with self.assertRaises(ConfigError):
- AppServiceConfig().read_config(
+ AppServiceConfig(RootConfig()).read_config(
{"app_service_config_files": invalid_value}
)
def test_valid_app_service_config_files(self) -> None:
- AppServiceConfig().read_config({"app_service_config_files": []})
- AppServiceConfig().read_config(
+ AppServiceConfig(RootConfig()).read_config({"app_service_config_files": []})
+ AppServiceConfig(RootConfig()).read_config(
{"app_service_config_files": ["/not/a/real/path", "/not/a/real/path/2"]}
)
diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py
index 631263b5ca..aead73e059 100644
--- a/tests/config/test_cache.py
+++ b/tests/config/test_cache.py
@@ -19,6 +19,7 @@
#
#
+from synapse.config._base import RootConfig
from synapse.config.cache import CacheConfig, add_resizable_cache
from synapse.types import JsonDict
from synapse.util.caches.lrucache import LruCache
@@ -29,7 +30,7 @@ from tests.unittest import TestCase
class CacheConfigTests(TestCase):
def setUp(self) -> None:
# Reset caches before each test since there's global state involved.
- self.config = CacheConfig()
+ self.config = CacheConfig(RootConfig())
self.config.reset()
def tearDown(self) -> None:
diff --git a/tests/config/test_database.py b/tests/config/test_database.py
index b46519f84a..3fa5fff2b2 100644
--- a/tests/config/test_database.py
+++ b/tests/config/test_database.py
@@ -20,6 +20,7 @@
import yaml
+from synapse.config._base import RootConfig
from synapse.config.database import DatabaseConfig
from tests import unittest
@@ -28,7 +29,9 @@ from tests import unittest
class DatabaseConfigTestCase(unittest.TestCase):
def test_database_configured_correctly(self) -> None:
conf = yaml.safe_load(
- DatabaseConfig().generate_config_section(data_dir_path="/data_dir_path")
+ DatabaseConfig(RootConfig()).generate_config_section(
+ data_dir_path="/data_dir_path"
+ )
)
expected_database_conf = {
diff --git a/tests/config/test_load.py b/tests/config/test_load.py
index 479d2aab91..e06a961b16 100644
--- a/tests/config/test_load.py
+++ b/tests/config/test_load.py
@@ -19,17 +19,33 @@
# [This file includes modifications made by New Vector Limited]
#
#
+import tempfile
+from typing import Callable
+from unittest import mock
+
import yaml
+from parameterized import parameterized
from synapse.config import ConfigError
+from synapse.config._base import RootConfig
from synapse.config.homeserver import HomeServerConfig
from tests.config.utils import ConfigFileTestCase
+try:
+ import authlib
+except ImportError:
+ authlib = None
+
+try:
+ import hiredis
+except ImportError:
+ hiredis = None # type: ignore
+
class ConfigLoadingFileTestCase(ConfigFileTestCase):
def test_load_fails_if_server_name_missing(self) -> None:
- self.generate_config_and_remove_lines_containing("server_name")
+ self.generate_config_and_remove_lines_containing(["server_name"])
with self.assertRaises(ConfigError):
HomeServerConfig.load_config("", ["-c", self.config_file])
with self.assertRaises(ConfigError):
@@ -66,7 +82,7 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
)
def test_load_succeeds_if_macaroon_secret_key_missing(self) -> None:
- self.generate_config_and_remove_lines_containing("macaroon")
+ self.generate_config_and_remove_lines_containing(["macaroon"])
config1 = HomeServerConfig.load_config("", ["-c", self.config_file])
config2 = HomeServerConfig.load_config("", ["-c", self.config_file])
config3 = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file])
@@ -101,18 +117,180 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
self.assertTrue(config3.registration.enable_registration)
def test_stats_enabled(self) -> None:
- self.generate_config_and_remove_lines_containing("enable_metrics")
+ self.generate_config_and_remove_lines_containing(["enable_metrics"])
self.add_lines_to_config(["enable_metrics: true"])
# The default Metrics Flags are off by default.
config = HomeServerConfig.load_config("", ["-c", self.config_file])
self.assertFalse(config.metrics.metrics_flags.known_servers)
- def test_depreciated_identity_server_flag_throws_error(self) -> None:
+ @parameterized.expand(
+ [
+ "turn_shared_secret_path: /does/not/exist",
+ "registration_shared_secret_path: /does/not/exist",
+ "macaroon_secret_key_path: /does/not/exist",
+ "form_secret_path: /does/not/exist",
+ "worker_replication_secret_path: /does/not/exist",
+ "experimental_features:\n msc3861:\n client_secret_path: /does/not/exist",
+ "experimental_features:\n msc3861:\n admin_token_path: /does/not/exist",
+ *["redis:\n enabled: true\n password_path: /does/not/exist"]
+ * (hiredis is not None),
+ ]
+ )
+ def test_secret_files_missing(self, config_str: str) -> None:
self.generate_config()
- # Needed to ensure that actual key/value pair added below don't end up on a line with a comment
- self.add_lines_to_config([" "])
- # Check that presence of "trust_identity_server_for_password" throws config error
- self.add_lines_to_config(["trust_identity_server_for_password_resets: true"])
+ self.add_lines_to_config(["", config_str])
+
with self.assertRaises(ConfigError):
HomeServerConfig.load_config("", ["-c", self.config_file])
+
+ @parameterized.expand(
+ [
+ (
+ "turn_shared_secret_path: {}",
+ lambda c: c.voip.turn_shared_secret.encode("utf-8"),
+ ),
+ (
+ "registration_shared_secret_path: {}",
+ lambda c: c.registration.registration_shared_secret.encode("utf-8"),
+ ),
+ (
+ "macaroon_secret_key_path: {}",
+ lambda c: c.key.macaroon_secret_key,
+ ),
+ (
+ "form_secret_path: {}",
+ lambda c: c.key.form_secret.encode("utf-8"),
+ ),
+ (
+ "worker_replication_secret_path: {}",
+ lambda c: c.worker.worker_replication_secret.encode("utf-8"),
+ ),
+ (
+ "experimental_features:\n msc3861:\n client_secret_path: {}",
+ lambda c: c.experimental.msc3861.client_secret().encode("utf-8"),
+ ),
+ (
+ "experimental_features:\n msc3861:\n admin_token_path: {}",
+ lambda c: c.experimental.msc3861.admin_token().encode("utf-8"),
+ ),
+ *[
+ (
+ "redis:\n enabled: true\n password_path: {}",
+ lambda c: c.redis.redis_password.encode("utf-8"),
+ )
+ ]
+ * (hiredis is not None),
+ ]
+ )
+ def test_secret_files_existing(
+ self, config_line: str, get_secret: Callable[[RootConfig], str]
+ ) -> None:
+ self.generate_config_and_remove_lines_containing(
+ ["form_secret", "macaroon_secret_key", "registration_shared_secret"]
+ )
+ with tempfile.NamedTemporaryFile(buffering=0) as secret_file:
+ secret_file.write(b"53C237")
+
+ self.add_lines_to_config(["", config_line.format(secret_file.name)])
+ config = HomeServerConfig.load_config("", ["-c", self.config_file])
+
+ self.assertEqual(get_secret(config), b"53C237")
+
+ @parameterized.expand(
+ [
+ "turn_shared_secret: 53C237",
+ "registration_shared_secret: 53C237",
+ "macaroon_secret_key: 53C237",
+ "recaptcha_private_key: 53C237",
+ "recaptcha_public_key: ¬53C237",
+ "form_secret: 53C237",
+ "worker_replication_secret: 53C237",
+ *[
+ "experimental_features:\n"
+ " msc3861:\n"
+ " enabled: true\n"
+ " client_secret: 53C237"
+ ]
+ * (authlib is not None),
+ *[
+ "experimental_features:\n"
+ " msc3861:\n"
+ " enabled: true\n"
+ " client_auth_method: private_key_jwt\n"
+ ' jwk: {{"mock": "mock"}}'
+ ]
+ * (authlib is not None),
+ *[
+ "experimental_features:\n"
+ " msc3861:\n"
+ " enabled: true\n"
+ " admin_token: 53C237\n"
+ " client_secret_path: {secret_file}"
+ ]
+ * (authlib is not None),
+ *["redis:\n enabled: true\n password: 53C237"] * (hiredis is not None),
+ ]
+ )
+ def test_no_secrets_in_config(self, config_line: str) -> None:
+ if authlib is not None:
+ patcher = mock.patch("authlib.jose.rfc7517.JsonWebKey.import_key")
+ self.addCleanup(patcher.stop)
+ patcher.start()
+
+ with tempfile.NamedTemporaryFile(buffering=0) as secret_file:
+ # Only used for less mocking with admin_token
+ secret_file.write(b"53C237")
+
+ self.generate_config_and_remove_lines_containing(
+ ["form_secret", "macaroon_secret_key", "registration_shared_secret"]
+ )
+ # Check strict mode with no offenders.
+ HomeServerConfig.load_config(
+ "", ["-c", self.config_file, "--no-secrets-in-config"]
+ )
+ self.add_lines_to_config(
+ ["", config_line.format(secret_file=secret_file.name)]
+ )
+ # Check strict mode with a single offender.
+ with self.assertRaises(ConfigError):
+ HomeServerConfig.load_config(
+ "", ["-c", self.config_file, "--no-secrets-in-config"]
+ )
+
+ # Check lenient mode with a single offender.
+ HomeServerConfig.load_config("", ["-c", self.config_file])
+
+ def test_no_secrets_in_config_but_in_files(self) -> None:
+ with tempfile.NamedTemporaryFile(buffering=0) as secret_file:
+ secret_file.write(b"53C237")
+
+ self.generate_config_and_remove_lines_containing(
+ ["form_secret", "macaroon_secret_key", "registration_shared_secret"]
+ )
+ self.add_lines_to_config(
+ [
+ "",
+ f"turn_shared_secret_path: {secret_file.name}",
+ f"registration_shared_secret_path: {secret_file.name}",
+ f"macaroon_secret_key_path: {secret_file.name}",
+ f"recaptcha_private_key_path: {secret_file.name}",
+ f"recaptcha_public_key_path: {secret_file.name}",
+ f"form_secret_path: {secret_file.name}",
+ f"worker_replication_secret_path: {secret_file.name}",
+ *[
+ "experimental_features:\n"
+ " msc3861:\n"
+ " enabled: true\n"
+ f" admin_token_path: {secret_file.name}\n"
+ f" client_secret_path: {secret_file.name}\n"
+ # f" jwk_path: {secret_file.name}"
+ ]
+ * (authlib is not None),
+ *[f"redis:\n enabled: true\n password_path: {secret_file.name}"]
+ * (hiredis is not None),
+ ]
+ )
+ HomeServerConfig.load_config(
+ "", ["-c", self.config_file, "--no-secrets-in-config"]
+ )
diff --git a/tests/config/test_oauth_delegation.py b/tests/config/test_oauth_delegation.py
index 713bddeb90..e1e4e008dd 100644
--- a/tests/config/test_oauth_delegation.py
+++ b/tests/config/test_oauth_delegation.py
@@ -205,17 +205,6 @@ class MSC3861OAuthDelegation(TestCase):
with self.assertRaises(ConfigError):
self.parse_config()
- def test_cas_sso_cannot_be_enabled(self) -> None:
- self.config_dict["cas_config"] = {
- "enabled": True,
- "server_url": "https://cas-server.com",
- "displayname_attribute": "name",
- "required_attributes": {"userGroup": "staff", "department": "None"},
- }
-
- with self.assertRaises(ConfigError):
- self.parse_config()
-
def test_auth_providers_cannot_be_enabled(self) -> None:
self.config_dict["modules"] = [
{
@@ -270,8 +259,3 @@ class MSC3861OAuthDelegation(TestCase):
self.config_dict["session_lifetime"] = "24h"
with self.assertRaises(ConfigError):
self.parse_config()
-
- def test_enable_3pid_changes_cannot_be_enabled(self) -> None:
- self.config_dict["enable_3pid_changes"] = True
- with self.assertRaises(ConfigError):
- self.parse_config()
diff --git a/tests/config/test_room_directory.py b/tests/config/test_room_directory.py
index e25f7787f4..5208381279 100644
--- a/tests/config/test_room_directory.py
+++ b/tests/config/test_room_directory.py
@@ -24,6 +24,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
import synapse.rest.client.login
import synapse.rest.client.room
+from synapse.config._base import RootConfig
from synapse.config.room_directory import RoomDirectoryConfig
from synapse.server import HomeServer
from synapse.util import Clock
@@ -63,7 +64,7 @@ class RoomDirectoryConfigTestCase(unittest.HomeserverTestCase):
"""
)
- rd_config = RoomDirectoryConfig()
+ rd_config = RoomDirectoryConfig(RootConfig())
rd_config.read_config(config)
self.assertFalse(
@@ -123,7 +124,7 @@ class RoomDirectoryConfigTestCase(unittest.HomeserverTestCase):
"""
)
- rd_config = RoomDirectoryConfig()
+ rd_config = RoomDirectoryConfig(RootConfig())
rd_config.read_config(config)
self.assertFalse(
diff --git a/tests/config/test_server.py b/tests/config/test_server.py
index 74073cfdc5..05faf8fcc9 100644
--- a/tests/config/test_server.py
+++ b/tests/config/test_server.py
@@ -20,27 +20,16 @@
import yaml
-from synapse.config._base import ConfigError
-from synapse.config.server import ServerConfig, generate_ip_set, is_threepid_reserved
+from synapse.config._base import ConfigError, RootConfig
+from synapse.config.server import ServerConfig, generate_ip_set
from tests import unittest
class ServerConfigTestCase(unittest.TestCase):
- def test_is_threepid_reserved(self) -> None:
- user1 = {"medium": "email", "address": "user1@example.com"}
- user2 = {"medium": "email", "address": "user2@example.com"}
- user3 = {"medium": "email", "address": "user3@example.com"}
- user1_msisdn = {"medium": "msisdn", "address": "447700000000"}
- config = [user1, user2]
-
- self.assertTrue(is_threepid_reserved(config, user1))
- self.assertFalse(is_threepid_reserved(config, user3))
- self.assertFalse(is_threepid_reserved(config, user1_msisdn))
-
def test_unsecure_listener_no_listeners_open_private_ports_false(self) -> None:
conf = yaml.safe_load(
- ServerConfig().generate_config_section(
+ ServerConfig(RootConfig()).generate_config_section(
"CONFDIR", "/data_dir_path", "che.org", False, None
)
)
@@ -60,7 +49,7 @@ class ServerConfigTestCase(unittest.TestCase):
def test_unsecure_listener_no_listeners_open_private_ports_true(self) -> None:
conf = yaml.safe_load(
- ServerConfig().generate_config_section(
+ ServerConfig(RootConfig()).generate_config_section(
"CONFDIR", "/data_dir_path", "che.org", True, None
)
)
@@ -94,7 +83,7 @@ class ServerConfigTestCase(unittest.TestCase):
]
conf = yaml.safe_load(
- ServerConfig().generate_config_section(
+ ServerConfig(RootConfig()).generate_config_section(
"CONFDIR", "/data_dir_path", "this.one.listens", True, listeners
)
)
@@ -128,7 +117,7 @@ class ServerConfigTestCase(unittest.TestCase):
expected_listeners[1]["bind_addresses"] = ["::1", "127.0.0.1"]
conf = yaml.safe_load(
- ServerConfig().generate_config_section(
+ ServerConfig(RootConfig()).generate_config_section(
"CONFDIR", "/data_dir_path", "this.one.listens", True, listeners
)
)
diff --git a/tests/config/test_workers.py b/tests/config/test_workers.py
index 64c0285d01..3a21975b89 100644
--- a/tests/config/test_workers.py
+++ b/tests/config/test_workers.py
@@ -47,7 +47,7 @@ class WorkerDutyConfigTestCase(TestCase):
"worker_app": worker_app,
**extras,
}
- worker_config.read_config(worker_config_dict)
+ worker_config.read_config(worker_config_dict, allow_secrets_in_config=True)
return worker_config
def test_old_configs_master(self) -> None:
diff --git a/tests/config/utils.py b/tests/config/utils.py
index 11140ff979..3cba4ac588 100644
--- a/tests/config/utils.py
+++ b/tests/config/utils.py
@@ -51,12 +51,13 @@ class ConfigFileTestCase(unittest.TestCase):
],
)
- def generate_config_and_remove_lines_containing(self, needle: str) -> None:
+ def generate_config_and_remove_lines_containing(self, needles: list[str]) -> None:
self.generate_config()
with open(self.config_file) as f:
contents = f.readlines()
- contents = [line for line in contents if needle not in line]
+ for needle in needles:
+ contents = [line for line in contents if needle not in line]
with open(self.config_file, "w") as f:
f.write("".join(contents))
diff --git a/tests/events/test_auto_accept_invites.py b/tests/events/test_auto_accept_invites.py
index 7fb4d4fa90..d2100e9903 100644
--- a/tests/events/test_auto_accept_invites.py
+++ b/tests/events/test_auto_accept_invites.py
@@ -31,6 +31,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
+from synapse.config._base import RootConfig
from synapse.config.auto_accept_invites import AutoAcceptInvitesConfig
from synapse.events.auto_accept_invites import InviteAutoAccepter
from synapse.federation.federation_base import event_from_pdu_json
@@ -39,7 +40,7 @@ from synapse.module_api import ModuleApi
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
-from synapse.types import StreamToken, create_requester
+from synapse.types import StreamToken, UserID, UserInfo, create_requester
from synapse.util import Clock
from tests.handlers.test_sync import generate_sync_config
@@ -349,6 +350,169 @@ class AutoAcceptInvitesTestCase(FederatingHomeserverTestCase):
join_updates, _ = sync_join(self, invited_user_id)
self.assertEqual(len(join_updates), 0)
+ @override_config(
+ {
+ "auto_accept_invites": {
+ "enabled": True,
+ },
+ }
+ )
+ async def test_ignore_invite_for_missing_user(self) -> None:
+ """Tests that receiving an invite for a missing user is ignored."""
+ inviting_user_id = self.register_user("inviter", "pass")
+ inviting_user_tok = self.login("inviter", "pass")
+
+ # A local user who receives an invite
+ invited_user_id = "@fake:" + self.hs.config.server.server_name
+
+ # Create a room and send an invite to the other user
+ room_id = self.helper.create_room_as(
+ inviting_user_id,
+ tok=inviting_user_tok,
+ )
+
+ self.helper.invite(
+ room_id,
+ inviting_user_id,
+ invited_user_id,
+ tok=inviting_user_tok,
+ )
+
+ join_updates, _ = sync_join(self, inviting_user_id)
+ # Assert that the last event in the room was not a member event for the target user.
+ self.assertEqual(
+ join_updates[0].timeline.events[-1].content["membership"], "invite"
+ )
+
+ @override_config(
+ {
+ "auto_accept_invites": {
+ "enabled": True,
+ },
+ }
+ )
+ async def test_ignore_invite_for_deactivated_user(self) -> None:
+ """Tests that receiving an invite for a deactivated user is ignored."""
+ inviting_user_id = self.register_user("inviter", "pass", admin=True)
+ inviting_user_tok = self.login("inviter", "pass")
+
+ # A local user who receives an invite
+ invited_user_id = self.register_user("invitee", "pass")
+
+ # Create a room and send an invite to the other user
+ room_id = self.helper.create_room_as(
+ inviting_user_id,
+ tok=inviting_user_tok,
+ )
+
+ channel = self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/%s" % invited_user_id,
+ {"deactivated": True},
+ access_token=inviting_user_tok,
+ )
+
+ assert channel.code == 200
+
+ self.helper.invite(
+ room_id,
+ inviting_user_id,
+ invited_user_id,
+ tok=inviting_user_tok,
+ )
+
+ join_updates, b = sync_join(self, inviting_user_id)
+ # Assert that the last event in the room was not a member event for the target user.
+ self.assertEqual(
+ join_updates[0].timeline.events[-1].content["membership"], "invite"
+ )
+
+ @override_config(
+ {
+ "auto_accept_invites": {
+ "enabled": True,
+ },
+ }
+ )
+ async def test_ignore_invite_for_suspended_user(self) -> None:
+ """Tests that receiving an invite for a suspended user is ignored."""
+ inviting_user_id = self.register_user("inviter", "pass", admin=True)
+ inviting_user_tok = self.login("inviter", "pass")
+
+ # A local user who receives an invite
+ invited_user_id = self.register_user("invitee", "pass")
+
+ # Create a room and send an invite to the other user
+ room_id = self.helper.create_room_as(
+ inviting_user_id,
+ tok=inviting_user_tok,
+ )
+
+ channel = self.make_request(
+ "PUT",
+ f"/_synapse/admin/v1/suspend/{invited_user_id}",
+ {"suspend": True},
+ access_token=inviting_user_tok,
+ )
+
+ assert channel.code == 200
+
+ self.helper.invite(
+ room_id,
+ inviting_user_id,
+ invited_user_id,
+ tok=inviting_user_tok,
+ )
+
+ join_updates, b = sync_join(self, inviting_user_id)
+ # Assert that the last event in the room was not a member event for the target user.
+ self.assertEqual(
+ join_updates[0].timeline.events[-1].content["membership"], "invite"
+ )
+
+ @override_config(
+ {
+ "auto_accept_invites": {
+ "enabled": True,
+ },
+ }
+ )
+ async def test_ignore_invite_for_locked_user(self) -> None:
+ """Tests that receiving an invite for a suspended user is ignored."""
+ inviting_user_id = self.register_user("inviter", "pass", admin=True)
+ inviting_user_tok = self.login("inviter", "pass")
+
+ # A local user who receives an invite
+ invited_user_id = self.register_user("invitee", "pass")
+
+ # Create a room and send an invite to the other user
+ room_id = self.helper.create_room_as(
+ inviting_user_id,
+ tok=inviting_user_tok,
+ )
+
+ channel = self.make_request(
+ "PUT",
+ f"/_synapse/admin/v2/users/{invited_user_id}",
+ {"locked": True},
+ access_token=inviting_user_tok,
+ )
+
+ assert channel.code == 200
+
+ self.helper.invite(
+ room_id,
+ inviting_user_id,
+ invited_user_id,
+ tok=inviting_user_tok,
+ )
+
+ join_updates, b = sync_join(self, inviting_user_id)
+ # Assert that the last event in the room was not a member event for the target user.
+ self.assertEqual(
+ join_updates[0].timeline.events[-1].content["membership"], "invite"
+ )
+
_request_key = 0
@@ -527,7 +691,7 @@ class InviteAutoAccepterInternalTestCase(TestCase):
"only_from_local_users": True,
}
}
- parsed_config = AutoAcceptInvitesConfig()
+ parsed_config = AutoAcceptInvitesConfig(RootConfig())
parsed_config.read_config(config)
self.assertTrue(parsed_config.enabled)
@@ -647,11 +811,27 @@ def create_module(
module_api.is_mine.side_effect = lambda a: a.split(":")[1] == "test"
module_api.worker_name = worker_name
module_api.sleep.return_value = make_multiple_awaitable(None)
+ module_api.get_userinfo_by_id.return_value = UserInfo(
+ user_id=UserID.from_string("@user:test"),
+ is_admin=False,
+ is_guest=False,
+ consent_server_notice_sent=None,
+ consent_ts=None,
+ consent_version=None,
+ appservice_id=None,
+ creation_ts=0,
+ user_type=None,
+ is_deactivated=False,
+ locked=False,
+ is_shadow_banned=False,
+ approved=True,
+ suspended=False,
+ )
if config_override is None:
config_override = {}
- config = AutoAcceptInvitesConfig()
+ config = AutoAcceptInvitesConfig(RootConfig())
config.read_config(config_override)
return InviteAutoAccepter(config, module_api)
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index 30f8787758..654e6521a2 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -756,7 +756,8 @@ class SerializeEventTestCase(stdlib_unittest.TestCase):
def test_event_fields_fail_if_fields_not_str(self) -> None:
with self.assertRaises(TypeError):
self.serialize(
- MockEvent(room_id="!foo:bar", content={"foo": "bar"}), ["room_id", 4] # type: ignore[list-item]
+ MockEvent(room_id="!foo:bar", content={"foo": "bar"}),
+ ["room_id", 4], # type: ignore[list-item]
)
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 9bd97e5d4e..87b9ffc0c6 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -158,7 +158,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
async def get_current_state_event_counts(room_id: str) -> int:
return 600
- self.hs.get_datastores().main.get_current_state_event_counts = get_current_state_event_counts # type: ignore[method-assign]
+ self.hs.get_datastores().main.get_current_state_event_counts = ( # type: ignore[method-assign]
+ get_current_state_event_counts
+ )
d = handler._remote_join(
create_requester(u1),
diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index 08214b0013..1e1ed8e642 100644
--- a/tests/federation/test_federation_catch_up.py
+++ b/tests/federation/test_federation_catch_up.py
@@ -401,7 +401,10 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
now = self.clock.time_msec()
self.get_success(
self.hs.get_datastores().main.set_destination_retry_timings(
- "zzzerver", now, now, 24 * 60 * 60 * 1000 # retry in 1 day
+ "zzzerver",
+ now,
+ now,
+ 24 * 60 * 60 * 1000, # retry in 1 day
)
)
diff --git a/tests/federation/test_federation_devices.py b/tests/federation/test_federation_devices.py
new file mode 100644
index 0000000000..ba27e69479
--- /dev/null
+++ b/tests/federation/test_federation_devices.py
@@ -0,0 +1,161 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+# Originally licensed under the Apache License, Version 2.0:
+# <http://www.apache.org/licenses/LICENSE-2.0>.
+#
+# [This file includes modifications made by New Vector Limited]
+#
+#
+
+import logging
+from unittest.mock import AsyncMock, Mock
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.handlers.device import DeviceListUpdater
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
+from synapse.util.retryutils import NotRetryingDestination
+
+from tests import unittest
+
+logger = logging.getLogger(__name__)
+
+
+class DeviceListResyncTestCase(unittest.HomeserverTestCase):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = self.hs.get_datastores().main
+
+ def test_retry_device_list_resync(self) -> None:
+ """Tests that device lists are marked as stale if they couldn't be synced, and
+ that stale device lists are retried periodically.
+ """
+ remote_user_id = "@john:test_remote"
+ remote_origin = "test_remote"
+
+ # Track the number of attempts to resync the user's device list.
+ self.resync_attempts = 0
+
+ # When this function is called, increment the number of resync attempts (only if
+ # we're querying devices for the right user ID), then raise a
+ # NotRetryingDestination error to fail the resync gracefully.
+ def query_user_devices(
+ destination: str, user_id: str, timeout: int = 30000
+ ) -> JsonDict:
+ if user_id == remote_user_id:
+ self.resync_attempts += 1
+
+ raise NotRetryingDestination(0, 0, destination)
+
+ # Register the mock on the federation client.
+ federation_client = self.hs.get_federation_client()
+ federation_client.query_user_devices = Mock(side_effect=query_user_devices) # type: ignore[method-assign]
+
+ # Register a mock on the store so that the incoming update doesn't fail because
+ # we don't share a room with the user.
+ self.store.get_rooms_for_user = AsyncMock(return_value=["!someroom:test"])
+
+ # Manually inject a fake device list update. We need this update to include at
+ # least one prev_id so that the user's device list will need to be retried.
+ device_list_updater = self.hs.get_device_handler().device_list_updater
+ assert isinstance(device_list_updater, DeviceListUpdater)
+ self.get_success(
+ device_list_updater.incoming_device_list_update(
+ origin=remote_origin,
+ edu_content={
+ "deleted": False,
+ "device_display_name": "Mobile",
+ "device_id": "QBUAZIFURK",
+ "prev_id": [5],
+ "stream_id": 6,
+ "user_id": remote_user_id,
+ },
+ )
+ )
+
+ # Check that there was one resync attempt.
+ self.assertEqual(self.resync_attempts, 1)
+
+ # Check that the resync attempt failed and caused the user's device list to be
+ # marked as stale.
+ need_resync = self.get_success(
+ self.store.get_user_ids_requiring_device_list_resync()
+ )
+ self.assertIn(remote_user_id, need_resync)
+
+ # Check that waiting for 30 seconds caused Synapse to retry resyncing the device
+ # list.
+ self.reactor.advance(30)
+ self.assertEqual(self.resync_attempts, 2)
+
+ def test_cross_signing_keys_retry(self) -> None:
+ """Tests that resyncing a device list correctly processes cross-signing keys from
+ the remote server.
+ """
+ remote_user_id = "@john:test_remote"
+ remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
+ remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
+
+ # Register mock device list retrieval on the federation client.
+ federation_client = self.hs.get_federation_client()
+ federation_client.query_user_devices = AsyncMock( # type: ignore[method-assign]
+ return_value={
+ "user_id": remote_user_id,
+ "stream_id": 1,
+ "devices": [],
+ "master_key": {
+ "user_id": remote_user_id,
+ "usage": ["master"],
+ "keys": {"ed25519:" + remote_master_key: remote_master_key},
+ },
+ "self_signing_key": {
+ "user_id": remote_user_id,
+ "usage": ["self_signing"],
+ "keys": {
+ "ed25519:" + remote_self_signing_key: remote_self_signing_key
+ },
+ },
+ }
+ )
+
+ # Resync the device list.
+ device_handler = self.hs.get_device_handler()
+ self.get_success(
+ device_handler.device_list_updater.multi_user_device_resync(
+ [remote_user_id]
+ ),
+ )
+
+ # Retrieve the cross-signing keys for this user.
+ keys = self.get_success(
+ self.store.get_e2e_cross_signing_keys_bulk(user_ids=[remote_user_id]),
+ )
+ self.assertIn(remote_user_id, keys)
+ key = keys[remote_user_id]
+ assert key is not None
+
+ # Check that the master key is the one returned by the mock.
+ master_key = key["master"]
+ self.assertEqual(len(master_key["keys"]), 1)
+ self.assertTrue("ed25519:" + remote_master_key in master_key["keys"].keys())
+ self.assertTrue(remote_master_key in master_key["keys"].values())
+
+ # Check that the self-signing key is the one returned by the mock.
+ self_signing_key = key["self_signing"]
+ self.assertEqual(len(self_signing_key["keys"]), 1)
+ self.assertTrue(
+ "ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(),
+ )
+ self.assertTrue(remote_self_signing_key in self_signing_key["keys"].values())
diff --git a/tests/federation/test_federation_media.py b/tests/federation/test_federation_media.py
index 0dcf20f5f5..9c92003ce5 100644
--- a/tests/federation/test_federation_media.py
+++ b/tests/federation/test_federation_media.py
@@ -40,7 +40,6 @@ from tests.test_utils import SMALL_PNG
class FederationMediaDownloadsTest(unittest.FederatingHomeserverTestCase):
-
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-")
@@ -148,9 +147,47 @@ class FederationMediaDownloadsTest(unittest.FederatingHomeserverTestCase):
found_file = any(SMALL_PNG in field for field in stripped_bytes)
self.assertTrue(found_file)
+ def test_federation_etag(self) -> None:
+ """Test that federation ETags work"""
+
+ content = io.BytesIO(b"file_to_stream")
+ content_uri = self.get_success(
+ self.media_repo.create_content(
+ "text/plain",
+ "test_upload",
+ content,
+ 46,
+ UserID.from_string("@user_id:whatever.org"),
+ )
+ )
+
+ channel = self.make_signed_federation_request(
+ "GET",
+ f"/_matrix/federation/v1/media/download/{content_uri.media_id}",
+ )
+ self.pump()
+ self.assertEqual(200, channel.code)
+
+ # We expect exactly one ETag header.
+ etags = channel.headers.getRawHeaders("ETag")
+ self.assertIsNotNone(etags)
+ assert etags is not None # For mypy
+ self.assertEqual(len(etags), 1)
+ etag = etags[0]
+
+ # Refetching with the etag should result in 304 and empty body.
+ channel = self.make_signed_federation_request(
+ "GET",
+ f"/_matrix/federation/v1/media/download/{content_uri.media_id}",
+ custom_headers=[("If-None-Match", etag)],
+ )
+ self.pump()
+ self.assertEqual(channel.code, 304)
+ self.assertEqual(channel.is_finished(), True)
+ self.assertNotIn("body", channel.result)
-class FederationThumbnailTest(unittest.FederatingHomeserverTestCase):
+class FederationThumbnailTest(unittest.FederatingHomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-")
diff --git a/tests/federation/test_federation_out_of_band_membership.py b/tests/federation/test_federation_out_of_band_membership.py
new file mode 100644
index 0000000000..f77b8fe300
--- /dev/null
+++ b/tests/federation/test_federation_out_of_band_membership.py
@@ -0,0 +1,671 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+# Copyright (C) 2023 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+# Originally licensed under the Apache License, Version 2.0:
+# <http://www.apache.org/licenses/LICENSE-2.0>.
+#
+# [This file includes modifications made by New Vector Limited]
+#
+#
+
+import logging
+import time
+import urllib.parse
+from http import HTTPStatus
+from typing import Any, Callable, Optional, Set, Tuple, TypeVar, Union
+from unittest.mock import Mock
+
+import attr
+from parameterized import parameterized
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.constants import EventContentFields, EventTypes, Membership
+from synapse.api.room_versions import RoomVersion, RoomVersions
+from synapse.events import EventBase, make_event_from_dict
+from synapse.events.utils import strip_event
+from synapse.federation.federation_base import (
+ event_from_pdu_json,
+)
+from synapse.federation.transport.client import SendJoinResponse
+from synapse.http.matrixfederationclient import (
+ ByteParser,
+)
+from synapse.http.types import QueryParams
+from synapse.rest import admin
+from synapse.rest.client import login, room, sync
+from synapse.server import HomeServer
+from synapse.types import JsonDict, MutableStateMap, StateMap
+from synapse.types.handlers.sliding_sync import (
+ StateValues,
+)
+from synapse.util import Clock
+
+from tests import unittest
+from tests.utils import test_timeout
+
+logger = logging.getLogger(__name__)
+
+
+def required_state_json_to_state_map(required_state: Any) -> StateMap[EventBase]:
+ state_map: MutableStateMap[EventBase] = {}
+
+ # Scrutinize JSON values to ensure it's in the expected format
+ if isinstance(required_state, list):
+ for state_event_dict in required_state:
+ # Yell because we're in a test and this is unexpected
+ assert isinstance(state_event_dict, dict), (
+ "`required_state` should be a list of event dicts"
+ )
+
+ event_type = state_event_dict["type"]
+ event_state_key = state_event_dict["state_key"]
+
+ # Yell because we're in a test and this is unexpected
+ assert isinstance(event_type, str), (
+ "Each event in `required_state` should have a string `type`"
+ )
+ assert isinstance(event_state_key, str), (
+ "Each event in `required_state` should have a string `state_key`"
+ )
+
+ state_map[(event_type, event_state_key)] = make_event_from_dict(
+ state_event_dict
+ )
+ else:
+ # Yell because we're in a test and this is unexpected
+ raise AssertionError("`required_state` should be a list of event dicts")
+
+ return state_map
+
+
+@attr.s(slots=True, auto_attribs=True)
+class RemoteRoomJoinResult:
+ remote_room_id: str
+ room_version: RoomVersion
+ remote_room_creator_user_id: str
+ local_user1_id: str
+ local_user1_tok: str
+ state_map: StateMap[EventBase]
+
+
+class OutOfBandMembershipTests(unittest.FederatingHomeserverTestCase):
+ """
+ Tests to make sure that interactions with out-of-band membership (outliers) works as
+ expected.
+
+ - invites received over federation, before we join the room
+ - *rejections* for said invites
+
+ See the "Out-of-band membership events" section in
+ `docs/development/room-dag-concepts.md` for more information.
+ """
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
+
+ sync_endpoint = "/_matrix/client/unstable/org.matrix.simplified_msc3575/sync"
+
+ def default_config(self) -> JsonDict:
+ conf = super().default_config()
+ # Federation sending is disabled by default in the test environment
+ # so we need to enable it like this.
+ conf["federation_sender_instances"] = ["master"]
+
+ return conf
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.federation_http_client = Mock(
+ # The problem with using `spec=MatrixFederationHttpClient` here is that it
+ # requires everything to be mocked which is a lot of work that I don't want
+ # to do when the code only uses a few methods (`get_json` and `put_json`).
+ )
+ return self.setup_test_homeserver(
+ federation_http_client=self.federation_http_client
+ )
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ super().prepare(reactor, clock, hs)
+
+ self.store = self.hs.get_datastores().main
+ self.storage_controllers = hs.get_storage_controllers()
+
+ def do_sync(
+ self, sync_body: JsonDict, *, since: Optional[str] = None, tok: str
+ ) -> Tuple[JsonDict, str]:
+ """Do a sliding sync request with given body.
+
+ Asserts the request was successful.
+
+ Attributes:
+ sync_body: The full request body to use
+ since: Optional since token
+ tok: Access token to use
+
+ Returns:
+ A tuple of the response body and the `pos` field.
+ """
+
+ sync_path = self.sync_endpoint
+ if since:
+ sync_path += f"?pos={since}"
+
+ channel = self.make_request(
+ method="POST",
+ path=sync_path,
+ content=sync_body,
+ access_token=tok,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ return channel.json_body, channel.json_body["pos"]
+
+ def _invite_local_user_to_remote_room_and_join(self) -> RemoteRoomJoinResult:
+ """
+ Helper to reproduce this scenario:
+
+ 1. The remote user invites our local user to a room on their remote server (which
+ creates an out-of-band invite membership for user1 on our local server).
+ 2. The local user notices the invite from `/sync`.
+ 3. The local user joins the room.
+ 4. The local user can see that they are now joined to the room from `/sync`.
+ """
+
+ # Create a local user
+ local_user1_id = self.register_user("user1", "pass")
+ local_user1_tok = self.login(local_user1_id, "pass")
+
+ # Create a remote room
+ room_creator_user_id = f"@remote-user:{self.OTHER_SERVER_NAME}"
+ remote_room_id = f"!remote-room:{self.OTHER_SERVER_NAME}"
+ room_version = RoomVersions.V10
+
+ room_create_event = make_event_from_dict(
+ self.add_hashes_and_signatures_from_other_server(
+ {
+ "room_id": remote_room_id,
+ "sender": room_creator_user_id,
+ "depth": 1,
+ "origin_server_ts": 1,
+ "type": EventTypes.Create,
+ "state_key": "",
+ "content": {
+ # The `ROOM_CREATOR` field could be removed if we used a room
+ # version > 10 (in favor of relying on `sender`)
+ EventContentFields.ROOM_CREATOR: room_creator_user_id,
+ EventContentFields.ROOM_VERSION: room_version.identifier,
+ },
+ "auth_events": [],
+ "prev_events": [],
+ }
+ ),
+ room_version=room_version,
+ )
+
+ creator_membership_event = make_event_from_dict(
+ self.add_hashes_and_signatures_from_other_server(
+ {
+ "room_id": remote_room_id,
+ "sender": room_creator_user_id,
+ "depth": 2,
+ "origin_server_ts": 2,
+ "type": EventTypes.Member,
+ "state_key": room_creator_user_id,
+ "content": {"membership": Membership.JOIN},
+ "auth_events": [room_create_event.event_id],
+ "prev_events": [room_create_event.event_id],
+ }
+ ),
+ room_version=room_version,
+ )
+
+ # From the remote homeserver, invite user1 on the local homserver
+ user1_invite_membership_event = make_event_from_dict(
+ self.add_hashes_and_signatures_from_other_server(
+ {
+ "room_id": remote_room_id,
+ "sender": room_creator_user_id,
+ "depth": 3,
+ "origin_server_ts": 3,
+ "type": EventTypes.Member,
+ "state_key": local_user1_id,
+ "content": {"membership": Membership.INVITE},
+ "auth_events": [
+ room_create_event.event_id,
+ creator_membership_event.event_id,
+ ],
+ "prev_events": [creator_membership_event.event_id],
+ }
+ ),
+ room_version=room_version,
+ )
+ channel = self.make_signed_federation_request(
+ "PUT",
+ f"/_matrix/federation/v2/invite/{remote_room_id}/{user1_invite_membership_event.event_id}",
+ content={
+ "event": user1_invite_membership_event.get_dict(),
+ "invite_room_state": [
+ strip_event(room_create_event),
+ ],
+ "room_version": room_version.identifier,
+ },
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
+
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [(EventTypes.Member, StateValues.WILDCARD)],
+ "timeline_limit": 0,
+ }
+ }
+ }
+
+ # Sync until the local user1 can see the invite
+ with test_timeout(
+ 3,
+ "Unable to find user1's invite event in the room",
+ ):
+ while True:
+ response_body, _ = self.do_sync(sync_body, tok=local_user1_tok)
+ if (
+ remote_room_id in response_body["rooms"].keys()
+ # If they have `invite_state` for the room, they are invited
+ and len(
+ response_body["rooms"][remote_room_id].get("invite_state", [])
+ )
+ > 0
+ ):
+ break
+
+ # Prevent tight-looping to allow the `test_timeout` to work
+ time.sleep(0.1)
+
+ user1_join_membership_event_template = make_event_from_dict(
+ {
+ "room_id": remote_room_id,
+ "sender": local_user1_id,
+ "depth": 4,
+ "origin_server_ts": 4,
+ "type": EventTypes.Member,
+ "state_key": local_user1_id,
+ "content": {"membership": Membership.JOIN},
+ "auth_events": [
+ room_create_event.event_id,
+ user1_invite_membership_event.event_id,
+ ],
+ "prev_events": [user1_invite_membership_event.event_id],
+ },
+ room_version=room_version,
+ )
+
+ T = TypeVar("T")
+
+ # Mock the remote homeserver responding to our HTTP requests
+ #
+ # We're going to mock the following endpoints so that user1 can join the remote room:
+ # - GET /_matrix/federation/v1/make_join/{room_id}/{user_id}
+ # - PUT /_matrix/federation/v2/send_join/{room_id}/{user_id}
+ #
+ async def get_json(
+ destination: str,
+ path: str,
+ args: Optional[QueryParams] = None,
+ retry_on_dns_fail: bool = True,
+ timeout: Optional[int] = None,
+ ignore_backoff: bool = False,
+ try_trailing_slash_on_400: bool = False,
+ parser: Optional[ByteParser[T]] = None,
+ ) -> Union[JsonDict, T]:
+ if (
+ path
+ == f"/_matrix/federation/v1/make_join/{urllib.parse.quote_plus(remote_room_id)}/{urllib.parse.quote_plus(local_user1_id)}"
+ ):
+ return {
+ "event": user1_join_membership_event_template.get_pdu_json(),
+ "room_version": room_version.identifier,
+ }
+
+ raise NotImplementedError(
+ "We have not mocked a response for `get_json(...)` for the following endpoint yet: "
+ + f"{destination}{path}"
+ )
+
+ self.federation_http_client.get_json.side_effect = get_json
+
+ # PDU's that hs1 sent to hs2
+ collected_pdus_from_hs1_federation_send: Set[str] = set()
+
+ async def put_json(
+ destination: str,
+ path: str,
+ args: Optional[QueryParams] = None,
+ data: Optional[JsonDict] = None,
+ json_data_callback: Optional[Callable[[], JsonDict]] = None,
+ long_retries: bool = False,
+ timeout: Optional[int] = None,
+ ignore_backoff: bool = False,
+ backoff_on_404: bool = False,
+ try_trailing_slash_on_400: bool = False,
+ parser: Optional[ByteParser[T]] = None,
+ backoff_on_all_error_codes: bool = False,
+ ) -> Union[JsonDict, T, SendJoinResponse]:
+ if (
+ path.startswith(
+ f"/_matrix/federation/v2/send_join/{urllib.parse.quote_plus(remote_room_id)}/"
+ )
+ and data is not None
+ and data.get("type") == EventTypes.Member
+ and data.get("state_key") == local_user1_id
+ # We're assuming this is a `ByteParser[SendJoinResponse]`
+ and parser is not None
+ ):
+ # As the remote server, we need to sign the event before sending it back
+ user1_join_membership_event_signed = make_event_from_dict(
+ self.add_hashes_and_signatures_from_other_server(data),
+ room_version=room_version,
+ )
+
+ # Since they passed in a `parser`, we need to return the type that
+ # they're expecting instead of just a `JsonDict`
+ return SendJoinResponse(
+ auth_events=[
+ room_create_event,
+ user1_invite_membership_event,
+ ],
+ state=[
+ room_create_event,
+ creator_membership_event,
+ user1_invite_membership_event,
+ ],
+ event_dict=user1_join_membership_event_signed.get_pdu_json(),
+ event=user1_join_membership_event_signed,
+ members_omitted=False,
+ servers_in_room=[
+ self.OTHER_SERVER_NAME,
+ ],
+ )
+
+ if path.startswith("/_matrix/federation/v1/send/") and data is not None:
+ for pdu in data.get("pdus", []):
+ event = event_from_pdu_json(pdu, room_version)
+ collected_pdus_from_hs1_federation_send.add(event.event_id)
+
+ # Just acknowledge everything hs1 is trying to send hs2
+ return {
+ event_from_pdu_json(pdu, room_version).event_id: {}
+ for pdu in data.get("pdus", [])
+ }
+
+ raise NotImplementedError(
+ "We have not mocked a response for `put_json(...)` for the following endpoint yet: "
+ + f"{destination}{path} with the following body data: {data}"
+ )
+
+ self.federation_http_client.put_json.side_effect = put_json
+
+ # User1 joins the room
+ self.helper.join(remote_room_id, local_user1_id, tok=local_user1_tok)
+
+ # Reset the mocks now that user1 has joined the room
+ self.federation_http_client.get_json.side_effect = None
+ self.federation_http_client.put_json.side_effect = None
+
+ # Sync until the local user1 can see that they are now joined to the room
+ with test_timeout(
+ 3,
+ "Unable to find user1's join event in the room",
+ ):
+ while True:
+ response_body, _ = self.do_sync(sync_body, tok=local_user1_tok)
+ if remote_room_id in response_body["rooms"].keys():
+ required_state_map = required_state_json_to_state_map(
+ response_body["rooms"][remote_room_id]["required_state"]
+ )
+ if (
+ required_state_map.get((EventTypes.Member, local_user1_id))
+ is not None
+ ):
+ break
+
+ # Prevent tight-looping to allow the `test_timeout` to work
+ time.sleep(0.1)
+
+ # Nothing needs to be sent from hs1 to hs2 since we already let the other
+ # homeserver know by doing the `/make_join` and `/send_join` dance.
+ self.assertIncludes(
+ collected_pdus_from_hs1_federation_send,
+ set(),
+ exact=True,
+ message="Didn't expect any events to be sent from hs1 over federation to hs2",
+ )
+
+ return RemoteRoomJoinResult(
+ remote_room_id=remote_room_id,
+ room_version=room_version,
+ remote_room_creator_user_id=room_creator_user_id,
+ local_user1_id=local_user1_id,
+ local_user1_tok=local_user1_tok,
+ state_map=self.get_success(
+ self.storage_controllers.state.get_current_state(remote_room_id)
+ ),
+ )
+
+ def test_can_join_from_out_of_band_invite(self) -> None:
+ """
+ Test to make sure that we can join a room that we were invited to over
+ federation; even if our server has never participated in the room before.
+ """
+ self._invite_local_user_to_remote_room_and_join()
+
+ @parameterized.expand(
+ [("accept invite", Membership.JOIN), ("reject invite", Membership.LEAVE)]
+ )
+ def test_can_x_from_out_of_band_invite_after_we_are_already_participating_in_the_room(
+ self, _test_description: str, membership_action: str
+ ) -> None:
+ """
+ Test to make sure that we can do either a) join the room (accept the invite) or
+ b) reject the invite after being invited to over federation; even if we are
+ already participating in the room.
+
+ This is a regression test to make sure we stress the scenario where even though
+ we are already participating in the room, local users can still react to invites
+ regardless of whether the remote server has told us about the invite event (via
+ a federation `/send` transaction) and we have de-outliered the invite event.
+ Previously, we would mistakenly throw an error saying the user wasn't in the
+ room when they tried to join or reject the invite.
+ """
+ remote_room_join_result = self._invite_local_user_to_remote_room_and_join()
+ remote_room_id = remote_room_join_result.remote_room_id
+ room_version = remote_room_join_result.room_version
+
+ # Create another local user
+ local_user2_id = self.register_user("user2", "pass")
+ local_user2_tok = self.login(local_user2_id, "pass")
+
+ T = TypeVar("T")
+
+ # PDU's that hs1 sent to hs2
+ collected_pdus_from_hs1_federation_send: Set[str] = set()
+
+ async def put_json(
+ destination: str,
+ path: str,
+ args: Optional[QueryParams] = None,
+ data: Optional[JsonDict] = None,
+ json_data_callback: Optional[Callable[[], JsonDict]] = None,
+ long_retries: bool = False,
+ timeout: Optional[int] = None,
+ ignore_backoff: bool = False,
+ backoff_on_404: bool = False,
+ try_trailing_slash_on_400: bool = False,
+ parser: Optional[ByteParser[T]] = None,
+ backoff_on_all_error_codes: bool = False,
+ ) -> Union[JsonDict, T]:
+ if path.startswith("/_matrix/federation/v1/send/") and data is not None:
+ for pdu in data.get("pdus", []):
+ event = event_from_pdu_json(pdu, room_version)
+ collected_pdus_from_hs1_federation_send.add(event.event_id)
+
+ # Just acknowledge everything hs1 is trying to send hs2
+ return {
+ event_from_pdu_json(pdu, room_version).event_id: {}
+ for pdu in data.get("pdus", [])
+ }
+
+ raise NotImplementedError(
+ "We have not mocked a response for `put_json(...)` for the following endpoint yet: "
+ + f"{destination}{path} with the following body data: {data}"
+ )
+
+ self.federation_http_client.put_json.side_effect = put_json
+
+ # From the remote homeserver, invite user2 on the local homserver
+ user2_invite_membership_event = make_event_from_dict(
+ self.add_hashes_and_signatures_from_other_server(
+ {
+ "room_id": remote_room_id,
+ "sender": remote_room_join_result.remote_room_creator_user_id,
+ "depth": 5,
+ "origin_server_ts": 5,
+ "type": EventTypes.Member,
+ "state_key": local_user2_id,
+ "content": {"membership": Membership.INVITE},
+ "auth_events": [
+ remote_room_join_result.state_map[
+ (EventTypes.Create, "")
+ ].event_id,
+ remote_room_join_result.state_map[
+ (
+ EventTypes.Member,
+ remote_room_join_result.remote_room_creator_user_id,
+ )
+ ].event_id,
+ ],
+ "prev_events": [
+ remote_room_join_result.state_map[
+ (EventTypes.Member, remote_room_join_result.local_user1_id)
+ ].event_id
+ ],
+ }
+ ),
+ room_version=room_version,
+ )
+ channel = self.make_signed_federation_request(
+ "PUT",
+ f"/_matrix/federation/v2/invite/{remote_room_id}/{user2_invite_membership_event.event_id}",
+ content={
+ "event": user2_invite_membership_event.get_dict(),
+ "invite_room_state": [
+ strip_event(
+ remote_room_join_result.state_map[(EventTypes.Create, "")]
+ ),
+ ],
+ "room_version": room_version.identifier,
+ },
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
+
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [(EventTypes.Member, StateValues.WILDCARD)],
+ "timeline_limit": 0,
+ }
+ }
+ }
+
+ # Sync until the local user2 can see the invite
+ with test_timeout(
+ 3,
+ "Unable to find user2's invite event in the room",
+ ):
+ while True:
+ response_body, _ = self.do_sync(sync_body, tok=local_user2_tok)
+ if (
+ remote_room_id in response_body["rooms"].keys()
+ # If they have `invite_state` for the room, they are invited
+ and len(
+ response_body["rooms"][remote_room_id].get("invite_state", [])
+ )
+ > 0
+ ):
+ break
+
+ # Prevent tight-looping to allow the `test_timeout` to work
+ time.sleep(0.1)
+
+ if membership_action == Membership.JOIN:
+ # User2 joins the room
+ join_event = self.helper.join(
+ remote_room_join_result.remote_room_id,
+ local_user2_id,
+ tok=local_user2_tok,
+ )
+ expected_pdu_event_id = join_event["event_id"]
+ elif membership_action == Membership.LEAVE:
+ # User2 rejects the invite
+ leave_event = self.helper.leave(
+ remote_room_join_result.remote_room_id,
+ local_user2_id,
+ tok=local_user2_tok,
+ )
+ expected_pdu_event_id = leave_event["event_id"]
+ else:
+ raise NotImplementedError(
+ "This test does not support this membership action yet"
+ )
+
+ # Sync until the local user2 can see their new membership in the room
+ with test_timeout(
+ 3,
+ "Unable to find user2's new membership event in the room",
+ ):
+ while True:
+ response_body, _ = self.do_sync(sync_body, tok=local_user2_tok)
+ if membership_action == Membership.JOIN:
+ if remote_room_id in response_body["rooms"].keys():
+ required_state_map = required_state_json_to_state_map(
+ response_body["rooms"][remote_room_id]["required_state"]
+ )
+ if (
+ required_state_map.get((EventTypes.Member, local_user2_id))
+ is not None
+ ):
+ break
+ elif membership_action == Membership.LEAVE:
+ if remote_room_id not in response_body["rooms"].keys():
+ break
+ else:
+ raise NotImplementedError(
+ "This test does not support this membership action yet"
+ )
+
+ # Prevent tight-looping to allow the `test_timeout` to work
+ time.sleep(0.1)
+
+ # Make sure that we let hs2 know about the new membership event
+ self.assertIncludes(
+ collected_pdus_from_hs1_federation_send,
+ {expected_pdu_event_id},
+ exact=True,
+ message="Expected to find the event ID of the user2 membership to be sent from hs1 over federation to hs2",
+ )
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 6a8887fe74..cd906bbbc7 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -34,6 +34,7 @@ from synapse.handlers.device import DeviceHandler
from synapse.rest import admin
from synapse.rest.client import login
from synapse.server import HomeServer
+from synapse.storage.databases.main.events_worker import EventMetadata
from synapse.types import JsonDict, ReadReceipt
from synapse.util import Clock
@@ -55,12 +56,15 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
federation_transport_client=self.federation_transport_client,
)
- hs.get_storage_controllers().state.get_current_hosts_in_room = AsyncMock( # type: ignore[method-assign]
+ self.main_store = hs.get_datastores().main
+ self.state_controller = hs.get_storage_controllers().state
+
+ self.state_controller.get_current_hosts_in_room = AsyncMock( # type: ignore[method-assign]
return_value={"test", "host2"}
)
- hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = ( # type: ignore[method-assign]
- hs.get_storage_controllers().state.get_current_hosts_in_room
+ self.state_controller.get_current_hosts_in_room_or_partial_state_approximation = ( # type: ignore[method-assign]
+ self.state_controller.get_current_hosts_in_room
)
return hs
@@ -185,12 +189,15 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
],
)
- def test_send_receipts_with_backoff(self) -> None:
- """Send two receipts in quick succession; the second should be flushed, but
- only after 20ms"""
+ def test_send_receipts_with_backoff_small_room(self) -> None:
+ """Read receipt in small rooms should not be delayed"""
mock_send_transaction = self.federation_transport_client.send_transaction
mock_send_transaction.return_value = {}
+ self.state_controller.get_current_hosts_in_room_or_partial_state_approximation = AsyncMock( # type: ignore[method-assign]
+ return_value={"test", "host2"}
+ )
+
sender = self.hs.get_federation_sender()
receipt = ReadReceipt(
"room_id",
@@ -206,47 +213,104 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
# expect a call to send_transaction
mock_send_transaction.assert_called_once()
- json_cb = mock_send_transaction.call_args[0][1]
- data = json_cb()
- self.assertEqual(
- data["edus"],
- [
- {
- "edu_type": EduTypes.RECEIPT,
- "content": {
- "room_id": {
- "m.read": {
- "user_id": {
- "event_ids": ["event_id"],
- "data": {"ts": 1234},
- }
- }
- }
- },
- }
- ],
+ self._assert_edu_in_call(mock_send_transaction.call_args[0][1])
+
+ def test_send_receipts_with_backoff_recent_event(self) -> None:
+ """Read receipt for a recent message should not be delayed"""
+ mock_send_transaction = self.federation_transport_client.send_transaction
+ mock_send_transaction.return_value = {}
+
+ # Pretend this is a big room
+ self.state_controller.get_current_hosts_in_room_or_partial_state_approximation = AsyncMock( # type: ignore[method-assign]
+ return_value={"test"} | {f"host{i}" for i in range(20)}
)
+
+ self.main_store.get_metadata_for_event = AsyncMock(
+ return_value=EventMetadata(
+ received_ts=self.clock.time_msec(),
+ sender="@test:test",
+ )
+ )
+
+ sender = self.hs.get_federation_sender()
+ receipt = ReadReceipt(
+ "room_id",
+ "m.read",
+ "user_id",
+ ["event_id"],
+ thread_id=None,
+ data={"ts": 1234},
+ )
+ self.get_success(sender.send_read_receipt(receipt))
+
+ self.pump()
+
+ # expect a call to send_transaction for each host
+ self.assertEqual(mock_send_transaction.call_count, 20)
+ self._assert_edu_in_call(mock_send_transaction.call_args.args[1])
+
mock_send_transaction.reset_mock()
- # send the second RR
+ def test_send_receipts_with_backoff_sender(self) -> None:
+ """Read receipt for a message should not be delayed to the sender, but
+ is delayed to everyone else"""
+ mock_send_transaction = self.federation_transport_client.send_transaction
+ mock_send_transaction.return_value = {}
+
+ # Pretend this is a big room
+ self.state_controller.get_current_hosts_in_room_or_partial_state_approximation = AsyncMock( # type: ignore[method-assign]
+ return_value={"test"} | {f"host{i}" for i in range(20)}
+ )
+
+ self.main_store.get_metadata_for_event = AsyncMock(
+ return_value=EventMetadata(
+ received_ts=self.clock.time_msec() - 5 * 60_000,
+ sender="@test:host1",
+ )
+ )
+
+ sender = self.hs.get_federation_sender()
receipt = ReadReceipt(
"room_id",
"m.read",
"user_id",
- ["other_id"],
+ ["event_id"],
thread_id=None,
data={"ts": 1234},
)
- self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
+ self.get_success(sender.send_read_receipt(receipt))
+
self.pump()
- mock_send_transaction.assert_not_called()
- self.reactor.advance(19)
- mock_send_transaction.assert_not_called()
+ # First, expect a call to send_transaction for the sending host
+ mock_send_transaction.assert_called()
- self.reactor.advance(10)
- mock_send_transaction.assert_called_once()
- json_cb = mock_send_transaction.call_args[0][1]
+ transaction = mock_send_transaction.call_args_list[0].args[0]
+ self.assertEqual(transaction.destination, "host1")
+ self._assert_edu_in_call(mock_send_transaction.call_args_list[0].args[1])
+
+ # We also expect a call to one of the other hosts, as the first
+ # destination to wake up.
+ self.assertEqual(mock_send_transaction.call_count, 2)
+ self._assert_edu_in_call(mock_send_transaction.call_args.args[1])
+
+ mock_send_transaction.reset_mock()
+
+ # We now expect to see 18 more transactions to the remaining hosts
+ # periodically.
+ for _ in range(18):
+ self.reactor.advance(
+ 1.0
+ / self.hs.config.ratelimiting.federation_rr_transactions_per_room_per_second
+ )
+
+ mock_send_transaction.assert_called_once()
+ self._assert_edu_in_call(mock_send_transaction.call_args.args[1])
+ mock_send_transaction.reset_mock()
+
+ def _assert_edu_in_call(self, json_cb: Callable[[], JsonDict]) -> None:
+ """Assert that the given `json_cb` from a `send_transaction` has a
+ receipt in it."""
data = json_cb()
self.assertEqual(
data["edus"],
@@ -257,7 +321,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
"room_id": {
"m.read": {
"user_id": {
- "event_ids": ["other_id"],
+ "event_ids": ["event_id"],
"data": {"ts": 1234},
}
}
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index 88261450b1..42dc844734 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -20,14 +20,21 @@
#
import logging
from http import HTTPStatus
+from typing import Optional, Union
+from unittest.mock import Mock
from parameterized import parameterized
from twisted.test.proto_helpers import MemoryReactor
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.api.constants import EventTypes, Membership
+from synapse.api.errors import FederationError
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.events import EventBase, make_event_from_dict
+from synapse.federation.federation_base import event_from_pdu_json
+from synapse.http.types import QueryParams
+from synapse.logging.context import LoggingContext
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
@@ -85,6 +92,163 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase):
self.assertEqual(500, channel.code, channel.result)
+def _create_acl_event(content: JsonDict) -> EventBase:
+ return make_event_from_dict(
+ {
+ "room_id": "!a:b",
+ "event_id": "$a:b",
+ "type": "m.room.server_acls",
+ "sender": "@a:b",
+ "content": content,
+ }
+ )
+
+
+class MessageAcceptTests(unittest.FederatingHomeserverTestCase):
+ """
+ Tests to make sure that we don't accept flawed events from federation (incoming).
+ """
+
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.http_client = Mock()
+ return self.setup_test_homeserver(federation_http_client=self.http_client)
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ super().prepare(reactor, clock, hs)
+
+ self.store = self.hs.get_datastores().main
+ self.storage_controllers = hs.get_storage_controllers()
+ self.federation_event_handler = self.hs.get_federation_event_handler()
+
+ # Create a local room
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ self.room_id = self.helper.create_room_as(
+ user1_id, tok=user1_tok, is_public=True
+ )
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(self.room_id)
+ )
+
+ # Figure out what the forward extremities in the room are (the most recent
+ # events that aren't tied into the DAG)
+ forward_extremity_event_ids = self.get_success(
+ self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
+ )
+
+ # Join a remote user to the room that will attempt to send bad events
+ self.remote_bad_user_id = f"@baduser:{self.OTHER_SERVER_NAME}"
+ self.remote_bad_user_join_event = make_event_from_dict(
+ self.add_hashes_and_signatures_from_other_server(
+ {
+ "room_id": self.room_id,
+ "sender": self.remote_bad_user_id,
+ "state_key": self.remote_bad_user_id,
+ "depth": 1000,
+ "origin_server_ts": 1,
+ "type": EventTypes.Member,
+ "content": {"membership": Membership.JOIN},
+ "auth_events": [
+ state_map[(EventTypes.Create, "")].event_id,
+ state_map[(EventTypes.JoinRules, "")].event_id,
+ ],
+ "prev_events": list(forward_extremity_event_ids),
+ }
+ ),
+ room_version=RoomVersions.V10,
+ )
+
+ # Send the join, it should return None (which is not an error)
+ self.assertEqual(
+ self.get_success(
+ self.federation_event_handler.on_receive_pdu(
+ self.OTHER_SERVER_NAME, self.remote_bad_user_join_event
+ )
+ ),
+ None,
+ )
+
+ # Make sure we actually joined the room
+ self.assertEqual(
+ self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)),
+ {self.remote_bad_user_join_event.event_id},
+ )
+
+ def test_cant_hide_direct_ancestors(self) -> None:
+ """
+ If you send a message, you must be able to provide the direct
+ prev_events that said event references.
+ """
+
+ async def post_json(
+ destination: str,
+ path: str,
+ data: Optional[JsonDict] = None,
+ long_retries: bool = False,
+ timeout: Optional[int] = None,
+ ignore_backoff: bool = False,
+ args: Optional[QueryParams] = None,
+ ) -> Union[JsonDict, list]:
+ # If it asks us for new missing events, give them NOTHING
+ if path.startswith("/_matrix/federation/v1/get_missing_events/"):
+ return {"events": []}
+ return {}
+
+ self.http_client.post_json = post_json
+
+ # Figure out what the forward extremities in the room are (the most recent
+ # events that aren't tied into the DAG)
+ forward_extremity_event_ids = self.get_success(
+ self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
+ )
+
+ # Now lie about an event's prev_events
+ lying_event = make_event_from_dict(
+ self.add_hashes_and_signatures_from_other_server(
+ {
+ "room_id": self.room_id,
+ "sender": self.remote_bad_user_id,
+ "depth": 1000,
+ "origin_server_ts": 1,
+ "type": "m.room.message",
+ "content": {"body": "hewwo?"},
+ "auth_events": [],
+ "prev_events": ["$missing_prev_event"]
+ + list(forward_extremity_event_ids),
+ }
+ ),
+ room_version=RoomVersions.V10,
+ )
+
+ with LoggingContext("test-context"):
+ failure = self.get_failure(
+ self.federation_event_handler.on_receive_pdu(
+ self.OTHER_SERVER_NAME, lying_event
+ ),
+ FederationError,
+ )
+
+ # on_receive_pdu should throw an error
+ self.assertEqual(
+ failure.value.args[0],
+ (
+ "ERROR 403: Your server isn't divulging details about prev_events "
+ "referenced in this event."
+ ),
+ )
+
+ # Make sure the invalid event isn't there
+ extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
+ self.assertEqual(extrem, {self.remote_bad_user_join_event.event_id})
+
+
class ServerACLsTestCase(unittest.TestCase):
def test_blocked_server(self) -> None:
e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]})
@@ -355,13 +519,76 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
# is probably sufficient to reassure that the bucket is updated.
-def _create_acl_event(content: JsonDict) -> EventBase:
- return make_event_from_dict(
- {
- "room_id": "!a:b",
- "event_id": "$a:b",
- "type": "m.room.server_acls",
- "sender": "@a:b",
- "content": content,
+class StripUnsignedFromEventsTestCase(unittest.TestCase):
+ """
+ Test to make sure that we handle the raw JSON events from federation carefully and
+ strip anything that shouldn't be there.
+ """
+
+ def test_strip_unauthorized_unsigned_values(self) -> None:
+ event1 = {
+ "sender": "@baduser:test.serv",
+ "state_key": "@baduser:test.serv",
+ "event_id": "$event1:test.serv",
+ "depth": 1000,
+ "origin_server_ts": 1,
+ "type": "m.room.member",
+ "origin": "test.servx",
+ "content": {"membership": "join"},
+ "auth_events": [],
+ "unsigned": {"malicious garbage": "hackz", "more warez": "more hackz"},
}
- )
+ filtered_event = event_from_pdu_json(event1, RoomVersions.V1)
+ # Make sure unauthorized fields are stripped from unsigned
+ self.assertNotIn("more warez", filtered_event.unsigned)
+
+ def test_strip_event_maintains_allowed_fields(self) -> None:
+ event2 = {
+ "sender": "@baduser:test.serv",
+ "state_key": "@baduser:test.serv",
+ "event_id": "$event2:test.serv",
+ "depth": 1000,
+ "origin_server_ts": 1,
+ "type": "m.room.member",
+ "origin": "test.servx",
+ "auth_events": [],
+ "content": {"membership": "join"},
+ "unsigned": {
+ "malicious garbage": "hackz",
+ "more warez": "more hackz",
+ "age": 14,
+ "invite_room_state": [],
+ },
+ }
+
+ filtered_event2 = event_from_pdu_json(event2, RoomVersions.V1)
+ self.assertIn("age", filtered_event2.unsigned)
+ self.assertEqual(14, filtered_event2.unsigned["age"])
+ self.assertNotIn("more warez", filtered_event2.unsigned)
+ # Invite_room_state is allowed in events of type m.room.member
+ self.assertIn("invite_room_state", filtered_event2.unsigned)
+ self.assertEqual([], filtered_event2.unsigned["invite_room_state"])
+
+ def test_strip_event_removes_fields_based_on_event_type(self) -> None:
+ event3 = {
+ "sender": "@baduser:test.serv",
+ "state_key": "@baduser:test.serv",
+ "event_id": "$event3:test.serv",
+ "depth": 1000,
+ "origin_server_ts": 1,
+ "type": "m.room.power_levels",
+ "origin": "test.servx",
+ "content": {},
+ "auth_events": [],
+ "unsigned": {
+ "malicious garbage": "hackz",
+ "more warez": "more hackz",
+ "age": 14,
+ "invite_room_state": [],
+ },
+ }
+ filtered_event3 = event_from_pdu_json(event3, RoomVersions.V1)
+ self.assertIn("age", filtered_event3.unsigned)
+ # Invite_room_state field is only permitted in event type m.room.member
+ self.assertNotIn("invite_room_state", filtered_event3.unsigned)
+ self.assertNotIn("more warez", filtered_event3.unsigned)
diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
index 9ff853a83d..c5bff468e2 100644
--- a/tests/handlers/test_admin.py
+++ b/tests/handlers/test_admin.py
@@ -260,7 +260,6 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(args[0]["name"], self.user2)
self.assertIn("displayname", args[0])
self.assertIn("avatar_url", args[0])
- self.assertIn("threepids", args[0])
self.assertIn("external_ids", args[0])
self.assertIn("creation_ts", args[0])
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 1eec0d43b7..1db630e9e4 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -1165,12 +1165,23 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
self.hs.get_datastores().main.services_cache = [self._service]
# Register some appservice users
- self._sender_user, self._sender_device = self.register_appservice_user(
+ user_id, device_id = self.register_appservice_user(
"as.sender", self._service_token
)
- self._namespaced_user, self._namespaced_device = self.register_appservice_user(
+ # With MSC4190 enabled, there will not be a device created
+ # during AS registration. However MSC4190 is not enabled
+ # in this test. It may become the default behaviour in the
+ # future, in which case this test will need to be updated.
+ assert device_id is not None
+ self._sender_user = user_id
+ self._sender_device = device_id
+
+ user_id, device_id = self.register_appservice_user(
"_as_user1", self._service_token
)
+ assert device_id is not None
+ self._namespaced_user = user_id
+ self._namespaced_device = device_id
# Register a real user as well.
self._real_user = self.register_user("real.user", "meow")
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
deleted file mode 100644
index f41f7d36ad..0000000000
--- a/tests/handlers/test_cas.py
+++ /dev/null
@@ -1,239 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
-# Copyright (C) 2023 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-# Originally licensed under the Apache License, Version 2.0:
-# <http://www.apache.org/licenses/LICENSE-2.0>.
-#
-# [This file includes modifications made by New Vector Limited]
-#
-#
-from typing import Any, Dict
-from unittest.mock import AsyncMock, Mock
-
-from twisted.test.proto_helpers import MemoryReactor
-
-from synapse.handlers.cas import CasResponse
-from synapse.server import HomeServer
-from synapse.util import Clock
-
-from tests.unittest import HomeserverTestCase, override_config
-
-# These are a few constants that are used as config parameters in the tests.
-BASE_URL = "https://synapse/"
-SERVER_URL = "https://issuer/"
-
-
-class CasHandlerTestCase(HomeserverTestCase):
- def default_config(self) -> Dict[str, Any]:
- config = super().default_config()
- config["public_baseurl"] = BASE_URL
- cas_config = {
- "enabled": True,
- "server_url": SERVER_URL,
- "service_url": BASE_URL,
- }
-
- # Update this config with what's in the default config so that
- # override_config works as expected.
- cas_config.update(config.get("cas_config", {}))
- config["cas_config"] = cas_config
-
- return config
-
- def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- hs = self.setup_test_homeserver()
-
- self.handler = hs.get_cas_handler()
-
- # Reduce the number of attempts when generating MXIDs.
- sso_handler = hs.get_sso_handler()
- sso_handler._MAP_USERNAME_RETRIES = 3
-
- return hs
-
- def test_map_cas_user_to_user(self) -> None:
- """Ensure that mapping the CAS user returned from a provider to an MXID works properly."""
-
- # stub out the auth handler
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
-
- cas_response = CasResponse("test_user", {})
- request = _mock_request()
- self.get_success(
- self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
- )
-
- # check that the auth handler got called as expected
- auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test",
- "cas",
- request,
- "redirect_uri",
- None,
- new_user=True,
- auth_provider_session_id=None,
- )
-
- def test_map_cas_user_to_existing_user(self) -> None:
- """Existing users can log in with CAS account."""
- store = self.hs.get_datastores().main
- self.get_success(
- store.register_user(user_id="@test_user:test", password_hash=None)
- )
-
- # stub out the auth handler
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
-
- # Map a user via SSO.
- cas_response = CasResponse("test_user", {})
- request = _mock_request()
- self.get_success(
- self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
- )
-
- # check that the auth handler got called as expected
- auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test",
- "cas",
- request,
- "redirect_uri",
- None,
- new_user=False,
- auth_provider_session_id=None,
- )
-
- # Subsequent calls should map to the same mxid.
- auth_handler.complete_sso_login.reset_mock()
- self.get_success(
- self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
- )
- auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test",
- "cas",
- request,
- "redirect_uri",
- None,
- new_user=False,
- auth_provider_session_id=None,
- )
-
- def test_map_cas_user_to_invalid_localpart(self) -> None:
- """CAS automaps invalid characters to base-64 encoding."""
-
- # stub out the auth handler
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
-
- cas_response = CasResponse("föö", {})
- request = _mock_request()
- self.get_success(
- self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
- )
-
- # check that the auth handler got called as expected
- auth_handler.complete_sso_login.assert_called_once_with(
- "@f=c3=b6=c3=b6:test",
- "cas",
- request,
- "redirect_uri",
- None,
- new_user=True,
- auth_provider_session_id=None,
- )
-
- @override_config(
- {
- "cas_config": {
- "required_attributes": {"userGroup": "staff", "department": None}
- }
- }
- )
- def test_required_attributes(self) -> None:
- """The required attributes must be met from the CAS response."""
-
- # stub out the auth handler
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
-
- # The response doesn't have the proper userGroup or department.
- cas_response = CasResponse("test_user", {})
- request = _mock_request()
- self.get_success(
- self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
- )
- auth_handler.complete_sso_login.assert_not_called()
-
- # The response doesn't have any department.
- cas_response = CasResponse("test_user", {"userGroup": ["staff"]})
- request.reset_mock()
- self.get_success(
- self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
- )
- auth_handler.complete_sso_login.assert_not_called()
-
- # Add the proper attributes and it should succeed.
- cas_response = CasResponse(
- "test_user", {"userGroup": ["staff", "admin"], "department": ["sales"]}
- )
- request.reset_mock()
- self.get_success(
- self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
- )
-
- # check that the auth handler got called as expected
- auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test",
- "cas",
- request,
- "redirect_uri",
- None,
- new_user=True,
- auth_provider_session_id=None,
- )
-
- @override_config({"cas_config": {"enable_registration": False}})
- def test_map_cas_user_does_not_register_new_user(self) -> None:
- """Ensures new users are not registered if the enabled registration flag is disabled."""
-
- # stub out the auth handler
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
-
- cas_response = CasResponse("test_user", {})
- request = _mock_request()
- self.get_success(
- self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
- )
-
- # check that the auth handler was not called as expected
- auth_handler.complete_sso_login.assert_not_called()
-
-
-def _mock_request() -> Mock:
- """Returns a mock which will stand in as a SynapseRequest"""
- mock = Mock(
- spec=[
- "finish",
- "getClientAddress",
- "getHeader",
- "setHeader",
- "setResponseCode",
- "write",
- ]
- )
- # `_disconnected` musn't be another `Mock`, otherwise it will be truthy.
- mock._disconnected = False
- return mock
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 4a3e36ffde..b7058d8002 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -587,6 +587,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
self.room_list_handler = hs.get_room_list_handler()
self.directory_handler = hs.get_directory_handler()
+ @unittest.override_config({"room_list_publication_rules": [{"action": "allow"}]})
def test_disabling_room_list(self) -> None:
self.room_list_handler.enable_room_list_search = True
self.directory_handler.enable_room_list_search = True
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 8a3dfdcf75..70fc4263e7 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -19,6 +19,7 @@
# [This file includes modifications made by New Vector Limited]
#
#
+import time
from typing import Dict, Iterable
from unittest import mock
@@ -151,18 +152,30 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
def test_claim_one_time_key(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
- keys = {"alg1:k1": "key1"}
-
res = self.get_success(
self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": keys}
+ local_user, device_id, {"one_time_keys": {"alg1:k1": "key1"}}
)
)
self.assertDictEqual(
res, {"one_time_key_counts": {"alg1": 1, "signed_curve25519": 0}}
)
- res2 = self.get_success(
+ # Keys should be returned in the order they were uploaded. To test, advance time
+ # a little, then upload a second key with an earlier key ID; it should get
+ # returned second.
+ self.reactor.advance(1)
+ res = self.get_success(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": {"alg1:k0": "key0"}}
+ )
+ )
+ self.assertDictEqual(
+ res, {"one_time_key_counts": {"alg1": 2, "signed_curve25519": 0}}
+ )
+
+ # now claim both keys back. They should be in the same order
+ res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
self.requester,
@@ -171,12 +184,27 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
)
self.assertEqual(
- res2,
+ res,
{
"failures": {},
"one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}},
},
)
+ res = self.get_success(
+ self.handler.claim_one_time_keys(
+ {local_user: {device_id: {"alg1": 1}}},
+ self.requester,
+ timeout=None,
+ always_include_fallback_keys=False,
+ )
+ )
+ self.assertEqual(
+ res,
+ {
+ "failures": {},
+ "one_time_keys": {local_user: {device_id: {"alg1:k0": "key0"}}},
+ },
+ )
def test_claim_one_time_key_bulk(self) -> None:
"""Like test_claim_one_time_key but claims multiple keys in one handler call."""
@@ -336,6 +364,47 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
counts_by_alg, expected_counts_by_alg, f"{user_id}:{device_id}"
)
+ def test_claim_one_time_key_bulk_ordering(self) -> None:
+ """Keys returned by the bulk claim call should be returned in the correct order"""
+
+ # Alice has lots of keys, uploaded in a specific order
+ alice = f"@alice:{self.hs.hostname}"
+ alice_dev = "alice_dev_1"
+
+ self.get_success(
+ self.handler.upload_keys_for_user(
+ alice,
+ alice_dev,
+ {"one_time_keys": {"alg1:k20": 20, "alg1:k21": 21, "alg1:k22": 22}},
+ )
+ )
+ # Advance time by 1s, to ensure that there is a difference in upload time.
+ self.reactor.advance(1)
+ self.get_success(
+ self.handler.upload_keys_for_user(
+ alice,
+ alice_dev,
+ {"one_time_keys": {"alg1:k10": 10, "alg1:k11": 11, "alg1:k12": 12}},
+ )
+ )
+
+ # Now claim some, and check we get the right ones.
+ claim_res = self.get_success(
+ self.handler.claim_one_time_keys(
+ {alice: {alice_dev: {"alg1": 2}}},
+ self.requester,
+ timeout=None,
+ always_include_fallback_keys=False,
+ )
+ )
+ # We should get the first-uploaded keys, even though they have later key ids.
+ # We should get a random set of two of k20, k21, k22.
+ self.assertEqual(claim_res["failures"], {})
+ claimed_keys = claim_res["one_time_keys"]["@alice:test"]["alice_dev_1"]
+ self.assertEqual(len(claimed_keys), 2)
+ for key_id in claimed_keys.keys():
+ self.assertIn(key_id, ["alg1:k20", "alg1:k21", "alg1:k22"])
+
def test_fallback_key(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
@@ -1758,3 +1827,222 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
self.assertIs(exists, True)
self.assertIs(replaceable_without_uia, False)
+
+ def test_delete_old_one_time_keys(self) -> None:
+ """Test the db migration that clears out old OTKs"""
+
+ # We upload two sets of keys, one just over a week ago, and one just less than
+ # a week ago. Each batch contains some keys that match the deletion pattern
+ # (key IDs of 6 chars), and some that do not.
+ #
+ # Finally, set the scheduled task going, and check what gets deleted.
+
+ user_id = "@user000:" + self.hs.hostname
+ device_id = "xyz"
+
+ # The scheduled task should be for "now" in real, wallclock time, so
+ # set the test reactor to just over a week ago.
+ self.reactor.advance(time.time() - 7.5 * 24 * 3600)
+
+ # Upload some keys
+ self.get_success(
+ self.handler.upload_keys_for_user(
+ user_id,
+ device_id,
+ {
+ "one_time_keys": {
+ # some keys to delete
+ "alg1:AAAAAA": "key1",
+ "alg2:AAAAAB": {"key": "key2", "signatures": {"k1": "sig1"}},
+ # A key to *not* delete
+ "alg2:AAAAAAAAAA": {"key": "key3"},
+ }
+ },
+ )
+ )
+
+ # A day passes
+ self.reactor.advance(24 * 3600)
+
+ # Upload some more keys
+ self.get_success(
+ self.handler.upload_keys_for_user(
+ user_id,
+ device_id,
+ {
+ "one_time_keys": {
+ # some keys which match the pattern
+ "alg1:BAAAAA": "key1",
+ "alg2:BAAAAB": {"key": "key2", "signatures": {"k1": "sig1"}},
+ # A key to *not* delete
+ "alg2:BAAAAAAAAA": {"key": "key3"},
+ }
+ },
+ )
+ )
+
+ # The rest of the week passes, which should set the scheduled task going.
+ self.reactor.advance(6.5 * 24 * 3600)
+
+ # Check what we're left with in the database
+ remaining_key_ids = {
+ row[0]
+ for row in self.get_success(
+ self.handler.store.db_pool.simple_select_list(
+ "e2e_one_time_keys_json", None, ["key_id"]
+ )
+ )
+ }
+ self.assertEqual(
+ remaining_key_ids, {"AAAAAAAAAA", "BAAAAA", "BAAAAB", "BAAAAAAAAA"}
+ )
+
+ @override_config(
+ {
+ "experimental_features": {
+ "msc4263_limit_key_queries_to_users_who_share_rooms": True
+ }
+ }
+ )
+ def test_query_devices_remote_restricted_not_in_shared_room(self) -> None:
+ """Tests that querying keys for a remote user that we don't share a room
+ with returns nothing.
+ """
+
+ remote_user_id = "@test:other"
+ local_user_id = "@test:test"
+
+ # Do *not* pretend we're sharing a room with the user we're querying.
+
+ remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
+ remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
+
+ self.hs.get_federation_client().query_client_keys = mock.AsyncMock( # type: ignore[method-assign]
+ return_value={
+ "device_keys": {remote_user_id: {}},
+ "master_keys": {
+ remote_user_id: {
+ "user_id": remote_user_id,
+ "usage": ["master"],
+ "keys": {"ed25519:" + remote_master_key: remote_master_key},
+ },
+ },
+ "self_signing_keys": {
+ remote_user_id: {
+ "user_id": remote_user_id,
+ "usage": ["self_signing"],
+ "keys": {
+ "ed25519:"
+ + remote_self_signing_key: remote_self_signing_key
+ },
+ }
+ },
+ }
+ )
+
+ e2e_handler = self.hs.get_e2e_keys_handler()
+
+ query_result = self.get_success(
+ e2e_handler.query_devices(
+ {
+ "device_keys": {remote_user_id: []},
+ },
+ timeout=10,
+ from_user_id=local_user_id,
+ from_device_id="some_device_id",
+ )
+ )
+
+ self.assertEqual(
+ query_result,
+ {
+ "device_keys": {},
+ "failures": {},
+ "master_keys": {},
+ "self_signing_keys": {},
+ "user_signing_keys": {},
+ },
+ )
+
+ @override_config(
+ {
+ "experimental_features": {
+ "msc4263_limit_key_queries_to_users_who_share_rooms": True
+ }
+ }
+ )
+ def test_query_devices_remote_restricted_in_shared_room(self) -> None:
+ """Tests that querying keys for a remote user that we share a room
+ with returns the cross signing keys correctly.
+ """
+
+ remote_user_id = "@test:other"
+ local_user_id = "@test:test"
+
+ # Pretend we're sharing a room with the user we're querying. If not,
+ # `query_devices` will filter out the user ID and `_query_devices_for_destination`
+ # will return early.
+ self.store.do_users_share_a_room_joined_or_invited = mock.AsyncMock( # type: ignore[method-assign]
+ return_value=[remote_user_id]
+ )
+ self.store.get_rooms_for_user = mock.AsyncMock(return_value={"some_room_id"})
+
+ remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
+ remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
+
+ self.hs.get_federation_client().query_user_devices = mock.AsyncMock( # type: ignore[method-assign]
+ return_value={
+ "user_id": remote_user_id,
+ "stream_id": 1,
+ "devices": [],
+ "master_key": {
+ "user_id": remote_user_id,
+ "usage": ["master"],
+ "keys": {"ed25519:" + remote_master_key: remote_master_key},
+ },
+ "self_signing_key": {
+ "user_id": remote_user_id,
+ "usage": ["self_signing"],
+ "keys": {
+ "ed25519:" + remote_self_signing_key: remote_self_signing_key
+ },
+ },
+ }
+ )
+
+ e2e_handler = self.hs.get_e2e_keys_handler()
+
+ query_result = self.get_success(
+ e2e_handler.query_devices(
+ {
+ "device_keys": {remote_user_id: []},
+ },
+ timeout=10,
+ from_user_id=local_user_id,
+ from_device_id="some_device_id",
+ )
+ )
+
+ self.assertEqual(query_result["failures"], {})
+ self.assertEqual(
+ query_result["master_keys"],
+ {
+ remote_user_id: {
+ "user_id": remote_user_id,
+ "usage": ["master"],
+ "keys": {"ed25519:" + remote_master_key: remote_master_key},
+ }
+ },
+ )
+ self.assertEqual(
+ query_result["self_signing_keys"],
+ {
+ remote_user_id: {
+ "user_id": remote_user_id,
+ "usage": ["self_signing"],
+ "keys": {
+ "ed25519:" + remote_self_signing_key: remote_self_signing_key
+ },
+ }
+ },
+ )
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 3fe5b0a1b4..b64a8a86a2 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -44,7 +44,7 @@ from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.storage.databases.main.events_worker import EventCacheEntry
from synapse.util import Clock
-from synapse.util.stringutils import random_string
+from synapse.util.events import generate_fake_event_id
from tests import unittest
from tests.test_utils import event_injection
@@ -52,10 +52,6 @@ from tests.test_utils import event_injection
logger = logging.getLogger(__name__)
-def generate_fake_event_id() -> str:
- return "$fake_" + random_string(43)
-
-
class FederationTestCase(unittest.FederatingHomeserverTestCase):
servlets = [
admin.register_servlets,
@@ -665,9 +661,12 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
)
)
- with patch.object(
- fed_client, "make_membership_event", mock_make_membership_event
- ), patch.object(fed_client, "send_join", mock_send_join):
+ with (
+ patch.object(
+ fed_client, "make_membership_event", mock_make_membership_event
+ ),
+ patch.object(fed_client, "send_join", mock_send_join),
+ ):
# Join and check that our join event is rejected
# (The join event is rejected because it doesn't have any signatures)
join_exc = self.get_failure(
@@ -712,9 +711,12 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
fed_handler = self.hs.get_federation_handler()
store = self.hs.get_datastores().main
- with patch.object(
- fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
- ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
+ with (
+ patch.object(
+ fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
+ ),
+ patch.object(store, "is_partial_state_room", mock_is_partial_state_room),
+ ):
# Start the partial state sync.
fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
@@ -764,9 +766,12 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
fed_handler = self.hs.get_federation_handler()
store = self.hs.get_datastores().main
- with patch.object(
- fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
- ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
+ with (
+ patch.object(
+ fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
+ ),
+ patch.object(store, "is_partial_state_room", mock_is_partial_state_room),
+ ):
# Start the partial state sync.
fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index 1b83aea579..51eca56c3b 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -288,13 +288,15 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
}
# We also expect an outbound request to /state
- self.mock_federation_transport_client.get_room_state.return_value = StateRequestResponse(
- # Mimic the other server not knowing about the state at all.
- # We want to cause Synapse to throw an error (`Unable to get
- # missing prev_event $fake_prev_event`) and fail to backfill
- # the pulled event.
- auth_events=[],
- state=[],
+ self.mock_federation_transport_client.get_room_state.return_value = (
+ StateRequestResponse(
+ # Mimic the other server not knowing about the state at all.
+ # We want to cause Synapse to throw an error (`Unable to get
+ # missing prev_event $fake_prev_event`) and fail to backfill
+ # the pulled event.
+ auth_events=[],
+ state=[],
+ )
)
pulled_event = make_event_from_dict(
@@ -373,7 +375,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
In this test, we pretend we are processing a "pulled" event via
backfill. The pulled event succesfully processes and the backward
- extremeties are updated along with clearing out any failed pull attempts
+ extremities are updated along with clearing out any failed pull attempts
for those old extremities.
We check that we correctly cleared failed pull attempts of the
@@ -805,6 +807,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
main_store = self.hs.get_datastores().main
+ state_deletion_store = self.hs.get_datastores().state_deletion
# Create the room.
kermit_user_id = self.register_user("kermit", "test")
@@ -956,7 +959,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
bert_member_event.event_id: bert_member_event,
rejected_kick_event.event_id: rejected_kick_event,
},
- state_res_store=StateResolutionStore(main_store),
+ state_res_store=StateResolutionStore(
+ main_store, state_deletion_store
+ ),
)
),
[bert_member_event.event_id, rejected_kick_event.event_id],
@@ -1001,7 +1006,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
rejected_power_levels_event.event_id,
],
event_map={},
- state_res_store=StateResolutionStore(main_store),
+ state_res_store=StateResolutionStore(
+ main_store, state_deletion_store
+ ),
full_conflicted_set=set(),
)
),
diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py
index 036c539db2..37acb660e7 100644
--- a/tests/handlers/test_oauth_delegation.py
+++ b/tests/handlers/test_oauth_delegation.py
@@ -43,6 +43,7 @@ from synapse.api.errors import (
OAuthInsufficientScopeError,
SynapseError,
)
+from synapse.appservice import ApplicationService
from synapse.http.site import SynapseRequest
from synapse.rest import admin
from synapse.rest.client import account, devices, keys, login, logout, register
@@ -146,6 +147,16 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
return hs
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
+ # Provision the user and the device we use in the tests.
+ store = homeserver.get_datastores().main
+ self.get_success(store.register_user(USER_ID))
+ self.get_success(
+ store.store_device(USER_ID, DEVICE, initial_device_display_name=None)
+ )
+
def _assertParams(self) -> None:
"""Assert that the request parameters are correct."""
params = parse_qs(self.http_client.request.call_args[1]["data"].decode("utf-8"))
@@ -379,6 +390,44 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
)
self.assertEqual(requester.device_id, DEVICE)
+ def test_active_user_with_device_explicit_device_id(self) -> None:
+ """The handler should return a requester with normal user rights and a device ID, given explicitly, as supported by MAS 0.15+"""
+
+ self.http_client.request = AsyncMock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={
+ "active": True,
+ "sub": SUBJECT,
+ "scope": " ".join([MATRIX_USER_SCOPE]),
+ "device_id": DEVICE,
+ "username": USERNAME,
+ },
+ )
+ )
+ request = Mock(args={})
+ request.args[b"access_token"] = [b"mockAccessToken"]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ requester = self.get_success(self.auth.get_user_by_req(request))
+ self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+ self.http_client.request.assert_called_once_with(
+ method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+ )
+ # It should have called with the 'X-MAS-Supports-Device-Id: 1' header
+ self.assertEqual(
+ self.http_client.request.call_args[1]["headers"].getRawHeaders(
+ b"X-MAS-Supports-Device-Id",
+ ),
+ [b"1"],
+ )
+ self._assertParams()
+ self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
+ self.assertEqual(requester.is_guest, False)
+ self.assertEqual(
+ get_awaitable_result(self.auth.is_server_admin(requester)), False
+ )
+ self.assertEqual(requester.device_id, DEVICE)
+
def test_multiple_devices(self) -> None:
"""The handler should raise an error if multiple devices are found in the scope."""
@@ -500,6 +549,44 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
self.assertEqual(error.value.code, 503)
+ def test_cached_expired_introspection(self) -> None:
+ """The handler should raise an error if the introspection response gives
+ an expiry time, the introspection response is cached and then the entry is
+ re-requested after it has expired."""
+
+ self.http_client.request = introspection_mock = AsyncMock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={
+ "active": True,
+ "sub": SUBJECT,
+ "scope": " ".join(
+ [
+ MATRIX_USER_SCOPE,
+ f"{MATRIX_DEVICE_SCOPE_PREFIX}AABBCC",
+ ]
+ ),
+ "username": USERNAME,
+ "expires_in": 60,
+ },
+ )
+ )
+ request = Mock(args={})
+ request.args[b"access_token"] = [b"mockAccessToken"]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+
+ # The first CS-API request causes a successful introspection
+ self.get_success(self.auth.get_user_by_req(request))
+ self.assertEqual(introspection_mock.call_count, 1)
+
+ # Sleep for 60 seconds so the token expires.
+ self.reactor.advance(60.0)
+
+ # Now the CS-API request fails because the token expired
+ self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
+ # Ensure another introspection request was not sent
+ self.assertEqual(introspection_mock.call_count, 1)
+
def make_device_keys(self, user_id: str, device_id: str) -> JsonDict:
# We only generate a master key to simplify the test.
master_signing_key = generate_signing_key(device_id)
@@ -550,7 +637,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
access_token="mockAccessToken",
)
- self.assertEqual(channel.code, HTTPStatus.NOT_IMPLEMENTED, channel.json_body)
+ self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.json_body)
def expect_unauthorized(
self, method: str, path: str, content: Union[bytes, str, JsonDict] = ""
@@ -560,15 +647,31 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
self.assertEqual(channel.code, 401, channel.json_body)
def expect_unrecognized(
- self, method: str, path: str, content: Union[bytes, str, JsonDict] = ""
+ self,
+ method: str,
+ path: str,
+ content: Union[bytes, str, JsonDict] = "",
+ auth: bool = False,
) -> None:
- channel = self.make_request(method, path, content)
+ channel = self.make_request(
+ method, path, content, access_token="token" if auth else None
+ )
self.assertEqual(channel.code, 404, channel.json_body)
self.assertEqual(
channel.json_body["errcode"], Codes.UNRECOGNIZED, channel.json_body
)
+ def expect_forbidden(
+ self, method: str, path: str, content: Union[bytes, str, JsonDict] = ""
+ ) -> None:
+ channel = self.make_request(method, path, content)
+
+ self.assertEqual(channel.code, 403, channel.json_body)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.FORBIDDEN, channel.json_body
+ )
+
def test_uia_endpoints(self) -> None:
"""Test that endpoints that were removed in MSC2964 are no longer available."""
@@ -580,36 +683,6 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
"POST", "/_matrix/client/v3/keys/device_signing/upload"
)
- def test_3pid_endpoints(self) -> None:
- """Test that 3pid account management endpoints that were removed in MSC2964 are no longer available."""
-
- # Remains and requires auth:
- self.expect_unauthorized("GET", "/_matrix/client/v3/account/3pid")
- self.expect_unauthorized(
- "POST",
- "/_matrix/client/v3/account/3pid/bind",
- {
- "client_secret": "foo",
- "id_access_token": "bar",
- "id_server": "foo",
- "sid": "bar",
- },
- )
- self.expect_unauthorized("POST", "/_matrix/client/v3/account/3pid/unbind", {})
-
- # These are gone:
- self.expect_unrecognized(
- "POST", "/_matrix/client/v3/account/3pid"
- ) # deprecated
- self.expect_unrecognized("POST", "/_matrix/client/v3/account/3pid/add")
- self.expect_unrecognized("POST", "/_matrix/client/v3/account/3pid/delete")
- self.expect_unrecognized(
- "POST", "/_matrix/client/v3/account/3pid/email/requestToken"
- )
- self.expect_unrecognized(
- "POST", "/_matrix/client/v3/account/3pid/msisdn/requestToken"
- )
-
def test_account_management_endpoints_removed(self) -> None:
"""Test that account management endpoints that were removed in MSC2964 are no longer available."""
self.expect_unrecognized("POST", "/_matrix/client/v3/account/deactivate")
@@ -623,11 +696,35 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_registration_endpoints_removed(self) -> None:
"""Test that registration endpoints that were removed in MSC2964 are no longer available."""
+ appservice = ApplicationService(
+ token="i_am_an_app_service",
+ id="1234",
+ namespaces={"users": [{"regex": r"@alice:.+", "exclusive": True}]},
+ sender="@as_main:test",
+ )
+
+ self.hs.get_datastores().main.services_cache = [appservice]
self.expect_unrecognized(
"GET", "/_matrix/client/v1/register/m.login.registration_token/validity"
)
+
+ # Registration is disabled
+ self.expect_forbidden(
+ "POST",
+ "/_matrix/client/v3/register",
+ {"username": "alice", "password": "hunter2"},
+ )
+
# This is still available for AS registrations
- # self.expect_unrecognized("POST", "/_matrix/client/v3/register")
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/register",
+ {"username": "alice", "type": "m.login.application_service"},
+ shorthand=False,
+ access_token="i_am_an_app_service",
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
self.expect_unrecognized("GET", "/_matrix/client/v3/register/available")
self.expect_unrecognized(
"POST", "/_matrix/client/v3/register/email/requestToken"
@@ -648,8 +745,25 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_device_management_endpoints_removed(self) -> None:
"""Test that device management endpoints that were removed in MSC2964 are no longer available."""
- self.expect_unrecognized("POST", "/_matrix/client/v3/delete_devices")
- self.expect_unrecognized("DELETE", "/_matrix/client/v3/devices/{DEVICE}")
+
+ # Because we still support those endpoints with ASes, it checks the
+ # access token before returning 404
+ self.http_client.request = AsyncMock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={
+ "active": True,
+ "sub": SUBJECT,
+ "scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]),
+ "username": USERNAME,
+ },
+ )
+ )
+
+ self.expect_unrecognized("POST", "/_matrix/client/v3/delete_devices", auth=True)
+ self.expect_unrecognized(
+ "DELETE", "/_matrix/client/v3/devices/{DEVICE}", auth=True
+ )
def test_openid_endpoints_removed(self) -> None:
"""Test that OpenID id_token endpoints that were removed in MSC2964 are no longer available."""
@@ -772,7 +886,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
req = SynapseRequest(channel, self.site) # type: ignore[arg-type]
req.client.host = MAS_IPV4_ADDR
req.requestHeaders.addRawHeader(
- "Authorization", f"Bearer {self.auth._admin_token}"
+ "Authorization", f"Bearer {self.auth._admin_token()}"
)
req.requestHeaders.addRawHeader("User-Agent", MAS_USER_AGENT)
req.content = BytesIO(b"")
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index a81501979d..ff8e3c5cb6 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -57,6 +57,7 @@ CLIENT_ID = "test-client-id"
CLIENT_SECRET = "test-client-secret"
BASE_URL = "https://synapse/"
CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback"
+TEST_REDIRECT_URI = "https://test/oidc/callback"
SCOPES = ["openid"]
# config for common cases
@@ -70,12 +71,16 @@ DEFAULT_CONFIG = {
}
# extends the default config with explicit OAuth2 endpoints instead of using discovery
+#
+# We add "explicit" to things to make them different from the discovered values to make
+# sure that the explicit values override the discovered ones.
EXPLICIT_ENDPOINT_CONFIG = {
**DEFAULT_CONFIG,
"discover": False,
- "authorization_endpoint": ISSUER + "authorize",
- "token_endpoint": ISSUER + "token",
- "jwks_uri": ISSUER + "jwks",
+ "authorization_endpoint": ISSUER + "authorize-explicit",
+ "token_endpoint": ISSUER + "token-explicit",
+ "jwks_uri": ISSUER + "jwks-explicit",
+ "id_token_signing_alg_values_supported": ["RS256", "<explicit>"],
}
@@ -259,12 +264,64 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.provider.load_metadata())
self.fake_server.get_metadata_handler.assert_not_called()
+ @override_config({"oidc_config": {**EXPLICIT_ENDPOINT_CONFIG, "discover": True}})
+ def test_discovery_with_explicit_config(self) -> None:
+ """
+ The handler should discover the endpoints from OIDC discovery document but
+ values are overriden by the explicit config.
+ """
+ # This would throw if some metadata were invalid
+ metadata = self.get_success(self.provider.load_metadata())
+ self.fake_server.get_metadata_handler.assert_called_once()
+
+ self.assertEqual(metadata.issuer, self.fake_server.issuer)
+ # It seems like authlib does not have that defined in its metadata models
+ self.assertEqual(
+ metadata.get("userinfo_endpoint"),
+ self.fake_server.userinfo_endpoint,
+ )
+
+ # Ensure the values are overridden correctly since these were configured
+ # explicitly
+ self.assertEqual(
+ metadata.authorization_endpoint,
+ EXPLICIT_ENDPOINT_CONFIG["authorization_endpoint"],
+ )
+ self.assertEqual(
+ metadata.token_endpoint, EXPLICIT_ENDPOINT_CONFIG["token_endpoint"]
+ )
+ self.assertEqual(metadata.jwks_uri, EXPLICIT_ENDPOINT_CONFIG["jwks_uri"])
+ self.assertEqual(
+ metadata.id_token_signing_alg_values_supported,
+ EXPLICIT_ENDPOINT_CONFIG["id_token_signing_alg_values_supported"],
+ )
+
+ # subsequent calls should be cached
+ self.reset_mocks()
+ self.get_success(self.provider.load_metadata())
+ self.fake_server.get_metadata_handler.assert_not_called()
+
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
def test_no_discovery(self) -> None:
"""When discovery is disabled, it should not try to load from discovery document."""
- self.get_success(self.provider.load_metadata())
+ metadata = self.get_success(self.provider.load_metadata())
self.fake_server.get_metadata_handler.assert_not_called()
+ # Ensure the values are overridden correctly since these were configured
+ # explicitly
+ self.assertEqual(
+ metadata.authorization_endpoint,
+ EXPLICIT_ENDPOINT_CONFIG["authorization_endpoint"],
+ )
+ self.assertEqual(
+ metadata.token_endpoint, EXPLICIT_ENDPOINT_CONFIG["token_endpoint"]
+ )
+ self.assertEqual(metadata.jwks_uri, EXPLICIT_ENDPOINT_CONFIG["jwks_uri"])
+ self.assertEqual(
+ metadata.id_token_signing_alg_values_supported,
+ EXPLICIT_ENDPOINT_CONFIG["id_token_signing_alg_values_supported"],
+ )
+
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_load_jwks(self) -> None:
"""JWKS loading is done once (then cached) if used."""
@@ -427,6 +484,32 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(code_verifier, "")
self.assertEqual(redirect, "http://client/redirect")
+ @override_config(
+ {
+ "oidc_config": {
+ **DEFAULT_CONFIG,
+ "passthrough_authorization_parameters": ["additional_parameter"],
+ }
+ }
+ )
+ def test_passthrough_parameters(self) -> None:
+ """The redirect request has additional parameters, one is authorized, one is not"""
+ req = Mock(spec=["cookies", "args"])
+ req.cookies = []
+ req.args = {}
+ req.args[b"additional_parameter"] = ["a_value".encode("utf-8")]
+ req.args[b"not_authorized_parameter"] = ["any".encode("utf-8")]
+
+ url = urlparse(
+ self.get_success(
+ self.provider.handle_redirect_request(req, b"http://client/redirect")
+ )
+ )
+
+ params = parse_qs(url.query)
+ self.assertEqual(params["additional_parameter"], ["a_value"])
+ self.assertNotIn("not_authorized_parameters", params)
+
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_redirect_request_with_code_challenge(self) -> None:
"""The redirect request has the right arguments & generates a valid session cookie."""
@@ -530,6 +613,24 @@ class OidcHandlerTestCase(HomeserverTestCase):
code_verifier = get_value_from_macaroon(macaroon, "code_verifier")
self.assertEqual(code_verifier, "")
+ @override_config(
+ {"oidc_config": {**DEFAULT_CONFIG, "redirect_uri": TEST_REDIRECT_URI}}
+ )
+ def test_redirect_request_with_overridden_redirect_uri(self) -> None:
+ """The authorization endpoint redirect has the overridden `redirect_uri` value."""
+ req = Mock(spec=["cookies"])
+ req.cookies = []
+
+ url = urlparse(
+ self.get_success(
+ self.provider.handle_redirect_request(req, b"http://client/redirect")
+ )
+ )
+
+ # Ensure that the redirect_uri in the returned url has been overridden.
+ params = parse_qs(url.query)
+ self.assertEqual(params["redirect_uri"], [TEST_REDIRECT_URI])
+
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback_error(self) -> None:
"""Errors from the provider returned in the callback are displayed."""
@@ -901,6 +1002,81 @@ class OidcHandlerTestCase(HomeserverTestCase):
{
"oidc_config": {
**DEFAULT_CONFIG,
+ "redirect_uri": TEST_REDIRECT_URI,
+ }
+ }
+ )
+ def test_code_exchange_with_overridden_redirect_uri(self) -> None:
+ """Code exchange behaves correctly and handles various error scenarios."""
+ # Set up a fake IdP with a token endpoint handler.
+ token = {
+ "type": "Bearer",
+ "access_token": "aabbcc",
+ }
+
+ self.fake_server.post_token_handler.side_effect = None
+ self.fake_server.post_token_handler.return_value = FakeResponse.json(
+ payload=token
+ )
+ code = "code"
+
+ # Exchange the code against the fake IdP.
+ self.get_success(self.provider._exchange_code(code, code_verifier=""))
+
+ # Check that the `redirect_uri` parameter provided matches our
+ # overridden config value.
+ kwargs = self.fake_server.request.call_args[1]
+ args = parse_qs(kwargs["data"].decode("utf-8"))
+ self.assertEqual(args["redirect_uri"], [TEST_REDIRECT_URI])
+
+ @override_config(
+ {
+ "oidc_config": {
+ **DEFAULT_CONFIG,
+ "redirect_uri": TEST_REDIRECT_URI,
+ }
+ }
+ )
+ def test_code_exchange_ignores_access_token(self) -> None:
+ """
+ Code exchange completes successfully and doesn't validate the `at_hash`
+ (access token hash) field of an ID token when the access token isn't
+ going to be used.
+
+ The access token won't be used in this test because Synapse (currently)
+ only needs it to fetch a user's metadata if it isn't included in the ID
+ token itself.
+
+ Because we have included "openid" in the requested scopes for this IdP
+ (see `SCOPES`), user metadata is be included in the ID token. Thus the
+ access token isn't needed, and it's unnecessary for Synapse to validate
+ the access token.
+
+ This is a regression test for a situation where an upstream identity
+ provider was providing an invalid `at_hash` value, which Synapse errored
+ on, yet Synapse wasn't using the access token for anything.
+ """
+ # Exchange the code against the fake IdP.
+ userinfo = {
+ "sub": "foo",
+ "username": "foo",
+ "phone": "1234567",
+ }
+ with self.fake_server.id_token_override(
+ {
+ "at_hash": "invalid-hash",
+ }
+ ):
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+
+ # If no error was rendered, then we have success.
+ self.render_error.assert_not_called()
+
+ @override_config(
+ {
+ "oidc_config": {
+ **DEFAULT_CONFIG,
"user_mapping_provider": {
"module": __name__ + ".TestMappingProviderExtra"
},
@@ -1271,6 +1447,113 @@ class OidcHandlerTestCase(HomeserverTestCase):
{
"oidc_config": {
**DEFAULT_CONFIG,
+ "attribute_requirements": [
+ {"attribute": "test", "one_of": ["foo", "bar"]}
+ ],
+ }
+ }
+ )
+ def test_attribute_requirements_one_of_succeeds(self) -> None:
+ """Test that auth succeeds if userinfo attribute has multiple values and CONTAINS required value"""
+ # userinfo with "test": ["bar"] attribute should succeed.
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": ["bar"],
+ }
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+
+ # check that the auth handler got called as expected
+ self.complete_sso_login.assert_called_once_with(
+ "@tester:test",
+ self.provider.idp_id,
+ request,
+ ANY,
+ None,
+ new_user=True,
+ auth_provider_session_id=None,
+ )
+
+ @override_config(
+ {
+ "oidc_config": {
+ **DEFAULT_CONFIG,
+ "attribute_requirements": [
+ {"attribute": "test", "one_of": ["foo", "bar"]}
+ ],
+ }
+ }
+ )
+ def test_attribute_requirements_one_of_fails(self) -> None:
+ """Test that auth fails if userinfo attribute has multiple values yet
+ DOES NOT CONTAIN a required value
+ """
+ # userinfo with "test": ["something else"] attribute should fail.
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": ["something else"],
+ }
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
+
+ @override_config(
+ {
+ "oidc_config": {
+ **DEFAULT_CONFIG,
+ "attribute_requirements": [{"attribute": "test"}],
+ }
+ }
+ )
+ def test_attribute_requirements_does_not_exist(self) -> None:
+ """OIDC login fails if the required attribute does not exist in the OIDC userinfo response."""
+ # userinfo lacking "test" attribute should fail.
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ }
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
+
+ @override_config(
+ {
+ "oidc_config": {
+ **DEFAULT_CONFIG,
+ "attribute_requirements": [{"attribute": "test"}],
+ }
+ }
+ )
+ def test_attribute_requirements_exist(self) -> None:
+ """OIDC login succeeds if the required attribute exist (regardless of value)
+ in the OIDC userinfo response.
+ """
+ # userinfo with "test" attribute and random value should succeed.
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": random_string(5), # value does not matter
+ }
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+
+ # check that the auth handler got called as expected
+ self.complete_sso_login.assert_called_once_with(
+ "@tester:test",
+ self.provider.idp_id,
+ request,
+ ANY,
+ None,
+ new_user=True,
+ auth_provider_session_id=None,
+ )
+
+ @override_config(
+ {
+ "oidc_config": {
+ **DEFAULT_CONFIG,
"attribute_requirements": [{"attribute": "test", "value": "foobar"}],
}
}
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index ed203eb299..d0351a8509 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -768,17 +768,6 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# Check that the callback has been called.
m.assert_called_once()
- # Set some email configuration so the test doesn't fail because of its absence.
- @override_config({"email": {"notif_from": "noreply@test"}})
- def test_3pid_allowed(self) -> None:
- """Tests that an is_3pid_allowed_callbacks forbidding a 3PID makes Synapse refuse
- to bind the new 3PID, and that one allowing a 3PID makes Synapse accept to bind
- the 3PID. Also checks that the module is passed a boolean indicating whether the
- user to bind this 3PID to is currently registering.
- """
- self._test_3pid_allowed("rin", False)
- self._test_3pid_allowed("kitay", True)
-
def test_displayname(self) -> None:
"""Tests that the get_displayname_for_registration callback can define the
display name of a user when registering.
@@ -829,66 +818,6 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# Check that the callback has been called.
m.assert_called_once()
- def _test_3pid_allowed(self, username: str, registration: bool) -> None:
- """Tests that the "is_3pid_allowed" module callback is called correctly, using
- either /register or /account URLs depending on the arguments.
-
- Args:
- username: The username to use for the test.
- registration: Whether to test with registration URLs.
- """
- self.hs.get_identity_handler().send_threepid_validation = AsyncMock( # type: ignore[method-assign]
- return_value=0
- )
-
- m = AsyncMock(return_value=False)
- self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
-
- self.register_user(username, "password")
- tok = self.login(username, "password")
-
- if registration:
- url = "/register/email/requestToken"
- else:
- url = "/account/3pid/email/requestToken"
-
- channel = self.make_request(
- "POST",
- url,
- {
- "client_secret": "foo",
- "email": "foo@test.com",
- "send_attempt": 0,
- },
- access_token=tok,
- )
- self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
- self.assertEqual(
- channel.json_body["errcode"],
- Codes.THREEPID_DENIED,
- channel.json_body,
- )
-
- m.assert_called_once_with("email", "foo@test.com", registration)
-
- m = AsyncMock(return_value=True)
- self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
-
- channel = self.make_request(
- "POST",
- url,
- {
- "client_secret": "foo",
- "email": "bar@test.com",
- "send_attempt": 0,
- },
- access_token=tok,
- )
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
- self.assertIn("sid", channel.json_body)
-
- m.assert_called_once_with("email", "bar@test.com", registration)
-
def _setup_get_name_for_registration(self, callback_name: str) -> Mock:
"""Registers either a get_username_for_registration callback or a
get_displayname_for_registration callback that appends "-foo" to the username the
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index cc630d606c..6b7bf112c2 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -23,14 +23,21 @@ from typing import Optional, cast
from unittest.mock import Mock, call
from parameterized import parameterized
-from signedjson.key import generate_signing_key
+from signedjson.key import (
+ encode_verify_key_base64,
+ generate_signing_key,
+ get_verify_key,
+)
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.presence import UserDevicePresenceState, UserPresenceState
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
-from synapse.events.builder import EventBuilder
+from synapse.api.room_versions import (
+ RoomVersion,
+)
+from synapse.crypto.event_signing import add_hashes_and_signatures
+from synapse.events import EventBase, make_event_from_dict
from synapse.federation.sender import FederationSender
from synapse.handlers.presence import (
BUSY_ONLINE_TIMEOUT,
@@ -45,18 +52,24 @@ from synapse.handlers.presence import (
handle_update,
)
from synapse.rest import admin
-from synapse.rest.client import room
+from synapse.rest.client import login, room, sync
from synapse.server import HomeServer
from synapse.storage.database import LoggingDatabaseConnection
+from synapse.storage.keys import FetchKeyResult
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util import Clock
from tests import unittest
from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.unittest import override_config
class PresenceUpdateTestCase(unittest.HomeserverTestCase):
- servlets = [admin.register_servlets]
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
@@ -425,6 +438,102 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
wheel_timer.insert.assert_not_called()
+ # `rc_presence` is set very high during unit tests to avoid ratelimiting
+ # subtly impacting unrelated tests. We set the ratelimiting back to a
+ # reasonable value for the tests specific to presence ratelimiting.
+ @override_config(
+ {"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}}
+ )
+ def test_over_ratelimit_offline_to_online_to_unavailable(self) -> None:
+ """
+ Send a presence update, check that it went through, immediately send another one and
+ check that it was ignored.
+ """
+ self._test_ratelimit_offline_to_online_to_unavailable(ratelimited=True)
+
+ @override_config(
+ {"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}}
+ )
+ def test_within_ratelimit_offline_to_online_to_unavailable(self) -> None:
+ """
+ Send a presence update, check that it went through, advancing time a sufficient amount,
+ send another presence update and check that it also worked.
+ """
+ self._test_ratelimit_offline_to_online_to_unavailable(ratelimited=False)
+
+ @override_config(
+ {"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}}
+ )
+ def _test_ratelimit_offline_to_online_to_unavailable(
+ self, ratelimited: bool
+ ) -> None:
+ """Test rate limit for presence updates sent with sync requests.
+
+ Args:
+ ratelimited: Test rate limited case.
+ """
+ wheel_timer = Mock()
+ user_id = "@user:pass"
+ now = 5000000
+ sync_url = "/sync?access_token=%s&set_presence=%s"
+
+ # Register the user who syncs presence
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Get the handler (which kicks off a bunch of timers).
+ presence_handler = self.hs.get_presence_handler()
+
+ # Ensure the user is initially offline.
+ prev_state = UserPresenceState.default(user_id)
+ new_state = prev_state.copy_and_replace(
+ state=PresenceState.OFFLINE, last_active_ts=now
+ )
+
+ state, persist_and_notify, federation_ping = handle_update(
+ prev_state,
+ new_state,
+ is_mine=True,
+ wheel_timer=wheel_timer,
+ now=now,
+ persist=False,
+ )
+
+ # Check that the user is offline.
+ state = self.get_success(
+ presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, PresenceState.OFFLINE)
+
+ # Send sync request with set_presence=online.
+ channel = self.make_request("GET", sync_url % (access_token, "online"))
+ self.assertEqual(200, channel.code)
+
+ # Assert the user is now online.
+ state = self.get_success(
+ presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, PresenceState.ONLINE)
+
+ if not ratelimited:
+ # Advance time a sufficient amount to avoid rate limiting.
+ self.reactor.advance(30)
+
+ # Send another sync request with set_presence=unavailable.
+ channel = self.make_request("GET", sync_url % (access_token, "unavailable"))
+ self.assertEqual(200, channel.code)
+
+ state = self.get_success(
+ presence_handler.get_state(UserID.from_string(user_id))
+ )
+
+ if ratelimited:
+ # Assert the user is still online and presence update was ignored.
+ self.assertEqual(state.state, PresenceState.ONLINE)
+ else:
+ # Assert the user is now unavailable.
+ self.assertEqual(state.state, PresenceState.UNAVAILABLE)
+
class PresenceTimeoutTestCase(unittest.TestCase):
"""Tests different timers and that the timer does not change `status_msg` of user."""
@@ -1107,7 +1216,9 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
),
]
],
- name_func=lambda testcase_func, param_num, params: f"{testcase_func.__name__}_{param_num}_{'workers' if params.args[5] else 'monolith'}",
+ name_func=lambda testcase_func,
+ param_num,
+ params: f"{testcase_func.__name__}_{param_num}_{'workers' if params.args[5] else 'monolith'}",
)
@unittest.override_config({"experimental_features": {"msc3026_enabled": True}})
def test_set_presence_from_syncing_multi_device(
@@ -1343,7 +1454,9 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
),
]
],
- name_func=lambda testcase_func, param_num, params: f"{testcase_func.__name__}_{param_num}_{'workers' if params.args[4] else 'monolith'}",
+ name_func=lambda testcase_func,
+ param_num,
+ params: f"{testcase_func.__name__}_{param_num}_{'workers' if params.args[4] else 'monolith'}",
)
@unittest.override_config({"experimental_features": {"msc3026_enabled": True}})
def test_set_presence_from_non_syncing_multi_device(
@@ -1821,6 +1934,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
# self.event_builder_for_2.hostname = "test2"
self.store = hs.get_datastores().main
+ self.storage_controllers = hs.get_storage_controllers()
self.state = hs.get_state_handler()
self._event_auth_handler = hs.get_event_auth_handler()
@@ -1936,29 +2050,35 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
hostname = get_domain_from_id(user_id)
- room_version = self.get_success(self.store.get_room_version_id(room_id))
+ room_version = self.get_success(self.store.get_room_version(room_id))
- builder = EventBuilder(
- state=self.state,
- event_auth_handler=self._event_auth_handler,
- store=self.store,
- clock=self.clock,
- hostname=hostname,
- signing_key=self.random_signing_key,
- room_version=KNOWN_ROOM_VERSIONS[room_version],
- room_id=room_id,
- type=EventTypes.Member,
- sender=user_id,
- state_key=user_id,
- content={"membership": Membership.JOIN},
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id)
)
- prev_event_ids = self.get_success(
- self.store.get_latest_event_ids_in_room(room_id)
+ # Figure out what the forward extremities in the room are (the most recent
+ # events that aren't tied into the DAG)
+ forward_extremity_event_ids = self.get_success(
+ self.hs.get_datastores().main.get_latest_event_ids_in_room(room_id)
)
- event = self.get_success(
- builder.build(prev_event_ids=list(prev_event_ids), auth_event_ids=None)
+ event = self.create_fake_event_from_remote_server(
+ remote_server_name=hostname,
+ event_dict={
+ "room_id": room_id,
+ "sender": user_id,
+ "type": EventTypes.Member,
+ "state_key": user_id,
+ "depth": 1000,
+ "origin_server_ts": 1,
+ "content": {"membership": Membership.JOIN},
+ "auth_events": [
+ state_map[(EventTypes.Create, "")].event_id,
+ state_map[(EventTypes.JoinRules, "")].event_id,
+ ],
+ "prev_events": list(forward_extremity_event_ids),
+ },
+ room_version=room_version,
)
self.get_success(self.federation_event_handler.on_receive_pdu(hostname, event))
@@ -1966,3 +2086,50 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
# Check that it was successfully persisted.
self.get_success(self.store.get_event(event.event_id))
self.get_success(self.store.get_event(event.event_id))
+
+ def create_fake_event_from_remote_server(
+ self, remote_server_name: str, event_dict: JsonDict, room_version: RoomVersion
+ ) -> EventBase:
+ """
+ This is similar to what `FederatingHomeserverTestCase` is doing but we don't
+ need all of the extra baggage and we want to be able to create an event from
+ many remote servers.
+ """
+
+ # poke the other server's signing key into the key store, so that we don't
+ # make requests for it
+ other_server_signature_key = generate_signing_key("test")
+ verify_key = get_verify_key(other_server_signature_key)
+ verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version)
+
+ self.get_success(
+ self.hs.get_datastores().main.store_server_keys_response(
+ remote_server_name,
+ from_server=remote_server_name,
+ ts_added_ms=self.clock.time_msec(),
+ verify_keys={
+ verify_key_id: FetchKeyResult(
+ verify_key=verify_key,
+ valid_until_ts=self.clock.time_msec() + 10000,
+ ),
+ },
+ response_json={
+ "verify_keys": {
+ verify_key_id: {"key": encode_verify_key_base64(verify_key)}
+ }
+ },
+ )
+ )
+
+ add_hashes_and_signatures(
+ room_version=room_version,
+ event_dict=event_dict,
+ signature_name=remote_server_name,
+ signing_key=other_server_signature_key,
+ )
+ event = make_event_from_dict(
+ event_dict,
+ room_version=room_version,
+ )
+
+ return event
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index cb1c6fbb80..2b9b56da95 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -369,6 +369,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
time_now_ms=self.clock.time_msec(),
upload_name=None,
filesystem_id="xyz",
+ sha256="abcdefg12345",
)
)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 92487692db..99bd0de834 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -588,6 +588,29 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
d = self.store.is_support_user(user_id)
self.assertFalse(self.get_success(d))
+ def test_underscore_localpart_rejected_by_default(self) -> None:
+ for invalid_user_id in ("_", "_prefixed"):
+ with self.subTest(invalid_user_id=invalid_user_id):
+ self.get_failure(
+ self.handler.register_user(localpart=invalid_user_id),
+ SynapseError,
+ )
+
+ @override_config(
+ {
+ "allow_underscore_prefixed_localpart": True,
+ }
+ )
+ def test_underscore_localpart_allowed_if_configured(self) -> None:
+ for valid_user_id in ("_", "_prefixed"):
+ with self.subTest(valid_user_id=valid_user_id):
+ user_id = self.get_success(
+ self.handler.register_user(
+ localpart=valid_user_id,
+ ),
+ )
+ self.assertEqual(user_id, f"@{valid_user_id}:test")
+
def test_invalid_user_id(self) -> None:
invalid_user_id = "^abcd"
self.get_failure(
@@ -715,6 +738,41 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart="bobflimflob", auth_provider_id="saml")
)
+ def test_register_default_user_type(self) -> None:
+ """Test that the default user type is none when registering a user."""
+ user_id = self.get_success(self.handler.register_user(localpart="user"))
+ user_info = self.get_success(self.store.get_user_by_id(user_id))
+ assert user_info is not None
+ self.assertEqual(user_info.user_type, None)
+
+ def test_register_extra_user_types_valid(self) -> None:
+ """
+ Test that the specified user type is set correctly when registering a user.
+ n.b. No validation is done on the user type, so this test
+ is only to ensure that the user type can be set to any value.
+ """
+ user_id = self.get_success(
+ self.handler.register_user(localpart="user", user_type="anyvalue")
+ )
+ user_info = self.get_success(self.store.get_user_by_id(user_id))
+ assert user_info is not None
+ self.assertEqual(user_info.user_type, "anyvalue")
+
+ @override_config(
+ {
+ "user_types": {
+ "extra_user_types": ["extra1", "extra2"],
+ "default_user_type": "extra1",
+ }
+ }
+ )
+ def test_register_extra_user_types_with_default(self) -> None:
+ """Test that the default_user_type in config is set correctly when registering a user."""
+ user_id = self.get_success(self.handler.register_user(localpart="user"))
+ user_info = self.get_success(self.store.get_user_by_id(user_id))
+ assert user_info is not None
+ self.assertEqual(user_info.user_type, "extra1")
+
async def get_or_create_user(
self,
requester: Requester,
diff --git a/tests/handlers/test_room_list.py b/tests/handlers/test_room_list.py
index 4d22ef98c2..45cef09b22 100644
--- a/tests/handlers/test_room_list.py
+++ b/tests/handlers/test_room_list.py
@@ -6,6 +6,7 @@ from synapse.rest.client import directory, login, room
from synapse.types import JsonDict
from tests import unittest
+from tests.utils import default_config
class RoomListHandlerTestCase(unittest.HomeserverTestCase):
@@ -30,6 +31,11 @@ class RoomListHandlerTestCase(unittest.HomeserverTestCase):
assert channel.code == HTTPStatus.OK, f"couldn't publish room: {channel.result}"
return room_id
+ def default_config(self) -> JsonDict:
+ config = default_config("test")
+ config["room_list_publication_rules"] = [{"action": "allow"}]
+ return config
+
def test_acls_applied_to_room_directory_results(self) -> None:
"""
Creates 3 rooms. Room 2 has an ACL that only permits the homeservers
diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py
index 213a66ed1a..d87fe9d62c 100644
--- a/tests/handlers/test_room_member.py
+++ b/tests/handlers/test_room_member.py
@@ -5,10 +5,13 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
import synapse.rest.client.login
import synapse.rest.client.room
-from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import LimitExceededError, SynapseError
+from synapse.api.constants import AccountDataTypes, EventTypes, Membership
+from synapse.api.errors import Codes, LimitExceededError, SynapseError
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events import FrozenEventV3
+from synapse.federation.federation_base import (
+ event_from_pdu_json,
+)
from synapse.federation.federation_client import SendJoinResult
from synapse.server import HomeServer
from synapse.types import UserID, create_requester
@@ -172,20 +175,25 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
)
)
- with patch.object(
- self.handler.federation_handler.federation_client,
- "make_membership_event",
- mock_make_membership_event,
- ), patch.object(
- self.handler.federation_handler.federation_client,
- "send_join",
- mock_send_join,
- ), patch(
- "synapse.event_auth._is_membership_change_allowed",
- return_value=None,
- ), patch(
- "synapse.handlers.federation_event.check_state_dependent_auth_rules",
- return_value=None,
+ with (
+ patch.object(
+ self.handler.federation_handler.federation_client,
+ "make_membership_event",
+ mock_make_membership_event,
+ ),
+ patch.object(
+ self.handler.federation_handler.federation_client,
+ "send_join",
+ mock_send_join,
+ ),
+ patch(
+ "synapse.event_auth._is_membership_change_allowed",
+ return_value=None,
+ ),
+ patch(
+ "synapse.handlers.federation_event.check_state_dependent_auth_rules",
+ return_value=None,
+ ),
):
self.get_success(
self.handler.update_membership(
@@ -380,9 +388,29 @@ class RoomMemberMasterHandlerTestCase(HomeserverTestCase):
)
def test_forget_when_not_left(self) -> None:
- """Tests that a user cannot not forgets a room that has not left."""
+ """Tests that a user cannot forget a room that they are still in."""
self.get_failure(self.handler.forget(self.alice_ID, self.room_id), SynapseError)
+ def test_nonlocal_room_user_action(self) -> None:
+ """
+ Test that non-local user ids cannot perform room actions through
+ this homeserver.
+ """
+ alien_user_id = UserID.from_string("@cheeky_monkey:matrix.org")
+ bad_room_id = f"{self.room_id}+BAD_ID"
+
+ exc = self.get_failure(
+ self.handler.update_membership(
+ create_requester(self.alice),
+ alien_user_id,
+ bad_room_id,
+ "unban",
+ ),
+ SynapseError,
+ ).value
+
+ self.assertEqual(exc.errcode, Codes.BAD_JSON)
+
def test_rejoin_forgotten_by_user(self) -> None:
"""Test that a user that has forgotten a room can do a re-join.
The room was not forgotten from the local server.
@@ -428,3 +456,165 @@ class RoomMemberMasterHandlerTestCase(HomeserverTestCase):
new_count = rows[0][0]
self.assertEqual(initial_count, new_count)
+
+
+class TestInviteFiltering(FederatingHomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ synapse.rest.client.login.register_servlets,
+ synapse.rest.client.room.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.handler = hs.get_room_member_handler()
+ self.fed_handler = hs.get_federation_handler()
+ self.store = hs.get_datastores().main
+
+ # Create three users.
+ self.alice = self.register_user("alice", "pass")
+ self.alice_token = self.login("alice", "pass")
+ self.bob = self.register_user("bob", "pass")
+ self.bob_token = self.login("bob", "pass")
+
+ @override_config({"experimental_features": {"msc4155_enabled": True}})
+ def test_misc4155_block_invite_local(self) -> None:
+ """Test that MSC4155 will block a user from being invited to a room"""
+ room_id = self.helper.create_room_as(self.alice, tok=self.alice_token)
+
+ self.get_success(
+ self.store.add_account_data_for_user(
+ self.bob,
+ AccountDataTypes.MSC4155_INVITE_PERMISSION_CONFIG,
+ {
+ "blocked_users": [self.alice],
+ },
+ )
+ )
+
+ f = self.get_failure(
+ self.handler.update_membership(
+ requester=create_requester(self.alice),
+ target=UserID.from_string(self.bob),
+ room_id=room_id,
+ action=Membership.INVITE,
+ ),
+ SynapseError,
+ ).value
+ self.assertEqual(f.code, 403)
+ self.assertEqual(f.errcode, "ORG.MATRIX.MSC4155.M_INVITE_BLOCKED")
+
+ @override_config({"experimental_features": {"msc4155_enabled": False}})
+ def test_msc4155_disabled_allow_invite_local(self) -> None:
+ """Test that MSC4155 will block a user from being invited to a room"""
+ room_id = self.helper.create_room_as(self.alice, tok=self.alice_token)
+
+ self.get_success(
+ self.store.add_account_data_for_user(
+ self.bob,
+ AccountDataTypes.MSC4155_INVITE_PERMISSION_CONFIG,
+ {
+ "blocked_users": [self.alice],
+ },
+ )
+ )
+
+ self.get_success(
+ self.handler.update_membership(
+ requester=create_requester(self.alice),
+ target=UserID.from_string(self.bob),
+ room_id=room_id,
+ action=Membership.INVITE,
+ ),
+ )
+
+ @override_config({"experimental_features": {"msc4155_enabled": True}})
+ def test_msc4155_block_invite_remote(self) -> None:
+ """Test that MSC4155 will block a remote user from being invited to a room"""
+ # A remote user who sends the invite
+ remote_server = "otherserver"
+ remote_user = "@otheruser:" + remote_server
+
+ self.get_success(
+ self.store.add_account_data_for_user(
+ self.bob,
+ AccountDataTypes.MSC4155_INVITE_PERMISSION_CONFIG,
+ {"blocked_users": [remote_user]},
+ )
+ )
+
+ room_id = self.helper.create_room_as(
+ room_creator=self.alice, tok=self.alice_token
+ )
+ room_version = self.get_success(self.store.get_room_version(room_id))
+
+ invite_event = event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "content": {"membership": "invite"},
+ "room_id": room_id,
+ "sender": remote_user,
+ "state_key": self.bob,
+ "depth": 32,
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ room_version,
+ )
+
+ f = self.get_failure(
+ self.fed_handler.on_invite_request(
+ remote_server,
+ invite_event,
+ invite_event.room_version,
+ ),
+ SynapseError,
+ ).value
+ self.assertEqual(f.code, 403)
+ self.assertEqual(f.errcode, "ORG.MATRIX.MSC4155.M_INVITE_BLOCKED")
+
+ @override_config({"experimental_features": {"msc4155_enabled": True}})
+ def test_msc4155_block_invite_remote_server(self) -> None:
+ """Test that MSC4155 will block a remote server's user from being invited to a room"""
+ # A remote user who sends the invite
+ remote_server = "otherserver"
+ remote_user = "@otheruser:" + remote_server
+
+ self.get_success(
+ self.store.add_account_data_for_user(
+ self.bob,
+ AccountDataTypes.MSC4155_INVITE_PERMISSION_CONFIG,
+ {"blocked_servers": [remote_server]},
+ )
+ )
+
+ room_id = self.helper.create_room_as(
+ room_creator=self.alice, tok=self.alice_token
+ )
+ room_version = self.get_success(self.store.get_room_version(room_id))
+
+ invite_event = event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "content": {"membership": "invite"},
+ "room_id": room_id,
+ "sender": remote_user,
+ "state_key": self.bob,
+ "depth": 32,
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ room_version,
+ )
+
+ f = self.get_failure(
+ self.fed_handler.on_invite_request(
+ remote_server,
+ invite_event,
+ invite_event.room_version,
+ ),
+ SynapseError,
+ ).value
+ self.assertEqual(f.code, 403)
+ self.assertEqual(f.errcode, "ORG.MATRIX.MSC4155.M_INVITE_BLOCKED")
diff --git a/tests/handlers/test_room_policy.py b/tests/handlers/test_room_policy.py
new file mode 100644
index 0000000000..26642c18ea
--- /dev/null
+++ b/tests/handlers/test_room_policy.py
@@ -0,0 +1,226 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2025 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+#
+from typing import Optional
+from unittest import mock
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.events import EventBase, make_event_from_dict
+from synapse.rest import admin
+from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.types import JsonDict, UserID
+from synapse.types.handlers.policy_server import RECOMMENDATION_OK, RECOMMENDATION_SPAM
+from synapse.util import Clock
+
+from tests import unittest
+from tests.test_utils import event_injection
+
+
+class RoomPolicyTestCase(unittest.FederatingHomeserverTestCase):
+ """Tests room policy handler."""
+
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ # mock out the federation transport client
+ self.mock_federation_transport_client = mock.Mock(
+ spec=["get_policy_recommendation_for_pdu"]
+ )
+ self.mock_federation_transport_client.get_policy_recommendation_for_pdu = (
+ mock.AsyncMock()
+ )
+ return super().setup_test_homeserver(
+ federation_transport_client=self.mock_federation_transport_client
+ )
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.hs = hs
+ self.handler = hs.get_room_policy_handler()
+ main_store = self.hs.get_datastores().main
+
+ # Create a room
+ self.creator = self.register_user("creator", "test1234")
+ self.creator_token = self.login("creator", "test1234")
+ self.room_id = self.helper.create_room_as(
+ room_creator=self.creator, tok=self.creator_token
+ )
+ room_version = self.get_success(main_store.get_room_version(self.room_id))
+
+ # Create some sample events
+ self.spammy_event = make_event_from_dict(
+ room_version=room_version,
+ internal_metadata_dict={},
+ event_dict={
+ "room_id": self.room_id,
+ "type": "m.room.message",
+ "sender": "@spammy:example.org",
+ "content": {
+ "msgtype": "m.text",
+ "body": "This is a spammy event.",
+ },
+ },
+ )
+ self.not_spammy_event = make_event_from_dict(
+ room_version=room_version,
+ internal_metadata_dict={},
+ event_dict={
+ "room_id": self.room_id,
+ "type": "m.room.message",
+ "sender": "@not_spammy:example.org",
+ "content": {
+ "msgtype": "m.text",
+ "body": "This is a NOT spammy event.",
+ },
+ },
+ )
+
+ # Prepare the policy server mock to decide spam vs not spam on those events
+ self.call_count = 0
+
+ async def get_policy_recommendation_for_pdu(
+ destination: str,
+ pdu: EventBase,
+ timeout: Optional[int] = None,
+ ) -> JsonDict:
+ self.call_count += 1
+ self.assertEqual(destination, self.OTHER_SERVER_NAME)
+ if pdu.event_id == self.spammy_event.event_id:
+ return {"recommendation": RECOMMENDATION_SPAM}
+ elif pdu.event_id == self.not_spammy_event.event_id:
+ return {"recommendation": RECOMMENDATION_OK}
+ else:
+ self.fail("Unexpected event ID")
+
+ self.mock_federation_transport_client.get_policy_recommendation_for_pdu.side_effect = get_policy_recommendation_for_pdu
+
+ def _add_policy_server_to_room(self) -> None:
+ # Inject a member event into the room
+ policy_user_id = f"@policy:{self.OTHER_SERVER_NAME}"
+ self.get_success(
+ event_injection.inject_member_event(
+ self.hs, self.room_id, policy_user_id, "join"
+ )
+ )
+ self.helper.send_state(
+ self.room_id,
+ "org.matrix.msc4284.policy",
+ {
+ "via": self.OTHER_SERVER_NAME,
+ },
+ tok=self.creator_token,
+ state_key="",
+ )
+
+ def test_no_policy_event_set(self) -> None:
+ # We don't need to modify the room state at all - we're testing the default
+ # case where a room doesn't use a policy server.
+ ok = self.get_success(self.handler.is_event_allowed(self.spammy_event))
+ self.assertEqual(ok, True)
+ self.assertEqual(self.call_count, 0)
+
+ def test_empty_policy_event_set(self) -> None:
+ self.helper.send_state(
+ self.room_id,
+ "org.matrix.msc4284.policy",
+ {
+ # empty content (no `via`)
+ },
+ tok=self.creator_token,
+ state_key="",
+ )
+
+ ok = self.get_success(self.handler.is_event_allowed(self.spammy_event))
+ self.assertEqual(ok, True)
+ self.assertEqual(self.call_count, 0)
+
+ def test_nonstring_policy_event_set(self) -> None:
+ self.helper.send_state(
+ self.room_id,
+ "org.matrix.msc4284.policy",
+ {
+ "via": 42, # should be a server name
+ },
+ tok=self.creator_token,
+ state_key="",
+ )
+
+ ok = self.get_success(self.handler.is_event_allowed(self.spammy_event))
+ self.assertEqual(ok, True)
+ self.assertEqual(self.call_count, 0)
+
+ def test_self_policy_event_set(self) -> None:
+ self.helper.send_state(
+ self.room_id,
+ "org.matrix.msc4284.policy",
+ {
+ # We ignore events when the policy server is ourselves (for now?)
+ "via": (UserID.from_string(self.creator)).domain,
+ },
+ tok=self.creator_token,
+ state_key="",
+ )
+
+ ok = self.get_success(self.handler.is_event_allowed(self.spammy_event))
+ self.assertEqual(ok, True)
+ self.assertEqual(self.call_count, 0)
+
+ def test_invalid_server_policy_event_set(self) -> None:
+ self.helper.send_state(
+ self.room_id,
+ "org.matrix.msc4284.policy",
+ {
+ "via": "|this| is *not* a (valid) server name.com",
+ },
+ tok=self.creator_token,
+ state_key="",
+ )
+
+ ok = self.get_success(self.handler.is_event_allowed(self.spammy_event))
+ self.assertEqual(ok, True)
+ self.assertEqual(self.call_count, 0)
+
+ def test_not_in_room_policy_event_set(self) -> None:
+ self.helper.send_state(
+ self.room_id,
+ "org.matrix.msc4284.policy",
+ {
+ "via": f"x.{self.OTHER_SERVER_NAME}",
+ },
+ tok=self.creator_token,
+ state_key="",
+ )
+
+ ok = self.get_success(self.handler.is_event_allowed(self.spammy_event))
+ self.assertEqual(ok, True)
+ self.assertEqual(self.call_count, 0)
+
+ def test_spammy_event_is_spam(self) -> None:
+ self._add_policy_server_to_room()
+
+ ok = self.get_success(self.handler.is_event_allowed(self.spammy_event))
+ self.assertEqual(ok, False)
+ self.assertEqual(self.call_count, 1)
+
+ def test_not_spammy_event_is_not_spam(self) -> None:
+ self._add_policy_server_to_room()
+
+ ok = self.get_success(self.handler.is_event_allowed(self.not_spammy_event))
+ self.assertEqual(ok, True)
+ self.assertEqual(self.call_count, 1)
diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py
index 244a4e7689..b55fa1a8fd 100644
--- a/tests/handlers/test_room_summary.py
+++ b/tests/handlers/test_room_summary.py
@@ -757,6 +757,54 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
)
self._assert_hierarchy(result, expected)
+ def test_fed_root(self) -> None:
+ """
+ Test if requested room is available over federation.
+ """
+ fed_hostname = self.hs.hostname + "2"
+ fed_space = "#fed_space:" + fed_hostname
+ fed_subroom = "#fed_sub_room:" + fed_hostname
+
+ requested_room_entry = _RoomEntry(
+ fed_space,
+ {
+ "room_id": fed_space,
+ "world_readable": True,
+ "room_type": RoomTypes.SPACE,
+ },
+ [
+ {
+ "type": EventTypes.SpaceChild,
+ "room_id": fed_space,
+ "state_key": fed_subroom,
+ "content": {"via": [fed_hostname]},
+ }
+ ],
+ )
+ child_room = {
+ "room_id": fed_subroom,
+ "world_readable": True,
+ }
+
+ async def summarize_remote_room_hierarchy(
+ _self: Any, room: Any, suggested_only: bool
+ ) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]:
+ return requested_room_entry, {fed_subroom: child_room}, set()
+
+ expected = [
+ (fed_space, [fed_subroom]),
+ (fed_subroom, ()),
+ ]
+
+ with mock.patch(
+ "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy",
+ new=summarize_remote_room_hierarchy,
+ ):
+ result = self.get_success(
+ self.handler.get_room_hierarchy(create_requester(self.user), fed_space)
+ )
+ self._assert_hierarchy(result, expected)
+
def test_fed_filtering(self) -> None:
"""
Rooms returned over federation should be properly filtered to only include
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
deleted file mode 100644
index 6ab8fda6e7..0000000000
--- a/tests/handlers/test_saml.py
+++ /dev/null
@@ -1,381 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
-# Copyright (C) 2023 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-# Originally licensed under the Apache License, Version 2.0:
-# <http://www.apache.org/licenses/LICENSE-2.0>.
-#
-# [This file includes modifications made by New Vector Limited]
-#
-#
-
-from typing import Any, Dict, Optional, Set, Tuple
-from unittest.mock import AsyncMock, Mock
-
-import attr
-
-from twisted.test.proto_helpers import MemoryReactor
-
-from synapse.api.errors import RedirectException
-from synapse.module_api import ModuleApi
-from synapse.server import HomeServer
-from synapse.types import JsonDict
-from synapse.util import Clock
-
-from tests.unittest import HomeserverTestCase, override_config
-
-# Check if we have the dependencies to run the tests.
-try:
- import saml2.config
- import saml2.response
- from saml2.sigver import SigverError
-
- has_saml2 = True
-
- # pysaml2 can be installed and imported, but might not be able to find xmlsec1.
- config = saml2.config.SPConfig()
- try:
- config.load({"metadata": {}})
- has_xmlsec1 = True
- except SigverError:
- has_xmlsec1 = False
-except ImportError:
- has_saml2 = False
- has_xmlsec1 = False
-
-# These are a few constants that are used as config parameters in the tests.
-BASE_URL = "https://synapse/"
-
-
-@attr.s
-class FakeAuthnResponse:
- ava = attr.ib(type=dict)
- assertions = attr.ib(type=list, factory=list)
- in_response_to = attr.ib(type=Optional[str], default=None)
-
-
-class TestMappingProvider:
- def __init__(self, config: None, module: ModuleApi):
- pass
-
- @staticmethod
- def parse_config(config: JsonDict) -> None:
- return None
-
- @staticmethod
- def get_saml_attributes(config: None) -> Tuple[Set[str], Set[str]]:
- return {"uid"}, {"displayName"}
-
- def get_remote_user_id(
- self, saml_response: "saml2.response.AuthnResponse", client_redirect_url: str
- ) -> str:
- return saml_response.ava["uid"]
-
- def saml_response_to_user_attributes(
- self,
- saml_response: "saml2.response.AuthnResponse",
- failures: int,
- client_redirect_url: str,
- ) -> dict:
- localpart = saml_response.ava["username"] + (str(failures) if failures else "")
- return {"mxid_localpart": localpart, "displayname": None}
-
-
-class TestRedirectMappingProvider(TestMappingProvider):
- def saml_response_to_user_attributes(
- self,
- saml_response: "saml2.response.AuthnResponse",
- failures: int,
- client_redirect_url: str,
- ) -> dict:
- raise RedirectException(b"https://custom-saml-redirect/")
-
-
-class SamlHandlerTestCase(HomeserverTestCase):
- def default_config(self) -> Dict[str, Any]:
- config = super().default_config()
- config["public_baseurl"] = BASE_URL
- saml_config: Dict[str, Any] = {
- "sp_config": {"metadata": {}},
- # Disable grandfathering.
- "grandfathered_mxid_source_attribute": None,
- "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
- }
-
- # Update this config with what's in the default config so that
- # override_config works as expected.
- saml_config.update(config.get("saml2_config", {}))
- config["saml2_config"] = saml_config
-
- return config
-
- def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- hs = self.setup_test_homeserver()
-
- self.handler = hs.get_saml_handler()
-
- # Reduce the number of attempts when generating MXIDs.
- sso_handler = hs.get_sso_handler()
- sso_handler._MAP_USERNAME_RETRIES = 3
-
- return hs
-
- if not has_saml2:
- skip = "Requires pysaml2"
- elif not has_xmlsec1:
- skip = "Requires xmlsec1"
-
- def test_map_saml_response_to_user(self) -> None:
- """Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
-
- # stub out the auth handler
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
-
- # send a mocked-up SAML response to the callback
- saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
- request = _mock_request()
- self.get_success(
- self.handler._handle_authn_response(request, saml_response, "redirect_uri")
- )
-
- # check that the auth handler got called as expected
- auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test",
- "saml",
- request,
- "redirect_uri",
- None,
- new_user=True,
- auth_provider_session_id=None,
- )
-
- @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
- def test_map_saml_response_to_existing_user(self) -> None:
- """Existing users can log in with SAML account."""
- store = self.hs.get_datastores().main
- self.get_success(
- store.register_user(user_id="@test_user:test", password_hash=None)
- )
-
- # stub out the auth handler
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
-
- # Map a user via SSO.
- saml_response = FakeAuthnResponse(
- {"uid": "tester", "mxid": ["test_user"], "username": "test_user"}
- )
- request = _mock_request()
- self.get_success(
- self.handler._handle_authn_response(request, saml_response, "")
- )
-
- # check that the auth handler got called as expected
- auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test",
- "saml",
- request,
- "",
- None,
- new_user=False,
- auth_provider_session_id=None,
- )
-
- # Subsequent calls should map to the same mxid.
- auth_handler.complete_sso_login.reset_mock()
- self.get_success(
- self.handler._handle_authn_response(request, saml_response, "")
- )
- auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test",
- "saml",
- request,
- "",
- None,
- new_user=False,
- auth_provider_session_id=None,
- )
-
- def test_map_saml_response_to_invalid_localpart(self) -> None:
- """If the mapping provider generates an invalid localpart it should be rejected."""
-
- # stub out the auth handler
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
-
- # mock out the error renderer too
- sso_handler = self.hs.get_sso_handler()
- sso_handler.render_error = Mock(return_value=None) # type: ignore[method-assign]
-
- saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
- request = _mock_request()
- self.get_success(
- self.handler._handle_authn_response(request, saml_response, ""),
- )
- sso_handler.render_error.assert_called_once_with(
- request, "mapping_error", "localpart is invalid: föö"
- )
- auth_handler.complete_sso_login.assert_not_called()
-
- def test_map_saml_response_to_user_retries(self) -> None:
- """The mapping provider can retry generating an MXID if the MXID is already in use."""
-
- # stub out the auth handler and error renderer
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
- sso_handler = self.hs.get_sso_handler()
- sso_handler.render_error = Mock(return_value=None) # type: ignore[method-assign]
-
- # register a user to occupy the first-choice MXID
- store = self.hs.get_datastores().main
- self.get_success(
- store.register_user(user_id="@test_user:test", password_hash=None)
- )
-
- # send the fake SAML response
- saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
- request = _mock_request()
- self.get_success(
- self.handler._handle_authn_response(request, saml_response, ""),
- )
-
- # test_user is already taken, so test_user1 gets registered instead.
- auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user1:test",
- "saml",
- request,
- "",
- None,
- new_user=True,
- auth_provider_session_id=None,
- )
- auth_handler.complete_sso_login.reset_mock()
-
- # Register all of the potential mxids for a particular SAML username.
- self.get_success(
- store.register_user(user_id="@tester:test", password_hash=None)
- )
- for i in range(1, 3):
- self.get_success(
- store.register_user(user_id="@tester%d:test" % i, password_hash=None)
- )
-
- # Now attempt to map to a username, this will fail since all potential usernames are taken.
- saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"})
- self.get_success(
- self.handler._handle_authn_response(request, saml_response, ""),
- )
- sso_handler.render_error.assert_called_once_with(
- request,
- "mapping_error",
- "Unable to generate a Matrix ID from the SSO response",
- )
- auth_handler.complete_sso_login.assert_not_called()
-
- @override_config(
- {
- "saml2_config": {
- "user_mapping_provider": {
- "module": __name__ + ".TestRedirectMappingProvider"
- },
- }
- }
- )
- def test_map_saml_response_redirect(self) -> None:
- """Test a mapping provider that raises a RedirectException"""
-
- saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
- request = _mock_request()
- e = self.get_failure(
- self.handler._handle_authn_response(request, saml_response, ""),
- RedirectException,
- )
- self.assertEqual(e.value.location, b"https://custom-saml-redirect/")
-
- @override_config(
- {
- "saml2_config": {
- "attribute_requirements": [
- {"attribute": "userGroup", "value": "staff"},
- {"attribute": "department", "value": "sales"},
- ],
- },
- }
- )
- def test_attribute_requirements(self) -> None:
- """The required attributes must be met from the SAML response."""
-
- # stub out the auth handler
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
-
- # The response doesn't have the proper userGroup or department.
- saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
- request = _mock_request()
- self.get_success(
- self.handler._handle_authn_response(request, saml_response, "redirect_uri")
- )
- auth_handler.complete_sso_login.assert_not_called()
-
- # The response doesn't have the proper department.
- saml_response = FakeAuthnResponse(
- {"uid": "test_user", "username": "test_user", "userGroup": ["staff"]}
- )
- request = _mock_request()
- self.get_success(
- self.handler._handle_authn_response(request, saml_response, "redirect_uri")
- )
- auth_handler.complete_sso_login.assert_not_called()
-
- # Add the proper attributes and it should succeed.
- saml_response = FakeAuthnResponse(
- {
- "uid": "test_user",
- "username": "test_user",
- "userGroup": ["staff", "admin"],
- "department": ["sales"],
- }
- )
- request.reset_mock()
- self.get_success(
- self.handler._handle_authn_response(request, saml_response, "redirect_uri")
- )
-
- # check that the auth handler got called as expected
- auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test",
- "saml",
- request,
- "redirect_uri",
- None,
- new_user=True,
- auth_provider_session_id=None,
- )
-
-
-def _mock_request() -> Mock:
- """Returns a mock which will stand in as a SynapseRequest"""
- mock = Mock(
- spec=[
- "finish",
- "getClientAddress",
- "getHeader",
- "setHeader",
- "setResponseCode",
- "write",
- ]
- )
- # `_disconnected` musn't be another `Mock`, otherwise it will be truthy.
- mock._disconnected = False
- return mock
diff --git a/tests/handlers/test_send_email.py b/tests/handlers/test_send_email.py
deleted file mode 100644
index cedcea27d9..0000000000
--- a/tests/handlers/test_send_email.py
+++ /dev/null
@@ -1,230 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
-# Copyright (C) 2023 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-# Originally licensed under the Apache License, Version 2.0:
-# <http://www.apache.org/licenses/LICENSE-2.0>.
-#
-# [This file includes modifications made by New Vector Limited]
-#
-#
-
-
-from typing import Callable, List, Tuple, Type, Union
-from unittest.mock import patch
-
-from zope.interface import implementer
-
-from twisted.internet import defer
-from twisted.internet._sslverify import ClientTLSOptions
-from twisted.internet.address import IPv4Address, IPv6Address
-from twisted.internet.defer import ensureDeferred
-from twisted.internet.interfaces import IProtocolFactory
-from twisted.internet.ssl import ContextFactory
-from twisted.mail import interfaces, smtp
-
-from tests.server import FakeTransport
-from tests.unittest import HomeserverTestCase, override_config
-
-
-def TestingESMTPTLSClientFactory(
- contextFactory: ContextFactory,
- _connectWrapped: bool,
- wrappedProtocol: IProtocolFactory,
-) -> IProtocolFactory:
- """We use this to pass through in testing without using TLS, but
- saving the context information to check that it would have happened.
-
- Note that this is what the MemoryReactor does on connectSSL.
- It only saves the contextFactory, but starts the connection with the
- underlying Factory.
- See: L{twisted.internet.testing.MemoryReactor.connectSSL}"""
-
- wrappedProtocol._testingContextFactory = contextFactory # type: ignore[attr-defined]
- return wrappedProtocol
-
-
-@implementer(interfaces.IMessageDelivery)
-class _DummyMessageDelivery:
- def __init__(self) -> None:
- # (recipient, message) tuples
- self.messages: List[Tuple[smtp.Address, bytes]] = []
-
- def receivedHeader(
- self,
- helo: Tuple[bytes, bytes],
- origin: smtp.Address,
- recipients: List[smtp.User],
- ) -> None:
- return None
-
- def validateFrom(
- self, helo: Tuple[bytes, bytes], origin: smtp.Address
- ) -> smtp.Address:
- return origin
-
- def record_message(self, recipient: smtp.Address, message: bytes) -> None:
- self.messages.append((recipient, message))
-
- def validateTo(self, user: smtp.User) -> Callable[[], interfaces.IMessageSMTP]:
- return lambda: _DummyMessage(self, user)
-
-
-@implementer(interfaces.IMessageSMTP)
-class _DummyMessage:
- """IMessageSMTP implementation which saves the message delivered to it
- to the _DummyMessageDelivery object.
- """
-
- def __init__(self, delivery: _DummyMessageDelivery, user: smtp.User):
- self._delivery = delivery
- self._user = user
- self._buffer: List[bytes] = []
-
- def lineReceived(self, line: bytes) -> None:
- self._buffer.append(line)
-
- def eomReceived(self) -> "defer.Deferred[bytes]":
- message = b"\n".join(self._buffer) + b"\n"
- self._delivery.record_message(self._user.dest, message)
- return defer.succeed(b"saved")
-
- def connectionLost(self) -> None:
- pass
-
-
-class SendEmailHandlerTestCaseIPv4(HomeserverTestCase):
- ip_class: Union[Type[IPv4Address], Type[IPv6Address]] = IPv4Address
-
- def setUp(self) -> None:
- super().setUp()
- self.reactor.lookups["localhost"] = "127.0.0.1"
-
- def test_send_email(self) -> None:
- """Happy-path test that we can send email to a non-TLS server."""
- h = self.hs.get_send_email_handler()
- d = ensureDeferred(
- h.send_email(
- "foo@bar.com", "test subject", "Tests", "HTML content", "Text content"
- )
- )
- # there should be an attempt to connect to localhost:25
- self.assertEqual(len(self.reactor.tcpClients), 1)
- (host, port, client_factory, _timeout, _bindAddress) = self.reactor.tcpClients[
- 0
- ]
- self.assertEqual(host, self.reactor.lookups["localhost"])
- self.assertEqual(port, 25)
-
- # wire it up to an SMTP server
- message_delivery = _DummyMessageDelivery()
- server_protocol = smtp.ESMTP()
- server_protocol.delivery = message_delivery
- # make sure that the server uses the test reactor to set timeouts
- server_protocol.callLater = self.reactor.callLater # type: ignore[assignment]
-
- client_protocol = client_factory.buildProtocol(None)
- client_protocol.makeConnection(FakeTransport(server_protocol, self.reactor))
- server_protocol.makeConnection(
- FakeTransport(
- client_protocol,
- self.reactor,
- peer_address=self.ip_class(
- "TCP", self.reactor.lookups["localhost"], 1234
- ),
- )
- )
-
- # the message should now get delivered
- self.get_success(d, by=0.1)
-
- # check it arrived
- self.assertEqual(len(message_delivery.messages), 1)
- user, msg = message_delivery.messages.pop()
- self.assertEqual(str(user), "foo@bar.com")
- self.assertIn(b"Subject: test subject", msg)
-
- @patch(
- "synapse.handlers.send_email.TLSMemoryBIOFactory",
- TestingESMTPTLSClientFactory,
- )
- @override_config(
- {
- "email": {
- "notif_from": "noreply@test",
- "force_tls": True,
- },
- }
- )
- def test_send_email_force_tls(self) -> None:
- """Happy-path test that we can send email to an Implicit TLS server."""
- h = self.hs.get_send_email_handler()
- d = ensureDeferred(
- h.send_email(
- "foo@bar.com", "test subject", "Tests", "HTML content", "Text content"
- )
- )
- # there should be an attempt to connect to localhost:465
- self.assertEqual(len(self.reactor.tcpClients), 1)
- (
- host,
- port,
- client_factory,
- _timeout,
- _bindAddress,
- ) = self.reactor.tcpClients[0]
- self.assertEqual(host, self.reactor.lookups["localhost"])
- self.assertEqual(port, 465)
- # We need to make sure that TLS is happenning
- self.assertIsInstance(
- client_factory._wrappedFactory._testingContextFactory,
- ClientTLSOptions,
- )
- # And since we use endpoints, they go through reactor.connectTCP
- # which works differently to connectSSL on the testing reactor
-
- # wire it up to an SMTP server
- message_delivery = _DummyMessageDelivery()
- server_protocol = smtp.ESMTP()
- server_protocol.delivery = message_delivery
- # make sure that the server uses the test reactor to set timeouts
- server_protocol.callLater = self.reactor.callLater # type: ignore[assignment]
-
- client_protocol = client_factory.buildProtocol(None)
- client_protocol.makeConnection(FakeTransport(server_protocol, self.reactor))
- server_protocol.makeConnection(
- FakeTransport(
- client_protocol,
- self.reactor,
- peer_address=self.ip_class(
- "TCP", self.reactor.lookups["localhost"], 1234
- ),
- )
- )
-
- # the message should now get delivered
- self.get_success(d, by=0.1)
-
- # check it arrived
- self.assertEqual(len(message_delivery.messages), 1)
- user, msg = message_delivery.messages.pop()
- self.assertEqual(str(user), "foo@bar.com")
- self.assertIn(b"Subject: test subject", msg)
-
-
-class SendEmailHandlerTestCaseIPv6(SendEmailHandlerTestCaseIPv4):
- ip_class = IPv6Address
-
- def setUp(self) -> None:
- super().setUp()
- self.reactor.lookups["localhost"] = "::1"
diff --git a/tests/handlers/test_sliding_sync.py b/tests/handlers/test_sliding_sync.py
index 96da47f3b9..7144c58217 100644
--- a/tests/handlers/test_sliding_sync.py
+++ b/tests/handlers/test_sliding_sync.py
@@ -18,39 +18,40 @@
#
#
import logging
-from copy import deepcopy
-from typing import Dict, List, Optional
+from typing import AbstractSet, Dict, Mapping, Optional, Set, Tuple
from unittest.mock import patch
-from parameterized import parameterized
+import attr
+from parameterized import parameterized, parameterized_class
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import (
- AccountDataTypes,
- EventContentFields,
EventTypes,
JoinRules,
Membership,
- RoomTypes,
)
from synapse.api.room_versions import RoomVersions
-from synapse.events import StrippedStateEvent, make_event_from_dict
-from synapse.events.snapshot import EventContext
from synapse.handlers.sliding_sync import (
+ MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER,
+ RoomsForUserType,
RoomSyncConfig,
StateValues,
- _RoomMembershipForUser,
+ _required_state_changes,
)
from synapse.rest import admin
from synapse.rest.client import knock, login, room
from synapse.server import HomeServer
from synapse.storage.util.id_generators import MultiWriterIdGenerator
-from synapse.types import JsonDict, StreamToken, UserID
-from synapse.types.handlers import SlidingSyncConfig
+from synapse.types import JsonDict, StateMap, StreamToken, UserID, create_requester
+from synapse.types.handlers.sliding_sync import PerConnectionState, SlidingSyncConfig
+from synapse.types.state import StateFilter
from synapse.util import Clock
+from tests import unittest
from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase
+from tests.test_utils.event_injection import create_event
from tests.unittest import HomeserverTestCase, TestCase
logger = logging.getLogger(__name__)
@@ -566,31 +567,39 @@ class RoomSyncConfigTestCase(TestCase):
"""
Combine A into B and B into A to make sure we get the same result.
"""
- # Since we're mutating these in place, make a copy for each of our trials
- room_sync_config_a = deepcopy(a)
- room_sync_config_b = deepcopy(b)
-
- # Combine B into A
- room_sync_config_a.combine_room_sync_config(room_sync_config_b)
-
- self._assert_room_config_equal(room_sync_config_a, expected, "B into A")
-
- # Since we're mutating these in place, make a copy for each of our trials
- room_sync_config_a = deepcopy(a)
- room_sync_config_b = deepcopy(b)
-
- # Combine A into B
- room_sync_config_b.combine_room_sync_config(room_sync_config_a)
-
- self._assert_room_config_equal(room_sync_config_b, expected, "A into B")
-
-
-class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
+ combined_config = a.combine_room_sync_config(b)
+ self._assert_room_config_equal(combined_config, expected, "B into A")
+
+ combined_config = a.combine_room_sync_config(b)
+ self._assert_room_config_equal(combined_config, expected, "A into B")
+
+
+# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+# foreground update for
+# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+# https://github.com/element-hq/synapse/issues/17623)
+@parameterized_class(
+ ("use_new_tables",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}",
+)
+class ComputeInterestedRoomsTestCase(SlidingSyncBase):
"""
- Tests Sliding Sync handler `get_room_membership_for_user_at_to_token()` to make sure it returns
+ Tests Sliding Sync handler `compute_interested_rooms()` to make sure it returns
the correct list of rooms IDs.
"""
+ # FIXME: We should refactor these tests to run against `compute_interested_rooms(...)`
+ # instead of just `get_room_membership_for_user_at_to_token(...)` which is only used
+ # in the fallback path (`_compute_interested_rooms_fallback(...)`). These scenarios do
+ # well to stress that logic and we shouldn't remove them just because we're removing
+ # the fallback path (tracked by https://github.com/element-hq/synapse/issues/17623).
+
servlets = [
admin.register_servlets,
knock.register_servlets,
@@ -609,6 +618,11 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
self.store = self.hs.get_datastores().main
self.event_sources = hs.get_event_sources()
self.storage_controllers = hs.get_storage_controllers()
+ persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
+ self.persistence = persistence
+
+ super().prepare(reactor, clock, hs)
def test_no_rooms(self) -> None:
"""
@@ -619,15 +633,28 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
now_token = self.event_sources.get_current_token()
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=now_token,
to_token=now_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
- self.assertEqual(room_id_results.keys(), set())
+ self.assertIncludes(room_id_results, set(), exact=True)
def test_get_newly_joined_room(self) -> None:
"""
@@ -646,26 +673,48 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
after_room_token = self.event_sources.get_current_token()
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=before_room_token,
to_token=after_room_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
- self.assertEqual(room_id_results.keys(), {room_id})
+ self.assertIncludes(
+ room_id_results,
+ {room_id},
+ exact=True,
+ )
# It should be pointing to the join event (latest membership event in the
# from/to range)
self.assertEqual(
- room_id_results[room_id].event_id,
+ interested_rooms.room_membership_for_user_map[room_id].event_id,
join_response["event_id"],
)
- self.assertEqual(room_id_results[room_id].membership, Membership.JOIN)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id].membership,
+ Membership.JOIN,
+ )
# We should be considered `newly_joined` because we joined during the token
# range
- self.assertEqual(room_id_results[room_id].newly_joined, True)
- self.assertEqual(room_id_results[room_id].newly_left, False)
+ self.assertTrue(room_id in newly_joined)
+ self.assertTrue(room_id not in newly_left)
def test_get_already_joined_room(self) -> None:
"""
@@ -681,25 +730,43 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
after_room_token = self.event_sources.get_current_token()
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=after_room_token,
to_token=after_room_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
- self.assertEqual(room_id_results.keys(), {room_id})
+ self.assertIncludes(room_id_results, {room_id}, exact=True)
# It should be pointing to the join event (latest membership event in the
# from/to range)
self.assertEqual(
- room_id_results[room_id].event_id,
+ interested_rooms.room_membership_for_user_map[room_id].event_id,
join_response["event_id"],
)
- self.assertEqual(room_id_results[room_id].membership, Membership.JOIN)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id].membership,
+ Membership.JOIN,
+ )
# We should *NOT* be `newly_joined` because we joined before the token range
- self.assertEqual(room_id_results[room_id].newly_joined, False)
- self.assertEqual(room_id_results[room_id].newly_left, False)
+ self.assertTrue(room_id not in newly_joined)
+ self.assertTrue(room_id not in newly_left)
def test_get_invited_banned_knocked_room(self) -> None:
"""
@@ -755,48 +822,73 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
after_room_token = self.event_sources.get_current_token()
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=before_room_token,
to_token=after_room_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# Ensure that the invited, ban, and knock rooms show up
- self.assertEqual(
- room_id_results.keys(),
+ self.assertIncludes(
+ room_id_results,
{
invited_room_id,
ban_room_id,
knock_room_id,
},
+ exact=True,
)
# It should be pointing to the the respective membership event (latest
# membership event in the from/to range)
self.assertEqual(
- room_id_results[invited_room_id].event_id,
+ interested_rooms.room_membership_for_user_map[invited_room_id].event_id,
invite_response["event_id"],
)
- self.assertEqual(room_id_results[invited_room_id].membership, Membership.INVITE)
- self.assertEqual(room_id_results[invited_room_id].newly_joined, False)
- self.assertEqual(room_id_results[invited_room_id].newly_left, False)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[invited_room_id].membership,
+ Membership.INVITE,
+ )
+ self.assertTrue(invited_room_id not in newly_joined)
+ self.assertTrue(invited_room_id not in newly_left)
self.assertEqual(
- room_id_results[ban_room_id].event_id,
+ interested_rooms.room_membership_for_user_map[ban_room_id].event_id,
ban_response["event_id"],
)
- self.assertEqual(room_id_results[ban_room_id].membership, Membership.BAN)
- self.assertEqual(room_id_results[ban_room_id].newly_joined, False)
- self.assertEqual(room_id_results[ban_room_id].newly_left, False)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[ban_room_id].membership,
+ Membership.BAN,
+ )
+ self.assertTrue(ban_room_id not in newly_joined)
+ self.assertTrue(ban_room_id not in newly_left)
self.assertEqual(
- room_id_results[knock_room_id].event_id,
+ interested_rooms.room_membership_for_user_map[knock_room_id].event_id,
knock_room_membership_state_event.event_id,
)
- self.assertEqual(room_id_results[knock_room_id].membership, Membership.KNOCK)
- self.assertEqual(room_id_results[knock_room_id].newly_joined, False)
- self.assertEqual(room_id_results[knock_room_id].newly_left, False)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[knock_room_id].membership,
+ Membership.KNOCK,
+ )
+ self.assertTrue(knock_room_id not in newly_joined)
+ self.assertTrue(knock_room_id not in newly_left)
def test_get_kicked_room(self) -> None:
"""
@@ -827,27 +919,47 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
after_kick_token = self.event_sources.get_current_token()
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=after_kick_token,
to_token=after_kick_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# The kicked room should show up
- self.assertEqual(room_id_results.keys(), {kick_room_id})
+ self.assertIncludes(room_id_results, {kick_room_id}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[kick_room_id].event_id,
+ interested_rooms.room_membership_for_user_map[kick_room_id].event_id,
kick_response["event_id"],
)
- self.assertEqual(room_id_results[kick_room_id].membership, Membership.LEAVE)
- self.assertNotEqual(room_id_results[kick_room_id].sender, user1_id)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[kick_room_id].membership,
+ Membership.LEAVE,
+ )
+ self.assertNotEqual(
+ interested_rooms.room_membership_for_user_map[kick_room_id].sender, user1_id
+ )
# We should *NOT* be `newly_joined` because we were not joined at the the time
# of the `to_token`.
- self.assertEqual(room_id_results[kick_room_id].newly_joined, False)
- self.assertEqual(room_id_results[kick_room_id].newly_left, False)
+ self.assertTrue(kick_room_id not in newly_joined)
+ self.assertTrue(kick_room_id not in newly_left)
def test_forgotten_rooms(self) -> None:
"""
@@ -920,16 +1032,29 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
)
self.assertEqual(channel.code, 200, channel.result)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=before_room_forgets,
to_token=before_room_forgets,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
# We shouldn't see the room because it was forgotten
- self.assertEqual(room_id_results.keys(), set())
+ self.assertIncludes(room_id_results, set(), exact=True)
def test_newly_left_rooms(self) -> None:
"""
@@ -940,7 +1065,7 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
# Leave before we calculate the `from_token`
room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok)
- leave_response1 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
+ _leave_response1 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
after_room1_token = self.event_sources.get_current_token()
@@ -950,34 +1075,55 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
after_room2_token = self.event_sources.get_current_token()
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=after_room1_token,
to_token=after_room2_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
- self.assertEqual(room_id_results.keys(), {room_id1, room_id2})
-
- self.assertEqual(
- room_id_results[room_id1].event_id,
- leave_response1["event_id"],
+ # `room_id1` should not show up because it was left before the token range.
+ # `room_id2` should show up because it is `newly_left` within the token range.
+ self.assertIncludes(
+ room_id_results,
+ {room_id2},
+ exact=True,
+ message="Corresponding map to disambiguate the opaque room IDs: "
+ + str(
+ {
+ "room_id1": room_id1,
+ "room_id2": room_id2,
+ }
+ ),
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE)
- # We should *NOT* be `newly_joined` or `newly_left` because that happened before
- # the from/to range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
self.assertEqual(
- room_id_results[room_id2].event_id,
+ interested_rooms.room_membership_for_user_map[room_id2].event_id,
leave_response2["event_id"],
)
- self.assertEqual(room_id_results[room_id2].membership, Membership.LEAVE)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id2].membership,
+ Membership.LEAVE,
+ )
# We should *NOT* be `newly_joined` because we are instead `newly_left`
- self.assertEqual(room_id_results[room_id2].newly_joined, False)
- self.assertEqual(room_id_results[room_id2].newly_left, True)
+ self.assertTrue(room_id2 not in newly_joined)
+ self.assertTrue(room_id2 in newly_left)
def test_no_joins_after_to_token(self) -> None:
"""
@@ -1000,24 +1146,42 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok)
self.helper.join(room_id2, user1_id, tok=user1_tok)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=before_room1_token,
to_token=after_room1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
- self.assertEqual(room_id_results.keys(), {room_id1})
+ self.assertIncludes(room_id_results, {room_id1}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
join_response1["event_id"],
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.JOIN,
+ )
# We should be `newly_joined` because we joined during the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, True)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
+ self.assertTrue(room_id1 in newly_joined)
+ self.assertTrue(room_id1 not in newly_left)
def test_join_during_range_and_left_room_after_to_token(self) -> None:
"""
@@ -1040,20 +1204,35 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
# Leave the room after we already have our tokens
leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=before_room1_token,
to_token=after_room1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# We should still see the room because we were joined during the
# from_token/to_token time period.
- self.assertEqual(room_id_results.keys(), {room_id1})
+ self.assertIncludes(room_id_results, {room_id1}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
join_response["event_id"],
"Corresponding map to disambiguate the opaque event IDs: "
+ str(
@@ -1063,10 +1242,13 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
}
),
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.JOIN,
+ )
# We should be `newly_joined` because we joined during the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, True)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
+ self.assertTrue(room_id1 in newly_joined)
+ self.assertTrue(room_id1 not in newly_left)
def test_join_before_range_and_left_room_after_to_token(self) -> None:
"""
@@ -1087,19 +1269,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
# Leave the room after we already have our tokens
leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=after_room1_token,
to_token=after_room1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# We should still see the room because we were joined before the `from_token`
- self.assertEqual(room_id_results.keys(), {room_id1})
+ self.assertIncludes(room_id_results, {room_id1}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
join_response["event_id"],
"Corresponding map to disambiguate the opaque event IDs: "
+ str(
@@ -1109,10 +1306,13 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
}
),
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.JOIN,
+ )
# We should *NOT* be `newly_joined` because we joined before the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
+ self.assertTrue(room_id1 not in newly_joined)
+ self.assertTrue(room_id1 not in newly_left)
def test_kicked_before_range_and_left_after_to_token(self) -> None:
"""
@@ -1151,19 +1351,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
join_response2 = self.helper.join(kick_room_id, user1_id, tok=user1_tok)
leave_response = self.helper.leave(kick_room_id, user1_id, tok=user1_tok)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=after_kick_token,
to_token=after_kick_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# We shouldn't see the room because it was forgotten
- self.assertEqual(room_id_results.keys(), {kick_room_id})
+ self.assertIncludes(room_id_results, {kick_room_id}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[kick_room_id].event_id,
+ interested_rooms.room_membership_for_user_map[kick_room_id].event_id,
kick_response["event_id"],
"Corresponding map to disambiguate the opaque event IDs: "
+ str(
@@ -1175,11 +1390,16 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
}
),
)
- self.assertEqual(room_id_results[kick_room_id].membership, Membership.LEAVE)
- self.assertNotEqual(room_id_results[kick_room_id].sender, user1_id)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[kick_room_id].membership,
+ Membership.LEAVE,
+ )
+ self.assertNotEqual(
+ interested_rooms.room_membership_for_user_map[kick_room_id].sender, user1_id
+ )
# We should *NOT* be `newly_joined` because we were kicked
- self.assertEqual(room_id_results[kick_room_id].newly_joined, False)
- self.assertEqual(room_id_results[kick_room_id].newly_left, False)
+ self.assertTrue(kick_room_id not in newly_joined)
+ self.assertTrue(kick_room_id not in newly_left)
def test_newly_left_during_range_and_join_leave_after_to_token(self) -> None:
"""
@@ -1207,19 +1427,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
join_response2 = self.helper.join(room_id1, user1_id, tok=user1_tok)
leave_response2 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=before_room1_token,
to_token=after_room1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# Room should still show up because it's newly_left during the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
+ self.assertIncludes(room_id_results, {room_id1}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
leave_response1["event_id"],
"Corresponding map to disambiguate the opaque event IDs: "
+ str(
@@ -1231,11 +1466,14 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
}
),
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.LEAVE,
+ )
# We should *NOT* be `newly_joined` because we are actually `newly_left` during
# the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, True)
+ self.assertTrue(room_id1 not in newly_joined)
+ self.assertTrue(room_id1 in newly_left)
def test_newly_left_during_range_and_join_after_to_token(self) -> None:
"""
@@ -1262,19 +1500,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
# Join the room after we already have our tokens
join_response2 = self.helper.join(room_id1, user1_id, tok=user1_tok)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=before_room1_token,
to_token=after_room1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# Room should still show up because it's newly_left during the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
+ self.assertIncludes(room_id_results, {room_id1}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
leave_response1["event_id"],
"Corresponding map to disambiguate the opaque event IDs: "
+ str(
@@ -1285,11 +1538,14 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
}
),
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.LEAVE,
+ )
# We should *NOT* be `newly_joined` because we are actually `newly_left` during
# the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, True)
+ self.assertTrue(room_id1 not in newly_joined)
+ self.assertTrue(room_id1 in newly_left)
def test_no_from_token(self) -> None:
"""
@@ -1314,47 +1570,53 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
# Join and leave the room2 before the `to_token`
self.helper.join(room_id2, user1_id, tok=user1_tok)
- leave_response2 = self.helper.leave(room_id2, user1_id, tok=user1_tok)
+ _leave_response2 = self.helper.leave(room_id2, user1_id, tok=user1_tok)
after_room1_token = self.event_sources.get_current_token()
# Join the room2 after we already have our tokens
self.helper.join(room_id2, user1_id, tok=user1_tok)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=None,
to_token=after_room1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# Only rooms we were joined to before the `to_token` should show up
- self.assertEqual(room_id_results.keys(), {room_id1, room_id2})
+ self.assertIncludes(room_id_results, {room_id1}, exact=True)
# Room1
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
join_response1["event_id"],
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
- # We should *NOT* be `newly_joined`/`newly_left` because there is no
- # `from_token` to define a "live" range to compare against
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
-
- # Room2
- # It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id2].event_id,
- leave_response2["event_id"],
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.JOIN,
)
- self.assertEqual(room_id_results[room_id2].membership, Membership.LEAVE)
# We should *NOT* be `newly_joined`/`newly_left` because there is no
# `from_token` to define a "live" range to compare against
- self.assertEqual(room_id_results[room_id2].newly_joined, False)
- self.assertEqual(room_id_results[room_id2].newly_left, False)
+ self.assertTrue(room_id1 not in newly_joined)
+ self.assertTrue(room_id1 not in newly_left)
def test_from_token_ahead_of_to_token(self) -> None:
"""
@@ -1378,7 +1640,7 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
# Join and leave the room2 before `to_token`
_join_room2_response1 = self.helper.join(room_id2, user1_id, tok=user1_tok)
- leave_room2_response1 = self.helper.leave(room_id2, user1_id, tok=user1_tok)
+ _leave_room2_response1 = self.helper.leave(room_id2, user1_id, tok=user1_tok)
# Note: These are purposely swapped. The `from_token` should come after
# the `to_token` in this test
@@ -1403,54 +1665,69 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
# Join the room4 after we already have our tokens
self.helper.join(room_id4, user1_id, tok=user1_tok)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=from_token,
to_token=to_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# In the "current" state snapshot, we're joined to all of the rooms but in the
# from/to token range...
self.assertIncludes(
- room_id_results.keys(),
+ room_id_results,
{
# Included because we were joined before both tokens
room_id1,
- # Included because we had membership before the to_token
- room_id2,
+ # Excluded because we left before the `from_token` and `to_token`
+ # room_id2,
# Excluded because we joined after the `to_token`
# room_id3,
# Excluded because we joined after the `to_token`
# room_id4,
},
exact=True,
+ message="Corresponding map to disambiguate the opaque room IDs: "
+ + str(
+ {
+ "room_id1": room_id1,
+ "room_id2": room_id2,
+ "room_id3": room_id3,
+ "room_id4": room_id4,
+ }
+ ),
)
# Room1
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
join_room1_response1["event_id"],
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
- # We should *NOT* be `newly_joined`/`newly_left` because we joined `room1`
- # before either of the tokens
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
-
- # Room2
- # It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id2].event_id,
- leave_room2_response1["event_id"],
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.JOIN,
)
- self.assertEqual(room_id_results[room_id2].membership, Membership.LEAVE)
- # We should *NOT* be `newly_joined`/`newly_left` because we joined and left
- # `room1` before either of the tokens
- self.assertEqual(room_id_results[room_id2].newly_joined, False)
- self.assertEqual(room_id_results[room_id2].newly_left, False)
+ # We should *NOT* be `newly_joined`/`newly_left` because we joined `room1`
+ # before either of the tokens
+ self.assertTrue(room_id1 not in newly_joined)
+ self.assertTrue(room_id1 not in newly_left)
def test_leave_before_range_and_join_leave_after_to_token(self) -> None:
"""
@@ -1468,7 +1745,7 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
# Join and leave the room before the from/to range
self.helper.join(room_id1, user1_id, tok=user1_tok)
- leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
+ self.helper.leave(room_id1, user1_id, tok=user1_tok)
after_room1_token = self.event_sources.get_current_token()
@@ -1476,25 +1753,28 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
self.helper.join(room_id1, user1_id, tok=user1_tok)
self.helper.leave(room_id1, user1_id, tok=user1_tok)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=after_room1_token,
to_token=after_room1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
- self.assertEqual(room_id_results.keys(), {room_id1})
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[room_id1].event_id,
- leave_response["event_id"],
- )
- self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE)
- # We should *NOT* be `newly_joined`/`newly_left` because we joined and left
- # `room1` before either of the tokens
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
+ self.assertIncludes(room_id_results, set(), exact=True)
def test_leave_before_range_and_join_after_to_token(self) -> None:
"""
@@ -1512,32 +1792,35 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
# Join and leave the room before the from/to range
self.helper.join(room_id1, user1_id, tok=user1_tok)
- leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
+ self.helper.leave(room_id1, user1_id, tok=user1_tok)
after_room1_token = self.event_sources.get_current_token()
# Join the room after we already have our tokens
self.helper.join(room_id1, user1_id, tok=user1_tok)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=after_room1_token,
to_token=after_room1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
- self.assertEqual(room_id_results.keys(), {room_id1})
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[room_id1].event_id,
- leave_response["event_id"],
- )
- self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE)
- # We should *NOT* be `newly_joined`/`newly_left` because we joined and left
- # `room1` before either of the tokens
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
+ self.assertIncludes(room_id_results, set(), exact=True)
def test_join_leave_multiple_times_during_range_and_after_to_token(
self,
@@ -1569,19 +1852,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
join_response3 = self.helper.join(room_id1, user1_id, tok=user1_tok)
leave_response3 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=before_room1_token,
to_token=after_room1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# Room should show up because it was newly_left and joined during the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
+ self.assertIncludes(room_id_results, {room_id1}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
join_response2["event_id"],
"Corresponding map to disambiguate the opaque event IDs: "
+ str(
@@ -1595,12 +1893,15 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
}
),
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.JOIN,
+ )
# We should be `newly_joined` because we joined during the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, True)
+ self.assertTrue(room_id1 in newly_joined)
# We should *NOT* be `newly_left` because we joined during the token range and
# was still joined at the end of the range
- self.assertEqual(room_id_results[room_id1].newly_left, False)
+ self.assertTrue(room_id1 not in newly_left)
def test_join_leave_multiple_times_before_range_and_after_to_token(
self,
@@ -1631,19 +1932,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
join_response3 = self.helper.join(room_id1, user1_id, tok=user1_tok)
leave_response3 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=after_room1_token,
to_token=after_room1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# Room should show up because we were joined before the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
+ self.assertIncludes(room_id_results, {room_id1}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
join_response2["event_id"],
"Corresponding map to disambiguate the opaque event IDs: "
+ str(
@@ -1657,10 +1973,13 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
}
),
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.JOIN,
+ )
# We should *NOT* be `newly_joined` because we joined before the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
+ self.assertTrue(room_id1 not in newly_joined)
+ self.assertTrue(room_id1 not in newly_left)
def test_invite_before_range_and_join_leave_after_to_token(
self,
@@ -1690,19 +2009,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
join_respsonse = self.helper.join(room_id1, user1_id, tok=user1_tok)
leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=after_room1_token,
to_token=after_room1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# Room should show up because we were invited before the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
+ self.assertIncludes(room_id_results, {room_id1}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
invite_response["event_id"],
"Corresponding map to disambiguate the opaque event IDs: "
+ str(
@@ -1713,11 +2047,14 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
}
),
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.INVITE)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.INVITE,
+ )
# We should *NOT* be `newly_joined` because we were only invited before the
# token range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
+ self.assertTrue(room_id1 not in newly_joined)
+ self.assertTrue(room_id1 not in newly_left)
def test_join_and_display_name_changes_in_token_range(
self,
@@ -1764,19 +2101,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
tok=user1_tok,
)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=before_room1_token,
to_token=after_room1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# Room should show up because we were joined during the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
+ self.assertIncludes(room_id_results, {room_id1}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
displayname_change_during_token_range_response["event_id"],
"Corresponding map to disambiguate the opaque event IDs: "
+ str(
@@ -1791,10 +2143,13 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
}
),
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.JOIN,
+ )
# We should be `newly_joined` because we joined during the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, True)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
+ self.assertTrue(room_id1 in newly_joined)
+ self.assertTrue(room_id1 not in newly_left)
def test_display_name_changes_in_token_range(
self,
@@ -1829,19 +2184,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
after_change1_token = self.event_sources.get_current_token()
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=after_room1_token,
to_token=after_change1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# Room should show up because we were joined during the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
+ self.assertIncludes(room_id_results, {room_id1}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
displayname_change_during_token_range_response["event_id"],
"Corresponding map to disambiguate the opaque event IDs: "
+ str(
@@ -1853,10 +2223,13 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
}
),
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.JOIN,
+ )
# We should *NOT* be `newly_joined` because we joined before the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
+ self.assertTrue(room_id1 not in newly_joined)
+ self.assertTrue(room_id1 not in newly_left)
def test_display_name_changes_before_and_after_token_range(
self,
@@ -1901,19 +2274,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
tok=user1_tok,
)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=after_room1_token,
to_token=after_room1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# Room should show up because we were joined before the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
+ self.assertIncludes(room_id_results, {room_id1}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
displayname_change_before_token_range_response["event_id"],
"Corresponding map to disambiguate the opaque event IDs: "
+ str(
@@ -1928,18 +2316,22 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
}
),
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.JOIN,
+ )
# We should *NOT* be `newly_joined` because we joined before the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
+ self.assertTrue(room_id1 not in newly_joined)
+ self.assertTrue(room_id1 not in newly_left)
- def test_display_name_changes_leave_after_token_range(
+ def test_newly_joined_display_name_changes_leave_after_token_range(
self,
) -> None:
"""
Test that we point to the correct membership event within the from/to range even
- if there are multiple `join` membership events in a row indicating
- `displayname`/`avatar_url` updates and we leave after the `to_token`.
+ if we are `newly_joined` and there are multiple `join` membership events in a
+ row indicating `displayname`/`avatar_url` updates and we leave after the
+ `to_token`.
See condition "1a)" comments in the `get_room_membership_for_user_at_to_token()` method.
"""
@@ -1954,6 +2346,7 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
# leave and can still re-join.
room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
join_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
+
# Update the displayname during the token range
displayname_change_during_token_range_response = self.helper.send_state(
room_id1,
@@ -1983,19 +2376,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
# Leave after the token
self.helper.leave(room_id1, user1_id, tok=user1_tok)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=before_room1_token,
to_token=after_room1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# Room should show up because we were joined during the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
+ self.assertIncludes(room_id_results, {room_id1}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
displayname_change_during_token_range_response["event_id"],
"Corresponding map to disambiguate the opaque event IDs: "
+ str(
@@ -2010,10 +2418,117 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
}
),
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.JOIN,
+ )
# We should be `newly_joined` because we joined during the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, True)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
+ self.assertTrue(room_id1 in newly_joined)
+ self.assertTrue(room_id1 not in newly_left)
+
+ def test_display_name_changes_leave_after_token_range(
+ self,
+ ) -> None:
+ """
+ Test that we point to the correct membership event within the from/to range even
+ if there are multiple `join` membership events in a row indicating
+ `displayname`/`avatar_url` updates and we leave after the `to_token`.
+
+ See condition "1a)" comments in the `get_room_membership_for_user_at_to_token()` method.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ _before_room1_token = self.event_sources.get_current_token()
+
+ # We create the room with user2 so the room isn't left with no members when we
+ # leave and can still re-join.
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
+ join_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ after_join_token = self.event_sources.get_current_token()
+
+ # Update the displayname during the token range
+ displayname_change_during_token_range_response = self.helper.send_state(
+ room_id1,
+ event_type=EventTypes.Member,
+ state_key=user1_id,
+ body={
+ "membership": Membership.JOIN,
+ "displayname": "displayname during token range",
+ },
+ tok=user1_tok,
+ )
+
+ after_display_name_change_token = self.event_sources.get_current_token()
+
+ # Update the displayname after the token range
+ displayname_change_after_token_range_response = self.helper.send_state(
+ room_id1,
+ event_type=EventTypes.Member,
+ state_key=user1_id,
+ body={
+ "membership": Membership.JOIN,
+ "displayname": "displayname after token range",
+ },
+ tok=user1_tok,
+ )
+
+ # Leave after the token
+ self.helper.leave(room_id1, user1_id, tok=user1_tok)
+
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
+ from_token=after_join_token,
+ to_token=after_display_name_change_token,
+ )
+ )
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
+
+ # Room should show up because we were joined during the from/to range
+ self.assertIncludes(room_id_results, {room_id1}, exact=True)
+ # It should be pointing to the latest membership event in the from/to range
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
+ displayname_change_during_token_range_response["event_id"],
+ "Corresponding map to disambiguate the opaque event IDs: "
+ + str(
+ {
+ "join_response": join_response["event_id"],
+ "displayname_change_during_token_range_response": displayname_change_during_token_range_response[
+ "event_id"
+ ],
+ "displayname_change_after_token_range_response": displayname_change_after_token_range_response[
+ "event_id"
+ ],
+ }
+ ),
+ )
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.JOIN,
+ )
+ # We only changed our display name during the token range so we shouldn't be
+ # considered `newly_joined` or `newly_left`
+ self.assertTrue(room_id1 not in newly_joined)
+ self.assertTrue(room_id1 not in newly_left)
def test_display_name_changes_join_after_token_range(
self,
@@ -2051,16 +2566,29 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
tok=user1_tok,
)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=before_room1_token,
to_token=after_room1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
# Room shouldn't show up because we joined after the from/to range
- self.assertEqual(room_id_results.keys(), set())
+ self.assertIncludes(room_id_results, set(), exact=True)
def test_newly_joined_with_leave_join_in_token_range(
self,
@@ -2087,26 +2615,44 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
after_more_changes_token = self.event_sources.get_current_token()
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=after_room1_token,
to_token=after_more_changes_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# Room should show up because we were joined during the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
+ self.assertIncludes(room_id_results, {room_id1}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
join_response2["event_id"],
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.JOIN,
+ )
# We should be considered `newly_joined` because there is some non-join event in
# between our latest join event.
- self.assertEqual(room_id_results[room_id1].newly_joined, True)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
+ self.assertTrue(room_id1 in newly_joined)
+ self.assertTrue(room_id1 not in newly_left)
def test_newly_joined_only_joins_during_token_range(
self,
@@ -2152,19 +2698,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
after_room1_token = self.event_sources.get_current_token()
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=before_room1_token,
to_token=after_room1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# Room should show up because it was newly_left and joined during the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
+ self.assertIncludes(room_id_results, {room_id1}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
displayname_change_during_token_range_response2["event_id"],
"Corresponding map to disambiguate the opaque event IDs: "
+ str(
@@ -2179,10 +2740,13 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
}
),
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.JOIN,
+ )
# We should be `newly_joined` because we first joined during the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, True)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
+ self.assertTrue(room_id1 in newly_joined)
+ self.assertTrue(room_id1 not in newly_left)
def test_multiple_rooms_are_not_confused(
self,
@@ -2205,7 +2769,7 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
# Invited and left the room before the token
self.helper.invite(room_id1, src=user2_id, targ=user1_id, tok=user2_tok)
- leave_room1_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
+ _leave_room1_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
# Invited to room2
invite_room2_response = self.helper.invite(
room_id2, src=user2_id, targ=user1_id, tok=user2_tok
@@ -2228,61 +2792,71 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
# Leave room3
self.helper.leave(room_id3, user1_id, tok=user1_tok)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=before_room3_token,
to_token=after_room3_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
- self.assertEqual(
- room_id_results.keys(),
+ self.assertIncludes(
+ room_id_results,
{
- # Left before the from/to range
- room_id1,
+ # Excluded because we left before the from/to range
+ # room_id1,
# Invited before the from/to range
room_id2,
# `newly_left` during the from/to range
room_id3,
},
+ exact=True,
)
- # Room1
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[room_id1].event_id,
- leave_room1_response["event_id"],
- )
- self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE)
- # We should *NOT* be `newly_joined`/`newly_left` because we were invited and left
- # before the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
-
# Room2
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id2].event_id,
+ interested_rooms.room_membership_for_user_map[room_id2].event_id,
invite_room2_response["event_id"],
)
- self.assertEqual(room_id_results[room_id2].membership, Membership.INVITE)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id2].membership,
+ Membership.INVITE,
+ )
# We should *NOT* be `newly_joined`/`newly_left` because we were invited before
# the token range
- self.assertEqual(room_id_results[room_id2].newly_joined, False)
- self.assertEqual(room_id_results[room_id2].newly_left, False)
+ self.assertTrue(room_id2 not in newly_joined)
+ self.assertTrue(room_id2 not in newly_left)
# Room3
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id3].event_id,
+ interested_rooms.room_membership_for_user_map[room_id3].event_id,
leave_room3_response["event_id"],
)
- self.assertEqual(room_id_results[room_id3].membership, Membership.LEAVE)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id3].membership,
+ Membership.LEAVE,
+ )
# We should be `newly_left` because we were invited and left during
# the token range
- self.assertEqual(room_id_results[room_id3].newly_joined, False)
- self.assertEqual(room_id_results[room_id3].newly_left, True)
+ self.assertTrue(room_id3 not in newly_joined)
+ self.assertTrue(room_id3 in newly_left)
def test_state_reset(self) -> None:
"""
@@ -2295,7 +2869,16 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
user2_tok = self.login(user2_id, "pass")
# The room where the state reset will happen
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ room_id1 = self.helper.create_room_as(
+ user2_id,
+ is_public=True,
+ tok=user2_tok,
+ )
+ # Create a dummy event for us to point back to for the state reset
+ dummy_event_response = self.helper.send(room_id1, "test", tok=user2_tok)
+ dummy_event_id = dummy_event_response["event_id"]
+
+ # Join after the dummy event
join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
# Join another room so we don't hit the short-circuit and return early if they
@@ -2305,95 +2888,106 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
before_reset_token = self.event_sources.get_current_token()
- # Send another state event to make a position for the state reset to happen at
- dummy_state_response = self.helper.send_state(
- room_id1,
- event_type="foobarbaz",
- state_key="",
- body={"foo": "bar"},
- tok=user2_tok,
- )
- dummy_state_pos = self.get_success(
- self.store.get_position_for_event(dummy_state_response["event_id"])
- )
-
- # Mock a state reset removing the membership for user1 in the current state
- self.get_success(
- self.store.db_pool.simple_delete(
- table="current_state_events",
- keyvalues={
- "room_id": room_id1,
- "type": EventTypes.Member,
- "state_key": user1_id,
- },
- desc="state reset user in current_state_events",
+ # Trigger a state reset
+ join_rule_event, join_rule_context = self.get_success(
+ create_event(
+ self.hs,
+ prev_event_ids=[dummy_event_id],
+ type=EventTypes.JoinRules,
+ state_key="",
+ content={"join_rule": JoinRules.INVITE},
+ sender=user2_id,
+ room_id=room_id1,
+ room_version=self.get_success(self.store.get_room_version_id(room_id1)),
)
)
- self.get_success(
- self.store.db_pool.simple_delete(
- table="local_current_membership",
- keyvalues={
- "room_id": room_id1,
- "user_id": user1_id,
- },
- desc="state reset user in local_current_membership",
- )
- )
- self.get_success(
- self.store.db_pool.simple_insert(
- table="current_state_delta_stream",
- values={
- "stream_id": dummy_state_pos.stream,
- "room_id": room_id1,
- "type": EventTypes.Member,
- "state_key": user1_id,
- "event_id": None,
- "prev_event_id": join_response1["event_id"],
- "instance_name": dummy_state_pos.instance_name,
- },
- desc="state reset user in current_state_delta_stream",
- )
+ _, join_rule_event_pos, _ = self.get_success(
+ self.persistence.persist_event(join_rule_event, join_rule_context)
)
- # Manually bust the cache since we we're just manually messing with the database
- # and not causing an actual state reset.
- self.store._membership_stream_cache.entity_has_changed(
- user1_id, dummy_state_pos.stream
- )
+ # Ensure that the state reset worked and only user2 is in the room now
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id1))
+ self.assertIncludes(set(users_in_room), {user2_id}, exact=True)
after_reset_token = self.event_sources.get_current_token()
# The function under test
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=before_reset_token,
to_token=after_reset_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# Room1 should show up because it was `newly_left` via state reset during the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1, room_id2})
+ self.assertIncludes(room_id_results, {room_id1, room_id2}, exact=True)
# It should be pointing to no event because we were removed from the room
# without a corresponding leave event
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
None,
+ "Corresponding map to disambiguate the opaque event IDs: "
+ + str(
+ {
+ "join_response1": join_response1["event_id"],
+ }
+ ),
)
# State reset caused us to leave the room and there is no corresponding leave event
- self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.LEAVE,
+ )
# We should *NOT* be `newly_joined` because we joined before the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
+ self.assertTrue(room_id1 not in newly_joined)
# We should be `newly_left` because we were removed via state reset during the from/to range
- self.assertEqual(room_id_results[room_id1].newly_left, True)
-
-
-class GetRoomMembershipForUserAtToTokenShardTestCase(BaseMultiWorkerStreamTestCase):
+ self.assertTrue(room_id1 in newly_left)
+
+
+# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+# foreground update for
+# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+# https://github.com/element-hq/synapse/issues/17623)
+@parameterized_class(
+ ("use_new_tables",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}",
+)
+class ComputeInterestedRoomsShardTestCase(
+ BaseMultiWorkerStreamTestCase, SlidingSyncBase
+):
"""
- Tests Sliding Sync handler `get_room_membership_for_user_at_to_token()` to make sure it works with
+ Tests Sliding Sync handler `compute_interested_rooms()` to make sure it works with
sharded event stream_writers enabled
"""
+ # FIXME: We should refactor these tests to run against `compute_interested_rooms(...)`
+ # instead of just `get_room_membership_for_user_at_to_token(...)` which is only used
+ # in the fallback path (`_compute_interested_rooms_fallback(...)`). These scenarios do
+ # well to stress that logic and we shouldn't remove them just because we're removing
+ # the fallback path (tracked by https://github.com/element-hq/synapse/issues/17623).
+
servlets = [
admin.register_servlets_for_client_rest_resource,
room.register_servlets,
@@ -2488,7 +3082,7 @@ class GetRoomMembershipForUserAtToTokenShardTestCase(BaseMultiWorkerStreamTestCa
join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
join_response2 = self.helper.join(room_id2, user1_id, tok=user1_tok)
# Leave room2
- leave_room2_response = self.helper.leave(room_id2, user1_id, tok=user1_tok)
+ _leave_room2_response = self.helper.leave(room_id2, user1_id, tok=user1_tok)
join_response3 = self.helper.join(room_id3, user1_id, tok=user1_tok)
# Leave room3
self.helper.leave(room_id3, user1_id, tok=user1_tok)
@@ -2578,60 +3172,77 @@ class GetRoomMembershipForUserAtToTokenShardTestCase(BaseMultiWorkerStreamTestCa
self.get_success(actx.__aexit__(None, None, None))
# The function under test
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=before_stuck_activity_token,
to_token=stuck_activity_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
- self.assertEqual(
- room_id_results.keys(),
+ self.assertIncludes(
+ room_id_results,
{
room_id1,
- room_id2,
+ # Excluded because we left before the from/to range and the second join
+ # event happened while worker2 was stuck and technically occurs after
+ # the `stuck_activity_token`.
+ # room_id2,
room_id3,
},
+ exact=True,
+ message="Corresponding map to disambiguate the opaque room IDs: "
+ + str(
+ {
+ "room_id1": room_id1,
+ "room_id2": room_id2,
+ "room_id3": room_id3,
+ }
+ ),
)
# Room1
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
join_room1_response["event_id"],
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
- # We should be `newly_joined` because we joined during the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, True)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
-
- # Room2
- # It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id2].event_id,
- leave_room2_response["event_id"],
- )
- self.assertEqual(room_id_results[room_id2].membership, Membership.LEAVE)
- # room_id2 should *NOT* be considered `newly_left` because we left before the
- # from/to range and the join event during the range happened while worker2 was
- # stuck. This means that from the perspective of the master, where the
- # `stuck_activity_token` is generated, the stream position for worker2 wasn't
- # advanced to the join yet. Looking at the `instance_map`, the join technically
- # comes after `stuck_activity_token`.
- self.assertEqual(room_id_results[room_id2].newly_joined, False)
- self.assertEqual(room_id_results[room_id2].newly_left, False)
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.JOIN,
+ )
+ # We should be `newly_joined` because we joined during the token range
+ self.assertTrue(room_id1 in newly_joined)
+ self.assertTrue(room_id1 not in newly_left)
# Room3
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id3].event_id,
+ interested_rooms.room_membership_for_user_map[room_id3].event_id,
join_on_worker3_response["event_id"],
)
- self.assertEqual(room_id_results[room_id3].membership, Membership.JOIN)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id3].membership,
+ Membership.JOIN,
+ )
# We should be `newly_joined` because we joined during the token range
- self.assertEqual(room_id_results[room_id3].newly_joined, True)
- self.assertEqual(room_id_results[room_id3].newly_left, False)
+ self.assertTrue(room_id3 in newly_joined)
+ self.assertTrue(room_id3 not in newly_left)
class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase):
@@ -2658,31 +3269,35 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase):
self.store = self.hs.get_datastores().main
self.event_sources = hs.get_event_sources()
self.storage_controllers = hs.get_storage_controllers()
+ persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
+ self.persistence = persistence
def _get_sync_room_ids_for_user(
self,
user: UserID,
to_token: StreamToken,
from_token: Optional[StreamToken],
- ) -> Dict[str, _RoomMembershipForUser]:
+ ) -> Tuple[Dict[str, RoomsForUserType], AbstractSet[str], AbstractSet[str]]:
"""
Get the rooms the user should be syncing with
"""
- room_membership_for_user_map = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
+ room_membership_for_user_map, newly_joined, newly_left = self.get_success(
+ self.sliding_sync_handler.room_lists.get_room_membership_for_user_at_to_token(
user=user,
from_token=from_token,
to_token=to_token,
)
)
filtered_sync_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms_relevant_for_sync(
+ self.sliding_sync_handler.room_lists.filter_rooms_relevant_for_sync(
user=user,
room_membership_for_user_map=room_membership_for_user_map,
+ newly_left_room_ids=newly_left,
)
)
- return filtered_sync_room_map
+ return filtered_sync_room_map, newly_joined, newly_left
def test_no_rooms(self) -> None:
"""
@@ -2693,13 +3308,13 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase):
now_token = self.event_sources.get_current_token()
- room_id_results = self._get_sync_room_ids_for_user(
+ room_id_results, newly_joined, newly_left = self._get_sync_room_ids_for_user(
UserID.from_string(user1_id),
from_token=now_token,
to_token=now_token,
)
- self.assertEqual(room_id_results.keys(), set())
+ self.assertIncludes(room_id_results.keys(), set(), exact=True)
def test_basic_rooms(self) -> None:
"""
@@ -2758,14 +3373,14 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase):
after_room_token = self.event_sources.get_current_token()
- room_id_results = self._get_sync_room_ids_for_user(
+ room_id_results, newly_joined, newly_left = self._get_sync_room_ids_for_user(
UserID.from_string(user1_id),
from_token=before_room_token,
to_token=after_room_token,
)
# Ensure that the invited, ban, and knock rooms show up
- self.assertEqual(
+ self.assertIncludes(
room_id_results.keys(),
{
join_room_id,
@@ -2773,6 +3388,7 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase):
ban_room_id,
knock_room_id,
},
+ exact=True,
)
# It should be pointing to the the respective membership event (latest
# membership event in the from/to range)
@@ -2781,32 +3397,32 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase):
join_response["event_id"],
)
self.assertEqual(room_id_results[join_room_id].membership, Membership.JOIN)
- self.assertEqual(room_id_results[join_room_id].newly_joined, True)
- self.assertEqual(room_id_results[join_room_id].newly_left, False)
+ self.assertTrue(join_room_id in newly_joined)
+ self.assertTrue(join_room_id not in newly_left)
self.assertEqual(
room_id_results[invited_room_id].event_id,
invite_response["event_id"],
)
self.assertEqual(room_id_results[invited_room_id].membership, Membership.INVITE)
- self.assertEqual(room_id_results[invited_room_id].newly_joined, False)
- self.assertEqual(room_id_results[invited_room_id].newly_left, False)
+ self.assertTrue(invited_room_id not in newly_joined)
+ self.assertTrue(invited_room_id not in newly_left)
self.assertEqual(
room_id_results[ban_room_id].event_id,
ban_response["event_id"],
)
self.assertEqual(room_id_results[ban_room_id].membership, Membership.BAN)
- self.assertEqual(room_id_results[ban_room_id].newly_joined, False)
- self.assertEqual(room_id_results[ban_room_id].newly_left, False)
+ self.assertTrue(ban_room_id not in newly_joined)
+ self.assertTrue(ban_room_id not in newly_left)
self.assertEqual(
room_id_results[knock_room_id].event_id,
knock_room_membership_state_event.event_id,
)
self.assertEqual(room_id_results[knock_room_id].membership, Membership.KNOCK)
- self.assertEqual(room_id_results[knock_room_id].newly_joined, False)
- self.assertEqual(room_id_results[knock_room_id].newly_left, False)
+ self.assertTrue(knock_room_id not in newly_joined)
+ self.assertTrue(knock_room_id not in newly_left)
def test_only_newly_left_rooms_show_up(self) -> None:
"""
@@ -2829,21 +3445,21 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase):
after_room2_token = self.event_sources.get_current_token()
- room_id_results = self._get_sync_room_ids_for_user(
+ room_id_results, newly_joined, newly_left = self._get_sync_room_ids_for_user(
UserID.from_string(user1_id),
from_token=after_room1_token,
to_token=after_room2_token,
)
# Only the `newly_left` room should show up
- self.assertEqual(room_id_results.keys(), {room_id2})
+ self.assertIncludes(room_id_results.keys(), {room_id2}, exact=True)
self.assertEqual(
room_id_results[room_id2].event_id,
_leave_response2["event_id"],
)
# We should *NOT* be `newly_joined` because we are instead `newly_left`
- self.assertEqual(room_id_results[room_id2].newly_joined, False)
- self.assertEqual(room_id_results[room_id2].newly_left, True)
+ self.assertTrue(room_id2 not in newly_joined)
+ self.assertTrue(room_id2 in newly_left)
def test_get_kicked_room(self) -> None:
"""
@@ -2874,14 +3490,14 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase):
after_kick_token = self.event_sources.get_current_token()
- room_id_results = self._get_sync_room_ids_for_user(
+ room_id_results, newly_joined, newly_left = self._get_sync_room_ids_for_user(
UserID.from_string(user1_id),
from_token=after_kick_token,
to_token=after_kick_token,
)
# The kicked room should show up
- self.assertEqual(room_id_results.keys(), {kick_room_id})
+ self.assertIncludes(room_id_results.keys(), {kick_room_id}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
room_id_results[kick_room_id].event_id,
@@ -2891,8 +3507,8 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase):
self.assertNotEqual(room_id_results[kick_room_id].sender, user1_id)
# We should *NOT* be `newly_joined` because we were not joined at the the time
# of the `to_token`.
- self.assertEqual(room_id_results[kick_room_id].newly_joined, False)
- self.assertEqual(room_id_results[kick_room_id].newly_left, False)
+ self.assertTrue(kick_room_id not in newly_joined)
+ self.assertTrue(kick_room_id not in newly_left)
def test_state_reset(self) -> None:
"""
@@ -2905,8 +3521,17 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase):
user2_tok = self.login(user2_id, "pass")
# The room where the state reset will happen
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
+ room_id1 = self.helper.create_room_as(
+ user2_id,
+ is_public=True,
+ tok=user2_tok,
+ )
+ # Create a dummy event for us to point back to for the state reset
+ dummy_event_response = self.helper.send(room_id1, "test", tok=user2_tok)
+ dummy_event_id = dummy_event_response["event_id"]
+
+ # Join after the dummy event
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
# Join another room so we don't hit the short-circuit and return early if they
# have no room membership
@@ -2915,73 +3540,38 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase):
before_reset_token = self.event_sources.get_current_token()
- # Send another state event to make a position for the state reset to happen at
- dummy_state_response = self.helper.send_state(
- room_id1,
- event_type="foobarbaz",
- state_key="",
- body={"foo": "bar"},
- tok=user2_tok,
- )
- dummy_state_pos = self.get_success(
- self.store.get_position_for_event(dummy_state_response["event_id"])
- )
-
- # Mock a state reset removing the membership for user1 in the current state
- self.get_success(
- self.store.db_pool.simple_delete(
- table="current_state_events",
- keyvalues={
- "room_id": room_id1,
- "type": EventTypes.Member,
- "state_key": user1_id,
- },
- desc="state reset user in current_state_events",
+ # Trigger a state reset
+ join_rule_event, join_rule_context = self.get_success(
+ create_event(
+ self.hs,
+ prev_event_ids=[dummy_event_id],
+ type=EventTypes.JoinRules,
+ state_key="",
+ content={"join_rule": JoinRules.INVITE},
+ sender=user2_id,
+ room_id=room_id1,
+ room_version=self.get_success(self.store.get_room_version_id(room_id1)),
)
)
- self.get_success(
- self.store.db_pool.simple_delete(
- table="local_current_membership",
- keyvalues={
- "room_id": room_id1,
- "user_id": user1_id,
- },
- desc="state reset user in local_current_membership",
- )
- )
- self.get_success(
- self.store.db_pool.simple_insert(
- table="current_state_delta_stream",
- values={
- "stream_id": dummy_state_pos.stream,
- "room_id": room_id1,
- "type": EventTypes.Member,
- "state_key": user1_id,
- "event_id": None,
- "prev_event_id": join_response1["event_id"],
- "instance_name": dummy_state_pos.instance_name,
- },
- desc="state reset user in current_state_delta_stream",
- )
+ _, join_rule_event_pos, _ = self.get_success(
+ self.persistence.persist_event(join_rule_event, join_rule_context)
)
- # Manually bust the cache since we we're just manually messing with the database
- # and not causing an actual state reset.
- self.store._membership_stream_cache.entity_has_changed(
- user1_id, dummy_state_pos.stream
- )
+ # Ensure that the state reset worked and only user2 is in the room now
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id1))
+ self.assertIncludes(set(users_in_room), {user2_id}, exact=True)
after_reset_token = self.event_sources.get_current_token()
# The function under test
- room_id_results = self._get_sync_room_ids_for_user(
+ room_id_results, newly_joined, newly_left = self._get_sync_room_ids_for_user(
UserID.from_string(user1_id),
from_token=before_reset_token,
to_token=after_reset_token,
)
# Room1 should show up because it was `newly_left` via state reset during the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1, room_id2})
+ self.assertIncludes(room_id_results.keys(), {room_id1, room_id2}, exact=True)
# It should be pointing to no event because we were removed from the room
# without a corresponding leave event
self.assertEqual(
@@ -2991,1345 +3581,9 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase):
# State reset caused us to leave the room and there is no corresponding leave event
self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE)
# We should *NOT* be `newly_joined` because we joined before the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
+ self.assertTrue(room_id1 not in newly_joined)
# We should be `newly_left` because we were removed via state reset during the from/to range
- self.assertEqual(room_id_results[room_id1].newly_left, True)
-
-
-class FilterRoomsTestCase(HomeserverTestCase):
- """
- Tests Sliding Sync handler `filter_rooms()` to make sure it includes/excludes rooms
- correctly.
- """
-
- servlets = [
- admin.register_servlets,
- knock.register_servlets,
- login.register_servlets,
- room.register_servlets,
- ]
-
- def default_config(self) -> JsonDict:
- config = super().default_config()
- # Enable sliding sync
- config["experimental_features"] = {"msc3575_enabled": True}
- return config
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.sliding_sync_handler = self.hs.get_sliding_sync_handler()
- self.store = self.hs.get_datastores().main
- self.event_sources = hs.get_event_sources()
-
- def _get_sync_room_ids_for_user(
- self,
- user: UserID,
- to_token: StreamToken,
- from_token: Optional[StreamToken],
- ) -> Dict[str, _RoomMembershipForUser]:
- """
- Get the rooms the user should be syncing with
- """
- room_membership_for_user_map = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- user=user,
- from_token=from_token,
- to_token=to_token,
- )
- )
- filtered_sync_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms_relevant_for_sync(
- user=user,
- room_membership_for_user_map=room_membership_for_user_map,
- )
- )
-
- return filtered_sync_room_map
-
- def _create_dm_room(
- self,
- inviter_user_id: str,
- inviter_tok: str,
- invitee_user_id: str,
- invitee_tok: str,
- ) -> str:
- """
- Helper to create a DM room as the "inviter" and invite the "invitee" user to the room. The
- "invitee" user also will join the room. The `m.direct` account data will be set
- for both users.
- """
-
- # Create a room and send an invite the other user
- room_id = self.helper.create_room_as(
- inviter_user_id,
- is_public=False,
- tok=inviter_tok,
- )
- self.helper.invite(
- room_id,
- src=inviter_user_id,
- targ=invitee_user_id,
- tok=inviter_tok,
- extra_data={"is_direct": True},
- )
- # Person that was invited joins the room
- self.helper.join(room_id, invitee_user_id, tok=invitee_tok)
-
- # Mimic the client setting the room as a direct message in the global account
- # data
- self.get_success(
- self.store.add_account_data_for_user(
- invitee_user_id,
- AccountDataTypes.DIRECT,
- {inviter_user_id: [room_id]},
- )
- )
- self.get_success(
- self.store.add_account_data_for_user(
- inviter_user_id,
- AccountDataTypes.DIRECT,
- {invitee_user_id: [room_id]},
- )
- )
-
- return room_id
-
- _remote_invite_count: int = 0
-
- def _create_remote_invite_room_for_user(
- self,
- invitee_user_id: str,
- unsigned_invite_room_state: Optional[List[StrippedStateEvent]],
- ) -> str:
- """
- Create a fake invite for a remote room and persist it.
-
- We don't have any state for these kind of rooms and can only rely on the
- stripped state included in the unsigned portion of the invite event to identify
- the room.
-
- Args:
- invitee_user_id: The person being invited
- unsigned_invite_room_state: List of stripped state events to assist the
- receiver in identifying the room.
-
- Returns:
- The room ID of the remote invite room
- """
- invite_room_id = f"!test_room{self._remote_invite_count}:remote_server"
-
- invite_event_dict = {
- "room_id": invite_room_id,
- "sender": "@inviter:remote_server",
- "state_key": invitee_user_id,
- "depth": 1,
- "origin_server_ts": 1,
- "type": EventTypes.Member,
- "content": {"membership": Membership.INVITE},
- "auth_events": [],
- "prev_events": [],
- }
- if unsigned_invite_room_state is not None:
- serialized_stripped_state_events = []
- for stripped_event in unsigned_invite_room_state:
- serialized_stripped_state_events.append(
- {
- "type": stripped_event.type,
- "state_key": stripped_event.state_key,
- "sender": stripped_event.sender,
- "content": stripped_event.content,
- }
- )
-
- invite_event_dict["unsigned"] = {
- "invite_room_state": serialized_stripped_state_events
- }
-
- invite_event = make_event_from_dict(
- invite_event_dict,
- room_version=RoomVersions.V10,
- )
- invite_event.internal_metadata.outlier = True
- invite_event.internal_metadata.out_of_band_membership = True
-
- self.get_success(
- self.store.maybe_store_room_on_outlier_membership(
- room_id=invite_room_id, room_version=invite_event.room_version
- )
- )
- context = EventContext.for_outlier(self.hs.get_storage_controllers())
- persist_controller = self.hs.get_storage_controllers().persistence
- assert persist_controller is not None
- self.get_success(persist_controller.persist_event(invite_event, context))
-
- self._remote_invite_count += 1
-
- return invite_room_id
-
- def test_filter_dm_rooms(self) -> None:
- """
- Test `filter.is_dm` for DM rooms
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- # Create a normal room
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # Create a DM room
- dm_room_id = self._create_dm_room(
- inviter_user_id=user1_id,
- inviter_tok=user1_tok,
- invitee_user_id=user2_id,
- invitee_tok=user2_tok,
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try with `is_dm=True`
- truthy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_dm=True,
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(truthy_filtered_room_map.keys(), {dm_room_id})
-
- # Try with `is_dm=False`
- falsy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_dm=False,
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
-
- def test_filter_encrypted_rooms(self) -> None:
- """
- Test `filter.is_encrypted` for encrypted rooms
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Create an unencrypted room
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # Create an encrypted room
- encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
- self.helper.send_state(
- encrypted_room_id,
- EventTypes.RoomEncryption,
- {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
- tok=user1_tok,
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try with `is_encrypted=True`
- truthy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=True,
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(truthy_filtered_room_map.keys(), {encrypted_room_id})
-
- # Try with `is_encrypted=False`
- falsy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=False,
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
-
- def test_filter_encrypted_server_left_room(self) -> None:
- """
- Test that we can apply a `filter.is_encrypted` against a room that everyone has left.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- before_rooms_token = self.event_sources.get_current_token()
-
- # Create an unencrypted room
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
- # Leave the room
- self.helper.leave(room_id, user1_id, tok=user1_tok)
-
- # Create an encrypted room
- encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
- self.helper.send_state(
- encrypted_room_id,
- EventTypes.RoomEncryption,
- {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
- tok=user1_tok,
- )
- # Leave the room
- self.helper.leave(encrypted_room_id, user1_id, tok=user1_tok)
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- # We're using a `from_token` so that the room is considered `newly_left` and
- # appears in our list of relevant sync rooms
- from_token=before_rooms_token,
- to_token=after_rooms_token,
- )
-
- # Try with `is_encrypted=True`
- truthy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=True,
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(truthy_filtered_room_map.keys(), {encrypted_room_id})
-
- # Try with `is_encrypted=False`
- falsy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=False,
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
-
- def test_filter_encrypted_server_left_room2(self) -> None:
- """
- Test that we can apply a `filter.is_encrypted` against a room that everyone has
- left.
-
- There is still someone local who is invited to the rooms but that doesn't affect
- whether the server is participating in the room (users need to be joined).
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- _user2_tok = self.login(user2_id, "pass")
-
- before_rooms_token = self.event_sources.get_current_token()
-
- # Create an unencrypted room
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
- # Invite user2
- self.helper.invite(room_id, targ=user2_id, tok=user1_tok)
- # User1 leaves the room
- self.helper.leave(room_id, user1_id, tok=user1_tok)
-
- # Create an encrypted room
- encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
- self.helper.send_state(
- encrypted_room_id,
- EventTypes.RoomEncryption,
- {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
- tok=user1_tok,
- )
- # Invite user2
- self.helper.invite(encrypted_room_id, targ=user2_id, tok=user1_tok)
- # User1 leaves the room
- self.helper.leave(encrypted_room_id, user1_id, tok=user1_tok)
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- # We're using a `from_token` so that the room is considered `newly_left` and
- # appears in our list of relevant sync rooms
- from_token=before_rooms_token,
- to_token=after_rooms_token,
- )
-
- # Try with `is_encrypted=True`
- truthy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=True,
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(truthy_filtered_room_map.keys(), {encrypted_room_id})
-
- # Try with `is_encrypted=False`
- falsy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=False,
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
-
- def test_filter_encrypted_after_we_left(self) -> None:
- """
- Test that we can apply a `filter.is_encrypted` against a room that was encrypted
- after we left the room (make sure we don't just use the current state)
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- before_rooms_token = self.event_sources.get_current_token()
-
- # Create an unencrypted room
- room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
- # Leave the room
- self.helper.join(room_id, user1_id, tok=user1_tok)
- self.helper.leave(room_id, user1_id, tok=user1_tok)
-
- # Create a room that will be encrypted
- encrypted_after_we_left_room_id = self.helper.create_room_as(
- user2_id, tok=user2_tok
- )
- # Leave the room
- self.helper.join(encrypted_after_we_left_room_id, user1_id, tok=user1_tok)
- self.helper.leave(encrypted_after_we_left_room_id, user1_id, tok=user1_tok)
-
- # Encrypt the room after we've left
- self.helper.send_state(
- encrypted_after_we_left_room_id,
- EventTypes.RoomEncryption,
- {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
- tok=user2_tok,
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- # We're using a `from_token` so that the room is considered `newly_left` and
- # appears in our list of relevant sync rooms
- from_token=before_rooms_token,
- to_token=after_rooms_token,
- )
-
- # Try with `is_encrypted=True`
- truthy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=True,
- ),
- after_rooms_token,
- )
- )
-
- # Even though we left the room before it was encrypted, we still see it because
- # someone else on our server is still participating in the room and we "leak"
- # the current state to the left user. But we consider the room encryption status
- # to not be a secret given it's often set at the start of the room and it's one
- # of the stripped state events that is normally handed out.
- self.assertEqual(
- truthy_filtered_room_map.keys(), {encrypted_after_we_left_room_id}
- )
-
- # Try with `is_encrypted=False`
- falsy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=False,
- ),
- after_rooms_token,
- )
- )
-
- # Even though we left the room before it was encrypted... (see comment above)
- self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
-
- def test_filter_encrypted_with_remote_invite_room_no_stripped_state(self) -> None:
- """
- Test that we can apply a `filter.is_encrypted` filter against a remote invite
- room without any `unsigned.invite_room_state` (stripped state).
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Create a remote invite room without any `unsigned.invite_room_state`
- _remote_invite_room_id = self._create_remote_invite_room_for_user(
- user1_id, None
- )
-
- # Create an unencrypted room
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # Create an encrypted room
- encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
- self.helper.send_state(
- encrypted_room_id,
- EventTypes.RoomEncryption,
- {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
- tok=user1_tok,
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try with `is_encrypted=True`
- truthy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=True,
- ),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should not appear because we can't figure out whether
- # it is encrypted or not (no stripped state, `unsigned.invite_room_state`).
- self.assertEqual(truthy_filtered_room_map.keys(), {encrypted_room_id})
-
- # Try with `is_encrypted=False`
- falsy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=False,
- ),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should not appear because we can't figure out whether
- # it is encrypted or not (no stripped state, `unsigned.invite_room_state`).
- self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
-
- def test_filter_encrypted_with_remote_invite_encrypted_room(self) -> None:
- """
- Test that we can apply a `filter.is_encrypted` filter against a remote invite
- encrypted room with some `unsigned.invite_room_state` (stripped state).
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Create a remote invite room with some `unsigned.invite_room_state`
- # indicating that the room is encrypted.
- remote_invite_room_id = self._create_remote_invite_room_for_user(
- user1_id,
- [
- StrippedStateEvent(
- type=EventTypes.Create,
- state_key="",
- sender="@inviter:remote_server",
- content={
- EventContentFields.ROOM_CREATOR: "@inviter:remote_server",
- EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier,
- },
- ),
- StrippedStateEvent(
- type=EventTypes.RoomEncryption,
- state_key="",
- sender="@inviter:remote_server",
- content={
- EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2",
- },
- ),
- ],
- )
-
- # Create an unencrypted room
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # Create an encrypted room
- encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
- self.helper.send_state(
- encrypted_room_id,
- EventTypes.RoomEncryption,
- {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
- tok=user1_tok,
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try with `is_encrypted=True`
- truthy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=True,
- ),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should appear here because it is encrypted
- # according to the stripped state
- self.assertEqual(
- truthy_filtered_room_map.keys(), {encrypted_room_id, remote_invite_room_id}
- )
-
- # Try with `is_encrypted=False`
- falsy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=False,
- ),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should not appear here because it is encrypted
- # according to the stripped state
- self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
-
- def test_filter_encrypted_with_remote_invite_unencrypted_room(self) -> None:
- """
- Test that we can apply a `filter.is_encrypted` filter against a remote invite
- unencrypted room with some `unsigned.invite_room_state` (stripped state).
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Create a remote invite room with some `unsigned.invite_room_state`
- # but don't set any room encryption event.
- remote_invite_room_id = self._create_remote_invite_room_for_user(
- user1_id,
- [
- StrippedStateEvent(
- type=EventTypes.Create,
- state_key="",
- sender="@inviter:remote_server",
- content={
- EventContentFields.ROOM_CREATOR: "@inviter:remote_server",
- EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier,
- },
- ),
- # No room encryption event
- ],
- )
-
- # Create an unencrypted room
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # Create an encrypted room
- encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
- self.helper.send_state(
- encrypted_room_id,
- EventTypes.RoomEncryption,
- {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
- tok=user1_tok,
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try with `is_encrypted=True`
- truthy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=True,
- ),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should not appear here because it is unencrypted
- # according to the stripped state
- self.assertEqual(truthy_filtered_room_map.keys(), {encrypted_room_id})
-
- # Try with `is_encrypted=False`
- falsy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=False,
- ),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should appear because it is unencrypted according to
- # the stripped state
- self.assertEqual(
- falsy_filtered_room_map.keys(), {room_id, remote_invite_room_id}
- )
-
- def test_filter_invite_rooms(self) -> None:
- """
- Test `filter.is_invite` for rooms that the user has been invited to
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- # Create a normal room
- room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id, user1_id, tok=user1_tok)
-
- # Create a room that user1 is invited to
- invite_room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.invite(invite_room_id, src=user2_id, targ=user1_id, tok=user2_tok)
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try with `is_invite=True`
- truthy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_invite=True,
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(truthy_filtered_room_map.keys(), {invite_room_id})
-
- # Try with `is_invite=False`
- falsy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_invite=False,
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
-
- def test_filter_room_types(self) -> None:
- """
- Test `filter.room_types` for different room types
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Create a normal room (no room type)
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # Create a space room
- space_room_id = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- extra_content={
- "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
- },
- )
-
- # Create an arbitrarily typed room
- foo_room_id = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- extra_content={
- "creation_content": {
- EventContentFields.ROOM_TYPE: "org.matrix.foobarbaz"
- }
- },
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try finding only normal rooms
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[None]),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {room_id})
-
- # Try finding only spaces
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[RoomTypes.SPACE]),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {space_room_id})
-
- # Try finding normal rooms and spaces
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- room_types=[None, RoomTypes.SPACE]
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {room_id, space_room_id})
-
- # Try finding an arbitrary room type
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- room_types=["org.matrix.foobarbaz"]
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {foo_room_id})
-
- def test_filter_not_room_types(self) -> None:
- """
- Test `filter.not_room_types` for different room types
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Create a normal room (no room type)
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # Create a space room
- space_room_id = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- extra_content={
- "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
- },
- )
-
- # Create an arbitrarily typed room
- foo_room_id = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- extra_content={
- "creation_content": {
- EventContentFields.ROOM_TYPE: "org.matrix.foobarbaz"
- }
- },
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try finding *NOT* normal rooms
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(not_room_types=[None]),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {space_room_id, foo_room_id})
-
- # Try finding *NOT* spaces
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- not_room_types=[RoomTypes.SPACE]
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {room_id, foo_room_id})
-
- # Try finding *NOT* normal rooms or spaces
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- not_room_types=[None, RoomTypes.SPACE]
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {foo_room_id})
-
- # Test how it behaves when we have both `room_types` and `not_room_types`.
- # `not_room_types` should win.
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- room_types=[None], not_room_types=[None]
- ),
- after_rooms_token,
- )
- )
-
- # Nothing matches because nothing is both a normal room and not a normal room
- self.assertEqual(filtered_room_map.keys(), set())
-
- # Test how it behaves when we have both `room_types` and `not_room_types`.
- # `not_room_types` should win.
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- room_types=[None, RoomTypes.SPACE], not_room_types=[None]
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {space_room_id})
-
- def test_filter_room_types_server_left_room(self) -> None:
- """
- Test that we can apply a `filter.room_types` against a room that everyone has left.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- before_rooms_token = self.event_sources.get_current_token()
-
- # Create a normal room (no room type)
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
- # Leave the room
- self.helper.leave(room_id, user1_id, tok=user1_tok)
-
- # Create a space room
- space_room_id = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- extra_content={
- "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
- },
- )
- # Leave the room
- self.helper.leave(space_room_id, user1_id, tok=user1_tok)
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- # We're using a `from_token` so that the room is considered `newly_left` and
- # appears in our list of relevant sync rooms
- from_token=before_rooms_token,
- to_token=after_rooms_token,
- )
-
- # Try finding only normal rooms
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[None]),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {room_id})
-
- # Try finding only spaces
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[RoomTypes.SPACE]),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {space_room_id})
-
- def test_filter_room_types_server_left_room2(self) -> None:
- """
- Test that we can apply a `filter.room_types` against a room that everyone has left.
-
- There is still someone local who is invited to the rooms but that doesn't affect
- whether the server is participating in the room (users need to be joined).
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- _user2_tok = self.login(user2_id, "pass")
-
- before_rooms_token = self.event_sources.get_current_token()
-
- # Create a normal room (no room type)
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
- # Invite user2
- self.helper.invite(room_id, targ=user2_id, tok=user1_tok)
- # User1 leaves the room
- self.helper.leave(room_id, user1_id, tok=user1_tok)
-
- # Create a space room
- space_room_id = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- extra_content={
- "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
- },
- )
- # Invite user2
- self.helper.invite(space_room_id, targ=user2_id, tok=user1_tok)
- # User1 leaves the room
- self.helper.leave(space_room_id, user1_id, tok=user1_tok)
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- # We're using a `from_token` so that the room is considered `newly_left` and
- # appears in our list of relevant sync rooms
- from_token=before_rooms_token,
- to_token=after_rooms_token,
- )
-
- # Try finding only normal rooms
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[None]),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {room_id})
-
- # Try finding only spaces
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[RoomTypes.SPACE]),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {space_room_id})
-
- def test_filter_room_types_with_remote_invite_room_no_stripped_state(self) -> None:
- """
- Test that we can apply a `filter.room_types` filter against a remote invite
- room without any `unsigned.invite_room_state` (stripped state).
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Create a remote invite room without any `unsigned.invite_room_state`
- _remote_invite_room_id = self._create_remote_invite_room_for_user(
- user1_id, None
- )
-
- # Create a normal room (no room type)
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # Create a space room
- space_room_id = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- extra_content={
- "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
- },
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try finding only normal rooms
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[None]),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should not appear because we can't figure out what
- # room type it is (no stripped state, `unsigned.invite_room_state`)
- self.assertEqual(filtered_room_map.keys(), {room_id})
-
- # Try finding only spaces
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[RoomTypes.SPACE]),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should not appear because we can't figure out what
- # room type it is (no stripped state, `unsigned.invite_room_state`)
- self.assertEqual(filtered_room_map.keys(), {space_room_id})
-
- def test_filter_room_types_with_remote_invite_space(self) -> None:
- """
- Test that we can apply a `filter.room_types` filter against a remote invite
- to a space room with some `unsigned.invite_room_state` (stripped state).
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Create a remote invite room with some `unsigned.invite_room_state` indicating
- # that it is a space room
- remote_invite_room_id = self._create_remote_invite_room_for_user(
- user1_id,
- [
- StrippedStateEvent(
- type=EventTypes.Create,
- state_key="",
- sender="@inviter:remote_server",
- content={
- EventContentFields.ROOM_CREATOR: "@inviter:remote_server",
- EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier,
- # Specify that it is a space room
- EventContentFields.ROOM_TYPE: RoomTypes.SPACE,
- },
- ),
- ],
- )
-
- # Create a normal room (no room type)
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # Create a space room
- space_room_id = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- extra_content={
- "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
- },
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try finding only normal rooms
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[None]),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should not appear here because it is a space room
- # according to the stripped state
- self.assertEqual(filtered_room_map.keys(), {room_id})
-
- # Try finding only spaces
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[RoomTypes.SPACE]),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should appear here because it is a space room
- # according to the stripped state
- self.assertEqual(
- filtered_room_map.keys(), {space_room_id, remote_invite_room_id}
- )
-
- def test_filter_room_types_with_remote_invite_normal_room(self) -> None:
- """
- Test that we can apply a `filter.room_types` filter against a remote invite
- to a normal room with some `unsigned.invite_room_state` (stripped state).
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Create a remote invite room with some `unsigned.invite_room_state`
- # but the create event does not specify a room type (normal room)
- remote_invite_room_id = self._create_remote_invite_room_for_user(
- user1_id,
- [
- StrippedStateEvent(
- type=EventTypes.Create,
- state_key="",
- sender="@inviter:remote_server",
- content={
- EventContentFields.ROOM_CREATOR: "@inviter:remote_server",
- EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier,
- # No room type means this is a normal room
- },
- ),
- ],
- )
-
- # Create a normal room (no room type)
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # Create a space room
- space_room_id = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- extra_content={
- "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
- },
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try finding only normal rooms
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[None]),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should appear here because it is a normal room
- # according to the stripped state (no room type)
- self.assertEqual(filtered_room_map.keys(), {room_id, remote_invite_room_id})
-
- # Try finding only spaces
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[RoomTypes.SPACE]),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should not appear here because it is a normal room
- # according to the stripped state (no room type)
- self.assertEqual(filtered_room_map.keys(), {space_room_id})
+ self.assertTrue(room_id1 in newly_left)
class SortRoomsTestCase(HomeserverTestCase):
@@ -4361,25 +3615,26 @@ class SortRoomsTestCase(HomeserverTestCase):
user: UserID,
to_token: StreamToken,
from_token: Optional[StreamToken],
- ) -> Dict[str, _RoomMembershipForUser]:
+ ) -> Tuple[Dict[str, RoomsForUserType], AbstractSet[str], AbstractSet[str]]:
"""
Get the rooms the user should be syncing with
"""
- room_membership_for_user_map = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
+ room_membership_for_user_map, newly_joined, newly_left = self.get_success(
+ self.sliding_sync_handler.room_lists.get_room_membership_for_user_at_to_token(
user=user,
from_token=from_token,
to_token=to_token,
)
)
filtered_sync_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms_relevant_for_sync(
+ self.sliding_sync_handler.room_lists.filter_rooms_relevant_for_sync(
user=user,
room_membership_for_user_map=room_membership_for_user_map,
+ newly_left_room_ids=newly_left,
)
)
- return filtered_sync_room_map
+ return filtered_sync_room_map, newly_joined, newly_left
def test_sort_activity_basic(self) -> None:
"""
@@ -4400,7 +3655,7 @@ class SortRoomsTestCase(HomeserverTestCase):
after_rooms_token = self.event_sources.get_current_token()
# Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
+ sync_room_map, newly_joined, newly_left = self._get_sync_room_ids_for_user(
UserID.from_string(user1_id),
from_token=None,
to_token=after_rooms_token,
@@ -4408,7 +3663,7 @@ class SortRoomsTestCase(HomeserverTestCase):
# Sort the rooms (what we're testing)
sorted_sync_rooms = self.get_success(
- self.sliding_sync_handler.sort_rooms(
+ self.sliding_sync_handler.room_lists.sort_rooms(
sync_room_map=sync_room_map,
to_token=after_rooms_token,
)
@@ -4481,7 +3736,7 @@ class SortRoomsTestCase(HomeserverTestCase):
self.helper.send(room_id3, "activity in room3", tok=user2_tok)
# Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
+ sync_room_map, newly_joined, newly_left = self._get_sync_room_ids_for_user(
UserID.from_string(user1_id),
from_token=before_rooms_token,
to_token=after_rooms_token,
@@ -4489,7 +3744,7 @@ class SortRoomsTestCase(HomeserverTestCase):
# Sort the rooms (what we're testing)
sorted_sync_rooms = self.get_success(
- self.sliding_sync_handler.sort_rooms(
+ self.sliding_sync_handler.room_lists.sort_rooms(
sync_room_map=sync_room_map,
to_token=after_rooms_token,
)
@@ -4545,7 +3800,7 @@ class SortRoomsTestCase(HomeserverTestCase):
after_rooms_token = self.event_sources.get_current_token()
# Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
+ sync_room_map, newly_joined, newly_left = self._get_sync_room_ids_for_user(
UserID.from_string(user1_id),
from_token=None,
to_token=after_rooms_token,
@@ -4553,7 +3808,7 @@ class SortRoomsTestCase(HomeserverTestCase):
# Sort the rooms (what we're testing)
sorted_sync_rooms = self.get_success(
- self.sliding_sync_handler.sort_rooms(
+ self.sliding_sync_handler.room_lists.sort_rooms(
sync_room_map=sync_room_map,
to_token=after_rooms_token,
)
@@ -4565,3 +3820,1071 @@ class SortRoomsTestCase(HomeserverTestCase):
# We only care about the *latest* event in the room.
[room_id1, room_id2],
)
+
+
+@attr.s(slots=True, auto_attribs=True, frozen=True)
+class RequiredStateChangesTestParameters:
+ previous_required_state_map: Dict[str, Set[str]]
+ request_required_state_map: Dict[str, Set[str]]
+ state_deltas: StateMap[str]
+ expected_with_state_deltas: Tuple[
+ Optional[Mapping[str, AbstractSet[str]]], StateFilter
+ ]
+ expected_without_state_deltas: Tuple[
+ Optional[Mapping[str, AbstractSet[str]]], StateFilter
+ ]
+
+
+class RequiredStateChangesTestCase(unittest.TestCase):
+ """Test cases for `_required_state_changes`"""
+
+ @parameterized.expand(
+ [
+ (
+ "simple_no_change",
+ """Test no change to required state""",
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={"type1": {"state_key"}},
+ request_required_state_map={"type1": {"state_key"}},
+ state_deltas={("type1", "state_key"): "$event_id"},
+ # No changes
+ expected_with_state_deltas=(None, StateFilter.none()),
+ expected_without_state_deltas=(None, StateFilter.none()),
+ ),
+ ),
+ (
+ "simple_add_type",
+ """Test adding a type to the config""",
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={"type1": {"state_key"}},
+ request_required_state_map={
+ "type1": {"state_key"},
+ "type2": {"state_key"},
+ },
+ state_deltas={("type2", "state_key"): "$event_id"},
+ expected_with_state_deltas=(
+ # We've added a type so we should persist the changed required state
+ # config.
+ {"type1": {"state_key"}, "type2": {"state_key"}},
+ # We should see the new type added
+ StateFilter.from_types([("type2", "state_key")]),
+ ),
+ expected_without_state_deltas=(
+ {"type1": {"state_key"}, "type2": {"state_key"}},
+ StateFilter.from_types([("type2", "state_key")]),
+ ),
+ ),
+ ),
+ (
+ "simple_add_type_from_nothing",
+ """Test adding a type to the config when previously requesting nothing""",
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={},
+ request_required_state_map={
+ "type1": {"state_key"},
+ "type2": {"state_key"},
+ },
+ state_deltas={("type2", "state_key"): "$event_id"},
+ expected_with_state_deltas=(
+ # We've added a type so we should persist the changed required state
+ # config.
+ {"type1": {"state_key"}, "type2": {"state_key"}},
+ # We should see the new types added
+ StateFilter.from_types(
+ [("type1", "state_key"), ("type2", "state_key")]
+ ),
+ ),
+ expected_without_state_deltas=(
+ {"type1": {"state_key"}, "type2": {"state_key"}},
+ StateFilter.from_types(
+ [("type1", "state_key"), ("type2", "state_key")]
+ ),
+ ),
+ ),
+ ),
+ (
+ "simple_add_state_key",
+ """Test adding a state key to the config""",
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={"type": {"state_key1"}},
+ request_required_state_map={"type": {"state_key1", "state_key2"}},
+ state_deltas={("type", "state_key2"): "$event_id"},
+ expected_with_state_deltas=(
+ # We've added a key so we should persist the changed required state
+ # config.
+ {"type": {"state_key1", "state_key2"}},
+ # We should see the new state_keys added
+ StateFilter.from_types([("type", "state_key2")]),
+ ),
+ expected_without_state_deltas=(
+ {"type": {"state_key1", "state_key2"}},
+ StateFilter.from_types([("type", "state_key2")]),
+ ),
+ ),
+ ),
+ (
+ "simple_retain_previous_state_keys",
+ """Test adding a state key to the config and retaining a previously sent state_key""",
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={"type": {"state_key1"}},
+ request_required_state_map={"type": {"state_key2", "state_key3"}},
+ state_deltas={("type", "state_key2"): "$event_id"},
+ expected_with_state_deltas=(
+ # We've added a key so we should persist the changed required state
+ # config.
+ #
+ # Retain `state_key1` from the `previous_required_state_map`
+ {"type": {"state_key1", "state_key2", "state_key3"}},
+ # We should see the new state_keys added
+ StateFilter.from_types(
+ [("type", "state_key2"), ("type", "state_key3")]
+ ),
+ ),
+ expected_without_state_deltas=(
+ {"type": {"state_key1", "state_key2", "state_key3"}},
+ StateFilter.from_types(
+ [("type", "state_key2"), ("type", "state_key3")]
+ ),
+ ),
+ ),
+ ),
+ (
+ "simple_remove_type",
+ """
+ Test removing a type from the config when there are a matching state
+ delta does cause the persisted required state config to change
+
+ Test removing a type from the config when there are no matching state
+ deltas does *not* cause the persisted required state config to change
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={
+ "type1": {"state_key"},
+ "type2": {"state_key"},
+ },
+ request_required_state_map={"type1": {"state_key"}},
+ state_deltas={("type2", "state_key"): "$event_id"},
+ expected_with_state_deltas=(
+ # Remove `type2` since there's been a change to that state,
+ # (persist the change to required state). That way next time,
+ # they request `type2`, we see that we haven't sent it before
+ # and send the new state. (we should still keep track that we've
+ # sent `type1` before).
+ {"type1": {"state_key"}},
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ expected_without_state_deltas=(
+ # `type2` is no longer requested but since that state hasn't
+ # changed, nothing should change (we should still keep track
+ # that we've sent `type2` before).
+ None,
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ ),
+ ),
+ (
+ "simple_remove_type_to_nothing",
+ """
+ Test removing a type from the config and no longer requesting any state
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={
+ "type1": {"state_key"},
+ "type2": {"state_key"},
+ },
+ request_required_state_map={},
+ state_deltas={("type2", "state_key"): "$event_id"},
+ expected_with_state_deltas=(
+ # Remove `type2` since there's been a change to that state,
+ # (persist the change to required state). That way next time,
+ # they request `type2`, we see that we haven't sent it before
+ # and send the new state. (we should still keep track that we've
+ # sent `type1` before).
+ {"type1": {"state_key"}},
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ expected_without_state_deltas=(
+ # `type2` is no longer requested but since that state hasn't
+ # changed, nothing should change (we should still keep track
+ # that we've sent `type2` before).
+ None,
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ ),
+ ),
+ (
+ "simple_remove_state_key",
+ """
+ Test removing a state_key from the config
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={"type": {"state_key1", "state_key2"}},
+ request_required_state_map={"type": {"state_key1"}},
+ state_deltas={("type", "state_key2"): "$event_id"},
+ expected_with_state_deltas=(
+ # Remove `(type, state_key2)` since there's been a change
+ # to that state (persist the change to required state).
+ # That way next time, they request `(type, state_key2)`, we see
+ # that we haven't sent it before and send the new state. (we
+ # should still keep track that we've sent `(type, state_key1)`
+ # before).
+ {"type": {"state_key1"}},
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ expected_without_state_deltas=(
+ # `(type, state_key2)` is no longer requested but since that
+ # state hasn't changed, nothing should change (we should still
+ # keep track that we've sent `(type, state_key1)` and `(type,
+ # state_key2)` before).
+ None,
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ ),
+ ),
+ (
+ "type_wildcards_add",
+ """
+ Test adding a wildcard type causes the persisted required state config
+ to change and we request everything.
+
+ If a event type wildcard has been added or removed we don't try and do
+ anything fancy, and instead always update the effective room required
+ state config to match the request.
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={"type1": {"state_key2"}},
+ request_required_state_map={
+ "type1": {"state_key2"},
+ StateValues.WILDCARD: {"state_key"},
+ },
+ state_deltas={
+ ("other_type", "state_key"): "$event_id",
+ },
+ # We've added a wildcard, so we persist the change and request everything
+ expected_with_state_deltas=(
+ {"type1": {"state_key2"}, StateValues.WILDCARD: {"state_key"}},
+ StateFilter.all(),
+ ),
+ expected_without_state_deltas=(
+ {"type1": {"state_key2"}, StateValues.WILDCARD: {"state_key"}},
+ StateFilter.all(),
+ ),
+ ),
+ ),
+ (
+ "type_wildcards_remove",
+ """
+ Test removing a wildcard type causes the persisted required state config
+ to change and request nothing.
+
+ If a event type wildcard has been added or removed we don't try and do
+ anything fancy, and instead always update the effective room required
+ state config to match the request.
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={
+ "type1": {"state_key2"},
+ StateValues.WILDCARD: {"state_key"},
+ },
+ request_required_state_map={"type1": {"state_key2"}},
+ state_deltas={
+ ("other_type", "state_key"): "$event_id",
+ },
+ # We've removed a type wildcard, so we persist the change but don't request anything
+ expected_with_state_deltas=(
+ {"type1": {"state_key2"}},
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ expected_without_state_deltas=(
+ {"type1": {"state_key2"}},
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ ),
+ ),
+ (
+ "state_key_wildcards_add",
+ """Test adding a wildcard state_key""",
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={"type1": {"state_key"}},
+ request_required_state_map={
+ "type1": {"state_key"},
+ "type2": {StateValues.WILDCARD},
+ },
+ state_deltas={("type2", "state_key"): "$event_id"},
+ # We've added a wildcard state_key, so we persist the change and
+ # request all of the state for that type
+ expected_with_state_deltas=(
+ {"type1": {"state_key"}, "type2": {StateValues.WILDCARD}},
+ StateFilter.from_types([("type2", None)]),
+ ),
+ expected_without_state_deltas=(
+ {"type1": {"state_key"}, "type2": {StateValues.WILDCARD}},
+ StateFilter.from_types([("type2", None)]),
+ ),
+ ),
+ ),
+ (
+ "state_key_wildcards_remove",
+ """Test removing a wildcard state_key""",
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={
+ "type1": {"state_key"},
+ "type2": {StateValues.WILDCARD},
+ },
+ request_required_state_map={"type1": {"state_key"}},
+ state_deltas={("type2", "state_key"): "$event_id"},
+ # We've removed a state_key wildcard, so we persist the change and
+ # request nothing
+ expected_with_state_deltas=(
+ {"type1": {"state_key"}},
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ # We've removed a state_key wildcard but there have been no matching
+ # state changes, so no changes needed, just persist the
+ # `request_required_state_map` as-is.
+ expected_without_state_deltas=(
+ None,
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ ),
+ ),
+ (
+ "state_key_remove_some",
+ """
+ Test that removing state keys work when only some of the state keys have
+ changed
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={
+ "type1": {"state_key1", "state_key2", "state_key3"}
+ },
+ request_required_state_map={"type1": {"state_key1"}},
+ state_deltas={("type1", "state_key3"): "$event_id"},
+ expected_with_state_deltas=(
+ # We've removed some state keys from the type, but only state_key3 was
+ # changed so only that one should be removed.
+ {"type1": {"state_key1", "state_key2"}},
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ expected_without_state_deltas=(
+ # No changes needed, just persist the
+ # `request_required_state_map` as-is
+ None,
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ ),
+ ),
+ (
+ "state_key_me_add",
+ """
+ Test adding state keys work when using "$ME"
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={},
+ request_required_state_map={"type1": {StateValues.ME}},
+ state_deltas={("type1", "@user:test"): "$event_id"},
+ expected_with_state_deltas=(
+ # We've added a type so we should persist the changed required state
+ # config.
+ {"type1": {StateValues.ME}},
+ # We should see the new state_keys added
+ StateFilter.from_types([("type1", "@user:test")]),
+ ),
+ expected_without_state_deltas=(
+ {"type1": {StateValues.ME}},
+ StateFilter.from_types([("type1", "@user:test")]),
+ ),
+ ),
+ ),
+ (
+ "state_key_me_remove",
+ """
+ Test removing state keys work when using "$ME"
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={"type1": {StateValues.ME}},
+ request_required_state_map={},
+ state_deltas={("type1", "@user:test"): "$event_id"},
+ expected_with_state_deltas=(
+ # Remove `type1` since there's been a change to that state,
+ # (persist the change to required state). That way next time,
+ # they request `type1`, we see that we haven't sent it before
+ # and send the new state. (if we were tracking that we sent any
+ # other state, we should still keep track that).
+ {},
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ expected_without_state_deltas=(
+ # `type1` is no longer requested but since that state hasn't
+ # changed, nothing should change (we should still keep track
+ # that we've sent `type1` before).
+ None,
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ ),
+ ),
+ (
+ "state_key_user_id_add",
+ """
+ Test adding state keys work when using your own user ID
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={},
+ request_required_state_map={"type1": {"@user:test"}},
+ state_deltas={("type1", "@user:test"): "$event_id"},
+ expected_with_state_deltas=(
+ # We've added a type so we should persist the changed required state
+ # config.
+ {"type1": {"@user:test"}},
+ # We should see the new state_keys added
+ StateFilter.from_types([("type1", "@user:test")]),
+ ),
+ expected_without_state_deltas=(
+ {"type1": {"@user:test"}},
+ StateFilter.from_types([("type1", "@user:test")]),
+ ),
+ ),
+ ),
+ (
+ "state_key_me_remove",
+ """
+ Test removing state keys work when using your own user ID
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={"type1": {"@user:test"}},
+ request_required_state_map={},
+ state_deltas={("type1", "@user:test"): "$event_id"},
+ expected_with_state_deltas=(
+ # Remove `type1` since there's been a change to that state,
+ # (persist the change to required state). That way next time,
+ # they request `type1`, we see that we haven't sent it before
+ # and send the new state. (if we were tracking that we sent any
+ # other state, we should still keep track that).
+ {},
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ expected_without_state_deltas=(
+ # `type1` is no longer requested but since that state hasn't
+ # changed, nothing should change (we should still keep track
+ # that we've sent `type1` before).
+ None,
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ ),
+ ),
+ (
+ "state_key_lazy_add",
+ """
+ Test adding state keys work when using "$LAZY"
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={},
+ request_required_state_map={EventTypes.Member: {StateValues.LAZY}},
+ state_deltas={(EventTypes.Member, "@user:test"): "$event_id"},
+ expected_with_state_deltas=(
+ # If a "$LAZY" has been added or removed we always update the
+ # required state to what was requested for simplicity.
+ {EventTypes.Member: {StateValues.LAZY}},
+ StateFilter.none(),
+ ),
+ expected_without_state_deltas=(
+ {EventTypes.Member: {StateValues.LAZY}},
+ StateFilter.none(),
+ ),
+ ),
+ ),
+ (
+ "state_key_lazy_remove",
+ """
+ Test removing state keys work when using "$LAZY"
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={EventTypes.Member: {StateValues.LAZY}},
+ request_required_state_map={},
+ state_deltas={(EventTypes.Member, "@user:test"): "$event_id"},
+ expected_with_state_deltas=(
+ # If a "$LAZY" has been added or removed we always update the
+ # required state to what was requested for simplicity.
+ {},
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ expected_without_state_deltas=(
+ # `EventTypes.Member` is no longer requested but since that
+ # state hasn't changed, nothing should change (we should still
+ # keep track that we've sent `EventTypes.Member` before).
+ None,
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ ),
+ ),
+ (
+ "state_key_lazy_keep_previous_memberships_and_no_new_memberships",
+ """
+ This test mimics a request with lazy-loading room members enabled where
+ we have previously sent down user2 and user3's membership events and now
+ we're sending down another response without any timeline events.
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={
+ EventTypes.Member: {
+ StateValues.LAZY,
+ "@user2:test",
+ "@user3:test",
+ }
+ },
+ request_required_state_map={EventTypes.Member: {StateValues.LAZY}},
+ state_deltas={(EventTypes.Member, "@user2:test"): "$event_id"},
+ expected_with_state_deltas=(
+ # Remove "@user2:test" since that state has changed and is no
+ # longer being requested anymore. Since something was removed,
+ # we should persist the changed to required state. That way next
+ # time, they request "@user2:test", we see that we haven't sent
+ # it before and send the new state. (we should still keep track
+ # that we've sent specific `EventTypes.Member` before)
+ {
+ EventTypes.Member: {
+ StateValues.LAZY,
+ "@user3:test",
+ }
+ },
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ expected_without_state_deltas=(
+ # We're not requesting any specific `EventTypes.Member` now but
+ # since that state hasn't changed, nothing should change (we
+ # should still keep track that we've sent specific
+ # `EventTypes.Member` before).
+ None,
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ ),
+ ),
+ (
+ "state_key_lazy_keep_previous_memberships_with_new_memberships",
+ """
+ This test mimics a request with lazy-loading room members enabled where
+ we have previously sent down user2 and user3's membership events and now
+ we're sending down another response with a new event from user4.
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={
+ EventTypes.Member: {
+ StateValues.LAZY,
+ "@user2:test",
+ "@user3:test",
+ }
+ },
+ request_required_state_map={
+ EventTypes.Member: {StateValues.LAZY, "@user4:test"}
+ },
+ state_deltas={(EventTypes.Member, "@user2:test"): "$event_id"},
+ expected_with_state_deltas=(
+ # Since "@user4:test" was added, we should persist the changed
+ # required state config.
+ #
+ # Also remove "@user2:test" since that state has changed and is no
+ # longer being requested anymore. Since something was removed,
+ # we also should persist the changed to required state. That way next
+ # time, they request "@user2:test", we see that we haven't sent
+ # it before and send the new state. (we should still keep track
+ # that we've sent specific `EventTypes.Member` before)
+ {
+ EventTypes.Member: {
+ StateValues.LAZY,
+ "@user3:test",
+ "@user4:test",
+ }
+ },
+ # We should see the new state_keys added
+ StateFilter.from_types([(EventTypes.Member, "@user4:test")]),
+ ),
+ expected_without_state_deltas=(
+ # Since "@user4:test" was added, we should persist the changed
+ # required state config.
+ {
+ EventTypes.Member: {
+ StateValues.LAZY,
+ "@user2:test",
+ "@user3:test",
+ "@user4:test",
+ }
+ },
+ # We should see the new state_keys added
+ StateFilter.from_types([(EventTypes.Member, "@user4:test")]),
+ ),
+ ),
+ ),
+ (
+ "state_key_expand_lazy_keep_previous_memberships",
+ """
+ Test expanding the `required_state` to lazy-loading room members.
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={
+ EventTypes.Member: {"@user2:test", "@user3:test"}
+ },
+ request_required_state_map={EventTypes.Member: {StateValues.LAZY}},
+ state_deltas={(EventTypes.Member, "@user2:test"): "$event_id"},
+ expected_with_state_deltas=(
+ # Since `StateValues.LAZY` was added, we should persist the
+ # changed required state config.
+ #
+ # Also remove "@user2:test" since that state has changed and is no
+ # longer being requested anymore. Since something was removed,
+ # we also should persist the changed to required state. That way next
+ # time, they request "@user2:test", we see that we haven't sent
+ # it before and send the new state. (we should still keep track
+ # that we've sent specific `EventTypes.Member` before)
+ {
+ EventTypes.Member: {
+ StateValues.LAZY,
+ "@user3:test",
+ }
+ },
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ expected_without_state_deltas=(
+ # Since `StateValues.LAZY` was added, we should persist the
+ # changed required state config.
+ {
+ EventTypes.Member: {
+ StateValues.LAZY,
+ "@user2:test",
+ "@user3:test",
+ }
+ },
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ ),
+ ),
+ (
+ "state_key_retract_lazy_keep_previous_memberships_no_new_memberships",
+ """
+ Test retracting the `required_state` to no longer lazy-loading room members.
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={
+ EventTypes.Member: {
+ StateValues.LAZY,
+ "@user2:test",
+ "@user3:test",
+ }
+ },
+ request_required_state_map={},
+ state_deltas={(EventTypes.Member, "@user2:test"): "$event_id"},
+ expected_with_state_deltas=(
+ # Remove `EventTypes.Member` since there's been a change to that
+ # state, (persist the change to required state). That way next
+ # time, they request `EventTypes.Member`, we see that we haven't
+ # sent it before and send the new state. (if we were tracking
+ # that we sent any other state, we should still keep track
+ # that).
+ #
+ # This acts the same as the `simple_remove_type` test. It's
+ # possible that we could remember the specific `state_keys` that
+ # we have sent down before but this currently just acts the same
+ # as if a whole `type` was removed. Perhaps it's good that we
+ # "garbage collect" and forget what we've sent before for a
+ # given `type` when the client stops caring about a certain
+ # `type`.
+ {},
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ expected_without_state_deltas=(
+ # `EventTypes.Member` is no longer requested but since that
+ # state hasn't changed, nothing should change (we should still
+ # keep track that we've sent `EventTypes.Member` before).
+ None,
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ ),
+ ),
+ (
+ "state_key_retract_lazy_keep_previous_memberships_with_new_memberships",
+ """
+ Test retracting the `required_state` to no longer lazy-loading room members.
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={
+ EventTypes.Member: {
+ StateValues.LAZY,
+ "@user2:test",
+ "@user3:test",
+ }
+ },
+ request_required_state_map={EventTypes.Member: {"@user4:test"}},
+ state_deltas={(EventTypes.Member, "@user2:test"): "$event_id"},
+ expected_with_state_deltas=(
+ # Since "@user4:test" was added, we should persist the changed
+ # required state config.
+ #
+ # Also remove "@user2:test" since that state has changed and is no
+ # longer being requested anymore. Since something was removed,
+ # we also should persist the changed to required state. That way next
+ # time, they request "@user2:test", we see that we haven't sent
+ # it before and send the new state. (we should still keep track
+ # that we've sent specific `EventTypes.Member` before)
+ {
+ EventTypes.Member: {
+ "@user3:test",
+ "@user4:test",
+ }
+ },
+ # We should see the new state_keys added
+ StateFilter.from_types([(EventTypes.Member, "@user4:test")]),
+ ),
+ expected_without_state_deltas=(
+ # Since "@user4:test" was added, we should persist the changed
+ # required state config.
+ {
+ EventTypes.Member: {
+ "@user2:test",
+ "@user3:test",
+ "@user4:test",
+ }
+ },
+ # We should see the new state_keys added
+ StateFilter.from_types([(EventTypes.Member, "@user4:test")]),
+ ),
+ ),
+ ),
+ (
+ "type_wildcard_with_state_key_wildcard_to_explicit_state_keys",
+ """
+ Test switching from a wildcard ("*", "*") to explicit state keys
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={
+ StateValues.WILDCARD: {StateValues.WILDCARD}
+ },
+ request_required_state_map={
+ StateValues.WILDCARD: {"state_key1", "state_key2", "state_key3"}
+ },
+ state_deltas={("type1", "state_key1"): "$event_id"},
+ # If we were previously fetching everything ("*", "*"), always update the effective
+ # room required state config to match the request. And since we we're previously
+ # already fetching everything, we don't have to fetch anything now that they've
+ # narrowed.
+ expected_with_state_deltas=(
+ {
+ StateValues.WILDCARD: {
+ "state_key1",
+ "state_key2",
+ "state_key3",
+ }
+ },
+ StateFilter.none(),
+ ),
+ expected_without_state_deltas=(
+ {
+ StateValues.WILDCARD: {
+ "state_key1",
+ "state_key2",
+ "state_key3",
+ }
+ },
+ StateFilter.none(),
+ ),
+ ),
+ ),
+ (
+ "type_wildcard_with_explicit_state_keys_to_wildcard_state_key",
+ """
+ Test switching from explicit to wildcard state keys ("*", "*")
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={
+ StateValues.WILDCARD: {"state_key1", "state_key2", "state_key3"}
+ },
+ request_required_state_map={
+ StateValues.WILDCARD: {StateValues.WILDCARD}
+ },
+ state_deltas={("type1", "state_key1"): "$event_id"},
+ # We've added a wildcard, so we persist the change and request everything
+ expected_with_state_deltas=(
+ {StateValues.WILDCARD: {StateValues.WILDCARD}},
+ StateFilter.all(),
+ ),
+ expected_without_state_deltas=(
+ {StateValues.WILDCARD: {StateValues.WILDCARD}},
+ StateFilter.all(),
+ ),
+ ),
+ ),
+ (
+ "state_key_wildcard_to_explicit_state_keys",
+ """Test switching from a wildcard to explicit state keys with a concrete type""",
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={"type1": {StateValues.WILDCARD}},
+ request_required_state_map={
+ "type1": {"state_key1", "state_key2", "state_key3"}
+ },
+ state_deltas={("type1", "state_key1"): "$event_id"},
+ # If a state_key wildcard has been added or removed, we always
+ # update the effective room required state config to match the
+ # request. And since we we're previously already fetching
+ # everything, we don't have to fetch anything now that they've
+ # narrowed.
+ expected_with_state_deltas=(
+ {
+ "type1": {
+ "state_key1",
+ "state_key2",
+ "state_key3",
+ }
+ },
+ StateFilter.none(),
+ ),
+ expected_without_state_deltas=(
+ {
+ "type1": {
+ "state_key1",
+ "state_key2",
+ "state_key3",
+ }
+ },
+ StateFilter.none(),
+ ),
+ ),
+ ),
+ (
+ "explicit_state_keys_to_wildcard_state_key",
+ """Test switching from a wildcard to explicit state keys with a concrete type""",
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={
+ "type1": {"state_key1", "state_key2", "state_key3"}
+ },
+ request_required_state_map={"type1": {StateValues.WILDCARD}},
+ state_deltas={("type1", "state_key1"): "$event_id"},
+ # If a state_key wildcard has been added or removed, we always
+ # update the effective room required state config to match the
+ # request. And we need to request all of the state for that type
+ # because we previously, only sent down a few keys.
+ expected_with_state_deltas=(
+ {"type1": {StateValues.WILDCARD, "state_key2", "state_key3"}},
+ StateFilter.from_types([("type1", None)]),
+ ),
+ expected_without_state_deltas=(
+ {
+ "type1": {
+ StateValues.WILDCARD,
+ "state_key1",
+ "state_key2",
+ "state_key3",
+ }
+ },
+ StateFilter.from_types([("type1", None)]),
+ ),
+ ),
+ ),
+ ]
+ )
+ def test_xxx(
+ self,
+ _test_label: str,
+ _test_description: str,
+ test_parameters: RequiredStateChangesTestParameters,
+ ) -> None:
+ # Without `state_deltas`
+ changed_required_state_map, added_state_filter = _required_state_changes(
+ user_id="@user:test",
+ prev_required_state_map=test_parameters.previous_required_state_map,
+ request_required_state_map=test_parameters.request_required_state_map,
+ state_deltas={},
+ )
+
+ self.assertEqual(
+ changed_required_state_map,
+ test_parameters.expected_without_state_deltas[0],
+ "changed_required_state_map does not match (without state_deltas)",
+ )
+ self.assertEqual(
+ added_state_filter,
+ test_parameters.expected_without_state_deltas[1],
+ "added_state_filter does not match (without state_deltas)",
+ )
+
+ # With `state_deltas`
+ changed_required_state_map, added_state_filter = _required_state_changes(
+ user_id="@user:test",
+ prev_required_state_map=test_parameters.previous_required_state_map,
+ request_required_state_map=test_parameters.request_required_state_map,
+ state_deltas=test_parameters.state_deltas,
+ )
+
+ self.assertEqual(
+ changed_required_state_map,
+ test_parameters.expected_with_state_deltas[0],
+ "changed_required_state_map does not match (with state_deltas)",
+ )
+ self.assertEqual(
+ added_state_filter,
+ test_parameters.expected_with_state_deltas[1],
+ "added_state_filter does not match (with state_deltas)",
+ )
+
+ @parameterized.expand(
+ [
+ # Test with a normal arbitrary type (no special meaning)
+ ("arbitrary_type", "type", set()),
+ # Test with membership
+ ("membership", EventTypes.Member, set()),
+ # Test with lazy-loading room members
+ ("lazy_loading_membership", EventTypes.Member, {StateValues.LAZY}),
+ ]
+ )
+ def test_limit_retained_previous_state_keys(
+ self,
+ _test_label: str,
+ event_type: str,
+ extra_state_keys: Set[str],
+ ) -> None:
+ """
+ Test that we limit the number of state_keys that we remember but always include
+ the state_keys that we've just requested.
+ """
+ previous_required_state_map = {
+ event_type: {
+ # Prefix the state_keys we've "prev_"iously sent so they are easier to
+ # identify in our assertions.
+ f"prev_state_key{i}"
+ for i in range(MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER - 30)
+ }
+ | extra_state_keys
+ }
+ request_required_state_map = {
+ event_type: {f"state_key{i}" for i in range(50)} | extra_state_keys
+ }
+
+ # (function under test)
+ changed_required_state_map, added_state_filter = _required_state_changes(
+ user_id="@user:test",
+ prev_required_state_map=previous_required_state_map,
+ request_required_state_map=request_required_state_map,
+ state_deltas={},
+ )
+ assert changed_required_state_map is not None
+
+ # We should only remember up to the maximum number of state keys
+ self.assertGreaterEqual(
+ len(changed_required_state_map[event_type]),
+ # Most of the time this will be `MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER` but
+ # because we are just naively selecting enough previous state_keys to fill
+ # the limit, there might be some overlap in what's added back which means we
+ # might have slightly less than the limit.
+ #
+ # `extra_state_keys` overlaps in the previous and requested
+ # `required_state_map` so we might see this this scenario.
+ MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER - len(extra_state_keys),
+ )
+
+ # Should include all of the requested state
+ self.assertIncludes(
+ changed_required_state_map[event_type],
+ request_required_state_map[event_type],
+ )
+ # And the rest is filled with the previous state keys
+ #
+ # We can't assert the exact state_keys since we don't know the order so we just
+ # check that they all start with "prev_" and that we have the correct amount.
+ remaining_state_keys = (
+ changed_required_state_map[event_type]
+ - request_required_state_map[event_type]
+ )
+ self.assertGreater(
+ len(remaining_state_keys),
+ 0,
+ )
+ assert all(
+ state_key.startswith("prev_") for state_key in remaining_state_keys
+ ), "Remaining state_keys should be the previous state_keys"
+
+ def test_request_more_state_keys_than_remember_limit(self) -> None:
+ """
+ Test requesting more state_keys than fit in our limit to remember from previous
+ requests.
+ """
+ previous_required_state_map = {
+ "type": {
+ # Prefix the state_keys we've "prev_"iously sent so they are easier to
+ # identify in our assertions.
+ f"prev_state_key{i}"
+ for i in range(MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER - 30)
+ }
+ }
+ request_required_state_map = {
+ "type": {
+ f"state_key{i}"
+ # Requesting more than the MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER
+ for i in range(MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER + 20)
+ }
+ }
+ # Ensure that we are requesting more than the limit
+ self.assertGreater(
+ len(request_required_state_map["type"]),
+ MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER,
+ )
+
+ # (function under test)
+ changed_required_state_map, added_state_filter = _required_state_changes(
+ user_id="@user:test",
+ prev_required_state_map=previous_required_state_map,
+ request_required_state_map=request_required_state_map,
+ state_deltas={},
+ )
+ assert changed_required_state_map is not None
+
+ # Should include all of the requested state
+ self.assertIncludes(
+ changed_required_state_map["type"],
+ request_required_state_map["type"],
+ exact=True,
+ )
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index fa55f76916..6b202dfbd5 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -17,10 +17,11 @@
# [This file includes modifications made by New Vector Limited]
#
#
+from http import HTTPStatus
from typing import Collection, ContextManager, List, Optional
from unittest.mock import AsyncMock, Mock, patch
-from parameterized import parameterized
+from parameterized import parameterized, parameterized_class
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
@@ -32,7 +33,13 @@ from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.federation.federation_base import event_from_pdu_json
-from synapse.handlers.sync import SyncConfig, SyncRequestKey, SyncResult, SyncVersion
+from synapse.handlers.sync import (
+ SyncConfig,
+ SyncRequestKey,
+ SyncResult,
+ SyncVersion,
+ TimelineBatch,
+)
from synapse.rest import admin
from synapse.rest.client import knock, login, room
from synapse.server import HomeServer
@@ -58,9 +65,21 @@ def generate_request_key() -> SyncRequestKey:
return ("request_key", _request_key)
+@parameterized_class(
+ ("use_state_after",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'state_after' if params_dict['use_state_after'] else 'state'}",
+)
class SyncTestCase(tests.unittest.HomeserverTestCase):
"""Tests Sync Handler."""
+ use_state_after: bool
+
servlets = [
admin.register_servlets,
knock.register_servlets,
@@ -79,7 +98,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
def test_wait_for_sync_for_user_auth_blocking(self) -> None:
user_id1 = "@user1:test"
user_id2 = "@user2:test"
- sync_config = generate_sync_config(user_id1)
+ sync_config = generate_sync_config(
+ user_id1, use_state_after=self.use_state_after
+ )
requester = create_requester(user_id1)
self.reactor.advance(100) # So we get not 0 time
@@ -112,7 +133,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.auth_blocking._hs_disabled = False
- sync_config = generate_sync_config(user_id2)
+ sync_config = generate_sync_config(
+ user_id2, use_state_after=self.use_state_after
+ )
requester = create_requester(user_id2)
e = self.get_failure(
@@ -141,7 +164,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
initial_result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
requester,
- sync_config=generate_sync_config(user, device_id="dev"),
+ sync_config=generate_sync_config(
+ user, device_id="dev", use_state_after=self.use_state_after
+ ),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -175,7 +200,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
requester,
- sync_config=generate_sync_config(user),
+ sync_config=generate_sync_config(
+ user, use_state_after=self.use_state_after
+ ),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -188,7 +215,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
requester,
- sync_config=generate_sync_config(user, device_id="dev"),
+ sync_config=generate_sync_config(
+ user, device_id="dev", use_state_after=self.use_state_after
+ ),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
since_token=initial_result.next_batch,
@@ -220,7 +249,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
requester,
- sync_config=generate_sync_config(user),
+ sync_config=generate_sync_config(
+ user, use_state_after=self.use_state_after
+ ),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -233,7 +264,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
requester,
- sync_config=generate_sync_config(user, device_id="dev"),
+ sync_config=generate_sync_config(
+ user, device_id="dev", use_state_after=self.use_state_after
+ ),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
since_token=initial_result.next_batch,
@@ -276,7 +309,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
alice_sync_result: SyncResult = self.get_success(
self.sync_handler.wait_for_sync_for_user(
create_requester(owner),
- generate_sync_config(owner),
+ generate_sync_config(owner, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -296,7 +329,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Eve syncs.
eve_requester = create_requester(eve)
- eve_sync_config = generate_sync_config(eve)
+ eve_sync_config = generate_sync_config(
+ eve, use_state_after=self.use_state_after
+ )
eve_sync_after_ban: SyncResult = self.get_success(
self.sync_handler.wait_for_sync_for_user(
eve_requester,
@@ -313,7 +348,15 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# the prev_events used when creating the join event, such that the ban does not
# precede the join.
with self._patch_get_latest_events([last_room_creation_event_id]):
- self.helper.join(room_id, eve, tok=eve_token)
+ self.helper.join(
+ room_id,
+ eve,
+ tok=eve_token,
+ # Previously, this join would succeed but now we expect it to fail at
+ # this point. The rest of the test is for the case when this used to
+ # succeed.
+ expect_code=HTTPStatus.FORBIDDEN,
+ )
# Eve makes a second, incremental sync.
eve_incremental_sync_after_join: SyncResult = self.get_success(
@@ -367,7 +410,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
initial_sync_result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
alice_requester,
- generate_sync_config(alice),
+ generate_sync_config(alice, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -396,6 +439,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
filter_collection=FilterCollection(
self.hs, {"room": {"timeline": {"limit": 2}}}
),
+ use_state_after=self.use_state_after,
),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
@@ -442,7 +486,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
initial_sync_result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
alice_requester,
- generate_sync_config(alice),
+ generate_sync_config(alice, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -481,6 +525,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
}
},
),
+ use_state_after=self.use_state_after,
),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
@@ -518,6 +563,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
... and a filter that means we only return 1 event, represented by the dashed
horizontal lines: `S2` must be included in the `state` section on the second sync.
+
+ When `use_state_after` is enabled, then we expect to see `s2` in the first sync.
"""
alice = self.register_user("alice", "password")
alice_tok = self.login(alice, "password")
@@ -528,7 +575,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
initial_sync_result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
alice_requester,
- generate_sync_config(alice),
+ generate_sync_config(alice, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -554,6 +601,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
filter_collection=FilterCollection(
self.hs, {"room": {"timeline": {"limit": 1}}}
),
+ use_state_after=self.use_state_after,
),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
@@ -567,10 +615,18 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
[e.event_id for e in room_sync.timeline.events],
[e3_event],
)
- self.assertEqual(
- [e.event_id for e in room_sync.state.values()],
- [],
- )
+
+ if self.use_state_after:
+ # When using `state_after` we get told about s2 immediately
+ self.assertEqual(
+ [e.event_id for e in room_sync.state.values()],
+ [s2_event],
+ )
+ else:
+ self.assertEqual(
+ [e.event_id for e in room_sync.state.values()],
+ [],
+ )
# Now send another event that points to S2, but not E3.
with self._patch_get_latest_events([s2_event]):
@@ -585,6 +641,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
filter_collection=FilterCollection(
self.hs, {"room": {"timeline": {"limit": 1}}}
),
+ use_state_after=self.use_state_after,
),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
@@ -598,10 +655,19 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
[e.event_id for e in room_sync.timeline.events],
[e4_event],
)
- self.assertEqual(
- [e.event_id for e in room_sync.state.values()],
- [s2_event],
- )
+
+ if self.use_state_after:
+ # When using `state_after` we got told about s2 previously, so we
+ # don't again.
+ self.assertEqual(
+ [e.event_id for e in room_sync.state.values()],
+ [],
+ )
+ else:
+ self.assertEqual(
+ [e.event_id for e in room_sync.state.values()],
+ [s2_event],
+ )
def test_state_includes_changes_on_ungappy_syncs(self) -> None:
"""Test `state` where the sync is not gappy.
@@ -638,6 +704,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
This is the last chance for us to tell the client about S2, so it *must* be
included in the response.
+
+ When `use_state_after` is enabled, then we expect to see `s2` in the first sync.
"""
alice = self.register_user("alice", "password")
alice_tok = self.login(alice, "password")
@@ -648,7 +716,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
initial_sync_result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
alice_requester,
- generate_sync_config(alice),
+ generate_sync_config(alice, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -673,6 +741,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
filter_collection=FilterCollection(
self.hs, {"room": {"timeline": {"limit": 1}}}
),
+ use_state_after=self.use_state_after,
),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
@@ -684,7 +753,11 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
[e.event_id for e in room_sync.timeline.events],
[e3_event],
)
- self.assertNotIn(s2_event, [e.event_id for e in room_sync.state.values()])
+ if self.use_state_after:
+ # When using `state_after` we get told about s2 immediately
+ self.assertIn(s2_event, [e.event_id for e in room_sync.state.values()])
+ else:
+ self.assertNotIn(s2_event, [e.event_id for e in room_sync.state.values()])
# More events, E4 and E5
with self._patch_get_latest_events([e3_event]):
@@ -695,7 +768,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
incremental_sync = self.get_success(
self.sync_handler.wait_for_sync_for_user(
alice_requester,
- generate_sync_config(alice),
+ generate_sync_config(alice, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
since_token=initial_sync_result.next_batch,
@@ -710,10 +783,19 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
[e.event_id for e in room_sync.timeline.events],
[e4_event, e5_event],
)
- self.assertEqual(
- [e.event_id for e in room_sync.state.values()],
- [s2_event],
- )
+
+ if self.use_state_after:
+ # When using `state_after` we got told about s2 previously, so we
+ # don't again.
+ self.assertEqual(
+ [e.event_id for e in room_sync.state.values()],
+ [],
+ )
+ else:
+ self.assertEqual(
+ [e.event_id for e in room_sync.state.values()],
+ [s2_event],
+ )
@parameterized.expand(
[
@@ -721,7 +803,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
(True, False),
(False, True),
(True, True),
- ]
+ ],
+ name_func=lambda func, num, p: f"{func.__name__}_{p.args[0]}_{p.args[1]}",
)
def test_archived_rooms_do_not_include_state_after_leave(
self, initial_sync: bool, empty_timeline: bool
@@ -749,7 +832,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
initial_sync_result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
bob_requester,
- generate_sync_config(bob),
+ generate_sync_config(bob, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -780,7 +863,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler.wait_for_sync_for_user(
bob_requester,
generate_sync_config(
- bob, filter_collection=FilterCollection(self.hs, filter_dict)
+ bob,
+ filter_collection=FilterCollection(self.hs, filter_dict),
+ use_state_after=self.use_state_after,
),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
@@ -791,7 +876,15 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
if empty_timeline:
# The timeline should be empty
self.assertEqual(sync_room_result.timeline.events, [])
+ else:
+ # The last three events in the timeline should be those leading up to the
+ # leave
+ self.assertEqual(
+ [e.event_id for e in sync_room_result.timeline.events[-3:]],
+ [before_message_event, before_state_event, leave_event],
+ )
+ if empty_timeline or self.use_state_after:
# And the state should include the leave event...
self.assertEqual(
sync_room_result.state[("m.room.member", bob)].event_id, leave_event
@@ -801,12 +894,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
sync_room_result.state[("test_state", "")].event_id, before_state_event
)
else:
- # The last three events in the timeline should be those leading up to the
- # leave
- self.assertEqual(
- [e.event_id for e in sync_room_result.timeline.events[-3:]],
- [before_message_event, before_state_event, leave_event],
- )
# ... And the state should be empty
self.assertEqual(sync_room_result.state, {})
@@ -843,7 +930,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
) -> List[EventBase]:
return list(pdus)
- self.client._check_sigs_and_hash_for_pulled_events_and_fetch = _check_sigs_and_hash_for_pulled_events_and_fetch # type: ignore[assignment]
+ self.client._check_sigs_and_hash_for_pulled_events_and_fetch = ( # type: ignore[method-assign]
+ _check_sigs_and_hash_for_pulled_events_and_fetch # type: ignore[assignment]
+ )
prev_events = self.get_success(self.store.get_prev_events_for_room(room_id))
@@ -877,7 +966,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
sync_result: SyncResult = self.get_success(
self.sync_handler.wait_for_sync_for_user(
create_requester(user),
- generate_sync_config(user),
+ generate_sync_config(user, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -926,7 +1015,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
private_sync_result: SyncResult = self.get_success(
self.sync_handler.wait_for_sync_for_user(
create_requester(user2),
- generate_sync_config(user2),
+ generate_sync_config(user2, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -952,7 +1041,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
sync_result: SyncResult = self.get_success(
self.sync_handler.wait_for_sync_for_user(
create_requester(user),
- generate_sync_config(user),
+ generate_sync_config(user, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -989,7 +1078,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
sync_d = defer.ensureDeferred(
self.sync_handler.wait_for_sync_for_user(
create_requester(user),
- generate_sync_config(user),
+ generate_sync_config(user, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
since_token=since_token,
@@ -1044,7 +1133,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
sync_d = defer.ensureDeferred(
self.sync_handler.wait_for_sync_for_user(
create_requester(user),
- generate_sync_config(user),
+ generate_sync_config(user, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
since_token=since_token,
@@ -1060,6 +1149,7 @@ def generate_sync_config(
user_id: str,
device_id: Optional[str] = "device_id",
filter_collection: Optional[FilterCollection] = None,
+ use_state_after: bool = False,
) -> SyncConfig:
"""Generate a sync config (with a unique request key).
@@ -1067,7 +1157,8 @@ def generate_sync_config(
user_id: user who is syncing.
device_id: device that is syncing. Defaults to "device_id".
filter_collection: filter to apply. Defaults to the default filter (ie,
- return everything, with a default limit)
+ return everything, with a default limit)
+ use_state_after: whether the `use_state_after` flag was set.
"""
if filter_collection is None:
filter_collection = Filtering(Mock()).DEFAULT_FILTER_COLLECTION
@@ -1077,4 +1168,138 @@ def generate_sync_config(
filter_collection=filter_collection,
is_guest=False,
device_id=device_id,
+ use_state_after=use_state_after,
)
+
+
+class SyncStateAfterTestCase(tests.unittest.HomeserverTestCase):
+ """Tests Sync Handler state behavior when using `use_state_after."""
+
+ servlets = [
+ admin.register_servlets,
+ knock.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.sync_handler = self.hs.get_sync_handler()
+ self.store = self.hs.get_datastores().main
+
+ # AuthBlocking reads from the hs' config on initialization. We need to
+ # modify its config instead of the hs'
+ self.auth_blocking = self.hs.get_auth_blocking()
+
+ def test_initial_sync_multiple_deltas(self) -> None:
+ """Test that if multiple state deltas have happened during processing of
+ a full state sync we return the correct state"""
+
+ user = self.register_user("user", "password")
+ tok = self.login("user", "password")
+
+ # Create a room as the user and set some custom state.
+ joined_room = self.helper.create_room_as(user, tok=tok)
+
+ first_state = self.helper.send_state(
+ joined_room, event_type="m.test_event", body={"num": 1}, tok=tok
+ )
+
+ # Take a snapshot of the stream token, to simulate doing an initial sync
+ # at this point.
+ end_stream_token = self.hs.get_event_sources().get_current_token()
+
+ # Send some state *after* the stream token
+ self.helper.send_state(
+ joined_room, event_type="m.test_event", body={"num": 2}, tok=tok
+ )
+
+ # Calculating the full state will return the first state, and not the
+ # second.
+ state = self.get_success(
+ self.sync_handler._compute_state_delta_for_full_sync(
+ room_id=joined_room,
+ sync_config=generate_sync_config(user, use_state_after=True),
+ batch=TimelineBatch(
+ prev_batch=end_stream_token, events=[], limited=True
+ ),
+ end_token=end_stream_token,
+ members_to_fetch=None,
+ timeline_state={},
+ joined=True,
+ )
+ )
+ self.assertEqual(state[("m.test_event", "")], first_state["event_id"])
+
+ def test_incremental_sync_multiple_deltas(self) -> None:
+ """Test that if multiple state deltas have happened since an incremental
+ state sync we return the correct state"""
+
+ user = self.register_user("user", "password")
+ tok = self.login("user", "password")
+
+ # Create a room as the user and set some custom state.
+ joined_room = self.helper.create_room_as(user, tok=tok)
+
+ # Take a snapshot of the stream token, to simulate doing an incremental sync
+ # from this point.
+ since_token = self.hs.get_event_sources().get_current_token()
+
+ self.helper.send_state(
+ joined_room, event_type="m.test_event", body={"num": 1}, tok=tok
+ )
+
+ # Send some state *after* the stream token
+ second_state = self.helper.send_state(
+ joined_room, event_type="m.test_event", body={"num": 2}, tok=tok
+ )
+
+ end_stream_token = self.hs.get_event_sources().get_current_token()
+
+ # Calculating the incrementals state will return the second state, and not the
+ # first.
+ state = self.get_success(
+ self.sync_handler._compute_state_delta_for_incremental_sync(
+ room_id=joined_room,
+ sync_config=generate_sync_config(user, use_state_after=True),
+ batch=TimelineBatch(
+ prev_batch=end_stream_token, events=[], limited=True
+ ),
+ since_token=since_token,
+ end_token=end_stream_token,
+ members_to_fetch=None,
+ timeline_state={},
+ )
+ )
+ self.assertEqual(state[("m.test_event", "")], second_state["event_id"])
+
+ def test_incremental_sync_lazy_loaded_no_timeline(self) -> None:
+ """Test that lazy-loading with an empty timeline doesn't return the full
+ state.
+
+ There was a bug where an empty state filter would cause the DB to return
+ the full state, rather than an empty set.
+ """
+ user = self.register_user("user", "password")
+ tok = self.login("user", "password")
+
+ # Create a room as the user and set some custom state.
+ joined_room = self.helper.create_room_as(user, tok=tok)
+
+ since_token = self.hs.get_event_sources().get_current_token()
+ end_stream_token = self.hs.get_event_sources().get_current_token()
+
+ state = self.get_success(
+ self.sync_handler._compute_state_delta_for_incremental_sync(
+ room_id=joined_room,
+ sync_config=generate_sync_config(user, use_state_after=True),
+ batch=TimelineBatch(
+ prev_batch=end_stream_token, events=[], limited=True
+ ),
+ since_token=since_token,
+ end_token=end_stream_token,
+ members_to_fetch=set(),
+ timeline_state={},
+ )
+ )
+
+ self.assertEqual(state, {})
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 878d9683b6..b12ffc3665 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -796,6 +796,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 1)
+ # Kept old spam checker without `requester_id` tests for backwards compatibility.
async def allow_all(user_profile: UserProfile) -> bool:
# Allow all users.
return False
@@ -809,6 +810,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 1)
+ # Kept old spam checker without `requester_id` tests for backwards compatibility.
# Configure a spam checker that filters all users.
async def block_all(user_profile: UserProfile) -> bool:
# All users are spammy.
@@ -820,6 +822,40 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 0)
+ async def allow_all_expects_requester_id(
+ user_profile: UserProfile, requester_id: str
+ ) -> bool:
+ self.assertEqual(requester_id, u1)
+ # Allow all users.
+ return False
+
+ # Configure a spam checker that does not filter any users.
+ spam_checker = self.hs.get_module_api_callbacks().spam_checker
+ spam_checker._check_username_for_spam_callbacks = [
+ allow_all_expects_requester_id
+ ]
+
+ # The results do not change:
+ # We get one search result when searching for user2 by user1.
+ s = self.get_success(self.handler.search_users(u1, "user2", 10))
+ self.assertEqual(len(s["results"]), 1)
+
+ # Configure a spam checker that filters all users.
+ async def block_all_expects_requester_id(
+ user_profile: UserProfile, requester_id: str
+ ) -> bool:
+ self.assertEqual(requester_id, u1)
+ # All users are spammy.
+ return True
+
+ spam_checker._check_username_for_spam_callbacks = [
+ block_all_expects_requester_id
+ ]
+
+ # User1 now gets no search results for any of the other users.
+ s = self.get_success(self.handler.search_users(u1, "user2", 10))
+ self.assertEqual(len(s["results"]), 0)
+
@override_config(
{
"spam_checker": {
@@ -956,6 +992,67 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
[self.assertIn(user, local_users) for user in received_user_id_ordering[:3]]
[self.assertIn(user, remote_users) for user in received_user_id_ordering[3:]]
+ @override_config(
+ {
+ "user_directory": {
+ "enabled": True,
+ "search_all_users": True,
+ "exclude_remote_users": True,
+ }
+ }
+ )
+ def test_exclude_remote_users(self) -> None:
+ """Tests that only local users are returned when
+ user_directory.exclude_remote_users is True.
+ """
+
+ # Create a room and few users to test the directory with
+ searching_user = self.register_user("searcher", "password")
+ searching_user_tok = self.login("searcher", "password")
+
+ room_id = self.helper.create_room_as(
+ searching_user,
+ room_version=RoomVersions.V1.identifier,
+ tok=searching_user_tok,
+ )
+
+ # Create a few local users and join them to the room
+ local_user_1 = self.register_user("user_xxxxx", "password")
+ local_user_2 = self.register_user("user_bbbbb", "password")
+ local_user_3 = self.register_user("user_zzzzz", "password")
+
+ self._add_user_to_room(room_id, RoomVersions.V1, local_user_1)
+ self._add_user_to_room(room_id, RoomVersions.V1, local_user_2)
+ self._add_user_to_room(room_id, RoomVersions.V1, local_user_3)
+
+ # Create a few "remote" users and join them to the room
+ remote_user_1 = "@user_aaaaa:remote_server"
+ remote_user_2 = "@user_yyyyy:remote_server"
+ remote_user_3 = "@user_ccccc:remote_server"
+ self._add_user_to_room(room_id, RoomVersions.V1, remote_user_1)
+ self._add_user_to_room(room_id, RoomVersions.V1, remote_user_2)
+ self._add_user_to_room(room_id, RoomVersions.V1, remote_user_3)
+
+ local_users = [local_user_1, local_user_2, local_user_3]
+ remote_users = [remote_user_1, remote_user_2, remote_user_3]
+
+ # The local searching user searches for the term "user", which other users have
+ # in their user id
+ results = self.get_success(
+ self.handler.search_users(searching_user, "user", 20)
+ )["results"]
+ received_user_ids = [result["user_id"] for result in results]
+
+ for user in local_users:
+ self.assertIn(
+ user, received_user_ids, f"Local user {user} not found in results"
+ )
+
+ for user in remote_users:
+ self.assertNotIn(
+ user, received_user_ids, f"Remote user {user} should not be in results"
+ )
+
def _add_user_to_room(
self,
room_id: str,
@@ -1081,10 +1178,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
for use_numeric in [False, True]:
if use_numeric:
prefix1 = f"{i}"
- prefix2 = f"{i+1}"
+ prefix2 = f"{i + 1}"
else:
prefix1 = f"a{i}"
- prefix2 = f"a{i+1}"
+ prefix2 = f"a{i + 1}"
local_user_1 = self.register_user(f"user{char}{prefix1}", "password")
local_user_2 = self.register_user(f"user{char}{prefix2}", "password")
diff --git a/tests/handlers/test_worker_lock.py b/tests/handlers/test_worker_lock.py
index 6e9a15c8ee..0691d3f99c 100644
--- a/tests/handlers/test_worker_lock.py
+++ b/tests/handlers/test_worker_lock.py
@@ -19,6 +19,9 @@
#
#
+import logging
+import platform
+
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
@@ -29,6 +32,8 @@ from tests import unittest
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.utils import test_timeout
+logger = logging.getLogger(__name__)
+
class WorkerLockTestCase(unittest.HomeserverTestCase):
def prepare(
@@ -53,12 +58,27 @@ class WorkerLockTestCase(unittest.HomeserverTestCase):
def test_lock_contention(self) -> None:
"""Test lock contention when a lot of locks wait on a single worker"""
-
+ nb_locks_to_test = 500
+ current_machine = platform.machine().lower()
+ if current_machine.startswith("riscv"):
+ # RISC-V specific settings
+ timeout_seconds = 15 # Increased timeout for RISC-V
+ # add a print or log statement here for visibility in CI logs
+ logger.info( # use logger.info
+ f"Detected RISC-V architecture ({current_machine}). "
+ f"Adjusting test_lock_contention: timeout={timeout_seconds}s"
+ )
+ else:
+ # Settings for other architectures
+ timeout_seconds = 5
# It takes around 0.5s on a 5+ years old laptop
- with test_timeout(5):
- nb_locks = 500
- d = self._take_locks(nb_locks)
- self.assertEqual(self.get_success(d), nb_locks)
+ with test_timeout(timeout_seconds): # Use the dynamically set timeout
+ d = self._take_locks(
+ nb_locks_to_test
+ ) # Use the (potentially adjusted) number of locks
+ self.assertEqual(
+ self.get_success(d), nb_locks_to_test
+ ) # Assert against the used number of locks
async def _take_locks(self, nb_locks: int) -> int:
locks = [
diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index 8e8621e348..ffcbf4b3ca 100644
--- a/tests/http/federation/test_srv_resolver.py
+++ b/tests/http/federation/test_srv_resolver.py
@@ -93,9 +93,7 @@ class SrvResolverTestCase(unittest.TestCase):
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
servers: List[Server]
- servers = yield defer.ensureDeferred(
- resolver.resolve_service(service_name)
- ) # type: ignore[assignment]
+ servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) # type: ignore[assignment]
dns_client_mock.lookupService.assert_called_once_with(service_name)
@@ -122,9 +120,7 @@ class SrvResolverTestCase(unittest.TestCase):
)
servers: List[Server]
- servers = yield defer.ensureDeferred(
- resolver.resolve_service(service_name)
- ) # type: ignore[assignment]
+ servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) # type: ignore[assignment]
self.assertFalse(dns_client_mock.lookupService.called)
@@ -157,9 +153,7 @@ class SrvResolverTestCase(unittest.TestCase):
resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
servers: List[Server]
- servers = yield defer.ensureDeferred(
- resolver.resolve_service(service_name)
- ) # type: ignore[assignment]
+ servers = yield defer.ensureDeferred(resolver.resolve_service(service_name)) # type: ignore[assignment]
self.assertEqual(len(servers), 0)
self.assertEqual(len(cache), 0)
diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py
index 731b0c4e59..dff5a5d262 100644
--- a/tests/http/server/_base.py
+++ b/tests/http/server/_base.py
@@ -27,6 +27,7 @@ from typing import (
Callable,
ContextManager,
Dict,
+ Generator,
List,
Optional,
Set,
@@ -49,7 +50,10 @@ from synapse.http.server import (
respond_with_json,
)
from synapse.http.site import SynapseRequest
-from synapse.logging.context import LoggingContext, make_deferred_yieldable
+from synapse.logging.context import (
+ LoggingContext,
+ make_deferred_yieldable,
+)
from synapse.types import JsonDict
from tests.server import FakeChannel, make_request
@@ -199,7 +203,7 @@ def make_request_with_cancellation_test(
#
# We would like to trigger a cancellation at the first `await`, re-run the
# request and cancel at the second `await`, and so on. By patching
- # `Deferred.__next__`, we can intercept `await`s, track which ones we have or
+ # `Deferred.__await__`, we can intercept `await`s, track which ones we have or
# have not seen, and force them to block when they wouldn't have.
# The set of previously seen `await`s.
@@ -211,7 +215,7 @@ def make_request_with_cancellation_test(
)
for request_number in itertools.count(1):
- deferred_patch = Deferred__next__Patch(seen_awaits, request_number)
+ deferred_patch = Deferred__await__Patch(seen_awaits, request_number)
try:
with mock.patch(
@@ -250,6 +254,8 @@ def make_request_with_cancellation_test(
)
if respond_mock.called:
+ _log_for_request(request_number, "--- response finished ---")
+
# The request ran to completion and we are done with testing it.
# `respond_with_json` writes the response asynchronously, so we
@@ -311,8 +317,8 @@ def make_request_with_cancellation_test(
assert False, "unreachable" # noqa: B011
-class Deferred__next__Patch:
- """A `Deferred.__next__` patch that will intercept `await`s and force them
+class Deferred__await__Patch:
+ """A `Deferred.__await__` patch that will intercept `await`s and force them
to block once it sees a new `await`.
When done with the patch, `unblock_awaits()` must be called to clean up after any
@@ -322,7 +328,7 @@ class Deferred__next__Patch:
Usage:
seen_awaits = set()
- deferred_patch = Deferred__next__Patch(seen_awaits, 1)
+ deferred_patch = Deferred__await__Patch(seen_awaits, 1)
try:
with deferred_patch.patch():
# do things
@@ -335,14 +341,14 @@ class Deferred__next__Patch:
"""
Args:
seen_awaits: The set of stack traces of `await`s that have been previously
- seen. When the `Deferred.__next__` patch sees a new `await`, it will add
+ seen. When the `Deferred.__await__` patch sees a new `await`, it will add
it to the set.
request_number: The request number to log against.
"""
self._request_number = request_number
self._seen_awaits = seen_awaits
- self._original_Deferred___next__ = Deferred.__next__ # type: ignore[misc,unused-ignore]
+ self._original_Deferred__await__ = Deferred.__await__ # type: ignore[misc,unused-ignore]
# The number of `await`s on `Deferred`s we have seen so far.
self.awaits_seen = 0
@@ -350,8 +356,13 @@ class Deferred__next__Patch:
# Whether we have seen a new `await` not in `seen_awaits`.
self.new_await_seen = False
+ # Whether to block new await points we see. This gets set to False once
+ # we have cancelled the request to allow things to run after
+ # cancellation.
+ self._block_new_awaits = True
+
# To force `await`s on resolved `Deferred`s to block, we make up a new
- # unresolved `Deferred` and return it out of `Deferred.__next__` /
+ # unresolved `Deferred` and return it out of `Deferred.__await__` /
# `coroutine.send()`. We have to resolve it later, in case the `await`ing
# coroutine is part of some shared processing, such as `@cached`.
self._to_unblock: Dict[Deferred, Union[object, Failure]] = {}
@@ -360,15 +371,15 @@ class Deferred__next__Patch:
self._previous_stack: List[inspect.FrameInfo] = []
def patch(self) -> ContextManager[Mock]:
- """Returns a context manager which patches `Deferred.__next__`."""
+ """Returns a context manager which patches `Deferred.__await__`."""
- def Deferred___next__(
- deferred: "Deferred[T]", value: object = None
- ) -> "Deferred[T]":
- """Intercepts `await`s on `Deferred`s and rigs them to block once we have
- seen enough of them.
+ def Deferred___await__(
+ deferred: "Deferred[T]",
+ ) -> Generator["Deferred[T]", None, T]:
+ """Intercepts calls to `__await__`, which returns a generator
+ yielding deferreds that we await on.
- `Deferred.__next__` will normally:
+ The generator for `__await__` will normally:
* return `self` if the `Deferred` is unresolved, in which case
`coroutine.send()` will return the `Deferred`, and
`_defer.inlineCallbacks` will stop running the coroutine until the
@@ -376,9 +387,43 @@ class Deferred__next__Patch:
* raise a `StopIteration(result)`, containing the result of the `await`.
* raise another exception, which will come out of the `await`.
"""
+
+ # Get the original generator.
+ gen = self._original_Deferred__await__(deferred)
+
+ # Run the generator, handling each iteration to see if we need to
+ # block.
+ try:
+ while True:
+ # We've hit a new await point (or the deferred has
+ # completed), handle it.
+ handle_next_iteration(deferred)
+
+ # Continue on.
+ yield gen.send(None)
+ except StopIteration as e:
+ # We need to convert `StopIteration` into a normal return.
+ return e.value
+
+ def handle_next_iteration(
+ deferred: "Deferred[T]",
+ ) -> None:
+ """Intercepts `await`s on `Deferred`s and rigs them to block once we have
+ seen enough of them.
+
+ Args:
+ deferred: The deferred that we've captured and are intercepting
+ `await` calls within.
+ """
+ if not self._block_new_awaits:
+ # We're no longer blocking awaits points
+ return
+
self.awaits_seen += 1
- stack = _get_stack(skip_frames=1)
+ stack = _get_stack(
+ skip_frames=2 # Ignore this function and `Deferred___await__` in stack trace
+ )
stack_hash = _hash_stack(stack)
if stack_hash not in self._seen_awaits:
@@ -389,20 +434,29 @@ class Deferred__next__Patch:
if not self.new_await_seen:
# This `await` isn't interesting. Let it proceed normally.
+ _log_await_stack(
+ stack,
+ self._previous_stack,
+ self._request_number,
+ "already seen",
+ )
+
# Don't log the stack. It's been seen before in a previous run.
self._previous_stack = stack
- return self._original_Deferred___next__(deferred, value)
+ return
# We want to block at the current `await`.
if deferred.called and not deferred.paused:
- # This `Deferred` already has a result.
- # We return a new, unresolved, `Deferred` for `_inlineCallbacks` to wait
- # on. This blocks the coroutine that did this `await`.
+ # This `Deferred` already has a result. We chain a new,
+ # unresolved, `Deferred` to the end of this Deferred that it
+ # will wait on. This blocks the coroutine that did this `await`.
# We queue it up for unblocking later.
new_deferred: "Deferred[T]" = Deferred()
self._to_unblock[new_deferred] = deferred.result
+ deferred.addBoth(lambda _: make_deferred_yieldable(new_deferred))
+
_log_await_stack(
stack,
self._previous_stack,
@@ -411,7 +465,9 @@ class Deferred__next__Patch:
)
self._previous_stack = stack
- return make_deferred_yieldable(new_deferred)
+ # Continue iterating on the deferred now that we've blocked it
+ # again.
+ return
# This `Deferred` does not have a result yet.
# The `await` will block normally, so we don't have to do anything.
@@ -423,9 +479,9 @@ class Deferred__next__Patch:
)
self._previous_stack = stack
- return self._original_Deferred___next__(deferred, value)
+ return
- return mock.patch.object(Deferred, "__next__", new=Deferred___next__)
+ return mock.patch.object(Deferred, "__await__", new=Deferred___await__)
def unblock_awaits(self) -> None:
"""Unblocks any shared processing that we forced to block.
@@ -433,6 +489,9 @@ class Deferred__next__Patch:
Must be called when done, otherwise processing shared between multiple requests,
such as database queries started by `@cached`, will become permanently stuck.
"""
+ # Also disable blocking at future await points
+ self._block_new_awaits = False
+
to_unblock = self._to_unblock
self._to_unblock = {}
for deferred, result in to_unblock.items():
diff --git a/tests/http/test_client.py b/tests/http/test_client.py
index 721917f957..ac6470ebbd 100644
--- a/tests/http/test_client.py
+++ b/tests/http/test_client.py
@@ -49,8 +49,11 @@ from tests.unittest import TestCase
class ReadMultipartResponseTests(TestCase):
- data1 = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: text/plain\r\nContent-Disposition: inline; filename=test_upload\r\n\r\nfile_"
- data2 = b"to_stream\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n"
+ multipart_response_data1 = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: text/plain\r\nContent-Disposition: inline; filename=test_upload\r\n\r\nfile_"
+ multipart_response_data2 = (
+ b"to_stream\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n"
+ )
+ multipart_response_data_cased = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\ncOntEnt-type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-tyPe: text/plain\r\nconTent-dispOsition: inline; filename=test_upload\r\n\r\nfile_"
redirect_data = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nLocation: https://cdn.example.org/ab/c1/2345.txt\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n"
@@ -103,8 +106,31 @@ class ReadMultipartResponseTests(TestCase):
result, deferred, protocol = self._build_multipart_response(249, 250)
# Start sending data.
- protocol.dataReceived(self.data1)
- protocol.dataReceived(self.data2)
+ protocol.dataReceived(self.multipart_response_data1)
+ protocol.dataReceived(self.multipart_response_data2)
+ # Close the connection.
+ protocol.connectionLost(Failure(ResponseDone()))
+
+ multipart_response: MultipartResponse = deferred.result # type: ignore[assignment]
+
+ self.assertEqual(multipart_response.json, b"{}")
+ self.assertEqual(result.getvalue(), b"file_to_stream")
+ self.assertEqual(multipart_response.length, len(b"file_to_stream"))
+ self.assertEqual(multipart_response.content_type, b"text/plain")
+ self.assertEqual(
+ multipart_response.disposition, b"inline; filename=test_upload"
+ )
+
+ def test_parse_file_lowercase_headers(self) -> None:
+ """
+ Check that a multipart response containing a file is properly parsed
+ into the json/file parts, and the json and file are properly captured if the http headers are lowercased
+ """
+ result, deferred, protocol = self._build_multipart_response(249, 250)
+
+ # Start sending data.
+ protocol.dataReceived(self.multipart_response_data_cased)
+ protocol.dataReceived(self.multipart_response_data2)
# Close the connection.
protocol.connectionLost(Failure(ResponseDone()))
@@ -143,7 +169,7 @@ class ReadMultipartResponseTests(TestCase):
result, deferred, protocol = self._build_multipart_response(UNKNOWN_LENGTH, 180)
# Start sending data.
- protocol.dataReceived(self.data1)
+ protocol.dataReceived(self.multipart_response_data1)
self.assertEqual(result.getvalue(), b"file_")
self._assert_error(deferred, protocol)
@@ -154,11 +180,11 @@ class ReadMultipartResponseTests(TestCase):
result, deferred, protocol = self._build_multipart_response(UNKNOWN_LENGTH, 180)
# Start sending data.
- protocol.dataReceived(self.data1)
+ protocol.dataReceived(self.multipart_response_data1)
self._assert_error(deferred, protocol)
# More data might have come in.
- protocol.dataReceived(self.data2)
+ protocol.dataReceived(self.multipart_response_data2)
self.assertEqual(result.getvalue(), b"file_")
self._assert_error(deferred, protocol)
@@ -172,7 +198,7 @@ class ReadMultipartResponseTests(TestCase):
self.assertFalse(deferred.called)
# Start sending data.
- protocol.dataReceived(self.data1)
+ protocol.dataReceived(self.multipart_response_data1)
self._assert_error(deferred, protocol)
self._cleanup_error(deferred)
@@ -181,7 +207,9 @@ class ReadMultipartResponseTests(TestCase):
class ReadBodyWithMaxSizeTests(TestCase):
- def _build_response(self, length: Union[int, str] = UNKNOWN_LENGTH) -> Tuple[
+ def _build_response(
+ self, length: Union[int, str] = UNKNOWN_LENGTH
+ ) -> Tuple[
BytesIO,
"Deferred[int]",
_DiscardBodyWithMaxSizeProtocol,
diff --git a/tests/http/test_matrixfederationclient.py b/tests/http/test_matrixfederationclient.py
index e2f033fdae..d5ebf10eac 100644
--- a/tests/http/test_matrixfederationclient.py
+++ b/tests/http/test_matrixfederationclient.py
@@ -17,6 +17,7 @@
# [This file includes modifications made by New Vector Limited]
#
#
+import io
from typing import Any, Dict, Generator
from unittest.mock import ANY, Mock, create_autospec
@@ -32,7 +33,9 @@ from twisted.web.http import HTTPChannel
from twisted.web.http_headers import Headers
from synapse.api.errors import HttpResponseException, RequestSendFailed
+from synapse.api.ratelimiting import Ratelimiter
from synapse.config._base import ConfigError
+from synapse.config.ratelimiting import RatelimitSettings
from synapse.http.matrixfederationclient import (
ByteParser,
MatrixFederationHttpClient,
@@ -337,6 +340,81 @@ class FederationClientTests(HomeserverTestCase):
r = self.successResultOf(d)
self.assertEqual(r.code, 200)
+ def test_authed_media_redirect_response(self) -> None:
+ """
+ Validate that, when following a `Location` redirect, the
+ maximum size is _not_ set to the initial response `Content-Length` and
+ the media file can be downloaded.
+ """
+ limiter = Ratelimiter(
+ store=self.hs.get_datastores().main,
+ clock=self.clock,
+ cfg=RatelimitSettings(key="", per_second=0.17, burst_count=1048576),
+ )
+
+ output_stream = io.BytesIO()
+
+ d = defer.ensureDeferred(
+ self.cl.federation_get_file(
+ "testserv:8008", "path", output_stream, limiter, "127.0.0.1", 10000
+ )
+ )
+
+ self.pump()
+
+ clients = self.reactor.tcpClients
+ self.assertEqual(len(clients), 1)
+ (host, port, factory, _timeout, _bindAddress) = clients[0]
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8008)
+
+ # complete the connection and wire it up to a fake transport
+ protocol = factory.buildProtocol(None)
+ transport = StringTransport()
+ protocol.makeConnection(transport)
+
+ # Deferred does not have a result
+ self.assertNoResult(d)
+
+ redirect_data = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nLocation: http://testserv:8008/ab/c1/2345.txt\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n"
+ protocol.dataReceived(
+ b"HTTP/1.1 200 OK\r\n"
+ b"Server: Fake\r\n"
+ b"Content-Length: %i\r\n"
+ b"Content-Type: multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a\r\n\r\n"
+ % (len(redirect_data))
+ )
+ protocol.dataReceived(redirect_data)
+
+ # Still no result, not followed the redirect yet
+ self.assertNoResult(d)
+
+ # Now send the response returned by the server at `Location`
+ clients = self.reactor.tcpClients
+ (host, port, factory, _timeout, _bindAddress) = clients[1]
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8008)
+ protocol = factory.buildProtocol(None)
+ transport = StringTransport()
+ protocol.makeConnection(transport)
+
+ # make sure the length is longer than the initial response
+ data = b"Hello world!" * 30
+ protocol.dataReceived(
+ b"HTTP/1.1 200 OK\r\n"
+ b"Server: Fake\r\n"
+ b"Content-Length: %i\r\n"
+ b"Content-Type: text/plain\r\n"
+ b"\r\n"
+ b"%s\r\n"
+ b"\r\n" % (len(data), data)
+ )
+
+ # We should get a successful response
+ length, _, _ = self.successResultOf(d)
+ self.assertEqual(length, len(data))
+ self.assertEqual(output_stream.getvalue(), data)
+
@parameterized.expand(["get_json", "post_json", "delete_json", "put_json"])
def test_timeout_reading_body(self, method_name: str) -> None:
"""
@@ -358,8 +436,7 @@ class FederationClientTests(HomeserverTestCase):
# Send it the HTTP response
client.dataReceived(
- b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n"
- b"Server: Fake\r\n\r\n"
+ b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nServer: Fake\r\n\r\n"
)
# Push by enough to time it out
@@ -613,10 +690,7 @@ class FederationClientTests(HomeserverTestCase):
# Send it a huge HTTP response
protocol.dataReceived(
- b"HTTP/1.1 200 OK\r\n"
- b"Server: Fake\r\n"
- b"Content-Type: application/json\r\n"
- b"\r\n"
+ b"HTTP/1.1 200 OK\r\nServer: Fake\r\nContent-Type: application/json\r\n\r\n"
)
self.pump()
@@ -817,21 +891,30 @@ class FederationClientProxyTests(BaseMultiWorkerStreamTestCase):
)
# Fake `remoteserv:8008` responding to requests
- mock_agent_on_federation_sender.request.side_effect = lambda *args, **kwargs: defer.succeed(
- FakeResponse(
- code=200,
- body=b'{"foo": "bar"}',
- headers=Headers(
- {
- "Content-Type": ["application/json"],
- "Connection": ["close, X-Foo, X-Bar"],
- # Should be removed because it's defined in the `Connection` header
- "X-Foo": ["foo"],
- "X-Bar": ["bar"],
- # Should be removed because it's a hop-by-hop header
- "Proxy-Authorization": "abcdef",
- }
- ),
+ mock_agent_on_federation_sender.request.side_effect = (
+ lambda *args, **kwargs: defer.succeed(
+ FakeResponse(
+ code=200,
+ body=b'{"foo": "bar"}',
+ headers=Headers(
+ {
+ "Content-Type": ["application/json"],
+ "X-Test": ["test"],
+ # Define some hop-by-hop headers (try with varying casing to
+ # make sure we still match-up the headers)
+ "Connection": ["close, X-fOo, X-Bar, X-baz"],
+ # Should be removed because it's defined in the `Connection` header
+ "X-Foo": ["foo"],
+ "X-Bar": ["bar"],
+ # (not in canonical case)
+ "x-baZ": ["baz"],
+ # Should be removed because it's a hop-by-hop header
+ "Proxy-Authorization": "abcdef",
+ # Should be removed because it's a hop-by-hop header (not in canonical case)
+ "transfer-EnCoDiNg": "abcdef",
+ }
+ ),
+ )
)
)
@@ -858,9 +941,17 @@ class FederationClientProxyTests(BaseMultiWorkerStreamTestCase):
header_names = set(headers.keys())
# Make sure the response does not include the hop-by-hop headers
- self.assertNotIn(b"X-Foo", header_names)
- self.assertNotIn(b"X-Bar", header_names)
- self.assertNotIn(b"Proxy-Authorization", header_names)
+ self.assertIncludes(
+ header_names,
+ {
+ b"Content-Type",
+ b"X-Test",
+ # Default headers from Twisted
+ b"Date",
+ b"Server",
+ },
+ exact=True,
+ )
# Make sure the response is as expected back on the main worker
self.assertEqual(res, {"foo": "bar"})
diff --git a/tests/http/test_proxy.py b/tests/http/test_proxy.py
index 5895270494..7110dcf9f9 100644
--- a/tests/http/test_proxy.py
+++ b/tests/http/test_proxy.py
@@ -22,27 +22,42 @@ from typing import Set
from parameterized import parameterized
-from synapse.http.proxy import parse_connection_header_value
+from synapse.http.proxy import (
+ HOP_BY_HOP_HEADERS_LOWERCASE,
+ parse_connection_header_value,
+)
from tests.unittest import TestCase
+def mix_case(s: str) -> str:
+ """
+ Mix up the case of each character in the string (upper or lower case)
+ """
+ return "".join(c.upper() if i % 2 == 0 else c.lower() for i, c in enumerate(s))
+
+
class ProxyTests(TestCase):
@parameterized.expand(
[
- [b"close, X-Foo, X-Bar", {"Close", "X-Foo", "X-Bar"}],
+ [b"close, X-Foo, X-Bar", {"close", "x-foo", "x-bar"}],
# No whitespace
- [b"close,X-Foo,X-Bar", {"Close", "X-Foo", "X-Bar"}],
+ [b"close,X-Foo,X-Bar", {"close", "x-foo", "x-bar"}],
# More whitespace
- [b"close, X-Foo, X-Bar", {"Close", "X-Foo", "X-Bar"}],
+ [b"close, X-Foo, X-Bar", {"close", "x-foo", "x-bar"}],
# "close" directive in not the first position
- [b"X-Foo, X-Bar, close", {"X-Foo", "X-Bar", "Close"}],
+ [b"X-Foo, X-Bar, close", {"x-foo", "x-bar", "close"}],
# Normalizes header capitalization
- [b"keep-alive, x-fOo, x-bAr", {"Keep-Alive", "X-Foo", "X-Bar"}],
+ [b"keep-alive, x-fOo, x-bAr", {"keep-alive", "x-foo", "x-bar"}],
# Handles header names with whitespace
[
b"keep-alive, x foo, x bar",
- {"Keep-Alive", "X foo", "X bar"},
+ {"keep-alive", "x foo", "x bar"},
+ ],
+ # Make sure we handle all of the hop-by-hop headers
+ [
+ mix_case(", ".join(HOP_BY_HOP_HEADERS_LOWERCASE)).encode("ascii"),
+ HOP_BY_HOP_HEADERS_LOWERCASE,
],
]
)
@@ -54,7 +69,8 @@ class ProxyTests(TestCase):
"""
Tests that the connection header value is parsed correctly
"""
- self.assertEqual(
+ self.assertIncludes(
expected_extra_headers_to_remove,
parse_connection_header_value(connection_header_value),
+ exact=True,
)
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index f71e4c2b8f..80b0856a56 100644
--- a/tests/http/test_proxyagent.py
+++ b/tests/http/test_proxyagent.py
@@ -854,7 +854,7 @@ class MatrixFederationAgentTests(TestCase):
def test_proxy_with_no_scheme(self) -> None:
http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
proxy_ep = checked_cast(HostnameEndpoint, http_proxy_agent.http_proxy_endpoint)
- self.assertEqual(proxy_ep._hostStr, "proxy.com")
+ self.assertEqual(proxy_ep._hostText, "proxy.com")
self.assertEqual(proxy_ep._port, 8888)
@patch.dict(os.environ, {"http_proxy": "socks://proxy.com:8888"})
@@ -866,14 +866,14 @@ class MatrixFederationAgentTests(TestCase):
def test_proxy_with_http_scheme(self) -> None:
http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
proxy_ep = checked_cast(HostnameEndpoint, http_proxy_agent.http_proxy_endpoint)
- self.assertEqual(proxy_ep._hostStr, "proxy.com")
+ self.assertEqual(proxy_ep._hostText, "proxy.com")
self.assertEqual(proxy_ep._port, 8888)
@patch.dict(os.environ, {"http_proxy": "https://proxy.com:8888"})
def test_proxy_with_https_scheme(self) -> None:
https_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
proxy_ep = checked_cast(_WrapperEndpoint, https_proxy_agent.http_proxy_endpoint)
- self.assertEqual(proxy_ep._wrappedEndpoint._hostStr, "proxy.com")
+ self.assertEqual(proxy_ep._wrappedEndpoint._hostText, "proxy.com")
self.assertEqual(proxy_ep._wrappedEndpoint._port, 8888)
diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py
index 18af2735fe..db39ecf244 100644
--- a/tests/http/test_servlet.py
+++ b/tests/http/test_servlet.py
@@ -76,7 +76,7 @@ class TestServletUtils(unittest.TestCase):
# Invalid UTF-8.
with self.assertRaises(SynapseError):
- parse_json_value_from_request(make_request(b"\xFF\x00"))
+ parse_json_value_from_request(make_request(b"\xff\x00"))
# Invalid JSON.
with self.assertRaises(SynapseError):
diff --git a/tests/http/test_site.py b/tests/http/test_site.py
index bfa26a329c..fc620c705a 100644
--- a/tests/http/test_site.py
+++ b/tests/http/test_site.py
@@ -90,3 +90,56 @@ class SynapseRequestTestCase(HomeserverTestCase):
# default max upload size is 50M, so it should drop on the next buffer after
# that.
self.assertEqual(sent, 50 * 1024 * 1024 + 1024)
+
+ def test_content_type_multipart(self) -> None:
+ """HTTP POST requests with `content-type: multipart/form-data` should be rejected"""
+ self.hs.start_listening()
+
+ # find the HTTP server which is configured to listen on port 0
+ (port, factory, _backlog, interface) = self.reactor.tcpServers[0]
+ self.assertEqual(interface, "::")
+ self.assertEqual(port, 0)
+
+ # as a control case, first send a regular request.
+
+ # complete the connection and wire it up to a fake transport
+ client_address = IPv6Address("TCP", "::1", 2345)
+ protocol = factory.buildProtocol(client_address)
+ transport = StringTransport()
+ protocol.makeConnection(transport)
+
+ protocol.dataReceived(
+ b"POST / HTTP/1.1\r\n"
+ b"Connection: close\r\n"
+ b"Transfer-Encoding: chunked\r\n"
+ b"\r\n"
+ b"0\r\n"
+ b"\r\n"
+ )
+
+ while not transport.disconnecting:
+ self.reactor.advance(1)
+
+ # we should get a 404
+ self.assertRegex(transport.value().decode(), r"^HTTP/1\.1 404 ")
+
+ # now send request with content-type header
+ protocol = factory.buildProtocol(client_address)
+ transport = StringTransport()
+ protocol.makeConnection(transport)
+
+ protocol.dataReceived(
+ b"POST / HTTP/1.1\r\n"
+ b"Connection: close\r\n"
+ b"Transfer-Encoding: chunked\r\n"
+ b"Content-Type: multipart/form-data\r\n"
+ b"\r\n"
+ b"0\r\n"
+ b"\r\n"
+ )
+
+ while not transport.disconnecting:
+ self.reactor.advance(1)
+
+ # we should get a 415
+ self.assertRegex(transport.value().decode(), r"^HTTP/1\.1 415 ")
diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index ff85e067b7..33b94cf9fa 100644
--- a/tests/logging/test_terse_json.py
+++ b/tests/logging/test_terse_json.py
@@ -164,7 +164,6 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
site.site_tag = "test-site"
site.server_version_string = "Server v1"
site.reactor = Mock()
- site.experimental_cors_msc3886 = False
request = SynapseRequest(
cast(HTTPChannel, FakeChannel(site, self.reactor)), site
)
diff --git a/tests/media/test_media_retention.py b/tests/media/test_media_retention.py
index 417d17ebd2..d8f4f57c8c 100644
--- a/tests/media/test_media_retention.py
+++ b/tests/media/test_media_retention.py
@@ -31,6 +31,9 @@ from synapse.rest.client import login, register, room
from synapse.server import HomeServer
from synapse.types import UserID
from synapse.util import Clock
+from synapse.util.stringutils import (
+ random_string,
+)
from tests import unittest
from tests.unittest import override_config
@@ -65,7 +68,6 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
# quarantined media) into both the local store and the remote cache, plus
# one additional local media that is marked as protected from quarantine.
media_repository = hs.get_media_repository()
- test_media_content = b"example string"
def _create_media_and_set_attributes(
last_accessed_ms: Optional[int],
@@ -73,12 +75,14 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
is_protected: Optional[bool] = False,
) -> MXCUri:
# "Upload" some media to the local media store
+ # If the meda
+ random_content = bytes(random_string(24), "utf-8")
mxc_uri: MXCUri = self.get_success(
media_repository.create_content(
media_type="text/plain",
upload_name=None,
- content=io.BytesIO(test_media_content),
- content_length=len(test_media_content),
+ content=io.BytesIO(random_content),
+ content_length=len(random_content),
auth_user=UserID.from_string(test_user_id),
)
)
@@ -129,6 +133,7 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
time_now_ms=clock.time_msec(),
upload_name="testfile.txt",
filesystem_id="abcdefg12345",
+ sha256=random_string(24),
)
)
diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py
index e55001fb40..2f7cf4569b 100644
--- a/tests/media/test_media_storage.py
+++ b/tests/media/test_media_storage.py
@@ -23,14 +23,13 @@ import shutil
import tempfile
from binascii import unhexlify
from io import BytesIO
-from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Tuple, Union
+from typing import Any, BinaryIO, ClassVar, Dict, List, Literal, Optional, Tuple, Union
from unittest.mock import MagicMock, Mock, patch
from urllib import parse
import attr
from parameterized import parameterized, parameterized_class
from PIL import Image as Image
-from typing_extensions import Literal
from twisted.internet import defer
from twisted.internet.defer import Deferred
@@ -43,6 +42,7 @@ from twisted.web.resource import Resource
from synapse.api.errors import Codes, HttpResponseException
from synapse.api.ratelimiting import Ratelimiter
from synapse.events import EventBase
+from synapse.http.client import ByteWriteable
from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable
from synapse.media._base import FileInfo, ThumbnailInfo
@@ -60,7 +60,7 @@ from synapse.util import Clock
from tests import unittest
from tests.server import FakeChannel
-from tests.test_utils import SMALL_PNG
+from tests.test_utils import SMALL_CMYK_JPEG, SMALL_PNG, SMALL_PNG_SHA256
from tests.unittest import override_config
from tests.utils import default_config
@@ -187,10 +187,70 @@ small_png_with_transparency = TestImage(
# different versions of Pillow.
)
-small_lossless_webp = TestImage(
+small_cmyk_jpeg = TestImage(
+ SMALL_CMYK_JPEG,
+ b"image/jpeg",
+ b".jpeg",
+ # These values were sourced simply by seeing at what the tests produced at
+ # the time of writing. If this changes, the tests will fail.
+ unhexlify(
+ b"ffd8ffe000104a46494600010100000100010000ffdb00430006"
+ b"040506050406060506070706080a100a0a09090a140e0f0c1017"
+ b"141818171416161a1d251f1a1b231c1616202c20232627292a29"
+ b"191f2d302d283025282928ffdb0043010707070a080a130a0a13"
+ b"281a161a28282828282828282828282828282828282828282828"
+ b"2828282828282828282828282828282828282828282828282828"
+ b"2828ffc00011080020002003012200021101031101ffc4001f00"
+ b"0001050101010101010000000000000000010203040506070809"
+ b"0a0bffc400b5100002010303020403050504040000017d010203"
+ b"00041105122131410613516107227114328191a1082342b1c115"
+ b"52d1f02433627282090a161718191a25262728292a3435363738"
+ b"393a434445464748494a535455565758595a636465666768696a"
+ b"737475767778797a838485868788898a92939495969798999aa2"
+ b"a3a4a5a6a7a8a9aab2b3b4b5b6b7b8b9bac2c3c4c5c6c7c8c9ca"
+ b"d2d3d4d5d6d7d8d9dae1e2e3e4e5e6e7e8e9eaf1f2f3f4f5f6f7"
+ b"f8f9faffc4001f01000301010101010101010100000000000001"
+ b"02030405060708090a0bffc400b5110002010204040304070504"
+ b"0400010277000102031104052131061241510761711322328108"
+ b"144291a1b1c109233352f0156272d10a162434e125f11718191a"
+ b"262728292a35363738393a434445464748494a53545556575859"
+ b"5a636465666768696a737475767778797a82838485868788898a"
+ b"92939495969798999aa2a3a4a5a6a7a8a9aab2b3b4b5b6b7b8b9"
+ b"bac2c3c4c5c6c7c8c9cad2d3d4d5d6d7d8d9dae2e3e4e5e6e7e8"
+ b"e9eaf2f3f4f5f6f7f8f9faffda000c03010002110311003f00fa"
+ b"a68a28a0028a28a0028a28a0028a28a00fffd9"
+ ),
unhexlify(
- b"524946461a000000574542505650384c0d0000002f0000001007" b"1011118888fe0700"
+ b"ffd8ffe000104a46494600010100000100010000ffdb00430006"
+ b"040506050406060506070706080a100a0a09090a140e0f0c1017"
+ b"141818171416161a1d251f1a1b231c1616202c20232627292a29"
+ b"191f2d302d283025282928ffdb0043010707070a080a130a0a13"
+ b"281a161a28282828282828282828282828282828282828282828"
+ b"2828282828282828282828282828282828282828282828282828"
+ b"2828ffc00011080001000103012200021101031101ffc4001f00"
+ b"0001050101010101010000000000000000010203040506070809"
+ b"0a0bffc400b5100002010303020403050504040000017d010203"
+ b"00041105122131410613516107227114328191a1082342b1c115"
+ b"52d1f02433627282090a161718191a25262728292a3435363738"
+ b"393a434445464748494a535455565758595a636465666768696a"
+ b"737475767778797a838485868788898a92939495969798999aa2"
+ b"a3a4a5a6a7a8a9aab2b3b4b5b6b7b8b9bac2c3c4c5c6c7c8c9ca"
+ b"d2d3d4d5d6d7d8d9dae1e2e3e4e5e6e7e8e9eaf1f2f3f4f5f6f7"
+ b"f8f9faffc4001f01000301010101010101010100000000000001"
+ b"02030405060708090a0bffc400b5110002010204040304070504"
+ b"0400010277000102031104052131061241510761711322328108"
+ b"144291a1b1c109233352f0156272d10a162434e125f11718191a"
+ b"262728292a35363738393a434445464748494a53545556575859"
+ b"5a636465666768696a737475767778797a82838485868788898a"
+ b"92939495969798999aa2a3a4a5a6a7a8a9aab2b3b4b5b6b7b8b9"
+ b"bac2c3c4c5c6c7c8c9cad2d3d4d5d6d7d8d9dae2e3e4e5e6e7e8"
+ b"e9eaf2f3f4f5f6f7f8f9faffda000c03010002110311003f00fa"
+ b"a68a28a00fffd9"
),
+)
+
+small_lossless_webp = TestImage(
+ unhexlify(b"524946461a000000574542505650384c0d0000002f00000010071011118888fe0700"),
b"image/webp",
b".webp",
)
@@ -261,7 +321,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
"""A mock for MatrixFederationHttpClient.get_file."""
def write_to(
- r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]
+ r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]],
) -> Tuple[int, Dict[bytes, List[bytes]]]:
data, response = r
output_stream.write(data)
@@ -357,6 +417,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
return channel
+ @unittest.override_config(
+ {
+ "enable_authenticated_media": False,
+ }
+ )
def test_handle_missing_content_type(self) -> None:
channel = self._req(
b"attachment; filename=out" + self.test_image.extension,
@@ -368,6 +433,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
headers.getRawHeaders(b"Content-Type"), [b"application/octet-stream"]
)
+ @unittest.override_config(
+ {
+ "enable_authenticated_media": False,
+ }
+ )
def test_disposition_filename_ascii(self) -> None:
"""
If the filename is filename=<ascii> then Synapse will decode it as an
@@ -388,6 +458,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
],
)
+ @unittest.override_config(
+ {
+ "enable_authenticated_media": False,
+ }
+ )
def test_disposition_filenamestar_utf8escaped(self) -> None:
"""
If the filename is filename=*utf8''<utf8 escaped> then Synapse will
@@ -413,6 +488,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
],
)
+ @unittest.override_config(
+ {
+ "enable_authenticated_media": False,
+ }
+ )
def test_disposition_none(self) -> None:
"""
If there is no filename, Content-Disposition should only
@@ -429,6 +509,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
[b"inline" if self.test_image.is_inline else b"attachment"],
)
+ @unittest.override_config(
+ {
+ "enable_authenticated_media": False,
+ }
+ )
def test_thumbnail_crop(self) -> None:
"""Test that a cropped remote thumbnail is available."""
self._test_thumbnail(
@@ -438,6 +523,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
unable_to_thumbnail=self.test_image.unable_to_thumbnail,
)
+ @unittest.override_config(
+ {
+ "enable_authenticated_media": False,
+ }
+ )
def test_thumbnail_scale(self) -> None:
"""Test that a scaled remote thumbnail is available."""
self._test_thumbnail(
@@ -447,6 +537,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
unable_to_thumbnail=self.test_image.unable_to_thumbnail,
)
+ @unittest.override_config(
+ {
+ "enable_authenticated_media": False,
+ }
+ )
def test_invalid_type(self) -> None:
"""An invalid thumbnail type is never available."""
self._test_thumbnail(
@@ -457,7 +552,10 @@ class MediaRepoTests(unittest.HomeserverTestCase):
)
@unittest.override_config(
- {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]}
+ {
+ "thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}],
+ "enable_authenticated_media": False,
+ },
)
def test_no_thumbnail_crop(self) -> None:
"""
@@ -471,7 +569,10 @@ class MediaRepoTests(unittest.HomeserverTestCase):
)
@unittest.override_config(
- {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]}
+ {
+ "thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}],
+ "enable_authenticated_media": False,
+ }
)
def test_no_thumbnail_scale(self) -> None:
"""
@@ -484,6 +585,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
unable_to_thumbnail=self.test_image.unable_to_thumbnail,
)
+ @unittest.override_config(
+ {
+ "enable_authenticated_media": False,
+ }
+ )
def test_thumbnail_repeated_thumbnail(self) -> None:
"""Test that fetching the same thumbnail works, and deleting the on disk
thumbnail regenerates it.
@@ -658,6 +764,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
)
)
+ @unittest.override_config(
+ {
+ "enable_authenticated_media": False,
+ }
+ )
def test_x_robots_tag_header(self) -> None:
"""
Tests that the `X-Robots-Tag` header is present, which informs web crawlers
@@ -671,6 +782,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
[b"noindex, nofollow, noarchive, noimageindex"],
)
+ @unittest.override_config(
+ {
+ "enable_authenticated_media": False,
+ }
+ )
def test_cross_origin_resource_policy_header(self) -> None:
"""
Test that the Cross-Origin-Resource-Policy header is set to "cross-origin"
@@ -685,6 +801,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
[b"cross-origin"],
)
+ @unittest.override_config(
+ {
+ "enable_authenticated_media": False,
+ }
+ )
def test_unknown_v3_endpoint(self) -> None:
"""
If the v3 endpoint fails, try the r0 one.
@@ -923,6 +1044,11 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
d.callback(52428800)
return d
+ @override_config(
+ {
+ "enable_authenticated_media": False,
+ }
+ )
@patch(
"synapse.http.matrixfederationclient.read_body_with_max_size",
read_body_with_max_size_30MiB,
@@ -998,6 +1124,7 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
{
"remote_media_download_per_second": "50M",
"remote_media_download_burst_count": "50M",
+ "enable_authenticated_media": False,
}
)
@patch(
@@ -1057,7 +1184,12 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
)
assert channel.code == 200
- @override_config({"remote_media_download_burst_count": "87M"})
+ @override_config(
+ {
+ "remote_media_download_burst_count": "87M",
+ "enable_authenticated_media": False,
+ }
+ )
@patch(
"synapse.http.matrixfederationclient.read_body_with_max_size",
read_body_with_max_size_30MiB,
@@ -1097,7 +1229,7 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
)
assert channel2.code == 429
- @override_config({"max_upload_size": "29M"})
+ @override_config({"max_upload_size": "29M", "enable_authenticated_media": False})
@patch(
"synapse.http.matrixfederationclient.read_body_with_max_size",
read_body_with_max_size_30MiB,
@@ -1124,3 +1256,146 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
)
assert channel.code == 502
assert channel.json_body["errcode"] == "M_TOO_LARGE"
+
+
+def read_body(
+ response: IResponse, stream: ByteWriteable, max_size: Optional[int]
+) -> Deferred:
+ d: Deferred = defer.Deferred()
+ stream.write(SMALL_PNG)
+ d.callback(len(SMALL_PNG))
+ return d
+
+
+class MediaHashesTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ media.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.user = self.register_user("user", "pass")
+ self.tok = self.login("user", "pass")
+ self.store = hs.get_datastores().main
+ self.client = hs.get_federation_http_client()
+
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ resources = super().create_resource_dict()
+ resources["/_matrix/media"] = self.hs.get_media_repository_resource()
+ return resources
+
+ def test_ensure_correct_sha256(self) -> None:
+ """Check that the hash does not change"""
+ media = self.helper.upload_media(SMALL_PNG, tok=self.tok, expect_code=200)
+ mxc = media.get("content_uri")
+ assert mxc
+ store_media = self.get_success(self.store.get_local_media(mxc[11:]))
+ assert store_media
+ self.assertEqual(
+ store_media.sha256,
+ SMALL_PNG_SHA256,
+ )
+
+ def test_ensure_multiple_correct_sha256(self) -> None:
+ """Check that two media items have the same hash."""
+ media_a = self.helper.upload_media(SMALL_PNG, tok=self.tok, expect_code=200)
+ mxc_a = media_a.get("content_uri")
+ assert mxc_a
+ store_media_a = self.get_success(self.store.get_local_media(mxc_a[11:]))
+ assert store_media_a
+
+ media_b = self.helper.upload_media(SMALL_PNG, tok=self.tok, expect_code=200)
+ mxc_b = media_b.get("content_uri")
+ assert mxc_b
+ store_media_b = self.get_success(self.store.get_local_media(mxc_b[11:]))
+ assert store_media_b
+
+ self.assertNotEqual(
+ store_media_a.media_id,
+ store_media_b.media_id,
+ )
+ self.assertEqual(
+ store_media_a.sha256,
+ store_media_b.sha256,
+ )
+
+ @override_config(
+ {
+ "enable_authenticated_media": False,
+ }
+ )
+ # mock actually reading file body
+ @patch(
+ "synapse.http.matrixfederationclient.read_body_with_max_size",
+ read_body,
+ )
+ def test_ensure_correct_sha256_federated(self) -> None:
+ """Check that federated media have the same hash."""
+
+ # Mock getting a file over federation
+ async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
+ resp = MagicMock(spec=IResponse)
+ resp.code = 200
+ resp.length = 500
+ resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
+ resp.phrase = b"OK"
+ return resp
+
+ self.client._send_request = _send_request # type: ignore
+
+ # first request should go through
+ channel = self.make_request(
+ "GET",
+ "/_matrix/media/v3/download/remote.org/abc",
+ shorthand=False,
+ access_token=self.tok,
+ )
+ assert channel.code == 200
+ store_media = self.get_success(
+ self.store.get_cached_remote_media("remote.org", "abc")
+ )
+ assert store_media
+ self.assertEqual(
+ store_media.sha256,
+ SMALL_PNG_SHA256,
+ )
+
+
+class MediaRepoSizeModuleCallbackTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ login.register_servlets,
+ admin.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.user = self.register_user("user", "pass")
+ self.tok = self.login("user", "pass")
+ self.mock_result = True # Allow all uploads by default
+
+ hs.get_module_api().register_media_repository_callbacks(
+ is_user_allowed_to_upload_media_of_size=self.is_user_allowed_to_upload_media_of_size,
+ )
+
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ resources = super().create_resource_dict()
+ resources["/_matrix/media"] = self.hs.get_media_repository_resource()
+ return resources
+
+ async def is_user_allowed_to_upload_media_of_size(
+ self, user_id: str, size: int
+ ) -> bool:
+ self.last_user_id = user_id
+ self.last_size = size
+ return self.mock_result
+
+ def test_upload_allowed(self) -> None:
+ self.helper.upload_media(SMALL_PNG, tok=self.tok, expect_code=200)
+ assert self.last_user_id == self.user
+ assert self.last_size == len(SMALL_PNG)
+
+ def test_upload_not_allowed(self) -> None:
+ self.mock_result = False
+ self.helper.upload_media(SMALL_PNG, tok=self.tok, expect_code=413)
+ assert self.last_user_id == self.user
+ assert self.last_size == len(SMALL_PNG)
diff --git a/tests/media/test_oembed.py b/tests/media/test_oembed.py
index 29d4580697..b8265ff9ca 100644
--- a/tests/media/test_oembed.py
+++ b/tests/media/test_oembed.py
@@ -20,6 +20,7 @@
#
import json
+from typing import Any
from parameterized import parameterized
@@ -52,6 +53,7 @@ class OEmbedTests(HomeserverTestCase):
def test_version(self) -> None:
"""Accept versions that are similar to 1.0 as a string or int (or missing)."""
+ version: Any
for version in ("1.0", 1.0, 1):
result = self.parse_response({"version": version})
# An empty Open Graph response is an error, ensure the URL is included.
@@ -69,6 +71,7 @@ class OEmbedTests(HomeserverTestCase):
def test_cache_age(self) -> None:
"""Ensure a cache-age is parsed properly."""
+ cache_age: Any
# Correct-ish cache ages are allowed.
for cache_age in ("1", 1.0, 1):
result = self.parse_response({"cache_age": cache_age})
diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py
index 80f24814e8..2e7004df3a 100644
--- a/tests/metrics/test_metrics.py
+++ b/tests/metrics/test_metrics.py
@@ -19,12 +19,11 @@
#
#
from importlib import metadata
-from typing import Dict, Tuple
+from typing import Dict, Protocol, Tuple
from unittest.mock import patch
from pkg_resources import parse_version
from prometheus_client.core import Sample
-from typing_extensions import Protocol
from synapse.app._base import _set_prometheus_client_use_created_metrics
from synapse.metrics import REGISTRY, InFlightGauge, generate_latest
diff --git a/tests/metrics/test_phone_home_stats.py b/tests/metrics/test_phone_home_stats.py
new file mode 100644
index 0000000000..5339d649df
--- /dev/null
+++ b/tests/metrics/test_phone_home_stats.py
@@ -0,0 +1,263 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2025 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+
+import logging
+from unittest.mock import AsyncMock
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.app.phone_stats_home import (
+ PHONE_HOME_INTERVAL_SECONDS,
+ start_phone_stats_home,
+)
+from synapse.rest import admin, login, register, room
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
+
+from tests import unittest
+from tests.server import ThreadedMemoryReactorClock
+
+TEST_REPORT_STATS_ENDPOINT = "https://fake.endpoint/stats"
+TEST_SERVER_CONTEXT = "test-server-context"
+
+
+class PhoneHomeStatsTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ register.register_servlets,
+ login.register_servlets,
+ ]
+
+ def make_homeserver(
+ self, reactor: ThreadedMemoryReactorClock, clock: Clock
+ ) -> HomeServer:
+ # Configure the homeserver to enable stats reporting.
+ config = self.default_config()
+ config["report_stats"] = True
+ config["report_stats_endpoint"] = TEST_REPORT_STATS_ENDPOINT
+
+ # Configure the server context so we can check it ends up being reported
+ config["server_context"] = TEST_SERVER_CONTEXT
+
+ # Allow guests to be registered
+ config["allow_guest_access"] = True
+
+ hs = self.setup_test_homeserver(config=config)
+
+ # Replace the proxied http client with a mock, so we can inspect outbound requests to
+ # the configured stats endpoint.
+ self.put_json_mock = AsyncMock(return_value={})
+ hs.get_proxied_http_client().put_json = self.put_json_mock # type: ignore[method-assign]
+ return hs
+
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
+ self.store = homeserver.get_datastores().main
+
+ # Wait for the background updates to add the database triggers that keep the
+ # `event_stats` table up-to-date.
+ self.wait_for_background_updates()
+
+ # Force stats reporting to occur
+ start_phone_stats_home(hs=homeserver)
+
+ super().prepare(reactor, clock, homeserver)
+
+ def _get_latest_phone_home_stats(self) -> JsonDict:
+ # Wait for `phone_stats_home` to be called again + a healthy margin (50s).
+ self.reactor.advance(2 * PHONE_HOME_INTERVAL_SECONDS + 50)
+
+ # Extract the reported stats from our http client mock
+ mock_calls = self.put_json_mock.call_args_list
+ report_stats_calls = []
+ for call in mock_calls:
+ if call.args[0] == TEST_REPORT_STATS_ENDPOINT:
+ report_stats_calls.append(call)
+
+ self.assertGreaterEqual(
+ (len(report_stats_calls)),
+ 1,
+ "Expected at-least one call to the report_stats endpoint",
+ )
+
+ # Extract the phone home stats from the call
+ phone_home_stats = report_stats_calls[0].args[1]
+
+ return phone_home_stats
+
+ def _perform_user_actions(self) -> None:
+ """
+ Perform some actions on the homeserver that would bump the phone home
+ stats.
+
+ This creates a few users (including a guest), a room, and sends some messages.
+ Expected number of events:
+ - 10 unencrypted messages
+ - 5 encrypted messages
+ - 24 total events (including room state, etc)
+ """
+
+ # Create some users
+ user_1_mxid = self.register_user(
+ username="test_user_1",
+ password="test",
+ )
+ user_2_mxid = self.register_user(
+ username="test_user_2",
+ password="test",
+ )
+ # Note: `self.register_user` does not support guest registration, and updating the
+ # Admin API it calls to add a new parameter would cause the `mac` parameter to fail
+ # in a backwards-incompatible manner. Hence, we make a manual request here.
+ _guest_user_mxid = self.make_request(
+ method="POST",
+ path="/_matrix/client/v3/register?kind=guest",
+ content={
+ "username": "guest_user",
+ "password": "test",
+ },
+ shorthand=False,
+ )
+
+ # Log in to each user
+ user_1_token = self.login(username=user_1_mxid, password="test")
+ user_2_token = self.login(username=user_2_mxid, password="test")
+
+ # Create a room between the two users
+ room_1_id = self.helper.create_room_as(
+ is_public=False,
+ tok=user_1_token,
+ )
+
+ # Mark this room as end-to-end encrypted
+ self.helper.send_state(
+ room_id=room_1_id,
+ event_type="m.room.encryption",
+ body={
+ "algorithm": "m.megolm.v1.aes-sha2",
+ "rotation_period_ms": 604800000,
+ "rotation_period_msgs": 100,
+ },
+ state_key="",
+ tok=user_1_token,
+ )
+
+ # User 1 invites user 2
+ self.helper.invite(
+ room=room_1_id,
+ src=user_1_mxid,
+ targ=user_2_mxid,
+ tok=user_1_token,
+ )
+
+ # User 2 joins
+ self.helper.join(
+ room=room_1_id,
+ user=user_2_mxid,
+ tok=user_2_token,
+ )
+
+ # User 1 sends 10 unencrypted messages
+ for _ in range(10):
+ self.helper.send(
+ room_id=room_1_id,
+ body="Zoinks Scoob! A message!",
+ tok=user_1_token,
+ )
+
+ # User 2 sends 5 encrypted "messages"
+ for _ in range(5):
+ self.helper.send_event(
+ room_id=room_1_id,
+ type="m.room.encrypted",
+ content={
+ "algorithm": "m.olm.v1.curve25519-aes-sha2",
+ "sender_key": "some_key",
+ "ciphertext": {
+ "some_key": {
+ "type": 0,
+ "body": "encrypted_payload",
+ },
+ },
+ },
+ tok=user_2_token,
+ )
+
+ def test_phone_home_stats(self) -> None:
+ """
+ Test that the phone home stats contain the stats we expect based on
+ the scenario carried out in `prepare`
+ """
+ # Do things to bump the stats
+ self._perform_user_actions()
+
+ # Wait for the stats to be reported
+ phone_home_stats = self._get_latest_phone_home_stats()
+
+ self.assertEqual(
+ phone_home_stats["homeserver"], self.hs.config.server.server_name
+ )
+
+ self.assertTrue(isinstance(phone_home_stats["memory_rss"], int))
+ self.assertTrue(isinstance(phone_home_stats["cpu_average"], int))
+
+ self.assertEqual(phone_home_stats["server_context"], TEST_SERVER_CONTEXT)
+
+ self.assertTrue(isinstance(phone_home_stats["timestamp"], int))
+ self.assertTrue(isinstance(phone_home_stats["uptime_seconds"], int))
+ self.assertTrue(isinstance(phone_home_stats["python_version"], str))
+
+ # We expect only our test users to exist on the homeserver
+ self.assertEqual(phone_home_stats["total_users"], 3)
+ self.assertEqual(phone_home_stats["total_nonbridged_users"], 3)
+ self.assertEqual(phone_home_stats["daily_user_type_native"], 2)
+ self.assertEqual(phone_home_stats["daily_user_type_guest"], 1)
+ self.assertEqual(phone_home_stats["daily_user_type_bridged"], 0)
+ self.assertEqual(phone_home_stats["total_room_count"], 1)
+ self.assertEqual(phone_home_stats["daily_active_users"], 2)
+ self.assertEqual(phone_home_stats["monthly_active_users"], 2)
+ self.assertEqual(phone_home_stats["daily_active_rooms"], 1)
+ self.assertEqual(phone_home_stats["daily_active_e2ee_rooms"], 1)
+ self.assertEqual(phone_home_stats["daily_messages"], 10)
+ self.assertEqual(phone_home_stats["daily_e2ee_messages"], 5)
+ self.assertEqual(phone_home_stats["daily_sent_messages"], 10)
+ self.assertEqual(phone_home_stats["daily_sent_e2ee_messages"], 5)
+
+ # Our users have not been around for >30 days, hence these are all 0.
+ self.assertEqual(phone_home_stats["r30v2_users_all"], 0)
+ self.assertEqual(phone_home_stats["r30v2_users_android"], 0)
+ self.assertEqual(phone_home_stats["r30v2_users_ios"], 0)
+ self.assertEqual(phone_home_stats["r30v2_users_electron"], 0)
+ self.assertEqual(phone_home_stats["r30v2_users_web"], 0)
+ self.assertEqual(
+ phone_home_stats["cache_factor"], self.hs.config.caches.global_factor
+ )
+ self.assertEqual(
+ phone_home_stats["event_cache_size"],
+ self.hs.config.caches.event_cache_size,
+ )
+ self.assertEqual(
+ phone_home_stats["database_engine"],
+ self.hs.config.database.databases[0].config["name"],
+ )
+ self.assertEqual(
+ phone_home_stats["database_server_version"],
+ self.hs.get_datastores().main.database_engine.server_version,
+ )
+
+ synapse_logger = logging.getLogger("synapse")
+ log_level = synapse_logger.getEffectiveLevel()
+ self.assertEqual(phone_home_stats["log_level"], logging.getLevelName(log_level))
diff --git a/tests/module_api/test_account_data_manager.py b/tests/module_api/test_account_data_manager.py
index fd87eaffd0..1a1d5609b2 100644
--- a/tests/module_api/test_account_data_manager.py
+++ b/tests/module_api/test_account_data_manager.py
@@ -164,6 +164,8 @@ class ModuleApiTestCase(HomeserverTestCase):
# noinspection PyTypeChecker
self.get_success_or_raise(
self._module_api.account_data_manager.put_global(
- self.user_id, "test.data", 42 # type: ignore[arg-type]
+ self.user_id,
+ "test.data",
+ 42, # type: ignore[arg-type]
)
)
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index b6ba472d7d..85bd76b78f 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -87,7 +87,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
# Register a new user
user_id, access_token = self.get_success(
self.module_api.register(
- "bob", displayname="Bobberino", emails=["bob@bobinator.bob"]
+ "bob", displayname="Bobberino"
)
)
@@ -96,18 +96,6 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
self.assertTrue(access_token)
self.assertTrue(self.get_success(self.store.get_user_by_id(user_id)))
- # Check that the email was assigned
- emails = self.get_success(self.store.user_get_threepids(user_id))
- self.assertEqual(len(emails), 1)
-
- email = emails[0]
- self.assertEqual(email.medium, "email")
- self.assertEqual(email.address, "bob@bobinator.bob")
-
- # Should these be 0?
- self.assertEqual(email.validated_at, 0)
- self.assertEqual(email.added_at, 0)
-
# Check that the displayname was assigned
displayname = self.get_success(
self.store.get_profile_displayname(UserID.from_string("@bob:test"))
diff --git a/tests/module_api/test_spamchecker.py b/tests/module_api/test_spamchecker.py
new file mode 100644
index 0000000000..926fe30b43
--- /dev/null
+++ b/tests/module_api/test_spamchecker.py
@@ -0,0 +1,244 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2025 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+#
+from typing import Literal, Union
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.config.server import DEFAULT_ROOM_VERSION
+from synapse.rest import admin, login, room, room_upgrade_rest_servlet
+from synapse.server import HomeServer
+from synapse.types import Codes, JsonDict
+from synapse.util import Clock
+
+from tests.server import FakeChannel
+from tests.unittest import HomeserverTestCase
+
+
+class SpamCheckerTestCase(HomeserverTestCase):
+ servlets = [
+ room.register_servlets,
+ admin.register_servlets,
+ login.register_servlets,
+ room_upgrade_rest_servlet.register_servlets,
+ ]
+
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
+ self._module_api = homeserver.get_module_api()
+ self.user_id = self.register_user("user", "password")
+ self.token = self.login("user", "password")
+
+ def create_room(self, content: JsonDict) -> FakeChannel:
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/createRoom",
+ content,
+ access_token=self.token,
+ )
+
+ return channel
+
+ def test_may_user_create_room(self) -> None:
+ """Test that the may_user_create_room callback is called when a user
+ creates a room, and that it receives the correct parameters.
+ """
+
+ async def user_may_create_room(
+ user_id: str, room_config: JsonDict
+ ) -> Union[Literal["NOT_SPAM"], Codes]:
+ self.last_room_config = room_config
+ self.last_user_id = user_id
+ return "NOT_SPAM"
+
+ self._module_api.register_spam_checker_callbacks(
+ user_may_create_room=user_may_create_room
+ )
+
+ channel = self.create_room({"foo": "baa"})
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(self.last_user_id, self.user_id)
+ self.assertEqual(self.last_room_config["foo"], "baa")
+
+ def test_may_user_create_room_on_upgrade(self) -> None:
+ """Test that the may_user_create_room callback is called when a room is upgraded."""
+
+ # First, create a room to upgrade.
+ channel = self.create_room({"topic": "foo"})
+ self.assertEqual(channel.code, 200)
+ room_id = channel.json_body["room_id"]
+
+ async def user_may_create_room(
+ user_id: str, room_config: JsonDict
+ ) -> Union[Literal["NOT_SPAM"], Codes]:
+ self.last_room_config = room_config
+ self.last_user_id = user_id
+ return "NOT_SPAM"
+
+ # Register the callback for spam checking.
+ self._module_api.register_spam_checker_callbacks(
+ user_may_create_room=user_may_create_room
+ )
+
+ # Now upgrade the room.
+ channel = self.make_request(
+ "POST",
+ f"/_matrix/client/r0/rooms/{room_id}/upgrade",
+ # This will upgrade a room to the same version, but that's fine.
+ content={"new_version": DEFAULT_ROOM_VERSION},
+ access_token=self.token,
+ )
+
+ # Check that the callback was called and the room was upgraded.
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(self.last_user_id, self.user_id)
+ # Check that the initial state received by callback contains the topic event.
+ self.assertTrue(
+ any(
+ event[0][0] == "m.room.topic" and event[1].get("topic") == "foo"
+ for event in self.last_room_config["initial_state"]
+ )
+ )
+
+ def test_may_user_create_room_disallowed(self) -> None:
+ """Test that the codes response from may_user_create_room callback is respected
+ and returned via the API.
+ """
+
+ async def user_may_create_room(
+ user_id: str, room_config: JsonDict
+ ) -> Union[Literal["NOT_SPAM"], Codes]:
+ self.last_room_config = room_config
+ self.last_user_id = user_id
+ return Codes.UNAUTHORIZED
+
+ self._module_api.register_spam_checker_callbacks(
+ user_may_create_room=user_may_create_room
+ )
+
+ channel = self.create_room({"foo": "baa"})
+ self.assertEqual(channel.code, 403)
+ self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+ self.assertEqual(self.last_user_id, self.user_id)
+ self.assertEqual(self.last_room_config["foo"], "baa")
+
+ def test_may_user_create_room_compatibility(self) -> None:
+ """Test that the may_user_create_room callback is called when a user
+ creates a room for a module that uses the old callback signature
+ (without the `room_config` parameter)
+ """
+
+ async def user_may_create_room(
+ user_id: str,
+ ) -> Union[Literal["NOT_SPAM"], Codes]:
+ self.last_user_id = user_id
+ return "NOT_SPAM"
+
+ self._module_api.register_spam_checker_callbacks(
+ user_may_create_room=user_may_create_room
+ )
+
+ channel = self.create_room({"foo": "baa"})
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(self.last_user_id, self.user_id)
+
+ def test_user_may_send_state_event(self) -> None:
+ """Test that the user_may_send_state_event callback is called when a state event
+ is sent, and that it receives the correct parameters.
+ """
+
+ async def user_may_send_state_event(
+ user_id: str,
+ room_id: str,
+ event_type: str,
+ state_key: str,
+ content: JsonDict,
+ ) -> Union[Literal["NOT_SPAM"], Codes]:
+ self.last_user_id = user_id
+ self.last_room_id = room_id
+ self.last_event_type = event_type
+ self.last_state_key = state_key
+ self.last_content = content
+ return "NOT_SPAM"
+
+ self._module_api.register_spam_checker_callbacks(
+ user_may_send_state_event=user_may_send_state_event
+ )
+
+ channel = self.create_room({})
+ self.assertEqual(channel.code, 200)
+
+ room_id = channel.json_body["room_id"]
+
+ event_type = "test.event.type"
+ state_key = "test.state.key"
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/state/%s/%s"
+ % (
+ room_id,
+ event_type,
+ state_key,
+ ),
+ content={"foo": "bar"},
+ access_token=self.token,
+ )
+
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(self.last_user_id, self.user_id)
+ self.assertEqual(self.last_room_id, room_id)
+ self.assertEqual(self.last_event_type, event_type)
+ self.assertEqual(self.last_state_key, state_key)
+ self.assertEqual(self.last_content, {"foo": "bar"})
+
+ def test_user_may_send_state_event_disallows(self) -> None:
+ """Test that the user_may_send_state_event callback is called when a state event
+ is sent, and that the response is honoured.
+ """
+
+ async def user_may_send_state_event(
+ user_id: str,
+ room_id: str,
+ event_type: str,
+ state_key: str,
+ content: JsonDict,
+ ) -> Union[Literal["NOT_SPAM"], Codes]:
+ return Codes.FORBIDDEN
+
+ self._module_api.register_spam_checker_callbacks(
+ user_may_send_state_event=user_may_send_state_event
+ )
+
+ channel = self.create_room({})
+ self.assertEqual(channel.code, 200)
+
+ room_id = channel.json_body["room_id"]
+
+ event_type = "test.event.type"
+ state_key = "test.state.key"
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/state/%s/%s"
+ % (
+ room_id,
+ event_type,
+ state_key,
+ ),
+ content={"foo": "bar"},
+ access_token=self.token,
+ )
+
+ self.assertEqual(channel.code, 403)
+ self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py
index fc73f3dc2a..16c1292812 100644
--- a/tests/push/test_bulk_push_rule_evaluator.py
+++ b/tests/push/test_bulk_push_rule_evaluator.py
@@ -120,9 +120,11 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
#
# We have seen stringy and null values for "room" in the wild, so presumably
# some of this validation was missing in the past.
- with patch("synapse.events.validator.validate_canonicaljson"), patch(
- "synapse.events.validator.jsonschema.validate"
- ), patch("synapse.handlers.event_auth.check_state_dependent_auth_rules"):
+ with (
+ patch("synapse.events.validator.validate_canonicaljson"),
+ patch("synapse.events.validator.jsonschema.validate"),
+ patch("synapse.handlers.event_auth.check_state_dependent_auth_rules"),
+ ):
pl_event_id = self.helper.send_state(
self.room_id,
"m.room.power_levels",
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index e0aab1c046..66ee6ca20f 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -31,9 +31,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.errors import Codes, SynapseError
-from synapse.push.emailpusher import EmailPusher
from synapse.rest.client import login, room
-from synapse.rest.synapse.client.unsubscribe import UnsubscribeResource
from synapse.server import HomeServer
from synapse.util import Clock
@@ -44,524 +42,6 @@ from tests.unittest import HomeserverTestCase
@attr.s(auto_attribs=True)
class _User:
"Helper wrapper for user ID and access token"
+
id: str
token: str
-
-
-class EmailPusherTests(HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets_for_client_rest_resource,
- room.register_servlets,
- login.register_servlets,
- ]
- hijack_auth = False
-
- def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- config = self.default_config()
- config["email"] = {
- "enable_notifs": True,
- "template_dir": os.path.abspath(
- pkg_resources.resource_filename("synapse", "res/templates")
- ),
- "expiry_template_html": "notice_expiry.html",
- "expiry_template_text": "notice_expiry.txt",
- "notif_template_html": "notif_mail.html",
- "notif_template_text": "notif_mail.txt",
- "smtp_host": "127.0.0.1",
- "smtp_port": 20,
- "require_transport_security": False,
- "smtp_user": None,
- "smtp_pass": None,
- "app_name": "Matrix",
- "notif_from": "test@example.com",
- "riot_base_url": None,
- }
- config["public_baseurl"] = "http://aaa"
-
- hs = self.setup_test_homeserver(config=config)
-
- # List[Tuple[Deferred, args, kwargs]]
- self.email_attempts: List[Tuple[Deferred, Sequence, Dict]] = []
-
- def sendmail(*args: Any, **kwargs: Any) -> Deferred:
- # This mocks out synapse.reactor.send_email._sendmail.
- d: Deferred = Deferred()
- self.email_attempts.append((d, args, kwargs))
- return d
-
- hs.get_send_email_handler()._sendmail = sendmail # type: ignore[assignment]
-
- return hs
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- # Register the user who gets notified
- self.user_id = self.register_user("user", "pass")
- self.access_token = self.login("user", "pass")
-
- # Register other users
- self.others = [
- _User(
- id=self.register_user("otheruser1", "pass"),
- token=self.login("otheruser1", "pass"),
- ),
- _User(
- id=self.register_user("otheruser2", "pass"),
- token=self.login("otheruser2", "pass"),
- ),
- ]
-
- # Register the pusher
- user_tuple = self.get_success(
- self.hs.get_datastores().main.get_user_by_access_token(self.access_token)
- )
- assert user_tuple is not None
- self.device_id = user_tuple.device_id
-
- # We need to add email to account before we can create a pusher.
- self.get_success(
- hs.get_datastores().main.user_add_threepid(
- self.user_id, "email", "a@example.com", 0, 0
- )
- )
-
- pusher = self.get_success(
- self.hs.get_pusherpool().add_or_update_pusher(
- user_id=self.user_id,
- device_id=self.device_id,
- kind="email",
- app_id="m.email",
- app_display_name="Email Notifications",
- device_display_name="a@example.com",
- pushkey="a@example.com",
- lang=None,
- data={},
- )
- )
- assert isinstance(pusher, EmailPusher)
- self.pusher = pusher
-
- self.auth_handler = hs.get_auth_handler()
- self.store = hs.get_datastores().main
-
- def test_need_validated_email(self) -> None:
- """Test that we can only add an email pusher if the user has validated
- their email.
- """
- with self.assertRaises(SynapseError) as cm:
- self.get_success_or_raise(
- self.hs.get_pusherpool().add_or_update_pusher(
- user_id=self.user_id,
- device_id=self.device_id,
- kind="email",
- app_id="m.email",
- app_display_name="Email Notifications",
- device_display_name="b@example.com",
- pushkey="b@example.com",
- lang=None,
- data={},
- )
- )
-
- self.assertEqual(400, cm.exception.code)
- self.assertEqual(Codes.THREEPID_NOT_FOUND, cm.exception.errcode)
-
- def test_simple_sends_email(self) -> None:
- # Create a simple room with two users
- room = self.helper.create_room_as(self.user_id, tok=self.access_token)
- self.helper.invite(
- room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id
- )
- self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token)
-
- # The other user sends a single message.
- self.helper.send(room, body="Hi!", tok=self.others[0].token)
-
- # We should get emailed about that message
- self._check_for_mail()
-
- # The other user sends multiple messages.
- self.helper.send(room, body="Hi!", tok=self.others[0].token)
- self.helper.send(room, body="There!", tok=self.others[0].token)
-
- self._check_for_mail()
-
- @parameterized.expand([(False,), (True,)])
- def test_unsubscribe(self, use_post: bool) -> None:
- # Create a simple room with two users
- room = self.helper.create_room_as(self.user_id, tok=self.access_token)
- self.helper.invite(
- room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id
- )
- self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token)
-
- # The other user sends a single message.
- self.helper.send(room, body="Hi!", tok=self.others[0].token)
-
- # We should get emailed about that message
- args, kwargs = self._check_for_mail()
-
- # That email should contain an unsubscribe link in the body and header.
- msg: bytes = args[5]
-
- # Multipart: plain text, base 64 encoded; html, base 64 encoded
- multipart_msg = email.message_from_bytes(msg)
-
- # Extract the text (non-HTML) portion of the multipart Message,
- # as a Message.
- txt_message = multipart_msg.get_payload(i=0)
- assert isinstance(txt_message, email.message.Message)
-
- # Extract the actual bytes from the Message object, and decode them to a `str`.
- txt_bytes = txt_message.get_payload(decode=True)
- assert isinstance(txt_bytes, bytes)
- txt = txt_bytes.decode()
-
- # Do the same for the HTML portion of the multipart Message.
- html_message = multipart_msg.get_payload(i=1)
- assert isinstance(html_message, email.message.Message)
- html_bytes = html_message.get_payload(decode=True)
- assert isinstance(html_bytes, bytes)
- html = html_bytes.decode()
-
- self.assertIn("/_synapse/client/unsubscribe", txt)
- self.assertIn("/_synapse/client/unsubscribe", html)
-
- # The unsubscribe headers should exist.
- assert multipart_msg.get("List-Unsubscribe") is not None
- self.assertIsNotNone(multipart_msg.get("List-Unsubscribe-Post"))
-
- # Open the unsubscribe link.
- unsubscribe_link = multipart_msg["List-Unsubscribe"].strip("<>")
- unsubscribe_resource = UnsubscribeResource(self.hs)
- channel = make_request(
- self.reactor,
- FakeSite(unsubscribe_resource, self.reactor),
- "POST" if use_post else "GET",
- unsubscribe_link,
- shorthand=False,
- )
- self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
-
- # Ensure the pusher was removed.
- pushers = list(
- self.get_success(
- self.hs.get_datastores().main.get_pushers_by(
- {"user_name": self.user_id}
- )
- )
- )
- self.assertEqual(pushers, [])
-
- def test_invite_sends_email(self) -> None:
- # Create a room and invite the user to it
- room = self.helper.create_room_as(self.others[0].id, tok=self.others[0].token)
- self.helper.invite(
- room=room,
- src=self.others[0].id,
- tok=self.others[0].token,
- targ=self.user_id,
- )
-
- # We should get emailed about the invite
- self._check_for_mail()
-
- def test_invite_to_empty_room_sends_email(self) -> None:
- # Create a room and invite the user to it
- room = self.helper.create_room_as(self.others[0].id, tok=self.others[0].token)
- self.helper.invite(
- room=room,
- src=self.others[0].id,
- tok=self.others[0].token,
- targ=self.user_id,
- )
-
- # Then have the original user leave
- self.helper.leave(room, self.others[0].id, tok=self.others[0].token)
-
- # We should get emailed about the invite
- self._check_for_mail()
-
- def test_multiple_members_email(self) -> None:
- # We want to test multiple notifications, so we pause processing of push
- # while we send messages.
- self.pusher._pause_processing()
-
- # Create a simple room with multiple other users
- room = self.helper.create_room_as(self.user_id, tok=self.access_token)
-
- for other in self.others:
- self.helper.invite(
- room=room, src=self.user_id, tok=self.access_token, targ=other.id
- )
- self.helper.join(room=room, user=other.id, tok=other.token)
-
- # The other users send some messages
- self.helper.send(room, body="Hi!", tok=self.others[0].token)
- self.helper.send(room, body="There!", tok=self.others[1].token)
- self.helper.send(room, body="There!", tok=self.others[1].token)
-
- # Nothing should have happened yet, as we're paused.
- assert not self.email_attempts
-
- self.pusher._resume_processing()
-
- # We should get emailed about those messages
- self._check_for_mail()
-
- def test_multiple_rooms(self) -> None:
- # We want to test multiple notifications from multiple rooms, so we pause
- # processing of push while we send messages.
- self.pusher._pause_processing()
-
- # Create a simple room with multiple other users
- rooms = [
- self.helper.create_room_as(self.user_id, tok=self.access_token),
- self.helper.create_room_as(self.user_id, tok=self.access_token),
- ]
-
- for r, other in zip(rooms, self.others):
- self.helper.invite(
- room=r, src=self.user_id, tok=self.access_token, targ=other.id
- )
- self.helper.join(room=r, user=other.id, tok=other.token)
-
- # The other users send some messages
- self.helper.send(rooms[0], body="Hi!", tok=self.others[0].token)
- self.helper.send(rooms[1], body="There!", tok=self.others[1].token)
- self.helper.send(rooms[1], body="There!", tok=self.others[1].token)
-
- # Nothing should have happened yet, as we're paused.
- assert not self.email_attempts
-
- self.pusher._resume_processing()
-
- # We should get emailed about those messages
- self._check_for_mail()
-
- def test_room_notifications_include_avatar(self) -> None:
- # Create a room and set its avatar.
- room = self.helper.create_room_as(self.user_id, tok=self.access_token)
- self.helper.send_state(
- room, "m.room.avatar", {"url": "mxc://DUMMY_MEDIA_ID"}, self.access_token
- )
-
- # Invite two other uses.
- for other in self.others:
- self.helper.invite(
- room=room, src=self.user_id, tok=self.access_token, targ=other.id
- )
- self.helper.join(room=room, user=other.id, tok=other.token)
-
- # The other users send some messages.
- # TODO It seems that two messages are required to trigger an email?
- self.helper.send(room, body="Alpha", tok=self.others[0].token)
- self.helper.send(room, body="Beta", tok=self.others[1].token)
-
- # We should get emailed about those messages
- args, kwargs = self._check_for_mail()
-
- # That email should contain the room's avatar
- msg: bytes = args[5]
- # Multipart: plain text, base 64 encoded; html, base 64 encoded
-
- # Extract the html Message object from the Multipart Message.
- # We need the asserts to convince mypy that this is OK.
- html_message = email.message_from_bytes(msg).get_payload(i=1)
- assert isinstance(html_message, email.message.Message)
-
- # Extract the `bytes` from the html Message object, and decode to a `str`.
- html = html_message.get_payload(decode=True)
- assert isinstance(html, bytes)
- html = html.decode()
-
- self.assertIn("_matrix/media/v1/thumbnail/DUMMY_MEDIA_ID", html)
-
- def test_empty_room(self) -> None:
- """All users leaving a room shouldn't cause the pusher to break."""
- # Create a simple room with two users
- room = self.helper.create_room_as(self.user_id, tok=self.access_token)
- self.helper.invite(
- room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id
- )
- self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token)
-
- # The other user sends a single message.
- self.helper.send(room, body="Hi!", tok=self.others[0].token)
-
- # Leave the room before the message is processed.
- self.helper.leave(room, self.user_id, tok=self.access_token)
- self.helper.leave(room, self.others[0].id, tok=self.others[0].token)
-
- # We should get emailed about that message
- self._check_for_mail()
-
- def test_empty_room_multiple_messages(self) -> None:
- """All users leaving a room shouldn't cause the pusher to break."""
- # Create a simple room with two users
- room = self.helper.create_room_as(self.user_id, tok=self.access_token)
- self.helper.invite(
- room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id
- )
- self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token)
-
- # The other user sends a single message.
- self.helper.send(room, body="Hi!", tok=self.others[0].token)
- self.helper.send(room, body="There!", tok=self.others[0].token)
-
- # Leave the room before the message is processed.
- self.helper.leave(room, self.user_id, tok=self.access_token)
- self.helper.leave(room, self.others[0].id, tok=self.others[0].token)
-
- # We should get emailed about that message
- self._check_for_mail()
-
- def test_encrypted_message(self) -> None:
- room = self.helper.create_room_as(self.user_id, tok=self.access_token)
- self.helper.invite(
- room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id
- )
- self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token)
-
- # The other user sends some messages
- self.helper.send_event(room, "m.room.encrypted", {}, tok=self.others[0].token)
-
- # We should get emailed about that message
- self._check_for_mail()
-
- def test_no_email_sent_after_removed(self) -> None:
- # Create a simple room with two users
- room = self.helper.create_room_as(self.user_id, tok=self.access_token)
- self.helper.invite(
- room=room,
- src=self.user_id,
- tok=self.access_token,
- targ=self.others[0].id,
- )
- self.helper.join(
- room=room,
- user=self.others[0].id,
- tok=self.others[0].token,
- )
-
- # The other user sends a single message.
- self.helper.send(room, body="Hi!", tok=self.others[0].token)
-
- # We should get emailed about that message
- self._check_for_mail()
-
- # disassociate the user's email address
- self.get_success(
- self.auth_handler.delete_local_threepid(
- user_id=self.user_id, medium="email", address="a@example.com"
- )
- )
-
- # check that the pusher for that email address has been deleted
- pushers = list(
- self.get_success(
- self.hs.get_datastores().main.get_pushers_by(
- {"user_name": self.user_id}
- )
- )
- )
- self.assertEqual(len(pushers), 0)
-
- def test_remove_unlinked_pushers_background_job(self) -> None:
- """Checks that all existing pushers associated with unlinked email addresses are removed
- upon running the remove_deleted_email_pushers background update.
- """
- # disassociate the user's email address manually (without deleting the pusher).
- # This resembles the old behaviour, which the background update below is intended
- # to clean up.
- self.get_success(
- self.hs.get_datastores().main.user_delete_threepid(
- self.user_id, "email", "a@example.com"
- )
- )
-
- # Run the "remove_deleted_email_pushers" background job
- self.get_success(
- self.hs.get_datastores().main.db_pool.simple_insert(
- table="background_updates",
- values={
- "update_name": "remove_deleted_email_pushers",
- "progress_json": "{}",
- "depends_on": None,
- },
- )
- )
-
- # ... and tell the DataStore that it hasn't finished all updates yet
- self.hs.get_datastores().main.db_pool.updates._all_done = False
-
- # Now let's actually drive the updates to completion
- self.wait_for_background_updates()
-
- # Check that all pushers with unlinked addresses were deleted
- pushers = list(
- self.get_success(
- self.hs.get_datastores().main.get_pushers_by(
- {"user_name": self.user_id}
- )
- )
- )
- self.assertEqual(len(pushers), 0)
-
- def _check_for_mail(self) -> Tuple[Sequence, Dict]:
- """
- Assert that synapse sent off exactly one email notification.
-
- Returns:
- args and kwargs passed to synapse.reactor.send_email._sendmail for
- that notification.
- """
- # Get the stream ordering before it gets sent
- pushers = list(
- self.get_success(
- self.hs.get_datastores().main.get_pushers_by(
- {"user_name": self.user_id}
- )
- )
- )
- self.assertEqual(len(pushers), 1)
- last_stream_ordering = pushers[0].last_stream_ordering
-
- # Advance time a bit, so the pusher will register something has happened
- self.pump(10)
-
- # It hasn't succeeded yet, so the stream ordering shouldn't have moved
- pushers = list(
- self.get_success(
- self.hs.get_datastores().main.get_pushers_by(
- {"user_name": self.user_id}
- )
- )
- )
- self.assertEqual(len(pushers), 1)
- self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
-
- # One email was attempted to be sent
- self.assertEqual(len(self.email_attempts), 1)
-
- deferred, sendmail_args, sendmail_kwargs = self.email_attempts[0]
- # Make the email succeed
- deferred.callback(True)
- self.pump()
-
- # One email was attempted to be sent
- self.assertEqual(len(self.email_attempts), 1)
-
- # The stream ordering has increased
- pushers = list(
- self.get_success(
- self.hs.get_datastores().main.get_pushers_by(
- {"user_name": self.user_id}
- )
- )
- )
- self.assertEqual(len(pushers), 1)
- self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
-
- # Reset the attempts.
- self.email_attempts = []
- return sendmail_args, sendmail_kwargs
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index bcca472617..b42fd284b6 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -17,9 +17,11 @@
# [This file includes modifications made by New Vector Limited]
#
#
-from typing import Any, List, Tuple
+from typing import Any, Dict, List, Tuple
from unittest.mock import Mock
+from parameterized import parameterized
+
from twisted.internet.defer import Deferred
from twisted.test.proto_helpers import MemoryReactor
@@ -1085,3 +1087,161 @@ class HTTPPusherTests(HomeserverTestCase):
self.pump()
self.assertEqual(len(self.push_attempts), 11)
+
+ @parameterized.expand(
+ [
+ # Badge count disabled
+ (True, True),
+ (True, False),
+ # Badge count enabled
+ (False, True),
+ (False, False),
+ ]
+ )
+ @override_config({"experimental_features": {"msc4076_enabled": True}})
+ def test_msc4076_badge_count(
+ self, disable_badge_count: bool, event_id_only: bool
+ ) -> None:
+ # Register the user who gets notified
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Register the user who sends the message
+ other_user_id = self.register_user("otheruser", "pass")
+ other_access_token = self.login("otheruser", "pass")
+
+ # Register the pusher with disable_badge_count set to True
+ user_tuple = self.get_success(
+ self.hs.get_datastores().main.get_user_by_access_token(access_token)
+ )
+ assert user_tuple is not None
+ device_id = user_tuple.device_id
+
+ # Set the push data dict based on test input parameters
+ push_data: Dict[str, Any] = {
+ "url": "http://example.com/_matrix/push/v1/notify",
+ }
+ if disable_badge_count:
+ push_data["org.matrix.msc4076.disable_badge_count"] = True
+ if event_id_only:
+ push_data["format"] = "event_id_only"
+
+ self.get_success(
+ self.hs.get_pusherpool().add_or_update_pusher(
+ user_id=user_id,
+ device_id=device_id,
+ kind="http",
+ app_id="m.http",
+ app_display_name="HTTP Push Notifications",
+ device_display_name="pushy push",
+ pushkey="a@example.com",
+ lang=None,
+ data=push_data,
+ )
+ )
+
+ # Create a room
+ room = self.helper.create_room_as(user_id, tok=access_token)
+
+ # The other user joins
+ self.helper.join(room=room, user=other_user_id, tok=other_access_token)
+
+ # The other user sends a message
+ self.helper.send(room, body="Hi!", tok=other_access_token)
+
+ # Advance time a bit, so the pusher will register something has happened
+ self.pump()
+
+ # One push was attempted to be sent
+ self.assertEqual(len(self.push_attempts), 1)
+ self.assertEqual(
+ self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+ )
+
+ if disable_badge_count:
+ # Verify that the notification DOESN'T contain a counts field
+ self.assertNotIn("counts", self.push_attempts[0][2]["notification"])
+ else:
+ # Ensure that the notification DOES contain a counts field
+ self.assertIn("counts", self.push_attempts[0][2]["notification"])
+ self.assertEqual(
+ self.push_attempts[0][2]["notification"]["counts"]["unread"], 1
+ )
+
+ def test_push_backoff(self) -> None:
+ """
+ The HTTP pusher will backoff correctly if it fails to contact the pusher.
+ """
+
+ # Register the user who gets notified
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Register the user who sends the message
+ other_user_id = self.register_user("otheruser", "pass")
+ other_access_token = self.login("otheruser", "pass")
+
+ # Register the pusher
+ user_tuple = self.get_success(
+ self.hs.get_datastores().main.get_user_by_access_token(access_token)
+ )
+ assert user_tuple is not None
+ device_id = user_tuple.device_id
+
+ self.get_success(
+ self.hs.get_pusherpool().add_or_update_pusher(
+ user_id=user_id,
+ device_id=device_id,
+ kind="http",
+ app_id="m.http",
+ app_display_name="HTTP Push Notifications",
+ device_display_name="pushy push",
+ pushkey="a@example.com",
+ lang=None,
+ data={"url": "http://example.com/_matrix/push/v1/notify"},
+ )
+ )
+
+ # Create a room with the other user
+ room = self.helper.create_room_as(user_id, tok=access_token)
+ self.helper.join(room=room, user=other_user_id, tok=other_access_token)
+
+ # The other user sends some messages
+ self.helper.send(room, body="Message 1", tok=other_access_token)
+
+ # One push was attempted to be sent
+ self.assertEqual(len(self.push_attempts), 1)
+ self.assertEqual(
+ self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+ )
+ self.assertEqual(
+ self.push_attempts[0][2]["notification"]["content"]["body"], "Message 1"
+ )
+ self.push_attempts[0][0].callback({})
+ self.pump()
+
+ # Send another message, this time it fails
+ self.helper.send(room, body="Message 2", tok=other_access_token)
+ self.assertEqual(len(self.push_attempts), 2)
+ self.push_attempts[1][0].errback(Exception("couldn't connect"))
+ self.pump()
+
+ # Sending yet another message doesn't trigger a push immediately
+ self.helper.send(room, body="Message 3", tok=other_access_token)
+ self.pump()
+ self.assertEqual(len(self.push_attempts), 2)
+
+ # .. but waiting for a bit will cause more pushes
+ self.reactor.advance(10)
+ self.assertEqual(len(self.push_attempts), 3)
+ self.assertEqual(
+ self.push_attempts[2][2]["notification"]["content"]["body"], "Message 2"
+ )
+ self.push_attempts[2][0].callback({})
+ self.pump()
+
+ self.assertEqual(len(self.push_attempts), 4)
+ self.assertEqual(
+ self.push_attempts[3][2]["notification"]["content"]["body"], "Message 3"
+ )
+ self.push_attempts[3][0].callback({})
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 420fbea998..3898532acf 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -149,6 +149,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
content: JsonMapping,
*,
related_events: Optional[JsonDict] = None,
+ msc4210: bool = False,
) -> PushRuleEvaluator:
event = FrozenEvent(
{
@@ -174,6 +175,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
related_event_match_enabled=True,
room_version_feature_flags=event.room_version.msc3931_push_features,
msc3931_enabled=True,
+ msc4210_enabled=msc4210,
)
def test_display_name(self) -> None:
@@ -452,6 +454,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
{"value": False},
"incorrect values should not match",
)
+ value: Any
for value in ("foobaz", 1, 1.1, None, [], {}):
self._assert_not_matches(
condition,
@@ -492,6 +495,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
{"value": None},
"exact value should match",
)
+ value: Any
for value in ("foobaz", True, False, 1, 1.1, [], {}):
self._assert_not_matches(
condition,
diff --git a/tests/replication/http/test__base.py b/tests/replication/http/test__base.py
index 2eaad3707a..31d3163c01 100644
--- a/tests/replication/http/test__base.py
+++ b/tests/replication/http/test__base.py
@@ -46,7 +46,7 @@ class CancellableReplicationEndpoint(ReplicationEndpoint):
self.clock = hs.get_clock()
@staticmethod
- async def _serialize_payload() -> JsonDict:
+ async def _serialize_payload(**kwargs: ReplicationEndpoint) -> JsonDict:
return {}
@cancellable
@@ -68,7 +68,7 @@ class UncancellableReplicationEndpoint(ReplicationEndpoint):
self.clock = hs.get_clock()
@staticmethod
- async def _serialize_payload() -> JsonDict:
+ async def _serialize_payload(**kwargs: ReplicationEndpoint) -> JsonDict:
return {}
async def _handle_request( # type: ignore[override]
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index fdc74efb5a..2a0189a4e1 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -324,7 +324,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
pls = self.helper.get_state(
self.room_id, EventTypes.PowerLevels, tok=self.user_tok
)
- pls["users"].update({u: 50 for u in user_ids})
+ pls["users"].update(dict.fromkeys(user_ids, 50))
self.helper.send_state(
self.room_id,
EventTypes.PowerLevels,
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 4429d0f4e2..58a7a9dc72 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -22,14 +22,26 @@ import logging
from unittest.mock import AsyncMock, Mock
from netaddr import IPSet
+from signedjson.key import (
+ encode_verify_key_base64,
+ generate_signing_key,
+ get_verify_key,
+)
+
+from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, Membership
-from synapse.events.builder import EventBuilderFactory
+from synapse.api.room_versions import RoomVersion
+from synapse.crypto.event_signing import add_hashes_and_signatures
+from synapse.events import EventBase, make_event_from_dict
from synapse.handlers.typing import TypingWriterHandler
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.rest.admin import register_servlets_for_client_rest_resource
from synapse.rest.client import login, room
-from synapse.types import UserID, create_requester
+from synapse.server import HomeServer
+from synapse.storage.keys import FetchKeyResult
+from synapse.types import JsonDict, UserID, create_requester
+from synapse.util import Clock
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import get_clock
@@ -63,6 +75,9 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
ip_blocklist=IPSet(),
)
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.storage_controllers = hs.get_storage_controllers()
+
def test_send_event_single_sender(self) -> None:
"""Test that using a single federation sender worker correctly sends a
new event.
@@ -243,35 +258,92 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
self.assertTrue(sent_on_1)
self.assertTrue(sent_on_2)
+ def create_fake_event_from_remote_server(
+ self, remote_server_name: str, event_dict: JsonDict, room_version: RoomVersion
+ ) -> EventBase:
+ """
+ This is similar to what `FederatingHomeserverTestCase` is doing but we don't
+ need all of the extra baggage and we want to be able to create an event from
+ many remote servers.
+ """
+
+ # poke the other server's signing key into the key store, so that we don't
+ # make requests for it
+ other_server_signature_key = generate_signing_key("test")
+ verify_key = get_verify_key(other_server_signature_key)
+ verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version)
+
+ self.get_success(
+ self.hs.get_datastores().main.store_server_keys_response(
+ remote_server_name,
+ from_server=remote_server_name,
+ ts_added_ms=self.clock.time_msec(),
+ verify_keys={
+ verify_key_id: FetchKeyResult(
+ verify_key=verify_key,
+ valid_until_ts=self.clock.time_msec() + 10000,
+ ),
+ },
+ response_json={
+ "verify_keys": {
+ verify_key_id: {"key": encode_verify_key_base64(verify_key)}
+ }
+ },
+ )
+ )
+
+ add_hashes_and_signatures(
+ room_version=room_version,
+ event_dict=event_dict,
+ signature_name=remote_server_name,
+ signing_key=other_server_signature_key,
+ )
+ event = make_event_from_dict(
+ event_dict,
+ room_version=room_version,
+ )
+
+ return event
+
def create_room_with_remote_server(
self, user: str, token: str, remote_server: str = "other_server"
) -> str:
- room = self.helper.create_room_as(user, tok=token)
+ room_id = self.helper.create_room_as(user, tok=token)
store = self.hs.get_datastores().main
federation = self.hs.get_federation_event_handler()
- prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room))
- room_version = self.get_success(store.get_room_version(room))
+ room_version = self.get_success(store.get_room_version(room_id))
- factory = EventBuilderFactory(self.hs)
- factory.hostname = remote_server
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id)
+ )
+
+ # Figure out what the forward extremities in the room are (the most recent
+ # events that aren't tied into the DAG)
+ prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room_id))
user_id = UserID("user", remote_server).to_string()
- event_dict = {
- "type": EventTypes.Member,
- "state_key": user_id,
- "content": {"membership": Membership.JOIN},
- "sender": user_id,
- "room_id": room,
- }
-
- builder = factory.for_room_version(room_version, event_dict)
- join_event = self.get_success(
- builder.build(prev_event_ids=list(prev_event_ids), auth_event_ids=None)
+ join_event = self.create_fake_event_from_remote_server(
+ remote_server_name=remote_server,
+ event_dict={
+ "room_id": room_id,
+ "sender": user_id,
+ "type": EventTypes.Member,
+ "state_key": user_id,
+ "depth": 1000,
+ "origin_server_ts": 1,
+ "content": {"membership": Membership.JOIN},
+ "auth_events": [
+ state_map[(EventTypes.Create, "")].event_id,
+ state_map[(EventTypes.JoinRules, "")].event_id,
+ ],
+ "prev_events": list(prev_event_ids),
+ },
+ room_version=room_version,
)
self.get_success(federation.on_send_membership_event(remote_server, join_event))
self.replicate()
- return room
+ return room_id
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 6fc4600c41..f36af877c4 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -40,6 +40,7 @@ from tests.http import (
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import FakeChannel, FakeTransport, make_request
from tests.test_utils import SMALL_PNG
+from tests.unittest import override_config
logger = logging.getLogger(__name__)
@@ -148,6 +149,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
return channel, request
+ @override_config({"enable_authenticated_media": False})
def test_basic(self) -> None:
"""Test basic fetching of remote media from a single worker."""
hs1 = self.make_worker_hs("synapse.app.generic_worker")
@@ -164,6 +166,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.result["body"], b"Hello!")
+ @override_config({"enable_authenticated_media": False})
def test_download_simple_file_race(self) -> None:
"""Test that fetching remote media from two different processes at the
same time works.
@@ -203,6 +206,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
# We expect only one new file to have been persisted.
self.assertEqual(start_count + 1, self._count_remote_media())
+ @override_config({"enable_authenticated_media": False})
def test_download_image_race(self) -> None:
"""Test that fetching remote *images* from two different processes at
the same time works.
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 6351326fff..fc2a6c569b 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -20,7 +20,7 @@
#
import urllib.parse
-from typing import Dict
+from typing import Dict, cast
from parameterized import parameterized
@@ -30,8 +30,9 @@ from twisted.web.resource import Resource
import synapse.rest.admin
from synapse.http.server import JsonResource
from synapse.rest.admin import VersionServlet
-from synapse.rest.client import login, room
+from synapse.rest.client import login, media, room
from synapse.server import HomeServer
+from synapse.types import UserID
from synapse.util import Clock
from tests import unittest
@@ -60,6 +61,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
synapse.rest.admin.register_servlets,
synapse.rest.admin.register_servlets_for_media_repo,
login.register_servlets,
+ media.register_servlets,
room.register_servlets,
]
@@ -74,7 +76,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
"""Ensure a piece of media is quarantined when trying to access it."""
channel = self.make_request(
"GET",
- f"/_matrix/media/v3/download/{server_and_media_id}",
+ f"/_matrix/client/v1/media/download/{server_and_media_id}",
shorthand=False,
access_token=admin_user_tok,
)
@@ -131,7 +133,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Attempt to access the media
channel = self.make_request(
"GET",
- f"/_matrix/media/v3/download/{server_name_and_media_id}",
+ f"/_matrix/client/v1/media/download/{server_name_and_media_id}",
shorthand=False,
access_token=non_admin_user_tok,
)
@@ -226,10 +228,25 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Upload some media
response_1 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
response_2 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
+ response_3 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
# Extract media IDs
server_and_media_id_1 = response_1["content_uri"][6:]
server_and_media_id_2 = response_2["content_uri"][6:]
+ server_and_media_id_3 = response_3["content_uri"][6:]
+
+ # Remove the hash from the media to simulate historic media.
+ self.get_success(
+ self.hs.get_datastores().main.update_local_media(
+ media_id=server_and_media_id_3.split("/")[1],
+ media_type="image/png",
+ upload_name=None,
+ media_length=123,
+ user_id=UserID.from_string(non_admin_user),
+ # Hack to force some media to have no hash.
+ sha256=cast(str, None),
+ )
+ )
# Quarantine all media by this user
url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
@@ -243,12 +260,13 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
self.pump(1.0)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(
- channel.json_body, {"num_quarantined": 2}, "Expected 2 quarantined items"
+ channel.json_body, {"num_quarantined": 3}, "Expected 3 quarantined items"
)
# Attempt to access each piece of media
self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
self._ensure_quarantined(admin_user_tok, server_and_media_id_2)
+ self._ensure_quarantined(admin_user_tok, server_and_media_id_3)
def test_cannot_quarantine_safe_media(self) -> None:
self.register_user("user_admin", "pass", admin=True)
@@ -295,7 +313,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Attempt to access each piece of media
channel = self.make_request(
"GET",
- f"/_matrix/media/v3/download/{server_and_media_id_2}",
+ f"/_matrix/client/v1/media/download/{server_and_media_id_2}",
shorthand=False,
access_token=non_admin_user_tok,
)
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index a88c77bd19..531162a6e9 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -27,7 +27,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.errors import Codes
from synapse.handlers.device import DeviceHandler
-from synapse.rest.client import login
+from synapse.rest.client import devices, login
from synapse.server import HomeServer
from synapse.util import Clock
@@ -299,6 +299,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
class DevicesRestTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
+ devices.register_servlets,
login.register_servlets,
]
@@ -390,15 +391,63 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["devices"]))
+ @unittest.override_config(
+ {"experimental_features": {"msc2697_enabled": False, "msc3814_enabled": True}}
+ )
def test_get_devices(self) -> None:
"""
Tests that a normal lookup for devices is successfully
"""
# Create devices
number_devices = 5
- for _ in range(number_devices):
+ # we create 2 fewer devices in the loop, because we will create another
+ # login after the loop, and we will create a dehydrated device
+ for _ in range(number_devices - 2):
self.login("user", "pass")
+ other_user_token = self.login("user", "pass")
+ dehydrated_device_url = (
+ "/_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device"
+ )
+ content = {
+ "device_data": {
+ "algorithm": "m.dehydration.v1.olm",
+ },
+ "device_id": "dehydrated_device",
+ "initial_device_display_name": "foo bar",
+ "device_keys": {
+ "user_id": "@user:test",
+ "device_id": "dehydrated_device",
+ "valid_until_ts": "80",
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ ],
+ "keys": {
+ "<algorithm>:<device_id>": "<key_base64>",
+ },
+ "signatures": {
+ "@user:test": {"<algorithm>:<device_id>": "<signature_base64>"}
+ },
+ },
+ "fallback_keys": {
+ "alg1:device1": "f4llb4ckk3y",
+ "signed_<algorithm>:<device_id>": {
+ "fallback": "true",
+ "key": "f4llb4ckk3y",
+ "signatures": {
+ "@user:test": {"<algorithm>:<device_id>": "<key_base64>"}
+ },
+ },
+ },
+ "one_time_keys": {"alg1:k1": "0net1m3k3y"},
+ }
+ self.make_request(
+ "PUT",
+ dehydrated_device_url,
+ access_token=other_user_token,
+ content=content,
+ )
+
# Get devices
channel = self.make_request(
"GET",
@@ -410,13 +459,22 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(number_devices, channel.json_body["total"])
self.assertEqual(number_devices, len(channel.json_body["devices"]))
self.assertEqual(self.other_user, channel.json_body["devices"][0]["user_id"])
- # Check that all fields are available
+ # Check that all fields are available, and that the dehydrated device is marked as dehydrated
+ found_dehydrated = False
for d in channel.json_body["devices"]:
self.assertIn("user_id", d)
self.assertIn("device_id", d)
self.assertIn("display_name", d)
self.assertIn("last_seen_ip", d)
self.assertIn("last_seen_ts", d)
+ if d["device_id"] == "dehydrated_device":
+ self.assertTrue(d.get("dehydrated"))
+ found_dehydrated = True
+ else:
+ # Either the field is not present, or set to False
+ self.assertFalse(d.get("dehydrated"))
+
+ self.assertTrue(found_dehydrated)
class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index feb410a11d..6047ce1f4a 100644
--- a/tests/rest/admin/test_event_reports.py
+++ b/tests/rest/admin/test_event_reports.py
@@ -378,6 +378,41 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(channel.json_body["event_reports"]), 1)
self.assertNotIn("next_token", channel.json_body)
+ def test_filter_against_event_sender(self) -> None:
+ """
+ Tests filtering by the sender of the reported event
+ """
+ # first grab all the reports
+ channel = self.make_request(
+ "GET",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ # filter out set of report ids of events sent by one of the users
+ locally_filtered_report_ids = set()
+ for event_report in channel.json_body["event_reports"]:
+ if event_report["sender"] == self.other_user:
+ locally_filtered_report_ids.add(event_report["id"])
+
+ # grab the report ids by sender and compare to filtered report ids
+ channel = self.make_request(
+ "GET",
+ f"{self.url}?event_sender_user_id={self.other_user}",
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code)
+ self.assertEqual(channel.json_body["total"], len(locally_filtered_report_ids))
+
+ event_reports = channel.json_body["event_reports"]
+ server_filtered_report_ids = set()
+ for event_report in event_reports:
+ server_filtered_report_ids.add(event_report["id"])
+ self.assertIncludes(
+ locally_filtered_report_ids, server_filtered_report_ids, exact=True
+ )
+
def _create_event_and_report(self, room_id: str, user_tok: str) -> None:
"""Create and report events"""
resp = self.helper.send(room_id, tok=user_tok)
diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py
index c2015774a1..d5ae3345f5 100644
--- a/tests/rest/admin/test_federation.py
+++ b/tests/rest/admin/test_federation.py
@@ -96,7 +96,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
- # unkown order_by
+ # unknown order_by
channel = self.make_request(
"GET",
self.url + "?order_by=bar",
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index f378165513..da0e9749aa 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -35,7 +35,8 @@ from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
-from tests.test_utils import SMALL_PNG
+from tests.test_utils import SMALL_CMYK_JPEG, SMALL_PNG
+from tests.unittest import override_config
VALID_TIMESTAMP = 1609459200000 # 2021-01-01 in milliseconds
INVALID_TIMESTAMP_IN_S = 1893456000 # 2030-01-01 in seconds
@@ -126,6 +127,7 @@ class DeleteMediaByIDTestCase(_AdminMediaTests):
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only delete local media", channel.json_body["error"])
+ @override_config({"enable_authenticated_media": False})
def test_delete_media(self) -> None:
"""
Tests that delete a media is successfully
@@ -371,6 +373,7 @@ class DeleteMediaByDateSizeTestCase(_AdminMediaTests):
self._access_media(server_and_media_id, False)
+ @override_config({"enable_authenticated_media": False})
def test_keep_media_by_date(self) -> None:
"""
Tests that media is not deleted if it is newer than `before_ts`
@@ -408,6 +411,7 @@ class DeleteMediaByDateSizeTestCase(_AdminMediaTests):
self._access_media(server_and_media_id, False)
+ @override_config({"enable_authenticated_media": False})
def test_keep_media_by_size(self) -> None:
"""
Tests that media is not deleted if its size is smaller than or equal
@@ -443,6 +447,7 @@ class DeleteMediaByDateSizeTestCase(_AdminMediaTests):
self._access_media(server_and_media_id, False)
+ @override_config({"enable_authenticated_media": False})
def test_keep_media_by_user_avatar(self) -> None:
"""
Tests that we do not delete media if is used as a user avatar
@@ -487,6 +492,7 @@ class DeleteMediaByDateSizeTestCase(_AdminMediaTests):
self._access_media(server_and_media_id, False)
+ @override_config({"enable_authenticated_media": False})
def test_keep_media_by_room_avatar(self) -> None:
"""
Tests that we do not delete media if it is used as a room avatar
@@ -592,23 +598,27 @@ class DeleteMediaByDateSizeTestCase(_AdminMediaTests):
class QuarantineMediaByIDTestCase(_AdminMediaTests):
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
- self.server_name = hs.hostname
-
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
+ def upload_media_and_return_media_id(self, data: bytes) -> str:
# Upload some media into the room
response = self.helper.upload_media(
- SMALL_PNG,
+ data,
tok=self.admin_user_tok,
expect_code=200,
)
# Extract media ID from the response
server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
- self.media_id = server_and_media_id.split("/")[1]
+ return server_and_media_id.split("/")[1]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.server_name = hs.hostname
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+ self.media_id = self.upload_media_and_return_media_id(SMALL_PNG)
+ self.media_id_2 = self.upload_media_and_return_media_id(SMALL_PNG)
+ self.media_id_3 = self.upload_media_and_return_media_id(SMALL_PNG)
+ self.media_id_other = self.upload_media_and_return_media_id(SMALL_CMYK_JPEG)
self.url = "/_synapse/admin/v1/media/%s/%s/%s"
@parameterized.expand(["quarantine", "unquarantine"])
@@ -680,6 +690,52 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
assert media_info is not None
self.assertFalse(media_info.quarantined_by)
+ def test_quarantine_media_match_hash(self) -> None:
+ """
+ Tests that quarantining removes all media with the same hash
+ """
+
+ media_info = self.get_success(self.store.get_local_media(self.media_id))
+ assert media_info is not None
+ self.assertFalse(media_info.quarantined_by)
+
+ # quarantining
+ channel = self.make_request(
+ "POST",
+ self.url % ("quarantine", self.server_name, self.media_id),
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertFalse(channel.json_body)
+
+ # Test that ALL similar media was quarantined.
+ for media in [self.media_id, self.media_id_2, self.media_id_3]:
+ media_info = self.get_success(self.store.get_local_media(media))
+ assert media_info is not None
+ self.assertTrue(media_info.quarantined_by)
+
+ # Test that other media was not.
+ media_info = self.get_success(self.store.get_local_media(self.media_id_other))
+ assert media_info is not None
+ self.assertFalse(media_info.quarantined_by)
+
+ # remove from quarantine
+ channel = self.make_request(
+ "POST",
+ self.url % ("unquarantine", self.server_name, self.media_id),
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertFalse(channel.json_body)
+
+ # Test that ALL similar media is now reset.
+ for media in [self.media_id, self.media_id_2, self.media_id_3]:
+ media_info = self.get_success(self.store.get_local_media(media))
+ assert media_info is not None
+ self.assertFalse(media_info.quarantined_by)
+
def test_quarantine_protected_media(self) -> None:
"""
Tests that quarantining from protected media fails
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 95ed736451..e22dfcba1b 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -369,6 +369,47 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self._is_blocked(room_id)
+ def test_invited_users_not_joined_to_new_room(self) -> None:
+ """
+ Test that when a new room id is provided, users who are only invited
+ but have not joined original room are not moved to new room.
+ """
+ invitee = self.register_user("invitee", "pass")
+
+ self.helper.invite(
+ self.room_id, self.other_user, invitee, tok=self.other_user_tok
+ )
+
+ # verify that user is invited
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v3/rooms/{self.room_id}/members?membership=invite",
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(len(channel.json_body["chunk"]), 1)
+ invite = channel.json_body["chunk"][0]
+ self.assertEqual(invite["state_key"], invitee)
+
+ # shutdown room
+ channel = self.make_request(
+ "DELETE",
+ self.url,
+ {"new_room_user_id": self.admin_user},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(len(channel.json_body["kicked_users"]), 2)
+
+ # joined member is moved to new room but invited user is not
+ users_in_room = self.get_success(
+ self.store.get_users_in_room(channel.json_body["new_room_id"])
+ )
+ self.assertNotIn(invitee, users_in_room)
+ self.assertIn(self.other_user, users_in_room)
+ self._is_purged(self.room_id)
+ self._has_no_members(self.room_id)
+
def test_shutdown_room_consent(self) -> None:
"""Test that we can shutdown rooms with local users who have not
yet accepted the privacy policy. This used to fail when we tried to
@@ -758,6 +799,8 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self.assertEqual(2, len(channel.json_body["results"]))
self.assertEqual("complete", channel.json_body["results"][0]["status"])
self.assertEqual("complete", channel.json_body["results"][1]["status"])
+ self.assertEqual(self.room_id, channel.json_body["results"][0]["room_id"])
+ self.assertEqual(self.room_id, channel.json_body["results"][1]["room_id"])
delete_ids = {delete_id1, delete_id2}
self.assertTrue(channel.json_body["results"][0]["delete_id"] in delete_ids)
delete_ids.remove(channel.json_body["results"][0]["delete_id"])
@@ -777,6 +820,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self.assertEqual(1, len(channel.json_body["results"]))
self.assertEqual("complete", channel.json_body["results"][0]["status"])
self.assertEqual(delete_id2, channel.json_body["results"][0]["delete_id"])
+ self.assertEqual(self.room_id, channel.json_body["results"][0]["room_id"])
# get status after more than clearing time for all tasks
self.reactor.advance(TaskScheduler.KEEP_TASKS_FOR_MS / 1000 / 2)
@@ -1237,6 +1281,9 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
self.assertEqual(
delete_id, channel_room_id.json_body["results"][0]["delete_id"]
)
+ self.assertEqual(
+ self.room_id, channel_room_id.json_body["results"][0]["room_id"]
+ )
# get information by delete_id
channel_delete_id = self.make_request(
@@ -1249,6 +1296,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
channel_delete_id.code,
msg=channel_delete_id.json_body,
)
+ self.assertEqual(self.room_id, channel_delete_id.json_body["room_id"])
# test values that are the same in both responses
for content in [
@@ -1282,6 +1330,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
+ @unittest.override_config({"room_list_publication_rules": [{"action": "allow"}]})
def test_list_rooms(self) -> None:
"""Test that we can list rooms"""
# Create 3 test rooms
@@ -1311,7 +1360,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
# Check that response json body contains a "rooms" key
self.assertTrue(
"rooms" in channel.json_body,
- msg="Response body does not " "contain a 'rooms' key",
+ msg="Response body does not contain a 'rooms' key",
)
# Check that 3 rooms were returned
@@ -1795,6 +1844,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(room_id, channel.json_body["rooms"][0].get("room_id"))
self.assertEqual("ж", channel.json_body["rooms"][0].get("name"))
+ @unittest.override_config({"room_list_publication_rules": [{"action": "allow"}]})
def test_filter_public_rooms(self) -> None:
self.helper.create_room_as(
self.admin_user, tok=self.admin_user_tok, is_public=True
@@ -1872,6 +1922,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(1, response.json_body["total_rooms"])
self.assertEqual(1, len(response.json_body["rooms"]))
+ @unittest.override_config({"room_list_publication_rules": [{"action": "allow"}]})
def test_single_room(self) -> None:
"""Test that a single room can be requested correctly"""
# Create two test rooms
@@ -2035,6 +2086,52 @@ class RoomTestCase(unittest.HomeserverTestCase):
# the create_room already does the right thing, so no need to verify that we got
# the state events it created.
+ def test_room_state_param(self) -> None:
+ """Test that filtering by state event type works when requesting state"""
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/rooms/{room_id}/state?type=m.room.member",
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code)
+ state = channel.json_body["state"]
+ # only one member has joined so there should be one membership event
+ self.assertEqual(1, len(state))
+ event = state[0]
+ self.assertEqual(event["type"], "m.room.member")
+ self.assertEqual(event["state_key"], self.admin_user)
+
+ def test_room_state_param_empty(self) -> None:
+ """Test that passing an empty string as state filter param returns no state events"""
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/rooms/{room_id}/state?type=",
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code)
+ state = channel.json_body["state"]
+ self.assertEqual(5, len(state))
+
+ def test_room_state_param_not_in_room(self) -> None:
+ """
+ Test that passing a state filter param for a state event not in the room
+ returns no state events
+ """
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/rooms/{room_id}/state?type=m.room.custom",
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code)
+ state = channel.json_body["state"]
+ self.assertEqual(0, len(state))
+
def _set_canonical_alias(
self, room_id: str, test_alias: str, admin_user_tok: str
) -> None:
@@ -3050,7 +3147,7 @@ PURGE_TABLES = [
"pusher_throttle",
"room_account_data",
"room_tags",
- # "state_groups", # Current impl leaves orphaned state groups around.
+ "state_groups",
"state_groups_state",
"federation_inbound_events_staging",
]
diff --git a/tests/rest/admin/test_scheduled_tasks.py b/tests/rest/admin/test_scheduled_tasks.py
new file mode 100644
index 0000000000..9654e9322b
--- /dev/null
+++ b/tests/rest/admin/test_scheduled_tasks.py
@@ -0,0 +1,192 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2025 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+#
+#
+from typing import Mapping, Optional, Tuple
+
+from twisted.test.proto_helpers import MemoryReactor
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client import login
+from synapse.server import HomeServer
+from synapse.types import JsonMapping, ScheduledTask, TaskStatus
+from synapse.util import Clock
+
+from tests import unittest
+
+
+class ScheduledTasksAdminApiTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+ self._task_scheduler = hs.get_task_scheduler()
+
+ # create and schedule a few tasks
+ async def _test_task(
+ task: ScheduledTask,
+ ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
+ return TaskStatus.ACTIVE, None, None
+
+ async def _finished_test_task(
+ task: ScheduledTask,
+ ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
+ return TaskStatus.COMPLETE, None, None
+
+ async def _failed_test_task(
+ task: ScheduledTask,
+ ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
+ return TaskStatus.FAILED, None, "Everything failed"
+
+ self._task_scheduler.register_action(_test_task, "test_task")
+ self.get_success(
+ self._task_scheduler.schedule_task("test_task", resource_id="test")
+ )
+
+ self._task_scheduler.register_action(_finished_test_task, "finished_test_task")
+ self.get_success(
+ self._task_scheduler.schedule_task(
+ "finished_test_task", resource_id="finished_task"
+ )
+ )
+
+ self._task_scheduler.register_action(_failed_test_task, "failed_test_task")
+ self.get_success(
+ self._task_scheduler.schedule_task(
+ "failed_test_task", resource_id="failed_task"
+ )
+ )
+
+ def check_scheduled_tasks_response(self, scheduled_tasks: Mapping) -> list:
+ result = []
+ for task in scheduled_tasks:
+ if task["resource_id"] == "test":
+ self.assertEqual(task["status"], TaskStatus.ACTIVE)
+ self.assertEqual(task["action"], "test_task")
+ result.append(task)
+ if task["resource_id"] == "finished_task":
+ self.assertEqual(task["status"], TaskStatus.COMPLETE)
+ self.assertEqual(task["action"], "finished_test_task")
+ result.append(task)
+ if task["resource_id"] == "failed_task":
+ self.assertEqual(task["status"], TaskStatus.FAILED)
+ self.assertEqual(task["action"], "failed_test_task")
+ result.append(task)
+
+ return result
+
+ def test_requester_is_not_admin(self) -> None:
+ """
+ If the user is not a server admin, an error 403 is returned.
+ """
+
+ self.register_user("user", "pass", admin=False)
+ other_user_tok = self.login("user", "pass")
+
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/scheduled_tasks",
+ content={},
+ access_token=other_user_tok,
+ )
+
+ self.assertEqual(403, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_scheduled_tasks(self) -> None:
+ """
+ Test that endpoint returns scheduled tasks.
+ """
+
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/scheduled_tasks",
+ content={},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ scheduled_tasks = channel.json_body["scheduled_tasks"]
+
+ # make sure we got back all the scheduled tasks
+ found_tasks = self.check_scheduled_tasks_response(scheduled_tasks)
+ self.assertEqual(len(found_tasks), 3)
+
+ def test_filtering_scheduled_tasks(self) -> None:
+ """
+ Test that filtering the scheduled tasks response via query params works as expected.
+ """
+ # filter via job_status
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/scheduled_tasks?job_status=active",
+ content={},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ scheduled_tasks = channel.json_body["scheduled_tasks"]
+ found_tasks = self.check_scheduled_tasks_response(scheduled_tasks)
+
+ # only the active task should have been returned
+ self.assertEqual(len(found_tasks), 1)
+ self.assertEqual(found_tasks[0]["status"], "active")
+
+ # filter via action_name
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/scheduled_tasks?action_name=test_task",
+ content={},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ scheduled_tasks = channel.json_body["scheduled_tasks"]
+
+ # only test_task should have been returned
+ found_tasks = self.check_scheduled_tasks_response(scheduled_tasks)
+ self.assertEqual(len(found_tasks), 1)
+ self.assertEqual(found_tasks[0]["action"], "test_task")
+
+ # filter via max_timestamp
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/scheduled_tasks?max_timestamp=0",
+ content={},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ scheduled_tasks = channel.json_body["scheduled_tasks"]
+ found_tasks = self.check_scheduled_tasks_response(scheduled_tasks)
+
+ # none should have been returned
+ self.assertEqual(len(found_tasks), 0)
+
+ # filter via resource id
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/scheduled_tasks?resource_id=failed_task",
+ content={},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ scheduled_tasks = channel.json_body["scheduled_tasks"]
+ found_tasks = self.check_scheduled_tasks_response(scheduled_tasks)
+
+ # only the task with the matching resource id should have been returned
+ self.assertEqual(len(found_tasks), 1)
+ self.assertEqual(found_tasks[0]["resource_id"], "failed_task")
diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
index 2a1e42bbc8..150caeeee2 100644
--- a/tests/rest/admin/test_server_notice.py
+++ b/tests/rest/admin/test_server_notice.py
@@ -531,9 +531,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
# simulate a change in server config after a server restart.
new_display_name = "new display name"
- self.server_notices_manager._config.servernotices.server_notices_mxid_display_name = (
- new_display_name
- )
+ self.server_notices_manager._config.servernotices.server_notices_mxid_display_name = new_display_name
self.server_notices_manager.get_or_create_notice_room_for_user.cache.invalidate_all()
self.make_request(
@@ -577,9 +575,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
# simulate a change in server config after a server restart.
new_avatar_url = "test/new-url"
- self.server_notices_manager._config.servernotices.server_notices_mxid_avatar_url = (
- new_avatar_url
- )
+ self.server_notices_manager._config.servernotices.server_notices_mxid_avatar_url = new_avatar_url
self.server_notices_manager.get_or_create_notice_room_for_user.cache.invalidate_all()
self.make_request(
@@ -692,9 +688,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
# simulate a change in server config after a server restart.
new_avatar_url = "test/new-url"
- self.server_notices_manager._config.servernotices.server_notices_room_avatar_url = (
- new_avatar_url
- )
+ self.server_notices_manager._config.servernotices.server_notices_room_avatar_url = new_avatar_url
self.server_notices_manager.get_or_create_notice_room_for_user.cache.invalidate_all()
self.make_request(
diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py
index 5f60e19e56..07ec49c4e5 100644
--- a/tests/rest/admin/test_statistics.py
+++ b/tests/rest/admin/test_statistics.py
@@ -82,7 +82,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
"""
If parameters are invalid, an error is returned.
"""
- # unkown order_by
+ # unknown order_by
channel = self.make_request(
"GET",
self.url + "?order_by=bar",
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 16bb4349f5..412718f06c 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -21,9 +21,12 @@
import hashlib
import hmac
+import json
import os
+import time
import urllib.parse
from binascii import unhexlify
+from http import HTTPStatus
from typing import Dict, List, Optional
from unittest.mock import AsyncMock, Mock, patch
@@ -33,7 +36,13 @@ from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
import synapse.rest.admin
-from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes
+from synapse.api.constants import (
+ ApprovalNoticeMedium,
+ EventContentFields,
+ EventTypes,
+ LoginType,
+ UserTypes,
+)
from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
from synapse.api.room_versions import RoomVersions
from synapse.media.filepath import MediaFilePaths
@@ -42,6 +51,7 @@ from synapse.rest.client import (
devices,
login,
logout,
+ media,
profile,
register,
room,
@@ -54,7 +64,9 @@ from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
from tests import unittest
+from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.test_utils import SMALL_PNG
+from tests.test_utils.event_injection import inject_event
from tests.unittest import override_config
@@ -316,6 +328,61 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid user type", channel.json_body["error"])
+ @override_config(
+ {
+ "user_types": {
+ "extra_user_types": ["extra1", "extra2"],
+ }
+ }
+ )
+ def test_extra_user_type(self) -> None:
+ """
+ Check that the extra user type can be used when registering a user.
+ """
+
+ def nonce_mac(user_type: str) -> tuple[str, str]:
+ """
+ Get a nonce and the expected HMAC for that nonce.
+ """
+ channel = self.make_request("GET", self.url)
+ nonce = channel.json_body["nonce"]
+
+ want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
+ want_mac.update(
+ nonce.encode("ascii")
+ + b"\x00alice\x00abc123\x00notadmin\x00"
+ + user_type.encode("ascii")
+ )
+ want_mac_str = want_mac.hexdigest()
+
+ return nonce, want_mac_str
+
+ nonce, mac = nonce_mac("extra1")
+ # Valid user_type
+ body = {
+ "nonce": nonce,
+ "username": "alice",
+ "password": "abc123",
+ "user_type": "extra1",
+ "mac": mac,
+ }
+ channel = self.make_request("POST", self.url, body)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ nonce, mac = nonce_mac("extra3")
+ # Invalid user_type
+ body = {
+ "nonce": nonce,
+ "username": "alice",
+ "password": "abc123",
+ "user_type": "extra3",
+ "mac": mac,
+ }
+ channel = self.make_request("POST", self.url, body)
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Invalid user type", channel.json_body["error"])
+
def test_displayname(self) -> None:
"""
Test that displayname of new user is set
@@ -715,7 +782,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
- # unkown order_by
+ # unknown order_by
channel = self.make_request(
"GET",
self.url + "?order_by=bar",
@@ -1174,6 +1241,80 @@ class UsersListTestCase(unittest.HomeserverTestCase):
not_user_types=["custom"],
)
+ @override_config(
+ {
+ "user_types": {
+ "extra_user_types": ["extra1", "extra2"],
+ }
+ }
+ )
+ def test_filter_not_user_types_with_extra(self) -> None:
+ """Tests that the endpoint handles the not_user_types param when extra_user_types are configured"""
+
+ regular_user_id = self.register_user("normalo", "secret")
+
+ extra1_user_id = self.register_user("extra1", "secret")
+ self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/" + urllib.parse.quote(extra1_user_id),
+ {"user_type": "extra1"},
+ access_token=self.admin_user_tok,
+ )
+
+ def test_user_type(
+ expected_user_ids: List[str], not_user_types: Optional[List[str]] = None
+ ) -> None:
+ """Runs a test for the not_user_types param
+ Args:
+ expected_user_ids: Ids of the users that are expected to be returned
+ not_user_types: List of values for the not_user_types param
+ """
+
+ user_type_query = ""
+
+ if not_user_types is not None:
+ user_type_query = "&".join(
+ [f"not_user_type={u}" for u in not_user_types]
+ )
+
+ test_url = f"{self.url}?{user_type_query}"
+ channel = self.make_request(
+ "GET",
+ test_url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code)
+ self.assertEqual(channel.json_body["total"], len(expected_user_ids))
+ self.assertEqual(
+ expected_user_ids,
+ [u["name"] for u in channel.json_body["users"]],
+ )
+
+ # Request without user_types → all users expected
+ test_user_type([self.admin_user, extra1_user_id, regular_user_id])
+
+ # Request and exclude extra1 user type
+ test_user_type(
+ [self.admin_user, regular_user_id],
+ not_user_types=["extra1"],
+ )
+
+ # Request and exclude extra1 and extra2 user types
+ test_user_type(
+ [self.admin_user, regular_user_id],
+ not_user_types=["extra1", "extra2"],
+ )
+
+ # Request and exclude empty user types → only expected the extra1 user
+ test_user_type([extra1_user_id], not_user_types=[""])
+
+ # Request and exclude an unregistered type → expect all users
+ test_user_type(
+ [self.admin_user, extra1_user_id, regular_user_id],
+ not_user_types=["extra3"],
+ )
+
def test_erasure_status(self) -> None:
# Create a new user.
user_id = self.register_user("eraseme", "eraseme")
@@ -1411,9 +1552,6 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
UserID.from_string("@user:test"), "mxc://servername/mediaid"
)
)
- self.get_success(
- self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0)
- )
def test_no_auth(self) -> None:
"""
@@ -1500,7 +1638,6 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
- self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User1", channel.json_body["displayname"])
self.assertFalse(channel.json_body["erased"])
@@ -1525,7 +1662,6 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
- self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertIsNone(channel.json_body["avatar_url"])
self.assertIsNone(channel.json_body["displayname"])
self.assertTrue(channel.json_body["erased"])
@@ -1568,7 +1704,6 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
- self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User1", channel.json_body["displayname"])
@@ -1592,7 +1727,6 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
- self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User1", channel.json_body["displayname"])
@@ -1622,7 +1756,6 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
- self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
self.assertIsNone(channel.json_body["avatar_url"])
self.assertIsNone(channel.json_body["displayname"])
@@ -1646,7 +1779,6 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
- self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertIsNone(channel.json_body["avatar_url"])
self.assertIsNone(channel.json_body["displayname"])
@@ -1817,25 +1949,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
- # threepids not valid
- channel = self.make_request(
- "PUT",
- self.url_other_user,
- access_token=self.admin_user_tok,
- content={"threepids": {"medium": "email", "wrong_address": "id"}},
- )
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
-
- channel = self.make_request(
- "PUT",
- self.url_other_user,
- access_token=self.admin_user_tok,
- content={"threepids": {"address": "value"}},
- )
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
-
def test_get_user(self) -> None:
"""
Test a simple get of a user.
@@ -1890,8 +2003,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertTrue(channel.json_body["admin"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
self._check_fields(channel.json_body)
@@ -1906,8 +2017,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertTrue(channel.json_body["admin"])
self.assertFalse(channel.json_body["is_guest"])
self.assertFalse(channel.json_body["deactivated"])
@@ -1945,9 +2054,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
- self.assertEqual(1, len(channel.json_body["threepids"]))
self.assertEqual(
"external_id1", channel.json_body["external_ids"][0]["external_id"]
)
@@ -1969,8 +2075,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertFalse(channel.json_body["admin"])
self.assertFalse(channel.json_body["is_guest"])
self.assertFalse(channel.json_body["deactivated"])
@@ -2062,123 +2166,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertFalse(channel.json_body["admin"])
- @override_config(
- {
- "email": {
- "enable_notifs": True,
- "notif_for_new_users": True,
- "notif_from": "test@example.com",
- },
- "public_baseurl": "https://example.com",
- }
- )
- def test_create_user_email_notif_for_new_users(self) -> None:
- """
- Check that a new regular user is created successfully and
- got an email pusher.
- """
- url = self.url_prefix % "@bob:test"
-
- # Create user
- body = {
- "password": "abc123",
- # Note that the given email is not in canonical form.
- "threepids": [{"medium": "email", "address": "Bob@bob.bob"}],
- }
-
- channel = self.make_request(
- "PUT",
- url,
- access_token=self.admin_user_tok,
- content=body,
- )
-
- self.assertEqual(201, channel.code, msg=channel.json_body)
- self.assertEqual("@bob:test", channel.json_body["name"])
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
-
- pushers = list(
- self.get_success(self.store.get_pushers_by({"user_name": "@bob:test"}))
- )
- self.assertEqual(len(pushers), 1)
- self.assertEqual("@bob:test", pushers[0].user_name)
-
- @override_config(
- {
- "email": {
- "enable_notifs": False,
- "notif_for_new_users": False,
- "notif_from": "test@example.com",
- },
- "public_baseurl": "https://example.com",
- }
- )
- def test_create_user_email_no_notif_for_new_users(self) -> None:
- """
- Check that a new regular user is created successfully and
- got not an email pusher.
- """
- url = self.url_prefix % "@bob:test"
-
- # Create user
- body = {
- "password": "abc123",
- "threepids": [{"medium": "email", "address": "bob@bob.bob"}],
- }
-
- channel = self.make_request(
- "PUT",
- url,
- access_token=self.admin_user_tok,
- content=body,
- )
-
- self.assertEqual(201, channel.code, msg=channel.json_body)
- self.assertEqual("@bob:test", channel.json_body["name"])
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
-
- pushers = list(
- self.get_success(self.store.get_pushers_by({"user_name": "@bob:test"}))
- )
- self.assertEqual(len(pushers), 0)
-
- @override_config(
- {
- "email": {
- "enable_notifs": True,
- "notif_for_new_users": True,
- "notif_from": "test@example.com",
- },
- "public_baseurl": "https://example.com",
- }
- )
- def test_create_user_email_notif_for_new_users_with_msisdn_threepid(self) -> None:
- """
- Check that a new regular user is created successfully when they have a msisdn
- threepid and email notif_for_new_users is set to True.
- """
- url = self.url_prefix % "@bob:test"
-
- # Create user
- body = {
- "password": "abc123",
- "threepids": [{"medium": "msisdn", "address": "1234567890"}],
- }
-
- channel = self.make_request(
- "PUT",
- url,
- access_token=self.admin_user_tok,
- content=body,
- )
-
- self.assertEqual(201, channel.code, msg=channel.json_body)
- self.assertEqual("@bob:test", channel.json_body["name"])
- self.assertEqual("msisdn", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("1234567890", channel.json_body["threepids"][0]["address"])
-
def test_set_password(self) -> None:
"""
Test setting a new password for another user.
@@ -2222,89 +2209,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("foobar", channel.json_body["displayname"])
-
- def test_set_threepid(self) -> None:
- """
- Test setting threepid for an other user.
- """
-
- # Add two threepids to user
- channel = self.make_request(
- "PUT",
- self.url_other_user,
- access_token=self.admin_user_tok,
- content={
- "threepids": [
- {"medium": "email", "address": "bob1@bob.bob"},
- {"medium": "email", "address": "bob2@bob.bob"},
- ],
- },
- )
-
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(2, len(channel.json_body["threepids"]))
- # result does not always have the same sort order, therefore it becomes sorted
- sorted_result = sorted(
- channel.json_body["threepids"], key=lambda k: k["address"]
- )
- self.assertEqual("email", sorted_result[0]["medium"])
- self.assertEqual("bob1@bob.bob", sorted_result[0]["address"])
- self.assertEqual("email", sorted_result[1]["medium"])
- self.assertEqual("bob2@bob.bob", sorted_result[1]["address"])
- self._check_fields(channel.json_body)
-
- # Set a new and remove a threepid
- channel = self.make_request(
- "PUT",
- self.url_other_user,
- access_token=self.admin_user_tok,
- content={
- "threepids": [
- {"medium": "email", "address": "bob2@bob.bob"},
- {"medium": "email", "address": "bob3@bob.bob"},
- ],
- },
- )
-
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(2, len(channel.json_body["threepids"]))
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"])
- self.assertEqual("email", channel.json_body["threepids"][1]["medium"])
- self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"])
- self._check_fields(channel.json_body)
-
- # Get user
- channel = self.make_request(
- "GET",
- self.url_other_user,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(2, len(channel.json_body["threepids"]))
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"])
- self.assertEqual("email", channel.json_body["threepids"][1]["medium"])
- self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"])
- self._check_fields(channel.json_body)
-
- # Remove threepids
- channel = self.make_request(
- "PUT",
- self.url_other_user,
- access_token=self.admin_user_tok,
- content={"threepids": []},
- )
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(0, len(channel.json_body["threepids"]))
- self._check_fields(channel.json_body)
-
- def test_set_duplicate_threepid(self) -> None:
"""
Test setting the same threepid for a second user.
First user loses and second user gets mapping of this threepid.
@@ -2328,9 +2232,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(first_user, channel.json_body["name"])
- self.assertEqual(1, len(channel.json_body["threepids"]))
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob1@bob.bob", channel.json_body["threepids"][0]["address"])
self._check_fields(channel.json_body)
# Add threepids to other user
@@ -2347,9 +2248,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(1, len(channel.json_body["threepids"]))
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"])
self._check_fields(channel.json_body)
# Add two new threepids to other user
@@ -2369,15 +2267,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# other user has this two threepids
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(2, len(channel.json_body["threepids"]))
- # result does not always have the same sort order, therefore it becomes sorted
- sorted_result = sorted(
- channel.json_body["threepids"], key=lambda k: k["address"]
- )
- self.assertEqual("email", sorted_result[0]["medium"])
- self.assertEqual("bob1@bob.bob", sorted_result[0]["address"])
- self.assertEqual("email", sorted_result[1]["medium"])
- self.assertEqual("bob3@bob.bob", sorted_result[1]["address"])
self._check_fields(channel.json_body)
# first_user has no threepid anymore
@@ -2388,7 +2277,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(first_user, channel.json_body["name"])
- self.assertEqual(0, len(channel.json_body["threepids"]))
self._check_fields(channel.json_body)
def test_set_external_id(self) -> None:
@@ -2623,9 +2511,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
UserID.from_string("@user:test"), "mxc://servername/mediaid"
)
)
- self.get_success(
- self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0)
- )
# Get user
channel = self.make_request(
@@ -2637,7 +2522,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
- self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User", channel.json_body["displayname"])
@@ -2652,7 +2536,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
- self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User", channel.json_body["displayname"])
@@ -2671,7 +2554,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
- self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User", channel.json_body["displayname"])
@@ -2965,22 +2847,18 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["admin"])
- def test_set_user_type(self) -> None:
- """
- Test changing user type.
- """
-
- # Set to support type
+ def set_user_type(self, user_type: Optional[str]) -> None:
+ # Set to user_type
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content={"user_type": UserTypes.SUPPORT},
+ content={"user_type": user_type},
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"])
+ self.assertEqual(user_type, channel.json_body["user_type"])
# Get user
channel = self.make_request(
@@ -2991,30 +2869,44 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"])
+ self.assertEqual(user_type, channel.json_body["user_type"])
+
+ def test_set_user_type(self) -> None:
+ """
+ Test changing user type.
+ """
+
+ # Set to support type
+ self.set_user_type(UserTypes.SUPPORT)
# Change back to a regular user
- channel = self.make_request(
- "PUT",
- self.url_other_user,
- access_token=self.admin_user_tok,
- content={"user_type": None},
- )
+ self.set_user_type(None)
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual("@user:test", channel.json_body["name"])
- self.assertIsNone(channel.json_body["user_type"])
+ @override_config({"user_types": {"extra_user_types": ["extra1", "extra2"]}})
+ def test_set_user_type_with_extras(self) -> None:
+ """
+ Test changing user type with extra_user_types configured.
+ """
- # Get user
+ # Check that we can still set to support type
+ self.set_user_type(UserTypes.SUPPORT)
+
+ # Check that we can set to an extra user type
+ self.set_user_type("extra2")
+
+ # Change back to a regular user
+ self.set_user_type(None)
+
+ # Try setting to invalid type
channel = self.make_request(
- "GET",
+ "PUT",
self.url_other_user,
access_token=self.admin_user_tok,
+ content={"user_type": "extra3"},
)
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual("@user:test", channel.json_body["name"])
- self.assertIsNone(channel.json_body["user_type"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual("Invalid user type", channel.json_body["error"])
def test_accidental_deactivation_prevention(self) -> None:
"""
@@ -3204,7 +3096,6 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content: Content dictionary to check
"""
self.assertIn("displayname", content)
- self.assertIn("threepids", content)
self.assertIn("avatar_url", content)
self.assertIn("admin", content)
self.assertIn("deactivated", content)
@@ -3217,6 +3108,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertIn("consent_ts", content)
self.assertIn("external_ids", content)
self.assertIn("last_seen_ts", content)
+ self.assertIn("suspended", content)
# This key was removed intentionally. Ensure it is not accidentally re-included.
self.assertNotIn("password_hash", content)
@@ -3513,6 +3405,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
+ media.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -3692,7 +3585,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
@parameterized.expand(["GET", "DELETE"])
def test_invalid_parameter(self, method: str) -> None:
"""If parameters are invalid, an error is returned."""
- # unkown order_by
+ # unknown order_by
channel = self.make_request(
method,
self.url + "?order_by=bar",
@@ -3887,9 +3780,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
image_data1 = SMALL_PNG
# Resolution: 1×1, MIME type: image/gif, Extension: gif, Size: 35 B
image_data2 = unhexlify(
- b"47494638376101000100800100000000"
- b"ffffff2c00000000010001000002024c"
- b"01003b"
+ b"47494638376101000100800100000000ffffff2c00000000010001000002024c01003b"
)
# Resolution: 1×1, MIME type: image/bmp, Extension: bmp, Size: 54 B
image_data3 = unhexlify(
@@ -4019,7 +3910,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
# Try to access a media and to create `last_access_ts`
channel = self.make_request(
"GET",
- f"/_matrix/media/v3/download/{server_and_media_id}",
+ f"/_matrix/client/v1/media/download/{server_and_media_id}",
shorthand=False,
access_token=user_token,
)
@@ -4858,100 +4749,6 @@ class UsersByExternalIdTestCase(unittest.HomeserverTestCase):
)
-class UsersByThreePidTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- ]
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
-
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- self.other_user = self.register_user("user", "pass")
- self.get_success(
- self.store.user_add_threepid(
- self.other_user, "email", "user@email.com", 1, 1
- )
- )
- self.get_success(
- self.store.user_add_threepid(self.other_user, "msidn", "+1-12345678", 1, 1)
- )
-
- def test_no_auth(self) -> None:
- """Try to look up a user without authentication."""
- url = "/_synapse/admin/v1/threepid/email/users/user%40email.com"
-
- channel = self.make_request(
- "GET",
- url,
- )
-
- self.assertEqual(401, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
-
- def test_medium_does_not_exist(self) -> None:
- """Tests that both a lookup for a medium that does not exist and a user that
- doesn't exist with that third party ID returns a 404"""
- # test for unknown medium
- url = "/_synapse/admin/v1/threepid/publickey/users/unknown-key"
-
- channel = self.make_request(
- "GET",
- url,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(404, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
-
- # test for unknown user with a known medium
- url = "/_synapse/admin/v1/threepid/email/users/unknown"
-
- channel = self.make_request(
- "GET",
- url,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(404, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
-
- def test_success(self) -> None:
- """Tests a successful medium + address lookup"""
- # test for email medium with encoded value of user@email.com
- url = "/_synapse/admin/v1/threepid/email/users/user%40email.com"
-
- channel = self.make_request(
- "GET",
- url,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual(
- {"user_id": self.other_user},
- channel.json_body,
- )
-
- # test for msidn medium with encoded value of +1-12345678
- url = "/_synapse/admin/v1/threepid/msidn/users/%2B1-12345678"
-
- channel = self.make_request(
- "GET",
- url,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual(
- {"user_id": self.other_user},
- channel.json_body,
- )
-
-
class AllowCrossSigningReplacementTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
@@ -5024,7 +4821,6 @@ class UserSuspensionTestCase(unittest.HomeserverTestCase):
self.store = hs.get_datastores().main
- @override_config({"experimental_features": {"msc3823_account_suspension": True}})
def test_suspend_user(self) -> None:
# test that suspending user works
channel = self.make_request(
@@ -5089,3 +4885,766 @@ class UserSuspensionTestCase(unittest.HomeserverTestCase):
res5 = self.get_success(self.store.get_user_suspended_status(self.bad_user))
self.assertEqual(True, res5)
+
+
+class UserRedactionTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ admin.register_servlets,
+ room.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.admin = self.register_user("thomas", "pass", True)
+ self.admin_tok = self.login("thomas", "pass")
+
+ self.bad_user = self.register_user("teresa", "pass")
+ self.bad_user_tok = self.login("teresa", "pass")
+
+ self.store = hs.get_datastores().main
+
+ self.spam_checker = hs.get_module_api_callbacks().spam_checker
+
+ # create rooms - room versions 11+ store the `redacts` key in content while
+ # earlier ones don't so we use a mix of room versions
+ self.rm1 = self.helper.create_room_as(
+ self.admin, tok=self.admin_tok, room_version="7"
+ )
+ self.rm2 = self.helper.create_room_as(self.admin, tok=self.admin_tok)
+ self.rm3 = self.helper.create_room_as(
+ self.admin, tok=self.admin_tok, room_version="11"
+ )
+
+ def test_redact_messages_all_rooms(self) -> None:
+ """
+ Test that request to redact events in all rooms user is member of is successful
+ """
+ # join rooms, send some messages
+ originals = []
+ for rm in [self.rm1, self.rm2, self.rm3]:
+ join = self.helper.join(rm, self.bad_user, tok=self.bad_user_tok)
+ originals.append(join["event_id"])
+ for i in range(15):
+ event = {"body": f"hello{i}", "msgtype": "m.text"}
+ res = self.helper.send_event(
+ rm, "m.room.message", event, tok=self.bad_user_tok, expect_code=200
+ )
+ originals.append(res["event_id"])
+
+ # redact all events in all rooms
+ channel = self.make_request(
+ "POST",
+ f"/_synapse/admin/v1/user/{self.bad_user}/redact",
+ content={"rooms": []},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ matched = []
+ for rm in [self.rm1, self.rm2, self.rm3]:
+ filter = json.dumps({"types": [EventTypes.Redaction]})
+ channel = self.make_request(
+ "GET",
+ f"rooms/{rm}/messages?filter={filter}&limit=50",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ for event in channel.json_body["chunk"]:
+ for event_id in originals:
+ if (
+ event["type"] == "m.room.redaction"
+ and event["redacts"] == event_id
+ ):
+ matched.append(event_id)
+ self.assertEqual(len(matched), len(originals))
+
+ def test_redact_messages_specific_rooms(self) -> None:
+ """
+ Test that request to redact events in specified rooms user is member of is successful
+ """
+
+ originals = []
+ for rm in [self.rm1, self.rm2, self.rm3]:
+ join = self.helper.join(rm, self.bad_user, tok=self.bad_user_tok)
+ originals.append(join["event_id"])
+ for i in range(15):
+ event = {"body": f"hello{i}", "msgtype": "m.text"}
+ res = self.helper.send_event(
+ rm, "m.room.message", event, tok=self.bad_user_tok
+ )
+ originals.append(res["event_id"])
+
+ # redact messages in rooms 1 and 3
+ channel = self.make_request(
+ "POST",
+ f"/_synapse/admin/v1/user/{self.bad_user}/redact",
+ content={"rooms": [self.rm1, self.rm3]},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ # messages in requested rooms are redacted
+ for rm in [self.rm1, self.rm3]:
+ filter = json.dumps({"types": [EventTypes.Redaction]})
+ channel = self.make_request(
+ "GET",
+ f"rooms/{rm}/messages?filter={filter}&limit=50",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ matches = []
+ for event in channel.json_body["chunk"]:
+ for event_id in originals:
+ if (
+ event["type"] == "m.room.redaction"
+ and event["redacts"] == event_id
+ ):
+ matches.append((event_id, event))
+ # we redacted 16 messages
+ self.assertEqual(len(matches), 16)
+
+ channel = self.make_request(
+ "GET", f"rooms/{self.rm2}/messages?limit=50", access_token=self.admin_tok
+ )
+ self.assertEqual(channel.code, 200)
+
+ # messages in remaining room are not
+ for event in channel.json_body["chunk"]:
+ if event["type"] == "m.room.redaction":
+ self.fail("found redaction in room 2")
+
+ def test_redact_status(self) -> None:
+ rm2_originals = []
+ for rm in [self.rm1, self.rm2, self.rm3]:
+ join = self.helper.join(rm, self.bad_user, tok=self.bad_user_tok)
+ if rm == self.rm2:
+ rm2_originals.append(join["event_id"])
+ for i in range(5):
+ event = {"body": f"hello{i}", "msgtype": "m.text"}
+ res = self.helper.send_event(
+ rm, "m.room.message", event, tok=self.bad_user_tok
+ )
+ if rm == self.rm2:
+ rm2_originals.append(res["event_id"])
+
+ # redact messages in rooms 1 and 3
+ channel = self.make_request(
+ "POST",
+ f"/_synapse/admin/v1/user/{self.bad_user}/redact",
+ content={"rooms": [self.rm1, self.rm3]},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ id = channel.json_body.get("redact_id")
+
+ channel2 = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/user/redact_status/{id}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel2.code, 200)
+ self.assertEqual(channel2.json_body.get("status"), "complete")
+ self.assertEqual(channel2.json_body.get("failed_redactions"), {})
+
+ # mock that will cause persisting the redaction events to fail
+ async def check_event_for_spam(event: str) -> str:
+ return "spam"
+
+ self.spam_checker.check_event_for_spam = check_event_for_spam # type: ignore
+
+ channel3 = self.make_request(
+ "POST",
+ f"/_synapse/admin/v1/user/{self.bad_user}/redact",
+ content={"rooms": [self.rm2]},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ id = channel3.json_body.get("redact_id")
+
+ channel4 = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/user/redact_status/{id}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel4.code, 200)
+ self.assertEqual(channel4.json_body.get("status"), "complete")
+ failed_redactions = channel4.json_body.get("failed_redactions")
+ assert failed_redactions is not None
+ matched = []
+ for original in rm2_originals:
+ if failed_redactions.get(original) is not None:
+ matched.append(original)
+ self.assertEqual(len(matched), len(rm2_originals))
+
+ def test_admin_redact_works_if_user_kicked_or_banned(self) -> None:
+ originals1 = []
+ originals2 = []
+ for rm in [self.rm1, self.rm2, self.rm3]:
+ join = self.helper.join(rm, self.bad_user, tok=self.bad_user_tok)
+ if rm in [self.rm1, self.rm3]:
+ originals1.append(join["event_id"])
+ else:
+ originals2.append(join["event_id"])
+ for i in range(5):
+ event = {"body": f"hello{i}", "msgtype": "m.text"}
+ res = self.helper.send_event(
+ rm, "m.room.message", event, tok=self.bad_user_tok
+ )
+ if rm in [self.rm1, self.rm3]:
+ originals1.append(res["event_id"])
+ else:
+ originals2.append(res["event_id"])
+
+ # kick user from rooms 1 and 3
+ for r in [self.rm1, self.rm3]:
+ channel = self.make_request(
+ "POST",
+ f"/_matrix/client/r0/rooms/{r}/kick",
+ content={"reason": "being a bummer", "user_id": self.bad_user},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+ # redact messages in room 1 and 3
+ channel1 = self.make_request(
+ "POST",
+ f"/_synapse/admin/v1/user/{self.bad_user}/redact",
+ content={"rooms": [self.rm1, self.rm3]},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel1.code, 200)
+ id = channel1.json_body.get("redact_id")
+
+ # check that there were no failed redactions in room 1 and 3
+ channel2 = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/user/redact_status/{id}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel2.code, 200)
+ self.assertEqual(channel2.json_body.get("status"), "complete")
+ failed_redactions = channel2.json_body.get("failed_redactions")
+ self.assertEqual(failed_redactions, {})
+
+ # double check
+ for rm in [self.rm1, self.rm3]:
+ filter = json.dumps({"types": [EventTypes.Redaction]})
+ channel3 = self.make_request(
+ "GET",
+ f"rooms/{rm}/messages?filter={filter}&limit=50",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel3.code, 200)
+
+ matches = []
+ for event in channel3.json_body["chunk"]:
+ for event_id in originals1:
+ if (
+ event["type"] == "m.room.redaction"
+ and event["redacts"] == event_id
+ ):
+ matches.append((event_id, event))
+ # we redacted 6 messages
+ self.assertEqual(len(matches), 6)
+
+ # ban user from room 2
+ channel4 = self.make_request(
+ "POST",
+ f"/_matrix/client/r0/rooms/{self.rm2}/ban",
+ content={"reason": "being a bummer", "user_id": self.bad_user},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel4.code, HTTPStatus.OK, channel4.result)
+
+ # make a request to ban all user's messages
+ channel5 = self.make_request(
+ "POST",
+ f"/_synapse/admin/v1/user/{self.bad_user}/redact",
+ content={"rooms": []},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel5.code, 200)
+ id2 = channel5.json_body.get("redact_id")
+
+ # check that there were no failed redactions in room 2
+ channel6 = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/user/redact_status/{id2}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel6.code, 200)
+ self.assertEqual(channel6.json_body.get("status"), "complete")
+ failed_redactions = channel6.json_body.get("failed_redactions")
+ self.assertEqual(failed_redactions, {})
+
+ # double check messages in room 2 were redacted
+ filter = json.dumps({"types": [EventTypes.Redaction]})
+ channel7 = self.make_request(
+ "GET",
+ f"rooms/{self.rm2}/messages?filter={filter}&limit=50",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel7.code, 200)
+
+ matches = []
+ for event in channel7.json_body["chunk"]:
+ for event_id in originals2:
+ if event["type"] == "m.room.redaction" and event["redacts"] == event_id:
+ matches.append((event_id, event))
+ # we redacted 6 messages
+ self.assertEqual(len(matches), 6)
+
+ def test_redactions_for_remote_user_succeed_with_admin_priv_in_room(self) -> None:
+ """
+ Test that if the admin requester has privileges in a room, redaction requests
+ succeed for a remote user
+ """
+
+ # inject some messages from remote user and collect event ids
+ original_message_ids = []
+ for i in range(5):
+ event = self.get_success(
+ inject_event(
+ self.hs,
+ room_id=self.rm1,
+ type="m.room.message",
+ sender="@remote:remote_server",
+ content={"msgtype": "m.text", "body": f"nefarious_chatter{i}"},
+ )
+ )
+ original_message_ids.append(event.event_id)
+
+ # send a request to redact a remote user's messages in a room.
+ # the server admin created this room and has admin privilege in room
+ channel = self.make_request(
+ "POST",
+ "/_synapse/admin/v1/user/@remote:remote_server/redact",
+ content={"rooms": [self.rm1]},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ id = channel.json_body.get("redact_id")
+
+ # check that there were no failed redactions
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/user/redact_status/{id}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body.get("status"), "complete")
+ failed_redactions = channel.json_body.get("failed_redactions")
+ self.assertEqual(failed_redactions, {})
+
+ filter = json.dumps({"types": [EventTypes.Redaction]})
+ channel = self.make_request(
+ "GET",
+ f"rooms/{self.rm1}/messages?filter={filter}&limit=50",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ for event in channel.json_body["chunk"]:
+ for event_id in original_message_ids:
+ if event["type"] == "m.room.redaction" and event["redacts"] == event_id:
+ original_message_ids.remove(event_id)
+ break
+ # we originally sent 5 messages so 5 should be redacted
+ self.assertEqual(len(original_message_ids), 0)
+
+ def test_redact_redacts_encrypted_messages(self) -> None:
+ """
+ Test that user's encrypted messages are redacted
+ """
+ encrypted_room = self.helper.create_room_as(
+ self.admin, tok=self.admin_tok, room_version="7"
+ )
+ self.helper.send_state(
+ encrypted_room,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=self.admin_tok,
+ )
+ # join room send some messages
+ originals = []
+ join = self.helper.join(encrypted_room, self.bad_user, tok=self.bad_user_tok)
+ originals.append(join["event_id"])
+ for _ in range(15):
+ res = self.helper.send_event(
+ encrypted_room, "m.room.encrypted", {}, tok=self.bad_user_tok
+ )
+ originals.append(res["event_id"])
+
+ # redact user's events
+ channel = self.make_request(
+ "POST",
+ f"/_synapse/admin/v1/user/{self.bad_user}/redact",
+ content={"rooms": []},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ matched = []
+ filter = json.dumps({"types": [EventTypes.Redaction]})
+ channel = self.make_request(
+ "GET",
+ f"rooms/{encrypted_room}/messages?filter={filter}&limit=50",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ for event in channel.json_body["chunk"]:
+ for event_id in originals:
+ if event["type"] == "m.room.redaction" and event["redacts"] == event_id:
+ matched.append(event_id)
+ self.assertEqual(len(matched), len(originals))
+
+
+class UserRedactionBackgroundTaskTestCase(BaseMultiWorkerStreamTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ admin.register_servlets,
+ room.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.admin = self.register_user("thomas", "pass", True)
+ self.admin_tok = self.login("thomas", "pass")
+
+ self.bad_user = self.register_user("teresa", "pass")
+ self.bad_user_tok = self.login("teresa", "pass")
+
+ # create rooms - room versions 11+ store the `redacts` key in content while
+ # earlier ones don't so we use a mix of room versions
+ self.rm1 = self.helper.create_room_as(
+ self.admin, tok=self.admin_tok, room_version="7"
+ )
+ self.rm2 = self.helper.create_room_as(self.admin, tok=self.admin_tok)
+ self.rm3 = self.helper.create_room_as(
+ self.admin, tok=self.admin_tok, room_version="11"
+ )
+
+ @override_config({"run_background_tasks_on": "worker1"})
+ def test_redact_messages_all_rooms(self) -> None:
+ """
+ Test that redact task successfully runs when `run_background_tasks_on` is specified
+ """
+ self.make_worker_hs(
+ "synapse.app.generic_worker",
+ extra_config={
+ "worker_name": "worker1",
+ "run_background_tasks_on": "worker1",
+ "redis": {"enabled": True},
+ },
+ )
+
+ # join rooms, send some messages
+ original_event_ids = set()
+ for rm in [self.rm1, self.rm2, self.rm3]:
+ join = self.helper.join(rm, self.bad_user, tok=self.bad_user_tok)
+ original_event_ids.add(join["event_id"])
+ for i in range(15):
+ event = {"body": f"hello{i}", "msgtype": "m.text"}
+ res = self.helper.send_event(
+ rm, "m.room.message", event, tok=self.bad_user_tok, expect_code=200
+ )
+ original_event_ids.add(res["event_id"])
+
+ # redact all events in all rooms
+ channel = self.make_request(
+ "POST",
+ f"/_synapse/admin/v1/user/{self.bad_user}/redact",
+ content={"rooms": []},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ id = channel.json_body.get("redact_id")
+
+ timeout_s = 10
+ start_time = time.time()
+ redact_result = ""
+ while redact_result != "complete":
+ if start_time + timeout_s < time.time():
+ self.fail("Timed out waiting for redactions.")
+
+ channel2 = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/user/redact_status/{id}",
+ access_token=self.admin_tok,
+ )
+ redact_result = channel2.json_body["status"]
+ if redact_result == "failed":
+ self.fail("Redaction task failed.")
+
+ redaction_ids = set()
+ for rm in [self.rm1, self.rm2, self.rm3]:
+ filter = json.dumps({"types": [EventTypes.Redaction]})
+ channel = self.make_request(
+ "GET",
+ f"rooms/{rm}/messages?filter={filter}&limit=50",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ for event in channel.json_body["chunk"]:
+ if event["type"] == "m.room.redaction":
+ redaction_ids.add(event["redacts"])
+
+ self.assertIncludes(redaction_ids, original_event_ids, exact=True)
+
+
+class GetInvitesFromUserTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ admin.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.admin = self.register_user("thomas", "pass", True)
+ self.admin_tok = self.login("thomas", "pass")
+
+ self.bad_user = self.register_user("teresa", "pass")
+ self.bad_user_tok = self.login("teresa", "pass")
+
+ self.random_users = []
+ for i in range(4):
+ self.random_users.append(self.register_user(f"user{i}", f"pass{i}"))
+
+ self.room1 = self.helper.create_room_as(self.bad_user, tok=self.bad_user_tok)
+ self.room2 = self.helper.create_room_as(self.bad_user, tok=self.bad_user_tok)
+ self.room3 = self.helper.create_room_as(self.bad_user, tok=self.bad_user_tok)
+
+ @unittest.override_config(
+ {"rc_invites": {"per_issuer": {"per_second": 1000, "burst_count": 1000}}}
+ )
+ def test_get_user_invite_count_new_invites_test_case(self) -> None:
+ """
+ Test that new invites that arrive after a provided timestamp are counted
+ """
+ # grab a current timestamp
+ before_invites_sent_ts = self.hs.get_clock().time_msec()
+
+ # bad user sends some invites
+ for room_id in [self.room1, self.room2]:
+ for user in self.random_users:
+ self.helper.invite(room_id, self.bad_user, user, tok=self.bad_user_tok)
+
+ # fetch using timestamp, all should be returned
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/users/{self.bad_user}/sent_invite_count?from_ts={before_invites_sent_ts}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["invite_count"], 8)
+
+ # send some more invites, they should show up in addition to original 8 using same timestamp
+ for user in self.random_users:
+ self.helper.invite(
+ self.room3, src=self.bad_user, targ=user, tok=self.bad_user_tok
+ )
+
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/users/{self.bad_user}/sent_invite_count?from_ts={before_invites_sent_ts}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["invite_count"], 12)
+
+ def test_get_user_invite_count_invites_before_ts_test_case(self) -> None:
+ """
+ Test that invites sent before provided ts are not counted
+ """
+ # bad user sends some invites
+ for room_id in [self.room1, self.room2]:
+ for user in self.random_users:
+ self.helper.invite(room_id, self.bad_user, user, tok=self.bad_user_tok)
+
+ # add a msec between last invite and ts
+ after_invites_sent_ts = self.hs.get_clock().time_msec() + 1
+
+ # fetch invites with timestamp, none should be returned
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/users/{self.bad_user}/sent_invite_count?from_ts={after_invites_sent_ts}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["invite_count"], 0)
+
+ def test_user_invite_count_kick_ban_not_counted(self) -> None:
+ """
+ Test that kicks and bans are not counted in invite count
+ """
+ to_kick_user_id = self.register_user("kick_me", "pass")
+ to_kick_tok = self.login("kick_me", "pass")
+
+ self.helper.join(self.room1, to_kick_user_id, tok=to_kick_tok)
+
+ # grab a current timestamp
+ before_invites_sent_ts = self.hs.get_clock().time_msec()
+
+ # bad user sends some invites (8)
+ for room_id in [self.room1, self.room2]:
+ for user in self.random_users:
+ self.helper.invite(
+ room_id, src=self.bad_user, targ=user, tok=self.bad_user_tok
+ )
+
+ # fetch using timestamp, all invites sent should be counted
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/users/{self.bad_user}/sent_invite_count?from_ts={before_invites_sent_ts}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["invite_count"], 8)
+
+ # send a kick and some bans and make sure these aren't counted against invite total
+ for user in self.random_users:
+ self.helper.ban(
+ self.room1, src=self.bad_user, targ=user, tok=self.bad_user_tok
+ )
+
+ channel = self.make_request(
+ "POST",
+ f"/_matrix/client/v3/rooms/{self.room1}/kick",
+ content={"user_id": to_kick_user_id},
+ access_token=self.bad_user_tok,
+ )
+ self.assertEqual(channel.code, 200)
+
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/users/{self.bad_user}/sent_invite_count?from_ts={before_invites_sent_ts}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["invite_count"], 8)
+
+
+class GetCumulativeJoinedRoomCountForUserTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ admin.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.admin = self.register_user("thomas", "pass", True)
+ self.admin_tok = self.login("thomas", "pass")
+
+ self.bad_user = self.register_user("teresa", "pass")
+ self.bad_user_tok = self.login("teresa", "pass")
+
+ def test_user_cumulative_joined_room_count(self) -> None:
+ """
+ Tests proper count returned from /cumulative_joined_room_count endpoint
+ """
+ # Create rooms and join, grab timestamp before room creation
+ before_room_creation_timestamp = self.hs.get_clock().time_msec()
+
+ joined_rooms = []
+ for _ in range(3):
+ room = self.helper.create_room_as(self.admin, tok=self.admin_tok)
+ self.helper.join(
+ room, user=self.bad_user, expect_code=200, tok=self.bad_user_tok
+ )
+ joined_rooms.append(room)
+
+ # get a timestamp after room creation and join, add a msec between last join and ts
+ after_room_creation = self.hs.get_clock().time_msec() + 1
+
+ # Get rooms using this timestamp, there should be none since all rooms were created and joined
+ # before provided timestamp
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/users/{self.bad_user}/cumulative_joined_room_count?from_ts={int(after_room_creation)}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(0, channel.json_body["cumulative_joined_room_count"])
+
+ # fetch rooms with the older timestamp before they were created and joined, this should
+ # return the rooms
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/users/{self.bad_user}/cumulative_joined_room_count?from_ts={int(before_room_creation_timestamp)}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ len(joined_rooms), channel.json_body["cumulative_joined_room_count"]
+ )
+
+ def test_user_joined_room_count_includes_left_and_banned_rooms(self) -> None:
+ """
+ Tests proper count returned from /joined_room_count endpoint when user has left
+ or been banned from joined rooms
+ """
+ # Create rooms and join, grab timestamp before room creation
+ before_room_creation_timestamp = self.hs.get_clock().time_msec()
+
+ joined_rooms = []
+ for _ in range(3):
+ room = self.helper.create_room_as(self.admin, tok=self.admin_tok)
+ self.helper.join(
+ room, user=self.bad_user, expect_code=200, tok=self.bad_user_tok
+ )
+ joined_rooms.append(room)
+
+ # fetch rooms with the older timestamp before they were created and joined
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/users/{self.bad_user}/cumulative_joined_room_count?from_ts={int(before_room_creation_timestamp)}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ len(joined_rooms), channel.json_body["cumulative_joined_room_count"]
+ )
+
+ # have the user banned from/leave the joined rooms
+ self.helper.ban(
+ joined_rooms[0],
+ src=self.admin,
+ targ=self.bad_user,
+ expect_code=200,
+ tok=self.admin_tok,
+ )
+ self.helper.change_membership(
+ joined_rooms[1],
+ src=self.bad_user,
+ targ=self.bad_user,
+ membership="leave",
+ expect_code=200,
+ tok=self.bad_user_tok,
+ )
+ self.helper.ban(
+ joined_rooms[2],
+ src=self.admin,
+ targ=self.bad_user,
+ expect_code=200,
+ tok=self.admin_tok,
+ )
+
+ # fetch the joined room count again, the number should remain the same as the collected joined rooms
+ channel = self.make_request(
+ "GET",
+ f"/_synapse/admin/v1/users/{self.bad_user}/cumulative_joined_room_count?from_ts={int(before_room_creation_timestamp)}",
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ len(joined_rooms), channel.json_body["cumulative_joined_room_count"]
+ )
diff --git a/tests/rest/client/sliding_sync/test_connection_tracking.py b/tests/rest/client/sliding_sync/test_connection_tracking.py
index 6863c32f7c..5b819103c2 100644
--- a/tests/rest/client/sliding_sync/test_connection_tracking.py
+++ b/tests/rest/client/sliding_sync/test_connection_tracking.py
@@ -13,7 +13,7 @@
#
import logging
-from parameterized import parameterized
+from parameterized import parameterized, parameterized_class
from twisted.test.proto_helpers import MemoryReactor
@@ -28,6 +28,20 @@ from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase
logger = logging.getLogger(__name__)
+# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+# foreground update for
+# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+# https://github.com/element-hq/synapse/issues/17623)
+@parameterized_class(
+ ("use_new_tables",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}",
+)
class SlidingSyncConnectionTrackingTestCase(SlidingSyncBase):
"""
Test connection tracking in the Sliding Sync API.
@@ -44,6 +58,8 @@ class SlidingSyncConnectionTrackingTestCase(SlidingSyncBase):
self.store = hs.get_datastores().main
self.storage_controllers = hs.get_storage_controllers()
+ super().prepare(reactor, clock, hs)
+
def test_rooms_required_state_incremental_sync_LIVE(self) -> None:
"""Test that we only get state updates in incremental sync for rooms
we've already seen (LIVE).
diff --git a/tests/rest/client/sliding_sync/test_extension_account_data.py b/tests/rest/client/sliding_sync/test_extension_account_data.py
index 3482a5f887..799fbb1856 100644
--- a/tests/rest/client/sliding_sync/test_extension_account_data.py
+++ b/tests/rest/client/sliding_sync/test_extension_account_data.py
@@ -11,8 +11,12 @@
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
+import enum
import logging
+from parameterized import parameterized, parameterized_class
+from typing_extensions import assert_never
+
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -28,6 +32,25 @@ from tests.server import TimedOutException
logger = logging.getLogger(__name__)
+class TagAction(enum.Enum):
+ ADD = enum.auto()
+ REMOVE = enum.auto()
+
+
+# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+# foreground update for
+# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+# https://github.com/element-hq/synapse/issues/17623)
+@parameterized_class(
+ ("use_new_tables",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}",
+)
class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
"""Tests for the account_data sliding sync extension"""
@@ -43,6 +66,8 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
self.store = hs.get_datastores().main
self.account_data_handler = hs.get_account_data_handler()
+ super().prepare(reactor, clock, hs)
+
def test_no_data_initial_sync(self) -> None:
"""
Test that enabling the account_data extension works during an intitial sync,
@@ -62,18 +87,23 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
}
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ global_account_data_map = {
+ global_event["type"]: global_event["content"]
+ for global_event in response_body["extensions"]["account_data"].get(
+ "global"
+ )
+ }
self.assertIncludes(
- {
- global_event["type"]
- for global_event in response_body["extensions"]["account_data"].get(
- "global"
- )
- },
+ global_account_data_map.keys(),
# Even though we don't have any global account data set, Synapse saves some
# default push rules for us.
{AccountDataTypes.PUSH_RULES},
exact=True,
)
+ # Push rules are a giant chunk of JSON data so we will just assume the value is correct if they key is here.
+ # global_account_data_map[AccountDataTypes.PUSH_RULES]
+
+ # No room account data for this test
self.assertIncludes(
response_body["extensions"]["account_data"].get("rooms").keys(),
set(),
@@ -103,16 +133,19 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
# There has been no account data changes since the `from_token` so we shouldn't
# see any account data here.
+ global_account_data_map = {
+ global_event["type"]: global_event["content"]
+ for global_event in response_body["extensions"]["account_data"].get(
+ "global"
+ )
+ }
self.assertIncludes(
- {
- global_event["type"]
- for global_event in response_body["extensions"]["account_data"].get(
- "global"
- )
- },
+ global_account_data_map.keys(),
set(),
exact=True,
)
+
+ # No room account data for this test
self.assertIncludes(
response_body["extensions"]["account_data"].get("rooms").keys(),
set(),
@@ -147,16 +180,24 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
# It should show us all of the global account data
+ global_account_data_map = {
+ global_event["type"]: global_event["content"]
+ for global_event in response_body["extensions"]["account_data"].get(
+ "global"
+ )
+ }
self.assertIncludes(
- {
- global_event["type"]
- for global_event in response_body["extensions"]["account_data"].get(
- "global"
- )
- },
+ global_account_data_map.keys(),
{AccountDataTypes.PUSH_RULES, "org.matrix.foobarbaz"},
exact=True,
)
+ # Push rules are a giant chunk of JSON data so we will just assume the value is correct if they key is here.
+ # global_account_data_map[AccountDataTypes.PUSH_RULES]
+ self.assertEqual(
+ global_account_data_map["org.matrix.foobarbaz"], {"foo": "bar"}
+ )
+
+ # No room account data for this test
self.assertIncludes(
response_body["extensions"]["account_data"].get("rooms").keys(),
set(),
@@ -202,17 +243,23 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
# Make an incremental Sliding Sync request with the account_data extension enabled
response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+ global_account_data_map = {
+ global_event["type"]: global_event["content"]
+ for global_event in response_body["extensions"]["account_data"].get(
+ "global"
+ )
+ }
self.assertIncludes(
- {
- global_event["type"]
- for global_event in response_body["extensions"]["account_data"].get(
- "global"
- )
- },
+ global_account_data_map.keys(),
# We should only see the new global account data that happened after the `from_token`
{"org.matrix.doodardaz"},
exact=True,
)
+ self.assertEqual(
+ global_account_data_map["org.matrix.doodardaz"], {"doo": "dar"}
+ )
+
+ # No room account data for this test
self.assertIncludes(
response_body["extensions"]["account_data"].get("rooms").keys(),
set(),
@@ -237,6 +284,15 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
content={"roo": "rar"},
)
)
+ # Add a room tag to mark the room as a favourite
+ self.get_success(
+ self.account_data_handler.add_tag_to_room(
+ user_id=user1_id,
+ room_id=room_id1,
+ tag="m.favourite",
+ content={},
+ )
+ )
# Create another room with some room account data
room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok)
@@ -248,6 +304,15 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
content={"roo": "rar"},
)
)
+ # Add a room tag to mark the room as a favourite
+ self.get_success(
+ self.account_data_handler.add_tag_to_room(
+ user_id=user1_id,
+ room_id=room_id2,
+ tag="m.favourite",
+ content={},
+ )
+ )
# Make an initial Sliding Sync request with the account_data extension enabled
sync_body = {
@@ -276,21 +341,36 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
{room_id1},
exact=True,
)
+ account_data_map = {
+ event["type"]: event["content"]
+ for event in response_body["extensions"]["account_data"]
+ .get("rooms")
+ .get(room_id1)
+ }
self.assertIncludes(
- {
- event["type"]
- for event in response_body["extensions"]["account_data"]
- .get("rooms")
- .get(room_id1)
- },
- {"org.matrix.roorarraz"},
+ account_data_map.keys(),
+ {"org.matrix.roorarraz", AccountDataTypes.TAG},
exact=True,
)
+ self.assertEqual(account_data_map["org.matrix.roorarraz"], {"roo": "rar"})
+ self.assertEqual(
+ account_data_map[AccountDataTypes.TAG], {"tags": {"m.favourite": {}}}
+ )
- def test_room_account_data_incremental_sync(self) -> None:
+ @parameterized.expand(
+ [
+ ("add tags", TagAction.ADD),
+ ("remove tags", TagAction.REMOVE),
+ ]
+ )
+ def test_room_account_data_incremental_sync(
+ self, test_description: str, tag_action: TagAction
+ ) -> None:
"""
On incremental sync, we return all account data for a given room but only for
rooms that we request and are being returned in the Sliding Sync response.
+
+ (HaveSentRoomFlag.LIVE)
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
@@ -305,6 +385,15 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
content={"roo": "rar"},
)
)
+ # Add a room tag to mark the room as a favourite
+ self.get_success(
+ self.account_data_handler.add_tag_to_room(
+ user_id=user1_id,
+ room_id=room_id1,
+ tag="m.favourite",
+ content={},
+ )
+ )
# Create another room with some room account data
room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok)
@@ -316,6 +405,15 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
content={"roo": "rar"},
)
)
+ # Add a room tag to mark the room as a favourite
+ self.get_success(
+ self.account_data_handler.add_tag_to_room(
+ user_id=user1_id,
+ room_id=room_id2,
+ tag="m.favourite",
+ content={},
+ )
+ )
sync_body = {
"lists": {},
@@ -351,6 +449,42 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
content={"roo": "rar"},
)
)
+ if tag_action == TagAction.ADD:
+ # Add another room tag
+ self.get_success(
+ self.account_data_handler.add_tag_to_room(
+ user_id=user1_id,
+ room_id=room_id1,
+ tag="m.server_notice",
+ content={},
+ )
+ )
+ self.get_success(
+ self.account_data_handler.add_tag_to_room(
+ user_id=user1_id,
+ room_id=room_id2,
+ tag="m.server_notice",
+ content={},
+ )
+ )
+ elif tag_action == TagAction.REMOVE:
+ # Remove the room tag
+ self.get_success(
+ self.account_data_handler.remove_tag_from_room(
+ user_id=user1_id,
+ room_id=room_id1,
+ tag="m.favourite",
+ )
+ )
+ self.get_success(
+ self.account_data_handler.remove_tag_from_room(
+ user_id=user1_id,
+ room_id=room_id2,
+ tag="m.favourite",
+ )
+ )
+ else:
+ assert_never(tag_action)
# Make an incremental Sliding Sync request with the account_data extension enabled
response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
@@ -365,17 +499,444 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
exact=True,
)
# We should only see the new room account data that happened after the `from_token`
+ account_data_map = {
+ event["type"]: event["content"]
+ for event in response_body["extensions"]["account_data"]
+ .get("rooms")
+ .get(room_id1)
+ }
self.assertIncludes(
+ account_data_map.keys(),
+ {"org.matrix.roorarraz2", AccountDataTypes.TAG},
+ exact=True,
+ )
+ self.assertEqual(account_data_map["org.matrix.roorarraz2"], {"roo": "rar"})
+ if tag_action == TagAction.ADD:
+ self.assertEqual(
+ account_data_map[AccountDataTypes.TAG],
+ {"tags": {"m.favourite": {}, "m.server_notice": {}}},
+ )
+ elif tag_action == TagAction.REMOVE:
+ # If we previously showed the client that the room has tags, when it no
+ # longer has tags, we need to show them an empty map.
+ self.assertEqual(
+ account_data_map[AccountDataTypes.TAG],
+ {"tags": {}},
+ )
+ else:
+ assert_never(tag_action)
+
+ @parameterized.expand(
+ [
+ ("add tags", TagAction.ADD),
+ ("remove tags", TagAction.REMOVE),
+ ]
+ )
+ def test_room_account_data_incremental_sync_out_of_range_never(
+ self, test_description: str, tag_action: TagAction
+ ) -> None:
+ """Tests that we don't return account data for rooms that are out of
+ range, but then do send all account data once they're in range.
+
+ (initial/HaveSentRoomFlag.NEVER)
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a room and add some room account data
+ room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok)
+ self.get_success(
+ self.account_data_handler.add_account_data_to_room(
+ user_id=user1_id,
+ room_id=room_id1,
+ account_data_type="org.matrix.roorarraz",
+ content={"roo": "rar"},
+ )
+ )
+ # Add a room tag to mark the room as a favourite
+ self.get_success(
+ self.account_data_handler.add_tag_to_room(
+ user_id=user1_id,
+ room_id=room_id1,
+ tag="m.favourite",
+ content={},
+ )
+ )
+
+ # Create another room with some room account data
+ room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok)
+ self.get_success(
+ self.account_data_handler.add_account_data_to_room(
+ user_id=user1_id,
+ room_id=room_id2,
+ account_data_type="org.matrix.roorarraz",
+ content={"roo": "rar"},
+ )
+ )
+ # Add a room tag to mark the room as a favourite
+ self.get_success(
+ self.account_data_handler.add_tag_to_room(
+ user_id=user1_id,
+ room_id=room_id2,
+ tag="m.favourite",
+ content={},
+ )
+ )
+
+ # Now send a message into room1 so that it is at the top of the list
+ self.helper.send(room_id1, body="new event", tok=user1_tok)
+
+ # Make a SS request for only the top room.
+ sync_body = {
+ "lists": {
+ "main": {
+ "ranges": [[0, 0]],
+ "required_state": [],
+ "timeline_limit": 0,
+ }
+ },
+ "extensions": {
+ "account_data": {
+ "enabled": True,
+ "lists": ["main"],
+ }
+ },
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Only room1 should be in the response since it's the latest room with activity
+ # and our range only includes 1 room.
+ self.assertIncludes(
+ response_body["extensions"]["account_data"].get("rooms").keys(),
+ {room_id1},
+ exact=True,
+ )
+
+ # Add some other room account data
+ self.get_success(
+ self.account_data_handler.add_account_data_to_room(
+ user_id=user1_id,
+ room_id=room_id1,
+ account_data_type="org.matrix.roorarraz2",
+ content={"roo": "rar"},
+ )
+ )
+ self.get_success(
+ self.account_data_handler.add_account_data_to_room(
+ user_id=user1_id,
+ room_id=room_id2,
+ account_data_type="org.matrix.roorarraz2",
+ content={"roo": "rar"},
+ )
+ )
+ if tag_action == TagAction.ADD:
+ # Add another room tag
+ self.get_success(
+ self.account_data_handler.add_tag_to_room(
+ user_id=user1_id,
+ room_id=room_id1,
+ tag="m.server_notice",
+ content={},
+ )
+ )
+ self.get_success(
+ self.account_data_handler.add_tag_to_room(
+ user_id=user1_id,
+ room_id=room_id2,
+ tag="m.server_notice",
+ content={},
+ )
+ )
+ elif tag_action == TagAction.REMOVE:
+ # Remove the room tag
+ self.get_success(
+ self.account_data_handler.remove_tag_from_room(
+ user_id=user1_id,
+ room_id=room_id1,
+ tag="m.favourite",
+ )
+ )
+ self.get_success(
+ self.account_data_handler.remove_tag_from_room(
+ user_id=user1_id,
+ room_id=room_id2,
+ tag="m.favourite",
+ )
+ )
+ else:
+ assert_never(tag_action)
+
+ # Move room2 into range.
+ self.helper.send(room_id2, body="new event", tok=user1_tok)
+
+ # Make an incremental Sliding Sync request with the account_data extension enabled
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+ self.assertIsNotNone(response_body["extensions"]["account_data"].get("global"))
+ # We expect to see the account data of room2, as that has the most
+ # recent update.
+ self.assertIncludes(
+ response_body["extensions"]["account_data"].get("rooms").keys(),
+ {room_id2},
+ exact=True,
+ )
+ # Since this is the first time we're seeing room2 down sync, we should see all
+ # room account data for it.
+ account_data_map = {
+ event["type"]: event["content"]
+ for event in response_body["extensions"]["account_data"]
+ .get("rooms")
+ .get(room_id2)
+ }
+ expected_account_data_keys = {
+ "org.matrix.roorarraz",
+ "org.matrix.roorarraz2",
+ }
+ if tag_action == TagAction.ADD:
+ expected_account_data_keys.add(AccountDataTypes.TAG)
+ self.assertIncludes(
+ account_data_map.keys(),
+ expected_account_data_keys,
+ exact=True,
+ )
+ self.assertEqual(account_data_map["org.matrix.roorarraz"], {"roo": "rar"})
+ self.assertEqual(account_data_map["org.matrix.roorarraz2"], {"roo": "rar"})
+ if tag_action == TagAction.ADD:
+ self.assertEqual(
+ account_data_map[AccountDataTypes.TAG],
+ {"tags": {"m.favourite": {}, "m.server_notice": {}}},
+ )
+ elif tag_action == TagAction.REMOVE:
+ # Since we never told the client about the room tags, we don't need to say
+ # anything if there are no tags now (the client doesn't need an update).
+ self.assertIsNone(
+ account_data_map.get(AccountDataTypes.TAG),
+ account_data_map,
+ )
+ else:
+ assert_never(tag_action)
+
+ @parameterized.expand(
+ [
+ ("add tags", TagAction.ADD),
+ ("remove tags", TagAction.REMOVE),
+ ]
+ )
+ def test_room_account_data_incremental_sync_out_of_range_previously(
+ self, test_description: str, tag_action: TagAction
+ ) -> None:
+ """Tests that we don't return account data for rooms that fall out of
+ range, but then do send all account data that has changed they're back in range.
+
+ (HaveSentRoomFlag.PREVIOUSLY)
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a room and add some room account data
+ room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok)
+ self.get_success(
+ self.account_data_handler.add_account_data_to_room(
+ user_id=user1_id,
+ room_id=room_id1,
+ account_data_type="org.matrix.roorarraz",
+ content={"roo": "rar"},
+ )
+ )
+ # Add a room tag to mark the room as a favourite
+ self.get_success(
+ self.account_data_handler.add_tag_to_room(
+ user_id=user1_id,
+ room_id=room_id1,
+ tag="m.favourite",
+ content={},
+ )
+ )
+
+ # Create another room with some room account data
+ room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok)
+ self.get_success(
+ self.account_data_handler.add_account_data_to_room(
+ user_id=user1_id,
+ room_id=room_id2,
+ account_data_type="org.matrix.roorarraz",
+ content={"roo": "rar"},
+ )
+ )
+ # Add a room tag to mark the room as a favourite
+ self.get_success(
+ self.account_data_handler.add_tag_to_room(
+ user_id=user1_id,
+ room_id=room_id2,
+ tag="m.favourite",
+ content={},
+ )
+ )
+
+ # Make an initial Sliding Sync request for only room1 and room2.
+ sync_body = {
+ "lists": {},
+ "room_subscriptions": {
+ room_id1: {
+ "required_state": [],
+ "timeline_limit": 0,
+ },
+ room_id2: {
+ "required_state": [],
+ "timeline_limit": 0,
+ },
+ },
+ "extensions": {
+ "account_data": {
+ "enabled": True,
+ "rooms": [room_id1, room_id2],
+ }
+ },
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Both rooms show up because we have a room subscription for each and they're
+ # requested in the `account_data` extension.
+ self.assertIncludes(
+ response_body["extensions"]["account_data"].get("rooms").keys(),
+ {room_id1, room_id2},
+ exact=True,
+ )
+
+ # Add some other room account data
+ self.get_success(
+ self.account_data_handler.add_account_data_to_room(
+ user_id=user1_id,
+ room_id=room_id1,
+ account_data_type="org.matrix.roorarraz2",
+ content={"roo": "rar"},
+ )
+ )
+ self.get_success(
+ self.account_data_handler.add_account_data_to_room(
+ user_id=user1_id,
+ room_id=room_id2,
+ account_data_type="org.matrix.roorarraz2",
+ content={"roo": "rar"},
+ )
+ )
+ if tag_action == TagAction.ADD:
+ # Add another room tag
+ self.get_success(
+ self.account_data_handler.add_tag_to_room(
+ user_id=user1_id,
+ room_id=room_id1,
+ tag="m.server_notice",
+ content={},
+ )
+ )
+ self.get_success(
+ self.account_data_handler.add_tag_to_room(
+ user_id=user1_id,
+ room_id=room_id2,
+ tag="m.server_notice",
+ content={},
+ )
+ )
+ elif tag_action == TagAction.REMOVE:
+ # Remove the room tag
+ self.get_success(
+ self.account_data_handler.remove_tag_from_room(
+ user_id=user1_id,
+ room_id=room_id1,
+ tag="m.favourite",
+ )
+ )
+ self.get_success(
+ self.account_data_handler.remove_tag_from_room(
+ user_id=user1_id,
+ room_id=room_id2,
+ tag="m.favourite",
+ )
+ )
+ else:
+ assert_never(tag_action)
+
+ # Make an incremental Sliding Sync request for just room1
+ response_body, from_token = self.do_sync(
{
- event["type"]
- for event in response_body["extensions"]["account_data"]
- .get("rooms")
- .get(room_id1)
+ **sync_body,
+ "room_subscriptions": {
+ room_id1: {
+ "required_state": [],
+ "timeline_limit": 0,
+ },
+ },
},
- {"org.matrix.roorarraz2"},
+ since=from_token,
+ tok=user1_tok,
+ )
+
+ # Only room1 shows up because we only have a room subscription for room1 now.
+ self.assertIncludes(
+ response_body["extensions"]["account_data"].get("rooms").keys(),
+ {room_id1},
exact=True,
)
+ # Make an incremental Sliding Sync request for just room2 now
+ response_body, from_token = self.do_sync(
+ {
+ **sync_body,
+ "room_subscriptions": {
+ room_id2: {
+ "required_state": [],
+ "timeline_limit": 0,
+ },
+ },
+ },
+ since=from_token,
+ tok=user1_tok,
+ )
+
+ # Only room2 shows up because we only have a room subscription for room2 now.
+ self.assertIncludes(
+ response_body["extensions"]["account_data"].get("rooms").keys(),
+ {room_id2},
+ exact=True,
+ )
+
+ self.assertIsNotNone(response_body["extensions"]["account_data"].get("global"))
+ # Check for room account data for room2
+ self.assertIncludes(
+ response_body["extensions"]["account_data"].get("rooms").keys(),
+ {room_id2},
+ exact=True,
+ )
+ # We should see any room account data updates for room2 since the last
+ # time we saw it down sync
+ account_data_map = {
+ event["type"]: event["content"]
+ for event in response_body["extensions"]["account_data"]
+ .get("rooms")
+ .get(room_id2)
+ }
+ self.assertIncludes(
+ account_data_map.keys(),
+ {"org.matrix.roorarraz2", AccountDataTypes.TAG},
+ exact=True,
+ )
+ self.assertEqual(account_data_map["org.matrix.roorarraz2"], {"roo": "rar"})
+ if tag_action == TagAction.ADD:
+ self.assertEqual(
+ account_data_map[AccountDataTypes.TAG],
+ {"tags": {"m.favourite": {}, "m.server_notice": {}}},
+ )
+ elif tag_action == TagAction.REMOVE:
+ # If we previously showed the client that the room has tags, when it no
+ # longer has tags, we need to show them an empty map.
+ self.assertEqual(
+ account_data_map[AccountDataTypes.TAG],
+ {"tags": {}},
+ )
+ else:
+ assert_never(tag_action)
+
def test_wait_for_new_data(self) -> None:
"""
Test to make sure that the Sliding Sync request waits for new data to arrive.
diff --git a/tests/rest/client/sliding_sync/test_extension_e2ee.py b/tests/rest/client/sliding_sync/test_extension_e2ee.py
index 320f8c788f..7ce6592d8f 100644
--- a/tests/rest/client/sliding_sync/test_extension_e2ee.py
+++ b/tests/rest/client/sliding_sync/test_extension_e2ee.py
@@ -13,6 +13,8 @@
#
import logging
+from parameterized import parameterized_class
+
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -27,6 +29,20 @@ from tests.server import TimedOutException
logger = logging.getLogger(__name__)
+# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+# foreground update for
+# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+# https://github.com/element-hq/synapse/issues/17623)
+@parameterized_class(
+ ("use_new_tables",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}",
+)
class SlidingSyncE2eeExtensionTestCase(SlidingSyncBase):
"""Tests for the e2ee sliding sync extension"""
@@ -42,6 +58,8 @@ class SlidingSyncE2eeExtensionTestCase(SlidingSyncBase):
self.store = hs.get_datastores().main
self.e2e_keys_handler = hs.get_e2e_keys_handler()
+ super().prepare(reactor, clock, hs)
+
def test_no_data_initial_sync(self) -> None:
"""
Test that enabling e2ee extension works during an intitial sync, even if there
diff --git a/tests/rest/client/sliding_sync/test_extension_receipts.py b/tests/rest/client/sliding_sync/test_extension_receipts.py
index 65fbac260e..6e7700b533 100644
--- a/tests/rest/client/sliding_sync/test_extension_receipts.py
+++ b/tests/rest/client/sliding_sync/test_extension_receipts.py
@@ -13,6 +13,8 @@
#
import logging
+from parameterized import parameterized_class
+
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -28,6 +30,20 @@ from tests.server import TimedOutException
logger = logging.getLogger(__name__)
+# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+# foreground update for
+# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+# https://github.com/element-hq/synapse/issues/17623)
+@parameterized_class(
+ ("use_new_tables",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}",
+)
class SlidingSyncReceiptsExtensionTestCase(SlidingSyncBase):
"""Tests for the receipts sliding sync extension"""
@@ -42,6 +58,8 @@ class SlidingSyncReceiptsExtensionTestCase(SlidingSyncBase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
+ super().prepare(reactor, clock, hs)
+
def test_no_data_initial_sync(self) -> None:
"""
Test that enabling the receipts extension works during an intitial sync,
@@ -677,3 +695,240 @@ class SlidingSyncReceiptsExtensionTestCase(SlidingSyncBase):
set(),
exact=True,
)
+
+ def test_receipts_incremental_sync_out_of_range(self) -> None:
+ """Tests that we don't return read receipts for rooms that fall out of
+ range, but then do send all read receipts once they're back in range.
+ """
+
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+ room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id2, user1_id, tok=user1_tok)
+
+ # Send a message and read receipt into room2
+ event_response = self.helper.send(room_id2, body="new event", tok=user2_tok)
+ room2_event_id = event_response["event_id"]
+
+ self.helper.send_read_receipt(room_id2, room2_event_id, tok=user1_tok)
+
+ # Now send a message into room1 so that it is at the top of the list
+ self.helper.send(room_id1, body="new event", tok=user2_tok)
+
+ # Make a SS request for only the top room.
+ sync_body = {
+ "lists": {
+ "main": {
+ "ranges": [[0, 0]],
+ "required_state": [],
+ "timeline_limit": 5,
+ }
+ },
+ "extensions": {
+ "receipts": {
+ "enabled": True,
+ }
+ },
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # The receipt is in room2, but only room1 is returned, so we don't
+ # expect to get the receipt.
+ self.assertIncludes(
+ response_body["extensions"]["receipts"].get("rooms").keys(),
+ set(),
+ exact=True,
+ )
+
+ # Move room2 into range.
+ self.helper.send(room_id2, body="new event", tok=user2_tok)
+
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+
+ # We expect to see the read receipt of room2, as that has the most
+ # recent update.
+ self.assertIncludes(
+ response_body["extensions"]["receipts"].get("rooms").keys(),
+ {room_id2},
+ exact=True,
+ )
+ receipt = response_body["extensions"]["receipts"]["rooms"][room_id2]
+ self.assertIncludes(
+ receipt["content"][room2_event_id][ReceiptTypes.READ].keys(),
+ {user1_id},
+ exact=True,
+ )
+
+ # Send a message into room1 to bump it to the top, but also send a
+ # receipt in room2
+ self.helper.send(room_id1, body="new event", tok=user2_tok)
+ self.helper.send_read_receipt(room_id2, room2_event_id, tok=user2_tok)
+
+ # We don't expect to see the new read receipt.
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+ self.assertIncludes(
+ response_body["extensions"]["receipts"].get("rooms").keys(),
+ set(),
+ exact=True,
+ )
+
+ # But if we send a new message into room2, we expect to get the missing receipts
+ self.helper.send(room_id2, body="new event", tok=user2_tok)
+
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+ self.assertIncludes(
+ response_body["extensions"]["receipts"].get("rooms").keys(),
+ {room_id2},
+ exact=True,
+ )
+
+ # We should only see the new receipt
+ receipt = response_body["extensions"]["receipts"]["rooms"][room_id2]
+ self.assertIncludes(
+ receipt["content"][room2_event_id][ReceiptTypes.READ].keys(),
+ {user2_id},
+ exact=True,
+ )
+
+ def test_return_own_read_receipts(self) -> None:
+ """Test that we always send the user's own read receipts in initial
+ rooms, even if the receipts don't match events in the timeline..
+ """
+
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ # Send a message and read receipts into room1
+ event_response = self.helper.send(room_id1, body="new event", tok=user2_tok)
+ room1_event_id = event_response["event_id"]
+
+ self.helper.send_read_receipt(room_id1, room1_event_id, tok=user1_tok)
+ self.helper.send_read_receipt(room_id1, room1_event_id, tok=user2_tok)
+
+ # Now send a message so the above message is not in the timeline.
+ self.helper.send(room_id1, body="new event", tok=user2_tok)
+
+ # Make a SS request for only the latest message.
+ sync_body = {
+ "lists": {
+ "main": {
+ "ranges": [[0, 0]],
+ "required_state": [],
+ "timeline_limit": 1,
+ }
+ },
+ "extensions": {
+ "receipts": {
+ "enabled": True,
+ }
+ },
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+
+ # We should get our own receipt in room1, even though its not in the
+ # timeline limit.
+ self.assertIncludes(
+ response_body["extensions"]["receipts"].get("rooms").keys(),
+ {room_id1},
+ exact=True,
+ )
+
+ # We should only see our read receipt, not the other user's.
+ receipt = response_body["extensions"]["receipts"]["rooms"][room_id1]
+ self.assertIncludes(
+ receipt["content"][room1_event_id][ReceiptTypes.READ].keys(),
+ {user1_id},
+ exact=True,
+ )
+
+ def test_read_receipts_expanded_timeline(self) -> None:
+ """Test that we get read receipts when we expand the timeline limit (`unstable_expanded_timeline`)."""
+
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ # Send a message and read receipt into room1
+ event_response = self.helper.send(room_id1, body="new event", tok=user2_tok)
+ room1_event_id = event_response["event_id"]
+
+ self.helper.send_read_receipt(room_id1, room1_event_id, tok=user2_tok)
+
+ # Now send a message so the above message is not in the timeline.
+ self.helper.send(room_id1, body="new event", tok=user2_tok)
+
+ # Make a SS request for only the latest message.
+ sync_body = {
+ "lists": {
+ "main": {
+ "ranges": [[0, 0]],
+ "required_state": [],
+ "timeline_limit": 1,
+ }
+ },
+ "extensions": {
+ "receipts": {
+ "enabled": True,
+ }
+ },
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # We shouldn't see user2 read receipt, as its not in the timeline
+ self.assertIncludes(
+ response_body["extensions"]["receipts"].get("rooms").keys(),
+ set(),
+ exact=True,
+ )
+
+ # Now do another request with a room subscription with an increased timeline limit
+ sync_body["room_subscriptions"] = {
+ room_id1: {
+ "required_state": [],
+ "timeline_limit": 2,
+ }
+ }
+
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+
+ # Assert that we did actually get an expanded timeline
+ room_response = response_body["rooms"][room_id1]
+ self.assertNotIn("initial", room_response)
+ self.assertEqual(room_response["unstable_expanded_timeline"], True)
+
+ # We should now see user2 read receipt, as its in the expanded timeline
+ self.assertIncludes(
+ response_body["extensions"]["receipts"].get("rooms").keys(),
+ {room_id1},
+ exact=True,
+ )
+
+ # We should only see our read receipt, not the other user's.
+ receipt = response_body["extensions"]["receipts"]["rooms"][room_id1]
+ self.assertIncludes(
+ receipt["content"][room1_event_id][ReceiptTypes.READ].keys(),
+ {user2_id},
+ exact=True,
+ )
diff --git a/tests/rest/client/sliding_sync/test_extension_to_device.py b/tests/rest/client/sliding_sync/test_extension_to_device.py
index f8500812ea..790abb739d 100644
--- a/tests/rest/client/sliding_sync/test_extension_to_device.py
+++ b/tests/rest/client/sliding_sync/test_extension_to_device.py
@@ -14,6 +14,8 @@
import logging
from typing import List
+from parameterized import parameterized_class
+
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -28,6 +30,20 @@ from tests.server import TimedOutException
logger = logging.getLogger(__name__)
+# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+# foreground update for
+# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+# https://github.com/element-hq/synapse/issues/17623)
+@parameterized_class(
+ ("use_new_tables",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}",
+)
class SlidingSyncToDeviceExtensionTestCase(SlidingSyncBase):
"""Tests for the to-device sliding sync extension"""
@@ -40,6 +56,7 @@ class SlidingSyncToDeviceExtensionTestCase(SlidingSyncBase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
+ super().prepare(reactor, clock, hs)
def _assert_to_device_response(
self, response_body: JsonDict, expected_messages: List[JsonDict]
diff --git a/tests/rest/client/sliding_sync/test_extension_typing.py b/tests/rest/client/sliding_sync/test_extension_typing.py
index 7f523e0f10..f87c3c8b17 100644
--- a/tests/rest/client/sliding_sync/test_extension_typing.py
+++ b/tests/rest/client/sliding_sync/test_extension_typing.py
@@ -13,6 +13,8 @@
#
import logging
+from parameterized import parameterized_class
+
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -28,6 +30,20 @@ from tests.server import TimedOutException
logger = logging.getLogger(__name__)
+# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+# foreground update for
+# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+# https://github.com/element-hq/synapse/issues/17623)
+@parameterized_class(
+ ("use_new_tables",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}",
+)
class SlidingSyncTypingExtensionTestCase(SlidingSyncBase):
"""Tests for the typing notification sliding sync extension"""
@@ -41,6 +57,8 @@ class SlidingSyncTypingExtensionTestCase(SlidingSyncBase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
+ super().prepare(reactor, clock, hs)
+
def test_no_data_initial_sync(self) -> None:
"""
Test that enabling the typing extension works during an intitial sync,
diff --git a/tests/rest/client/sliding_sync/test_extensions.py b/tests/rest/client/sliding_sync/test_extensions.py
index 68f6661334..30230e5c4b 100644
--- a/tests/rest/client/sliding_sync/test_extensions.py
+++ b/tests/rest/client/sliding_sync/test_extensions.py
@@ -14,7 +14,7 @@
import logging
from typing import Literal
-from parameterized import parameterized
+from parameterized import parameterized, parameterized_class
from typing_extensions import assert_never
from twisted.test.proto_helpers import MemoryReactor
@@ -30,6 +30,20 @@ from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase
logger = logging.getLogger(__name__)
+# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+# foreground update for
+# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+# https://github.com/element-hq/synapse/issues/17623)
+@parameterized_class(
+ ("use_new_tables",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}",
+)
class SlidingSyncExtensionsTestCase(SlidingSyncBase):
"""
Test general extensions behavior in the Sliding Sync API. Each extension has their
@@ -49,6 +63,8 @@ class SlidingSyncExtensionsTestCase(SlidingSyncBase):
self.storage_controllers = hs.get_storage_controllers()
self.account_data_handler = hs.get_account_data_handler()
+ super().prepare(reactor, clock, hs)
+
# Any extensions that use `lists`/`rooms` should be tested here
@parameterized.expand([("account_data",), ("receipts",), ("typing",)])
def test_extensions_lists_rooms_relevant_rooms(
@@ -120,19 +136,26 @@ class SlidingSyncExtensionsTestCase(SlidingSyncBase):
"foo-list": {
"ranges": [[0, 1]],
"required_state": [],
- "timeline_limit": 0,
+ # We set this to `1` because we're testing `receipts` which
+ # interact with the `timeline`. With receipts, when a room
+ # hasn't been sent down the connection before or it appears
+ # as `initial: true`, we only include receipts for events in
+ # the timeline to avoid bloating and blowing up the sync
+ # response as the number of users in the room increases.
+ # (this behavior is part of the spec)
+ "timeline_limit": 1,
},
# We expect this list range to include room5, room4, room3
"bar-list": {
"ranges": [[0, 2]],
"required_state": [],
- "timeline_limit": 0,
+ "timeline_limit": 1,
},
},
"room_subscriptions": {
room_id1: {
"required_state": [],
- "timeline_limit": 0,
+ "timeline_limit": 1,
}
},
}
diff --git a/tests/rest/client/sliding_sync/test_lists_filters.py b/tests/rest/client/sliding_sync/test_lists_filters.py
new file mode 100644
index 0000000000..c59f6aedc4
--- /dev/null
+++ b/tests/rest/client/sliding_sync/test_lists_filters.py
@@ -0,0 +1,1975 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+import logging
+
+from parameterized import parameterized_class
+
+from twisted.test.proto_helpers import MemoryReactor
+
+import synapse.rest.admin
+from synapse.api.constants import (
+ EventContentFields,
+ EventTypes,
+ RoomTypes,
+)
+from synapse.api.room_versions import RoomVersions
+from synapse.events import StrippedStateEvent
+from synapse.rest.client import login, room, sync, tags
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
+
+from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase
+
+logger = logging.getLogger(__name__)
+
+
+# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+# foreground update for
+# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+# https://github.com/element-hq/synapse/issues/17623)
+@parameterized_class(
+ ("use_new_tables",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}",
+)
+class SlidingSyncFiltersTestCase(SlidingSyncBase):
+ """
+ Test `filters` in the Sliding Sync API to make sure it includes/excludes rooms
+ correctly.
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ sync.register_servlets,
+ tags.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.event_sources = hs.get_event_sources()
+ self.storage_controllers = hs.get_storage_controllers()
+ self.account_data_handler = hs.get_account_data_handler()
+
+ super().prepare(reactor, clock, hs)
+
+ def test_multiple_filters_and_multiple_lists(self) -> None:
+ """
+ Test that filters apply to `lists` in various scenarios.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # Create a DM room
+ joined_dm_room_id = self._create_dm_room(
+ inviter_user_id=user1_id,
+ inviter_tok=user1_tok,
+ invitee_user_id=user2_id,
+ invitee_tok=user2_tok,
+ should_join_room=True,
+ )
+ invited_dm_room_id = self._create_dm_room(
+ inviter_user_id=user1_id,
+ inviter_tok=user1_tok,
+ invitee_user_id=user2_id,
+ invitee_tok=user2_tok,
+ should_join_room=False,
+ )
+
+ # Create a normal room
+ room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id, user1_id, tok=user1_tok)
+
+ # Create a room that user1 is invited to
+ invite_room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.invite(invite_room_id, src=user2_id, targ=user1_id, tok=user2_tok)
+
+ # Make the Sliding Sync request
+ sync_body = {
+ "lists": {
+ # Absence of filters does not imply "False" values
+ "all": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 1,
+ "filters": {},
+ },
+ # Test single truthy filter
+ "dms": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 1,
+ "filters": {"is_dm": True},
+ },
+ # Test single falsy filter
+ "non-dms": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 1,
+ "filters": {"is_dm": False},
+ },
+ # Test how multiple filters should stack (AND'd together)
+ "room-invites": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 1,
+ "filters": {"is_dm": False, "is_invite": True},
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+
+ # Make sure it has the lists we requested
+ self.assertIncludes(
+ response_body["lists"].keys(),
+ {"all", "dms", "non-dms", "room-invites"},
+ exact=True,
+ )
+
+ # Make sure the lists have the correct rooms
+ self.assertIncludes(
+ set(response_body["lists"]["all"]["ops"][0]["room_ids"]),
+ {
+ invite_room_id,
+ room_id,
+ invited_dm_room_id,
+ joined_dm_room_id,
+ },
+ exact=True,
+ )
+ self.assertIncludes(
+ set(response_body["lists"]["dms"]["ops"][0]["room_ids"]),
+ {invited_dm_room_id, joined_dm_room_id},
+ exact=True,
+ )
+ self.assertIncludes(
+ set(response_body["lists"]["non-dms"]["ops"][0]["room_ids"]),
+ {invite_room_id, room_id},
+ exact=True,
+ )
+ self.assertIncludes(
+ set(response_body["lists"]["room-invites"]["ops"][0]["room_ids"]),
+ {invite_room_id},
+ exact=True,
+ )
+
+ def test_filters_regardless_of_membership_server_left_room(self) -> None:
+ """
+ Test that filters apply to rooms regardless of membership. We're also
+ compounding the problem by having all of the local users leave the room causing
+ our server to leave the room.
+
+ We want to make sure that if someone is filtering rooms, and leaves, you still
+ get that final update down sync that you left.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # Create a normal room
+ room_id = self.helper.create_room_as(user1_id, tok=user2_tok)
+ self.helper.join(room_id, user1_id, tok=user1_tok)
+
+ # Create an encrypted space room
+ space_room_id = self.helper.create_room_as(
+ user2_id,
+ tok=user2_tok,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
+ },
+ )
+ self.helper.send_state(
+ space_room_id,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=user2_tok,
+ )
+ self.helper.join(space_room_id, user1_id, tok=user1_tok)
+
+ # Make an initial Sliding Sync request
+ sync_body = {
+ "lists": {
+ "all-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {},
+ },
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 1,
+ "filters": {
+ "is_encrypted": True,
+ "room_types": [RoomTypes.SPACE],
+ },
+ },
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Make sure the response has the lists we requested
+ self.assertIncludes(
+ response_body["lists"].keys(),
+ {"all-list", "foo-list"},
+ )
+
+ # Make sure the lists have the correct rooms
+ self.assertIncludes(
+ set(response_body["lists"]["all-list"]["ops"][0]["room_ids"]),
+ {space_room_id, room_id},
+ exact=True,
+ )
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {space_room_id},
+ exact=True,
+ )
+
+ # Everyone leaves the encrypted space room
+ self.helper.leave(space_room_id, user1_id, tok=user1_tok)
+ self.helper.leave(space_room_id, user2_id, tok=user2_tok)
+
+ # Make an incremental Sliding Sync request
+ sync_body = {
+ "lists": {
+ "all-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {},
+ },
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 1,
+ "filters": {
+ "is_encrypted": True,
+ "room_types": [RoomTypes.SPACE],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+ # Make sure the response has the lists we requested
+ self.assertIncludes(
+ response_body["lists"].keys(),
+ {"all-list", "foo-list"},
+ exact=True,
+ )
+
+ # Make sure the lists have the correct rooms even though we `newly_left`
+ self.assertIncludes(
+ set(response_body["lists"]["all-list"]["ops"][0]["room_ids"]),
+ {space_room_id, room_id},
+ exact=True,
+ )
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {space_room_id},
+ exact=True,
+ )
+
+ def test_filters_is_dm(self) -> None:
+ """
+ Test `filter.is_dm` for DM rooms
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # Create a normal room
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Create a DM room
+ dm_room_id = self._create_dm_room(
+ inviter_user_id=user1_id,
+ inviter_tok=user1_tok,
+ invitee_user_id=user2_id,
+ invitee_tok=user2_tok,
+ )
+
+ # Try with `is_dm=True`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_dm": True,
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {dm_room_id},
+ exact=True,
+ )
+
+ # Try with `is_dm=False`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_dm": False,
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ def test_filters_is_encrypted(self) -> None:
+ """
+ Test `filters.is_encrypted` for encrypted rooms
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create an unencrypted room
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Create an encrypted room
+ encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ self.helper.send_state(
+ encrypted_room_id,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=user1_tok,
+ )
+
+ # Try with `is_encrypted=True`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_encrypted": True,
+ },
+ },
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+ # No rooms are encrypted yet
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {encrypted_room_id},
+ exact=True,
+ )
+
+ # Try with `is_encrypted=False`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_encrypted": False,
+ },
+ },
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+ # No rooms are encrypted yet
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ def test_filters_is_encrypted_server_left_room(self) -> None:
+ """
+ Test that we can apply a `filters.is_encrypted` against a room that everyone has left.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Get a token before we create any rooms
+ sync_body: JsonDict = {
+ "lists": {},
+ }
+ response_body, before_rooms_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Create an unencrypted room
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ # Leave the room
+ self.helper.leave(room_id, user1_id, tok=user1_tok)
+
+ # Create an encrypted room
+ encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ self.helper.send_state(
+ encrypted_room_id,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=user1_tok,
+ )
+ # Leave the room
+ self.helper.leave(encrypted_room_id, user1_id, tok=user1_tok)
+
+ # Try with `is_encrypted=True`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_encrypted": True,
+ },
+ },
+ }
+ }
+ # Use an incremental sync so that the room is considered `newly_left` and shows
+ # up down sync
+ response_body, _ = self.do_sync(
+ sync_body, since=before_rooms_token, tok=user1_tok
+ )
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {encrypted_room_id},
+ exact=True,
+ )
+
+ # Try with `is_encrypted=False`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_encrypted": False,
+ },
+ },
+ }
+ }
+ # Use an incremental sync so that the room is considered `newly_left` and shows
+ # up down sync
+ response_body, _ = self.do_sync(
+ sync_body, since=before_rooms_token, tok=user1_tok
+ )
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ def test_filters_is_encrypted_server_left_room2(self) -> None:
+ """
+ Test that we can apply a `filters.is_encrypted` against a room that everyone has
+ left.
+
+ There is still someone local who is invited to the rooms but that doesn't affect
+ whether the server is participating in the room (users need to be joined).
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ _user2_tok = self.login(user2_id, "pass")
+
+ # Get a token before we create any rooms
+ sync_body: JsonDict = {
+ "lists": {},
+ }
+ response_body, before_rooms_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Create an unencrypted room
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ # Invite user2
+ self.helper.invite(room_id, targ=user2_id, tok=user1_tok)
+ # User1 leaves the room
+ self.helper.leave(room_id, user1_id, tok=user1_tok)
+
+ # Create an encrypted room
+ encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ self.helper.send_state(
+ encrypted_room_id,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=user1_tok,
+ )
+ # Invite user2
+ self.helper.invite(encrypted_room_id, targ=user2_id, tok=user1_tok)
+ # User1 leaves the room
+ self.helper.leave(encrypted_room_id, user1_id, tok=user1_tok)
+
+ # Try with `is_encrypted=True`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_encrypted": True,
+ },
+ },
+ }
+ }
+ # Use an incremental sync so that the room is considered `newly_left` and shows
+ # up down sync
+ response_body, _ = self.do_sync(
+ sync_body, since=before_rooms_token, tok=user1_tok
+ )
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {encrypted_room_id},
+ exact=True,
+ )
+
+ # Try with `is_encrypted=False`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_encrypted": False,
+ },
+ },
+ }
+ }
+ # Use an incremental sync so that the room is considered `newly_left` and shows
+ # up down sync
+ response_body, _ = self.do_sync(
+ sync_body, since=before_rooms_token, tok=user1_tok
+ )
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ def test_filters_is_encrypted_after_we_left(self) -> None:
+ """
+ Test that we can apply a `filters.is_encrypted` against a room that was encrypted
+ after we left the room (make sure we don't just use the current state)
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # Get a token before we create any rooms
+ sync_body: JsonDict = {
+ "lists": {},
+ }
+ response_body, before_rooms_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Create an unencrypted room
+ room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+ # Leave the room
+ self.helper.join(room_id, user1_id, tok=user1_tok)
+ self.helper.leave(room_id, user1_id, tok=user1_tok)
+
+ # Create a room that will be encrypted
+ encrypted_after_we_left_room_id = self.helper.create_room_as(
+ user2_id, tok=user2_tok
+ )
+ # Leave the room
+ self.helper.join(encrypted_after_we_left_room_id, user1_id, tok=user1_tok)
+ self.helper.leave(encrypted_after_we_left_room_id, user1_id, tok=user1_tok)
+
+ # Encrypt the room after we've left
+ self.helper.send_state(
+ encrypted_after_we_left_room_id,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=user2_tok,
+ )
+
+ # Try with `is_encrypted=True`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_encrypted": True,
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(
+ sync_body, since=before_rooms_token, tok=user1_tok
+ )
+ if self.use_new_tables:
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ set(),
+ exact=True,
+ )
+ else:
+ # Even though we left the room before it was encrypted, we still see it because
+ # someone else on our server is still participating in the room and we "leak"
+ # the current state to the left user. But we consider the room encryption status
+ # to not be a secret given it's often set at the start of the room and it's one
+ # of the stripped state events that is normally handed out.
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {encrypted_after_we_left_room_id},
+ exact=True,
+ )
+
+ # Try with `is_encrypted=False`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_encrypted": False,
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(
+ sync_body, since=before_rooms_token, tok=user1_tok
+ )
+ if self.use_new_tables:
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id, encrypted_after_we_left_room_id},
+ exact=True,
+ )
+ else:
+ # Even though we left the room before it was encrypted... (see comment above)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ def test_filters_is_encrypted_with_remote_invite_room_no_stripped_state(
+ self,
+ ) -> None:
+ """
+ Test that we can apply a `filters.is_encrypted` filter against a remote invite
+ room without any `unsigned.invite_room_state` (stripped state).
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a remote invite room without any `unsigned.invite_room_state`
+ _remote_invite_room_id = self._create_remote_invite_room_for_user(
+ user1_id, None
+ )
+
+ # Create an unencrypted room
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Create an encrypted room
+ encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ self.helper.send_state(
+ encrypted_room_id,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=user1_tok,
+ )
+
+ # Try with `is_encrypted=True`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_encrypted": True,
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ # `remote_invite_room_id` should not appear because we can't figure out whether
+ # it is encrypted or not (no stripped state, `unsigned.invite_room_state`).
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {encrypted_room_id},
+ exact=True,
+ )
+
+ # Try with `is_encrypted=False`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_encrypted": False,
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ # `remote_invite_room_id` should not appear because we can't figure out whether
+ # it is encrypted or not (no stripped state, `unsigned.invite_room_state`).
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ def test_filters_is_encrypted_with_remote_invite_encrypted_room(self) -> None:
+ """
+ Test that we can apply a `filters.is_encrypted` filter against a remote invite
+ encrypted room with some `unsigned.invite_room_state` (stripped state).
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a remote invite room with some `unsigned.invite_room_state`
+ # indicating that the room is encrypted.
+ remote_invite_room_id = self._create_remote_invite_room_for_user(
+ user1_id,
+ [
+ StrippedStateEvent(
+ type=EventTypes.Create,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ROOM_CREATOR: "@inviter:remote_server",
+ EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier,
+ },
+ ),
+ StrippedStateEvent(
+ type=EventTypes.RoomEncryption,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2",
+ },
+ ),
+ ],
+ )
+
+ # Create an unencrypted room
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Create an encrypted room
+ encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ self.helper.send_state(
+ encrypted_room_id,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=user1_tok,
+ )
+
+ # Try with `is_encrypted=True`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_encrypted": True,
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ # `remote_invite_room_id` should appear here because it is encrypted
+ # according to the stripped state
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {encrypted_room_id, remote_invite_room_id},
+ exact=True,
+ )
+
+ # Try with `is_encrypted=False`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_encrypted": False,
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ # `remote_invite_room_id` should not appear here because it is encrypted
+ # according to the stripped state
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ def test_filters_is_encrypted_with_remote_invite_unencrypted_room(self) -> None:
+ """
+ Test that we can apply a `filters.is_encrypted` filter against a remote invite
+ unencrypted room with some `unsigned.invite_room_state` (stripped state).
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a remote invite room with some `unsigned.invite_room_state`
+ # but don't set any room encryption event.
+ remote_invite_room_id = self._create_remote_invite_room_for_user(
+ user1_id,
+ [
+ StrippedStateEvent(
+ type=EventTypes.Create,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ROOM_CREATOR: "@inviter:remote_server",
+ EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier,
+ },
+ ),
+ # No room encryption event
+ ],
+ )
+
+ # Create an unencrypted room
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Create an encrypted room
+ encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ self.helper.send_state(
+ encrypted_room_id,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=user1_tok,
+ )
+
+ # Try with `is_encrypted=True`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_encrypted": True,
+ },
+ },
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+ # `remote_invite_room_id` should not appear here because it is unencrypted
+ # according to the stripped state
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {encrypted_room_id},
+ exact=True,
+ )
+
+ # Try with `is_encrypted=False`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_encrypted": False,
+ },
+ },
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+ # `remote_invite_room_id` should appear because it is unencrypted according to
+ # the stripped state
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id, remote_invite_room_id},
+ exact=True,
+ )
+
+ def test_filters_is_encrypted_updated(self) -> None:
+ """
+ Make sure we get rooms if the encrypted room status is updated for a joined room
+ (`filters.is_encrypted`)
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_encrypted": True,
+ },
+ },
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # No rooms are encrypted yet
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ set(),
+ exact=True,
+ )
+
+ # Update the encryption status
+ self.helper.send_state(
+ room_id,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=user1_tok,
+ )
+
+ # We should see the room now because it's encrypted
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ def test_filters_is_invite_rooms(self) -> None:
+ """
+ Test `filters.is_invite` for rooms that the user has been invited to
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # Create a normal room
+ room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id, user1_id, tok=user1_tok)
+
+ # Create a room that user1 is invited to
+ invite_room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.invite(invite_room_id, src=user2_id, targ=user1_id, tok=user2_tok)
+
+ # Try with `is_invite=True`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_invite": True,
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {invite_room_id},
+ exact=True,
+ )
+
+ # Try with `is_invite=False`
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "is_invite": False,
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ def test_filters_room_types(self) -> None:
+ """
+ Test `filters.room_types` for different room types
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a normal room (no room type)
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Create a space room
+ space_room_id = self.helper.create_room_as(
+ user1_id,
+ tok=user1_tok,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
+ },
+ )
+
+ # Create an arbitrarily typed room
+ foo_room_id = self.helper.create_room_as(
+ user1_id,
+ tok=user1_tok,
+ extra_content={
+ "creation_content": {
+ EventContentFields.ROOM_TYPE: "org.matrix.foobarbaz"
+ }
+ },
+ )
+
+ # Try finding only normal rooms
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": [None],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ # Try finding only spaces
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": [RoomTypes.SPACE],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {space_room_id},
+ exact=True,
+ )
+
+ # Try finding normal rooms and spaces
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": [None, RoomTypes.SPACE],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id, space_room_id},
+ exact=True,
+ )
+
+ # Try finding an arbitrary room type
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": ["org.matrix.foobarbaz"],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {foo_room_id},
+ exact=True,
+ )
+
+ # Just make sure we know what happens when you specify an empty list of room_types
+ # (we should find nothing)
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": [],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ set(),
+ exact=True,
+ )
+
+ def test_filters_not_room_types(self) -> None:
+ """
+ Test `filters.not_room_types` for different room types
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a normal room (no room type)
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Create a space room
+ space_room_id = self.helper.create_room_as(
+ user1_id,
+ tok=user1_tok,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
+ },
+ )
+
+ # Create an arbitrarily typed room
+ foo_room_id = self.helper.create_room_as(
+ user1_id,
+ tok=user1_tok,
+ extra_content={
+ "creation_content": {
+ EventContentFields.ROOM_TYPE: "org.matrix.foobarbaz"
+ }
+ },
+ )
+
+ # Try finding *NOT* normal rooms
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "not_room_types": [None],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {space_room_id, foo_room_id},
+ exact=True,
+ )
+
+ # Try finding *NOT* spaces
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "not_room_types": [RoomTypes.SPACE],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id, foo_room_id},
+ exact=True,
+ )
+
+ # Try finding *NOT* normal rooms or spaces
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "not_room_types": [None, RoomTypes.SPACE],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {foo_room_id},
+ exact=True,
+ )
+
+ # Test how it behaves when we have both `room_types` and `not_room_types`.
+ # `not_room_types` should win.
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": [None],
+ "not_room_types": [None],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ # Nothing matches because nothing is both a normal room and not a normal room
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ set(),
+ exact=True,
+ )
+
+ # Test how it behaves when we have both `room_types` and `not_room_types`.
+ # `not_room_types` should win.
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": [None, RoomTypes.SPACE],
+ "not_room_types": [None],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {space_room_id},
+ exact=True,
+ )
+
+ # Just make sure we know what happens when you specify an empty list of not_room_types
+ # (we should find all of the rooms)
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "not_room_types": [],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id, foo_room_id, space_room_id},
+ exact=True,
+ )
+
+ def test_filters_room_types_server_left_room(self) -> None:
+ """
+ Test that we can apply a `filters.room_types` against a room that everyone has left.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Get a token before we create any rooms
+ sync_body: JsonDict = {
+ "lists": {},
+ }
+ response_body, before_rooms_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Create a normal room (no room type)
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ # Leave the room
+ self.helper.leave(room_id, user1_id, tok=user1_tok)
+
+ # Create a space room
+ space_room_id = self.helper.create_room_as(
+ user1_id,
+ tok=user1_tok,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
+ },
+ )
+ # Leave the room
+ self.helper.leave(space_room_id, user1_id, tok=user1_tok)
+
+ # Try finding only normal rooms
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": [None],
+ },
+ },
+ }
+ }
+ # Use an incremental sync so that the room is considered `newly_left` and shows
+ # up down sync
+ response_body, _ = self.do_sync(
+ sync_body, since=before_rooms_token, tok=user1_tok
+ )
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ # Try finding only spaces
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": [RoomTypes.SPACE],
+ },
+ },
+ }
+ }
+ # Use an incremental sync so that the room is considered `newly_left` and shows
+ # up down sync
+ response_body, _ = self.do_sync(
+ sync_body, since=before_rooms_token, tok=user1_tok
+ )
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {space_room_id},
+ exact=True,
+ )
+
+ def test_filter_room_types_server_left_room2(self) -> None:
+ """
+ Test that we can apply a `filter.room_types` against a room that everyone has left.
+
+ There is still someone local who is invited to the rooms but that doesn't affect
+ whether the server is participating in the room (users need to be joined).
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ _user2_tok = self.login(user2_id, "pass")
+
+ # Get a token before we create any rooms
+ sync_body: JsonDict = {
+ "lists": {},
+ }
+ response_body, before_rooms_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Create a normal room (no room type)
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ # Invite user2
+ self.helper.invite(room_id, targ=user2_id, tok=user1_tok)
+ # User1 leaves the room
+ self.helper.leave(room_id, user1_id, tok=user1_tok)
+
+ # Create a space room
+ space_room_id = self.helper.create_room_as(
+ user1_id,
+ tok=user1_tok,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
+ },
+ )
+ # Invite user2
+ self.helper.invite(space_room_id, targ=user2_id, tok=user1_tok)
+ # User1 leaves the room
+ self.helper.leave(space_room_id, user1_id, tok=user1_tok)
+
+ # Try finding only normal rooms
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": [None],
+ },
+ },
+ }
+ }
+ # Use an incremental sync so that the room is considered `newly_left` and shows
+ # up down sync
+ response_body, _ = self.do_sync(
+ sync_body, since=before_rooms_token, tok=user1_tok
+ )
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ # Try finding only spaces
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": [RoomTypes.SPACE],
+ },
+ },
+ }
+ }
+ # Use an incremental sync so that the room is considered `newly_left` and shows
+ # up down sync
+ response_body, _ = self.do_sync(
+ sync_body, since=before_rooms_token, tok=user1_tok
+ )
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {space_room_id},
+ exact=True,
+ )
+
+ def test_filters_room_types_with_remote_invite_room_no_stripped_state(self) -> None:
+ """
+ Test that we can apply a `filters.room_types` filter against a remote invite
+ room without any `unsigned.invite_room_state` (stripped state).
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a remote invite room without any `unsigned.invite_room_state`
+ _remote_invite_room_id = self._create_remote_invite_room_for_user(
+ user1_id, None
+ )
+
+ # Create a normal room (no room type)
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Create a space room
+ space_room_id = self.helper.create_room_as(
+ user1_id,
+ tok=user1_tok,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
+ },
+ )
+
+ # Try finding only normal rooms
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": [None],
+ },
+ },
+ }
+ }
+ # `remote_invite_room_id` should not appear because we can't figure out what
+ # room type it is (no stripped state, `unsigned.invite_room_state`)
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ # Try finding only spaces
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": [RoomTypes.SPACE],
+ },
+ },
+ }
+ }
+ # `remote_invite_room_id` should not appear because we can't figure out what
+ # room type it is (no stripped state, `unsigned.invite_room_state`)
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {space_room_id},
+ exact=True,
+ )
+
+ def test_filters_room_types_with_remote_invite_space(self) -> None:
+ """
+ Test that we can apply a `filters.room_types` filter against a remote invite
+ to a space room with some `unsigned.invite_room_state` (stripped state).
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a remote invite room with some `unsigned.invite_room_state` indicating
+ # that it is a space room
+ remote_invite_room_id = self._create_remote_invite_room_for_user(
+ user1_id,
+ [
+ StrippedStateEvent(
+ type=EventTypes.Create,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ROOM_CREATOR: "@inviter:remote_server",
+ EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier,
+ # Specify that it is a space room
+ EventContentFields.ROOM_TYPE: RoomTypes.SPACE,
+ },
+ ),
+ ],
+ )
+
+ # Create a normal room (no room type)
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Create a space room
+ space_room_id = self.helper.create_room_as(
+ user1_id,
+ tok=user1_tok,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
+ },
+ )
+
+ # Try finding only normal rooms
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": [None],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ # `remote_invite_room_id` should not appear here because it is a space room
+ # according to the stripped state
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ # Try finding only spaces
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": [RoomTypes.SPACE],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ # `remote_invite_room_id` should appear here because it is a space room
+ # according to the stripped state
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {space_room_id, remote_invite_room_id},
+ exact=True,
+ )
+
+ def test_filters_room_types_with_remote_invite_normal_room(self) -> None:
+ """
+ Test that we can apply a `filters.room_types` filter against a remote invite
+ to a normal room with some `unsigned.invite_room_state` (stripped state).
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a remote invite room with some `unsigned.invite_room_state`
+ # but the create event does not specify a room type (normal room)
+ remote_invite_room_id = self._create_remote_invite_room_for_user(
+ user1_id,
+ [
+ StrippedStateEvent(
+ type=EventTypes.Create,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ROOM_CREATOR: "@inviter:remote_server",
+ EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier,
+ # No room type means this is a normal room
+ },
+ ),
+ ],
+ )
+
+ # Create a normal room (no room type)
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Create a space room
+ space_room_id = self.helper.create_room_as(
+ user1_id,
+ tok=user1_tok,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
+ },
+ )
+
+ # Try finding only normal rooms
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": [None],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ # `remote_invite_room_id` should appear here because it is a normal room
+ # according to the stripped state (no room type)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id, remote_invite_room_id},
+ exact=True,
+ )
+
+ # Try finding only spaces
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "room_types": [RoomTypes.SPACE],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ # `remote_invite_room_id` should not appear here because it is a normal room
+ # according to the stripped state (no room type)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {space_room_id},
+ exact=True,
+ )
+
+ def _add_tag_to_room(
+ self, *, room_id: str, user_id: str, access_token: str, tag_name: str
+ ) -> None:
+ channel = self.make_request(
+ method="PUT",
+ path=f"/user/{user_id}/rooms/{room_id}/tags/{tag_name}",
+ content={},
+ access_token=access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ def test_filters_tags(self) -> None:
+ """
+ Test `filters.tags` for rooms with given tags
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a room with no tags
+ self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Create some rooms with tags
+ foo_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ bar_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ # Create a room without multiple tags
+ foobar_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Add the "foo" tag to the foo room
+ self._add_tag_to_room(
+ room_id=foo_room_id,
+ user_id=user1_id,
+ access_token=user1_tok,
+ tag_name="foo",
+ )
+ # Add the "bar" tag to the bar room
+ self._add_tag_to_room(
+ room_id=bar_room_id,
+ user_id=user1_id,
+ access_token=user1_tok,
+ tag_name="bar",
+ )
+ # Add both "foo" and "bar" tags to the foobar room
+ self._add_tag_to_room(
+ room_id=foobar_room_id,
+ user_id=user1_id,
+ access_token=user1_tok,
+ tag_name="foo",
+ )
+ self._add_tag_to_room(
+ room_id=foobar_room_id,
+ user_id=user1_id,
+ access_token=user1_tok,
+ tag_name="bar",
+ )
+
+ # Try finding rooms with the "foo" tag
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "tags": ["foo"],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {foo_room_id, foobar_room_id},
+ exact=True,
+ )
+
+ # Try finding rooms with either "foo" or "bar" tags
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "tags": ["foo", "bar"],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {foo_room_id, bar_room_id, foobar_room_id},
+ exact=True,
+ )
+
+ # Try with a random tag we didn't add
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "tags": ["flomp"],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ # No rooms should match
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ set(),
+ exact=True,
+ )
+
+ # Just make sure we know what happens when you specify an empty list of tags
+ # (we should find nothing)
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "tags": [],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ set(),
+ exact=True,
+ )
+
+ def test_filters_not_tags(self) -> None:
+ """
+ Test `filters.not_tags` for excluding rooms with given tags
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a room with no tags
+ untagged_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Create some rooms with tags
+ foo_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ bar_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ # Create a room without multiple tags
+ foobar_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Add the "foo" tag to the foo room
+ self._add_tag_to_room(
+ room_id=foo_room_id,
+ user_id=user1_id,
+ access_token=user1_tok,
+ tag_name="foo",
+ )
+ # Add the "bar" tag to the bar room
+ self._add_tag_to_room(
+ room_id=bar_room_id,
+ user_id=user1_id,
+ access_token=user1_tok,
+ tag_name="bar",
+ )
+ # Add both "foo" and "bar" tags to the foobar room
+ self._add_tag_to_room(
+ room_id=foobar_room_id,
+ user_id=user1_id,
+ access_token=user1_tok,
+ tag_name="foo",
+ )
+ self._add_tag_to_room(
+ room_id=foobar_room_id,
+ user_id=user1_id,
+ access_token=user1_tok,
+ tag_name="bar",
+ )
+
+ # Try finding rooms without the "foo" tag
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "not_tags": ["foo"],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {untagged_room_id, bar_room_id},
+ exact=True,
+ )
+
+ # Try finding rooms without either "foo" or "bar" tags
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "not_tags": ["foo", "bar"],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {untagged_room_id},
+ exact=True,
+ )
+
+ # Test how it behaves when we have both `tags` and `not_tags`.
+ # `not_tags` should win.
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "tags": ["foo"],
+ "not_tags": ["foo"],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ # Nothing matches because nothing is both tagged with "foo" and not tagged with "foo"
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ set(),
+ exact=True,
+ )
+
+ # Just make sure we know what happens when you specify an empty list of not_tags
+ # (we should find all of the rooms)
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ "filters": {
+ "not_tags": [],
+ },
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {untagged_room_id, foo_room_id, bar_room_id, foobar_room_id},
+ exact=True,
+ )
diff --git a/tests/rest/client/sliding_sync/test_room_subscriptions.py b/tests/rest/client/sliding_sync/test_room_subscriptions.py
index cc17b0b354..285fdaaf78 100644
--- a/tests/rest/client/sliding_sync/test_room_subscriptions.py
+++ b/tests/rest/client/sliding_sync/test_room_subscriptions.py
@@ -14,6 +14,8 @@
import logging
from http import HTTPStatus
+from parameterized import parameterized_class
+
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -27,6 +29,20 @@ from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase
logger = logging.getLogger(__name__)
+# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+# foreground update for
+# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+# https://github.com/element-hq/synapse/issues/17623)
+@parameterized_class(
+ ("use_new_tables",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}",
+)
class SlidingSyncRoomSubscriptionsTestCase(SlidingSyncBase):
"""
Test `room_subscriptions` in the Sliding Sync API.
@@ -43,6 +59,8 @@ class SlidingSyncRoomSubscriptionsTestCase(SlidingSyncBase):
self.store = hs.get_datastores().main
self.storage_controllers = hs.get_storage_controllers()
+ super().prepare(reactor, clock, hs)
+
def test_room_subscriptions_with_join_membership(self) -> None:
"""
Test `room_subscriptions` with a joined room should give us timeline and current
diff --git a/tests/rest/client/sliding_sync/test_rooms_invites.py b/tests/rest/client/sliding_sync/test_rooms_invites.py
index f08ffaf674..882762ca29 100644
--- a/tests/rest/client/sliding_sync/test_rooms_invites.py
+++ b/tests/rest/client/sliding_sync/test_rooms_invites.py
@@ -13,6 +13,8 @@
#
import logging
+from parameterized import parameterized_class
+
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -27,6 +29,20 @@ from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase
logger = logging.getLogger(__name__)
+# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+# foreground update for
+# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+# https://github.com/element-hq/synapse/issues/17623)
+@parameterized_class(
+ ("use_new_tables",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}",
+)
class SlidingSyncRoomsInvitesTestCase(SlidingSyncBase):
"""
Test to make sure the `rooms` response looks good for invites in the Sliding Sync API.
@@ -49,6 +65,8 @@ class SlidingSyncRoomsInvitesTestCase(SlidingSyncBase):
self.store = hs.get_datastores().main
self.storage_controllers = hs.get_storage_controllers()
+ super().prepare(reactor, clock, hs)
+
def test_rooms_invite_shared_history_initial_sync(self) -> None:
"""
Test that `rooms` we are invited to have some stripped `invite_state` during an
diff --git a/tests/rest/client/sliding_sync/test_rooms_meta.py b/tests/rest/client/sliding_sync/test_rooms_meta.py
index 04f11c0524..0a8b2c02c2 100644
--- a/tests/rest/client/sliding_sync/test_rooms_meta.py
+++ b/tests/rest/client/sliding_sync/test_rooms_meta.py
@@ -13,10 +13,12 @@
#
import logging
+from parameterized import parameterized, parameterized_class
+
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.rest.client import login, room, sync
from synapse.server import HomeServer
@@ -28,6 +30,20 @@ from tests.test_utils.event_injection import create_event
logger = logging.getLogger(__name__)
+# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+# foreground update for
+# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+# https://github.com/element-hq/synapse/issues/17623)
+@parameterized_class(
+ ("use_new_tables",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}",
+)
class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
"""
Test rooms meta info like name, avatar, joined_count, invited_count, is_dm,
@@ -44,11 +60,18 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.storage_controllers = hs.get_storage_controllers()
+ self.state_handler = self.hs.get_state_handler()
+ persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
+ self.persistence = persistence
+
+ super().prepare(reactor, clock, hs)
- def test_rooms_meta_when_joined(self) -> None:
+ def test_rooms_meta_when_joined_initial(self) -> None:
"""
- Test that the `rooms` `name` and `avatar` are included in the response and
- reflect the current state of the room when the user is joined to the room.
+ Test that the `rooms` `name` and `avatar` are included in the initial sync
+ response and reflect the current state of the room when the user is joined to
+ the room.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
@@ -85,6 +108,7 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
# Reflect the current state of the room
+ self.assertEqual(response_body["rooms"][room_id1]["initial"], True)
self.assertEqual(
response_body["rooms"][room_id1]["name"],
"my super room",
@@ -107,6 +131,178 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
response_body["rooms"][room_id1].get("is_dm"),
)
+ def test_rooms_meta_when_joined_incremental_no_change(self) -> None:
+ """
+ Test that the `rooms` `name` and `avatar` aren't included in an incremental sync
+ response if they haven't changed.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(
+ user2_id,
+ tok=user2_tok,
+ extra_content={
+ "name": "my super room",
+ },
+ )
+ # Set the room avatar URL
+ self.helper.send_state(
+ room_id1,
+ EventTypes.RoomAvatar,
+ {"url": "mxc://DUMMY_MEDIA_ID"},
+ tok=user2_tok,
+ )
+
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ # Make the Sliding Sync request
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [],
+ # This needs to be set to one so the `RoomResult` isn't empty and
+ # the room comes down incremental sync when we send a new message.
+ "timeline_limit": 1,
+ }
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Send a message to make the room come down sync
+ self.helper.send(room_id1, "message in room1", tok=user2_tok)
+
+ # Incremental sync
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+ # We should only see changed meta info (nothing changed so we shouldn't see any
+ # of these fields)
+ self.assertNotIn(
+ "initial",
+ response_body["rooms"][room_id1],
+ )
+ self.assertNotIn(
+ "name",
+ response_body["rooms"][room_id1],
+ )
+ self.assertNotIn(
+ "avatar",
+ response_body["rooms"][room_id1],
+ )
+ self.assertNotIn(
+ "joined_count",
+ response_body["rooms"][room_id1],
+ )
+ self.assertNotIn(
+ "invited_count",
+ response_body["rooms"][room_id1],
+ )
+ self.assertIsNone(
+ response_body["rooms"][room_id1].get("is_dm"),
+ )
+
+ @parameterized.expand(
+ [
+ ("in_required_state", True),
+ ("not_in_required_state", False),
+ ]
+ )
+ def test_rooms_meta_when_joined_incremental_with_state_change(
+ self, test_description: str, include_changed_state_in_required_state: bool
+ ) -> None:
+ """
+ Test that the `rooms` `name` and `avatar` are included in an incremental sync
+ response if they changed.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(
+ user2_id,
+ tok=user2_tok,
+ extra_content={
+ "name": "my super room",
+ },
+ )
+ # Set the room avatar URL
+ self.helper.send_state(
+ room_id1,
+ EventTypes.RoomAvatar,
+ {"url": "mxc://DUMMY_MEDIA_ID"},
+ tok=user2_tok,
+ )
+
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ # Make the Sliding Sync request
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": (
+ [[EventTypes.Name, ""], [EventTypes.RoomAvatar, ""]]
+ # Conditionally include the changed state in the
+ # `required_state` to make sure whether we request it or not,
+ # the new room name still flows down to the client.
+ if include_changed_state_in_required_state
+ else []
+ ),
+ "timeline_limit": 0,
+ }
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Update the room name
+ self.helper.send_state(
+ room_id1,
+ EventTypes.Name,
+ {EventContentFields.ROOM_NAME: "my super duper room"},
+ tok=user2_tok,
+ )
+ # Update the room avatar URL
+ self.helper.send_state(
+ room_id1,
+ EventTypes.RoomAvatar,
+ {"url": "mxc://DUMMY_MEDIA_ID_UPDATED"},
+ tok=user2_tok,
+ )
+
+ # Incremental sync
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+ # We should only see changed meta info (the room name and avatar)
+ self.assertNotIn(
+ "initial",
+ response_body["rooms"][room_id1],
+ )
+ self.assertEqual(
+ response_body["rooms"][room_id1]["name"],
+ "my super duper room",
+ response_body["rooms"][room_id1],
+ )
+ self.assertEqual(
+ response_body["rooms"][room_id1]["avatar"],
+ "mxc://DUMMY_MEDIA_ID_UPDATED",
+ response_body["rooms"][room_id1],
+ )
+ self.assertNotIn(
+ "joined_count",
+ response_body["rooms"][room_id1],
+ )
+ self.assertNotIn(
+ "invited_count",
+ response_body["rooms"][room_id1],
+ )
+ self.assertIsNone(
+ response_body["rooms"][room_id1].get("is_dm"),
+ )
+
def test_rooms_meta_when_invited(self) -> None:
"""
Test that the `rooms` `name` and `avatar` are included in the response and
@@ -164,6 +360,7 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
# This should still reflect the current state of the room even when the user is
# invited.
+ self.assertEqual(response_body["rooms"][room_id1]["initial"], True)
self.assertEqual(
response_body["rooms"][room_id1]["name"],
"my super duper room",
@@ -174,14 +371,17 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
"mxc://UPDATED_DUMMY_MEDIA_ID",
response_body["rooms"][room_id1],
)
- self.assertEqual(
- response_body["rooms"][room_id1]["joined_count"],
- 1,
+
+ # We don't give extra room information to invitees
+ self.assertNotIn(
+ "joined_count",
+ response_body["rooms"][room_id1],
)
- self.assertEqual(
- response_body["rooms"][room_id1]["invited_count"],
- 1,
+ self.assertNotIn(
+ "invited_count",
+ response_body["rooms"][room_id1],
)
+
self.assertIsNone(
response_body["rooms"][room_id1].get("is_dm"),
)
@@ -242,6 +442,7 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
# Reflect the state of the room at the time of leaving
+ self.assertEqual(response_body["rooms"][room_id1]["initial"], True)
self.assertEqual(
response_body["rooms"][room_id1]["name"],
"my super room",
@@ -252,15 +453,16 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
"mxc://DUMMY_MEDIA_ID",
response_body["rooms"][room_id1],
)
- self.assertEqual(
- response_body["rooms"][room_id1]["joined_count"],
- # FIXME: The actual number should be "1" (user2) but we currently don't
- # support this for rooms where the user has left/been banned.
- 0,
+
+ # FIXME: We possibly want to return joined and invited counts for rooms
+ # you're banned form
+ self.assertNotIn(
+ "joined_count",
+ response_body["rooms"][room_id1],
)
- self.assertEqual(
- response_body["rooms"][room_id1]["invited_count"],
- 0,
+ self.assertNotIn(
+ "invited_count",
+ response_body["rooms"][room_id1],
)
self.assertIsNone(
response_body["rooms"][room_id1].get("is_dm"),
@@ -316,6 +518,7 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
# Room1 has a name so we shouldn't see any `heroes` which the client would use
# the calculate the room name themselves.
+ self.assertEqual(response_body["rooms"][room_id1]["initial"], True)
self.assertEqual(
response_body["rooms"][room_id1]["name"],
"my super room",
@@ -332,6 +535,7 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
)
# Room2 doesn't have a name so we should see `heroes` populated
+ self.assertEqual(response_body["rooms"][room_id2]["initial"], True)
self.assertIsNone(response_body["rooms"][room_id2].get("name"))
self.assertCountEqual(
[
@@ -403,6 +607,7 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
# Room2 doesn't have a name so we should see `heroes` populated
+ self.assertEqual(response_body["rooms"][room_id1]["initial"], True)
self.assertIsNone(response_body["rooms"][room_id1].get("name"))
self.assertCountEqual(
[
@@ -475,7 +680,8 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
}
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
- # Room2 doesn't have a name so we should see `heroes` populated
+ # Room doesn't have a name so we should see `heroes` populated
+ self.assertEqual(response_body["rooms"][room_id1]["initial"], True)
self.assertIsNone(response_body["rooms"][room_id1].get("name"))
self.assertCountEqual(
[
@@ -490,20 +696,175 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
[],
)
+ # FIXME: We possibly want to return joined and invited counts for rooms
+ # you're banned form
+ self.assertNotIn(
+ "joined_count",
+ response_body["rooms"][room_id1],
+ )
+ self.assertNotIn(
+ "invited_count",
+ response_body["rooms"][room_id1],
+ )
+
+ def test_rooms_meta_heroes_incremental_sync_no_change(self) -> None:
+ """
+ Test that the `rooms` `heroes` aren't included in an incremental sync
+ response if they haven't changed.
+
+ (when the room doesn't have a room name set)
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user3_id = self.register_user("user3", "pass")
+ _user3_tok = self.login(user3_id, "pass")
+
+ room_id = self.helper.create_room_as(
+ user2_id,
+ tok=user2_tok,
+ extra_content={
+ # No room name set so that `heroes` is populated
+ #
+ # "name": "my super room2",
+ },
+ )
+ self.helper.join(room_id, user1_id, tok=user1_tok)
+ # User3 is invited
+ self.helper.invite(room_id, src=user2_id, targ=user3_id, tok=user2_tok)
+
+ # Make the Sliding Sync request
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [],
+ # This needs to be set to one so the `RoomResult` isn't empty and
+ # the room comes down incremental sync when we send a new message.
+ "timeline_limit": 1,
+ }
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Send a message to make the room come down sync
+ self.helper.send(room_id, "message in room", tok=user2_tok)
+
+ # Incremental sync
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+ # This is an incremental sync and the second time we have seen this room so it
+ # isn't `initial`
+ self.assertNotIn(
+ "initial",
+ response_body["rooms"][room_id],
+ )
+ # Room shouldn't have a room name because we're testing the `heroes` field which
+ # will only has a chance to appear if the room doesn't have a name.
+ self.assertNotIn(
+ "name",
+ response_body["rooms"][room_id],
+ )
+ # No change to heroes
+ self.assertNotIn(
+ "heroes",
+ response_body["rooms"][room_id],
+ )
+ # No change to member counts
+ self.assertNotIn(
+ "joined_count",
+ response_body["rooms"][room_id],
+ )
+ self.assertNotIn(
+ "invited_count",
+ response_body["rooms"][room_id],
+ )
+ # We didn't request any state so we shouldn't see any `required_state`
+ self.assertNotIn(
+ "required_state",
+ response_body["rooms"][room_id],
+ )
+
+ def test_rooms_meta_heroes_incremental_sync_with_membership_change(self) -> None:
+ """
+ Test that the `rooms` `heroes` are included in an incremental sync response if
+ the membership has changed.
+
+ (when the room doesn't have a room name set)
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user3_id = self.register_user("user3", "pass")
+ user3_tok = self.login(user3_id, "pass")
+
+ room_id = self.helper.create_room_as(
+ user2_id,
+ tok=user2_tok,
+ extra_content={
+ # No room name set so that `heroes` is populated
+ #
+ # "name": "my super room2",
+ },
+ )
+ self.helper.join(room_id, user1_id, tok=user1_tok)
+ # User3 is invited
+ self.helper.invite(room_id, src=user2_id, targ=user3_id, tok=user2_tok)
+
+ # Make the Sliding Sync request
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [],
+ "timeline_limit": 0,
+ }
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # User3 joins (membership change)
+ self.helper.join(room_id, user3_id, tok=user3_tok)
+
+ # Incremental sync
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+ # This is an incremental sync and the second time we have seen this room so it
+ # isn't `initial`
+ self.assertNotIn(
+ "initial",
+ response_body["rooms"][room_id],
+ )
+ # Room shouldn't have a room name because we're testing the `heroes` field which
+ # will only has a chance to appear if the room doesn't have a name.
+ self.assertNotIn(
+ "name",
+ response_body["rooms"][room_id],
+ )
+ # Membership change so we should see heroes and membership counts
+ self.assertCountEqual(
+ [
+ hero["user_id"]
+ for hero in response_body["rooms"][room_id].get("heroes", [])
+ ],
+ # Heroes shouldn't include the user themselves (we shouldn't see user1)
+ [user2_id, user3_id],
+ )
self.assertEqual(
- response_body["rooms"][room_id1]["joined_count"],
- # FIXME: The actual number should be "1" (user2) but we currently don't
- # support this for rooms where the user has left/been banned.
- 0,
+ response_body["rooms"][room_id]["joined_count"],
+ 3,
)
self.assertEqual(
- response_body["rooms"][room_id1]["invited_count"],
- # We shouldn't see user5 since they were invited after user1 was banned.
- #
- # FIXME: The actual number should be "1" (user3) but we currently don't
- # support this for rooms where the user has left/been banned.
+ response_body["rooms"][room_id]["invited_count"],
0,
)
+ # We didn't request any state so we shouldn't see any `required_state`
+ self.assertNotIn(
+ "required_state",
+ response_body["rooms"][room_id],
+ )
def test_rooms_bump_stamp(self) -> None:
"""
@@ -566,19 +927,17 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
)
# Make sure the list includes the rooms in the right order
- self.assertListEqual(
- list(response_body["lists"]["foo-list"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 1],
- # room1 sorts before room2 because it has the latest event (the
- # reaction)
- "room_ids": [room_id1, room_id2],
- }
- ],
+ self.assertEqual(
+ len(response_body["lists"]["foo-list"]["ops"]),
+ 1,
response_body["lists"]["foo-list"],
)
+ op = response_body["lists"]["foo-list"]["ops"][0]
+ self.assertEqual(op["op"], "SYNC")
+ self.assertEqual(op["range"], [0, 1])
+ # Note that we don't sort the rooms when the range includes all of the rooms, so
+ # we just assert that the rooms are included
+ self.assertIncludes(set(op["room_ids"]), {room_id1, room_id2}, exact=True)
# The `bump_stamp` for room1 should point at the latest message (not the
# reaction since it's not one of the `DEFAULT_BUMP_EVENT_TYPES`)
@@ -600,16 +959,16 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
Test that `bump_stamp` ignores backfilled events, i.e. events with a
negative stream ordering.
"""
-
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
# Create a remote room
creator = "@user:other"
room_id = "!foo:other"
+ room_version = RoomVersions.V10
shared_kwargs = {
"room_id": room_id,
- "room_version": "10",
+ "room_version": room_version.identifier,
}
create_tuple = self.get_success(
@@ -618,6 +977,12 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
prev_event_ids=[],
type=EventTypes.Create,
state_key="",
+ content={
+ # The `ROOM_CREATOR` field could be removed if we used a room
+ # version > 10 (in favor of relying on `sender`)
+ EventContentFields.ROOM_CREATOR: creator,
+ EventContentFields.ROOM_VERSION: room_version.identifier,
+ },
sender=creator,
**shared_kwargs,
)
@@ -667,22 +1032,29 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
]
# Ensure the local HS knows the room version
- self.get_success(
- self.store.store_room(room_id, creator, False, RoomVersions.V10)
- )
+ self.get_success(self.store.store_room(room_id, creator, False, room_version))
# Persist these events as backfilled events.
- persistence = self.hs.get_storage_controllers().persistence
- assert persistence is not None
-
for event, context in remote_events_and_contexts:
- self.get_success(persistence.persist_event(event, context, backfilled=True))
+ self.get_success(
+ self.persistence.persist_event(event, context, backfilled=True)
+ )
- # Now we join the local user to the room
- join_tuple = self.get_success(
+ # Now we join the local user to the room. We want to make this feel as close to
+ # the real `process_remote_join()` as possible but we'd like to avoid some of
+ # the auth checks that would be done in the real code.
+ #
+ # FIXME: The test was originally written using this less-real
+ # `persist_event(...)` shortcut but it would be nice to use the real remote join
+ # process in a `FederatingHomeserverTestCase`.
+ flawed_join_tuple = self.get_success(
create_event(
self.hs,
prev_event_ids=[invite_tuple[0].event_id],
+ # This doesn't work correctly to create an `EventContext` that includes
+ # both of these state events. I assume it's because we're working on our
+ # local homeserver which has the remote state set as `outlier`. We have
+ # to create our own EventContext below to get this right.
auth_event_ids=[create_tuple[0].event_id, invite_tuple[0].event_id],
type=EventTypes.Member,
state_key=user1_id,
@@ -691,7 +1063,22 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
**shared_kwargs,
)
)
- self.get_success(persistence.persist_event(*join_tuple))
+ # We have to create our own context to get the state set correctly. If we use
+ # the `EventContext` from the `flawed_join_tuple`, the `current_state_events`
+ # table will only have the join event in it which should never happen in our
+ # real server.
+ join_event = flawed_join_tuple[0]
+ join_context = self.get_success(
+ self.state_handler.compute_event_context(
+ join_event,
+ state_ids_before_event={
+ (e.type, e.state_key): e.event_id
+ for e in [create_tuple[0], invite_tuple[0]]
+ },
+ partial_state=False,
+ )
+ )
+ self.get_success(self.persistence.persist_event(join_event, join_context))
# Doing an SS request should return a positive `bump_stamp`, even though
# the only event that matches the bump types has as negative stream
@@ -708,3 +1095,244 @@ class SlidingSyncRoomsMetaTestCase(SlidingSyncBase):
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
self.assertGreater(response_body["rooms"][room_id]["bump_stamp"], 0)
+
+ def test_rooms_bump_stamp_no_change_incremental(self) -> None:
+ """Test that the bump stamp is omitted if there has been no change"""
+
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ room_id1 = self.helper.create_room_as(
+ user1_id,
+ tok=user1_tok,
+ )
+
+ # Make the Sliding Sync request
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [],
+ "timeline_limit": 100,
+ }
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Initial sync so we expect to see a bump stamp
+ self.assertIn("bump_stamp", response_body["rooms"][room_id1])
+
+ # Send an event that is not in the bump events list
+ self.helper.send_event(
+ room_id1, type="org.matrix.test", content={}, tok=user1_tok
+ )
+
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+
+ # There hasn't been a change to the bump stamps, so we ignore it
+ self.assertNotIn("bump_stamp", response_body["rooms"][room_id1])
+
+ def test_rooms_bump_stamp_change_incremental(self) -> None:
+ """Test that the bump stamp is included if there has been a change, even
+ if its not in the timeline"""
+
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ room_id1 = self.helper.create_room_as(
+ user1_id,
+ tok=user1_tok,
+ )
+
+ # Make the Sliding Sync request
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [],
+ "timeline_limit": 2,
+ }
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Initial sync so we expect to see a bump stamp
+ self.assertIn("bump_stamp", response_body["rooms"][room_id1])
+ first_bump_stamp = response_body["rooms"][room_id1]["bump_stamp"]
+
+ # Send a bump event at the start.
+ self.helper.send(room_id1, "test", tok=user1_tok)
+
+ # Send events that are not in the bump events list to fill the timeline
+ for _ in range(5):
+ self.helper.send_event(
+ room_id1, type="org.matrix.test", content={}, tok=user1_tok
+ )
+
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+
+ # There was a bump event in the timeline gap, so we should see the bump
+ # stamp be updated.
+ self.assertIn("bump_stamp", response_body["rooms"][room_id1])
+ second_bump_stamp = response_body["rooms"][room_id1]["bump_stamp"]
+
+ self.assertGreater(second_bump_stamp, first_bump_stamp)
+
+ def test_rooms_bump_stamp_invites(self) -> None:
+ """
+ Test that `bump_stamp` is present and points to the membership event,
+ and not later events, for non-joined rooms
+ """
+
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id = self.helper.create_room_as(
+ user2_id,
+ tok=user2_tok,
+ )
+
+ # Invite user1 to the room
+ invite_response = self.helper.invite(room_id, user2_id, user1_id, tok=user2_tok)
+
+ # More messages happen after the invite
+ self.helper.send(room_id, "message in room1", tok=user2_tok)
+
+ # We expect the bump_stamp to match the invite.
+ invite_pos = self.get_success(
+ self.store.get_position_for_event(invite_response["event_id"])
+ )
+
+ # Doing an SS request should return a `bump_stamp` of the invite event,
+ # rather than the message that was sent after.
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [],
+ "timeline_limit": 5,
+ }
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+
+ self.assertEqual(
+ response_body["rooms"][room_id]["bump_stamp"], invite_pos.stream
+ )
+
+ def test_rooms_meta_is_dm(self) -> None:
+ """
+ Test `rooms` `is_dm` is correctly set for DM rooms.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # Create a DM room
+ joined_dm_room_id = self._create_dm_room(
+ inviter_user_id=user1_id,
+ inviter_tok=user1_tok,
+ invitee_user_id=user2_id,
+ invitee_tok=user2_tok,
+ should_join_room=True,
+ )
+ invited_dm_room_id = self._create_dm_room(
+ inviter_user_id=user1_id,
+ inviter_tok=user1_tok,
+ invitee_user_id=user2_id,
+ invitee_tok=user2_tok,
+ should_join_room=False,
+ )
+
+ # Create a normal room
+ room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id, user1_id, tok=user1_tok)
+
+ # Create a room that user1 is invited to
+ invite_room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.invite(invite_room_id, src=user2_id, targ=user1_id, tok=user2_tok)
+
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ }
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+
+ # Ensure DM's are correctly marked
+ self.assertDictEqual(
+ {
+ room_id: room.get("is_dm")
+ for room_id, room in response_body["rooms"].items()
+ },
+ {
+ invite_room_id: None,
+ room_id: None,
+ invited_dm_room_id: True,
+ joined_dm_room_id: True,
+ },
+ )
+
+ def test_old_room_with_unknown_room_version(self) -> None:
+ """Test that an old room with unknown room version does not break
+ sync."""
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # We first create a standard room, then we'll change the room version in
+ # the DB.
+ room_id = self.helper.create_room_as(
+ user1_id,
+ tok=user1_tok,
+ )
+
+ # Poke the database and update the room version to an unknown one.
+ self.get_success(
+ self.hs.get_datastores().main.db_pool.simple_update(
+ "rooms",
+ keyvalues={"room_id": room_id},
+ updatevalues={"room_version": "unknown-room-version"},
+ desc="updated-room-version",
+ )
+ )
+
+ # Invalidate method so that it returns the currently updated version
+ # instead of the cached version.
+ self.hs.get_datastores().main.get_room_version_id.invalidate((room_id,))
+
+ # For old unknown room versions we won't have an entry in this table
+ # (due to us skipping unknown room versions in the background update).
+ self.get_success(
+ self.store.db_pool.simple_delete(
+ table="sliding_sync_joined_rooms",
+ keyvalues={"room_id": room_id},
+ desc="delete_sliding_room",
+ )
+ )
+
+ # Also invalidate some caches to ensure we pull things from the DB.
+ self.store._events_stream_cache._entity_to_key.pop(room_id)
+ self.store._get_max_event_pos.invalidate((room_id,))
+
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [],
+ "timeline_limit": 5,
+ }
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
diff --git a/tests/rest/client/sliding_sync/test_rooms_required_state.py b/tests/rest/client/sliding_sync/test_rooms_required_state.py
index a13cad223f..ba46c5a93c 100644
--- a/tests/rest/client/sliding_sync/test_rooms_required_state.py
+++ b/tests/rest/client/sliding_sync/test_rooms_required_state.py
@@ -11,16 +11,17 @@
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
+import enum
import logging
-from parameterized import parameterized
+from parameterized import parameterized, parameterized_class
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventContentFields, EventTypes, JoinRules, Membership
from synapse.handlers.sliding_sync import StateValues
-from synapse.rest.client import login, room, sync
+from synapse.rest.client import knock, login, room, sync
from synapse.server import HomeServer
from synapse.util import Clock
@@ -30,6 +31,31 @@ from tests.test_utils.event_injection import mark_event_as_partial_state
logger = logging.getLogger(__name__)
+# Inherit from `str` so that they show up in the test description when we
+# `@parameterized.expand(...)` the first parameter
+class MembershipAction(str, enum.Enum):
+ INVITE = "invite"
+ JOIN = "join"
+ KNOCK = "knock"
+ LEAVE = "leave"
+ BAN = "ban"
+ KICK = "kick"
+
+
+# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+# foreground update for
+# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+# https://github.com/element-hq/synapse/issues/17623)
+@parameterized_class(
+ ("use_new_tables",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}",
+)
class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
"""
Test `rooms.required_state` in the Sliding Sync API.
@@ -38,6 +64,7 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
+ knock.register_servlets,
room.register_servlets,
sync.register_servlets,
]
@@ -46,6 +73,8 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
self.store = hs.get_datastores().main
self.storage_controllers = hs.get_storage_controllers()
+ super().prepare(reactor, clock, hs)
+
def test_rooms_no_required_state(self) -> None:
"""
Empty `rooms.required_state` should not return any state events in the room
@@ -191,8 +220,14 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
}
_, from_token = self.do_sync(sync_body, tok=user1_tok)
- # Reset the in-memory cache
- self.hs.get_sliding_sync_handler().connection_store._connections.clear()
+ # Reset the positions
+ self.get_success(
+ self.store.db_pool.simple_delete(
+ table="sliding_sync_connections",
+ keyvalues={"user_id": user1_id},
+ desc="clear_sliding_sync_connections_cache",
+ )
+ )
# Make the Sliding Sync request
channel = self.make_request(
@@ -359,10 +394,10 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
)
self.assertIsNone(response_body["rooms"][room_id1].get("invite_state"))
- def test_rooms_required_state_lazy_loading_room_members(self) -> None:
+ def test_rooms_required_state_lazy_loading_room_members_initial_sync(self) -> None:
"""
- Test `rooms.required_state` returns people relevant to the timeline when
- lazy-loading room members, `["m.room.member","$LAZY"]`.
+ On initial sync, test `rooms.required_state` returns people relevant to the
+ timeline when lazy-loading room members, `["m.room.member","$LAZY"]`.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
@@ -410,6 +445,402 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
)
self.assertIsNone(response_body["rooms"][room_id1].get("invite_state"))
+ def test_rooms_required_state_lazy_loading_room_members_incremental_sync(
+ self,
+ ) -> None:
+ """
+ On incremental sync, test `rooms.required_state` returns people relevant to the
+ timeline when lazy-loading room members, `["m.room.member","$LAZY"]`.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user3_id = self.register_user("user3", "pass")
+ user3_tok = self.login(user3_id, "pass")
+ user4_id = self.register_user("user4", "pass")
+ user4_tok = self.login(user4_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+ self.helper.join(room_id1, user3_id, tok=user3_tok)
+ self.helper.join(room_id1, user4_id, tok=user4_tok)
+
+ self.helper.send(room_id1, "1", tok=user2_tok)
+ self.helper.send(room_id1, "2", tok=user2_tok)
+ self.helper.send(room_id1, "3", tok=user2_tok)
+
+ # Make the Sliding Sync request with lazy loading for the room members
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [EventTypes.Create, ""],
+ [EventTypes.Member, StateValues.LAZY],
+ ],
+ "timeline_limit": 3,
+ }
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Send more timeline events into the room
+ self.helper.send(room_id1, "4", tok=user2_tok)
+ self.helper.send(room_id1, "5", tok=user4_tok)
+ self.helper.send(room_id1, "6", tok=user4_tok)
+
+ # Make an incremental Sliding Sync request
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ # Only user2 and user4 sent events in the last 3 events we see in the `timeline`
+ # but since we've seen user2 in the last sync (and their membership hasn't
+ # changed), we should only see user4 here.
+ self._assertRequiredStateIncludes(
+ response_body["rooms"][room_id1]["required_state"],
+ {
+ state_map[(EventTypes.Member, user4_id)],
+ },
+ exact=True,
+ )
+ self.assertIsNone(response_body["rooms"][room_id1].get("invite_state"))
+
+ @parameterized.expand(
+ [
+ (MembershipAction.LEAVE,),
+ (MembershipAction.INVITE,),
+ (MembershipAction.KNOCK,),
+ (MembershipAction.JOIN,),
+ (MembershipAction.BAN,),
+ (MembershipAction.KICK,),
+ ]
+ )
+ def test_rooms_required_state_changed_membership_in_timeline_lazy_loading_room_members_incremental_sync(
+ self,
+ room_membership_action: str,
+ ) -> None:
+ """
+ On incremental sync, test `rooms.required_state` returns people relevant to the
+ timeline when lazy-loading room members, `["m.room.member","$LAZY"]` **including
+ changes to membership**.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user3_id = self.register_user("user3", "pass")
+ user3_tok = self.login(user3_id, "pass")
+ user4_id = self.register_user("user4", "pass")
+ user4_tok = self.login(user4_id, "pass")
+ user5_id = self.register_user("user5", "pass")
+ user5_tok = self.login(user5_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
+ # If we're testing knocks, set the room to knock
+ if room_membership_action == MembershipAction.KNOCK:
+ self.helper.send_state(
+ room_id1,
+ EventTypes.JoinRules,
+ {"join_rule": JoinRules.KNOCK},
+ tok=user2_tok,
+ )
+
+ # Join the test users to the room
+ self.helper.invite(room_id1, src=user2_id, targ=user1_id, tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+ self.helper.invite(room_id1, src=user2_id, targ=user3_id, tok=user2_tok)
+ self.helper.join(room_id1, user3_id, tok=user3_tok)
+ self.helper.invite(room_id1, src=user2_id, targ=user4_id, tok=user2_tok)
+ self.helper.join(room_id1, user4_id, tok=user4_tok)
+ if room_membership_action in (
+ MembershipAction.LEAVE,
+ MembershipAction.BAN,
+ MembershipAction.JOIN,
+ ):
+ self.helper.invite(room_id1, src=user2_id, targ=user5_id, tok=user2_tok)
+ self.helper.join(room_id1, user5_id, tok=user5_tok)
+
+ # Send some messages to fill up the space
+ self.helper.send(room_id1, "1", tok=user2_tok)
+ self.helper.send(room_id1, "2", tok=user2_tok)
+ self.helper.send(room_id1, "3", tok=user2_tok)
+
+ # Make the Sliding Sync request with lazy loading for the room members
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [EventTypes.Create, ""],
+ [EventTypes.Member, StateValues.LAZY],
+ ],
+ "timeline_limit": 3,
+ }
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Send more timeline events into the room
+ self.helper.send(room_id1, "4", tok=user2_tok)
+ self.helper.send(room_id1, "5", tok=user4_tok)
+ # The third event will be our membership event concerning user5
+ if room_membership_action == MembershipAction.LEAVE:
+ # User 5 leaves
+ self.helper.leave(room_id1, user5_id, tok=user5_tok)
+ elif room_membership_action == MembershipAction.INVITE:
+ # User 5 is invited
+ self.helper.invite(room_id1, src=user2_id, targ=user5_id, tok=user2_tok)
+ elif room_membership_action == MembershipAction.KNOCK:
+ # User 5 knocks
+ self.helper.knock(room_id1, user5_id, tok=user5_tok)
+ # The admin of the room accepts the knock
+ self.helper.invite(room_id1, src=user2_id, targ=user5_id, tok=user2_tok)
+ elif room_membership_action == MembershipAction.JOIN:
+ # Update the display name of user5 (causing a membership change)
+ self.helper.send_state(
+ room_id1,
+ event_type=EventTypes.Member,
+ state_key=user5_id,
+ body={
+ EventContentFields.MEMBERSHIP: Membership.JOIN,
+ EventContentFields.MEMBERSHIP_DISPLAYNAME: "quick changer",
+ },
+ tok=user5_tok,
+ )
+ elif room_membership_action == MembershipAction.BAN:
+ self.helper.ban(room_id1, src=user2_id, targ=user5_id, tok=user2_tok)
+ elif room_membership_action == MembershipAction.KICK:
+ # Kick user5 from the room
+ self.helper.change_membership(
+ room=room_id1,
+ src=user2_id,
+ targ=user5_id,
+ tok=user2_tok,
+ membership=Membership.LEAVE,
+ extra_data={
+ "reason": "Bad manners",
+ },
+ )
+ else:
+ raise AssertionError(
+ f"Unknown room_membership_action: {room_membership_action}"
+ )
+
+ # Make an incremental Sliding Sync request
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ # Only user2, user4, and user5 sent events in the last 3 events we see in the
+ # `timeline`.
+ self._assertRequiredStateIncludes(
+ response_body["rooms"][room_id1]["required_state"],
+ {
+ # This appears because *some* membership in the room changed and the
+ # heroes are recalculated and is thrown in because we have it. But this
+ # is technically optional and not needed because we've already seen user2
+ # in the last sync (and their membership hasn't changed).
+ state_map[(EventTypes.Member, user2_id)],
+ # Appears because there is a message in the timeline from this user
+ state_map[(EventTypes.Member, user4_id)],
+ # Appears because there is a membership event in the timeline from this user
+ state_map[(EventTypes.Member, user5_id)],
+ },
+ exact=True,
+ )
+ self.assertIsNone(response_body["rooms"][room_id1].get("invite_state"))
+
+ def test_rooms_required_state_expand_lazy_loading_room_members_incremental_sync(
+ self,
+ ) -> None:
+ """
+ Test that when we expand the `required_state` to include lazy-loading room
+ members, it returns people relevant to the timeline.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user3_id = self.register_user("user3", "pass")
+ user3_tok = self.login(user3_id, "pass")
+ user4_id = self.register_user("user4", "pass")
+ user4_tok = self.login(user4_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+ self.helper.join(room_id1, user3_id, tok=user3_tok)
+ self.helper.join(room_id1, user4_id, tok=user4_tok)
+
+ self.helper.send(room_id1, "1", tok=user2_tok)
+ self.helper.send(room_id1, "2", tok=user2_tok)
+ self.helper.send(room_id1, "3", tok=user2_tok)
+
+ # Make the Sliding Sync request *without* lazy loading for the room members
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [EventTypes.Create, ""],
+ ],
+ "timeline_limit": 3,
+ }
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Send more timeline events into the room
+ self.helper.send(room_id1, "4", tok=user2_tok)
+ self.helper.send(room_id1, "5", tok=user4_tok)
+ self.helper.send(room_id1, "6", tok=user4_tok)
+
+ # Expand `required_state` and make an incremental Sliding Sync request *with*
+ # lazy-loading room members
+ sync_body["lists"]["foo-list"]["required_state"] = [
+ [EventTypes.Create, ""],
+ [EventTypes.Member, StateValues.LAZY],
+ ]
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ # Only user2 and user4 sent events in the last 3 events we see in the `timeline`
+ # and we haven't seen any membership before this sync so we should see both
+ # users.
+ self._assertRequiredStateIncludes(
+ response_body["rooms"][room_id1]["required_state"],
+ {
+ state_map[(EventTypes.Member, user2_id)],
+ state_map[(EventTypes.Member, user4_id)],
+ },
+ exact=True,
+ )
+ self.assertIsNone(response_body["rooms"][room_id1].get("invite_state"))
+
+ # Send a message so the room comes down sync.
+ self.helper.send(room_id1, "7", tok=user2_tok)
+ self.helper.send(room_id1, "8", tok=user4_tok)
+ self.helper.send(room_id1, "9", tok=user4_tok)
+
+ # Make another incremental Sliding Sync request
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+ # Only user2 and user4 sent events in the last 3 events we see in the `timeline`
+ # but since we've seen both memberships in the last sync, they shouldn't appear
+ # again.
+ self._assertRequiredStateIncludes(
+ response_body["rooms"][room_id1].get("required_state", []),
+ set(),
+ exact=True,
+ )
+ self.assertIsNone(response_body["rooms"][room_id1].get("invite_state"))
+
+ def test_rooms_required_state_expand_retract_expand_lazy_loading_room_members_incremental_sync(
+ self,
+ ) -> None:
+ """
+ Test that when we expand the `required_state` to include lazy-loading room
+ members, it returns people relevant to the timeline.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user3_id = self.register_user("user3", "pass")
+ user3_tok = self.login(user3_id, "pass")
+ user4_id = self.register_user("user4", "pass")
+ user4_tok = self.login(user4_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+ self.helper.join(room_id1, user3_id, tok=user3_tok)
+ self.helper.join(room_id1, user4_id, tok=user4_tok)
+
+ self.helper.send(room_id1, "1", tok=user2_tok)
+ self.helper.send(room_id1, "2", tok=user2_tok)
+ self.helper.send(room_id1, "3", tok=user2_tok)
+
+ # Make the Sliding Sync request *without* lazy loading for the room members
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [EventTypes.Create, ""],
+ ],
+ "timeline_limit": 3,
+ }
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # Send more timeline events into the room
+ self.helper.send(room_id1, "4", tok=user2_tok)
+ self.helper.send(room_id1, "5", tok=user4_tok)
+ self.helper.send(room_id1, "6", tok=user4_tok)
+
+ # Expand `required_state` and make an incremental Sliding Sync request *with*
+ # lazy-loading room members
+ sync_body["lists"]["foo-list"]["required_state"] = [
+ [EventTypes.Create, ""],
+ [EventTypes.Member, StateValues.LAZY],
+ ]
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ # Only user2 and user4 sent events in the last 3 events we see in the `timeline`
+ # and we haven't seen any membership before this sync so we should see both
+ # users because we're lazy-loading the room members.
+ self._assertRequiredStateIncludes(
+ response_body["rooms"][room_id1]["required_state"],
+ {
+ state_map[(EventTypes.Member, user2_id)],
+ state_map[(EventTypes.Member, user4_id)],
+ },
+ exact=True,
+ )
+
+ # Send a message so the room comes down sync.
+ self.helper.send(room_id1, "msg", tok=user4_tok)
+
+ # Retract `required_state` and make an incremental Sliding Sync request
+ # requesting a few memberships
+ sync_body["lists"]["foo-list"]["required_state"] = [
+ [EventTypes.Create, ""],
+ [EventTypes.Member, StateValues.ME],
+ [EventTypes.Member, user2_id],
+ ]
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ # We've seen user2's membership in the last sync so we shouldn't see it here
+ # even though it's requested. We should only see user1's membership.
+ self._assertRequiredStateIncludes(
+ response_body["rooms"][room_id1]["required_state"],
+ {
+ state_map[(EventTypes.Member, user1_id)],
+ },
+ exact=True,
+ )
+
def test_rooms_required_state_me(self) -> None:
"""
Test `rooms.required_state` correctly handles $ME.
@@ -480,9 +911,10 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
self.assertIsNone(response_body["rooms"][room_id1].get("invite_state"))
@parameterized.expand([(Membership.LEAVE,), (Membership.BAN,)])
- def test_rooms_required_state_leave_ban(self, stop_membership: str) -> None:
+ def test_rooms_required_state_leave_ban_initial(self, stop_membership: str) -> None:
"""
- Test `rooms.required_state` should not return state past a leave/ban event.
+ Test `rooms.required_state` should not return state past a leave/ban event when
+ it's the first "initial" time the room is being sent down the connection.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
@@ -517,6 +949,13 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
body={"foo": "bar"},
tok=user2_tok,
)
+ self.helper.send_state(
+ room_id1,
+ event_type="org.matrix.bar_state",
+ state_key="",
+ body={"bar": "bar"},
+ tok=user2_tok,
+ )
if stop_membership == Membership.LEAVE:
# User 1 leaves
@@ -525,6 +964,8 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
# User 1 is banned
self.helper.ban(room_id1, src=user2_id, targ=user1_id, tok=user2_tok)
+ # Get the state_map before we change the state as this is the final state we
+ # expect User1 to be able to see
state_map = self.get_success(
self.storage_controllers.state.get_current_state(room_id1)
)
@@ -537,12 +978,36 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
body={"foo": "qux"},
tok=user2_tok,
)
+ self.helper.send_state(
+ room_id1,
+ event_type="org.matrix.bar_state",
+ state_key="",
+ body={"bar": "qux"},
+ tok=user2_tok,
+ )
self.helper.leave(room_id1, user3_id, tok=user3_tok)
- # Make the Sliding Sync request with lazy loading for the room members
+ # Make an incremental Sliding Sync request
+ #
+ # Also expand the required state to include the `org.matrix.bar_state` event.
+ # This is just an extra complication of the test.
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [EventTypes.Create, ""],
+ [EventTypes.Member, "*"],
+ ["org.matrix.foo_state", ""],
+ ["org.matrix.bar_state", ""],
+ ],
+ "timeline_limit": 3,
+ }
+ }
+ }
response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
- # Only user2 and user3 sent events in the 3 events we see in the `timeline`
+ # We should only see the state up to the leave/ban event
self._assertRequiredStateIncludes(
response_body["rooms"][room_id1]["required_state"],
{
@@ -551,6 +1016,126 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
state_map[(EventTypes.Member, user2_id)],
state_map[(EventTypes.Member, user3_id)],
state_map[("org.matrix.foo_state", "")],
+ state_map[("org.matrix.bar_state", "")],
+ },
+ exact=True,
+ )
+ self.assertIsNone(response_body["rooms"][room_id1].get("invite_state"))
+
+ @parameterized.expand([(Membership.LEAVE,), (Membership.BAN,)])
+ def test_rooms_required_state_leave_ban_incremental(
+ self, stop_membership: str
+ ) -> None:
+ """
+ Test `rooms.required_state` should not return state past a leave/ban event on
+ incremental sync.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user3_id = self.register_user("user3", "pass")
+ user3_tok = self.login(user3_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+ self.helper.join(room_id1, user3_id, tok=user3_tok)
+
+ self.helper.send_state(
+ room_id1,
+ event_type="org.matrix.foo_state",
+ state_key="",
+ body={"foo": "bar"},
+ tok=user2_tok,
+ )
+ self.helper.send_state(
+ room_id1,
+ event_type="org.matrix.bar_state",
+ state_key="",
+ body={"bar": "bar"},
+ tok=user2_tok,
+ )
+
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [EventTypes.Create, ""],
+ [EventTypes.Member, "*"],
+ ["org.matrix.foo_state", ""],
+ ],
+ "timeline_limit": 3,
+ }
+ }
+ }
+ _, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ if stop_membership == Membership.LEAVE:
+ # User 1 leaves
+ self.helper.leave(room_id1, user1_id, tok=user1_tok)
+ elif stop_membership == Membership.BAN:
+ # User 1 is banned
+ self.helper.ban(room_id1, src=user2_id, targ=user1_id, tok=user2_tok)
+
+ # Get the state_map before we change the state as this is the final state we
+ # expect User1 to be able to see
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ # Change the state after user 1 leaves
+ self.helper.send_state(
+ room_id1,
+ event_type="org.matrix.foo_state",
+ state_key="",
+ body={"foo": "qux"},
+ tok=user2_tok,
+ )
+ self.helper.send_state(
+ room_id1,
+ event_type="org.matrix.bar_state",
+ state_key="",
+ body={"bar": "qux"},
+ tok=user2_tok,
+ )
+ self.helper.leave(room_id1, user3_id, tok=user3_tok)
+
+ # Make an incremental Sliding Sync request
+ #
+ # Also expand the required state to include the `org.matrix.bar_state` event.
+ # This is just an extra complication of the test.
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [EventTypes.Create, ""],
+ [EventTypes.Member, "*"],
+ ["org.matrix.foo_state", ""],
+ ["org.matrix.bar_state", ""],
+ ],
+ "timeline_limit": 3,
+ }
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+ # User1 should only see the state up to the leave/ban event
+ self._assertRequiredStateIncludes(
+ response_body["rooms"][room_id1]["required_state"],
+ {
+ # User1 should see their leave/ban membership
+ state_map[(EventTypes.Member, user1_id)],
+ state_map[("org.matrix.bar_state", "")],
+ # The commented out state events were already returned in the initial
+ # sync so we shouldn't see them again on the incremental sync. And we
+ # shouldn't see the state events that changed after the leave/ban event.
+ #
+ # state_map[(EventTypes.Create, "")],
+ # state_map[(EventTypes.Member, user2_id)],
+ # state_map[(EventTypes.Member, user3_id)],
+ # state_map[("org.matrix.foo_state", "")],
},
exact=True,
)
@@ -631,8 +1216,7 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
def test_rooms_required_state_partial_state(self) -> None:
"""
- Test partially-stated room are excluded unless `rooms.required_state` is
- lazy-loading room members.
+ Test partially-stated room are excluded if they require full state.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
@@ -649,13 +1233,63 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
mark_event_as_partial_state(self.hs, join_response2["event_id"], room_id2)
)
- # Make the Sliding Sync request (NOT lazy-loading room members)
+ # Make the Sliding Sync request with examples where `must_await_full_state()` is
+ # `False`
sync_body = {
"lists": {
- "foo-list": {
+ "no-state-list": {
+ "ranges": [[0, 1]],
+ "required_state": [],
+ "timeline_limit": 0,
+ },
+ "other-state-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [EventTypes.Create, ""],
+ ],
+ "timeline_limit": 0,
+ },
+ "lazy-load-list": {
"ranges": [[0, 1]],
"required_state": [
[EventTypes.Create, ""],
+ # Lazy-load room members
+ [EventTypes.Member, StateValues.LAZY],
+ # Local member
+ [EventTypes.Member, user2_id],
+ ],
+ "timeline_limit": 0,
+ },
+ "local-members-only-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ # Own user ID
+ [EventTypes.Member, user1_id],
+ # Local member
+ [EventTypes.Member, user2_id],
+ ],
+ "timeline_limit": 0,
+ },
+ "me-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ # Own user ID
+ [EventTypes.Member, StateValues.ME],
+ # Local member
+ [EventTypes.Member, user2_id],
+ ],
+ "timeline_limit": 0,
+ },
+ "wildcard-type-local-state-key-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ ["*", user1_id],
+ # Not a user ID
+ ["*", "foobarbaz"],
+ # Not a user ID
+ ["*", "foo.bar.baz"],
+ # Not a user ID
+ ["*", "@foo"],
],
"timeline_limit": 0,
},
@@ -663,29 +1297,89 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
}
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
- # Make sure the list includes room1 but room2 is excluded because it's still
- # partially-stated
- self.assertListEqual(
- list(response_body["lists"]["foo-list"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 1],
- "room_ids": [room_id1],
+ # The list should include both rooms now because we don't need full state
+ for list_key in response_body["lists"].keys():
+ self.assertIncludes(
+ set(response_body["lists"][list_key]["ops"][0]["room_ids"]),
+ {room_id2, room_id1},
+ exact=True,
+ message=f"Expected all rooms to show up for list_key={list_key}. Response "
+ + str(response_body["lists"][list_key]),
+ )
+
+ # Take each of the list variants and apply them to room subscriptions to make
+ # sure the same rules apply
+ for list_key in sync_body["lists"].keys():
+ sync_body_for_subscriptions = {
+ "room_subscriptions": {
+ room_id1: {
+ "required_state": sync_body["lists"][list_key][
+ "required_state"
+ ],
+ "timeline_limit": 0,
+ },
+ room_id2: {
+ "required_state": sync_body["lists"][list_key][
+ "required_state"
+ ],
+ "timeline_limit": 0,
+ },
}
- ],
- response_body["lists"]["foo-list"],
- )
+ }
+ response_body, _ = self.do_sync(sync_body_for_subscriptions, tok=user1_tok)
+
+ self.assertIncludes(
+ set(response_body["rooms"].keys()),
+ {room_id2, room_id1},
+ exact=True,
+ message=f"Expected all rooms to show up for test_key={list_key}.",
+ )
- # Make the Sliding Sync request (with lazy-loading room members)
+ # =====================================================================
+
+ # Make the Sliding Sync request with examples where `must_await_full_state()` is
+ # `True`
sync_body = {
"lists": {
- "foo-list": {
+ "wildcard-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ ["*", "*"],
+ ],
+ "timeline_limit": 0,
+ },
+ "wildcard-type-remote-state-key-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ ["*", "@some:remote"],
+ # Not a user ID
+ ["*", "foobarbaz"],
+ # Not a user ID
+ ["*", "foo.bar.baz"],
+ # Not a user ID
+ ["*", "@foo"],
+ ],
+ "timeline_limit": 0,
+ },
+ "remote-member-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ # Own user ID
+ [EventTypes.Member, user1_id],
+ # Remote member
+ [EventTypes.Member, "@some:remote"],
+ # Local member
+ [EventTypes.Member, user2_id],
+ ],
+ "timeline_limit": 0,
+ },
+ "lazy-but-remote-member-list": {
"ranges": [[0, 1]],
"required_state": [
- [EventTypes.Create, ""],
# Lazy-load room members
[EventTypes.Member, StateValues.LAZY],
+ # Remote member
+ [EventTypes.Member, "@some:remote"],
],
"timeline_limit": 0,
},
@@ -693,15 +1387,302 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
}
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
- # The list should include both rooms now because we're lazy-loading room members
- self.assertListEqual(
- list(response_body["lists"]["foo-list"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 1],
- "room_ids": [room_id2, room_id1],
+ # Make sure the list includes room1 but room2 is excluded because it's still
+ # partially-stated
+ for list_key in response_body["lists"].keys():
+ self.assertIncludes(
+ set(response_body["lists"][list_key]["ops"][0]["room_ids"]),
+ {room_id1},
+ exact=True,
+ message=f"Expected only fully-stated rooms to show up for list_key={list_key}. Response "
+ + str(response_body["lists"][list_key]),
+ )
+
+ # Take each of the list variants and apply them to room subscriptions to make
+ # sure the same rules apply
+ for list_key in sync_body["lists"].keys():
+ sync_body_for_subscriptions = {
+ "room_subscriptions": {
+ room_id1: {
+ "required_state": sync_body["lists"][list_key][
+ "required_state"
+ ],
+ "timeline_limit": 0,
+ },
+ room_id2: {
+ "required_state": sync_body["lists"][list_key][
+ "required_state"
+ ],
+ "timeline_limit": 0,
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body_for_subscriptions, tok=user1_tok)
+
+ self.assertIncludes(
+ set(response_body["rooms"].keys()),
+ {room_id1},
+ exact=True,
+ message=f"Expected only fully-stated rooms to show up for test_key={list_key}.",
+ )
+
+ def test_rooms_required_state_expand(self) -> None:
+ """Test that when we expand the required state argument we get the
+ expanded state, and not just the changes to the new expanded."""
+
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a room with a room name.
+ room_id1 = self.helper.create_room_as(
+ user1_id, tok=user1_tok, extra_content={"name": "Foo"}
+ )
+
+ # Only request the state event to begin with
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [EventTypes.Create, ""],
+ ],
+ "timeline_limit": 1,
}
- ],
- response_body["lists"]["foo-list"],
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
)
+
+ self._assertRequiredStateIncludes(
+ response_body["rooms"][room_id1]["required_state"],
+ {
+ state_map[(EventTypes.Create, "")],
+ },
+ exact=True,
+ )
+
+ # Send a message so the room comes down sync.
+ self.helper.send(room_id1, "msg", tok=user1_tok)
+
+ # Update the sliding sync requests to include the room name
+ sync_body["lists"]["foo-list"]["required_state"] = [
+ [EventTypes.Create, ""],
+ [EventTypes.Name, ""],
+ ]
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+
+ # We should see the room name, even though there haven't been any
+ # changes.
+ self._assertRequiredStateIncludes(
+ response_body["rooms"][room_id1]["required_state"],
+ {
+ state_map[(EventTypes.Name, "")],
+ },
+ exact=True,
+ )
+
+ # Send a message so the room comes down sync.
+ self.helper.send(room_id1, "msg", tok=user1_tok)
+
+ # We should not see any state changes.
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+ self.assertIsNone(response_body["rooms"][room_id1].get("required_state"))
+
+ def test_rooms_required_state_expand_retract_expand(self) -> None:
+ """Test that when expanding, retracting and then expanding the required
+ state, we get the changes that happened."""
+
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a room with a room name.
+ room_id1 = self.helper.create_room_as(
+ user1_id, tok=user1_tok, extra_content={"name": "Foo"}
+ )
+
+ # Only request the state event to begin with
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [EventTypes.Create, ""],
+ ],
+ "timeline_limit": 1,
+ }
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ self._assertRequiredStateIncludes(
+ response_body["rooms"][room_id1]["required_state"],
+ {
+ state_map[(EventTypes.Create, "")],
+ },
+ exact=True,
+ )
+
+ # Send a message so the room comes down sync.
+ self.helper.send(room_id1, "msg", tok=user1_tok)
+
+ # Update the sliding sync requests to include the room name
+ sync_body["lists"]["foo-list"]["required_state"] = [
+ [EventTypes.Create, ""],
+ [EventTypes.Name, ""],
+ ]
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+
+ # We should see the room name, even though there haven't been any
+ # changes.
+ self._assertRequiredStateIncludes(
+ response_body["rooms"][room_id1]["required_state"],
+ {
+ state_map[(EventTypes.Name, "")],
+ },
+ exact=True,
+ )
+
+ # Update the room name
+ self.helper.send_state(
+ room_id1, EventTypes.Name, {"name": "Bar"}, state_key="", tok=user1_tok
+ )
+
+ # Update the sliding sync requests to exclude the room name again
+ sync_body["lists"]["foo-list"]["required_state"] = [
+ [EventTypes.Create, ""],
+ ]
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+
+ # We should not see the updated room name in state (though it will be in
+ # the timeline).
+ self.assertIsNone(response_body["rooms"][room_id1].get("required_state"))
+
+ # Send a message so the room comes down sync.
+ self.helper.send(room_id1, "msg", tok=user1_tok)
+
+ # Update the sliding sync requests to include the room name again
+ sync_body["lists"]["foo-list"]["required_state"] = [
+ [EventTypes.Create, ""],
+ [EventTypes.Name, ""],
+ ]
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+
+ # We should see the *new* room name, even though there haven't been any
+ # changes.
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+ self._assertRequiredStateIncludes(
+ response_body["rooms"][room_id1]["required_state"],
+ {
+ state_map[(EventTypes.Name, "")],
+ },
+ exact=True,
+ )
+
+ def test_rooms_required_state_expand_deduplicate(self) -> None:
+ """Test that when expanding, retracting and then expanding the required
+ state, we don't get the state down again if it hasn't changed"""
+
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a room with a room name.
+ room_id1 = self.helper.create_room_as(
+ user1_id, tok=user1_tok, extra_content={"name": "Foo"}
+ )
+
+ # Only request the state event to begin with
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ [EventTypes.Create, ""],
+ ],
+ "timeline_limit": 1,
+ }
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ self._assertRequiredStateIncludes(
+ response_body["rooms"][room_id1]["required_state"],
+ {
+ state_map[(EventTypes.Create, "")],
+ },
+ exact=True,
+ )
+
+ # Send a message so the room comes down sync.
+ self.helper.send(room_id1, "msg", tok=user1_tok)
+
+ # Update the sliding sync requests to include the room name
+ sync_body["lists"]["foo-list"]["required_state"] = [
+ [EventTypes.Create, ""],
+ [EventTypes.Name, ""],
+ ]
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+
+ # We should see the room name, even though there haven't been any
+ # changes.
+ self._assertRequiredStateIncludes(
+ response_body["rooms"][room_id1]["required_state"],
+ {
+ state_map[(EventTypes.Name, "")],
+ },
+ exact=True,
+ )
+
+ # Send a message so the room comes down sync.
+ self.helper.send(room_id1, "msg", tok=user1_tok)
+
+ # Update the sliding sync requests to exclude the room name again
+ sync_body["lists"]["foo-list"]["required_state"] = [
+ [EventTypes.Create, ""],
+ ]
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+
+ # We should not see any state updates
+ self.assertIsNone(response_body["rooms"][room_id1].get("required_state"))
+
+ # Send a message so the room comes down sync.
+ self.helper.send(room_id1, "msg", tok=user1_tok)
+
+ # Update the sliding sync requests to include the room name again
+ sync_body["lists"]["foo-list"]["required_state"] = [
+ [EventTypes.Create, ""],
+ [EventTypes.Name, ""],
+ ]
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+
+ # We should not see the room name again, as we have already sent that
+ # down.
+ self.assertIsNone(response_body["rooms"][room_id1].get("required_state"))
diff --git a/tests/rest/client/sliding_sync/test_rooms_timeline.py b/tests/rest/client/sliding_sync/test_rooms_timeline.py
index 2e9586ca73..535420209b 100644
--- a/tests/rest/client/sliding_sync/test_rooms_timeline.py
+++ b/tests/rest/client/sliding_sync/test_rooms_timeline.py
@@ -14,12 +14,15 @@
import logging
from typing import List, Optional
+from parameterized import parameterized_class
+
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
+from synapse.api.constants import EventTypes
from synapse.rest.client import login, room, sync
from synapse.server import HomeServer
-from synapse.types import StreamToken, StrSequence
+from synapse.types import StrSequence
from synapse.util import Clock
from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase
@@ -27,6 +30,20 @@ from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase
logger = logging.getLogger(__name__)
+# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+# foreground update for
+# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+# https://github.com/element-hq/synapse/issues/17623)
+@parameterized_class(
+ ("use_new_tables",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}",
+)
class SlidingSyncRoomsTimelineTestCase(SlidingSyncBase):
"""
Test `rooms.timeline` in the Sliding Sync API.
@@ -43,6 +60,8 @@ class SlidingSyncRoomsTimelineTestCase(SlidingSyncBase):
self.store = hs.get_datastores().main
self.storage_controllers = hs.get_storage_controllers()
+ super().prepare(reactor, clock, hs)
+
def _assertListEqual(
self,
actual_items: StrSequence,
@@ -130,16 +149,10 @@ class SlidingSyncRoomsTimelineTestCase(SlidingSyncBase):
user2_tok = self.login(user2_id, "pass")
room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.send(room_id1, "activity1", tok=user2_tok)
- self.helper.send(room_id1, "activity2", tok=user2_tok)
+ event_response1 = self.helper.send(room_id1, "activity1", tok=user2_tok)
+ event_response2 = self.helper.send(room_id1, "activity2", tok=user2_tok)
event_response3 = self.helper.send(room_id1, "activity3", tok=user2_tok)
- event_pos3 = self.get_success(
- self.store.get_position_for_event(event_response3["event_id"])
- )
event_response4 = self.helper.send(room_id1, "activity4", tok=user2_tok)
- event_pos4 = self.get_success(
- self.store.get_position_for_event(event_response4["event_id"])
- )
event_response5 = self.helper.send(room_id1, "activity5", tok=user2_tok)
user1_join_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
@@ -177,27 +190,23 @@ class SlidingSyncRoomsTimelineTestCase(SlidingSyncBase):
)
# Check to make sure the `prev_batch` points at the right place
- prev_batch_token = self.get_success(
- StreamToken.from_string(
- self.store, response_body["rooms"][room_id1]["prev_batch"]
- )
- )
- prev_batch_room_stream_token_serialized = self.get_success(
- prev_batch_token.room_key.to_string(self.store)
+ prev_batch_token = response_body["rooms"][room_id1]["prev_batch"]
+
+ # If we use the `prev_batch` token to look backwards we should see
+ # `event3` and older next.
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{room_id1}/messages?from={prev_batch_token}&dir=b&limit=3",
+ access_token=user1_tok,
)
- # If we use the `prev_batch` token to look backwards, we should see `event3`
- # next so make sure the token encompasses it
- self.assertEqual(
- event_pos3.persisted_after(prev_batch_token.room_key),
- False,
- f"`prev_batch` token {prev_batch_room_stream_token_serialized} should be >= event_pos3={self.get_success(event_pos3.to_room_stream_token().to_string(self.store))}",
- )
- # If we use the `prev_batch` token to look backwards, we shouldn't see `event4`
- # anymore since it was just returned in this response.
- self.assertEqual(
- event_pos4.persisted_after(prev_batch_token.room_key),
- True,
- f"`prev_batch` token {prev_batch_room_stream_token_serialized} should be < event_pos4={self.get_success(event_pos4.to_room_stream_token().to_string(self.store))}",
+ self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertListEqual(
+ [
+ event_response3["event_id"],
+ event_response2["event_id"],
+ event_response1["event_id"],
+ ],
+ [ev["event_id"] for ev in channel.json_body["chunk"]],
)
# With no `from_token` (initial sync), it's all historical since there is no
@@ -300,8 +309,8 @@ class SlidingSyncRoomsTimelineTestCase(SlidingSyncBase):
self.assertEqual(
response_body["rooms"][room_id1]["limited"],
False,
- f'Our `timeline_limit` was {sync_body["lists"]["foo-list"]["timeline_limit"]} '
- + f'and {len(response_body["rooms"][room_id1]["timeline"])} events were returned in the timeline. '
+ f"Our `timeline_limit` was {sync_body['lists']['foo-list']['timeline_limit']} "
+ + f"and {len(response_body['rooms'][room_id1]['timeline'])} events were returned in the timeline. "
+ str(response_body["rooms"][room_id1]),
)
# Check to make sure the latest events are returned
@@ -378,7 +387,7 @@ class SlidingSyncRoomsTimelineTestCase(SlidingSyncBase):
response_body["rooms"][room_id1]["limited"],
True,
f"Our `timeline_limit` was {timeline_limit} "
- + f'and {len(response_body["rooms"][room_id1]["timeline"])} events were returned in the timeline. '
+ + f"and {len(response_body['rooms'][room_id1]['timeline'])} events were returned in the timeline. "
+ str(response_body["rooms"][room_id1]),
)
# Check to make sure that the "live" and historical events are returned
@@ -573,3 +582,138 @@ class SlidingSyncRoomsTimelineTestCase(SlidingSyncBase):
# Nothing to see for this banned user in the room in the token range
self.assertIsNone(response_body["rooms"].get(room_id1))
+
+ def test_increasing_timeline_range_sends_more_messages(self) -> None:
+ """
+ Test that increasing the timeline limit via room subscriptions sends the
+ room down with more messages in a limited sync.
+ """
+
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [[EventTypes.Create, ""]],
+ "timeline_limit": 1,
+ }
+ }
+ }
+
+ message_events = []
+ for _ in range(10):
+ resp = self.helper.send(room_id1, "msg", tok=user1_tok)
+ message_events.append(resp["event_id"])
+
+ # Make the first Sliding Sync request
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+ room_response = response_body["rooms"][room_id1]
+
+ self.assertEqual(room_response["initial"], True)
+ self.assertNotIn("unstable_expanded_timeline", room_response)
+ self.assertEqual(room_response["limited"], True)
+
+ # We only expect the last message at first
+ self._assertTimelineEqual(
+ room_id=room_id1,
+ actual_event_ids=[event["event_id"] for event in room_response["timeline"]],
+ expected_event_ids=message_events[-1:],
+ message=str(room_response["timeline"]),
+ )
+
+ # We also expect to get the create event state.
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+ self._assertRequiredStateIncludes(
+ room_response["required_state"],
+ {state_map[(EventTypes.Create, "")]},
+ exact=True,
+ )
+
+ # Now do another request with a room subscription with an increased timeline limit
+ sync_body["room_subscriptions"] = {
+ room_id1: {
+ "required_state": [],
+ "timeline_limit": 10,
+ }
+ }
+
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+ room_response = response_body["rooms"][room_id1]
+
+ self.assertNotIn("initial", room_response)
+ self.assertEqual(room_response["unstable_expanded_timeline"], True)
+ self.assertEqual(room_response["limited"], True)
+
+ # Now we expect all the messages
+ self._assertTimelineEqual(
+ room_id=room_id1,
+ actual_event_ids=[event["event_id"] for event in room_response["timeline"]],
+ expected_event_ids=message_events,
+ message=str(room_response["timeline"]),
+ )
+
+ # We don't expect to get the room create down, as nothing has changed.
+ self.assertNotIn("required_state", room_response)
+
+ # Decreasing the timeline limit shouldn't resend any events
+ sync_body["room_subscriptions"] = {
+ room_id1: {
+ "required_state": [],
+ "timeline_limit": 5,
+ }
+ }
+
+ event_response = self.helper.send(room_id1, "msg", tok=user1_tok)
+ latest_event_id = event_response["event_id"]
+
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+ room_response = response_body["rooms"][room_id1]
+
+ self.assertNotIn("initial", room_response)
+ self.assertNotIn("unstable_expanded_timeline", room_response)
+ self.assertEqual(room_response["limited"], False)
+
+ self._assertTimelineEqual(
+ room_id=room_id1,
+ actual_event_ids=[event["event_id"] for event in room_response["timeline"]],
+ expected_event_ids=[latest_event_id],
+ message=str(room_response["timeline"]),
+ )
+
+ # Increasing the limit to what it was before also should not resend any
+ # events
+ sync_body["room_subscriptions"] = {
+ room_id1: {
+ "required_state": [],
+ "timeline_limit": 10,
+ }
+ }
+
+ event_response = self.helper.send(room_id1, "msg", tok=user1_tok)
+ latest_event_id = event_response["event_id"]
+
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+ room_response = response_body["rooms"][room_id1]
+
+ self.assertNotIn("initial", room_response)
+ self.assertNotIn("unstable_expanded_timeline", room_response)
+ self.assertEqual(room_response["limited"], False)
+
+ self._assertTimelineEqual(
+ room_id=room_id1,
+ actual_event_ids=[event["event_id"] for event in room_response["timeline"]],
+ expected_event_ids=[latest_event_id],
+ message=str(room_response["timeline"]),
+ )
diff --git a/tests/rest/client/sliding_sync/test_sliding_sync.py b/tests/rest/client/sliding_sync/test_sliding_sync.py
index cb7638c5ba..dcec5b4cf0 100644
--- a/tests/rest/client/sliding_sync/test_sliding_sync.py
+++ b/tests/rest/client/sliding_sync/test_sliding_sync.py
@@ -13,7 +13,9 @@
#
import logging
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple
+from unittest.mock import AsyncMock
+from parameterized import parameterized, parameterized_class
from typing_extensions import assert_never
from twisted.test.proto_helpers import MemoryReactor
@@ -23,10 +25,15 @@ from synapse.api.constants import (
AccountDataTypes,
EventContentFields,
EventTypes,
+ JoinRules,
+ Membership,
RoomTypes,
)
-from synapse.events import EventBase
-from synapse.rest.client import devices, login, receipts, room, sync
+from synapse.api.room_versions import RoomVersions
+from synapse.events import EventBase, StrippedStateEvent, make_event_from_dict
+from synapse.events.snapshot import EventContext
+from synapse.handlers.sliding_sync import StateValues
+from synapse.rest.client import account_data, devices, login, receipts, room, sync
from synapse.server import HomeServer
from synapse.types import (
JsonDict,
@@ -40,6 +47,7 @@ from synapse.util.stringutils import random_string
from tests import unittest
from tests.server import TimedOutException
+from tests.test_utils.event_injection import create_event
logger = logging.getLogger(__name__)
@@ -47,8 +55,25 @@ logger = logging.getLogger(__name__)
class SlidingSyncBase(unittest.HomeserverTestCase):
"""Base class for sliding sync test cases"""
+ # Flag as to whether to use the new sliding sync tables or not
+ #
+ # FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+ # foreground update for
+ # `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+ # https://github.com/element-hq/synapse/issues/17623)
+ use_new_tables: bool = True
+
sync_endpoint = "/_matrix/client/unstable/org.matrix.simplified_msc3575/sync"
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ # FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+ # foreground update for
+ # `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+ # https://github.com/element-hq/synapse/issues/17623)
+ hs.get_datastores().main.have_finished_sliding_sync_background_jobs = AsyncMock( # type: ignore[method-assign]
+ return_value=self.use_new_tables
+ )
+
def default_config(self) -> JsonDict:
config = super().default_config()
# Enable sliding sync
@@ -122,6 +147,172 @@ class SlidingSyncBase(unittest.HomeserverTestCase):
message=str(actual_required_state),
)
+ def _add_new_dm_to_global_account_data(
+ self, source_user_id: str, target_user_id: str, target_room_id: str
+ ) -> None:
+ """
+ Helper to handle inserting a new DM for the source user into global account data
+ (handles all of the list merging).
+
+ Args:
+ source_user_id: The user ID of the DM mapping we're going to update
+ target_user_id: User ID of the person the DM is with
+ target_room_id: Room ID of the DM
+ """
+ store = self.hs.get_datastores().main
+
+ # Get the current DM map
+ existing_dm_map = self.get_success(
+ store.get_global_account_data_by_type_for_user(
+ source_user_id, AccountDataTypes.DIRECT
+ )
+ )
+ # Scrutinize the account data since it has no concrete type. We're just copying
+ # everything into a known type. It should be a mapping from user ID to a list of
+ # room IDs. Ignore anything else.
+ new_dm_map: Dict[str, List[str]] = {}
+ if isinstance(existing_dm_map, dict):
+ for user_id, room_ids in existing_dm_map.items():
+ if isinstance(user_id, str) and isinstance(room_ids, list):
+ for room_id in room_ids:
+ if isinstance(room_id, str):
+ new_dm_map[user_id] = new_dm_map.get(user_id, []) + [
+ room_id
+ ]
+
+ # Add the new DM to the map
+ new_dm_map[target_user_id] = new_dm_map.get(target_user_id, []) + [
+ target_room_id
+ ]
+ # Save the DM map to global account data
+ self.get_success(
+ store.add_account_data_for_user(
+ source_user_id,
+ AccountDataTypes.DIRECT,
+ new_dm_map,
+ )
+ )
+
+ def _create_dm_room(
+ self,
+ inviter_user_id: str,
+ inviter_tok: str,
+ invitee_user_id: str,
+ invitee_tok: str,
+ should_join_room: bool = True,
+ ) -> str:
+ """
+ Helper to create a DM room as the "inviter" and invite the "invitee" user to the
+ room. The "invitee" user also will join the room. The `m.direct` account data
+ will be set for both users.
+ """
+ # Create a room and send an invite the other user
+ room_id = self.helper.create_room_as(
+ inviter_user_id,
+ is_public=False,
+ tok=inviter_tok,
+ )
+ self.helper.invite(
+ room_id,
+ src=inviter_user_id,
+ targ=invitee_user_id,
+ tok=inviter_tok,
+ extra_data={"is_direct": True},
+ )
+ if should_join_room:
+ # Person that was invited joins the room
+ self.helper.join(room_id, invitee_user_id, tok=invitee_tok)
+
+ # Mimic the client setting the room as a direct message in the global account
+ # data for both users.
+ self._add_new_dm_to_global_account_data(
+ invitee_user_id, inviter_user_id, room_id
+ )
+ self._add_new_dm_to_global_account_data(
+ inviter_user_id, invitee_user_id, room_id
+ )
+
+ return room_id
+
+ _remote_invite_count: int = 0
+
+ def _create_remote_invite_room_for_user(
+ self,
+ invitee_user_id: str,
+ unsigned_invite_room_state: Optional[List[StrippedStateEvent]],
+ invite_room_id: Optional[str] = None,
+ ) -> str:
+ """
+ Create a fake invite for a remote room and persist it.
+
+ We don't have any state for these kind of rooms and can only rely on the
+ stripped state included in the unsigned portion of the invite event to identify
+ the room.
+
+ Args:
+ invitee_user_id: The person being invited
+ unsigned_invite_room_state: List of stripped state events to assist the
+ receiver in identifying the room.
+ invite_room_id: Optional remote room ID to be invited to. When unset, we
+ will generate one.
+
+ Returns:
+ The room ID of the remote invite room
+ """
+ store = self.hs.get_datastores().main
+
+ if invite_room_id is None:
+ invite_room_id = f"!test_room{self._remote_invite_count}:remote_server"
+
+ invite_event_dict = {
+ "room_id": invite_room_id,
+ "sender": "@inviter:remote_server",
+ "state_key": invitee_user_id,
+ # Just keep advancing the depth
+ "depth": self._remote_invite_count,
+ "origin_server_ts": 1,
+ "type": EventTypes.Member,
+ "content": {"membership": Membership.INVITE},
+ "auth_events": [],
+ "prev_events": [],
+ }
+ if unsigned_invite_room_state is not None:
+ serialized_stripped_state_events = []
+ for stripped_event in unsigned_invite_room_state:
+ serialized_stripped_state_events.append(
+ {
+ "type": stripped_event.type,
+ "state_key": stripped_event.state_key,
+ "sender": stripped_event.sender,
+ "content": stripped_event.content,
+ }
+ )
+
+ invite_event_dict["unsigned"] = {
+ "invite_room_state": serialized_stripped_state_events
+ }
+
+ invite_event = make_event_from_dict(
+ invite_event_dict,
+ room_version=RoomVersions.V10,
+ )
+ invite_event.internal_metadata.outlier = True
+ invite_event.internal_metadata.out_of_band_membership = True
+
+ self.get_success(
+ store.maybe_store_room_on_outlier_membership(
+ room_id=invite_room_id, room_version=invite_event.room_version
+ )
+ )
+ context = EventContext.for_outlier(self.hs.get_storage_controllers())
+ persist_controller = self.hs.get_storage_controllers().persistence
+ assert persist_controller is not None
+ self.get_success(persist_controller.persist_event(invite_event, context))
+
+ self._remote_invite_count += 1
+
+ return invite_room_id
+
def _bump_notifier_wait_for_events(
self,
user_id: str,
@@ -203,6 +394,20 @@ class SlidingSyncBase(unittest.HomeserverTestCase):
)
+# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+# foreground update for
+# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+# https://github.com/element-hq/synapse/issues/17623)
+@parameterized_class(
+ ("use_new_tables",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}",
+)
class SlidingSyncTestCase(SlidingSyncBase):
"""
Tests regarding MSC3575 Sliding Sync `/sync` endpoint.
@@ -218,6 +423,7 @@ class SlidingSyncTestCase(SlidingSyncBase):
sync.register_servlets,
devices.register_servlets,
receipts.register_servlets,
+ account_data.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -225,93 +431,11 @@ class SlidingSyncTestCase(SlidingSyncBase):
self.event_sources = hs.get_event_sources()
self.storage_controllers = hs.get_storage_controllers()
self.account_data_handler = hs.get_account_data_handler()
+ persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
+ self.persistence = persistence
- def _add_new_dm_to_global_account_data(
- self, source_user_id: str, target_user_id: str, target_room_id: str
- ) -> None:
- """
- Helper to handle inserting a new DM for the source user into global account data
- (handles all of the list merging).
-
- Args:
- source_user_id: The user ID of the DM mapping we're going to update
- target_user_id: User ID of the person the DM is with
- target_room_id: Room ID of the DM
- """
-
- # Get the current DM map
- existing_dm_map = self.get_success(
- self.store.get_global_account_data_by_type_for_user(
- source_user_id, AccountDataTypes.DIRECT
- )
- )
- # Scrutinize the account data since it has no concrete type. We're just copying
- # everything into a known type. It should be a mapping from user ID to a list of
- # room IDs. Ignore anything else.
- new_dm_map: Dict[str, List[str]] = {}
- if isinstance(existing_dm_map, dict):
- for user_id, room_ids in existing_dm_map.items():
- if isinstance(user_id, str) and isinstance(room_ids, list):
- for room_id in room_ids:
- if isinstance(room_id, str):
- new_dm_map[user_id] = new_dm_map.get(user_id, []) + [
- room_id
- ]
-
- # Add the new DM to the map
- new_dm_map[target_user_id] = new_dm_map.get(target_user_id, []) + [
- target_room_id
- ]
- # Save the DM map to global account data
- self.get_success(
- self.store.add_account_data_for_user(
- source_user_id,
- AccountDataTypes.DIRECT,
- new_dm_map,
- )
- )
-
- def _create_dm_room(
- self,
- inviter_user_id: str,
- inviter_tok: str,
- invitee_user_id: str,
- invitee_tok: str,
- should_join_room: bool = True,
- ) -> str:
- """
- Helper to create a DM room as the "inviter" and invite the "invitee" user to the
- room. The "invitee" user also will join the room. The `m.direct` account data
- will be set for both users.
- """
-
- # Create a room and send an invite the other user
- room_id = self.helper.create_room_as(
- inviter_user_id,
- is_public=False,
- tok=inviter_tok,
- )
- self.helper.invite(
- room_id,
- src=inviter_user_id,
- targ=invitee_user_id,
- tok=inviter_tok,
- extra_data={"is_direct": True},
- )
- if should_join_room:
- # Person that was invited joins the room
- self.helper.join(room_id, invitee_user_id, tok=invitee_tok)
-
- # Mimic the client setting the room as a direct message in the global account
- # data for both users.
- self._add_new_dm_to_global_account_data(
- invitee_user_id, inviter_user_id, room_id
- )
- self._add_new_dm_to_global_account_data(
- inviter_user_id, invitee_user_id, room_id
- )
-
- return room_id
+ super().prepare(reactor, clock, hs)
def test_sync_list(self) -> None:
"""
@@ -512,288 +636,326 @@ class SlidingSyncTestCase(SlidingSyncBase):
# There should be no room sent down.
self.assertFalse(channel.json_body["rooms"])
- def test_filter_list(self) -> None:
+ def test_forgotten_up_to_date(self) -> None:
"""
- Test that filters apply to `lists`
+ Make sure we get up-to-date `forgotten` status for rooms
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
user2_id = self.register_user("user2", "pass")
user2_tok = self.login(user2_id, "pass")
- # Create a DM room
- joined_dm_room_id = self._create_dm_room(
- inviter_user_id=user1_id,
- inviter_tok=user1_tok,
- invitee_user_id=user2_id,
- invitee_tok=user2_tok,
- should_join_room=True,
- )
- invited_dm_room_id = self._create_dm_room(
- inviter_user_id=user1_id,
- inviter_tok=user1_tok,
- invitee_user_id=user2_id,
- invitee_tok=user2_tok,
- should_join_room=False,
- )
-
- # Create a normal room
room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id, user1_id, tok=user1_tok)
- # Create a room that user1 is invited to
- invite_room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.invite(invite_room_id, src=user2_id, targ=user1_id, tok=user2_tok)
+ # User1 is banned from the room (was never in the room)
+ self.helper.ban(room_id, src=user2_id, targ=user1_id, tok=user2_tok)
- # Make the Sliding Sync request
sync_body = {
"lists": {
- # Absense of filters does not imply "False" values
- "all": {
+ "foo-list": {
"ranges": [[0, 99]],
"required_state": [],
- "timeline_limit": 1,
+ "timeline_limit": 0,
"filters": {},
},
- # Test single truthy filter
- "dms": {
- "ranges": [[0, 99]],
- "required_state": [],
- "timeline_limit": 1,
- "filters": {"is_dm": True},
- },
- # Test single falsy filter
- "non-dms": {
- "ranges": [[0, 99]],
- "required_state": [],
- "timeline_limit": 1,
- "filters": {"is_dm": False},
- },
- # Test how multiple filters should stack (AND'd together)
- "room-invites": {
- "ranges": [[0, 99]],
- "required_state": [],
- "timeline_limit": 1,
- "filters": {"is_dm": False, "is_invite": True},
- },
}
}
- response_body, _ = self.do_sync(sync_body, tok=user1_tok)
-
- # Make sure it has the foo-list we requested
- self.assertListEqual(
- list(response_body["lists"].keys()),
- ["all", "dms", "non-dms", "room-invites"],
- response_body["lists"].keys(),
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
)
- # Make sure the lists have the correct rooms
- self.assertListEqual(
- list(response_body["lists"]["all"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 99],
- "room_ids": [
- invite_room_id,
- room_id,
- invited_dm_room_id,
- joined_dm_room_id,
- ],
- }
- ],
- list(response_body["lists"]["all"]),
- )
- self.assertListEqual(
- list(response_body["lists"]["dms"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 99],
- "room_ids": [invited_dm_room_id, joined_dm_room_id],
- }
- ],
- list(response_body["lists"]["dms"]),
- )
- self.assertListEqual(
- list(response_body["lists"]["non-dms"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 99],
- "room_ids": [invite_room_id, room_id],
- }
- ],
- list(response_body["lists"]["non-dms"]),
- )
- self.assertListEqual(
- list(response_body["lists"]["room-invites"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 99],
- "room_ids": [invite_room_id],
- }
- ],
- list(response_body["lists"]["room-invites"]),
+ # User1 forgets the room
+ channel = self.make_request(
+ "POST",
+ f"/_matrix/client/r0/rooms/{room_id}/forget",
+ content={},
+ access_token=user1_tok,
)
+ self.assertEqual(channel.code, 200, channel.result)
- # Ensure DM's are correctly marked
- self.assertDictEqual(
- {
- room_id: room.get("is_dm")
- for room_id, room in response_body["rooms"].items()
- },
- {
- invite_room_id: None,
- room_id: None,
- invited_dm_room_id: True,
- joined_dm_room_id: True,
- },
+ # We should no longer see the forgotten room
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ set(),
+ exact=True,
)
- def test_filter_regardless_of_membership_server_left_room(self) -> None:
+ def test_rejoin_forgotten_room(self) -> None:
"""
- Test that filters apply to rooms regardless of membership. We're also
- compounding the problem by having all of the local users leave the room causing
- our server to leave the room.
-
- We want to make sure that if someone is filtering rooms, and leaves, you still
- get that final update down sync that you left.
+ Make sure we can see a forgotten room again if we rejoin (or any new membership
+ like an invite) (no longer forgotten)
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
user2_id = self.register_user("user2", "pass")
user2_tok = self.login(user2_id, "pass")
- # Create a normal room
- room_id = self.helper.create_room_as(user1_id, tok=user2_tok)
+ room_id = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
+ # User1 joins the room
self.helper.join(room_id, user1_id, tok=user1_tok)
- # Create an encrypted space room
- space_room_id = self.helper.create_room_as(
- user2_id,
- tok=user2_tok,
- extra_content={
- "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
- },
- )
- self.helper.send_state(
- space_room_id,
- EventTypes.RoomEncryption,
- {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
- tok=user2_tok,
+ # Make the Sliding Sync request
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ }
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+ # We should see the room (like normal)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
)
- self.helper.join(space_room_id, user1_id, tok=user1_tok)
- # Make an initial Sliding Sync request
+ # Leave and forget the room
+ self.helper.leave(room_id, user1_id, tok=user1_tok)
+ # User1 forgets the room
channel = self.make_request(
"POST",
- self.sync_endpoint,
- {
- "lists": {
- "all-list": {
- "ranges": [[0, 99]],
- "required_state": [],
- "timeline_limit": 0,
- "filters": {},
- },
- "foo-list": {
- "ranges": [[0, 99]],
- "required_state": [],
- "timeline_limit": 1,
- "filters": {
- "is_encrypted": True,
- "room_types": [RoomTypes.SPACE],
- },
- },
+ f"/_matrix/client/r0/rooms/{room_id}/forget",
+ content={},
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # Re-join the room
+ self.helper.join(room_id, user1_id, tok=user1_tok)
+
+ # We should see the room again after re-joining
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ def test_invited_to_forgotten_remote_room(self) -> None:
+ """
+ Make sure we can see a forgotten room again if we are invited again
+ (remote/federated out-of-band memberships)
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a remote room invite (out-of-band membership)
+ room_id = self._create_remote_invite_room_for_user(user1_id, None)
+
+ # Make the Sliding Sync request
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
}
- },
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+ # We should see the room (like normal)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+
+ # Leave and forget the room
+ self.helper.leave(room_id, user1_id, tok=user1_tok)
+ # User1 forgets the room
+ channel = self.make_request(
+ "POST",
+ f"/_matrix/client/r0/rooms/{room_id}/forget",
+ content={},
access_token=user1_tok,
)
- self.assertEqual(channel.code, 200, channel.json_body)
- from_token = channel.json_body["pos"]
+ self.assertEqual(channel.code, 200, channel.result)
- # Make sure the response has the lists we requested
- self.assertListEqual(
- list(channel.json_body["lists"].keys()),
- ["all-list", "foo-list"],
- channel.json_body["lists"].keys(),
+ # Get invited to the room again
+ # self.helper.join(room_id, user1_id, tok=user1_tok)
+ self._create_remote_invite_room_for_user(user1_id, None, invite_room_id=room_id)
+
+ # We should see the room again after re-joining
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
)
- # Make sure the lists have the correct rooms
- self.assertListEqual(
- list(channel.json_body["lists"]["all-list"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 99],
- "room_ids": [space_room_id, room_id],
+ def test_reject_remote_invite(self) -> None:
+ """Test that rejecting a remote invite comes down incremental sync"""
+
+ user_id = self.register_user("user1", "pass")
+ user_tok = self.login(user_id, "pass")
+
+ # Create a remote room invite (out-of-band membership)
+ room_id = "!room:remote.server"
+ self._create_remote_invite_room_for_user(user_id, None, room_id)
+
+ # Make the Sliding Sync request
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [(EventTypes.Member, StateValues.ME)],
+ "timeline_limit": 3,
}
- ],
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user_tok)
+ # We should see the room (like normal)
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
)
- self.assertListEqual(
- list(channel.json_body["lists"]["foo-list"]["ops"]),
- [
+
+ # Reject the remote room invite
+ self.helper.leave(room_id, user_id, tok=user_tok)
+
+ # Sync again after rejecting the invite
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user_tok)
+
+ # The fix to add the leave event to incremental sync when rejecting a remote
+ # invite relies on the new tables to work.
+ if self.use_new_tables:
+ # We should see the newly_left room
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id},
+ exact=True,
+ )
+ # We should see the leave state for the room so clients don't end up with stuck
+ # invites
+ self.assertIncludes(
{
- "op": "SYNC",
- "range": [0, 99],
- "room_ids": [space_room_id],
- }
- ],
+ (
+ state["type"],
+ state["state_key"],
+ state["content"].get("membership"),
+ )
+ for state in response_body["rooms"][room_id]["required_state"]
+ },
+ {(EventTypes.Member, user_id, Membership.LEAVE)},
+ exact=True,
+ )
+
+ def test_ignored_user_invites_initial_sync(self) -> None:
+ """
+ Make sure we ignore invites if they are from one of the `m.ignored_user_list` on
+ initial sync.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # Create a room that user1 is already in
+ room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Create a room that user2 is already in
+ room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok)
+
+ # User1 is invited to room_id2
+ self.helper.invite(room_id2, src=user2_id, targ=user1_id, tok=user2_tok)
+
+ # Sync once before we ignore to make sure the rooms can show up
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ },
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ # room_id2 shows up because we haven't ignored the user yet
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id1, room_id2},
+ exact=True,
)
- # Everyone leaves the encrypted space room
- self.helper.leave(space_room_id, user1_id, tok=user1_tok)
- self.helper.leave(space_room_id, user2_id, tok=user2_tok)
+ # User1 ignores user2
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/v3/user/{user1_id}/account_data/{AccountDataTypes.IGNORED_USER_LIST}",
+ content={"ignored_users": {user2_id: {}}},
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # Sync again (initial sync)
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+ # The invite for room_id2 should no longer show up because user2 is ignored
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id1},
+ exact=True,
+ )
+
+ def test_ignored_user_invites_incremental_sync(self) -> None:
+ """
+ Make sure we ignore invites if they are from one of the `m.ignored_user_list` on
+ incremental sync.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # Create a room that user1 is already in
+ room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Create a room that user2 is already in
+ room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok)
- # Make an incremental Sliding Sync request
+ # User1 ignores user2
channel = self.make_request(
- "POST",
- self.sync_endpoint + f"?pos={from_token}",
- {
- "lists": {
- "all-list": {
- "ranges": [[0, 99]],
- "required_state": [],
- "timeline_limit": 0,
- "filters": {},
- },
- "foo-list": {
- "ranges": [[0, 99]],
- "required_state": [],
- "timeline_limit": 1,
- "filters": {
- "is_encrypted": True,
- "room_types": [RoomTypes.SPACE],
- },
- },
- }
- },
+ "PUT",
+ f"/_matrix/client/v3/user/{user1_id}/account_data/{AccountDataTypes.IGNORED_USER_LIST}",
+ content={"ignored_users": {user2_id: {}}},
access_token=user1_tok,
)
- self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.code, 200, channel.result)
- # Make sure the lists have the correct rooms even though we `newly_left`
- self.assertListEqual(
- list(channel.json_body["lists"]["all-list"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 99],
- "room_ids": [space_room_id, room_id],
- }
- ],
+ # Initial sync
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 0,
+ },
+ }
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+ # User1 only has membership in room_id1 at this point
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id1},
+ exact=True,
)
- self.assertListEqual(
- list(channel.json_body["lists"]["foo-list"]["ops"]),
- [
- {
- "op": "SYNC",
- "range": [0, 99],
- "room_ids": [space_room_id],
- }
- ],
+
+ # User1 is invited to room_id2 after the initial sync
+ self.helper.invite(room_id2, src=user2_id, targ=user1_id, tok=user2_tok)
+
+ # Sync again (incremental sync)
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+ # The invite for room_id2 doesn't show up because user2 is ignored
+ self.assertIncludes(
+ set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+ {room_id1},
+ exact=True,
)
def test_sort_list(self) -> None:
@@ -812,11 +974,11 @@ class SlidingSyncTestCase(SlidingSyncBase):
self.helper.send(room_id1, "activity in room1", tok=user1_tok)
self.helper.send(room_id2, "activity in room2", tok=user1_tok)
- # Make the Sliding Sync request
+ # Make the Sliding Sync request where the range includes *some* of the rooms
sync_body = {
"lists": {
"foo-list": {
- "ranges": [[0, 99]],
+ "ranges": [[0, 1]],
"required_state": [],
"timeline_limit": 1,
}
@@ -825,25 +987,56 @@ class SlidingSyncTestCase(SlidingSyncBase):
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
# Make sure it has the foo-list we requested
- self.assertListEqual(
- list(response_body["lists"].keys()),
- ["foo-list"],
+ self.assertIncludes(
response_body["lists"].keys(),
+ {"foo-list"},
)
-
- # Make sure the list is sorted in the way we expect
+ # Make sure the list is sorted in the way we expect (we only sort when the range
+ # doesn't include all of the room)
self.assertListEqual(
list(response_body["lists"]["foo-list"]["ops"]),
[
{
"op": "SYNC",
- "range": [0, 99],
- "room_ids": [room_id2, room_id1, room_id3],
+ "range": [0, 1],
+ "room_ids": [room_id2, room_id1],
}
],
response_body["lists"]["foo-list"],
)
+ # Make the Sliding Sync request where the range includes *all* of the rooms
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 1,
+ }
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+
+ # Make sure it has the foo-list we requested
+ self.assertIncludes(
+ response_body["lists"].keys(),
+ {"foo-list"},
+ )
+ # Since the range includes all of the rooms, we don't sort the list
+ self.assertEqual(
+ len(response_body["lists"]["foo-list"]["ops"]),
+ 1,
+ response_body["lists"]["foo-list"],
+ )
+ op = response_body["lists"]["foo-list"]["ops"][0]
+ self.assertEqual(op["op"], "SYNC")
+ self.assertEqual(op["range"], [0, 99])
+ # Note that we don't sort the rooms when the range includes all of the rooms, so
+ # we just assert that the rooms are included
+ self.assertIncludes(
+ set(op["room_ids"]), {room_id1, room_id2, room_id3}, exact=True
+ )
+
def test_sliced_windows(self) -> None:
"""
Test that the `lists` `ranges` are sliced correctly. Both sides of each range
@@ -972,3 +1165,454 @@ class SlidingSyncTestCase(SlidingSyncBase):
# Make the Sliding Sync request
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
self.assertEqual(response_body["rooms"][room_id1]["initial"], True)
+
+ def test_state_reset_room_comes_down_incremental_sync(self) -> None:
+ """Test that a room that we were state reset out of comes down
+ incremental sync"""
+
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(
+ user2_id,
+ is_public=True,
+ tok=user2_tok,
+ extra_content={
+ "name": "my super room",
+ },
+ )
+
+ # Create an event for us to point back to for the state reset
+ event_response = self.helper.send(room_id1, "test", tok=user2_tok)
+ event_id = event_response["event_id"]
+
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ # Request all state just to see what we get back when we are
+ # state reset out of the room
+ [StateValues.WILDCARD, StateValues.WILDCARD]
+ ],
+ "timeline_limit": 1,
+ }
+ }
+ }
+
+ # Make the Sliding Sync request
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+ # Make sure we see room1
+ self.assertIncludes(set(response_body["rooms"].keys()), {room_id1}, exact=True)
+ self.assertEqual(response_body["rooms"][room_id1]["initial"], True)
+
+ # Trigger a state reset
+ join_rule_event, join_rule_context = self.get_success(
+ create_event(
+ self.hs,
+ prev_event_ids=[event_id],
+ type=EventTypes.JoinRules,
+ state_key="",
+ content={"join_rule": JoinRules.INVITE},
+ sender=user2_id,
+ room_id=room_id1,
+ room_version=self.get_success(self.store.get_room_version_id(room_id1)),
+ )
+ )
+ _, join_rule_event_pos, _ = self.get_success(
+ self.persistence.persist_event(join_rule_event, join_rule_context)
+ )
+
+ # Ensure that the state reset worked and only user2 is in the room now
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id1))
+ self.assertIncludes(set(users_in_room), {user2_id}, exact=True)
+
+ state_map_at_reset = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ # Update the state after user1 was state reset out of the room
+ self.helper.send_state(
+ room_id1,
+ EventTypes.Name,
+ {EventContentFields.ROOM_NAME: "my super duper room"},
+ tok=user2_tok,
+ )
+
+ # Make another Sliding Sync request (incremental)
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+ # Expect to see room1 because it is `newly_left` thanks to being state reset out
+ # of it since the last time we synced. We need to let the client know that
+ # something happened and that they are no longer in the room.
+ self.assertIncludes(set(response_body["rooms"].keys()), {room_id1}, exact=True)
+ # We set `initial=True` to indicate that the client should reset the state they
+ # have about the room
+ self.assertEqual(response_body["rooms"][room_id1]["initial"], True)
+ # They shouldn't see anything past the state reset
+ self._assertRequiredStateIncludes(
+ response_body["rooms"][room_id1]["required_state"],
+ # We should see all the state events in the room
+ state_map_at_reset.values(),
+ exact=True,
+ )
+ # The position where the state reset happened
+ self.assertEqual(
+ response_body["rooms"][room_id1]["bump_stamp"],
+ join_rule_event_pos.stream,
+ response_body["rooms"][room_id1],
+ )
+
+ # Other non-important things. We just want to check what these are so we know
+ # what happens in a state reset scenario.
+ #
+ # Room name was set at the time of the state reset so we should still be able to
+ # see it.
+ self.assertEqual(response_body["rooms"][room_id1]["name"], "my super room")
+ # Could be set but there is no avatar for this room
+ self.assertIsNone(
+ response_body["rooms"][room_id1].get("avatar"),
+ response_body["rooms"][room_id1],
+ )
+ # Could be set but this room isn't marked as a DM
+ self.assertIsNone(
+ response_body["rooms"][room_id1].get("is_dm"),
+ response_body["rooms"][room_id1],
+ )
+ # Empty timeline because we are not in the room at all (they are all being
+ # filtered out)
+ self.assertIsNone(
+ response_body["rooms"][room_id1].get("timeline"),
+ response_body["rooms"][room_id1],
+ )
+ # `limited` since we're not providing any timeline events but there are some in
+ # the room.
+ self.assertEqual(response_body["rooms"][room_id1]["limited"], True)
+ # User is no longer in the room so they can't see this info
+ self.assertIsNone(
+ response_body["rooms"][room_id1].get("joined_count"),
+ response_body["rooms"][room_id1],
+ )
+ self.assertIsNone(
+ response_body["rooms"][room_id1].get("invited_count"),
+ response_body["rooms"][room_id1],
+ )
+
+ def test_state_reset_previously_room_comes_down_incremental_sync_with_filters(
+ self,
+ ) -> None:
+ """
+ Test that a room that we were state reset out of should always be sent down
+ regardless of the filters if it has been sent down the connection before.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # Create a space room
+ space_room_id = self.helper.create_room_as(
+ user2_id,
+ tok=user2_tok,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE},
+ "name": "my super space",
+ },
+ )
+
+ # Create an event for us to point back to for the state reset
+ event_response = self.helper.send(space_room_id, "test", tok=user2_tok)
+ event_id = event_response["event_id"]
+
+ self.helper.join(space_room_id, user1_id, tok=user1_tok)
+
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 1]],
+ "required_state": [
+ # Request all state just to see what we get back when we are
+ # state reset out of the room
+ [StateValues.WILDCARD, StateValues.WILDCARD]
+ ],
+ "timeline_limit": 1,
+ "filters": {
+ "room_types": [RoomTypes.SPACE],
+ },
+ }
+ }
+ }
+
+ # Make the Sliding Sync request
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+ # Make sure we see room1
+ self.assertIncludes(
+ set(response_body["rooms"].keys()), {space_room_id}, exact=True
+ )
+ self.assertEqual(response_body["rooms"][space_room_id]["initial"], True)
+
+ # Trigger a state reset
+ join_rule_event, join_rule_context = self.get_success(
+ create_event(
+ self.hs,
+ prev_event_ids=[event_id],
+ type=EventTypes.JoinRules,
+ state_key="",
+ content={"join_rule": JoinRules.INVITE},
+ sender=user2_id,
+ room_id=space_room_id,
+ room_version=self.get_success(
+ self.store.get_room_version_id(space_room_id)
+ ),
+ )
+ )
+ _, join_rule_event_pos, _ = self.get_success(
+ self.persistence.persist_event(join_rule_event, join_rule_context)
+ )
+
+ # Ensure that the state reset worked and only user2 is in the room now
+ users_in_room = self.get_success(self.store.get_users_in_room(space_room_id))
+ self.assertIncludes(set(users_in_room), {user2_id}, exact=True)
+
+ state_map_at_reset = self.get_success(
+ self.storage_controllers.state.get_current_state(space_room_id)
+ )
+
+ # Update the state after user1 was state reset out of the room
+ self.helper.send_state(
+ space_room_id,
+ EventTypes.Name,
+ {EventContentFields.ROOM_NAME: "my super duper space"},
+ tok=user2_tok,
+ )
+
+ # User2 also leaves the room so the server is no longer participating in the room
+ # and we don't have access to current state
+ self.helper.leave(space_room_id, user2_id, tok=user2_tok)
+
+ # Make another Sliding Sync request (incremental)
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+ # Expect to see room1 because it is `newly_left` thanks to being state reset out
+ # of it since the last time we synced. We need to let the client know that
+ # something happened and that they are no longer in the room.
+ self.assertIncludes(
+ set(response_body["rooms"].keys()), {space_room_id}, exact=True
+ )
+ # We set `initial=True` to indicate that the client should reset the state they
+ # have about the room
+ self.assertEqual(response_body["rooms"][space_room_id]["initial"], True)
+ # They shouldn't see anything past the state reset
+ self._assertRequiredStateIncludes(
+ response_body["rooms"][space_room_id]["required_state"],
+ # We should see all the state events in the room
+ state_map_at_reset.values(),
+ exact=True,
+ )
+ # The position where the state reset happened
+ self.assertEqual(
+ response_body["rooms"][space_room_id]["bump_stamp"],
+ join_rule_event_pos.stream,
+ response_body["rooms"][space_room_id],
+ )
+
+ # Other non-important things. We just want to check what these are so we know
+ # what happens in a state reset scenario.
+ #
+ # Room name was set at the time of the state reset so we should still be able to
+ # see it.
+ self.assertEqual(
+ response_body["rooms"][space_room_id]["name"], "my super space"
+ )
+ # Could be set but there is no avatar for this room
+ self.assertIsNone(
+ response_body["rooms"][space_room_id].get("avatar"),
+ response_body["rooms"][space_room_id],
+ )
+ # Could be set but this room isn't marked as a DM
+ self.assertIsNone(
+ response_body["rooms"][space_room_id].get("is_dm"),
+ response_body["rooms"][space_room_id],
+ )
+ # Empty timeline because we are not in the room at all (they are all being
+ # filtered out)
+ self.assertIsNone(
+ response_body["rooms"][space_room_id].get("timeline"),
+ response_body["rooms"][space_room_id],
+ )
+ # `limited` since we're not providing any timeline events but there are some in
+ # the room.
+ self.assertEqual(response_body["rooms"][space_room_id]["limited"], True)
+ # User is no longer in the room so they can't see this info
+ self.assertIsNone(
+ response_body["rooms"][space_room_id].get("joined_count"),
+ response_body["rooms"][space_room_id],
+ )
+ self.assertIsNone(
+ response_body["rooms"][space_room_id].get("invited_count"),
+ response_body["rooms"][space_room_id],
+ )
+
+ @parameterized.expand(
+ [
+ ("server_leaves_room", True),
+ ("server_participating_in_room", False),
+ ]
+ )
+ def test_state_reset_never_room_incremental_sync_with_filters(
+ self, test_description: str, server_leaves_room: bool
+ ) -> None:
+ """
+ Test that a room that we were state reset out of should be sent down if we can
+ figure out the state or if it was sent down the connection before.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # Create a space room
+ space_room_id = self.helper.create_room_as(
+ user2_id,
+ tok=user2_tok,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE},
+ "name": "my super space",
+ },
+ )
+
+ # Create another space room
+ space_room_id2 = self.helper.create_room_as(
+ user2_id,
+ tok=user2_tok,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE},
+ },
+ )
+
+ # Create an event for us to point back to for the state reset
+ event_response = self.helper.send(space_room_id, "test", tok=user2_tok)
+ event_id = event_response["event_id"]
+
+ # User1 joins the rooms
+ #
+ self.helper.join(space_room_id, user1_id, tok=user1_tok)
+ # Join space_room_id2 so that it is at the top of the list
+ self.helper.join(space_room_id2, user1_id, tok=user1_tok)
+
+ # Make a SS request for only the top room.
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ "ranges": [[0, 0]],
+ "required_state": [
+ # Request all state just to see what we get back when we are
+ # state reset out of the room
+ [StateValues.WILDCARD, StateValues.WILDCARD]
+ ],
+ "timeline_limit": 1,
+ "filters": {
+ "room_types": [RoomTypes.SPACE],
+ },
+ }
+ }
+ }
+
+ # Make the Sliding Sync request
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+ # Make sure we only see space_room_id2
+ self.assertIncludes(
+ set(response_body["rooms"].keys()), {space_room_id2}, exact=True
+ )
+ self.assertEqual(response_body["rooms"][space_room_id2]["initial"], True)
+
+ # Just create some activity in space_room_id2 so it appears when we incremental sync again
+ self.helper.send(space_room_id2, "test", tok=user2_tok)
+
+ # Trigger a state reset
+ join_rule_event, join_rule_context = self.get_success(
+ create_event(
+ self.hs,
+ prev_event_ids=[event_id],
+ type=EventTypes.JoinRules,
+ state_key="",
+ content={"join_rule": JoinRules.INVITE},
+ sender=user2_id,
+ room_id=space_room_id,
+ room_version=self.get_success(
+ self.store.get_room_version_id(space_room_id)
+ ),
+ )
+ )
+ _, join_rule_event_pos, _ = self.get_success(
+ self.persistence.persist_event(join_rule_event, join_rule_context)
+ )
+
+ # Ensure that the state reset worked and only user2 is in the room now
+ users_in_room = self.get_success(self.store.get_users_in_room(space_room_id))
+ self.assertIncludes(set(users_in_room), {user2_id}, exact=True)
+
+ # Update the state after user1 was state reset out of the room.
+ # This will also bump it to the top of the list.
+ self.helper.send_state(
+ space_room_id,
+ EventTypes.Name,
+ {EventContentFields.ROOM_NAME: "my super duper space"},
+ tok=user2_tok,
+ )
+
+ if server_leaves_room:
+ # User2 also leaves the room so the server is no longer participating in the room
+ # and we don't have access to current state
+ self.helper.leave(space_room_id, user2_id, tok=user2_tok)
+
+ # Make another Sliding Sync request (incremental)
+ sync_body = {
+ "lists": {
+ "foo-list": {
+ # Expand the range to include all rooms
+ "ranges": [[0, 1]],
+ "required_state": [
+ # Request all state just to see what we get back when we are
+ # state reset out of the room
+ [StateValues.WILDCARD, StateValues.WILDCARD]
+ ],
+ "timeline_limit": 1,
+ "filters": {
+ "room_types": [RoomTypes.SPACE],
+ },
+ }
+ }
+ }
+ response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+ if self.use_new_tables:
+ if server_leaves_room:
+ # We still only expect to see space_room_id2 because even though we were state
+ # reset out of space_room_id, it was never sent down the connection before so we
+ # don't need to bother the client with it.
+ self.assertIncludes(
+ set(response_body["rooms"].keys()), {space_room_id2}, exact=True
+ )
+ else:
+ # Both rooms show up because we can figure out the state for the
+ # `filters.room_types` if someone is still in the room (we look at the
+ # current state because `room_type` never changes).
+ self.assertIncludes(
+ set(response_body["rooms"].keys()),
+ {space_room_id, space_room_id2},
+ exact=True,
+ )
+ else:
+ # Both rooms show up because we can actually take the time to figure out the
+ # state for the `filters.room_types` in the fallback path (we look at
+ # historical state for `LEAVE` membership).
+ self.assertIncludes(
+ set(response_body["rooms"].keys()),
+ {space_room_id, space_room_id2},
+ exact=True,
+ )
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index a85ea994de..33611e8a8c 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -36,7 +36,6 @@ from synapse.api.errors import Codes, HttpResponseException
from synapse.appservice import ApplicationService
from synapse.rest import admin
from synapse.rest.client import account, login, register, room
-from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
from synapse.server import HomeServer
from synapse.storage._base import db_to_json
from synapse.types import JsonDict, UserID
@@ -47,430 +46,404 @@ from tests.server import FakeSite, make_request
from tests.unittest import override_config
-class PasswordResetTestCase(unittest.HomeserverTestCase):
- servlets = [
- account.register_servlets,
- synapse.rest.admin.register_servlets_for_client_rest_resource,
- register.register_servlets,
- login.register_servlets,
- ]
-
- def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- config = self.default_config()
-
- # Email config.
- config["email"] = {
- "enable_notifs": False,
- "template_dir": os.path.abspath(
- pkg_resources.resource_filename("synapse", "res/templates")
- ),
- "smtp_host": "127.0.0.1",
- "smtp_port": 20,
- "require_transport_security": False,
- "smtp_user": None,
- "smtp_pass": None,
- "notif_from": "test@example.com",
- }
- config["public_baseurl"] = "https://example.com"
-
- hs = self.setup_test_homeserver(config=config)
-
- async def sendmail(
- reactor: IReactorTCP,
- smtphost: str,
- smtpport: int,
- from_addr: str,
- to_addr: str,
- msg_bytes: bytes,
- *args: Any,
- **kwargs: Any,
- ) -> None:
- self.email_attempts.append(msg_bytes)
-
- self.email_attempts: List[bytes] = []
- hs.get_send_email_handler()._sendmail = sendmail
-
- return hs
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
- self.submit_token_resource = PasswordResetSubmitTokenResource(hs)
-
- def attempt_wrong_password_login(self, username: str, password: str) -> None:
- """Attempts to login as the user with the given password, asserting
- that the attempt *fails*.
- """
- body = {"type": "m.login.password", "user": username, "password": password}
-
- channel = self.make_request("POST", "/_matrix/client/r0/login", body)
- self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
-
- def test_basic_password_reset(self) -> None:
- """Test basic password reset flow"""
- old_password = "monkey"
- new_password = "kangeroo"
-
- user_id = self.register_user("kermit", old_password)
- self.login("kermit", old_password)
-
- email = "test@example.com"
-
- # Add a threepid
- self.get_success(
- self.store.user_add_threepid(
- user_id=user_id,
- medium="email",
- address=email,
- validated_at=0,
- added_at=0,
- )
- )
-
- client_secret = "foobar"
- session_id = self._request_token(email, client_secret)
-
- self.assertEqual(len(self.email_attempts), 1)
- link = self._get_link_from_email()
-
- self._validate_token(link)
-
- self._reset_password(new_password, session_id, client_secret)
-
- # Assert we can log in with the new password
- self.login("kermit", new_password)
-
- # Assert we can't log in with the old password
- self.attempt_wrong_password_login("kermit", old_password)
-
- # Check that the UI Auth information doesn't store the password in the database.
- #
- # Note that we don't have the UI Auth session ID, so just pull out the single
- # row.
- result = self.get_success(
- self.store.db_pool.simple_select_one_onecol(
- "ui_auth_sessions", keyvalues={}, retcol="clientdict"
- )
- )
- client_dict = db_to_json(result)
- self.assertNotIn("new_password", client_dict)
-
- @override_config({"rc_3pid_validation": {"burst_count": 3}})
- def test_ratelimit_by_email(self) -> None:
- """Test that we ratelimit /requestToken for the same email."""
- old_password = "monkey"
- new_password = "kangeroo"
-
- user_id = self.register_user("kermit", old_password)
- self.login("kermit", old_password)
-
- email = "test1@example.com"
-
- # Add a threepid
- self.get_success(
- self.store.user_add_threepid(
- user_id=user_id,
- medium="email",
- address=email,
- validated_at=0,
- added_at=0,
- )
- )
-
- def reset(ip: str) -> None:
- client_secret = "foobar"
- session_id = self._request_token(email, client_secret, ip)
-
- self.assertEqual(len(self.email_attempts), 1)
- link = self._get_link_from_email()
-
- self._validate_token(link)
-
- self._reset_password(new_password, session_id, client_secret)
-
- self.email_attempts.clear()
-
- # We expect to be able to make three requests before getting rate
- # limited.
- #
- # We change IPs to ensure that we're not being ratelimited due to the
- # same IP
- reset("127.0.0.1")
- reset("127.0.0.2")
- reset("127.0.0.3")
-
- with self.assertRaises(HttpResponseException) as cm:
- reset("127.0.0.4")
-
- self.assertEqual(cm.exception.code, 429)
-
- def test_basic_password_reset_canonicalise_email(self) -> None:
- """Test basic password reset flow
- Request password reset with different spelling
- """
- old_password = "monkey"
- new_password = "kangeroo"
-
- user_id = self.register_user("kermit", old_password)
- self.login("kermit", old_password)
-
- email_profile = "test@example.com"
- email_passwort_reset = "TEST@EXAMPLE.COM"
-
- # Add a threepid
- self.get_success(
- self.store.user_add_threepid(
- user_id=user_id,
- medium="email",
- address=email_profile,
- validated_at=0,
- added_at=0,
- )
- )
-
- client_secret = "foobar"
- session_id = self._request_token(email_passwort_reset, client_secret)
-
- self.assertEqual(len(self.email_attempts), 1)
- link = self._get_link_from_email()
-
- self._validate_token(link)
-
- self._reset_password(new_password, session_id, client_secret)
-
- # Assert we can log in with the new password
- self.login("kermit", new_password)
-
- # Assert we can't log in with the old password
- self.attempt_wrong_password_login("kermit", old_password)
-
- def test_cant_reset_password_without_clicking_link(self) -> None:
- """Test that we do actually need to click the link in the email"""
- old_password = "monkey"
- new_password = "kangeroo"
-
- user_id = self.register_user("kermit", old_password)
- self.login("kermit", old_password)
-
- email = "test@example.com"
+# class PasswordResetTestCase(unittest.HomeserverTestCase):
+# servlets = [
+# account.register_servlets,
+# synapse.rest.admin.register_servlets_for_client_rest_resource,
+# register.register_servlets,
+# login.register_servlets,
+# ]
+
+# def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+# config = self.default_config()
+
+# # Email config.
+# config["email"] = {
+# "enable_notifs": False,
+# "template_dir": os.path.abspath(
+# pkg_resources.resource_filename("synapse", "res/templates")
+# ),
+# "smtp_host": "127.0.0.1",
+# "smtp_port": 20,
+# "require_transport_security": False,
+# "smtp_user": None,
+# "smtp_pass": None,
+# "notif_from": "test@example.com",
+# }
+# config["public_baseurl"] = "https://example.com"
+
+# hs = self.setup_test_homeserver(config=config)
- # Add a threepid
- self.get_success(
- self.store.user_add_threepid(
- user_id=user_id,
- medium="email",
- address=email,
- validated_at=0,
- added_at=0,
- )
- )
-
- client_secret = "foobar"
- session_id = self._request_token(email, client_secret)
-
- self.assertEqual(len(self.email_attempts), 1)
-
- # Attempt to reset password without clicking the link
- self._reset_password(new_password, session_id, client_secret, expected_code=401)
-
- # Assert we can log in with the old password
- self.login("kermit", old_password)
-
- # Assert we can't log in with the new password
- self.attempt_wrong_password_login("kermit", new_password)
-
- def test_no_valid_token(self) -> None:
- """Test that we do actually need to request a token and can't just
- make a session up.
- """
- old_password = "monkey"
- new_password = "kangeroo"
-
- user_id = self.register_user("kermit", old_password)
- self.login("kermit", old_password)
-
- email = "test@example.com"
-
- # Add a threepid
- self.get_success(
- self.store.user_add_threepid(
- user_id=user_id,
- medium="email",
- address=email,
- validated_at=0,
- added_at=0,
- )
- )
-
- client_secret = "foobar"
- session_id = "weasle"
+# async def sendmail(
+# reactor: IReactorTCP,
+# smtphost: str,
+# smtpport: int,
+# from_addr: str,
+# to_addr: str,
+# msg_bytes: bytes,
+# *args: Any,
+# **kwargs: Any,
+# ) -> None:
+# self.email_attempts.append(msg_bytes)
- # Attempt to reset password without even requesting an email
- self._reset_password(new_password, session_id, client_secret, expected_code=401)
+# self.email_attempts: List[bytes] = []
+
+# return hs
- # Assert we can log in with the old password
- self.login("kermit", old_password)
+# def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+# self.store = hs.get_datastores().main
+
+# def attempt_wrong_password_login(self, username: str, password: str) -> None:
+# """Attempts to login as the user with the given password, asserting
+# that the attempt *fails*.
+# """
+# body = {"type": "m.login.password", "user": username, "password": password}
+
+# channel = self.make_request("POST", "/_matrix/client/r0/login", body)
+# self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
+
+# def test_basic_password_reset(self) -> None:
+# """Test basic password reset flow"""
+# old_password = "monkey"
+# new_password = "kangeroo"
+
+# user_id = self.register_user("kermit", old_password)
+# self.login("kermit", old_password)
- # Assert we can't log in with the new password
- self.attempt_wrong_password_login("kermit", new_password)
- @unittest.override_config({"request_token_inhibit_3pid_errors": True})
- def test_password_reset_bad_email_inhibit_error(self) -> None:
- """Test that triggering a password reset with an email address that isn't bound
- to an account doesn't leak the lack of binding for that address if configured
- that way.
- """
- self.register_user("kermit", "monkey")
- self.login("kermit", "monkey")
+# client_secret = "foobar"
+# session_id = self._request_token(email, client_secret)
- email = "test@example.com"
+# self.assertEqual(len(self.email_attempts), 1)
+# link = self._get_link_from_email()
- client_secret = "foobar"
- session_id = self._request_token(email, client_secret)
+# self._validate_token(link)
- self.assertIsNotNone(session_id)
+# self._reset_password(new_password, session_id, client_secret)
- def test_password_reset_redirection(self) -> None:
- """Test basic password reset flow"""
- old_password = "monkey"
+# # Assert we can log in with the new password
+# self.login("kermit", new_password)
- user_id = self.register_user("kermit", old_password)
- self.login("kermit", old_password)
+# # Assert we can't log in with the old password
+# self.attempt_wrong_password_login("kermit", old_password)
- email = "test@example.com"
+# # Check that the UI Auth information doesn't store the password in the database.
+# #
+# # Note that we don't have the UI Auth session ID, so just pull out the single
+# # row.
+# result = self.get_success(
+# self.store.db_pool.simple_select_one_onecol(
+# "ui_auth_sessions", keyvalues={}, retcol="clientdict"
+# )
+# )
+# client_dict = db_to_json(result)
+# self.assertNotIn("new_password", client_dict)
+
+# @override_config({"rc_3pid_validation": {"burst_count": 3}})
+# def test_ratelimit_by_email(self) -> None:
+# """Test that we ratelimit /requestToken for the same email."""
+# old_password = "monkey"
+# new_password = "kangeroo"
- # Add a threepid
- self.get_success(
- self.store.user_add_threepid(
- user_id=user_id,
- medium="email",
- address=email,
- validated_at=0,
- added_at=0,
- )
- )
+# user_id = self.register_user("kermit", old_password)
+# self.login("kermit", old_password)
+
+
+# def reset(ip: str) -> None:
+# client_secret = "foobar"
+# session_id = self._request_token(email, client_secret, ip)
- client_secret = "foobar"
- next_link = "http://example.com"
- self._request_token(email, client_secret, "127.0.0.1", next_link)
+# self.assertEqual(len(self.email_attempts), 1)
+# link = self._get_link_from_email()
+
+# self._validate_token(link)
+
+# self._reset_password(new_password, session_id, client_secret)
+
+# self.email_attempts.clear()
- self.assertEqual(len(self.email_attempts), 1)
- link = self._get_link_from_email()
+# # We expect to be able to make three requests before getting rate
+# # limited.
+# #
+# # We change IPs to ensure that we're not being ratelimited due to the
+# # same IP
+# reset("127.0.0.1")
+# reset("127.0.0.2")
+# reset("127.0.0.3")
- self._validate_token(link, next_link)
+# with self.assertRaises(HttpResponseException) as cm:
+# reset("127.0.0.4")
- def _request_token(
- self,
- email: str,
- client_secret: str,
- ip: str = "127.0.0.1",
- next_link: Optional[str] = None,
- ) -> str:
- body = {"client_secret": client_secret, "email": email, "send_attempt": 1}
- if next_link is not None:
- body["next_link"] = next_link
- channel = self.make_request(
- "POST",
- b"account/password/email/requestToken",
- body,
- client_ip=ip,
- )
+# self.assertEqual(cm.exception.code, 429)
- if channel.code != 200:
- raise HttpResponseException(
- channel.code,
- channel.result["reason"],
- channel.result["body"],
- )
+# def test_basic_password_reset_canonicalise_email(self) -> None:
+# """Test basic password reset flow
+# Request password reset with different spelling
+# """
+# old_password = "monkey"
+# new_password = "kangeroo"
- return channel.json_body["sid"]
+# user_id = self.register_user("kermit", old_password)
+# self.login("kermit", old_password)
- def _validate_token(self, link: str, next_link: Optional[str] = None) -> None:
- # Remove the host
- path = link.replace("https://example.com", "")
+# email_profile = "test@example.com"
+# email_passwort_reset = "TEST@EXAMPLE.COM"
- # Load the password reset confirmation page
- channel = make_request(
- self.reactor,
- FakeSite(self.submit_token_resource, self.reactor),
- "GET",
- path,
- shorthand=False,
- )
+# # Add a threepid
+# self.get_success(
+# self.store.user_add_threepid(
+# user_id=user_id,
+# medium="email",
+# address=email_profile,
+# validated_at=0,
+# added_at=0,
+# )
+# )
- self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+# client_secret = "foobar"
+# session_id = self._request_token(email_passwort_reset, client_secret)
- # Now POST to the same endpoint, mimicking the same behaviour as clicking the
- # password reset confirm button
+# self.assertEqual(len(self.email_attempts), 1)
+# link = self._get_link_from_email()
- # Confirm the password reset
- channel = make_request(
- self.reactor,
- FakeSite(self.submit_token_resource, self.reactor),
- "POST",
- path,
- content=b"",
- shorthand=False,
- content_is_form=True,
- )
- self.assertEqual(
- HTTPStatus.OK if next_link is None else HTTPStatus.FOUND,
- channel.code,
- channel.result,
- )
-
- def _get_link_from_email(self) -> str:
- assert self.email_attempts, "No emails have been sent"
+# self._validate_token(link)
- raw_msg = self.email_attempts[-1].decode("UTF-8")
- mail = Parser().parsestr(raw_msg)
+# self._reset_password(new_password, session_id, client_secret)
- text = None
- for part in mail.walk():
- if part.get_content_type() == "text/plain":
- text = part.get_payload(decode=True)
- if text is not None:
- # According to the logic table in `get_payload`, we know that
- # the result of `get_payload` will be `bytes`, but mypy doesn't
- # know this and complains. Thus, we assert the type.
- assert isinstance(text, bytes)
- text = text.decode("UTF-8")
+# # Assert we can log in with the new password
+# self.login("kermit", new_password)
- break
+# # Assert we can't log in with the old password
+# self.attempt_wrong_password_login("kermit", old_password)
- if not text:
- self.fail("Could not find text portion of email to parse")
+# def test_cant_reset_password_without_clicking_link(self) -> None:
+# """Test that we do actually need to click the link in the email"""
+# old_password = "monkey"
+# new_password = "kangeroo"
+
+# user_id = self.register_user("kermit", old_password)
+# self.login("kermit", old_password)
+
+# email = "test@example.com"
+
+# # Add a threepid
+# self.get_success(
+# self.store.user_add_threepid(
+# user_id=user_id,
+# medium="email",
+# address=email,
+# validated_at=0,
+# added_at=0,
+# )
+# )
- # `text` must be a `str`, after being decoded and determined just above
- # to not be `None` or an empty `str`.
- assert isinstance(text, str)
-
- match = re.search(r"https://example.com\S+", text)
- assert match, "Could not find link in email"
-
- return match.group(0)
-
- def _reset_password(
- self,
- new_password: str,
- session_id: str,
- client_secret: str,
- expected_code: int = HTTPStatus.OK,
- ) -> None:
- channel = self.make_request(
- "POST",
- b"account/password",
- {
- "new_password": new_password,
- "auth": {
- "type": LoginType.EMAIL_IDENTITY,
- "threepid_creds": {
- "client_secret": client_secret,
- "sid": session_id,
- },
- },
- },
- )
- self.assertEqual(expected_code, channel.code, channel.result)
+# client_secret = "foobar"
+# session_id = self._request_token(email, client_secret)
+
+# self.assertEqual(len(self.email_attempts), 1)
+
+# # Attempt to reset password without clicking the link
+# self._reset_password(new_password, session_id, client_secret, expected_code=401)
+
+# # Assert we can log in with the old password
+# self.login("kermit", old_password)
+
+# # Assert we can't log in with the new password
+# self.attempt_wrong_password_login("kermit", new_password)
+
+# def test_no_valid_token(self) -> None:
+# """Test that we do actually need to request a token and can't just
+# make a session up.
+# """
+# old_password = "monkey"
+# new_password = "kangeroo"
+
+# user_id = self.register_user("kermit", old_password)
+# self.login("kermit", old_password)
+
+# email = "test@example.com"
+
+# # Add a threepid
+# self.get_success(
+# self.store.user_add_threepid(
+# user_id=user_id,
+# medium="email",
+# address=email,
+# validated_at=0,
+# added_at=0,
+# )
+# )
+
+# client_secret = "foobar"
+# session_id = "weasle"
+
+# # Attempt to reset password without even requesting an email
+# self._reset_password(new_password, session_id, client_secret, expected_code=401)
+
+# # Assert we can log in with the old password
+# self.login("kermit", old_password)
+
+# # Assert we can't log in with the new password
+# self.attempt_wrong_password_login("kermit", new_password)
+
+# @unittest.override_config({"request_token_inhibit_3pid_errors": True})
+# def test_password_reset_bad_email_inhibit_error(self) -> None:
+# """Test that triggering a password reset with an email address that isn't bound
+# to an account doesn't leak the lack of binding for that address if configured
+# that way.
+# """
+# self.register_user("kermit", "monkey")
+# self.login("kermit", "monkey")
+
+# email = "test@example.com"
+
+# client_secret = "foobar"
+# session_id = self._request_token(email, client_secret)
+
+# self.assertIsNotNone(session_id)
+
+# def test_password_reset_redirection(self) -> None:
+# """Test basic password reset flow"""
+# old_password = "monkey"
+
+# user_id = self.register_user("kermit", old_password)
+# self.login("kermit", old_password)
+
+# email = "test@example.com"
+
+# # Add a threepid
+# self.get_success(
+# self.store.user_add_threepid(
+# user_id=user_id,
+# medium="email",
+# address=email,
+# validated_at=0,
+# added_at=0,
+# )
+# )
+
+# client_secret = "foobar"
+# next_link = "http://example.com"
+# self._request_token(email, client_secret, "127.0.0.1", next_link)
+
+# self.assertEqual(len(self.email_attempts), 1)
+# link = self._get_link_from_email()
+
+# self._validate_token(link, next_link)
+
+# def _request_token(
+# self,
+# email: str,
+# client_secret: str,
+# ip: str = "127.0.0.1",
+# next_link: Optional[str] = None,
+# ) -> str:
+# body = {"client_secret": client_secret, "email": email, "send_attempt": 1}
+# if next_link is not None:
+# body["next_link"] = next_link
+# channel = self.make_request(
+# "POST",
+# b"account/password/email/requestToken",
+# body,
+# client_ip=ip,
+# )
+
+# if channel.code != 200:
+# raise HttpResponseException(
+# channel.code,
+# channel.result["reason"],
+# channel.result["body"],
+# )
+
+# return channel.json_body["sid"]
+
+# def _validate_token(self, link: str, next_link: Optional[str] = None) -> None:
+# # Remove the host
+# path = link.replace("https://example.com", "")
+
+# # Load the password reset confirmation page
+# channel = make_request(
+# self.reactor,
+# FakeSite(self.submit_token_resource, self.reactor),
+# "GET",
+# path,
+# shorthand=False,
+# )
+
+# self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+
+# # Now POST to the same endpoint, mimicking the same behaviour as clicking the
+# # password reset confirm button
+
+# # Confirm the password reset
+# channel = make_request(
+# self.reactor,
+# FakeSite(self.submit_token_resource, self.reactor),
+# "POST",
+# path,
+# content=b"",
+# shorthand=False,
+# content_is_form=True,
+# )
+# self.assertEqual(
+# HTTPStatus.OK if next_link is None else HTTPStatus.FOUND,
+# channel.code,
+# channel.result,
+# )
+
+# def _get_link_from_email(self) -> str:
+# assert self.email_attempts, "No emails have been sent"
+
+# raw_msg = self.email_attempts[-1].decode("UTF-8")
+# mail = Parser().parsestr(raw_msg)
+
+# text = None
+# for part in mail.walk():
+# if part.get_content_type() == "text/plain":
+# text = part.get_payload(decode=True)
+# if text is not None:
+# # According to the logic table in `get_payload`, we know that
+# # the result of `get_payload` will be `bytes`, but mypy doesn't
+# # know this and complains. Thus, we assert the type.
+# assert isinstance(text, bytes)
+# text = text.decode("UTF-8")
+
+# break
+
+# if not text:
+# self.fail("Could not find text portion of email to parse")
+
+# # `text` must be a `str`, after being decoded and determined just above
+# # to not be `None` or an empty `str`.
+# assert isinstance(text, str)
+
+# match = re.search(r"https://example.com\S+", text)
+# assert match, "Could not find link in email"
+
+# return match.group(0)
+
+# def _reset_password(
+# self,
+# new_password: str,
+# session_id: str,
+# client_secret: str,
+# expected_code: int = HTTPStatus.OK,
+# ) -> None:
+# channel = self.make_request(
+# "POST",
+# b"account/password",
+# {
+# "new_password": new_password,
+# "auth": {
+# "type": LoginType.EMAIL_IDENTITY,
+# "threepid_creds": {
+# "client_secret": client_secret,
+# "sid": session_id,
+# },
+# },
+# },
+# )
+# self.assertEqual(expected_code, channel.code, channel.result)
class DeactivateTestCase(unittest.HomeserverTestCase):
@@ -787,503 +760,6 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
return channel.json_body
-class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
- servlets = [
- account.register_servlets,
- login.register_servlets,
- synapse.rest.admin.register_servlets_for_client_rest_resource,
- ]
-
- def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- config = self.default_config()
-
- # Email config.
- config["email"] = {
- "enable_notifs": False,
- "template_dir": os.path.abspath(
- pkg_resources.resource_filename("synapse", "res/templates")
- ),
- "smtp_host": "127.0.0.1",
- "smtp_port": 20,
- "require_transport_security": False,
- "smtp_user": None,
- "smtp_pass": None,
- "notif_from": "test@example.com",
- }
- config["public_baseurl"] = "https://example.com"
-
- self.hs = self.setup_test_homeserver(config=config)
-
- async def sendmail(
- reactor: IReactorTCP,
- smtphost: str,
- smtpport: int,
- from_addr: str,
- to_addr: str,
- msg_bytes: bytes,
- *args: Any,
- **kwargs: Any,
- ) -> None:
- self.email_attempts.append(msg_bytes)
-
- self.email_attempts: List[bytes] = []
- self.hs.get_send_email_handler()._sendmail = sendmail
-
- return self.hs
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
-
- self.user_id = self.register_user("kermit", "test")
- self.user_id_tok = self.login("kermit", "test")
- self.email = "test@example.com"
- self.url_3pid = b"account/3pid"
-
- def test_add_valid_email(self) -> None:
- self._add_email(self.email, self.email)
-
- def test_add_valid_email_second_time(self) -> None:
- self._add_email(self.email, self.email)
- self._request_token_invalid_email(
- self.email,
- expected_errcode=Codes.THREEPID_IN_USE,
- expected_error="Email is already in use",
- )
-
- def test_add_valid_email_second_time_canonicalise(self) -> None:
- self._add_email(self.email, self.email)
- self._request_token_invalid_email(
- "TEST@EXAMPLE.COM",
- expected_errcode=Codes.THREEPID_IN_USE,
- expected_error="Email is already in use",
- )
-
- def test_add_email_no_at(self) -> None:
- self._request_token_invalid_email(
- "address-without-at.bar",
- expected_errcode=Codes.BAD_JSON,
- expected_error="Unable to parse email address",
- )
-
- def test_add_email_two_at(self) -> None:
- self._request_token_invalid_email(
- "foo@foo@test.bar",
- expected_errcode=Codes.BAD_JSON,
- expected_error="Unable to parse email address",
- )
-
- def test_add_email_bad_format(self) -> None:
- self._request_token_invalid_email(
- "user@bad.example.net@good.example.com",
- expected_errcode=Codes.BAD_JSON,
- expected_error="Unable to parse email address",
- )
-
- def test_add_email_domain_to_lower(self) -> None:
- self._add_email("foo@TEST.BAR", "foo@test.bar")
-
- def test_add_email_domain_with_umlaut(self) -> None:
- self._add_email("foo@Öumlaut.com", "foo@öumlaut.com")
-
- def test_add_email_address_casefold(self) -> None:
- self._add_email("Strauß@Example.com", "strauss@example.com")
-
- def test_address_trim(self) -> None:
- self._add_email(" foo@test.bar ", "foo@test.bar")
-
- @override_config({"rc_3pid_validation": {"burst_count": 3}})
- def test_ratelimit_by_ip(self) -> None:
- """Tests that adding emails is ratelimited by IP"""
-
- # We expect to be able to set three emails before getting ratelimited.
- self._add_email("foo1@test.bar", "foo1@test.bar")
- self._add_email("foo2@test.bar", "foo2@test.bar")
- self._add_email("foo3@test.bar", "foo3@test.bar")
-
- with self.assertRaises(HttpResponseException) as cm:
- self._add_email("foo4@test.bar", "foo4@test.bar")
-
- self.assertEqual(cm.exception.code, 429)
-
- def test_add_email_if_disabled(self) -> None:
- """Test adding email to profile when doing so is disallowed"""
- self.hs.config.registration.enable_3pid_changes = False
-
- client_secret = "foobar"
- channel = self.make_request(
- "POST",
- b"/_matrix/client/unstable/account/3pid/email/requestToken",
- {
- "client_secret": client_secret,
- "email": "test@example.com",
- "send_attempt": 1,
- },
- )
-
- self.assertEqual(
- HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
- )
-
- self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
- def test_delete_email(self) -> None:
- """Test deleting an email from profile"""
- # Add a threepid
- self.get_success(
- self.store.user_add_threepid(
- user_id=self.user_id,
- medium="email",
- address=self.email,
- validated_at=0,
- added_at=0,
- )
- )
-
- channel = self.make_request(
- "POST",
- b"account/3pid/delete",
- {"medium": "email", "address": self.email},
- access_token=self.user_id_tok,
- )
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
-
- # Get user
- channel = self.make_request(
- "GET",
- self.url_3pid,
- access_token=self.user_id_tok,
- )
-
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
- self.assertFalse(channel.json_body["threepids"])
-
- def test_delete_email_if_disabled(self) -> None:
- """Test deleting an email from profile when disallowed"""
- self.hs.config.registration.enable_3pid_changes = False
-
- # Add a threepid
- self.get_success(
- self.store.user_add_threepid(
- user_id=self.user_id,
- medium="email",
- address=self.email,
- validated_at=0,
- added_at=0,
- )
- )
-
- channel = self.make_request(
- "POST",
- b"account/3pid/delete",
- {"medium": "email", "address": self.email},
- access_token=self.user_id_tok,
- )
-
- self.assertEqual(
- HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
- )
- self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
- # Get user
- channel = self.make_request(
- "GET",
- self.url_3pid,
- access_token=self.user_id_tok,
- )
-
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
-
- def test_cant_add_email_without_clicking_link(self) -> None:
- """Test that we do actually need to click the link in the email"""
- client_secret = "foobar"
- session_id = self._request_token(self.email, client_secret)
-
- self.assertEqual(len(self.email_attempts), 1)
-
- # Attempt to add email without clicking the link
- channel = self.make_request(
- "POST",
- b"/_matrix/client/unstable/account/3pid/add",
- {
- "client_secret": client_secret,
- "sid": session_id,
- "auth": {
- "type": "m.login.password",
- "user": self.user_id,
- "password": "test",
- },
- },
- access_token=self.user_id_tok,
- )
- self.assertEqual(
- HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
- )
- self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
-
- # Get user
- channel = self.make_request(
- "GET",
- self.url_3pid,
- access_token=self.user_id_tok,
- )
-
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
- self.assertFalse(channel.json_body["threepids"])
-
- def test_no_valid_token(self) -> None:
- """Test that we do actually need to request a token and can't just
- make a session up.
- """
- client_secret = "foobar"
- session_id = "weasle"
-
- # Attempt to add email without even requesting an email
- channel = self.make_request(
- "POST",
- b"/_matrix/client/unstable/account/3pid/add",
- {
- "client_secret": client_secret,
- "sid": session_id,
- "auth": {
- "type": "m.login.password",
- "user": self.user_id,
- "password": "test",
- },
- },
- access_token=self.user_id_tok,
- )
- self.assertEqual(
- HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
- )
- self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
-
- # Get user
- channel = self.make_request(
- "GET",
- self.url_3pid,
- access_token=self.user_id_tok,
- )
-
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
- self.assertFalse(channel.json_body["threepids"])
-
- @override_config({"next_link_domain_whitelist": None})
- def test_next_link(self) -> None:
- """Tests a valid next_link parameter value with no whitelist (good case)"""
- self._request_token(
- "something@example.com",
- "some_secret",
- next_link="https://example.com/a/good/site",
- expect_code=HTTPStatus.OK,
- )
-
- @override_config({"next_link_domain_whitelist": None})
- def test_next_link_exotic_protocol(self) -> None:
- """Tests using a esoteric protocol as a next_link parameter value.
- Someone may be hosting a client on IPFS etc.
- """
- self._request_token(
- "something@example.com",
- "some_secret",
- next_link="some-protocol://abcdefghijklmopqrstuvwxyz",
- expect_code=HTTPStatus.OK,
- )
-
- @override_config({"next_link_domain_whitelist": None})
- def test_next_link_file_uri(self) -> None:
- """Tests next_link parameters cannot be file URI"""
- # Attempt to use a next_link value that points to the local disk
- self._request_token(
- "something@example.com",
- "some_secret",
- next_link="file:///host/path",
- expect_code=HTTPStatus.BAD_REQUEST,
- )
-
- @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]})
- def test_next_link_domain_whitelist(self) -> None:
- """Tests next_link parameters must fit the whitelist if provided"""
-
- # Ensure not providing a next_link parameter still works
- self._request_token(
- "something@example.com",
- "some_secret",
- next_link=None,
- expect_code=HTTPStatus.OK,
- )
-
- self._request_token(
- "something@example.com",
- "some_secret",
- next_link="https://example.com/some/good/page",
- expect_code=HTTPStatus.OK,
- )
-
- self._request_token(
- "something@example.com",
- "some_secret",
- next_link="https://example.org/some/also/good/page",
- expect_code=HTTPStatus.OK,
- )
-
- self._request_token(
- "something@example.com",
- "some_secret",
- next_link="https://bad.example.org/some/bad/page",
- expect_code=HTTPStatus.BAD_REQUEST,
- )
-
- @override_config({"next_link_domain_whitelist": []})
- def test_empty_next_link_domain_whitelist(self) -> None:
- """Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially
- disallowed
- """
- self._request_token(
- "something@example.com",
- "some_secret",
- next_link="https://example.com/a/page",
- expect_code=HTTPStatus.BAD_REQUEST,
- )
-
- def _request_token(
- self,
- email: str,
- client_secret: str,
- next_link: Optional[str] = None,
- expect_code: int = HTTPStatus.OK,
- ) -> Optional[str]:
- """Request a validation token to add an email address to a user's account
-
- Args:
- email: The email address to validate
- client_secret: A secret string
- next_link: A link to redirect the user to after validation
- expect_code: Expected return code of the call
-
- Returns:
- The ID of the new threepid validation session, or None if the response
- did not contain a session ID.
- """
- body = {"client_secret": client_secret, "email": email, "send_attempt": 1}
- if next_link:
- body["next_link"] = next_link
-
- channel = self.make_request(
- "POST",
- b"account/3pid/email/requestToken",
- body,
- )
-
- if channel.code != expect_code:
- raise HttpResponseException(
- channel.code,
- channel.result["reason"],
- channel.result["body"],
- )
-
- return channel.json_body.get("sid")
-
- def _request_token_invalid_email(
- self,
- email: str,
- expected_errcode: str,
- expected_error: str,
- client_secret: str = "foobar",
- ) -> None:
- channel = self.make_request(
- "POST",
- b"account/3pid/email/requestToken",
- {"client_secret": client_secret, "email": email, "send_attempt": 1},
- )
- self.assertEqual(
- HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
- )
- self.assertEqual(expected_errcode, channel.json_body["errcode"])
- self.assertIn(expected_error, channel.json_body["error"])
-
- def _validate_token(self, link: str) -> None:
- # Remove the host
- path = link.replace("https://example.com", "")
-
- channel = self.make_request("GET", path, shorthand=False)
- self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
-
- def _get_link_from_email(self) -> str:
- assert self.email_attempts, "No emails have been sent"
-
- raw_msg = self.email_attempts[-1].decode("UTF-8")
- mail = Parser().parsestr(raw_msg)
-
- text = None
- for part in mail.walk():
- if part.get_content_type() == "text/plain":
- text = part.get_payload(decode=True)
- if text is not None:
- # According to the logic table in `get_payload`, we know that
- # the result of `get_payload` will be `bytes`, but mypy doesn't
- # know this and complains. Thus, we assert the type.
- assert isinstance(text, bytes)
- text = text.decode("UTF-8")
-
- break
-
- if not text:
- self.fail("Could not find text portion of email to parse")
-
- # `text` must be a `str`, after being decoded and determined just above
- # to not be `None` or an empty `str`.
- assert isinstance(text, str)
-
- match = re.search(r"https://example.com\S+", text)
- assert match, "Could not find link in email"
-
- return match.group(0)
-
- def _add_email(self, request_email: str, expected_email: str) -> None:
- """Test adding an email to profile"""
- previous_email_attempts = len(self.email_attempts)
-
- client_secret = "foobar"
- session_id = self._request_token(request_email, client_secret)
-
- self.assertEqual(len(self.email_attempts) - previous_email_attempts, 1)
- link = self._get_link_from_email()
-
- self._validate_token(link)
-
- channel = self.make_request(
- "POST",
- b"/_matrix/client/unstable/account/3pid/add",
- {
- "client_secret": client_secret,
- "sid": session_id,
- "auth": {
- "type": "m.login.password",
- "user": self.user_id,
- "password": "test",
- },
- },
- access_token=self.user_id_tok,
- )
-
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
-
- # Get user
- channel = self.make_request(
- "GET",
- self.url_3pid,
- access_token=self.user_id_tok,
- )
-
- self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
-
- threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
- self.assertIn(expected_email, threepids)
-
-
class AccountStatusTestCase(unittest.HomeserverTestCase):
servlets = [
account.register_servlets,
diff --git a/tests/rest/client/test_auth_issuer.py b/tests/rest/client/test_auth_issuer.py
deleted file mode 100644
index 964baeec32..0000000000
--- a/tests/rest/client/test_auth_issuer.py
+++ /dev/null
@@ -1,59 +0,0 @@
-# Copyright 2023 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from http import HTTPStatus
-
-from synapse.rest.client import auth_issuer
-
-from tests.unittest import HomeserverTestCase, override_config, skip_unless
-from tests.utils import HAS_AUTHLIB
-
-ISSUER = "https://account.example.com/"
-
-
-class AuthIssuerTestCase(HomeserverTestCase):
- servlets = [
- auth_issuer.register_servlets,
- ]
-
- def test_returns_404_when_msc3861_disabled(self) -> None:
- # Make an unauthenticated request for the discovery info.
- channel = self.make_request(
- "GET",
- "/_matrix/client/unstable/org.matrix.msc2965/auth_issuer",
- )
- self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
-
- @skip_unless(HAS_AUTHLIB, "requires authlib")
- @override_config(
- {
- "disable_registration": True,
- "experimental_features": {
- "msc3861": {
- "enabled": True,
- "issuer": ISSUER,
- "client_id": "David Lister",
- "client_auth_method": "client_secret_post",
- "client_secret": "Who shot Mister Burns?",
- }
- },
- }
- )
- def test_returns_issuer_when_oidc_enabled(self) -> None:
- # Make an unauthenticated request for the discovery info.
- channel = self.make_request(
- "GET",
- "/_matrix/client/unstable/org.matrix.msc2965/auth_issuer",
- )
- self.assertEqual(channel.code, HTTPStatus.OK)
- self.assertEqual(channel.json_body, {"issuer": ISSUER})
diff --git a/tests/rest/client/test_auth_metadata.py b/tests/rest/client/test_auth_metadata.py
new file mode 100644
index 0000000000..a935533b09
--- /dev/null
+++ b/tests/rest/client/test_auth_metadata.py
@@ -0,0 +1,140 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright 2023 The Matrix.org Foundation C.I.C
+# Copyright (C) 2023-2025 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+# Originally licensed under the Apache License, Version 2.0:
+# <http://www.apache.org/licenses/LICENSE-2.0>.
+#
+# [This file includes modifications made by New Vector Limited]
+#
+from http import HTTPStatus
+from unittest.mock import AsyncMock
+
+from synapse.rest.client import auth_metadata
+
+from tests.unittest import HomeserverTestCase, override_config, skip_unless
+from tests.utils import HAS_AUTHLIB
+
+ISSUER = "https://account.example.com/"
+
+
+class AuthIssuerTestCase(HomeserverTestCase):
+ servlets = [
+ auth_metadata.register_servlets,
+ ]
+
+ def test_returns_404_when_msc3861_disabled(self) -> None:
+ # Make an unauthenticated request for the discovery info.
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc2965/auth_issuer",
+ )
+ self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
+
+ @skip_unless(HAS_AUTHLIB, "requires authlib")
+ @override_config(
+ {
+ "disable_registration": True,
+ "experimental_features": {
+ "msc3861": {
+ "enabled": True,
+ "issuer": ISSUER,
+ "client_id": "David Lister",
+ "client_auth_method": "client_secret_post",
+ "client_secret": "Who shot Mister Burns?",
+ }
+ },
+ }
+ )
+ def test_returns_issuer_when_oidc_enabled(self) -> None:
+ # Patch the HTTP client to return the issuer metadata
+ req_mock = AsyncMock(return_value={"issuer": ISSUER})
+ self.hs.get_proxied_http_client().get_json = req_mock # type: ignore[method-assign]
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc2965/auth_issuer",
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertEqual(channel.json_body, {"issuer": ISSUER})
+
+ req_mock.assert_called_with(
+ "https://account.example.com/.well-known/openid-configuration"
+ )
+ req_mock.reset_mock()
+
+ # Second call it should use the cached value
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc2965/auth_issuer",
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertEqual(channel.json_body, {"issuer": ISSUER})
+ req_mock.assert_not_called()
+
+
+class AuthMetadataTestCase(HomeserverTestCase):
+ servlets = [
+ auth_metadata.register_servlets,
+ ]
+
+ def test_returns_404_when_msc3861_disabled(self) -> None:
+ # Make an unauthenticated request for the discovery info.
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc2965/auth_metadata",
+ )
+ self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
+
+ @skip_unless(HAS_AUTHLIB, "requires authlib")
+ @override_config(
+ {
+ "disable_registration": True,
+ "experimental_features": {
+ "msc3861": {
+ "enabled": True,
+ "issuer": ISSUER,
+ "client_id": "David Lister",
+ "client_auth_method": "client_secret_post",
+ "client_secret": "Who shot Mister Burns?",
+ }
+ },
+ }
+ )
+ def test_returns_issuer_when_oidc_enabled(self) -> None:
+ # Patch the HTTP client to return the issuer metadata
+ req_mock = AsyncMock(
+ return_value={
+ "issuer": ISSUER,
+ "authorization_endpoint": "https://example.com/auth",
+ "token_endpoint": "https://example.com/token",
+ }
+ )
+ self.hs.get_proxied_http_client().get_json = req_mock # type: ignore[method-assign]
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc2965/auth_metadata",
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertEqual(
+ channel.json_body,
+ {
+ "issuer": ISSUER,
+ "authorization_endpoint": "https://example.com/auth",
+ "token_endpoint": "https://example.com/token",
+ },
+ )
diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py
index bbe8ab1a7c..1cfaf4fbd7 100644
--- a/tests/rest/client/test_capabilities.py
+++ b/tests/rest/client/test_capabilities.py
@@ -118,7 +118,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.assertTrue(capabilities["m.change_password"]["enabled"])
self.assertTrue(capabilities["m.set_displayname"]["enabled"])
self.assertTrue(capabilities["m.set_avatar_url"]["enabled"])
- self.assertTrue(capabilities["m.3pid_changes"]["enabled"])
+ self.assertFalse(capabilities["m.3pid_changes"]["enabled"])
@override_config({"enable_set_displayname": False})
def test_get_set_displayname_capabilities_displayname_disabled(self) -> None:
@@ -142,56 +142,49 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, HTTPStatus.OK)
self.assertFalse(capabilities["m.set_avatar_url"]["enabled"])
- @override_config({"enable_3pid_changes": False})
- def test_get_change_3pid_capabilities_3pid_disabled(self) -> None:
- """Test if change 3pid is disabled that the server responds it."""
+ @override_config(
+ {
+ "enable_set_displayname": False,
+ "experimental_features": {"msc4133_enabled": True},
+ }
+ )
+ def test_get_set_displayname_capabilities_displayname_disabled_msc4133(
+ self,
+ ) -> None:
+ """Test if set displayname is disabled that the server responds it."""
access_token = self.login(self.localpart, self.password)
channel = self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, HTTPStatus.OK)
- self.assertFalse(capabilities["m.3pid_changes"]["enabled"])
-
- @override_config({"experimental_features": {"msc3244_enabled": False}})
- def test_get_does_not_include_msc3244_fields_when_disabled(self) -> None:
- access_token = self.get_success(
- self.auth_handler.create_access_token_for_user_id(
- self.user, device_id=None, valid_until_ms=None
- )
- )
-
- channel = self.make_request("GET", self.url, access_token=access_token)
- capabilities = channel.json_body["capabilities"]
-
- self.assertEqual(channel.code, 200)
- self.assertNotIn(
- "org.matrix.msc3244.room_capabilities", capabilities["m.room_versions"]
+ self.assertFalse(capabilities["m.set_displayname"]["enabled"])
+ self.assertTrue(capabilities["uk.tcpip.msc4133.profile_fields"]["enabled"])
+ self.assertEqual(
+ capabilities["uk.tcpip.msc4133.profile_fields"]["disallowed"],
+ ["displayname"],
)
- def test_get_does_include_msc3244_fields_when_enabled(self) -> None:
- access_token = self.get_success(
- self.auth_handler.create_access_token_for_user_id(
- self.user, device_id=None, valid_until_ms=None
- )
- )
+ @override_config(
+ {
+ "enable_set_avatar_url": False,
+ "experimental_features": {"msc4133_enabled": True},
+ }
+ )
+ def test_get_set_avatar_url_capabilities_avatar_url_disabled_msc4133(self) -> None:
+ """Test if set avatar_url is disabled that the server responds it."""
+ access_token = self.login(self.localpart, self.password)
channel = self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
- self.assertEqual(channel.code, 200)
- for details in capabilities["m.room_versions"][
- "org.matrix.msc3244.room_capabilities"
- ].values():
- if details["preferred"] is not None:
- self.assertTrue(
- details["preferred"] in KNOWN_ROOM_VERSIONS,
- str(details["preferred"]),
- )
-
- self.assertGreater(len(details["support"]), 0)
- for room_version in details["support"]:
- self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, str(room_version))
+ self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertFalse(capabilities["m.set_avatar_url"]["enabled"])
+ self.assertTrue(capabilities["uk.tcpip.msc4133.profile_fields"]["enabled"])
+ self.assertEqual(
+ capabilities["uk.tcpip.msc4133.profile_fields"]["disallowed"],
+ ["avatar_url"],
+ )
def test_get_get_token_login_fields_when_disabled(self) -> None:
"""By default login via an existing session is disabled."""
diff --git a/tests/rest/client/test_delayed_events.py b/tests/rest/client/test_delayed_events.py
new file mode 100644
index 0000000000..9f9d241f12
--- /dev/null
+++ b/tests/rest/client/test_delayed_events.py
@@ -0,0 +1,610 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+
+"""Tests REST events for /delayed_events paths."""
+
+from http import HTTPStatus
+from typing import List
+
+from parameterized import parameterized
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.errors import Codes
+from synapse.rest import admin
+from synapse.rest.client import delayed_events, login, room, versions
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
+
+from tests import unittest
+from tests.unittest import HomeserverTestCase
+
+PATH_PREFIX = "/_matrix/client/unstable/org.matrix.msc4140/delayed_events"
+
+_EVENT_TYPE = "com.example.test"
+
+
+class DelayedEventsUnstableSupportTestCase(HomeserverTestCase):
+ servlets = [versions.register_servlets]
+
+ def test_false_by_default(self) -> None:
+ channel = self.make_request("GET", "/_matrix/client/versions")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertFalse(channel.json_body["unstable_features"]["org.matrix.msc4140"])
+
+ @unittest.override_config({"max_event_delay_duration": "24h"})
+ def test_true_if_enabled(self) -> None:
+ channel = self.make_request("GET", "/_matrix/client/versions")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertTrue(channel.json_body["unstable_features"]["org.matrix.msc4140"])
+
+
+class DelayedEventsTestCase(HomeserverTestCase):
+ """Tests getting and managing delayed events."""
+
+ servlets = [
+ admin.register_servlets,
+ delayed_events.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["max_event_delay_duration"] = "24h"
+ return config
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.user1_user_id = self.register_user("user1", "pass")
+ self.user1_access_token = self.login("user1", "pass")
+ self.user2_user_id = self.register_user("user2", "pass")
+ self.user2_access_token = self.login("user2", "pass")
+
+ self.room_id = self.helper.create_room_as(
+ self.user1_user_id,
+ tok=self.user1_access_token,
+ extra_content={
+ "preset": "public_chat",
+ "power_level_content_override": {
+ "events": {
+ _EVENT_TYPE: 0,
+ }
+ },
+ },
+ )
+
+ self.helper.join(
+ room=self.room_id, user=self.user2_user_id, tok=self.user2_access_token
+ )
+
+ def test_delayed_events_empty_on_startup(self) -> None:
+ self.assertListEqual([], self._get_delayed_events())
+
+ def test_delayed_state_events_are_sent_on_timeout(self) -> None:
+ state_key = "to_send_on_timeout"
+
+ setter_key = "setter"
+ setter_expected = "on_timeout"
+ channel = self.make_request(
+ "PUT",
+ _get_path_for_delayed_state(self.room_id, _EVENT_TYPE, state_key, 900),
+ {
+ setter_key: setter_expected,
+ },
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+ events = self._get_delayed_events()
+ self.assertEqual(1, len(events), events)
+ content = self._get_delayed_event_content(events[0])
+ self.assertEqual(setter_expected, content.get(setter_key), content)
+ self.helper.get_state(
+ self.room_id,
+ _EVENT_TYPE,
+ self.user1_access_token,
+ state_key=state_key,
+ expect_code=HTTPStatus.NOT_FOUND,
+ )
+
+ self.reactor.advance(1)
+ self.assertListEqual([], self._get_delayed_events())
+ content = self.helper.get_state(
+ self.room_id,
+ _EVENT_TYPE,
+ self.user1_access_token,
+ state_key=state_key,
+ )
+ self.assertEqual(setter_expected, content.get(setter_key), content)
+
+ @unittest.override_config(
+ {"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}}
+ )
+ def test_get_delayed_events_ratelimit(self) -> None:
+ args = ("GET", PATH_PREFIX, b"", self.user1_access_token)
+
+ channel = self.make_request(*args)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+
+ channel = self.make_request(*args)
+ self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
+
+ # Add the current user to the ratelimit overrides, allowing them no ratelimiting.
+ self.get_success(
+ self.hs.get_datastores().main.set_ratelimit_for_user(
+ self.user1_user_id, 0, 0
+ )
+ )
+
+ # Test that the request isn't ratelimited anymore.
+ channel = self.make_request(*args)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+
+ def test_update_delayed_event_without_id(self) -> None:
+ channel = self.make_request(
+ "POST",
+ f"{PATH_PREFIX}/",
+ access_token=self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result)
+
+ def test_update_delayed_event_without_body(self) -> None:
+ channel = self.make_request(
+ "POST",
+ f"{PATH_PREFIX}/abc",
+ access_token=self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
+ self.assertEqual(
+ Codes.NOT_JSON,
+ channel.json_body["errcode"],
+ )
+
+ def test_update_delayed_event_without_action(self) -> None:
+ channel = self.make_request(
+ "POST",
+ f"{PATH_PREFIX}/abc",
+ {},
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
+ self.assertEqual(
+ Codes.MISSING_PARAM,
+ channel.json_body["errcode"],
+ )
+
+ def test_update_delayed_event_with_invalid_action(self) -> None:
+ channel = self.make_request(
+ "POST",
+ f"{PATH_PREFIX}/abc",
+ {"action": "oops"},
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
+ self.assertEqual(
+ Codes.INVALID_PARAM,
+ channel.json_body["errcode"],
+ )
+
+ @parameterized.expand(["cancel", "restart", "send"])
+ def test_update_delayed_event_without_match(self, action: str) -> None:
+ channel = self.make_request(
+ "POST",
+ f"{PATH_PREFIX}/abc",
+ {"action": action},
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result)
+
+ def test_cancel_delayed_state_event(self) -> None:
+ state_key = "to_never_send"
+
+ setter_key = "setter"
+ setter_expected = "none"
+ channel = self.make_request(
+ "PUT",
+ _get_path_for_delayed_state(self.room_id, _EVENT_TYPE, state_key, 1500),
+ {
+ setter_key: setter_expected,
+ },
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+ delay_id = channel.json_body.get("delay_id")
+ self.assertIsNotNone(delay_id)
+
+ self.reactor.advance(1)
+ events = self._get_delayed_events()
+ self.assertEqual(1, len(events), events)
+ content = self._get_delayed_event_content(events[0])
+ self.assertEqual(setter_expected, content.get(setter_key), content)
+ self.helper.get_state(
+ self.room_id,
+ _EVENT_TYPE,
+ self.user1_access_token,
+ state_key=state_key,
+ expect_code=HTTPStatus.NOT_FOUND,
+ )
+
+ channel = self.make_request(
+ "POST",
+ f"{PATH_PREFIX}/{delay_id}",
+ {"action": "cancel"},
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+ self.assertListEqual([], self._get_delayed_events())
+
+ self.reactor.advance(1)
+ content = self.helper.get_state(
+ self.room_id,
+ _EVENT_TYPE,
+ self.user1_access_token,
+ state_key=state_key,
+ expect_code=HTTPStatus.NOT_FOUND,
+ )
+
+ @unittest.override_config(
+ {"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}}
+ )
+ def test_cancel_delayed_event_ratelimit(self) -> None:
+ delay_ids = []
+ for _ in range(2):
+ channel = self.make_request(
+ "POST",
+ _get_path_for_delayed_send(self.room_id, _EVENT_TYPE, 100000),
+ {},
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+ delay_id = channel.json_body.get("delay_id")
+ self.assertIsNotNone(delay_id)
+ delay_ids.append(delay_id)
+
+ channel = self.make_request(
+ "POST",
+ f"{PATH_PREFIX}/{delay_ids.pop(0)}",
+ {"action": "cancel"},
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+
+ args = (
+ "POST",
+ f"{PATH_PREFIX}/{delay_ids.pop(0)}",
+ {"action": "cancel"},
+ self.user1_access_token,
+ )
+ channel = self.make_request(*args)
+ self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
+
+ # Add the current user to the ratelimit overrides, allowing them no ratelimiting.
+ self.get_success(
+ self.hs.get_datastores().main.set_ratelimit_for_user(
+ self.user1_user_id, 0, 0
+ )
+ )
+
+ # Test that the request isn't ratelimited anymore.
+ channel = self.make_request(*args)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+
+ def test_send_delayed_state_event(self) -> None:
+ state_key = "to_send_on_request"
+
+ setter_key = "setter"
+ setter_expected = "on_send"
+ channel = self.make_request(
+ "PUT",
+ _get_path_for_delayed_state(self.room_id, _EVENT_TYPE, state_key, 100000),
+ {
+ setter_key: setter_expected,
+ },
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+ delay_id = channel.json_body.get("delay_id")
+ self.assertIsNotNone(delay_id)
+
+ self.reactor.advance(1)
+ events = self._get_delayed_events()
+ self.assertEqual(1, len(events), events)
+ content = self._get_delayed_event_content(events[0])
+ self.assertEqual(setter_expected, content.get(setter_key), content)
+ self.helper.get_state(
+ self.room_id,
+ _EVENT_TYPE,
+ self.user1_access_token,
+ state_key=state_key,
+ expect_code=HTTPStatus.NOT_FOUND,
+ )
+
+ channel = self.make_request(
+ "POST",
+ f"{PATH_PREFIX}/{delay_id}",
+ {"action": "send"},
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+ self.assertListEqual([], self._get_delayed_events())
+ content = self.helper.get_state(
+ self.room_id,
+ _EVENT_TYPE,
+ self.user1_access_token,
+ state_key=state_key,
+ )
+ self.assertEqual(setter_expected, content.get(setter_key), content)
+
+ @unittest.override_config({"rc_message": {"per_second": 3.5, "burst_count": 4}})
+ def test_send_delayed_event_ratelimit(self) -> None:
+ delay_ids = []
+ for _ in range(2):
+ channel = self.make_request(
+ "POST",
+ _get_path_for_delayed_send(self.room_id, _EVENT_TYPE, 100000),
+ {},
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+ delay_id = channel.json_body.get("delay_id")
+ self.assertIsNotNone(delay_id)
+ delay_ids.append(delay_id)
+
+ channel = self.make_request(
+ "POST",
+ f"{PATH_PREFIX}/{delay_ids.pop(0)}",
+ {"action": "send"},
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+
+ args = (
+ "POST",
+ f"{PATH_PREFIX}/{delay_ids.pop(0)}",
+ {"action": "send"},
+ self.user1_access_token,
+ )
+ channel = self.make_request(*args)
+ self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
+
+ # Add the current user to the ratelimit overrides, allowing them no ratelimiting.
+ self.get_success(
+ self.hs.get_datastores().main.set_ratelimit_for_user(
+ self.user1_user_id, 0, 0
+ )
+ )
+
+ # Test that the request isn't ratelimited anymore.
+ channel = self.make_request(*args)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+
+ def test_restart_delayed_state_event(self) -> None:
+ state_key = "to_send_on_restarted_timeout"
+
+ setter_key = "setter"
+ setter_expected = "on_timeout"
+ channel = self.make_request(
+ "PUT",
+ _get_path_for_delayed_state(self.room_id, _EVENT_TYPE, state_key, 1500),
+ {
+ setter_key: setter_expected,
+ },
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+ delay_id = channel.json_body.get("delay_id")
+ self.assertIsNotNone(delay_id)
+
+ self.reactor.advance(1)
+ events = self._get_delayed_events()
+ self.assertEqual(1, len(events), events)
+ content = self._get_delayed_event_content(events[0])
+ self.assertEqual(setter_expected, content.get(setter_key), content)
+ self.helper.get_state(
+ self.room_id,
+ _EVENT_TYPE,
+ self.user1_access_token,
+ state_key=state_key,
+ expect_code=HTTPStatus.NOT_FOUND,
+ )
+
+ channel = self.make_request(
+ "POST",
+ f"{PATH_PREFIX}/{delay_id}",
+ {"action": "restart"},
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+
+ self.reactor.advance(1)
+ events = self._get_delayed_events()
+ self.assertEqual(1, len(events), events)
+ content = self._get_delayed_event_content(events[0])
+ self.assertEqual(setter_expected, content.get(setter_key), content)
+ self.helper.get_state(
+ self.room_id,
+ _EVENT_TYPE,
+ self.user1_access_token,
+ state_key=state_key,
+ expect_code=HTTPStatus.NOT_FOUND,
+ )
+
+ self.reactor.advance(1)
+ self.assertListEqual([], self._get_delayed_events())
+ content = self.helper.get_state(
+ self.room_id,
+ _EVENT_TYPE,
+ self.user1_access_token,
+ state_key=state_key,
+ )
+ self.assertEqual(setter_expected, content.get(setter_key), content)
+
+ @unittest.override_config(
+ {"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}}
+ )
+ def test_restart_delayed_event_ratelimit(self) -> None:
+ delay_ids = []
+ for _ in range(2):
+ channel = self.make_request(
+ "POST",
+ _get_path_for_delayed_send(self.room_id, _EVENT_TYPE, 100000),
+ {},
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+ delay_id = channel.json_body.get("delay_id")
+ self.assertIsNotNone(delay_id)
+ delay_ids.append(delay_id)
+
+ channel = self.make_request(
+ "POST",
+ f"{PATH_PREFIX}/{delay_ids.pop(0)}",
+ {"action": "restart"},
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+
+ args = (
+ "POST",
+ f"{PATH_PREFIX}/{delay_ids.pop(0)}",
+ {"action": "restart"},
+ self.user1_access_token,
+ )
+ channel = self.make_request(*args)
+ self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
+
+ # Add the current user to the ratelimit overrides, allowing them no ratelimiting.
+ self.get_success(
+ self.hs.get_datastores().main.set_ratelimit_for_user(
+ self.user1_user_id, 0, 0
+ )
+ )
+
+ # Test that the request isn't ratelimited anymore.
+ channel = self.make_request(*args)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+
+ def test_delayed_state_is_not_cancelled_by_new_state_from_same_user(
+ self,
+ ) -> None:
+ state_key = "to_not_be_cancelled_by_same_user"
+
+ setter_key = "setter"
+ setter_expected = "on_timeout"
+ channel = self.make_request(
+ "PUT",
+ _get_path_for_delayed_state(self.room_id, _EVENT_TYPE, state_key, 900),
+ {
+ setter_key: setter_expected,
+ },
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+ events = self._get_delayed_events()
+ self.assertEqual(1, len(events), events)
+
+ self.helper.send_state(
+ self.room_id,
+ _EVENT_TYPE,
+ {
+ setter_key: "manual",
+ },
+ self.user1_access_token,
+ state_key=state_key,
+ )
+ events = self._get_delayed_events()
+ self.assertEqual(1, len(events), events)
+
+ self.reactor.advance(1)
+ content = self.helper.get_state(
+ self.room_id,
+ _EVENT_TYPE,
+ self.user1_access_token,
+ state_key=state_key,
+ )
+ self.assertEqual(setter_expected, content.get(setter_key), content)
+
+ def test_delayed_state_is_cancelled_by_new_state_from_other_user(
+ self,
+ ) -> None:
+ state_key = "to_be_cancelled_by_other_user"
+
+ setter_key = "setter"
+ channel = self.make_request(
+ "PUT",
+ _get_path_for_delayed_state(self.room_id, _EVENT_TYPE, state_key, 900),
+ {
+ setter_key: "on_timeout",
+ },
+ self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+ events = self._get_delayed_events()
+ self.assertEqual(1, len(events), events)
+
+ setter_expected = "other_user"
+ self.helper.send_state(
+ self.room_id,
+ _EVENT_TYPE,
+ {
+ setter_key: setter_expected,
+ },
+ self.user2_access_token,
+ state_key=state_key,
+ )
+ self.assertListEqual([], self._get_delayed_events())
+
+ self.reactor.advance(1)
+ content = self.helper.get_state(
+ self.room_id,
+ _EVENT_TYPE,
+ self.user1_access_token,
+ state_key=state_key,
+ )
+ self.assertEqual(setter_expected, content.get(setter_key), content)
+
+ def _get_delayed_events(self) -> List[JsonDict]:
+ channel = self.make_request(
+ "GET",
+ PATH_PREFIX,
+ access_token=self.user1_access_token,
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+
+ key = "delayed_events"
+ self.assertIn(key, channel.json_body)
+
+ events = channel.json_body[key]
+ self.assertIsInstance(events, list)
+
+ return events
+
+ def _get_delayed_event_content(self, event: JsonDict) -> JsonDict:
+ key = "content"
+ self.assertIn(key, event)
+
+ content = event[key]
+ self.assertIsInstance(content, dict)
+
+ return content
+
+
+def _get_path_for_delayed_state(
+ room_id: str, event_type: str, state_key: str, delay_ms: int
+) -> str:
+ return f"rooms/{room_id}/state/{event_type}/{state_key}?org.matrix.msc4140.delay={delay_ms}"
+
+
+def _get_path_for_delayed_send(room_id: str, event_type: str, delay_ms: int) -> str:
+ return f"rooms/{room_id}/send/{event_type}?org.matrix.msc4140.delay={delay_ms}"
diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py
index a3ed12a38f..dd3abdebac 100644
--- a/tests/rest/client/test_devices.py
+++ b/tests/rest/client/test_devices.py
@@ -24,6 +24,7 @@ from twisted.internet.defer import ensureDeferred
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import NotFoundError
+from synapse.appservice import ApplicationService
from synapse.rest import admin, devices, sync
from synapse.rest.client import keys, login, register
from synapse.server import HomeServer
@@ -455,3 +456,183 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
token,
)
self.assertEqual(channel.json_body["device_keys"], {"@mikey:test": {}})
+
+
+class MSC4190AppserviceDevicesTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ register.register_servlets,
+ devices.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.hs = self.setup_test_homeserver()
+
+ # This application service uses the new MSC4190 behaviours
+ self.msc4190_service = ApplicationService(
+ id="msc4190",
+ token="some_token",
+ hs_token="some_token",
+ sender="@as:example.com",
+ namespaces={
+ ApplicationService.NS_USERS: [{"regex": "@.*", "exclusive": False}]
+ },
+ msc4190_device_management=True,
+ )
+ # This application service doesn't use the new MSC4190 behaviours
+ self.pre_msc_service = ApplicationService(
+ id="regular",
+ token="other_token",
+ hs_token="other_token",
+ sender="@as2:example.com",
+ namespaces={
+ ApplicationService.NS_USERS: [{"regex": "@.*", "exclusive": False}]
+ },
+ msc4190_device_management=False,
+ )
+ self.hs.get_datastores().main.services_cache.append(self.msc4190_service)
+ self.hs.get_datastores().main.services_cache.append(self.pre_msc_service)
+ return self.hs
+
+ def test_PUT_device(self) -> None:
+ self.register_appservice_user("alice", self.msc4190_service.token)
+ self.register_appservice_user("bob", self.pre_msc_service.token)
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v3/devices?user_id=@alice:test",
+ access_token=self.msc4190_service.token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.json_body, {"devices": []})
+
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/v3/devices/AABBCCDD?user_id=@alice:test",
+ content={"display_name": "Alice's device"},
+ access_token=self.msc4190_service.token,
+ )
+ self.assertEqual(channel.code, 201, channel.json_body)
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v3/devices?user_id=@alice:test",
+ access_token=self.msc4190_service.token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(len(channel.json_body["devices"]), 1)
+ self.assertEqual(channel.json_body["devices"][0]["device_id"], "AABBCCDD")
+
+ # Doing a second time should return a 200 instead of a 201
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/v3/devices/AABBCCDD?user_id=@alice:test",
+ content={"display_name": "Alice's device"},
+ access_token=self.msc4190_service.token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # On the regular service, that API should not allow for the
+ # creation of new devices.
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/v3/devices/AABBCCDD?user_id=@bob:test",
+ content={"display_name": "Bob's device"},
+ access_token=self.pre_msc_service.token,
+ )
+ self.assertEqual(channel.code, 404, channel.json_body)
+
+ def test_DELETE_device(self) -> None:
+ self.register_appservice_user("alice", self.msc4190_service.token)
+
+ # There should be no device
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v3/devices?user_id=@alice:test",
+ access_token=self.msc4190_service.token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.json_body, {"devices": []})
+
+ # Create a device
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/v3/devices/AABBCCDD?user_id=@alice:test",
+ content={},
+ access_token=self.msc4190_service.token,
+ )
+ self.assertEqual(channel.code, 201, channel.json_body)
+
+ # There should be one device
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v3/devices?user_id=@alice:test",
+ access_token=self.msc4190_service.token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(len(channel.json_body["devices"]), 1)
+
+ # Delete the device. UIA should not be required.
+ channel = self.make_request(
+ "DELETE",
+ "/_matrix/client/v3/devices/AABBCCDD?user_id=@alice:test",
+ access_token=self.msc4190_service.token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # There should be no device again
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v3/devices?user_id=@alice:test",
+ access_token=self.msc4190_service.token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.json_body, {"devices": []})
+
+ def test_POST_delete_devices(self) -> None:
+ self.register_appservice_user("alice", self.msc4190_service.token)
+
+ # There should be no device
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v3/devices?user_id=@alice:test",
+ access_token=self.msc4190_service.token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.json_body, {"devices": []})
+
+ # Create a device
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/v3/devices/AABBCCDD?user_id=@alice:test",
+ content={},
+ access_token=self.msc4190_service.token,
+ )
+ self.assertEqual(channel.code, 201, channel.json_body)
+
+ # There should be one device
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v3/devices?user_id=@alice:test",
+ access_token=self.msc4190_service.token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(len(channel.json_body["devices"]), 1)
+
+ # Delete the device with delete_devices
+ # UIA should not be required.
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/delete_devices?user_id=@alice:test",
+ content={"devices": ["AABBCCDD"]},
+ access_token=self.msc4190_service.token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # There should be no device again
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v3/devices?user_id=@alice:test",
+ access_token=self.msc4190_service.token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ self.assertEqual(channel.json_body, {"devices": []})
diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py
index 06f1c1b234..039144fdbe 100644
--- a/tests/rest/client/test_events.py
+++ b/tests/rest/client/test_events.py
@@ -19,7 +19,7 @@
#
#
-""" Tests REST events for /events paths."""
+"""Tests REST events for /events paths."""
from unittest.mock import Mock
diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
deleted file mode 100644
index 63c2c5923e..0000000000
--- a/tests/rest/client/test_identity.py
+++ /dev/null
@@ -1,67 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright (C) 2023 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-# Originally licensed under the Apache License, Version 2.0:
-# <http://www.apache.org/licenses/LICENSE-2.0>.
-#
-# [This file includes modifications made by New Vector Limited]
-#
-#
-
-from http import HTTPStatus
-
-from twisted.test.proto_helpers import MemoryReactor
-
-import synapse.rest.admin
-from synapse.rest.client import login, room
-from synapse.server import HomeServer
-from synapse.util import Clock
-
-from tests import unittest
-
-
-class IdentityTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets_for_client_rest_resource,
- room.register_servlets,
- login.register_servlets,
- ]
-
- def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- config = self.default_config()
- config["enable_3pid_lookup"] = False
- self.hs = self.setup_test_homeserver(config=config)
-
- return self.hs
-
- def test_3pid_lookup_disabled(self) -> None:
- self.hs.config.registration.enable_3pid_lookup = False
-
- self.register_user("kermit", "monkey")
- tok = self.login("kermit", "monkey")
-
- channel = self.make_request(b"POST", "/createRoom", b"{}", access_token=tok)
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
- room_id = channel.json_body["room_id"]
-
- request_data = {
- "id_server": "testis",
- "medium": "email",
- "address": "test@example.com",
- "id_access_token": tok,
- }
- request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
- channel = self.make_request(
- b"POST", request_url, request_data, access_token=tok
- )
- self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py
index 8bbd109092..d9a210b616 100644
--- a/tests/rest/client/test_keys.py
+++ b/tests/rest/client/test_keys.py
@@ -315,9 +315,7 @@ class SigningKeyUploadServletTestCase(unittest.HomeserverTestCase):
"master_key": master_key2,
},
)
- self.assertEqual(
- channel.code, HTTPStatus.NOT_IMPLEMENTED, channel.json_body
- )
+ self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.json_body)
# Pretend that MAS did UIA and allowed us to replace the master key.
channel = self.make_request(
@@ -349,9 +347,7 @@ class SigningKeyUploadServletTestCase(unittest.HomeserverTestCase):
"master_key": master_key3,
},
)
- self.assertEqual(
- channel.code, HTTPStatus.NOT_IMPLEMENTED, channel.json_body
- )
+ self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.json_body)
# Pretend that MAS did UIA and allowed us to replace the master key.
channel = self.make_request(
@@ -376,6 +372,4 @@ class SigningKeyUploadServletTestCase(unittest.HomeserverTestCase):
"master_key": master_key3,
},
)
- self.assertEqual(
- channel.code, HTTPStatus.NOT_IMPLEMENTED, channel.json_body
- )
+ self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.json_body)
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index 2b1e44381b..24e2288ee3 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -27,6 +27,7 @@ from typing import (
Collection,
Dict,
List,
+ Literal,
Optional,
Tuple,
Union,
@@ -35,7 +36,6 @@ from unittest.mock import Mock
from urllib.parse import urlencode
import pymacaroons
-from typing_extensions import Literal
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
@@ -43,6 +43,7 @@ from twisted.web.resource import Resource
import synapse.rest.admin
from synapse.api.constants import ApprovalNoticeMedium, LoginType
from synapse.api.errors import Codes
+from synapse.api.urls import LoginSSORedirectURIBuilder
from synapse.appservice import ApplicationService
from synapse.http.client import RawHeaders
from synapse.module_api import ModuleApi
@@ -55,7 +56,6 @@ from synapse.util import Clock
from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC
-from tests.handlers.test_saml import has_saml2
from tests.rest.client.utils import TEST_OIDC_CONFIG
from tests.server import FakeChannel
from tests.test_utils.html_parsers import TestHtmlParser
@@ -69,6 +69,10 @@ try:
except ImportError:
HAS_JWT = False
+import logging
+
+logger = logging.getLogger(__name__)
+
# synapse server name: used to populate public_baseurl in some tests
SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse"
@@ -77,22 +81,7 @@ SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse"
# FakeChannel.isSecure() returns False, so synapse will see the requested uri as
# http://..., so using http in the public_baseurl stops Synapse trying to redirect to
# https://....
-BASE_URL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,)
-
-# CAS server used in some tests
-CAS_SERVER = "https://fake.test"
-
-# just enough to tell pysaml2 where to redirect to
-SAML_SERVER = "https://test.saml.server/idp/sso"
-TEST_SAML_METADATA = """
-<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata">
- <md:IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
- <md:SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="%(SAML_SERVER)s"/>
- </md:IDPSSODescriptor>
-</md:EntityDescriptor>
-""" % {
- "SAML_SERVER": SAML_SERVER,
-}
+PUBLIC_BASEURL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,)
LOGIN_URL = b"/_matrix/client/r0/login"
TEST_URL = b"/_matrix/client/r0/account/whoami"
@@ -109,6 +98,23 @@ ADDITIONAL_LOGIN_FLOWS = [
]
+def get_relative_uri_from_absolute_uri(absolute_uri: str) -> str:
+ """
+ Peels off the path and query string from an absolute URI. Useful when interacting
+ with `make_request(...)` util function which expects a relative path instead of a
+ full URI.
+ """
+ parsed_uri = urllib.parse.urlparse(absolute_uri)
+ # Sanity check that we're working with an absolute URI
+ assert parsed_uri.scheme == "http" or parsed_uri.scheme == "https"
+
+ relative_uri = parsed_uri.path
+ if parsed_uri.query:
+ relative_uri += "?" + parsed_uri.query
+
+ return relative_uri
+
+
class TestSpamChecker:
def __init__(self, config: None, api: ModuleApi):
api.register_spam_checker_callbacks(
@@ -172,7 +178,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.hs = self.setup_test_homeserver()
self.hs.config.registration.enable_registration = True
- self.hs.config.registration.registrations_require_3pid = []
self.hs.config.registration.auto_join_rooms = []
self.hs.config.captcha.enable_registration_captcha = False
@@ -603,7 +608,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
)
-@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
+@skip_unless(HAS_OIDC, "Requires OIDC")
class MultiSSOTestCase(unittest.HomeserverTestCase):
"""Tests for homeservers with multiple SSO providers enabled"""
@@ -614,21 +619,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
def default_config(self) -> Dict[str, Any]:
config = super().default_config()
- config["public_baseurl"] = BASE_URL
-
- config["cas_config"] = {
- "enabled": True,
- "server_url": CAS_SERVER,
- "service_url": "https://matrix.goodserver.com:8448",
- }
-
- config["saml2_config"] = {
- "sp_config": {
- "metadata": {"inline": [TEST_SAML_METADATA]},
- # use the XMLSecurity backend to avoid relying on xmlsec1
- "crypto_backend": "XMLSecurity",
- },
- }
+ config["public_baseurl"] = PUBLIC_BASEURL
# default OIDC provider
config["oidc_config"] = TEST_OIDC_CONFIG
@@ -653,6 +644,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
]
return config
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.login_sso_redirect_url_builder = LoginSSORedirectURIBuilder(hs.config)
+
def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict()
d.update(build_synapse_client_resource_tree(self.hs))
@@ -664,7 +658,6 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
expected_flow_types = [
- "m.login.cas",
"m.login.sso",
"m.login.token",
"m.login.password",
@@ -678,8 +671,6 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
self.assertCountEqual(
flows["m.login.sso"]["identity_providers"],
[
- {"id": "cas", "name": "CAS"},
- {"id": "saml", "name": "SAML"},
{"id": "oidc-idp1", "name": "IDP1"},
{"id": "oidc", "name": "OIDC"},
],
@@ -713,56 +704,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
self.assertEqual(params["redirectUrl"], [TEST_CLIENT_REDIRECT_URL])
returned_idps.append(params["idp"][0])
- self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"])
-
- def test_multi_sso_redirect_to_cas(self) -> None:
- """If CAS is chosen, should redirect to the CAS server"""
-
- channel = self.make_request(
- "GET",
- "/_synapse/client/pick_idp?redirectUrl="
- + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
- + "&idp=cas",
- shorthand=False,
- )
- self.assertEqual(channel.code, 302, channel.result)
- location_headers = channel.headers.getRawHeaders("Location")
- assert location_headers
- cas_uri = location_headers[0]
- cas_uri_path, cas_uri_query = cas_uri.split("?", 1)
-
- # it should redirect us to the login page of the cas server
- self.assertEqual(cas_uri_path, CAS_SERVER + "/login")
-
- # check that the redirectUrl is correctly encoded in the service param - ie, the
- # place that CAS will redirect to
- cas_uri_params = urllib.parse.parse_qs(cas_uri_query)
- service_uri = cas_uri_params["service"][0]
- _, service_uri_query = service_uri.split("?", 1)
- service_uri_params = urllib.parse.parse_qs(service_uri_query)
- self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL)
-
- def test_multi_sso_redirect_to_saml(self) -> None:
- """If SAML is chosen, should redirect to the SAML server"""
- channel = self.make_request(
- "GET",
- "/_synapse/client/pick_idp?redirectUrl="
- + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
- + "&idp=saml",
- )
- self.assertEqual(channel.code, 302, channel.result)
- location_headers = channel.headers.getRawHeaders("Location")
- assert location_headers
- saml_uri = location_headers[0]
- saml_uri_path, saml_uri_query = saml_uri.split("?", 1)
-
- # it should redirect us to the login page of the SAML server
- self.assertEqual(saml_uri_path, SAML_SERVER)
-
- # the RelayState is used to carry the client redirect url
- saml_uri_params = urllib.parse.parse_qs(saml_uri_query)
- relay_state_param = saml_uri_params["RelayState"][0]
- self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL)
+ self.assertCountEqual(returned_idps, ["oidc", "oidc-idp1"])
def test_login_via_oidc(self) -> None:
"""If OIDC is chosen, should redirect to the OIDC auth endpoint"""
@@ -773,13 +715,38 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
# pick the default OIDC provider
channel = self.make_request(
"GET",
- "/_synapse/client/pick_idp?redirectUrl="
- + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
- + "&idp=oidc",
+ f"/_synapse/client/pick_idp?redirectUrl={urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)}&idp=oidc",
)
self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
+ sso_login_redirect_uri = location_headers[0]
+
+ # it should redirect us to the standard login SSO redirect flow
+ self.assertEqual(
+ sso_login_redirect_uri,
+ self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
+ idp_id="oidc", client_redirect_url=TEST_CLIENT_REDIRECT_URL
+ ),
+ )
+
+ with fake_oidc_server.patch_homeserver(hs=self.hs):
+ # follow the redirect
+ channel = self.make_request(
+ "GET",
+ # We have to make this relative to be compatible with `make_request(...)`
+ get_relative_uri_from_absolute_uri(sso_login_redirect_uri),
+ # We have to set the Host header to match the `public_baseurl` to avoid
+ # the extra redirect in the `SsoRedirectServlet` in order for the
+ # cookies to be visible.
+ custom_headers=[
+ ("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME),
+ ],
+ )
+
+ self.assertEqual(channel.code, 302, channel.result)
+ location_headers = channel.headers.getRawHeaders("Location")
+ assert location_headers
oidc_uri = location_headers[0]
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
@@ -838,12 +805,38 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
self.assertEqual(chan.json_body["user_id"], "@user1:test")
def test_multi_sso_redirect_to_unknown(self) -> None:
- """An unknown IdP should cause a 400"""
+ """An unknown IdP should cause a 404"""
channel = self.make_request(
"GET",
"/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
)
- self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.code, 302, channel.result)
+ location_headers = channel.headers.getRawHeaders("Location")
+ assert location_headers
+ sso_login_redirect_uri = location_headers[0]
+
+ # it should redirect us to the standard login SSO redirect flow
+ self.assertEqual(
+ sso_login_redirect_uri,
+ self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
+ idp_id="xyz", client_redirect_url="http://x"
+ ),
+ )
+
+ # follow the redirect
+ channel = self.make_request(
+ "GET",
+ # We have to make this relative to be compatible with `make_request(...)`
+ get_relative_uri_from_absolute_uri(sso_login_redirect_uri),
+ # We have to set the Host header to match the `public_baseurl` to avoid
+ # the extra redirect in the `SsoRedirectServlet` in order for the
+ # cookies to be visible.
+ custom_headers=[
+ ("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME),
+ ],
+ )
+
+ self.assertEqual(channel.code, 404, channel.result)
def test_client_idp_redirect_to_unknown(self) -> None:
"""If the client tries to pick an unknown IdP, return a 404"""
@@ -891,162 +884,12 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
raise ValueError("No %s caveat in macaroon" % (key,))
-class CASTestCase(unittest.HomeserverTestCase):
- servlets = [
- login.register_servlets,
- ]
-
- def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- self.base_url = "https://matrix.goodserver.com/"
- self.redirect_path = "_synapse/client/login/sso/redirect/confirm"
-
- config = self.default_config()
- config["public_baseurl"] = (
- config.get("public_baseurl") or "https://matrix.goodserver.com:8448"
- )
- config["cas_config"] = {
- "enabled": True,
- "server_url": CAS_SERVER,
- }
-
- cas_user_id = "username"
- self.user_id = "@%s:test" % cas_user_id
-
- async def get_raw(uri: str, args: Any) -> bytes:
- """Return an example response payload from a call to the `/proxyValidate`
- endpoint of a CAS server, copied from
- https://apereo.github.io/cas/5.0.x/protocol/CAS-Protocol-V2-Specification.html#26-proxyvalidate-cas-20
-
- This needs to be returned by an async function (as opposed to set as the
- mock's return value) because the corresponding Synapse code awaits on it.
- """
- return (
- """
- <cas:serviceResponse xmlns:cas='http://www.yale.edu/tp/cas'>
- <cas:authenticationSuccess>
- <cas:user>%s</cas:user>
- <cas:proxyGrantingTicket>PGTIOU-84678-8a9d...</cas:proxyGrantingTicket>
- <cas:proxies>
- <cas:proxy>https://proxy2/pgtUrl</cas:proxy>
- <cas:proxy>https://proxy1/pgtUrl</cas:proxy>
- </cas:proxies>
- </cas:authenticationSuccess>
- </cas:serviceResponse>
- """
- % cas_user_id
- ).encode("utf-8")
-
- mocked_http_client = Mock(spec=["get_raw"])
- mocked_http_client.get_raw.side_effect = get_raw
-
- self.hs = self.setup_test_homeserver(
- config=config,
- proxied_http_client=mocked_http_client,
- )
-
- return self.hs
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.deactivate_account_handler = hs.get_deactivate_account_handler()
-
- def test_cas_redirect_confirm(self) -> None:
- """Tests that the SSO login flow serves a confirmation page before redirecting a
- user to the redirect URL.
- """
- base_url = "/_matrix/client/r0/login/cas/ticket?redirectUrl"
- redirect_url = "https://dodgy-site.com/"
-
- url_parts = list(urllib.parse.urlparse(base_url))
- query = dict(urllib.parse.parse_qsl(url_parts[4]))
- query.update({"redirectUrl": redirect_url})
- query.update({"ticket": "ticket"})
- url_parts[4] = urllib.parse.urlencode(query)
- cas_ticket_url = urllib.parse.urlunparse(url_parts)
-
- # Get Synapse to call the fake CAS and serve the template.
- channel = self.make_request("GET", cas_ticket_url)
-
- # Test that the response is HTML.
- self.assertEqual(channel.code, 200, channel.result)
- content_type_header_value = ""
- for header in channel.headers.getRawHeaders("Content-Type", []):
- content_type_header_value = header
-
- self.assertTrue(content_type_header_value.startswith("text/html"))
-
- # Test that the body isn't empty.
- self.assertTrue(len(channel.result["body"]) > 0)
-
- # And that it contains our redirect link
- self.assertIn(redirect_url, channel.result["body"].decode("UTF-8"))
-
- @override_config(
- {
- "sso": {
- "client_whitelist": [
- "https://legit-site.com/",
- "https://other-site.com/",
- ]
- }
- }
- )
- def test_cas_redirect_whitelisted(self) -> None:
- """Tests that the SSO login flow serves a redirect to a whitelisted url"""
- self._test_redirect("https://legit-site.com/")
-
- @override_config({"public_baseurl": "https://example.com"})
- def test_cas_redirect_login_fallback(self) -> None:
- self._test_redirect("https://example.com/_matrix/static/client/login")
-
- def _test_redirect(self, redirect_url: str) -> None:
- """Tests that the SSO login flow serves a redirect for the given redirect URL."""
- cas_ticket_url = (
- "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket"
- % (urllib.parse.quote(redirect_url))
- )
-
- # Get Synapse to call the fake CAS and serve the template.
- channel = self.make_request("GET", cas_ticket_url)
-
- self.assertEqual(channel.code, 302)
- location_headers = channel.headers.getRawHeaders("Location")
- assert location_headers
- self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url)
-
- @override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}})
- def test_deactivated_user(self) -> None:
- """Logging in as a deactivated account should error."""
- redirect_url = "https://legit-site.com/"
-
- # First login (to create the user).
- self._test_redirect(redirect_url)
-
- # Deactivate the account.
- self.get_success(
- self.deactivate_account_handler.deactivate_account(
- self.user_id, False, create_requester(self.user_id)
- )
- )
-
- # Request the CAS ticket.
- cas_ticket_url = (
- "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket"
- % (urllib.parse.quote(redirect_url))
- )
-
- # Get Synapse to call the fake CAS and serve the template.
- channel = self.make_request("GET", cas_ticket_url)
-
- # Because the user is deactivated they are served an error template.
- self.assertEqual(channel.code, 403)
- self.assertIn(b"SSO account deactivated", channel.result["body"])
-
-
@skip_unless(HAS_JWT, "requires authlib")
class JWTTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
+ profile.register_servlets,
]
jwt_secret = "secret"
@@ -1133,18 +976,18 @@ class JWTTestCase(unittest.HomeserverTestCase):
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
- self.assertEqual(
+ self.assertRegex(
channel.json_body["error"],
- 'JWT validation failed: invalid_claim: Invalid claim "iss"',
+ r"^JWT validation failed: invalid_claim: Invalid claim [\"']iss[\"']$",
)
# Not providing an issuer.
channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
- self.assertEqual(
+ self.assertRegex(
channel.json_body["error"],
- 'JWT validation failed: missing_claim: Missing "iss" claim',
+ r"^JWT validation failed: missing_claim: Missing [\"']iss[\"'] claim$",
)
def test_login_iss_no_config(self) -> None:
@@ -1165,18 +1008,18 @@ class JWTTestCase(unittest.HomeserverTestCase):
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
- self.assertEqual(
+ self.assertRegex(
channel.json_body["error"],
- 'JWT validation failed: invalid_claim: Invalid claim "aud"',
+ r"^JWT validation failed: invalid_claim: Invalid claim [\"']aud[\"']$",
)
# Not providing an audience.
channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
- self.assertEqual(
+ self.assertRegex(
channel.json_body["error"],
- 'JWT validation failed: missing_claim: Missing "aud" claim',
+ r"^JWT validation failed: missing_claim: Missing [\"']aud[\"'] claim$",
)
def test_login_aud_no_config(self) -> None:
@@ -1184,9 +1027,9 @@ class JWTTestCase(unittest.HomeserverTestCase):
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
self.assertEqual(channel.code, 403, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
- self.assertEqual(
+ self.assertRegex(
channel.json_body["error"],
- 'JWT validation failed: invalid_claim: Invalid claim "aud"',
+ r"^JWT validation failed: invalid_claim: Invalid claim [\"']aud[\"']$",
)
def test_login_default_sub(self) -> None:
@@ -1202,6 +1045,30 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@frog:test")
+ @override_config(
+ {"jwt_config": {**base_config, "display_name_claim": "display_name"}}
+ )
+ def test_login_custom_display_name(self) -> None:
+ """Test setting a custom display name."""
+ localpart = "pinkie"
+ user_id = f"@{localpart}:test"
+ display_name = "Pinkie Pie"
+
+ # Perform the login, specifying a custom display name.
+ channel = self.jwt_login({"sub": localpart, "display_name": display_name})
+ self.assertEqual(channel.code, 200, msg=channel.result)
+ self.assertEqual(channel.json_body["user_id"], user_id)
+
+ # Fetch the user's display name and check that it was set correctly.
+ access_token = channel.json_body["access_token"]
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v3/profile/{user_id}/displayname",
+ access_token=access_token,
+ )
+ self.assertEqual(channel.code, 200, msg=channel.result)
+ self.assertEqual(channel.json_body["displayname"], display_name)
+
def test_login_no_token(self) -> None:
params = {"type": "org.matrix.login.jwt"}
channel = self.make_request(b"POST", LOGIN_URL, params)
@@ -1448,7 +1315,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
def default_config(self) -> Dict[str, Any]:
config = super().default_config()
- config["public_baseurl"] = BASE_URL
+ config["public_baseurl"] = PUBLIC_BASEURL
config["oidc_config"] = {}
config["oidc_config"].update(TEST_OIDC_CONFIG)
@@ -1474,7 +1341,6 @@ class UsernamePickerTestCase(HomeserverTestCase):
self,
fake_oidc_server: FakeOidcServer,
displayname: str,
- email: str,
picture: str,
) -> Tuple[str, str]:
# do the start of the login flow
@@ -1483,8 +1349,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
{
"sub": "tester",
"displayname": displayname,
- "picture": picture,
- "email": email,
+ "picture": picture
},
TEST_CLIENT_REDIRECT_URL,
)
@@ -1513,7 +1378,6 @@ class UsernamePickerTestCase(HomeserverTestCase):
session = username_mapping_sessions[session_id]
self.assertEqual(session.remote_user_id, "tester")
self.assertEqual(session.display_name, displayname)
- self.assertEqual(session.emails, [email])
self.assertEqual(session.avatar_url, picture)
self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL)
@@ -1530,11 +1394,10 @@ class UsernamePickerTestCase(HomeserverTestCase):
mxid = "@bobby:test"
displayname = "Jonny"
- email = "bobby@test.com"
picture = "mxc://test/avatar_url"
picker_url, session_id = self.proceed_to_username_picker_page(
- fake_oidc_server, displayname, email, picture
+ fake_oidc_server, displayname, picture
)
# Now, submit a username to the username picker, which should serve a redirect
@@ -1544,8 +1407,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
{
b"username": b"bobby",
b"use_display_name": b"true",
- b"use_avatar": b"true",
- b"use_email": email,
+ b"use_avatar": b"true"
}
).encode("utf8")
chan = self.make_request(
@@ -1606,12 +1468,6 @@ class UsernamePickerTestCase(HomeserverTestCase):
self.assertIn("mxc://test", channel.json_body["avatar_url"])
self.assertEqual(displayname, channel.json_body["displayname"])
- # ensure the email from the OIDC response has been configured for the user.
- channel = self.make_request(
- "GET", "/account/3pid", access_token=chan.json_body["access_token"]
- )
- self.assertEqual(channel.code, 200, channel.result)
- self.assertEqual(email, channel.json_body["threepids"][0]["address"])
def test_username_picker_dont_use_displayname_avatar_or_email(self) -> None:
"""Test the happy path of a username picker flow without using displayname, avatar or email."""
@@ -1620,12 +1476,11 @@ class UsernamePickerTestCase(HomeserverTestCase):
mxid = "@bobby:test"
displayname = "Jonny"
- email = "bobby@test.com"
picture = "mxc://test/avatar_url"
username = "bobby"
picker_url, session_id = self.proceed_to_username_picker_page(
- fake_oidc_server, displayname, email, picture
+ fake_oidc_server, displayname, picture
)
# Now, submit a username to the username picker, which should serve a redirect
@@ -1696,13 +1551,6 @@ class UsernamePickerTestCase(HomeserverTestCase):
self.assertNotIn("avatar_url", channel.json_body)
self.assertEqual(username, channel.json_body["displayname"])
- # ensure the email from the OIDC response has not been configured for the user.
- channel = self.make_request(
- "GET", "/account/3pid", access_token=chan.json_body["access_token"]
- )
- self.assertEqual(channel.code, 200, channel.result)
- self.assertListEqual([], channel.json_body["threepids"])
-
async def mock_get_file(
url: str,
diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py
index fbacf9d869..99a0fd4fcd 100644
--- a/tests/rest/client/test_login_token_request.py
+++ b/tests/rest/client/test_login_token_request.py
@@ -43,7 +43,6 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.hs = self.setup_test_homeserver()
self.hs.config.registration.enable_registration = True
- self.hs.config.registration.registrations_require_3pid = []
self.hs.config.registration.auto_join_rooms = []
self.hs.config.captcha.enable_registration_captcha = False
diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py
index 30b6d31d0a..6ee761e44b 100644
--- a/tests/rest/client/test_media.py
+++ b/tests/rest/client/test_media.py
@@ -24,14 +24,13 @@ import json
import os
import re
import shutil
-from typing import Any, BinaryIO, Dict, List, Optional, Sequence, Tuple, Type
+from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Sequence, Tuple, Type
from unittest.mock import MagicMock, Mock, patch
from urllib import parse
from urllib.parse import quote, urlencode
from parameterized import parameterized, parameterized_class
from PIL import Image as Image
-from typing_extensions import ClassVar
from twisted.internet import defer
from twisted.internet._resolver import HostResolution
@@ -66,6 +65,7 @@ from tests.media.test_media_storage import (
SVG,
TestImage,
empty_file,
+ small_cmyk_jpeg,
small_lossless_webp,
small_png,
small_png_with_transparency,
@@ -137,6 +137,7 @@ class MediaDomainBlockingTests(unittest.HomeserverTestCase):
time_now_ms=clock.time_msec(),
upload_name="test.png",
filesystem_id=file_id,
+ sha256=file_id,
)
)
self.register_user("user", "password")
@@ -1005,7 +1006,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
data = base64.b64encode(SMALL_PNG)
end_content = (
- b"<html><head>" b'<img src="data:image/png;base64,%s" />' b"</head></html>"
+ b'<html><head><img src="data:image/png;base64,%s" /></head></html>'
) % (data,)
channel = self.make_request(
@@ -1617,6 +1618,63 @@ class MediaConfigTest(unittest.HomeserverTestCase):
)
+class MediaConfigModuleCallbackTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ media.register_servlets,
+ admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def make_homeserver(
+ self, reactor: ThreadedMemoryReactorClock, clock: Clock
+ ) -> HomeServer:
+ config = self.default_config()
+
+ self.storage_path = self.mktemp()
+ self.media_store_path = self.mktemp()
+ os.mkdir(self.storage_path)
+ os.mkdir(self.media_store_path)
+ config["media_store_path"] = self.media_store_path
+
+ provider_config = {
+ "module": "synapse.media.storage_provider.FileStorageProviderBackend",
+ "store_local": True,
+ "store_synchronous": False,
+ "store_remote": True,
+ "config": {"directory": self.storage_path},
+ }
+
+ config["media_storage_providers"] = [provider_config]
+
+ return self.setup_test_homeserver(config=config)
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.user = self.register_user("user", "password")
+ self.tok = self.login("user", "password")
+
+ hs.get_module_api().register_media_repository_callbacks(
+ get_media_config_for_user=self.get_media_config_for_user,
+ )
+
+ async def get_media_config_for_user(
+ self,
+ user_id: str,
+ ) -> Optional[JsonDict]:
+ # We echo back the user_id and set a custom upload size.
+ return {"m.upload.size": 1024, "user_id": user_id}
+
+ def test_media_config(self) -> None:
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/v1/media/config",
+ shorthand=False,
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["m.upload.size"], 1024)
+ self.assertEqual(channel.json_body["user_id"], self.user)
+
+
class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
servlets = [
media.register_servlets,
@@ -1916,6 +1974,7 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
test_images = [
small_png,
small_png_with_transparency,
+ small_cmyk_jpeg,
small_lossless_webp,
empty_file,
SVG,
@@ -1957,7 +2016,7 @@ class DownloadAndThumbnailTestCase(unittest.HomeserverTestCase):
"""A mock for MatrixFederationHttpClient.federation_get_file."""
def write_to(
- r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]], bytes]]
+ r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]], bytes]],
) -> Tuple[int, Dict[bytes, List[bytes]], bytes]:
data, response = r
output_stream.write(data)
@@ -1991,7 +2050,7 @@ class DownloadAndThumbnailTestCase(unittest.HomeserverTestCase):
"""A mock for MatrixFederationHttpClient.get_file."""
def write_to(
- r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]
+ r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]],
) -> Tuple[int, Dict[bytes, List[bytes]]]:
data, response = r
output_stream.write(data)
@@ -2400,7 +2459,7 @@ class DownloadAndThumbnailTestCase(unittest.HomeserverTestCase):
if expected_body is not None:
self.assertEqual(
- channel.result["body"], expected_body, channel.result["body"]
+ channel.result["body"], expected_body, channel.result["body"].hex()
)
else:
# ensure that the result is at least some valid image
@@ -2592,6 +2651,7 @@ class AuthenticatedMediaTestCase(unittest.HomeserverTestCase):
time_now_ms=self.clock.time_msec(),
upload_name="remote_test.png",
filesystem_id=file_id,
+ sha256=file_id,
)
)
@@ -2675,3 +2735,114 @@ class AuthenticatedMediaTestCase(unittest.HomeserverTestCase):
access_token=self.tok,
)
self.assertEqual(channel10.code, 200)
+
+ def test_authenticated_media_etag(self) -> None:
+ """Test that ETag works correctly with authenticated media over client
+ APIs"""
+
+ # upload some local media with authentication on
+ channel = self.make_request(
+ "POST",
+ "_matrix/media/v3/upload?filename=test_png_upload",
+ SMALL_PNG,
+ self.tok,
+ shorthand=False,
+ content_type=b"image/png",
+ custom_headers=[("Content-Length", str(67))],
+ )
+ self.assertEqual(channel.code, 200)
+ res = channel.json_body.get("content_uri")
+ assert res is not None
+ uri = res.split("mxc://")[1]
+
+ # Check standard media endpoint
+ self._check_caching(f"/download/{uri}")
+
+ # check thumbnails as well
+ params = "?width=32&height=32&method=crop"
+ self._check_caching(f"/thumbnail/{uri}{params}")
+
+ # Inject a piece of remote media.
+ file_id = "abcdefg12345"
+ file_info = FileInfo(server_name="lonelyIsland", file_id=file_id)
+
+ media_storage = self.hs.get_media_repository().media_storage
+
+ ctx = media_storage.store_into_file(file_info)
+ (f, fname) = self.get_success(ctx.__aenter__())
+ f.write(SMALL_PNG)
+ self.get_success(ctx.__aexit__(None, None, None))
+
+ # we write the authenticated status when storing media, so this should pick up
+ # config and authenticate the media
+ self.get_success(
+ self.store.store_cached_remote_media(
+ origin="lonelyIsland",
+ media_id="52",
+ media_type="image/png",
+ media_length=1,
+ time_now_ms=self.clock.time_msec(),
+ upload_name="remote_test.png",
+ filesystem_id=file_id,
+ sha256=file_id,
+ )
+ )
+
+ # ensure we have thumbnails for the non-dynamic code path
+ if self.extra_config == {"dynamic_thumbnails": False}:
+ self.get_success(
+ self.repo._generate_thumbnails(
+ "lonelyIsland", "52", file_id, "image/png"
+ )
+ )
+
+ self._check_caching("/download/lonelyIsland/52")
+
+ params = "?width=32&height=32&method=crop"
+ self._check_caching(f"/thumbnail/lonelyIsland/52{params}")
+
+ def _check_caching(self, path: str) -> None:
+ """
+ Checks that:
+ 1. fetching the path returns an ETag header
+ 2. refetching with the ETag returns a 304 without a body
+ 3. refetching with the ETag but through unauthenticated endpoint
+ returns 404
+ """
+
+ # Request media over authenticated endpoint, should be found
+ channel1 = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/media{path}",
+ access_token=self.tok,
+ shorthand=False,
+ )
+ self.assertEqual(channel1.code, 200)
+
+ # Should have a single ETag field
+ etags = channel1.headers.getRawHeaders("ETag")
+ self.assertIsNotNone(etags)
+ assert etags is not None # For mypy
+ self.assertEqual(len(etags), 1)
+ etag = etags[0]
+
+ # Refetching with the etag should result in 304 and empty body.
+ channel2 = self.make_request(
+ "GET",
+ f"/_matrix/client/v1/media{path}",
+ access_token=self.tok,
+ shorthand=False,
+ custom_headers=[("If-None-Match", etag)],
+ )
+ self.assertEqual(channel2.code, 304)
+ self.assertEqual(channel2.is_finished(), True)
+ self.assertNotIn("body", channel2.result)
+
+ # Refetching with the etag but no access token should result in 404.
+ channel3 = self.make_request(
+ "GET",
+ f"/_matrix/media/r0{path}",
+ shorthand=False,
+ custom_headers=[("If-None-Match", etag)],
+ )
+ self.assertEqual(channel3.code, 404)
diff --git a/tests/rest/client/test_models.py b/tests/rest/client/test_models.py
deleted file mode 100644
index f8a56c80ca..0000000000
--- a/tests/rest/client/test_models.py
+++ /dev/null
@@ -1,89 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
-# Copyright (C) 2023 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-# Originally licensed under the Apache License, Version 2.0:
-# <http://www.apache.org/licenses/LICENSE-2.0>.
-#
-# [This file includes modifications made by New Vector Limited]
-#
-#
-import unittest as stdlib_unittest
-from typing import TYPE_CHECKING
-
-from typing_extensions import Literal
-
-from synapse._pydantic_compat import HAS_PYDANTIC_V2
-from synapse.types.rest.client import EmailRequestTokenBody
-
-if TYPE_CHECKING or HAS_PYDANTIC_V2:
- from pydantic.v1 import BaseModel, ValidationError
-else:
- from pydantic import BaseModel, ValidationError
-
-
-class ThreepidMediumEnumTestCase(stdlib_unittest.TestCase):
- class Model(BaseModel):
- medium: Literal["email", "msisdn"]
-
- def test_accepts_valid_medium_string(self) -> None:
- """Sanity check that Pydantic behaves sensibly with an enum-of-str
-
- This is arguably more of a test of a class that inherits from str and Enum
- simultaneously.
- """
- model = self.Model.parse_obj({"medium": "email"})
- self.assertEqual(model.medium, "email")
-
- def test_rejects_invalid_medium_value(self) -> None:
- with self.assertRaises(ValidationError):
- self.Model.parse_obj({"medium": "interpretive_dance"})
-
- def test_rejects_invalid_medium_type(self) -> None:
- with self.assertRaises(ValidationError):
- self.Model.parse_obj({"medium": 123})
-
-
-class EmailRequestTokenBodyTestCase(stdlib_unittest.TestCase):
- base_request = {
- "client_secret": "hunter2",
- "email": "alice@wonderland.com",
- "send_attempt": 1,
- }
-
- def test_token_required_if_id_server_provided(self) -> None:
- with self.assertRaises(ValidationError):
- EmailRequestTokenBody.parse_obj(
- {
- **self.base_request,
- "id_server": "identity.wonderland.com",
- }
- )
- with self.assertRaises(ValidationError):
- EmailRequestTokenBody.parse_obj(
- {
- **self.base_request,
- "id_server": "identity.wonderland.com",
- "id_access_token": None,
- }
- )
-
- def test_token_typechecked_when_id_server_provided(self) -> None:
- with self.assertRaises(ValidationError):
- EmailRequestTokenBody.parse_obj(
- {
- **self.base_request,
- "id_server": "identity.wonderland.com",
- "id_access_token": 1337,
- }
- )
diff --git a/tests/rest/client/test_owned_state.py b/tests/rest/client/test_owned_state.py
new file mode 100644
index 0000000000..5fb5767676
--- /dev/null
+++ b/tests/rest/client/test_owned_state.py
@@ -0,0 +1,308 @@
+from http import HTTPStatus
+
+from parameterized import parameterized_class
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.errors import Codes
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
+from synapse.rest import admin
+from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
+
+from tests.unittest import HomeserverTestCase
+
+_STATE_EVENT_TEST_TYPE = "com.example.test"
+
+# To stress-test parsing, include separator & sigil characters
+_STATE_KEY_SUFFIX = "_state_key_suffix:!@#$123"
+
+
+class OwnedStateBase(HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.creator_user_id = self.register_user("creator", "pass")
+ self.creator_access_token = self.login("creator", "pass")
+ self.user1_user_id = self.register_user("user1", "pass")
+ self.user1_access_token = self.login("user1", "pass")
+
+ self.room_id = self.helper.create_room_as(
+ self.creator_user_id,
+ tok=self.creator_access_token,
+ is_public=True,
+ extra_content={
+ "power_level_content_override": {
+ "events": {
+ _STATE_EVENT_TEST_TYPE: 0,
+ },
+ },
+ },
+ )
+
+ self.helper.join(
+ room=self.room_id, user=self.user1_user_id, tok=self.user1_access_token
+ )
+
+
+class WithoutOwnedStateTestCase(OwnedStateBase):
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["default_room_version"] = RoomVersions.V10.identifier
+ return config
+
+ def test_user_can_set_state_with_own_userid_key(self) -> None:
+ self.helper.send_state(
+ self.room_id,
+ _STATE_EVENT_TEST_TYPE,
+ {},
+ state_key=f"{self.user1_user_id}",
+ tok=self.user1_access_token,
+ expect_code=HTTPStatus.OK,
+ )
+
+ def test_room_creator_cannot_set_state_with_own_suffixed_key(self) -> None:
+ body = self.helper.send_state(
+ self.room_id,
+ _STATE_EVENT_TEST_TYPE,
+ {},
+ state_key=f"{self.creator_user_id}{_STATE_KEY_SUFFIX}",
+ tok=self.creator_access_token,
+ expect_code=HTTPStatus.FORBIDDEN,
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.FORBIDDEN,
+ body,
+ )
+
+ def test_room_creator_cannot_set_state_with_other_userid_key(self) -> None:
+ body = self.helper.send_state(
+ self.room_id,
+ _STATE_EVENT_TEST_TYPE,
+ {},
+ state_key=f"{self.user1_user_id}",
+ tok=self.creator_access_token,
+ expect_code=HTTPStatus.FORBIDDEN,
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.FORBIDDEN,
+ body,
+ )
+
+ def test_room_creator_cannot_set_state_with_other_suffixed_key(self) -> None:
+ body = self.helper.send_state(
+ self.room_id,
+ _STATE_EVENT_TEST_TYPE,
+ {},
+ state_key=f"{self.user1_user_id}{_STATE_KEY_SUFFIX}",
+ tok=self.creator_access_token,
+ expect_code=HTTPStatus.FORBIDDEN,
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.FORBIDDEN,
+ body,
+ )
+
+ def test_room_creator_cannot_set_state_with_nonmember_userid_key(self) -> None:
+ body = self.helper.send_state(
+ self.room_id,
+ _STATE_EVENT_TEST_TYPE,
+ {},
+ state_key="@notinroom:hs2",
+ tok=self.creator_access_token,
+ expect_code=HTTPStatus.FORBIDDEN,
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.FORBIDDEN,
+ body,
+ )
+
+ def test_room_creator_cannot_set_state_with_malformed_userid_key(self) -> None:
+ body = self.helper.send_state(
+ self.room_id,
+ _STATE_EVENT_TEST_TYPE,
+ {},
+ state_key="@oops",
+ tok=self.creator_access_token,
+ expect_code=HTTPStatus.FORBIDDEN,
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.FORBIDDEN,
+ body,
+ )
+
+
+@parameterized_class(
+ ("room_version",),
+ [(i,) for i, v in KNOWN_ROOM_VERSIONS.items() if v.msc3757_enabled],
+)
+class MSC3757OwnedStateTestCase(OwnedStateBase):
+ room_version: str
+
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["default_room_version"] = self.room_version
+ return config
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ super().prepare(reactor, clock, hs)
+
+ self.user2_user_id = self.register_user("user2", "pass")
+ self.user2_access_token = self.login("user2", "pass")
+
+ self.helper.join(
+ room=self.room_id, user=self.user2_user_id, tok=self.user2_access_token
+ )
+
+ def test_user_can_set_state_with_own_suffixed_key(self) -> None:
+ self.helper.send_state(
+ self.room_id,
+ _STATE_EVENT_TEST_TYPE,
+ {},
+ state_key=f"{self.user1_user_id}{_STATE_KEY_SUFFIX}",
+ tok=self.user1_access_token,
+ expect_code=HTTPStatus.OK,
+ )
+
+ def test_room_creator_can_set_state_with_other_userid_key(self) -> None:
+ self.helper.send_state(
+ self.room_id,
+ _STATE_EVENT_TEST_TYPE,
+ {},
+ state_key=f"{self.user1_user_id}",
+ tok=self.creator_access_token,
+ expect_code=HTTPStatus.OK,
+ )
+
+ def test_room_creator_can_set_state_with_other_suffixed_key(self) -> None:
+ self.helper.send_state(
+ self.room_id,
+ _STATE_EVENT_TEST_TYPE,
+ {},
+ state_key=f"{self.user1_user_id}{_STATE_KEY_SUFFIX}",
+ tok=self.creator_access_token,
+ expect_code=HTTPStatus.OK,
+ )
+
+ def test_user_cannot_set_state_with_other_userid_key(self) -> None:
+ body = self.helper.send_state(
+ self.room_id,
+ _STATE_EVENT_TEST_TYPE,
+ {},
+ state_key=f"{self.user2_user_id}",
+ tok=self.user1_access_token,
+ expect_code=HTTPStatus.FORBIDDEN,
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.FORBIDDEN,
+ body,
+ )
+
+ def test_user_cannot_set_state_with_other_suffixed_key(self) -> None:
+ body = self.helper.send_state(
+ self.room_id,
+ _STATE_EVENT_TEST_TYPE,
+ {},
+ state_key=f"{self.user2_user_id}{_STATE_KEY_SUFFIX}",
+ tok=self.user1_access_token,
+ expect_code=HTTPStatus.FORBIDDEN,
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.FORBIDDEN,
+ body,
+ )
+
+ def test_user_cannot_set_state_with_unseparated_suffixed_key(self) -> None:
+ body = self.helper.send_state(
+ self.room_id,
+ _STATE_EVENT_TEST_TYPE,
+ {},
+ state_key=f"{self.user1_user_id}{_STATE_KEY_SUFFIX[1:]}",
+ tok=self.user1_access_token,
+ expect_code=HTTPStatus.FORBIDDEN,
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.FORBIDDEN,
+ body,
+ )
+
+ def test_user_cannot_set_state_with_misplaced_userid_in_key(self) -> None:
+ body = self.helper.send_state(
+ self.room_id,
+ _STATE_EVENT_TEST_TYPE,
+ {},
+ # Still put @ at start of state key, because without it, there is no write protection at all
+ state_key=f"@prefix_{self.user1_user_id}{_STATE_KEY_SUFFIX}",
+ tok=self.user1_access_token,
+ expect_code=HTTPStatus.FORBIDDEN,
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.FORBIDDEN,
+ body,
+ )
+
+ def test_room_creator_can_set_state_with_nonmember_userid_key(self) -> None:
+ self.helper.send_state(
+ self.room_id,
+ _STATE_EVENT_TEST_TYPE,
+ {},
+ state_key="@notinroom:hs2",
+ tok=self.creator_access_token,
+ expect_code=HTTPStatus.OK,
+ )
+
+ def test_room_creator_cannot_set_state_with_malformed_userid_key(self) -> None:
+ body = self.helper.send_state(
+ self.room_id,
+ _STATE_EVENT_TEST_TYPE,
+ {},
+ state_key="@oops",
+ tok=self.creator_access_token,
+ expect_code=HTTPStatus.BAD_REQUEST,
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.BAD_JSON,
+ body,
+ )
+
+ def test_room_creator_cannot_set_state_with_improperly_suffixed_key(self) -> None:
+ body = self.helper.send_state(
+ self.room_id,
+ _STATE_EVENT_TEST_TYPE,
+ {},
+ state_key=f"{self.creator_user_id}@{_STATE_KEY_SUFFIX[1:]}",
+ tok=self.creator_access_token,
+ expect_code=HTTPStatus.BAD_REQUEST,
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.BAD_JSON,
+ body,
+ )
diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py
index 5ced8319e1..6b9c70974a 100644
--- a/tests/rest/client/test_presence.py
+++ b/tests/rest/client/test_presence.py
@@ -29,6 +29,7 @@ from synapse.types import UserID
from synapse.util import Clock
from tests import unittest
+from tests.unittest import override_config
class PresenceTestCase(unittest.HomeserverTestCase):
@@ -95,3 +96,54 @@ class PresenceTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(self.presence_handler.set_state.call_count, 0)
+
+ @override_config(
+ {"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}}
+ )
+ def test_put_presence_over_ratelimit(self) -> None:
+ """
+ Multiple PUTs to the status endpoint without sufficient delay will be rate limited.
+ """
+ self.hs.config.server.presence_enabled = True
+
+ body = {"presence": "here", "status_msg": "beep boop"}
+ channel = self.make_request(
+ "PUT", "/presence/%s/status" % (self.user_id,), body
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.OK)
+
+ body = {"presence": "here", "status_msg": "beep boop"}
+ channel = self.make_request(
+ "PUT", "/presence/%s/status" % (self.user_id,), body
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.TOO_MANY_REQUESTS)
+ self.assertEqual(self.presence_handler.set_state.call_count, 1)
+
+ @override_config(
+ {"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}}
+ )
+ def test_put_presence_within_ratelimit(self) -> None:
+ """
+ Multiple PUTs to the status endpoint with sufficient delay should all call set_state.
+ """
+ self.hs.config.server.presence_enabled = True
+
+ body = {"presence": "here", "status_msg": "beep boop"}
+ channel = self.make_request(
+ "PUT", "/presence/%s/status" % (self.user_id,), body
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.OK)
+
+ # Advance time a sufficient amount to avoid rate limiting.
+ self.reactor.advance(30)
+
+ body = {"presence": "here", "status_msg": "beep boop"}
+ channel = self.make_request(
+ "PUT", "/presence/%s/status" % (self.user_id,), body
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertEqual(self.presence_handler.set_state.call_count, 2)
diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py
index f98f3f77aa..708402b792 100644
--- a/tests/rest/client/test_profile.py
+++ b/tests/rest/client/test_profile.py
@@ -20,20 +20,25 @@
#
"""Tests REST events for /profile paths."""
+
import urllib.parse
from http import HTTPStatus
from typing import Any, Dict, Optional
+from canonicaljson import encode_canonical_json
+
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import Codes
from synapse.rest import admin
from synapse.rest.client import login, profile, room
from synapse.server import HomeServer
+from synapse.storage.databases.main.profile import MAX_PROFILE_SIZE
from synapse.types import UserID
from synapse.util import Clock
from tests import unittest
+from tests.utils import USE_POSTGRES_FOR_TESTS
class ProfileTestCase(unittest.HomeserverTestCase):
@@ -479,6 +484,298 @@ class ProfileTestCase(unittest.HomeserverTestCase):
# The client requested ?propagate=true, so it should have happened.
self.assertEqual(channel.json_body.get(prop), "http://my.server/pic.gif")
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_get_missing_custom_field(self) -> None:
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ )
+ self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_get_missing_custom_field_invalid_field_name(self) -> None:
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/[custom_field]",
+ )
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_get_custom_field_rejects_bad_username(self) -> None:
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{urllib.parse.quote('@alice:')}/custom_field",
+ )
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_set_custom_field(self) -> None:
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ content={"custom_field": "test"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.json_body, {"custom_field": "test"})
+
+ # Overwriting the field should work.
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ content={"custom_field": "new_Value"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.json_body, {"custom_field": "new_Value"})
+
+ # Deleting the field should work.
+ channel = self.make_request(
+ "DELETE",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ content={},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ )
+ self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_non_string(self) -> None:
+ """Non-string fields are supported for custom fields."""
+ fields = {
+ "bool_field": True,
+ "array_field": ["test"],
+ "object_field": {"test": "test"},
+ "numeric_field": 1,
+ "null_field": None,
+ }
+
+ for key, value in fields.items():
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
+ content={key: value},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v3/profile/{self.owner}",
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.json_body, {"displayname": "owner", **fields})
+
+ # Check getting individual fields works.
+ for key, value in fields.items():
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.json_body, {key: value})
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_set_custom_field_noauth(self) -> None:
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ content={"custom_field": "test"},
+ )
+ self.assertEqual(channel.code, 401, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.MISSING_TOKEN)
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_set_custom_field_size(self) -> None:
+ """
+ Attempts to set a custom field name that is too long should get a 400 error.
+ """
+ # Key is missing.
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/",
+ content={"": "test"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Single key is too large.
+ key = "c" * 500
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
+ content={key: "test"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.KEY_TOO_LARGE)
+
+ channel = self.make_request(
+ "DELETE",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
+ content={key: "test"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.KEY_TOO_LARGE)
+
+ # Key doesn't match body.
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ content={"diff_key": "test"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM)
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_set_custom_field_profile_too_long(self) -> None:
+ """
+ Attempts to set a custom field that would push the overall profile too large.
+ """
+ # Get right to the boundary:
+ # len("displayname") + len("owner") + 5 = 21 for the displayname
+ # 1 + 65498 + 5 for key "a" = 65504
+ # 2 braces, 1 comma
+ # 3 + 21 + 65498 = 65522 < 65536.
+ key = "a"
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
+ content={key: "a" * 65498},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # Get the entire profile.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v3/profile/{self.owner}",
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ canonical_json = encode_canonical_json(channel.json_body)
+ # 6 is the minimum bytes to store a value: 4 quotes, 1 colon, 1 comma, an empty key.
+ # Be one below that so we can prove we're at the boundary.
+ self.assertEqual(len(canonical_json), MAX_PROFILE_SIZE - 8)
+
+ # Postgres stores JSONB with whitespace, while SQLite doesn't.
+ if USE_POSTGRES_FOR_TESTS:
+ ADDITIONAL_CHARS = 0
+ else:
+ ADDITIONAL_CHARS = 1
+
+ # The next one should fail, note the value has a (JSON) length of 2.
+ key = "b"
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
+ content={key: "1" + "a" * ADDITIONAL_CHARS},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE)
+
+ # Setting an avatar or (longer) display name should not work.
+ channel = self.make_request(
+ "PUT",
+ f"/profile/{self.owner}/displayname",
+ content={"displayname": "owner12345678" + "a" * ADDITIONAL_CHARS},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE)
+
+ channel = self.make_request(
+ "PUT",
+ f"/profile/{self.owner}/avatar_url",
+ content={"avatar_url": "mxc://foo/bar"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE)
+
+ # Removing a single byte should work.
+ key = "b"
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
+ content={key: "" + "a" * ADDITIONAL_CHARS},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # Finally, setting a field that already exists to a value that is <= in length should work.
+ key = "a"
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
+ content={key: ""},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_set_custom_field_displayname(self) -> None:
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/displayname",
+ content={"displayname": "test"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ displayname = self._get_displayname()
+ self.assertEqual(displayname, "test")
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_set_custom_field_avatar_url(self) -> None:
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/avatar_url",
+ content={"avatar_url": "mxc://test/good"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ avatar_url = self._get_avatar_url()
+ self.assertEqual(avatar_url, "mxc://test/good")
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_set_custom_field_other(self) -> None:
+ """Setting someone else's profile field should fail"""
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.other}/custom_field",
+ content={"custom_field": "test"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 403, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
+
def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]) -> None:
"""Stores metadata about files in the database.
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 694f143eff..d40efdfe1d 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -64,7 +64,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self, reactor: ThreadedMemoryReactorClock, clock: Clock
) -> HomeServer:
hs = super().make_homeserver(reactor, clock)
- hs.get_send_email_handler()._sendmail = AsyncMock()
return hs
def test_POST_appservice_registration_valid(self) -> None:
@@ -120,6 +119,34 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 401, msg=channel.result)
+ def test_POST_appservice_msc4190_enabled(self) -> None:
+ # With MSC4190 enabled, the registration should *not* return an access token
+ user_id = "@as_user_kermit:test"
+ as_token = "i_am_an_app_service"
+
+ appservice = ApplicationService(
+ as_token,
+ id="1234",
+ namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
+ sender="@as:test",
+ msc4190_device_management=True,
+ )
+
+ self.hs.get_datastores().main.services_cache.append(appservice)
+ request_data = {
+ "username": "as_user_kermit",
+ "type": APP_SERVICE_REGISTRATION_TYPE,
+ }
+
+ channel = self.make_request(
+ b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
+ )
+
+ self.assertEqual(channel.code, 200, msg=channel.result)
+ det_data = {"user_id": user_id, "home_server": self.hs.hostname}
+ self.assertLessEqual(det_data.items(), channel.json_body.items())
+ self.assertNotIn("access_token", channel.json_body)
+
def test_POST_bad_password(self) -> None:
request_data = {"username": "kermit", "password": 666}
channel = self.make_request(b"POST", self.url, request_data)
@@ -593,155 +620,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
# with the stock config, we only expect the dummy flow
self.assertCountEqual([["m.login.dummy"]], (f["stages"] for f in flows))
-
- @unittest.override_config(
- {
- "public_baseurl": "https://test_server",
- "enable_registration_captcha": True,
- "user_consent": {
- "version": "1",
- "template_dir": "/",
- "require_at_registration": True,
- },
- "account_threepid_delegates": {
- "msisdn": "https://id_server",
- },
- "email": {"notif_from": "Synapse <synapse@example.com>"},
- }
- )
- def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None:
- channel = self.make_request(b"POST", self.url, b"{}")
- self.assertEqual(channel.code, 401, msg=channel.result)
- flows = channel.json_body["flows"]
-
- self.assertCountEqual(
- [
- ["m.login.recaptcha", "m.login.terms", "m.login.dummy"],
- ["m.login.recaptcha", "m.login.terms", "m.login.email.identity"],
- ["m.login.recaptcha", "m.login.terms", "m.login.msisdn"],
- [
- "m.login.recaptcha",
- "m.login.terms",
- "m.login.msisdn",
- "m.login.email.identity",
- ],
- ],
- (f["stages"] for f in flows),
- )
-
- @unittest.override_config(
- {
- "public_baseurl": "https://test_server",
- "registrations_require_3pid": ["email"],
- "disable_msisdn_registration": True,
- "email": {
- "smtp_host": "mail_server",
- "smtp_port": 2525,
- "notif_from": "sender@host",
- },
- }
- )
- def test_advertised_flows_no_msisdn_email_required(self) -> None:
- channel = self.make_request(b"POST", self.url, b"{}")
- self.assertEqual(channel.code, 401, msg=channel.result)
- flows = channel.json_body["flows"]
-
- # with the stock config, we expect all four combinations of 3pid
- self.assertCountEqual(
- [["m.login.email.identity"]], (f["stages"] for f in flows)
- )
-
- @unittest.override_config(
- {
- "request_token_inhibit_3pid_errors": True,
- "public_baseurl": "https://test_server",
- "email": {
- "smtp_host": "mail_server",
- "smtp_port": 2525,
- "notif_from": "sender@host",
- },
- }
- )
- def test_request_token_existing_email_inhibit_error(self) -> None:
- """Test that requesting a token via this endpoint doesn't leak existing
- associations if configured that way.
- """
- user_id = self.register_user("kermit", "monkey")
- self.login("kermit", "monkey")
-
- email = "test@example.com"
-
- # Add a threepid
- self.get_success(
- self.hs.get_datastores().main.user_add_threepid(
- user_id=user_id,
- medium="email",
- address=email,
- validated_at=0,
- added_at=0,
- )
- )
-
- channel = self.make_request(
- "POST",
- b"register/email/requestToken",
- {"client_secret": "foobar", "email": email, "send_attempt": 1},
- )
- self.assertEqual(200, channel.code, channel.result)
-
- self.assertIsNotNone(channel.json_body.get("sid"))
-
- @unittest.override_config(
- {
- "public_baseurl": "https://test_server",
- "email": {
- "smtp_host": "mail_server",
- "smtp_port": 2525,
- "notif_from": "sender@host",
- },
- }
- )
- def test_reject_invalid_email(self) -> None:
- """Check that bad emails are rejected"""
-
- # Test for email with multiple @
- channel = self.make_request(
- "POST",
- b"register/email/requestToken",
- {"client_secret": "foobar", "email": "email@@email", "send_attempt": 1},
- )
- self.assertEqual(400, channel.code, channel.result)
- # Check error to ensure that we're not erroring due to a bug in the test.
- self.assertEqual(
- channel.json_body,
- {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"},
- )
-
- # Test for email with no @
- channel = self.make_request(
- "POST",
- b"register/email/requestToken",
- {"client_secret": "foobar", "email": "email", "send_attempt": 1},
- )
- self.assertEqual(400, channel.code, channel.result)
- self.assertEqual(
- channel.json_body,
- {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"},
- )
-
- # Test for super long email
- email = "a@" + "a" * 1000
- channel = self.make_request(
- "POST",
- b"register/email/requestToken",
- {"client_secret": "foobar", "email": email, "send_attempt": 1},
- )
- self.assertEqual(400, channel.code, channel.result)
- self.assertEqual(
- channel.json_body,
- {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"},
- )
-
+
@override_config(
{
"inhibit_user_in_use_error": True,
@@ -925,224 +804,6 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, msg=channel.result)
-class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
- servlets = [
- register.register_servlets,
- synapse.rest.admin.register_servlets_for_client_rest_resource,
- login.register_servlets,
- sync.register_servlets,
- account_validity.register_servlets,
- account.register_servlets,
- ]
-
- def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- config = self.default_config()
-
- # Test for account expiring after a week and renewal emails being sent 2
- # days before expiry.
- config["enable_registration"] = True
- config["account_validity"] = {
- "enabled": True,
- "period": 604800000, # Time in ms for 1 week
- "renew_at": 172800000, # Time in ms for 2 days
- "renew_by_email_enabled": True,
- "renew_email_subject": "Renew your account",
- "account_renewed_html_path": "account_renewed.html",
- "invalid_token_html_path": "invalid_token.html",
- }
-
- # Email config.
-
- config["email"] = {
- "enable_notifs": True,
- "template_dir": os.path.abspath(
- pkg_resources.resource_filename("synapse", "res/templates")
- ),
- "expiry_template_html": "notice_expiry.html",
- "expiry_template_text": "notice_expiry.txt",
- "notif_template_html": "notif_mail.html",
- "notif_template_text": "notif_mail.txt",
- "smtp_host": "127.0.0.1",
- "smtp_port": 20,
- "require_transport_security": False,
- "smtp_user": None,
- "smtp_pass": None,
- "notif_from": "test@example.com",
- }
-
- self.hs = self.setup_test_homeserver(config=config)
-
- async def sendmail(*args: Any, **kwargs: Any) -> None:
- self.email_attempts.append((args, kwargs))
-
- self.email_attempts: List[Tuple[Any, Any]] = []
- self.hs.get_send_email_handler()._sendmail = sendmail
-
- self.store = self.hs.get_datastores().main
-
- return self.hs
-
- def test_renewal_email(self) -> None:
- self.email_attempts = []
-
- (user_id, tok) = self.create_user()
-
- # Move 5 days forward. This should trigger a renewal email to be sent.
- self.reactor.advance(datetime.timedelta(days=5).total_seconds())
- self.assertEqual(len(self.email_attempts), 1)
-
- # Retrieving the URL from the email is too much pain for now, so we
- # retrieve the token from the DB.
- renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id))
- url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
- channel = self.make_request(b"GET", url)
- self.assertEqual(channel.code, 200, msg=channel.result)
-
- # Check that we're getting HTML back.
- content_type = channel.headers.getRawHeaders(b"Content-Type")
- self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result)
-
- # Check that the HTML we're getting is the one we expect on a successful renewal.
- expiration_ts = self.get_success(self.store.get_expiration_ts_for_user(user_id))
- expected_html = self.hs.config.account_validity.account_validity_account_renewed_template.render(
- expiration_ts=expiration_ts
- )
- self.assertEqual(
- channel.result["body"], expected_html.encode("utf8"), channel.result
- )
-
- # Move 1 day forward. Try to renew with the same token again.
- url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
- channel = self.make_request(b"GET", url)
- self.assertEqual(channel.code, 200, msg=channel.result)
-
- # Check that we're getting HTML back.
- content_type = channel.headers.getRawHeaders(b"Content-Type")
- self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result)
-
- # Check that the HTML we're getting is the one we expect when reusing a
- # token. The account expiration date should not have changed.
- expected_html = self.hs.config.account_validity.account_validity_account_previously_renewed_template.render(
- expiration_ts=expiration_ts
- )
- self.assertEqual(
- channel.result["body"], expected_html.encode("utf8"), channel.result
- )
-
- # Move 3 days forward. If the renewal failed, every authed request with
- # our access token should be denied from now, otherwise they should
- # succeed.
- self.reactor.advance(datetime.timedelta(days=3).total_seconds())
- channel = self.make_request(b"GET", "/sync", access_token=tok)
- self.assertEqual(channel.code, 200, msg=channel.result)
-
- def test_renewal_invalid_token(self) -> None:
- # Hit the renewal endpoint with an invalid token and check that it behaves as
- # expected, i.e. that it responds with 404 Not Found and the correct HTML.
- url = "/_matrix/client/unstable/account_validity/renew?token=123"
- channel = self.make_request(b"GET", url)
- self.assertEqual(channel.code, 404, msg=channel.result)
-
- # Check that we're getting HTML back.
- content_type = channel.headers.getRawHeaders(b"Content-Type")
- self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result)
-
- # Check that the HTML we're getting is the one we expect when using an
- # invalid/unknown token.
- expected_html = (
- self.hs.config.account_validity.account_validity_invalid_token_template.render()
- )
- self.assertEqual(
- channel.result["body"], expected_html.encode("utf8"), channel.result
- )
-
- def test_manual_email_send(self) -> None:
- self.email_attempts = []
-
- (user_id, tok) = self.create_user()
- channel = self.make_request(
- b"POST",
- "/_matrix/client/unstable/account_validity/send_mail",
- access_token=tok,
- )
- self.assertEqual(channel.code, 200, msg=channel.result)
-
- self.assertEqual(len(self.email_attempts), 1)
-
- def test_deactivated_user(self) -> None:
- self.email_attempts = []
-
- (user_id, tok) = self.create_user()
-
- request_data = {
- "auth": {
- "type": "m.login.password",
- "user": user_id,
- "password": "monkey",
- },
- "erase": False,
- }
- channel = self.make_request(
- "POST", "account/deactivate", request_data, access_token=tok
- )
- self.assertEqual(channel.code, 200)
-
- self.reactor.advance(datetime.timedelta(days=8).total_seconds())
-
- self.assertEqual(len(self.email_attempts), 0)
-
- def create_user(self) -> Tuple[str, str]:
- user_id = self.register_user("kermit", "monkey")
- tok = self.login("kermit", "monkey")
- # We need to manually add an email address otherwise the handler will do
- # nothing.
- now = self.hs.get_clock().time_msec()
- self.get_success(
- self.store.user_add_threepid(
- user_id=user_id,
- medium="email",
- address="kermit@example.com",
- validated_at=now,
- added_at=now,
- )
- )
- return user_id, tok
-
- def test_manual_email_send_expired_account(self) -> None:
- user_id = self.register_user("kermit", "monkey")
- tok = self.login("kermit", "monkey")
-
- # We need to manually add an email address otherwise the handler will do
- # nothing.
- now = self.hs.get_clock().time_msec()
- self.get_success(
- self.store.user_add_threepid(
- user_id=user_id,
- medium="email",
- address="kermit@example.com",
- validated_at=now,
- added_at=now,
- )
- )
-
- # Make the account expire.
- self.reactor.advance(datetime.timedelta(days=8).total_seconds())
-
- # Ignore all emails sent by the automatic background task and only focus on the
- # ones sent manually.
- self.email_attempts = []
-
- # Test that we're still able to manually trigger a mail to be sent.
- channel = self.make_request(
- b"POST",
- "/_matrix/client/unstable/account_validity/send_mail",
- access_token=tok,
- )
- self.assertEqual(channel.code, 200, msg=channel.result)
-
- self.assertEqual(len(self.email_attempts), 1)
-
-
class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py
index 0ab754a11a..83a5cbdc15 100644
--- a/tests/rest/client/test_rendezvous.py
+++ b/tests/rest/client/test_rendezvous.py
@@ -34,7 +34,6 @@ from tests import unittest
from tests.unittest import override_config
from tests.utils import HAS_AUTHLIB
-msc3886_endpoint = "/_matrix/client/unstable/org.matrix.msc3886/rendezvous"
msc4108_endpoint = "/_matrix/client/unstable/org.matrix.msc4108/rendezvous"
@@ -54,17 +53,9 @@ class RendezvousServletTestCase(unittest.HomeserverTestCase):
}
def test_disabled(self) -> None:
- channel = self.make_request("POST", msc3886_endpoint, {}, access_token=None)
- self.assertEqual(channel.code, 404)
channel = self.make_request("POST", msc4108_endpoint, {}, access_token=None)
self.assertEqual(channel.code, 404)
- @override_config({"experimental_features": {"msc3886_endpoint": "/asd"}})
- def test_msc3886_redirect(self) -> None:
- channel = self.make_request("POST", msc3886_endpoint, {}, access_token=None)
- self.assertEqual(channel.code, 307)
- self.assertEqual(channel.headers.getRawHeaders("Location"), ["/asd"])
-
@unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
@override_config(
{
@@ -126,10 +117,11 @@ class RendezvousServletTestCase(unittest.HomeserverTestCase):
headers = dict(channel.headers.getAllRawHeaders())
self.assertIn(b"ETag", headers)
self.assertIn(b"Expires", headers)
+ self.assertIn(b"Content-Length", headers)
self.assertEqual(headers[b"Content-Type"], [b"application/json"])
self.assertEqual(headers[b"Access-Control-Allow-Origin"], [b"*"])
self.assertEqual(headers[b"Access-Control-Expose-Headers"], [b"etag"])
- self.assertEqual(headers[b"Cache-Control"], [b"no-store"])
+ self.assertEqual(headers[b"Cache-Control"], [b"no-store, no-transform"])
self.assertEqual(headers[b"Pragma"], [b"no-cache"])
self.assertIn("url", channel.json_body)
self.assertTrue(channel.json_body["url"].startswith("https://"))
@@ -150,9 +142,10 @@ class RendezvousServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(headers[b"ETag"], [etag])
self.assertIn(b"Expires", headers)
self.assertEqual(headers[b"Content-Type"], [b"text/plain"])
+ self.assertEqual(headers[b"Content-Length"], [b"7"])
self.assertEqual(headers[b"Access-Control-Allow-Origin"], [b"*"])
self.assertEqual(headers[b"Access-Control-Expose-Headers"], [b"etag"])
- self.assertEqual(headers[b"Cache-Control"], [b"no-store"])
+ self.assertEqual(headers[b"Cache-Control"], [b"no-store, no-transform"])
self.assertEqual(headers[b"Pragma"], [b"no-cache"])
self.assertEqual(channel.text_body, "foo=bar")
diff --git a/tests/rest/client/test_reporting.py b/tests/rest/client/test_reporting.py
index 009deb9cb0..723553979f 100644
--- a/tests/rest/client/test_reporting.py
+++ b/tests/rest/client/test_reporting.py
@@ -156,58 +156,31 @@ class ReportRoomTestCase(unittest.HomeserverTestCase):
self.room_id = self.helper.create_room_as(
self.other_user, tok=self.other_user_tok, is_public=True
)
- self.report_path = (
- f"/_matrix/client/unstable/org.matrix.msc4151/rooms/{self.room_id}/report"
- )
+ self.report_path = f"/_matrix/client/v3/rooms/{self.room_id}/report"
- @unittest.override_config(
- {
- "experimental_features": {"msc4151_enabled": True},
- }
- )
def test_reason_str(self) -> None:
data = {"reason": "this makes me sad"}
self._assert_status(200, data)
- @unittest.override_config(
- {
- "experimental_features": {"msc4151_enabled": True},
- }
- )
def test_no_reason(self) -> None:
data = {"not_reason": "for typechecking"}
self._assert_status(400, data)
- @unittest.override_config(
- {
- "experimental_features": {"msc4151_enabled": True},
- }
- )
def test_reason_nonstring(self) -> None:
data = {"reason": 42}
self._assert_status(400, data)
- @unittest.override_config(
- {
- "experimental_features": {"msc4151_enabled": True},
- }
- )
def test_reason_null(self) -> None:
data = {"reason": None}
self._assert_status(400, data)
- @unittest.override_config(
- {
- "experimental_features": {"msc4151_enabled": True},
- }
- )
def test_cannot_report_nonexistent_room(self) -> None:
"""
Tests that we don't accept event reports for rooms which do not exist.
"""
channel = self.make_request(
"POST",
- "/_matrix/client/unstable/org.matrix.msc4151/rooms/!bloop:example.org/report",
+ "/_matrix/client/v3/rooms/!bloop:example.org/report",
{"reason": "i am very sad"},
access_token=self.other_user_tok,
shorthand=False,
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index c559dfda83..04442febb4 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -4,7 +4,7 @@
# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright 2017 Vector Creations Ltd
# Copyright 2014-2016 OpenMarket Ltd
-# Copyright (C) 2023 New Vector, Ltd
+# Copyright (C) 2023-2024 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
@@ -25,12 +25,11 @@
import json
from http import HTTPStatus
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union
from unittest.mock import AsyncMock, Mock, call, patch
from urllib import parse as urlparse
from parameterized import param, parameterized
-from typing_extensions import Literal
from twisted.test.proto_helpers import MemoryReactor
@@ -68,6 +67,7 @@ from tests.http.server._base import make_request_with_cancellation_test
from tests.storage.test_stream import PaginationTestCase
from tests.test_utils.event_injection import create_event
from tests.unittest import override_config
+from tests.utils import default_config
PATH_PREFIX = b"/_matrix/client/api/v1"
@@ -742,7 +742,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
- self.assertEqual(33, channel.resource_usage.db_txn_count)
+ self.assertEqual(35, channel.resource_usage.db_txn_count)
def test_post_room_initial_state(self) -> None:
# POST with initial_state config key, expect new room id
@@ -755,7 +755,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
- self.assertEqual(35, channel.resource_usage.db_txn_count)
+ self.assertEqual(37, channel.resource_usage.db_txn_count)
def test_post_room_visibility_key(self) -> None:
# POST with visibility config key, expect new room id
@@ -1337,17 +1337,13 @@ class RoomJoinTestCase(RoomBase):
"POST", f"/join/{self.room1}", access_token=self.tok2
)
self.assertEqual(channel.code, 403)
- self.assertEqual(
- channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
- )
+ self.assertEqual(channel.json_body["errcode"], "M_USER_SUSPENDED")
channel = self.make_request(
"POST", f"/rooms/{self.room1}/join", access_token=self.tok2
)
self.assertEqual(channel.code, 403)
- self.assertEqual(
- channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
- )
+ self.assertEqual(channel.json_body["errcode"], "M_USER_SUSPENDED")
def test_suspended_user_cannot_knock_on_room(self) -> None:
# set the user as suspended
@@ -1361,9 +1357,7 @@ class RoomJoinTestCase(RoomBase):
shorthand=False,
)
self.assertEqual(channel.code, 403)
- self.assertEqual(
- channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
- )
+ self.assertEqual(channel.json_body["errcode"], "M_USER_SUSPENDED")
def test_suspended_user_cannot_invite_to_room(self) -> None:
# set the user as suspended
@@ -1376,9 +1370,24 @@ class RoomJoinTestCase(RoomBase):
access_token=self.tok1,
content={"user_id": self.user2},
)
- self.assertEqual(
- channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
+ self.assertEqual(channel.json_body["errcode"], "M_USER_SUSPENDED")
+
+ def test_suspended_user_can_leave_room(self) -> None:
+ channel = self.make_request(
+ "POST", f"/join/{self.room1}", access_token=self.tok1
)
+ self.assertEqual(channel.code, 200)
+
+ # set the user as suspended
+ self.get_success(self.store.set_user_suspended_status(self.user1, True))
+
+ # leave room
+ channel = self.make_request(
+ "POST",
+ f"/rooms/{self.room1}/leave",
+ access_token=self.tok1,
+ )
+ self.assertEqual(channel.code, 200)
class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase):
@@ -2291,6 +2300,141 @@ class RoomMessageFilterTestCase(RoomBase):
self.assertEqual(len(chunk), 2, [event["content"] for event in chunk])
+class RoomDelayedEventTestCase(RoomBase):
+ """Tests delayed events."""
+
+ user_id = "@sid1:red"
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.room_id = self.helper.create_room_as(self.user_id)
+
+ @unittest.override_config({"max_event_delay_duration": "24h"})
+ def test_send_delayed_invalid_event(self) -> None:
+ """Test sending a delayed event with invalid content."""
+ channel = self.make_request(
+ "PUT",
+ (
+ "rooms/%s/send/m.room.message/mid1?org.matrix.msc4140.delay=2000"
+ % self.room_id
+ ).encode("ascii"),
+ {},
+ )
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
+ self.assertNotIn("org.matrix.msc4140.errcode", channel.json_body)
+
+ def test_delayed_event_unsupported_by_default(self) -> None:
+ """Test that sending a delayed event is unsupported with the default config."""
+ channel = self.make_request(
+ "PUT",
+ (
+ "rooms/%s/send/m.room.message/mid1?org.matrix.msc4140.delay=2000"
+ % self.room_id
+ ).encode("ascii"),
+ {"body": "test", "msgtype": "m.text"},
+ )
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
+ self.assertEqual(
+ "M_MAX_DELAY_UNSUPPORTED",
+ channel.json_body.get("org.matrix.msc4140.errcode"),
+ channel.json_body,
+ )
+
+ @unittest.override_config({"max_event_delay_duration": "1000"})
+ def test_delayed_event_exceeds_max_delay(self) -> None:
+ """Test that sending a delayed event fails if its delay is longer than allowed."""
+ channel = self.make_request(
+ "PUT",
+ (
+ "rooms/%s/send/m.room.message/mid1?org.matrix.msc4140.delay=2000"
+ % self.room_id
+ ).encode("ascii"),
+ {"body": "test", "msgtype": "m.text"},
+ )
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
+ self.assertEqual(
+ "M_MAX_DELAY_EXCEEDED",
+ channel.json_body.get("org.matrix.msc4140.errcode"),
+ channel.json_body,
+ )
+
+ @unittest.override_config({"max_event_delay_duration": "24h"})
+ def test_delayed_event_with_negative_delay(self) -> None:
+ """Test that sending a delayed event fails if its delay is negative."""
+ channel = self.make_request(
+ "PUT",
+ (
+ "rooms/%s/send/m.room.message/mid1?org.matrix.msc4140.delay=-2000"
+ % self.room_id
+ ).encode("ascii"),
+ {"body": "test", "msgtype": "m.text"},
+ )
+ self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
+ self.assertEqual(
+ Codes.INVALID_PARAM, channel.json_body["errcode"], channel.json_body
+ )
+
+ @unittest.override_config({"max_event_delay_duration": "24h"})
+ def test_send_delayed_message_event(self) -> None:
+ """Test sending a valid delayed message event."""
+ channel = self.make_request(
+ "PUT",
+ (
+ "rooms/%s/send/m.room.message/mid1?org.matrix.msc4140.delay=2000"
+ % self.room_id
+ ).encode("ascii"),
+ {"body": "test", "msgtype": "m.text"},
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+
+ @unittest.override_config({"max_event_delay_duration": "24h"})
+ def test_send_delayed_state_event(self) -> None:
+ """Test sending a valid delayed state event."""
+ channel = self.make_request(
+ "PUT",
+ (
+ "rooms/%s/state/m.room.topic/?org.matrix.msc4140.delay=2000"
+ % self.room_id
+ ).encode("ascii"),
+ {"topic": "This is a topic"},
+ )
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+
+ @unittest.override_config(
+ {
+ "max_event_delay_duration": "24h",
+ "rc_message": {"per_second": 1, "burst_count": 2},
+ }
+ )
+ def test_add_delayed_event_ratelimit(self) -> None:
+ """Test that requests to schedule new delayed events are ratelimited by a RateLimiter,
+ which ratelimits them correctly, including by not limiting when the requester is
+ exempt from ratelimiting.
+ """
+
+ # Test that new delayed events are correctly ratelimited.
+ args = (
+ "POST",
+ (
+ "rooms/%s/send/m.room.message?org.matrix.msc4140.delay=2000"
+ % self.room_id
+ ).encode("ascii"),
+ {"body": "test", "msgtype": "m.text"},
+ )
+ channel = self.make_request(*args)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+ channel = self.make_request(*args)
+ self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
+
+ # Add the current user to the ratelimit overrides, allowing them no ratelimiting.
+ self.get_success(
+ self.hs.get_datastores().main.set_ratelimit_for_user(self.user_id, 0, 0)
+ )
+
+ # Test that the new delayed events aren't ratelimited anymore.
+ channel = self.make_request(*args)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+
+
class RoomSearchTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -2457,6 +2601,11 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
tok=self.token,
)
+ def default_config(self) -> JsonDict:
+ config = default_config("test")
+ config["room_list_publication_rules"] = [{"action": "allow"}]
+ return config
+
def make_public_rooms_request(
self,
room_types: Optional[List[Union[str, None]]],
@@ -2794,6 +2943,68 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self.assertEqual(event_content.get("reason"), reason, channel.result)
+class RoomForgottenTestCase(unittest.HomeserverTestCase):
+ """
+ Test forget/forgotten rooms
+ """
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+
+ def test_room_not_forgotten_after_unban(self) -> None:
+ """
+ Test what happens when someone is banned from a room, they forget the room, and
+ some time later are unbanned.
+
+ Currently, when they are unbanned, the room isn't forgotten anymore which may or
+ may not be expected.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
+ self.helper.join(room_id, user1_id, tok=user1_tok)
+
+ # User1 is banned and forgets the room
+ self.helper.ban(room_id, src=user2_id, targ=user1_id, tok=user2_tok)
+ # User1 forgets the room
+ self.get_success(self.store.forget(user1_id, room_id))
+
+ # The room should show up as forgotten
+ forgotten_room_ids = self.get_success(
+ self.store.get_forgotten_rooms_for_user(user1_id)
+ )
+ self.assertIncludes(forgotten_room_ids, {room_id}, exact=True)
+
+ # Unban user1
+ self.helper.change_membership(
+ room=room_id,
+ src=user2_id,
+ targ=user1_id,
+ membership=Membership.LEAVE,
+ tok=user2_tok,
+ )
+
+ # Room is no longer forgotten because it's a new membership
+ #
+ # XXX: Is this how we actually want it to behave? It seems like ideally, the
+ # room forgotten status should only be reset when the user decides to join again
+ # (or is invited/knocks). This way the room remains forgotten for any ban/leave
+ # transitions.
+ forgotten_room_ids = self.get_success(
+ self.store.get_forgotten_rooms_for_user(user1_id)
+ )
+ self.assertIncludes(forgotten_room_ids, set(), exact=True)
+
+
class LabelsTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -3577,191 +3788,6 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400)
-class ThreepidInviteTestCase(unittest.HomeserverTestCase):
- servlets = [
- admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- ]
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.user_id = self.register_user("thomas", "hackme")
- self.tok = self.login("thomas", "hackme")
-
- self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
-
- def test_threepid_invite_spamcheck_deprecated(self) -> None:
- """
- Test allowing/blocking threepid invites with a spam-check module.
-
- In this test, we use the deprecated API in which callbacks return a bool.
- """
- # Mock a few functions to prevent the test from failing due to failing to talk to
- # a remote IS. We keep the mock for make_and_store_3pid_invite around so we
- # can check its call_count later on during the test.
- make_invite_mock = AsyncMock(return_value=(Mock(event_id="abc"), 0))
- self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[method-assign]
- self.hs.get_identity_handler().lookup_3pid = AsyncMock( # type: ignore[method-assign]
- return_value=None,
- )
-
- # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it
- # allow everything for now.
- # `spec` argument is needed for this function mock to have `__qualname__`, which
- # is needed for `Measure` metrics buried in SpamChecker.
- mock = AsyncMock(return_value=True, spec=lambda *x: None)
- self.hs.get_module_api_callbacks().spam_checker._user_may_send_3pid_invite_callbacks.append(
- mock
- )
-
- # Send a 3PID invite into the room and check that it succeeded.
- email_to_invite = "teresa@example.com"
- channel = self.make_request(
- method="POST",
- path="/rooms/" + self.room_id + "/invite",
- content={
- "id_server": "example.com",
- "id_access_token": "sometoken",
- "medium": "email",
- "address": email_to_invite,
- },
- access_token=self.tok,
- )
- self.assertEqual(channel.code, 200)
-
- # Check that the callback was called with the right params.
- mock.assert_called_with(self.user_id, "email", email_to_invite, self.room_id)
-
- # Check that the call to send the invite was made.
- make_invite_mock.assert_called_once()
-
- # Now change the return value of the callback to deny any invite and test that
- # we can't send the invite.
- mock.return_value = False
- channel = self.make_request(
- method="POST",
- path="/rooms/" + self.room_id + "/invite",
- content={
- "id_server": "example.com",
- "id_access_token": "sometoken",
- "medium": "email",
- "address": email_to_invite,
- },
- access_token=self.tok,
- )
- self.assertEqual(channel.code, 403)
-
- # Also check that it stopped before calling _make_and_store_3pid_invite.
- make_invite_mock.assert_called_once()
-
- def test_threepid_invite_spamcheck(self) -> None:
- """
- Test allowing/blocking threepid invites with a spam-check module.
-
- In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`.
- """
- # Mock a few functions to prevent the test from failing due to failing to talk to
- # a remote IS. We keep the mock for make_and_store_3pid_invite around so we
- # can check its call_count later on during the test.
- make_invite_mock = AsyncMock(return_value=(Mock(event_id="abc"), 0))
- self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[method-assign]
- self.hs.get_identity_handler().lookup_3pid = AsyncMock( # type: ignore[method-assign]
- return_value=None,
- )
-
- # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it
- # allow everything for now.
- # `spec` argument is needed for this function mock to have `__qualname__`, which
- # is needed for `Measure` metrics buried in SpamChecker.
- mock = AsyncMock(
- return_value=synapse.module_api.NOT_SPAM,
- spec=lambda *x: None,
- )
- self.hs.get_module_api_callbacks().spam_checker._user_may_send_3pid_invite_callbacks.append(
- mock
- )
-
- # Send a 3PID invite into the room and check that it succeeded.
- email_to_invite = "teresa@example.com"
- channel = self.make_request(
- method="POST",
- path="/rooms/" + self.room_id + "/invite",
- content={
- "id_server": "example.com",
- "id_access_token": "sometoken",
- "medium": "email",
- "address": email_to_invite,
- },
- access_token=self.tok,
- )
- self.assertEqual(channel.code, 200)
-
- # Check that the callback was called with the right params.
- mock.assert_called_with(self.user_id, "email", email_to_invite, self.room_id)
-
- # Check that the call to send the invite was made.
- make_invite_mock.assert_called_once()
-
- # Now change the return value of the callback to deny any invite and test that
- # we can't send the invite. We pick an arbitrary error code to be able to check
- # that the same code has been returned
- mock.return_value = Codes.CONSENT_NOT_GIVEN
- channel = self.make_request(
- method="POST",
- path="/rooms/" + self.room_id + "/invite",
- content={
- "id_server": "example.com",
- "id_access_token": "sometoken",
- "medium": "email",
- "address": email_to_invite,
- },
- access_token=self.tok,
- )
- self.assertEqual(channel.code, 403)
- self.assertEqual(channel.json_body["errcode"], Codes.CONSENT_NOT_GIVEN)
-
- # Also check that it stopped before calling _make_and_store_3pid_invite.
- make_invite_mock.assert_called_once()
-
- # Run variant with `Tuple[Codes, dict]`.
- mock.return_value = (Codes.EXPIRED_ACCOUNT, {"field": "value"})
- channel = self.make_request(
- method="POST",
- path="/rooms/" + self.room_id + "/invite",
- content={
- "id_server": "example.com",
- "id_access_token": "sometoken",
- "medium": "email",
- "address": email_to_invite,
- },
- access_token=self.tok,
- )
- self.assertEqual(channel.code, 403)
- self.assertEqual(channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT)
- self.assertEqual(channel.json_body["field"], "value")
-
- # Also check that it stopped before calling _make_and_store_3pid_invite.
- make_invite_mock.assert_called_once()
-
- def test_400_missing_param_without_id_access_token(self) -> None:
- """
- Test that a 3pid invite request returns 400 M_MISSING_PARAM
- if we do not include id_access_token.
- """
- channel = self.make_request(
- method="POST",
- path="/rooms/" + self.room_id + "/invite",
- content={
- "id_server": "example.com",
- "medium": "email",
- "address": "teresa@example.com",
- },
- access_token=self.tok,
- )
- self.assertEqual(channel.code, 400)
- self.assertEqual(channel.json_body["errcode"], "M_MISSING_PARAM")
-
-
class TimestampLookupTestCase(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets,
@@ -3836,10 +3862,25 @@ class UserSuspensionTests(unittest.HomeserverTestCase):
self.user2 = self.register_user("teresa", "hackme")
self.tok2 = self.login("teresa", "hackme")
- self.room1 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
+ self.admin = self.register_user("admin", "pass", True)
+ self.admin_tok = self.login("admin", "pass")
+
+ self.room1 = self.helper.create_room_as(
+ room_creator=self.user1, tok=self.tok1, room_version="11"
+ )
self.store = hs.get_datastores().main
- def test_suspended_user_cannot_send_message_to_room(self) -> None:
+ self.room2 = self.helper.create_room_as(
+ room_creator=self.user1, is_public=False, tok=self.tok1
+ )
+ self.helper.send_state(
+ self.room2,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=self.tok1,
+ )
+
+ def test_suspended_user_cannot_send_message_to_public_room(self) -> None:
# set the user as suspended
self.get_success(self.store.set_user_suspended_status(self.user1, True))
@@ -3849,9 +3890,25 @@ class UserSuspensionTests(unittest.HomeserverTestCase):
access_token=self.tok1,
content={"body": "hello", "msgtype": "m.text"},
)
- self.assertEqual(
- channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
+ self.assertEqual(channel.json_body["errcode"], "M_USER_SUSPENDED")
+
+ def test_suspended_user_cannot_send_message_to_encrypted_room(self) -> None:
+ channel = self.make_request(
+ "PUT",
+ f"/_synapse/admin/v1/suspend/{self.user1}",
+ {"suspend": True},
+ access_token=self.admin_tok,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body, {f"user_{self.user1}_suspended": True})
+
+ channel = self.make_request(
+ "PUT",
+ f"/rooms/{self.room2}/send/m.room.encrypted/1",
+ access_token=self.tok1,
+ content={},
)
+ self.assertEqual(channel.json_body["errcode"], "M_USER_SUSPENDED")
def test_suspended_user_cannot_change_profile_data(self) -> None:
# set the user as suspended
@@ -3864,9 +3921,7 @@ class UserSuspensionTests(unittest.HomeserverTestCase):
content={"avatar_url": "mxc://matrix.org/wefh34uihSDRGhw34"},
shorthand=False,
)
- self.assertEqual(
- channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
- )
+ self.assertEqual(channel.json_body["errcode"], "M_USER_SUSPENDED")
channel2 = self.make_request(
"PUT",
@@ -3875,9 +3930,7 @@ class UserSuspensionTests(unittest.HomeserverTestCase):
content={"displayname": "something offensive"},
shorthand=False,
)
- self.assertEqual(
- channel2.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
- )
+ self.assertEqual(channel2.json_body["errcode"], "M_USER_SUSPENDED")
def test_suspended_user_cannot_redact_messages_other_than_their_own(self) -> None:
# first user sends message
@@ -3911,9 +3964,7 @@ class UserSuspensionTests(unittest.HomeserverTestCase):
content={"reason": "bogus"},
shorthand=False,
)
- self.assertEqual(
- channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
- )
+ self.assertEqual(channel.json_body["errcode"], "M_USER_SUSPENDED")
# but can redact their own
channel = self.make_request(
@@ -3924,3 +3975,244 @@ class UserSuspensionTests(unittest.HomeserverTestCase):
shorthand=False,
)
self.assertEqual(channel.code, 200)
+
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/v3/rooms/{self.room1}/send/m.room.redaction/3456346",
+ access_token=self.tok1,
+ content={"reason": "bogus", "redacts": event_id},
+ shorthand=False,
+ )
+ self.assertEqual(channel.json_body["errcode"], "M_USER_SUSPENDED")
+
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/v3/rooms/{self.room1}/send/m.room.redaction/3456346",
+ access_token=self.tok1,
+ content={"reason": "bogus", "redacts": event_id2},
+ shorthand=False,
+ )
+ self.assertEqual(channel.code, 200)
+
+ def test_suspended_user_cannot_ban_others(self) -> None:
+ # user to ban joins room user1 created
+ self.make_request("POST", f"/rooms/{self.room1}/join", access_token=self.tok2)
+
+ # suspend user1
+ self.get_success(self.store.set_user_suspended_status(self.user1, True))
+
+ # user1 tries to ban other user while suspended
+ channel = self.make_request(
+ "POST",
+ f"/_matrix/client/v3/rooms/{self.room1}/ban",
+ access_token=self.tok1,
+ content={"reason": "spite", "user_id": self.user2},
+ shorthand=False,
+ )
+ self.assertEqual(channel.json_body["errcode"], "M_USER_SUSPENDED")
+
+ # un-suspend user1
+ self.get_success(self.store.set_user_suspended_status(self.user1, False))
+
+ # ban now goes through
+ channel = self.make_request(
+ "POST",
+ f"/_matrix/client/v3/rooms/{self.room1}/ban",
+ access_token=self.tok1,
+ content={"reason": "spite", "user_id": self.user2},
+ shorthand=False,
+ )
+ self.assertEqual(channel.code, 200)
+
+
+class RoomParticipantTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ login.register_servlets,
+ room.register_servlets,
+ profile.register_servlets,
+ admin.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.user1 = self.register_user("thomas", "hackme")
+ self.tok1 = self.login("thomas", "hackme")
+
+ self.user2 = self.register_user("teresa", "hackme")
+ self.tok2 = self.login("teresa", "hackme")
+
+ self.room1 = self.helper.create_room_as(
+ room_creator=self.user1,
+ tok=self.tok1,
+ # Allow user2 to send state events into the room.
+ extra_content={
+ "power_level_content_override": {
+ "state_default": 0,
+ },
+ },
+ )
+ self.store = hs.get_datastores().main
+
+ @parameterized.expand(
+ [
+ # Should record participation.
+ param(
+ is_state=False,
+ event_type="m.room.message",
+ event_content={
+ "msgtype": "m.text",
+ "body": "I am engaging in this room",
+ },
+ record_participation=True,
+ ),
+ param(
+ is_state=False,
+ event_type="m.room.encrypted",
+ event_content={
+ "algorithm": "m.megolm.v1.aes-sha2",
+ "ciphertext": "AwgAEnACgAkLmt6qF84IK++J7UDH2Za1YVchHyprqTqsg...",
+ "device_id": "RJYKSTBOIE",
+ "sender_key": "IlRMeOPX2e0MurIyfWEucYBRVOEEUMrOHqn/8mLqMjA",
+ "session_id": "X3lUlvLELLYxeTx4yOVu6UDpasGEVO0Jbu+QFnm0cKQ",
+ },
+ record_participation=True,
+ ),
+ # Should not record participation.
+ param(
+ is_state=False,
+ event_type="m.sticker",
+ event_content={
+ "body": "My great sticker",
+ "info": {},
+ "url": "mxc://unused/mxcurl",
+ },
+ record_participation=False,
+ ),
+ # An invalid **state event** with type `m.room.message`
+ param(
+ is_state=True,
+ event_type="m.room.message",
+ event_content={
+ "msgtype": "m.text",
+ "body": "I am engaging in this room",
+ },
+ record_participation=False,
+ ),
+ # An invalid **state event** with type `m.room.encrypted`
+ # Note: this may become valid in the future with encrypted state, though we
+ # still may not want to consider it grounds for marking a user as participating.
+ param(
+ is_state=True,
+ event_type="m.room.encrypted",
+ event_content={
+ "algorithm": "m.megolm.v1.aes-sha2",
+ "ciphertext": "AwgAEnACgAkLmt6qF84IK++J7UDH2Za1YVchHyprqTqsg...",
+ "device_id": "RJYKSTBOIE",
+ "sender_key": "IlRMeOPX2e0MurIyfWEucYBRVOEEUMrOHqn/8mLqMjA",
+ "session_id": "X3lUlvLELLYxeTx4yOVu6UDpasGEVO0Jbu+QFnm0cKQ",
+ },
+ record_participation=False,
+ ),
+ ]
+ )
+ def test_sending_message_records_participation(
+ self,
+ is_state: bool,
+ event_type: str,
+ event_content: JsonDict,
+ record_participation: bool,
+ ) -> None:
+ """
+ Test that sending an various events into a room causes the user to
+ appropriately marked or not marked as a participant in that room.
+ """
+ self.helper.join(self.room1, self.user2, tok=self.tok2)
+
+ # user has not sent any messages, so should not be a participant
+ participant = self.get_success(
+ self.store.get_room_participation(self.user2, self.room1)
+ )
+ self.assertFalse(participant)
+
+ # send an event into the room
+ if is_state:
+ # send a state event
+ self.helper.send_state(
+ self.room1,
+ event_type,
+ body=event_content,
+ tok=self.tok2,
+ )
+ else:
+ # send a non-state event
+ self.helper.send_event(
+ self.room1,
+ event_type,
+ content=event_content,
+ tok=self.tok2,
+ )
+
+ # check whether the user has been marked as a participant
+ participant = self.get_success(
+ self.store.get_room_participation(self.user2, self.room1)
+ )
+ self.assertEqual(participant, record_participation)
+
+ @parameterized.expand(
+ [
+ param(
+ event_type="m.room.message",
+ event_content={
+ "msgtype": "m.text",
+ "body": "I am engaging in this room",
+ },
+ ),
+ param(
+ event_type="m.room.encrypted",
+ event_content={
+ "algorithm": "m.megolm.v1.aes-sha2",
+ "ciphertext": "AwgAEnACgAkLmt6qF84IK++J7UDH2Za1YVchHyprqTqsg...",
+ "device_id": "RJYKSTBOIE",
+ "sender_key": "IlRMeOPX2e0MurIyfWEucYBRVOEEUMrOHqn/8mLqMjA",
+ "session_id": "X3lUlvLELLYxeTx4yOVu6UDpasGEVO0Jbu+QFnm0cKQ",
+ },
+ ),
+ ]
+ )
+ def test_sending_event_and_leaving_does_not_record_participation(
+ self,
+ event_type: str,
+ event_content: JsonDict,
+ ) -> None:
+ """
+ Test that sending an event into a room that should mark a user as a
+ participant, but then leaving the room, results in the user no longer
+ be marked as a participant in that room.
+ """
+ self.helper.join(self.room1, self.user2, tok=self.tok2)
+
+ # user has not sent any messages, so should not be a participant
+ participant = self.get_success(
+ self.store.get_room_participation(self.user2, self.room1)
+ )
+ self.assertFalse(participant)
+
+ # sending a message should now mark user as participant
+ self.helper.send_event(
+ self.room1,
+ event_type,
+ content=event_content,
+ tok=self.tok2,
+ )
+ participant = self.get_success(
+ self.store.get_room_participation(self.user2, self.room1)
+ )
+ self.assertTrue(participant)
+
+ # leave the room
+ self.helper.leave(self.room1, self.user2, tok=self.tok2)
+
+ # user should no longer be considered a participant
+ participant = self.get_success(
+ self.store.get_room_participation(self.user2, self.room1)
+ )
+ self.assertFalse(participant)
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index 2287f233b4..b406a578f0 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -88,35 +88,6 @@ class RoomTestCase(_ShadowBannedBase):
)
self.assertEqual(invited_rooms, [])
- def test_invite_3pid(self) -> None:
- """Ensure that a 3PID invite does not attempt to contact the identity server."""
- identity_handler = self.hs.get_identity_handler()
- identity_handler.lookup_3pid = Mock( # type: ignore[method-assign]
- side_effect=AssertionError("This should not get called")
- )
-
- # The create works fine.
- room_id = self.helper.create_room_as(
- self.banned_user_id, tok=self.banned_access_token
- )
-
- # Inviting the user completes successfully.
- channel = self.make_request(
- "POST",
- "/rooms/%s/invite" % (room_id,),
- {
- "id_server": "test",
- "medium": "email",
- "address": "test@test.test",
- "id_access_token": "anytoken",
- },
- access_token=self.banned_access_token,
- )
- self.assertEqual(200, channel.code, channel.result)
-
- # This should have raised an error earlier, but double check this wasn't called.
- identity_handler.lookup_3pid.assert_not_called()
-
def test_create_room(self) -> None:
"""Invitations during a room creation should be discarded, but the room still gets created."""
# The room creation is successful.
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 63df31ec75..c52a5b2e79 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -282,22 +282,33 @@ class SyncTypingTests(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code)
next_batch = channel.json_body["next_batch"]
- # This should time out! But it does not, because our stream token is
- # ahead, and therefore it's saying the typing (that we've actually
- # already seen) is new, since it's got a token above our new, now-reset
- # stream token.
- channel = self.make_request("GET", sync_url % (access_token, next_batch))
- self.assertEqual(200, channel.code)
- next_batch = channel.json_body["next_batch"]
-
# Clear the typing information, so that it doesn't think everything is
- # in the future.
+ # in the future. This happens automatically when the typing stream
+ # resets.
typing._reset()
- # Now it SHOULD fail as it never completes!
+ # Nothing new, so we time out.
with self.assertRaises(TimedOutException):
self.make_request("GET", sync_url % (access_token, next_batch))
+ # Sync and start typing again.
+ sync_channel = self.make_request(
+ "GET", sync_url % (access_token, next_batch), await_result=False
+ )
+ self.assertFalse(sync_channel.is_finished())
+
+ channel = self.make_request(
+ "PUT",
+ typing_url % (room, other_user_id, other_access_token),
+ b'{"typing": true, "timeout": 30000}',
+ )
+ self.assertEqual(200, channel.code)
+
+ # Sync should now return.
+ sync_channel.await_result()
+ self.assertEqual(200, sync_channel.code)
+ next_batch = sync_channel.json_body["next_batch"]
+
class SyncKnockTestCase(KnockingStrippedStateEventHelperMixin):
servlets = [
diff --git a/tests/rest/client/test_tags.py b/tests/rest/client/test_tags.py
new file mode 100644
index 0000000000..5d596409e1
--- /dev/null
+++ b/tests/rest/client/test_tags.py
@@ -0,0 +1,95 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+
+"""Tests REST events for /tags paths."""
+
+from http import HTTPStatus
+
+import synapse.rest.admin
+from synapse.rest.client import login, room, tags
+
+from tests import unittest
+
+
+class RoomTaggingTestCase(unittest.HomeserverTestCase):
+ """Tests /user/$user_id/rooms/$room_id/tags/$tag REST API."""
+
+ servlets = [
+ room.register_servlets,
+ tags.register_servlets,
+ login.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ ]
+
+ def test_put_tag_checks_room_membership(self) -> None:
+ """
+ Test that a user can add a tag to a room if they have membership to the room.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ tag = "test_tag"
+
+ # Make the request
+ channel = self.make_request(
+ "PUT",
+ f"/user/{user1_id}/rooms/{room_id}/tags/{tag}",
+ content={"order": 0.5},
+ access_token=user1_tok,
+ )
+ # Check that the request was successful
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+ def test_put_tag_fails_if_not_in_room(self) -> None:
+ """
+ Test that a user cannot add a tag to a room if they don't have membership to the
+ room.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ # Create the room with user2 (user1 has no membership in the room)
+ room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+ tag = "test_tag"
+
+ # Make the request
+ channel = self.make_request(
+ "PUT",
+ f"/user/{user1_id}/rooms/{room_id}/tags/{tag}",
+ content={"order": 0.5},
+ access_token=user1_tok,
+ )
+ # Check that the request failed with the correct error
+ self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
+
+ def test_put_tag_fails_if_room_does_not_exist(self) -> None:
+ """
+ Test that a user cannot add a tag to a room if the room doesn't exist (therefore
+ no membership in the room.)
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ room_id = "!nonexistent:test"
+ tag = "test_tag"
+
+ # Make the request
+ channel = self.make_request(
+ "PUT",
+ f"/user/{user1_id}/rooms/{room_id}/tags/{tag}",
+ content={"order": 0.5},
+ access_token=user1_tok,
+ )
+ # Check that the request failed with the correct error
+ self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index d10df1a90f..f02317533e 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -915,162 +915,3 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
# Check that the mock was called with the right room ID
self.assertEqual(args[1], self.room_id)
-
- def test_on_threepid_bind(self) -> None:
- """Tests that the on_threepid_bind module callback is called correctly after
- associating a 3PID to an account.
- """
- # Register a mocked callback.
- threepid_bind_mock = AsyncMock(return_value=None)
- third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
- third_party_rules._on_threepid_bind_callbacks.append(threepid_bind_mock)
-
- # Register an admin user.
- self.register_user("admin", "password", admin=True)
- admin_tok = self.login("admin", "password")
-
- # Also register a normal user we can modify.
- user_id = self.register_user("user", "password")
-
- # Add a 3PID to the user.
- channel = self.make_request(
- "PUT",
- "/_synapse/admin/v2/users/%s" % user_id,
- {
- "threepids": [
- {
- "medium": "email",
- "address": "foo@example.com",
- },
- ],
- },
- access_token=admin_tok,
- )
-
- # Check that the shutdown was blocked
- self.assertEqual(channel.code, 200, channel.json_body)
-
- # Check that the mock was called once.
- threepid_bind_mock.assert_called_once()
- args = threepid_bind_mock.call_args[0]
-
- # Check that the mock was called with the right parameters
- self.assertEqual(args, (user_id, "email", "foo@example.com"))
-
- def test_on_add_and_remove_user_third_party_identifier(self) -> None:
- """Tests that the on_add_user_third_party_identifier and
- on_remove_user_third_party_identifier module callbacks are called
- just before associating and removing a 3PID to/from an account.
- """
- # Pretend to be a Synapse module and register both callbacks as mocks.
- on_add_user_third_party_identifier_callback_mock = AsyncMock(return_value=None)
- on_remove_user_third_party_identifier_callback_mock = AsyncMock(
- return_value=None
- )
- self.hs.get_module_api().register_third_party_rules_callbacks(
- on_add_user_third_party_identifier=on_add_user_third_party_identifier_callback_mock,
- on_remove_user_third_party_identifier=on_remove_user_third_party_identifier_callback_mock,
- )
-
- # Register an admin user.
- self.register_user("admin", "password", admin=True)
- admin_tok = self.login("admin", "password")
-
- # Also register a normal user we can modify.
- user_id = self.register_user("user", "password")
-
- # Add a 3PID to the user.
- channel = self.make_request(
- "PUT",
- "/_synapse/admin/v2/users/%s" % user_id,
- {
- "threepids": [
- {
- "medium": "email",
- "address": "foo@example.com",
- },
- ],
- },
- access_token=admin_tok,
- )
-
- # Check that the mocked add callback was called with the appropriate
- # 3PID details.
- self.assertEqual(channel.code, 200, channel.json_body)
- on_add_user_third_party_identifier_callback_mock.assert_called_once()
- args = on_add_user_third_party_identifier_callback_mock.call_args[0]
- self.assertEqual(args, (user_id, "email", "foo@example.com"))
-
- # Now remove the 3PID from the user
- channel = self.make_request(
- "PUT",
- "/_synapse/admin/v2/users/%s" % user_id,
- {
- "threepids": [],
- },
- access_token=admin_tok,
- )
-
- # Check that the mocked remove callback was called with the appropriate
- # 3PID details.
- self.assertEqual(channel.code, 200, channel.json_body)
- on_remove_user_third_party_identifier_callback_mock.assert_called_once()
- args = on_remove_user_third_party_identifier_callback_mock.call_args[0]
- self.assertEqual(args, (user_id, "email", "foo@example.com"))
-
- def test_on_remove_user_third_party_identifier_is_called_on_deactivate(
- self,
- ) -> None:
- """Tests that the on_remove_user_third_party_identifier module callback is called
- when a user is deactivated and their third-party ID associations are deleted.
- """
- # Pretend to be a Synapse module and register both callbacks as mocks.
- on_remove_user_third_party_identifier_callback_mock = AsyncMock(
- return_value=None
- )
- self.hs.get_module_api().register_third_party_rules_callbacks(
- on_remove_user_third_party_identifier=on_remove_user_third_party_identifier_callback_mock,
- )
-
- # Register an admin user.
- self.register_user("admin", "password", admin=True)
- admin_tok = self.login("admin", "password")
-
- # Also register a normal user we can modify.
- user_id = self.register_user("user", "password")
-
- # Add a 3PID to the user.
- channel = self.make_request(
- "PUT",
- "/_synapse/admin/v2/users/%s" % user_id,
- {
- "threepids": [
- {
- "medium": "email",
- "address": "foo@example.com",
- },
- ],
- },
- access_token=admin_tok,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
-
- # Check that the mock was not called on the act of adding a third-party ID.
- on_remove_user_third_party_identifier_callback_mock.assert_not_called()
-
- # Now deactivate the user.
- channel = self.make_request(
- "PUT",
- "/_synapse/admin/v2/users/%s" % user_id,
- {
- "deactivated": True,
- },
- access_token=admin_tok,
- )
-
- # Check that the mocked remove callback was called with the appropriate
- # 3PID details.
- self.assertEqual(channel.code, 200, channel.json_body)
- on_remove_user_third_party_identifier_callback_mock.assert_called_once()
- args = on_remove_user_third_party_identifier_callback_mock.call_args[0]
- self.assertEqual(args, (user_id, "email", "foo@example.com"))
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index e43140720d..280486da08 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -31,6 +31,7 @@ from typing import (
AnyStr,
Dict,
Iterable,
+ Literal,
Mapping,
MutableMapping,
Optional,
@@ -40,12 +41,11 @@ from typing import (
from urllib.parse import urlencode
import attr
-from typing_extensions import Literal
from twisted.test.proto_helpers import MemoryReactorClock
from twisted.web.server import Site
-from synapse.api.constants import Membership
+from synapse.api.constants import Membership, ReceiptTypes
from synapse.api.errors import Codes
from synapse.server import HomeServer
from synapse.types import JsonDict
@@ -330,22 +330,24 @@ class RestHelper:
data,
)
- assert (
- channel.code == expect_code
- ), "Expected: %d, got: %d, PUT %s -> resp: %r" % (
- expect_code,
- channel.code,
- path,
- channel.result["body"],
+ assert channel.code == expect_code, (
+ "Expected: %d, got: %d, PUT %s -> resp: %r"
+ % (
+ expect_code,
+ channel.code,
+ path,
+ channel.result["body"],
+ )
)
if expect_errcode:
- assert (
- str(channel.json_body["errcode"]) == expect_errcode
- ), "Expected: %r, got: %r, resp: %r" % (
- expect_errcode,
- channel.json_body["errcode"],
- channel.result["body"],
+ assert str(channel.json_body["errcode"]) == expect_errcode, (
+ "Expected: %r, got: %r, resp: %r"
+ % (
+ expect_errcode,
+ channel.json_body["errcode"],
+ channel.result["body"],
+ )
)
if expect_additional_fields is not None:
@@ -354,13 +356,14 @@ class RestHelper:
expect_key,
channel.json_body,
)
- assert (
- channel.json_body[expect_key] == expect_value
- ), "Expected: %s at %s, got: %s, resp: %s" % (
- expect_value,
- expect_key,
- channel.json_body[expect_key],
- channel.json_body,
+ assert channel.json_body[expect_key] == expect_value, (
+ "Expected: %s at %s, got: %s, resp: %s"
+ % (
+ expect_value,
+ expect_key,
+ channel.json_body[expect_key],
+ channel.json_body,
+ )
)
self.auth_user_id = temp_id
@@ -545,7 +548,7 @@ class RestHelper:
room_id: str,
event_type: str,
body: Dict[str, Any],
- tok: Optional[str],
+ tok: Optional[str] = None,
expect_code: int = HTTPStatus.OK,
state_key: str = "",
) -> JsonDict:
@@ -713,9 +716,9 @@ class RestHelper:
"/login",
content={"type": "m.login.token", "token": login_token},
)
- assert (
- channel.code == expected_status
- ), f"unexpected status in response: {channel.code}"
+ assert channel.code == expected_status, (
+ f"unexpected status in response: {channel.code}"
+ )
return channel.json_body
def auth_via_oidc(
@@ -886,7 +889,7 @@ class RestHelper:
"GET",
uri,
)
- assert channel.code == 302
+ assert channel.code == 302, f"Expected 302 for {uri}, got {channel.code}"
# hit the redirect url again with the right Host header, which should now issue
# a cookie and redirect to the SSO provider.
@@ -898,17 +901,18 @@ class RestHelper:
location = get_location(channel)
parts = urllib.parse.urlsplit(location)
+ next_uri = urllib.parse.urlunsplit(("", "") + parts[2:])
channel = make_request(
self.reactor,
self.site,
"GET",
- urllib.parse.urlunsplit(("", "") + parts[2:]),
+ next_uri,
custom_headers=[
("Host", parts[1]),
],
)
- assert channel.code == 302
+ assert channel.code == 302, f"Expected 302 for {next_uri}, got {channel.code}"
channel.extract_cookies(cookies)
return get_location(channel)
@@ -944,3 +948,15 @@ class RestHelper:
assert len(p.links) == 1, "not exactly one link in confirmation page"
oauth_uri = p.links[0]
return oauth_uri
+
+ def send_read_receipt(self, room_id: str, event_id: str, *, tok: str) -> None:
+ """Send a read receipt into the room at the given event"""
+ channel = make_request(
+ self.reactor,
+ self.site,
+ method="POST",
+ path=f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{event_id}",
+ content={},
+ access_token=tok,
+ )
+ assert channel.code == HTTPStatus.OK, channel.text_body
diff --git a/tests/rest/media/test_domain_blocking.py b/tests/rest/media/test_domain_blocking.py
index 72205c6bb3..26453f70dd 100644
--- a/tests/rest/media/test_domain_blocking.py
+++ b/tests/rest/media/test_domain_blocking.py
@@ -61,6 +61,7 @@ class MediaDomainBlockingTests(unittest.HomeserverTestCase):
time_now_ms=clock.time_msec(),
upload_name="test.png",
filesystem_id=file_id,
+ sha256=file_id,
)
)
@@ -91,7 +92,8 @@ class MediaDomainBlockingTests(unittest.HomeserverTestCase):
{
# Disable downloads from a domain we won't be requesting downloads from.
# This proves we haven't broken anything.
- "prevent_media_downloads_from": ["not-listed.com"]
+ "prevent_media_downloads_from": ["not-listed.com"],
+ "enable_authenticated_media": False,
}
)
def test_remote_media_normally_unblocked(self) -> None:
@@ -132,6 +134,7 @@ class MediaDomainBlockingTests(unittest.HomeserverTestCase):
# This proves we haven't broken anything.
"prevent_media_downloads_from": ["not-listed.com"],
"dynamic_thumbnails": True,
+ "enable_authenticated_media": False,
}
)
def test_remote_media_thumbnail_normally_unblocked(self) -> None:
diff --git a/tests/rest/media/test_url_preview.py b/tests/rest/media/test_url_preview.py
index a96f0e7fca..2a7bee19f9 100644
--- a/tests/rest/media/test_url_preview.py
+++ b/tests/rest/media/test_url_preview.py
@@ -42,6 +42,7 @@ from synapse.util.stringutils import parse_and_validate_mxc_uri
from tests import unittest
from tests.server import FakeTransport
from tests.test_utils import SMALL_PNG
+from tests.unittest import override_config
try:
import lxml
@@ -877,7 +878,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
data = base64.b64encode(SMALL_PNG)
end_content = (
- b"<html><head>" b'<img src="data:image/png;base64,%s" />' b"</head></html>"
+ b'<html><head><img src="data:image/png;base64,%s" /></head></html>'
) % (data,)
channel = self.make_request(
@@ -1259,6 +1260,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertIsNone(_port)
return host, media_id
+ @override_config({"enable_authenticated_media": False})
def test_storage_providers_exclude_files(self) -> None:
"""Test that files are not stored in or fetched from storage providers."""
host, media_id = self._download_image()
@@ -1301,6 +1303,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"URL cache file was unexpectedly retrieved from a storage provider",
)
+ @override_config({"enable_authenticated_media": False})
def test_storage_providers_exclude_thumbnails(self) -> None:
"""Test that thumbnails are not stored in or fetched from storage providers."""
host, media_id = self._download_image()
diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
index e166c13bc1..96a4f5598e 100644
--- a/tests/rest/test_well_known.py
+++ b/tests/rest/test_well_known.py
@@ -17,6 +17,8 @@
# [This file includes modifications made by New Vector Limited]
#
#
+from unittest.mock import AsyncMock
+
from twisted.web.resource import Resource
from synapse.rest.well_known import well_known_resource
@@ -35,7 +37,6 @@ class WellKnownTests(unittest.HomeserverTestCase):
@unittest.override_config(
{
"public_baseurl": "https://tesths",
- "default_identity_server": "https://testis",
}
)
def test_client_well_known(self) -> None:
@@ -48,7 +49,6 @@ class WellKnownTests(unittest.HomeserverTestCase):
channel.json_body,
{
"m.homeserver": {"base_url": "https://tesths/"},
- "m.identity_server": {"base_url": "https://testis"},
},
)
@@ -67,7 +67,6 @@ class WellKnownTests(unittest.HomeserverTestCase):
@unittest.override_config(
{
"public_baseurl": "https://tesths",
- "default_identity_server": "https://testis",
"extra_well_known_client_content": {"custom": False},
}
)
@@ -81,7 +80,6 @@ class WellKnownTests(unittest.HomeserverTestCase):
channel.json_body,
{
"m.homeserver": {"base_url": "https://tesths/"},
- "m.identity_server": {"base_url": "https://testis"},
"custom": False,
},
)
@@ -112,7 +110,6 @@ class WellKnownTests(unittest.HomeserverTestCase):
"msc3861": {
"enabled": True,
"issuer": "https://issuer",
- "account_management_url": "https://my-account.issuer",
"client_id": "id",
"client_auth_method": "client_secret_post",
"client_secret": "secret",
@@ -122,18 +119,33 @@ class WellKnownTests(unittest.HomeserverTestCase):
}
)
def test_client_well_known_msc3861_oauth_delegation(self) -> None:
- channel = self.make_request(
- "GET", "/.well-known/matrix/client", shorthand=False
+ # Patch the HTTP client to return the issuer metadata
+ req_mock = AsyncMock(
+ return_value={
+ "issuer": "https://issuer",
+ "account_management_uri": "https://my-account.issuer",
+ }
)
+ self.hs.get_proxied_http_client().get_json = req_mock # type: ignore[method-assign]
- self.assertEqual(channel.code, 200)
- self.assertEqual(
- channel.json_body,
- {
- "m.homeserver": {"base_url": "https://homeserver/"},
- "org.matrix.msc2965.authentication": {
- "issuer": "https://issuer",
- "account": "https://my-account.issuer",
+ for _ in range(2):
+ channel = self.make_request(
+ "GET", "/.well-known/matrix/client", shorthand=False
+ )
+
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body,
+ {
+ "m.homeserver": {"base_url": "https://homeserver/"},
+ "org.matrix.msc2965.authentication": {
+ "issuer": "https://issuer",
+ "account": "https://my-account.issuer",
+ },
},
- },
+ )
+
+ # It should have been called exactly once, because it gets cached
+ req_mock.assert_called_once_with(
+ "https://issuer/.well-known/openid-configuration"
)
diff --git a/tests/server.py b/tests/server.py
index 3e377585ce..f01708b77f 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -58,6 +58,7 @@ import twisted
from twisted.enterprise import adbapi
from twisted.internet import address, tcp, threads, udp
from twisted.internet._resolver import SimpleResolverComplexifier
+from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import (
@@ -73,6 +74,7 @@ from twisted.internet.interfaces import (
IReactorPluggableNameResolver,
IReactorTime,
IResolverSimple,
+ ITCPTransport,
ITransport,
)
from twisted.internet.protocol import ClientFactory, DatagramProtocol, Factory
@@ -223,9 +225,9 @@ class FakeChannel:
new_headers.addRawHeader(k, v)
headers = new_headers
- assert isinstance(
- headers, Headers
- ), f"headers are of the wrong type: {headers!r}"
+ assert isinstance(headers, Headers), (
+ f"headers are of the wrong type: {headers!r}"
+ )
self.result["headers"] = headers
@@ -341,7 +343,6 @@ class FakeSite:
self,
resource: IResource,
reactor: IReactorTime,
- experimental_cors_msc3886: bool = False,
):
"""
@@ -350,7 +351,6 @@ class FakeSite:
"""
self._resource = resource
self.reactor = reactor
- self.experimental_cors_msc3886 = experimental_cors_msc3886
def getResourceFor(self, request: Request) -> IResource:
return self._resource
@@ -780,7 +780,7 @@ def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
return clock, hs_clock
-@implementer(ITransport)
+@implementer(ITCPTransport)
@attr.s(cmp=False, auto_attribs=True)
class FakeTransport:
"""
@@ -809,12 +809,12 @@ class FakeTransport:
will get called back for connectionLost() notifications etc.
"""
- _peer_address: IAddress = attr.Factory(
+ _peer_address: Union[IPv4Address, IPv6Address] = attr.Factory(
lambda: address.IPv4Address("TCP", "127.0.0.1", 5678)
)
"""The value to be returned by getPeer"""
- _host_address: IAddress = attr.Factory(
+ _host_address: Union[IPv4Address, IPv6Address] = attr.Factory(
lambda: address.IPv4Address("TCP", "127.0.0.1", 1234)
)
"""The value to be returned by getHost"""
@@ -826,10 +826,10 @@ class FakeTransport:
producer: Optional[IPushProducer] = None
autoflush: bool = True
- def getPeer(self) -> IAddress:
+ def getPeer(self) -> Union[IPv4Address, IPv6Address]:
return self._peer_address
- def getHost(self) -> IAddress:
+ def getHost(self) -> Union[IPv4Address, IPv6Address]:
return self._host_address
def loseConnection(self) -> None:
@@ -939,6 +939,51 @@ class FakeTransport:
logger.info("FakeTransport: Buffer now empty, completing disconnect")
self.disconnected = True
+ ## ITCPTransport methods. ##
+
+ def loseWriteConnection(self) -> None:
+ """
+ Half-close the write side of a TCP connection.
+
+ If the protocol instance this is attached to provides
+ IHalfCloseableProtocol, it will get notified when the operation is
+ done. When closing write connection, as with loseConnection this will
+ only happen when buffer has emptied and there is no registered
+ producer.
+ """
+ raise NotImplementedError()
+
+ def getTcpNoDelay(self) -> bool:
+ """
+ Return if C{TCP_NODELAY} is enabled.
+ """
+ return False
+
+ def setTcpNoDelay(self, enabled: bool) -> None:
+ """
+ Enable/disable C{TCP_NODELAY}.
+
+ Enabling C{TCP_NODELAY} turns off Nagle's algorithm. Small packets are
+ sent sooner, possibly at the expense of overall throughput.
+ """
+ # Ignore setting this.
+
+ def getTcpKeepAlive(self) -> bool:
+ """
+ Return if C{SO_KEEPALIVE} is enabled.
+ """
+ return False
+
+ def setTcpKeepAlive(self, enabled: bool) -> None:
+ """
+ Enable/disable C{SO_KEEPALIVE}.
+
+ Enabling C{SO_KEEPALIVE} sends packets periodically when the connection
+ is otherwise idle, usually once every two hours. They are intended
+ to allow detection of lost peers in a non-infinite amount of time.
+ """
+ # Ignore setting this.
+
def connect_client(
reactor: ThreadedMemoryReactorClock, client_id: int
@@ -1166,6 +1211,12 @@ def setup_test_homeserver(
hs.get_auth_handler().validate_hash = validate_hash # type: ignore[assignment]
+ # We need to replace the media threadpool with the fake test threadpool.
+ def thread_pool() -> threadpool.ThreadPool:
+ return reactor.getThreadPool()
+
+ hs.get_media_sender_thread_pool = thread_pool # type: ignore[method-assign]
+
# Load any configured modules into the homeserver
module_api = hs.get_module_api()
for module, module_config in hs.config.modules.loaded_modules:
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 0e3e4f7293..997ee7b91b 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -89,7 +89,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
return_value="!something:localhost"
)
self._rlsn._store.add_tag_to_room = AsyncMock(return_value=None) # type: ignore[method-assign]
- self._rlsn._store.get_tags_for_room = AsyncMock(return_value={}) # type: ignore[method-assign]
+ self._rlsn._store.get_tags_for_room = AsyncMock(return_value={})
@override_config({"hs_disabled": True})
def test_maybe_send_server_notice_disabled_hs(self) -> None:
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index fd1f5e7fd5..104d141a72 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -20,7 +20,7 @@
#
import json
from contextlib import contextmanager
-from typing import Generator, List, Tuple
+from typing import Generator, List, Set, Tuple
from unittest import mock
from twisted.enterprise.adbapi import ConnectionPool
@@ -295,6 +295,53 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
+class GetEventsTestCase(unittest.HomeserverTestCase):
+ """Test `get_events(...)`/`get_events_as_list(...)`"""
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store: EventsWorkerStore = hs.get_datastores().main
+
+ def test_get_lots_of_messages(self) -> None:
+ """Sanity check that `get_events(...)`/`get_events_as_list(...)` works"""
+ num_events = 100
+
+ user_id = self.register_user("user", "pass")
+ user_tok = self.login(user_id, "pass")
+
+ room_id = self.helper.create_room_as(user_id, tok=user_tok)
+
+ event_ids: Set[str] = set()
+ for i in range(num_events):
+ event = self.get_success(
+ inject_event(
+ self.hs,
+ room_id=room_id,
+ type="m.room.message",
+ sender=user_id,
+ content={
+ "body": f"foo{i}",
+ "msgtype": "m.text",
+ },
+ )
+ )
+ event_ids.add(event.event_id)
+
+ # Sanity check that we actually created the events
+ self.assertEqual(len(event_ids), num_events)
+
+ # This is the function under test
+ fetched_event_map = self.get_success(self.store.get_events(event_ids))
+
+ # Sanity check that we got the events back
+ self.assertIncludes(fetched_event_map.keys(), event_ids, exact=True)
+
+
class DatabaseOutageTestCase(unittest.HomeserverTestCase):
"""Test event fetching during a database outage."""
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 506d981ce6..49dc973a36 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -112,6 +112,24 @@ class UpdateUpsertManyTests(unittest.HomeserverTestCase):
{(1, "user1", "hello"), (2, "user2", "bleb")},
)
+ self.get_success(
+ self.storage.db_pool.runInteraction(
+ "test",
+ self.storage.db_pool.simple_upsert_many_txn,
+ self.table_name,
+ key_names=key_names,
+ key_values=[[2, "user2"]],
+ value_names=[],
+ value_values=[],
+ )
+ )
+
+ # Check results are what we expect
+ self.assertEqual(
+ set(self._dump_table_to_tuple()),
+ {(1, "user1", "hello"), (2, "user2", "bleb")},
+ )
+
def test_simple_update_many(self) -> None:
"""
simple_update_many performs many updates at once.
diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py
index 2859bcf4bd..0e52dd26ce 100644
--- a/tests/storage/test_account_data.py
+++ b/tests/storage/test_account_data.py
@@ -24,6 +24,7 @@ from typing import Iterable, Optional, Set
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import AccountDataTypes
+from synapse.api.errors import Codes, SynapseError
from synapse.server import HomeServer
from synapse.util import Clock
@@ -93,6 +94,20 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
# Check the removed user.
self.assert_ignorers("@another:remote", {self.user})
+ def test_ignoring_self_fails(self) -> None:
+ """Ensure users cannot add themselves to the ignored list."""
+
+ f = self.get_failure(
+ self.store.add_account_data_for_user(
+ self.user,
+ AccountDataTypes.IGNORED_USER_LIST,
+ {"ignored_users": {self.user: {}}},
+ ),
+ SynapseError,
+ ).value
+ self.assertEqual(f.code, 400)
+ self.assertEqual(f.errcode, Codes.INVALID_PARAM)
+
def test_caching(self) -> None:
"""Ensure that caching works properly between different users."""
# The first user ignores a user.
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 9420d03841..11313fc933 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -349,7 +349,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)
self.mock_txn.execute.assert_called_once_with(
- "UPDATE tablename SET colC = ?, colD = ? WHERE" " colA = ? AND colB = ?",
+ "UPDATE tablename SET colC = ?, colD = ? WHERE colA = ? AND colB = ?",
[3, 4, 1, 2],
)
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index ba01b038ab..74edca7523 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -211,9 +211,9 @@ class DeviceStoreTestCase(HomeserverTestCase):
even if that means leaving an earlier batch one EDU short of the limit.
"""
- assert self.hs.is_mine_id(
- "@user_id:test"
- ), "Test not valid: this MXID should be considered local"
+ assert self.hs.is_mine_id("@user_id:test"), (
+ "Test not valid: this MXID should be considered local"
+ )
self.get_success(
self.store.set_e2e_cross_signing_key(
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 088f0d24f9..0500c68e9d 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -114,7 +114,7 @@ def get_all_topologically_sorted_orders(
# This is implemented by Kahn's algorithm, and forking execution each time
# we have a choice over which node to consider next.
- degree_map = {node: 0 for node in nodes}
+ degree_map = dict.fromkeys(nodes, 0)
reverse_graph: Dict[T, Set[T]] = {}
for node, edges in graph.items():
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 233066bf82..b095090535 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -101,14 +101,6 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
)
self.assertEqual(2, len(http_actions))
- # Fetch unread actions for email pushers.
- email_actions = self.get_success(
- self.store.get_unread_push_actions_for_user_in_range_for_email(
- user_id, 0, 1000, 20
- )
- )
- self.assertEqual(2, len(email_actions))
-
# Send a receipt, which should clear the first action.
self.get_success(
self.store.insert_receipt(
@@ -126,12 +118,6 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
)
)
self.assertEqual(1, len(http_actions))
- email_actions = self.get_success(
- self.store.get_unread_push_actions_for_user_in_range_for_email(
- user_id, 0, 1000, 20
- )
- )
- self.assertEqual(1, len(email_actions))
# Send a thread receipt to clear the thread action.
self.get_success(
@@ -150,12 +136,6 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
)
)
self.assertEqual([], http_actions)
- email_actions = self.get_success(
- self.store.get_unread_push_actions_for_user_in_range_for_email(
- user_id, 0, 1000, 20
- )
- )
- self.assertEqual([], email_actions)
def test_count_aggregation(self) -> None:
# Create a user to receive notifications and send receipts.
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index 0a7c4c9421..cb3d8e19bc 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -19,6 +19,7 @@
#
#
+import logging
from typing import List, Optional
from twisted.test.proto_helpers import MemoryReactor
@@ -35,6 +36,8 @@ from synapse.util import Clock
from tests.unittest import HomeserverTestCase
+logger = logging.getLogger(__name__)
+
class ExtremPruneTestCase(HomeserverTestCase):
servlets = [
diff --git a/tests/storage/test_events_bg_updates.py b/tests/storage/test_events_bg_updates.py
new file mode 100644
index 0000000000..ecdf413e3b
--- /dev/null
+++ b/tests/storage/test_events_bg_updates.py
@@ -0,0 +1,157 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2025 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+#
+
+from typing import Dict
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.constants import MAX_DEPTH
+from synapse.api.room_versions import RoomVersion, RoomVersions
+from synapse.server import HomeServer
+from synapse.util import Clock
+
+from tests.unittest import HomeserverTestCase
+
+
+class TestFixupMaxDepthCapBgUpdate(HomeserverTestCase):
+ """Test the background update that caps topological_ordering at MAX_DEPTH."""
+
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
+ self.store = self.hs.get_datastores().main
+ self.db_pool = self.store.db_pool
+
+ self.room_id = "!testroom:example.com"
+
+ # Reinsert the background update as it was already run at the start of
+ # the test.
+ self.get_success(
+ self.db_pool.simple_insert(
+ table="background_updates",
+ values={
+ "update_name": "fixup_max_depth_cap",
+ "progress_json": "{}",
+ },
+ )
+ )
+
+ def create_room(self, room_version: RoomVersion) -> Dict[str, int]:
+ """Create a room with a known room version and insert events.
+
+ Returns the set of event IDs that exceed MAX_DEPTH and
+ their depth.
+ """
+
+ # Create a room with a specific room version
+ self.get_success(
+ self.db_pool.simple_insert(
+ table="rooms",
+ values={
+ "room_id": self.room_id,
+ "room_version": room_version.identifier,
+ },
+ )
+ )
+
+ # Insert events with some depths exceeding MAX_DEPTH
+ event_id_to_depth: Dict[str, int] = {}
+ for depth in range(MAX_DEPTH - 5, MAX_DEPTH + 5):
+ event_id = f"$event{depth}:example.com"
+ event_id_to_depth[event_id] = depth
+
+ self.get_success(
+ self.db_pool.simple_insert(
+ table="events",
+ values={
+ "event_id": event_id,
+ "room_id": self.room_id,
+ "topological_ordering": depth,
+ "depth": depth,
+ "type": "m.test",
+ "sender": "@user:test",
+ "processed": True,
+ "outlier": False,
+ },
+ )
+ )
+
+ return event_id_to_depth
+
+ def test_fixup_max_depth_cap_bg_update(self) -> None:
+ """Test that the background update correctly caps topological_ordering
+ at MAX_DEPTH."""
+
+ event_id_to_depth = self.create_room(RoomVersions.V6)
+
+ # Run the background update
+ progress = {"room_id": ""}
+ batch_size = 10
+ num_rooms = self.get_success(
+ self.store.fixup_max_depth_cap_bg_update(progress, batch_size)
+ )
+
+ # Verify the number of rooms processed
+ self.assertEqual(num_rooms, 1)
+
+ # Verify that the topological_ordering of events has been capped at
+ # MAX_DEPTH
+ rows = self.get_success(
+ self.db_pool.simple_select_list(
+ table="events",
+ keyvalues={"room_id": self.room_id},
+ retcols=["event_id", "topological_ordering"],
+ )
+ )
+
+ for event_id, topological_ordering in rows:
+ if event_id_to_depth[event_id] >= MAX_DEPTH:
+ # Events with a depth greater than or equal to MAX_DEPTH should
+ # be capped at MAX_DEPTH.
+ self.assertEqual(topological_ordering, MAX_DEPTH)
+ else:
+ # Events with a depth less than MAX_DEPTH should remain
+ # unchanged.
+ self.assertEqual(topological_ordering, event_id_to_depth[event_id])
+
+ def test_fixup_max_depth_cap_bg_update_old_room_version(self) -> None:
+ """Test that the background update does not cap topological_ordering for
+ rooms with old room versions."""
+
+ event_id_to_depth = self.create_room(RoomVersions.V5)
+
+ # Run the background update
+ progress = {"room_id": ""}
+ batch_size = 10
+ num_rooms = self.get_success(
+ self.store.fixup_max_depth_cap_bg_update(progress, batch_size)
+ )
+
+ # Verify the number of rooms processed
+ self.assertEqual(num_rooms, 0)
+
+ # Verify that the topological_ordering of events has been capped at
+ # MAX_DEPTH
+ rows = self.get_success(
+ self.db_pool.simple_select_list(
+ table="events",
+ keyvalues={"room_id": self.room_id},
+ retcols=["event_id", "topological_ordering"],
+ )
+ )
+
+ # Assert that the topological_ordering of events has not been changed
+ # from their depth.
+ self.assertDictEqual(event_id_to_depth, dict(rows))
diff --git a/tests/storage/test_invite_rule.py b/tests/storage/test_invite_rule.py
new file mode 100644
index 0000000000..38c97ecaa3
--- /dev/null
+++ b/tests/storage/test_invite_rule.py
@@ -0,0 +1,167 @@
+from synapse.storage.invite_rule import InviteRule, InviteRulesConfig
+from synapse.types import UserID
+
+from tests import unittest
+
+regular_user = UserID.from_string("@test:example.org")
+allowed_user = UserID.from_string("@allowed:allow.example.org")
+blocked_user = UserID.from_string("@blocked:block.example.org")
+ignored_user = UserID.from_string("@ignored:ignore.example.org")
+
+
+class InviteFilterTestCase(unittest.TestCase):
+ def test_empty(self) -> None:
+ """Permit by default"""
+ config = InviteRulesConfig(None)
+ self.assertEqual(
+ config.get_invite_rule(regular_user.to_string()), InviteRule.ALLOW
+ )
+
+ def test_ignore_invalid(self) -> None:
+ """Invalid strings are ignored"""
+ config = InviteRulesConfig({"blocked_users": ["not a user"]})
+ self.assertEqual(
+ config.get_invite_rule(blocked_user.to_string()), InviteRule.ALLOW
+ )
+
+ def test_user_blocked(self) -> None:
+ """Permit all, except explicitly blocked users"""
+ config = InviteRulesConfig({"blocked_users": [blocked_user.to_string()]})
+ self.assertEqual(
+ config.get_invite_rule(blocked_user.to_string()), InviteRule.BLOCK
+ )
+ self.assertEqual(
+ config.get_invite_rule(regular_user.to_string()), InviteRule.ALLOW
+ )
+
+ def test_user_ignored(self) -> None:
+ """Permit all, except explicitly ignored users"""
+ config = InviteRulesConfig({"ignored_users": [ignored_user.to_string()]})
+ self.assertEqual(
+ config.get_invite_rule(ignored_user.to_string()), InviteRule.IGNORE
+ )
+ self.assertEqual(
+ config.get_invite_rule(regular_user.to_string()), InviteRule.ALLOW
+ )
+
+ def test_user_precedence(self) -> None:
+ """Always take allowed over ignored, ignored over blocked, and then block."""
+ config = InviteRulesConfig(
+ {
+ "allowed_users": [allowed_user.to_string()],
+ "ignored_users": [allowed_user.to_string(), ignored_user.to_string()],
+ "blocked_users": [
+ allowed_user.to_string(),
+ ignored_user.to_string(),
+ blocked_user.to_string(),
+ ],
+ }
+ )
+ self.assertEqual(
+ config.get_invite_rule(allowed_user.to_string()), InviteRule.ALLOW
+ )
+ self.assertEqual(
+ config.get_invite_rule(ignored_user.to_string()), InviteRule.IGNORE
+ )
+ self.assertEqual(
+ config.get_invite_rule(blocked_user.to_string()), InviteRule.BLOCK
+ )
+
+ def test_server_blocked(self) -> None:
+ """Block all users on the server except those allowed."""
+ user_on_same_server = UserID("blocked", allowed_user.domain)
+ config = InviteRulesConfig(
+ {
+ "allowed_users": [allowed_user.to_string()],
+ "blocked_servers": [allowed_user.domain],
+ }
+ )
+ self.assertEqual(
+ config.get_invite_rule(allowed_user.to_string()), InviteRule.ALLOW
+ )
+ self.assertEqual(
+ config.get_invite_rule(user_on_same_server.to_string()), InviteRule.BLOCK
+ )
+
+ def test_server_ignored(self) -> None:
+ """Ignore all users on the server except those allowed."""
+ user_on_same_server = UserID("ignored", allowed_user.domain)
+ config = InviteRulesConfig(
+ {
+ "allowed_users": [allowed_user.to_string()],
+ "ignored_servers": [allowed_user.domain],
+ }
+ )
+ self.assertEqual(
+ config.get_invite_rule(allowed_user.to_string()), InviteRule.ALLOW
+ )
+ self.assertEqual(
+ config.get_invite_rule(user_on_same_server.to_string()), InviteRule.IGNORE
+ )
+
+ def test_server_allow(self) -> None:
+ """Allow all from a server except explictly blocked or ignored users."""
+ blocked_user_on_same_server = UserID("blocked", allowed_user.domain)
+ ignored_user_on_same_server = UserID("ignored", allowed_user.domain)
+ allowed_user_on_same_server = UserID("another", allowed_user.domain)
+ config = InviteRulesConfig(
+ {
+ "ignored_users": [ignored_user_on_same_server.to_string()],
+ "blocked_users": [blocked_user_on_same_server.to_string()],
+ "allowed_servers": [allowed_user.to_string()],
+ }
+ )
+ self.assertEqual(
+ config.get_invite_rule(allowed_user.to_string()), InviteRule.ALLOW
+ )
+ self.assertEqual(
+ config.get_invite_rule(allowed_user_on_same_server.to_string()),
+ InviteRule.ALLOW,
+ )
+ self.assertEqual(
+ config.get_invite_rule(blocked_user_on_same_server.to_string()),
+ InviteRule.BLOCK,
+ )
+ self.assertEqual(
+ config.get_invite_rule(ignored_user_on_same_server.to_string()),
+ InviteRule.IGNORE,
+ )
+
+ def test_server_precedence(self) -> None:
+ """Always take allowed over ignored, ignored over blocked, and then block."""
+ config = InviteRulesConfig(
+ {
+ "allowed_servers": [allowed_user.domain],
+ "ignored_servers": [allowed_user.domain, ignored_user.domain],
+ "blocked_servers": [
+ allowed_user.domain,
+ ignored_user.domain,
+ blocked_user.domain,
+ ],
+ }
+ )
+ self.assertEqual(
+ config.get_invite_rule(allowed_user.to_string()), InviteRule.ALLOW
+ )
+ self.assertEqual(
+ config.get_invite_rule(ignored_user.to_string()), InviteRule.IGNORE
+ )
+ self.assertEqual(
+ config.get_invite_rule(blocked_user.to_string()), InviteRule.BLOCK
+ )
+
+ def test_server_glob(self) -> None:
+ """Test that glob patterns match"""
+ config = InviteRulesConfig({"blocked_servers": ["*.example.org"]})
+ self.assertEqual(
+ config.get_invite_rule(allowed_user.to_string()), InviteRule.BLOCK
+ )
+ self.assertEqual(
+ config.get_invite_rule(ignored_user.to_string()), InviteRule.BLOCK
+ )
+ self.assertEqual(
+ config.get_invite_rule(blocked_user.to_string()), InviteRule.BLOCK
+ )
+ self.assertEqual(
+ config.get_invite_rule(regular_user.to_string()), InviteRule.ALLOW
+ )
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 15ae582051..c453c8b642 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -32,13 +32,6 @@ from tests.unittest import default_config, override_config
FORTY_DAYS = 40 * 24 * 60 * 60
-def gen_3pids(count: int) -> List[Dict[str, Any]]:
- """Generate `count` threepids as a list."""
- return [
- {"medium": "email", "address": "user%i@matrix.org" % i} for i in range(count)
- ]
-
-
class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
def default_config(self) -> Dict[str, Any]:
config = default_config("test")
@@ -57,87 +50,6 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
# Advance the clock a bit
self.reactor.advance(FORTY_DAYS)
- @override_config({"max_mau_value": 3, "mau_limit_reserved_threepids": gen_3pids(3)})
- def test_initialise_reserved_users(self) -> None:
- threepids = self.hs.config.server.mau_limits_reserved_threepids
-
- # register three users, of which two have reserved 3pids, and a third
- # which is a support user.
- user1 = "@user1:server"
- user1_email = threepids[0]["address"]
- user2 = "@user2:server"
- user2_email = threepids[1]["address"]
- user3 = "@user3:server"
-
- self.get_success(self.store.register_user(user_id=user1))
- self.get_success(self.store.register_user(user_id=user2))
- self.get_success(
- self.store.register_user(user_id=user3, user_type=UserTypes.SUPPORT)
- )
-
- now = int(self.hs.get_clock().time_msec())
- self.get_success(
- self.store.user_add_threepid(user1, "email", user1_email, now, now)
- )
- self.get_success(
- self.store.user_add_threepid(user2, "email", user2_email, now, now)
- )
-
- # XXX why are we doing this here? this function is only run at startup
- # so it is odd to re-run it here.
- self.get_success(
- self.store.db_pool.runInteraction(
- "initialise", self.store._initialise_reserved_users, threepids
- )
- )
-
- # the number of users we expect will be counted against the mau limit
- # -1 because user3 is a support user and does not count
- user_num = len(threepids) - 1
-
- # Check the number of active users. Ensure user3 (support user) is not counted
- active_count = self.get_success(self.store.get_monthly_active_count())
- self.assertEqual(active_count, user_num)
-
- # Test each of the registered users is marked as active
- timestamp = self.get_success(self.store.user_last_seen_monthly_active(user1))
- # Mypy notes that one shouldn't compare Optional[int] to 0 with assertGreater.
- # Check that timestamp really is an int.
- assert timestamp is not None
- self.assertGreater(timestamp, 0)
- timestamp = self.get_success(self.store.user_last_seen_monthly_active(user2))
- assert timestamp is not None
- self.assertGreater(timestamp, 0)
-
- # Test that users with reserved 3pids are not removed from the MAU table
- # XXX some of this is redundant. poking things into the config shouldn't
- # work, and in any case it's not obvious what we expect to happen when
- # we advance the reactor.
- self.hs.config.server.max_mau_value = 0
- self.reactor.advance(FORTY_DAYS)
- self.hs.config.server.max_mau_value = 5
-
- self.get_success(self.store.reap_monthly_active_users())
-
- active_count = self.get_success(self.store.get_monthly_active_count())
- self.assertEqual(active_count, user_num)
-
- # Add some more users and check they are counted as active
- ru_count = 2
-
- self.get_success(self.store.upsert_monthly_active_user("@ru1:server"))
- self.get_success(self.store.upsert_monthly_active_user("@ru2:server"))
-
- active_count = self.get_success(self.store.get_monthly_active_count())
- self.assertEqual(active_count, user_num + ru_count)
-
- # now run the reaper and check that the number of active users is reduced
- # to max_mau_value
- self.get_success(self.store.reap_monthly_active_users())
-
- active_count = self.get_success(self.store.get_monthly_active_count())
- self.assertEqual(active_count, 3)
-
def test_can_insert_and_count_mau(self) -> None:
count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(count, 0)
@@ -206,49 +118,6 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(count, 0)
- # Note that below says mau_limit (no s), this is the name of the config
- # value, although it gets stored on the config object as mau_limits.
- @override_config({"max_mau_value": 5, "mau_limit_reserved_threepids": gen_3pids(5)})
- def test_reap_monthly_active_users_reserved_users(self) -> None:
- """Tests that reaping correctly handles reaping where reserved users are
- present"""
- threepids = self.hs.config.server.mau_limits_reserved_threepids
- initial_users = len(threepids)
- reserved_user_number = initial_users - 1
- for i in range(initial_users):
- user = "@user%d:server" % i
- email = "user%d@matrix.org" % i
-
- self.get_success(self.store.upsert_monthly_active_user(user))
-
- # Need to ensure that the most recent entries in the
- # monthly_active_users table are reserved
- now = int(self.hs.get_clock().time_msec())
- if i != 0:
- self.get_success(
- self.store.register_user(user_id=user, password_hash=None)
- )
- self.get_success(
- self.store.user_add_threepid(user, "email", email, now, now)
- )
-
- self.get_success(
- self.store.db_pool.runInteraction(
- "initialise", self.store._initialise_reserved_users, threepids
- )
- )
-
- count = self.get_success(self.store.get_monthly_active_count())
- self.assertEqual(count, initial_users)
-
- users = self.get_success(self.store.get_registered_reserved_users())
- self.assertEqual(len(users), reserved_user_number)
-
- self.get_success(self.store.reap_monthly_active_users())
-
- count = self.get_success(self.store.get_monthly_active_count())
- self.assertEqual(count, self.hs.config.server.max_mau_value)
-
def test_populate_monthly_users_is_guest(self) -> None:
# Test that guest users are not added to mau list
user_id = "@user_id:host"
@@ -289,46 +158,6 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_not_called()
- def test_get_reserved_real_user_account(self) -> None:
- # Test no reserved users, or reserved threepids
- users = self.get_success(self.store.get_registered_reserved_users())
- self.assertEqual(len(users), 0)
-
- # Test reserved users but no registered users
- user1 = "@user1:example.com"
- user2 = "@user2:example.com"
-
- user1_email = "user1@example.com"
- user2_email = "user2@example.com"
- threepids = [
- {"medium": "email", "address": user1_email},
- {"medium": "email", "address": user2_email},
- ]
-
- self.hs.config.server.mau_limits_reserved_threepids = threepids
- d = self.store.db_pool.runInteraction(
- "initialise", self.store._initialise_reserved_users, threepids
- )
- self.get_success(d)
-
- users = self.get_success(self.store.get_registered_reserved_users())
- self.assertEqual(len(users), 0)
-
- # Test reserved registered users
- self.get_success(self.store.register_user(user_id=user1, password_hash=None))
- self.get_success(self.store.register_user(user_id=user2, password_hash=None))
-
- now = int(self.hs.get_clock().time_msec())
- self.get_success(
- self.store.user_add_threepid(user1, "email", user1_email, now, now)
- )
- self.get_success(
- self.store.user_add_threepid(user2, "email", user2_email, now, now)
- )
-
- users = self.get_success(self.store.get_registered_reserved_users())
- self.assertEqual(len(users), len(threepids))
-
def test_support_user_not_add_to_mau_limits(self) -> None:
support_user_id = "@support:test"
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 080d5640a5..0aa14fd1f4 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -23,6 +23,8 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import NotFoundError, SynapseError
from synapse.rest.client import room
from synapse.server import HomeServer
+from synapse.types.state import StateFilter
+from synapse.types.storage import _BackgroundUpdates
from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -40,6 +42,8 @@ class PurgeTests(HomeserverTestCase):
self.room_id = self.helper.create_room_as(self.user_id)
self.store = hs.get_datastores().main
+ self.state_store = hs.get_datastores().state
+ self.state_deletion_store = hs.get_datastores().state_deletion
self._storage_controllers = self.hs.get_storage_controllers()
def test_purge_history(self) -> None:
@@ -128,3 +132,328 @@ class PurgeTests(HomeserverTestCase):
self.store._invalidate_local_get_event_cache(create_event.event_id)
self.get_failure(self.store.get_event(create_event.event_id), NotFoundError)
self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)
+
+ def test_purge_history_deletes_state_groups(self) -> None:
+ """Test that unreferenced state groups get cleaned up after purge"""
+
+ # Send four state changes to the room.
+ first = self.helper.send_state(
+ self.room_id, event_type="m.foo", body={"test": 1}
+ )
+ second = self.helper.send_state(
+ self.room_id, event_type="m.foo", body={"test": 2}
+ )
+ third = self.helper.send_state(
+ self.room_id, event_type="m.foo", body={"test": 3}
+ )
+ last = self.helper.send_state(
+ self.room_id, event_type="m.foo", body={"test": 4}
+ )
+
+ # Get references to the state groups
+ event_to_groups = self.get_success(
+ self.store._get_state_group_for_events(
+ [
+ first["event_id"],
+ second["event_id"],
+ third["event_id"],
+ last["event_id"],
+ ]
+ )
+ )
+
+ # Get the topological token
+ token = self.get_success(
+ self.store.get_topological_token_for_event(last["event_id"])
+ )
+ token_str = self.get_success(token.to_string(self.hs.get_datastores().main))
+
+ # Purge everything before this topological token
+ self.get_success(
+ self._storage_controllers.purge_events.purge_history(
+ self.room_id, token_str, True
+ )
+ )
+
+ # Advance so that the background jobs to delete the state groups runs
+ self.reactor.advance(
+ 1 + self.state_deletion_store.DELAY_BEFORE_DELETION_MS / 1000
+ )
+
+ # We expect all the state groups associated with events above, except
+ # the last one, should return no state.
+ state_groups = self.get_success(
+ self.state_store._get_state_groups_from_groups(
+ list(event_to_groups.values()), StateFilter.all()
+ )
+ )
+ first_state = state_groups[event_to_groups[first["event_id"]]]
+ second_state = state_groups[event_to_groups[second["event_id"]]]
+ third_state = state_groups[event_to_groups[third["event_id"]]]
+ last_state = state_groups[event_to_groups[last["event_id"]]]
+
+ self.assertEqual(first_state, {})
+ self.assertEqual(second_state, {})
+ self.assertEqual(third_state, {})
+ self.assertNotEqual(last_state, {})
+
+ def test_purge_unreferenced_state_group(self) -> None:
+ """Test that purging a room also gets rid of unreferenced state groups
+ it encounters during the purge.
+
+ This is important, as otherwise these unreferenced state groups get
+ "de-deltaed" during the purge process, consuming lots of disk space.
+ """
+
+ self.helper.send(self.room_id, body="test1")
+ state1 = self.helper.send_state(
+ self.room_id, "org.matrix.test", body={"number": 2}
+ )
+ state2 = self.helper.send_state(
+ self.room_id, "org.matrix.test", body={"number": 3}
+ )
+ self.helper.send(self.room_id, body="test4")
+ last = self.helper.send(self.room_id, body="test5")
+
+ # Create an unreferenced state group that has a prev group of one of the
+ # to-be-purged events.
+ prev_group = self.get_success(
+ self.store._get_state_group_for_event(state1["event_id"])
+ )
+ unreferenced_state_group = self.get_success(
+ self.state_store.store_state_group(
+ event_id=last["event_id"],
+ room_id=self.room_id,
+ prev_group=prev_group,
+ delta_ids={("org.matrix.test", ""): state2["event_id"]},
+ current_state_ids=None,
+ )
+ )
+
+ # Get the topological token
+ token = self.get_success(
+ self.store.get_topological_token_for_event(last["event_id"])
+ )
+ token_str = self.get_success(token.to_string(self.hs.get_datastores().main))
+
+ # Purge everything before this topological token
+ self.get_success(
+ self._storage_controllers.purge_events.purge_history(
+ self.room_id, token_str, True
+ )
+ )
+
+ # Advance so that the background jobs to delete the state groups runs
+ self.reactor.advance(
+ 1 + self.state_deletion_store.DELAY_BEFORE_DELETION_MS / 1000
+ )
+
+ # We expect that the unreferenced state group has been deleted from all tables.
+ row = self.get_success(
+ self.state_store.db_pool.simple_select_one_onecol(
+ table="state_groups",
+ keyvalues={"id": unreferenced_state_group},
+ retcol="id",
+ allow_none=True,
+ desc="test_purge_unreferenced_state_group",
+ )
+ )
+ self.assertIsNone(row)
+
+ row = self.get_success(
+ self.state_store.db_pool.simple_select_one_onecol(
+ table="state_groups_state",
+ keyvalues={"state_group": unreferenced_state_group},
+ retcol="state_group",
+ allow_none=True,
+ desc="test_purge_unreferenced_state_group",
+ )
+ )
+ self.assertIsNone(row)
+
+ row = self.get_success(
+ self.state_store.db_pool.simple_select_one_onecol(
+ table="state_group_edges",
+ keyvalues={"state_group": unreferenced_state_group},
+ retcol="state_group",
+ allow_none=True,
+ desc="test_purge_unreferenced_state_group",
+ )
+ )
+ self.assertIsNone(row)
+
+ row = self.get_success(
+ self.state_store.db_pool.simple_select_one_onecol(
+ table="state_groups_pending_deletion",
+ keyvalues={"state_group": unreferenced_state_group},
+ retcol="state_group",
+ allow_none=True,
+ desc="test_purge_unreferenced_state_group",
+ )
+ )
+ self.assertIsNone(row)
+
+ # We expect there to now only be one state group for the room, which is
+ # the state group of the last event (as the only outlier).
+ state_groups = self.get_success(
+ self.state_store.db_pool.simple_select_onecol(
+ table="state_groups",
+ keyvalues={"room_id": self.room_id},
+ retcol="id",
+ desc="test_purge_unreferenced_state_group",
+ )
+ )
+ self.assertEqual(len(state_groups), 1)
+
+ def test_clear_unreferenced_state_groups(self) -> None:
+ """Test that any unreferenced state groups are automatically cleaned up."""
+
+ self.helper.send(self.room_id, body="test1")
+ state1 = self.helper.send_state(
+ self.room_id, "org.matrix.test", body={"number": 2}
+ )
+ # Create enough state events to require multiple batches of
+ # mark_unreferenced_state_groups_for_deletion_bg_update to be run.
+ for i in range(200):
+ self.helper.send_state(self.room_id, "org.matrix.test", body={"number": i})
+ self.helper.send(self.room_id, body="test4")
+ last = self.helper.send(self.room_id, body="test5")
+
+ # Create an unreferenced state group that has no prev group.
+ unreferenced_free_state_group = self.get_success(
+ self.state_store.store_state_group(
+ event_id=last["event_id"],
+ room_id=self.room_id,
+ prev_group=None,
+ delta_ids={("org.matrix.test", ""): state1["event_id"]},
+ current_state_ids={("org.matrix.test", ""): ""},
+ )
+ )
+
+ # Create some unreferenced state groups that have a prev group of one of the
+ # existing state groups.
+ prev_group = self.get_success(
+ self.store._get_state_group_for_event(state1["event_id"])
+ )
+ unreferenced_end_state_group = self.get_success(
+ self.state_store.store_state_group(
+ event_id=last["event_id"],
+ room_id=self.room_id,
+ prev_group=prev_group,
+ delta_ids={("org.matrix.test", ""): state1["event_id"]},
+ current_state_ids=None,
+ )
+ )
+ another_unreferenced_end_state_group = self.get_success(
+ self.state_store.store_state_group(
+ event_id=last["event_id"],
+ room_id=self.room_id,
+ prev_group=unreferenced_end_state_group,
+ delta_ids={("org.matrix.test", ""): state1["event_id"]},
+ current_state_ids=None,
+ )
+ )
+
+ # Add some other unreferenced state groups which lead to a referenced state
+ # group.
+ # These state groups should not get deleted.
+ chain_state_group = self.get_success(
+ self.state_store.store_state_group(
+ event_id=last["event_id"],
+ room_id=self.room_id,
+ prev_group=None,
+ delta_ids={("org.matrix.test", ""): ""},
+ current_state_ids={("org.matrix.test", ""): ""},
+ )
+ )
+ chain_state_group_2 = self.get_success(
+ self.state_store.store_state_group(
+ event_id=last["event_id"],
+ room_id=self.room_id,
+ prev_group=chain_state_group,
+ delta_ids={("org.matrix.test", ""): ""},
+ current_state_ids=None,
+ )
+ )
+ referenced_chain_state_group = self.get_success(
+ self.state_store.store_state_group(
+ event_id=last["event_id"],
+ room_id=self.room_id,
+ prev_group=chain_state_group_2,
+ delta_ids={("org.matrix.test", ""): ""},
+ current_state_ids=None,
+ )
+ )
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "event_to_state_groups",
+ {
+ "event_id": "$new_event",
+ "state_group": referenced_chain_state_group,
+ },
+ )
+ )
+
+ # Insert and run the background update.
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {
+ "update_name": _BackgroundUpdates.MARK_UNREFERENCED_STATE_GROUPS_FOR_DELETION_BG_UPDATE,
+ "progress_json": "{}",
+ },
+ )
+ )
+ self.store.db_pool.updates._all_done = False
+ self.wait_for_background_updates()
+
+ # Advance so that the background job to delete the state groups runs
+ self.reactor.advance(
+ 1 + self.state_deletion_store.DELAY_BEFORE_DELETION_MS / 1000
+ )
+
+ # We expect that the unreferenced free state group has been deleted.
+ row = self.get_success(
+ self.state_store.db_pool.simple_select_one_onecol(
+ table="state_groups",
+ keyvalues={"id": unreferenced_free_state_group},
+ retcol="id",
+ allow_none=True,
+ desc="test_purge_unreferenced_state_group",
+ )
+ )
+ self.assertIsNone(row)
+
+ # We expect that both unreferenced end state groups have been deleted.
+ row = self.get_success(
+ self.state_store.db_pool.simple_select_one_onecol(
+ table="state_groups",
+ keyvalues={"id": unreferenced_end_state_group},
+ retcol="id",
+ allow_none=True,
+ desc="test_purge_unreferenced_state_group",
+ )
+ )
+ self.assertIsNone(row)
+ row = self.get_success(
+ self.state_store.db_pool.simple_select_one_onecol(
+ table="state_groups",
+ keyvalues={"id": another_unreferenced_end_state_group},
+ retcol="id",
+ allow_none=True,
+ desc="test_purge_unreferenced_state_group",
+ )
+ )
+ self.assertIsNone(row)
+
+ # We expect there to now only be one state group for the room, which is
+ # the state group of the last event (as the only outlier).
+ state_groups = self.get_success(
+ self.state_store.db_pool.simple_select_onecol(
+ table="state_groups",
+ keyvalues={"room_id": self.room_id},
+ retcol="id",
+ desc="test_purge_unreferenced_state_group",
+ )
+ )
+ self.assertEqual(len(state_groups), 210)
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 14e3871dc1..ebad759fd1 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -21,7 +21,6 @@
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import UserTypes
-from synapse.api.errors import ThreepidValidationError
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID, UserInfo
from synapse.util import Clock
@@ -145,39 +144,6 @@ class RegistrationStoreTestCase(HomeserverTestCase):
res = self.get_success(self.store.is_support_user(SUPPORT_USER))
self.assertTrue(res)
- def test_3pid_inhibit_invalid_validation_session_error(self) -> None:
- """Tests that enabling the configuration option to inhibit 3PID errors on
- /requestToken also inhibits validation errors caused by an unknown session ID.
- """
-
- # Check that, with the config setting set to false (the default value), a
- # validation error is caused by the unknown session ID.
- e = self.get_failure(
- self.store.validate_threepid_session(
- "fake_sid",
- "fake_client_secret",
- "fake_token",
- 0,
- ),
- ThreepidValidationError,
- )
- self.assertEqual(e.value.msg, "Unknown session_id", e)
-
- # Set the config setting to true.
- self.store._ignore_unknown_session_error = True
-
- # Check that now the validation error is caused by the token not matching.
- e = self.get_failure(
- self.store.validate_threepid_session(
- "fake_sid",
- "fake_client_secret",
- "fake_token",
- 0,
- ),
- ThreepidValidationError,
- )
- self.assertEqual(e.value.msg, "Validation token not found or has expired", e)
-
class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
def default_config(self) -> JsonDict:
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 418b556108..330fea0e62 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -24,7 +24,7 @@ from typing import List, Optional, Tuple, cast
from twisted.test.proto_helpers import MemoryReactor
-from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.constants import EventContentFields, EventTypes, JoinRules, Membership
from synapse.api.room_versions import RoomVersions
from synapse.rest import admin
from synapse.rest.admin import register_servlets_for_client_rest_resource
@@ -38,6 +38,7 @@ from synapse.util import Clock
from tests import unittest
from tests.server import TestHomeServer
from tests.test_utils import event_injection
+from tests.test_utils.event_injection import create_event
from tests.unittest import skip_unless
logger = logging.getLogger(__name__)
@@ -54,6 +55,10 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
# We can't test the RoomMemberStore on its own without the other event
# storage logic
self.store = hs.get_datastores().main
+ self.state_handler = self.hs.get_state_handler()
+ persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
+ self.persistence = persistence
self.u_alice = self.register_user("alice", "pass")
self.t_alice = self.login("alice", "pass")
@@ -220,31 +225,166 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
)
def test_join_locally_forgotten_room(self) -> None:
- """Tests if a user joins a forgotten room the room is not forgotten anymore."""
- self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
- self.assertFalse(
- self.get_success(self.store.is_locally_forgotten_room(self.room))
+ """
+ Tests if a user joins a forgotten room, the room is not forgotten anymore.
+
+ Since a room can't be re-joined if everyone has left. This can only happen with
+ a room with remote users in it.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a remote room
+ creator = "@user:other"
+ room_id = "!foo:other"
+ room_version = RoomVersions.V10
+ shared_kwargs = {
+ "room_id": room_id,
+ "room_version": room_version.identifier,
+ }
+
+ create_tuple = self.get_success(
+ create_event(
+ self.hs,
+ prev_event_ids=[],
+ type=EventTypes.Create,
+ state_key="",
+ content={
+ # The `ROOM_CREATOR` field could be removed if we used a room
+ # version > 10 (in favor of relying on `sender`)
+ EventContentFields.ROOM_CREATOR: creator,
+ EventContentFields.ROOM_VERSION: room_version.identifier,
+ },
+ sender=creator,
+ **shared_kwargs,
+ )
+ )
+ creator_tuple = self.get_success(
+ create_event(
+ self.hs,
+ prev_event_ids=[create_tuple[0].event_id],
+ auth_event_ids=[create_tuple[0].event_id],
+ type=EventTypes.Member,
+ state_key=creator,
+ content={"membership": Membership.JOIN},
+ sender=creator,
+ **shared_kwargs,
+ )
)
- # after leaving and forget the room, it is forgotten
- self.get_success(
- event_injection.inject_member_event(
- self.hs, self.room, self.u_alice, "leave"
+ remote_events_and_contexts = [
+ create_tuple,
+ creator_tuple,
+ ]
+
+ # Ensure the local HS knows the room version
+ self.get_success(self.store.store_room(room_id, creator, False, room_version))
+
+ # Persist these events as backfilled events.
+ for event, context in remote_events_and_contexts:
+ self.get_success(
+ self.persistence.persist_event(event, context, backfilled=True)
+ )
+
+ # Now we join the local user to the room. We want to make this feel as close to
+ # the real `process_remote_join()` as possible but we'd like to avoid some of
+ # the auth checks that would be done in the real code.
+ #
+ # FIXME: The test was originally written using this less-real
+ # `persist_event(...)` shortcut but it would be nice to use the real remote join
+ # process in a `FederatingHomeserverTestCase`.
+ flawed_join_tuple = self.get_success(
+ create_event(
+ self.hs,
+ prev_event_ids=[creator_tuple[0].event_id],
+ # This doesn't work correctly to create an `EventContext` that includes
+ # both of these state events. I assume it's because we're working on our
+ # local homeserver which has the remote state set as `outlier`. We have
+ # to create our own EventContext below to get this right.
+ auth_event_ids=[create_tuple[0].event_id],
+ type=EventTypes.Member,
+ state_key=user1_id,
+ content={"membership": Membership.JOIN},
+ sender=user1_id,
+ **shared_kwargs,
)
)
- self.get_success(self.store.forget(self.u_alice, self.room))
- self.assertTrue(
- self.get_success(self.store.is_locally_forgotten_room(self.room))
+ # We have to create our own context to get the state set correctly. If we use
+ # the `EventContext` from the `flawed_join_tuple`, the `current_state_events`
+ # table will only have the join event in it which should never happen in our
+ # real server.
+ join_event = flawed_join_tuple[0]
+ join_context = self.get_success(
+ self.state_handler.compute_event_context(
+ join_event,
+ state_ids_before_event={
+ (e.type, e.state_key): e.event_id for e in [create_tuple[0]]
+ },
+ partial_state=False,
+ )
)
+ self.get_success(self.persistence.persist_event(join_event, join_context))
- # after rejoin the room is not forgotten anymore
- self.get_success(
- event_injection.inject_member_event(
- self.hs, self.room, self.u_alice, "join"
+ # The room shouldn't be forgotten because the local user just joined
+ self.assertFalse(
+ self.get_success(self.store.is_locally_forgotten_room(room_id))
+ )
+
+ # After all of the local users (there is only user1) leave and forgetting the
+ # room, it is forgotten
+ user1_leave_response = self.helper.leave(room_id, user1_id, tok=user1_tok)
+ user1_leave_event = self.get_success(
+ self.store.get_event(user1_leave_response["event_id"])
+ )
+ self.get_success(self.store.forget(user1_id, room_id))
+ self.assertTrue(self.get_success(self.store.is_locally_forgotten_room(room_id)))
+
+ # Join the local user to the room (again). We want to make this feel as close to
+ # the real `process_remote_join()` as possible but we'd like to avoid some of
+ # the auth checks that would be done in the real code.
+ #
+ # FIXME: The test was originally written using this less-real
+ # `event_injection.inject_member_event(...)` shortcut but it would be nice to
+ # use the real remote join process in a `FederatingHomeserverTestCase`.
+ flawed_join_tuple = self.get_success(
+ create_event(
+ self.hs,
+ prev_event_ids=[user1_leave_response["event_id"]],
+ # This doesn't work correctly to create an `EventContext` that includes
+ # both of these state events. I assume it's because we're working on our
+ # local homeserver which has the remote state set as `outlier`. We have
+ # to create our own EventContext below to get this right.
+ auth_event_ids=[
+ create_tuple[0].event_id,
+ user1_leave_response["event_id"],
+ ],
+ type=EventTypes.Member,
+ state_key=user1_id,
+ content={"membership": Membership.JOIN},
+ sender=user1_id,
+ **shared_kwargs,
+ )
+ )
+ # We have to create our own context to get the state set correctly. If we use
+ # the `EventContext` from the `flawed_join_tuple`, the `current_state_events`
+ # table will only have the join event in it which should never happen in our
+ # real server.
+ join_event = flawed_join_tuple[0]
+ join_context = self.get_success(
+ self.state_handler.compute_event_context(
+ join_event,
+ state_ids_before_event={
+ (e.type, e.state_key): e.event_id
+ for e in [create_tuple[0], user1_leave_event]
+ },
+ partial_state=False,
)
)
+ self.get_success(self.persistence.persist_event(join_event, join_context))
+
+ # After the local user rejoins the remote room, it isn't forgotten anymore
self.assertFalse(
- self.get_success(self.store.is_locally_forgotten_room(self.room))
+ self.get_success(self.store.is_locally_forgotten_room(room_id))
)
diff --git a/tests/storage/test_sliding_sync_tables.py b/tests/storage/test_sliding_sync_tables.py
new file mode 100644
index 0000000000..53212f7c45
--- /dev/null
+++ b/tests/storage/test_sliding_sync_tables.py
@@ -0,0 +1,5119 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+# Originally licensed under the Apache License, Version 2.0:
+# <http://www.apache.org/licenses/LICENSE-2.0>.
+#
+# [This file includes modifications made by New Vector Limited]
+#
+#
+import logging
+from typing import Dict, List, Optional, Tuple, cast
+
+import attr
+from parameterized import parameterized
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.constants import EventContentFields, EventTypes, Membership, RoomTypes
+from synapse.api.room_versions import RoomVersions
+from synapse.events import EventBase, StrippedStateEvent, make_event_from_dict
+from synapse.events.snapshot import EventContext
+from synapse.rest import admin
+from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.storage.databases.main.events import DeltaState
+from synapse.storage.databases.main.events_bg_updates import (
+ _resolve_stale_data_in_sliding_sync_joined_rooms_table,
+ _resolve_stale_data_in_sliding_sync_membership_snapshots_table,
+)
+from synapse.types import create_requester
+from synapse.types.storage import _BackgroundUpdates
+from synapse.util import Clock
+
+from tests.test_utils.event_injection import create_event
+from tests.unittest import HomeserverTestCase
+
+logger = logging.getLogger(__name__)
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _SlidingSyncJoinedRoomResult:
+ room_id: str
+ # `event_stream_ordering` is only optional to allow easier semantics when we make
+ # expected objects from `event.internal_metadata.stream_ordering`. in the tests.
+ # `event.internal_metadata.stream_ordering` is marked optional because it only
+ # exists for persisted events but in the context of these tests, we're only working
+ # with persisted events and we're making comparisons so we will find any mismatch.
+ event_stream_ordering: Optional[int]
+ bump_stamp: Optional[int]
+ room_type: Optional[str]
+ room_name: Optional[str]
+ is_encrypted: bool
+ tombstone_successor_room_id: Optional[str]
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _SlidingSyncMembershipSnapshotResult:
+ room_id: str
+ user_id: str
+ sender: str
+ membership_event_id: str
+ membership: str
+ # `event_stream_ordering` is only optional to allow easier semantics when we make
+ # expected objects from `event.internal_metadata.stream_ordering`. in the tests.
+ # `event.internal_metadata.stream_ordering` is marked optional because it only
+ # exists for persisted events but in the context of these tests, we're only working
+ # with persisted events and we're making comparisons so we will find any mismatch.
+ event_stream_ordering: Optional[int]
+ has_known_state: bool
+ room_type: Optional[str]
+ room_name: Optional[str]
+ is_encrypted: bool
+ tombstone_successor_room_id: Optional[str]
+ # Make this default to "not forgotten" because it doesn't apply to many tests and we
+ # don't want to force all of the tests to deal with it.
+ forgotten: bool = False
+
+
+class SlidingSyncTablesTestCaseBase(HomeserverTestCase):
+ """
+ Helpers to deal with testing that the
+ `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` database tables are
+ populated correctly.
+ """
+
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.storage_controllers = hs.get_storage_controllers()
+ persist_events_store = self.hs.get_datastores().persist_events
+ assert persist_events_store is not None
+ self.persist_events_store = persist_events_store
+
+ persist_controller = self.hs.get_storage_controllers().persistence
+ assert persist_controller is not None
+ self.persist_controller = persist_controller
+
+ self.state_handler = self.hs.get_state_handler()
+
+ def _get_sliding_sync_joined_rooms(self) -> Dict[str, _SlidingSyncJoinedRoomResult]:
+ """
+ Return the rows from the `sliding_sync_joined_rooms` table.
+
+ Returns:
+ Mapping from room_id to _SlidingSyncJoinedRoomResult.
+ """
+ rows = cast(
+ List[Tuple[str, int, int, str, str, bool, str]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ "sliding_sync_joined_rooms",
+ None,
+ retcols=(
+ "room_id",
+ "event_stream_ordering",
+ "bump_stamp",
+ "room_type",
+ "room_name",
+ "is_encrypted",
+ "tombstone_successor_room_id",
+ ),
+ ),
+ ),
+ )
+
+ return {
+ row[0]: _SlidingSyncJoinedRoomResult(
+ room_id=row[0],
+ event_stream_ordering=row[1],
+ bump_stamp=row[2],
+ room_type=row[3],
+ room_name=row[4],
+ is_encrypted=bool(row[5]),
+ tombstone_successor_room_id=row[6],
+ )
+ for row in rows
+ }
+
+ def _get_sliding_sync_membership_snapshots(
+ self,
+ ) -> Dict[Tuple[str, str], _SlidingSyncMembershipSnapshotResult]:
+ """
+ Return the rows from the `sliding_sync_membership_snapshots` table.
+
+ Returns:
+ Mapping from the (room_id, user_id) to _SlidingSyncMembershipSnapshotResult.
+ """
+ rows = cast(
+ List[Tuple[str, str, str, str, str, int, int, bool, str, str, bool, str]],
+ self.get_success(
+ self.store.db_pool.simple_select_list(
+ "sliding_sync_membership_snapshots",
+ None,
+ retcols=(
+ "room_id",
+ "user_id",
+ "sender",
+ "membership_event_id",
+ "membership",
+ "forgotten",
+ "event_stream_ordering",
+ "has_known_state",
+ "room_type",
+ "room_name",
+ "is_encrypted",
+ "tombstone_successor_room_id",
+ ),
+ ),
+ ),
+ )
+
+ return {
+ (row[0], row[1]): _SlidingSyncMembershipSnapshotResult(
+ room_id=row[0],
+ user_id=row[1],
+ sender=row[2],
+ membership_event_id=row[3],
+ membership=row[4],
+ forgotten=bool(row[5]),
+ event_stream_ordering=row[6],
+ has_known_state=bool(row[7]),
+ room_type=row[8],
+ room_name=row[9],
+ is_encrypted=bool(row[10]),
+ tombstone_successor_room_id=row[11],
+ )
+ for row in rows
+ }
+
+ _remote_invite_count: int = 0
+
+ def _create_remote_invite_room_for_user(
+ self,
+ invitee_user_id: str,
+ unsigned_invite_room_state: Optional[List[StrippedStateEvent]],
+ ) -> Tuple[str, EventBase]:
+ """
+ Create a fake invite for a remote room and persist it.
+
+ We don't have any state for these kind of rooms and can only rely on the
+ stripped state included in the unsigned portion of the invite event to identify
+ the room.
+
+ Args:
+ invitee_user_id: The person being invited
+ unsigned_invite_room_state: List of stripped state events to assist the
+ receiver in identifying the room.
+
+ Returns:
+ The room ID of the remote invite room and the persisted remote invite event.
+ """
+ invite_room_id = f"!test_room{self._remote_invite_count}:remote_server"
+
+ invite_event_dict = {
+ "room_id": invite_room_id,
+ "sender": "@inviter:remote_server",
+ "state_key": invitee_user_id,
+ "depth": 1,
+ "origin_server_ts": 1,
+ "type": EventTypes.Member,
+ "content": {"membership": Membership.INVITE},
+ "auth_events": [],
+ "prev_events": [],
+ }
+ if unsigned_invite_room_state is not None:
+ serialized_stripped_state_events = []
+ for stripped_event in unsigned_invite_room_state:
+ serialized_stripped_state_events.append(
+ {
+ "type": stripped_event.type,
+ "state_key": stripped_event.state_key,
+ "sender": stripped_event.sender,
+ "content": stripped_event.content,
+ }
+ )
+
+ invite_event_dict["unsigned"] = {
+ "invite_room_state": serialized_stripped_state_events
+ }
+
+ invite_event = make_event_from_dict(
+ invite_event_dict,
+ room_version=RoomVersions.V10,
+ )
+ invite_event.internal_metadata.outlier = True
+ invite_event.internal_metadata.out_of_band_membership = True
+
+ self.get_success(
+ self.store.maybe_store_room_on_outlier_membership(
+ room_id=invite_room_id, room_version=invite_event.room_version
+ )
+ )
+ context = EventContext.for_outlier(self.hs.get_storage_controllers())
+ persisted_event, _, _ = self.get_success(
+ self.persist_controller.persist_event(invite_event, context)
+ )
+
+ self._remote_invite_count += 1
+
+ return invite_room_id, persisted_event
+
+ def _retract_remote_invite_for_user(
+ self,
+ user_id: str,
+ remote_room_id: str,
+ ) -> EventBase:
+ """
+ Create a fake invite retraction for a remote room and persist it.
+
+ Retracting an invite just means the person is no longer invited to the room.
+ This is done by someone with proper power levels kicking the user from the room.
+ A kick shows up as a leave event for a given person with a different `sender`.
+
+ Args:
+ user_id: The person who was invited and we're going to retract the
+ invite for.
+ remote_room_id: The room ID that the invite was for.
+
+ Returns:
+ The persisted leave (kick) event.
+ """
+
+ kick_event_dict = {
+ "room_id": remote_room_id,
+ "sender": "@inviter:remote_server",
+ "state_key": user_id,
+ "depth": 1,
+ "origin_server_ts": 1,
+ "type": EventTypes.Member,
+ "content": {"membership": Membership.LEAVE},
+ "auth_events": [],
+ "prev_events": [],
+ }
+
+ kick_event = make_event_from_dict(
+ kick_event_dict,
+ room_version=RoomVersions.V10,
+ )
+ kick_event.internal_metadata.outlier = True
+ kick_event.internal_metadata.out_of_band_membership = True
+
+ self.get_success(
+ self.store.maybe_store_room_on_outlier_membership(
+ room_id=remote_room_id, room_version=kick_event.room_version
+ )
+ )
+ context = EventContext.for_outlier(self.hs.get_storage_controllers())
+ persisted_event, _, _ = self.get_success(
+ self.persist_controller.persist_event(kick_event, context)
+ )
+
+ return persisted_event
+
+
+class SlidingSyncTablesTestCase(SlidingSyncTablesTestCaseBase):
+ """
+ Tests to make sure the
+ `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` database tables are
+ populated and updated correctly as new events are sent.
+ """
+
+ def test_joined_room_with_no_info(self) -> None:
+ """
+ Test joined room that doesn't have a room type, encryption, or name shows up in
+ `sliding_sync_joined_rooms`.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ {room_id1},
+ exact=True,
+ )
+ self.assertEqual(
+ sliding_sync_joined_rooms_results[room_id1],
+ _SlidingSyncJoinedRoomResult(
+ room_id=room_id1,
+ # History visibility just happens to be the last event sent in the room
+ event_stream_ordering=state_map[
+ (EventTypes.RoomHistoryVisibility, "")
+ ].internal_metadata.stream_ordering,
+ bump_stamp=state_map[
+ (EventTypes.Create, "")
+ ].internal_metadata.stream_ordering,
+ room_type=None,
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (room_id1, user1_id),
+ },
+ exact=True,
+ )
+ # Holds the info according to the current state when the user joined
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id1, user1_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id1,
+ user_id=user1_id,
+ sender=user1_id,
+ membership_event_id=state_map[(EventTypes.Member, user1_id)].event_id,
+ membership=Membership.JOIN,
+ event_stream_ordering=state_map[
+ (EventTypes.Member, user1_id)
+ ].internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=None,
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ def test_joined_room_with_info(self) -> None:
+ """
+ Test joined encrypted room with name shows up in `sliding_sync_joined_rooms`.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ # Add a room name
+ self.helper.send_state(
+ room_id1,
+ EventTypes.Name,
+ {"name": "my super duper room"},
+ tok=user2_tok,
+ )
+ # Encrypt the room
+ self.helper.send_state(
+ room_id1,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=user2_tok,
+ )
+ # Add a tombstone
+ self.helper.send_state(
+ room_id1,
+ EventTypes.Tombstone,
+ {EventContentFields.TOMBSTONE_SUCCESSOR_ROOM: "another_room"},
+ tok=user2_tok,
+ )
+
+ # User1 joins the room
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ {room_id1},
+ exact=True,
+ )
+ self.assertEqual(
+ sliding_sync_joined_rooms_results[room_id1],
+ _SlidingSyncJoinedRoomResult(
+ room_id=room_id1,
+ # This should be whatever is the last event in the room
+ event_stream_ordering=state_map[
+ (EventTypes.Member, user1_id)
+ ].internal_metadata.stream_ordering,
+ bump_stamp=state_map[
+ (EventTypes.Create, "")
+ ].internal_metadata.stream_ordering,
+ room_type=None,
+ room_name="my super duper room",
+ is_encrypted=True,
+ tombstone_successor_room_id="another_room",
+ ),
+ )
+
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (room_id1, user1_id),
+ (room_id1, user2_id),
+ },
+ exact=True,
+ )
+ # Holds the info according to the current state when the user joined
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id1, user1_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id1,
+ user_id=user1_id,
+ sender=user1_id,
+ membership_event_id=state_map[(EventTypes.Member, user1_id)].event_id,
+ membership=Membership.JOIN,
+ event_stream_ordering=state_map[
+ (EventTypes.Member, user1_id)
+ ].internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=None,
+ room_name="my super duper room",
+ is_encrypted=True,
+ tombstone_successor_room_id="another_room",
+ ),
+ )
+ # Holds the info according to the current state when the user joined
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id1, user2_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id1,
+ user_id=user2_id,
+ sender=user2_id,
+ membership_event_id=state_map[(EventTypes.Member, user2_id)].event_id,
+ membership=Membership.JOIN,
+ event_stream_ordering=state_map[
+ (EventTypes.Member, user2_id)
+ ].internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=None,
+ # Even though this room does have a name, is encrypted, and has a
+ # tombstone, user2 is the room creator and joined at the room creation
+ # time which didn't have this state set yet.
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ def test_joined_space_room_with_info(self) -> None:
+ """
+ Test joined space room with name shows up in `sliding_sync_joined_rooms`.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ space_room_id = self.helper.create_room_as(
+ user2_id,
+ tok=user2_tok,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
+ },
+ )
+ # Add a room name
+ self.helper.send_state(
+ space_room_id,
+ EventTypes.Name,
+ {"name": "my super duper space"},
+ tok=user2_tok,
+ )
+
+ # User1 joins the room
+ user1_join_response = self.helper.join(space_room_id, user1_id, tok=user1_tok)
+ user1_join_event_pos = self.get_success(
+ self.store.get_position_for_event(user1_join_response["event_id"])
+ )
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(space_room_id)
+ )
+
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ {space_room_id},
+ exact=True,
+ )
+ self.assertEqual(
+ sliding_sync_joined_rooms_results[space_room_id],
+ _SlidingSyncJoinedRoomResult(
+ room_id=space_room_id,
+ event_stream_ordering=user1_join_event_pos.stream,
+ bump_stamp=state_map[
+ (EventTypes.Create, "")
+ ].internal_metadata.stream_ordering,
+ room_type=RoomTypes.SPACE,
+ room_name="my super duper space",
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (space_room_id, user1_id),
+ (space_room_id, user2_id),
+ },
+ exact=True,
+ )
+ # Holds the info according to the current state when the user joined
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((space_room_id, user1_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=space_room_id,
+ user_id=user1_id,
+ sender=user1_id,
+ membership_event_id=state_map[(EventTypes.Member, user1_id)].event_id,
+ membership=Membership.JOIN,
+ event_stream_ordering=state_map[
+ (EventTypes.Member, user1_id)
+ ].internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=RoomTypes.SPACE,
+ room_name="my super duper space",
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+ # Holds the info according to the current state when the user joined
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((space_room_id, user2_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=space_room_id,
+ user_id=user2_id,
+ sender=user2_id,
+ membership_event_id=state_map[(EventTypes.Member, user2_id)].event_id,
+ membership=Membership.JOIN,
+ event_stream_ordering=state_map[
+ (EventTypes.Member, user2_id)
+ ].internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=RoomTypes.SPACE,
+ # Even though this room does have a name, user2 is the room creator and
+ # joined at the room creation time which didn't have this state set yet.
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ def test_joined_room_with_state_updated(self) -> None:
+ """
+ Test state derived info in `sliding_sync_joined_rooms` is updated when the
+ current state is updated.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ # Add a room name
+ self.helper.send_state(
+ room_id1,
+ EventTypes.Name,
+ {"name": "my super duper room"},
+ tok=user2_tok,
+ )
+
+ # User1 joins the room
+ user1_join_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
+ user1_join_event_pos = self.get_success(
+ self.store.get_position_for_event(user1_join_response["event_id"])
+ )
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ {room_id1},
+ exact=True,
+ )
+ self.assertEqual(
+ sliding_sync_joined_rooms_results[room_id1],
+ _SlidingSyncJoinedRoomResult(
+ room_id=room_id1,
+ event_stream_ordering=user1_join_event_pos.stream,
+ bump_stamp=state_map[
+ (EventTypes.Create, "")
+ ].internal_metadata.stream_ordering,
+ room_type=None,
+ room_name="my super duper room",
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (room_id1, user1_id),
+ (room_id1, user2_id),
+ },
+ exact=True,
+ )
+
+ # Update the room name
+ self.helper.send_state(
+ room_id1,
+ EventTypes.Name,
+ {"name": "my super duper room was renamed"},
+ tok=user2_tok,
+ )
+ # Encrypt the room
+ encrypt_room_response = self.helper.send_state(
+ room_id1,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=user2_tok,
+ )
+ encrypt_room_event_pos = self.get_success(
+ self.store.get_position_for_event(encrypt_room_response["event_id"])
+ )
+
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ {room_id1},
+ exact=True,
+ )
+ # Make sure we see the new room name
+ self.assertEqual(
+ sliding_sync_joined_rooms_results[room_id1],
+ _SlidingSyncJoinedRoomResult(
+ room_id=room_id1,
+ event_stream_ordering=encrypt_room_event_pos.stream,
+ bump_stamp=state_map[
+ (EventTypes.Create, "")
+ ].internal_metadata.stream_ordering,
+ room_type=None,
+ room_name="my super duper room was renamed",
+ is_encrypted=True,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (room_id1, user1_id),
+ (room_id1, user2_id),
+ },
+ exact=True,
+ )
+ # Holds the info according to the current state when the user joined
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id1, user1_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id1,
+ user_id=user1_id,
+ sender=user1_id,
+ membership_event_id=state_map[(EventTypes.Member, user1_id)].event_id,
+ membership=Membership.JOIN,
+ event_stream_ordering=state_map[
+ (EventTypes.Member, user1_id)
+ ].internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=None,
+ room_name="my super duper room",
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+ # Holds the info according to the current state when the user joined
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id1, user2_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id1,
+ user_id=user2_id,
+ sender=user2_id,
+ membership_event_id=state_map[(EventTypes.Member, user2_id)].event_id,
+ membership=Membership.JOIN,
+ event_stream_ordering=state_map[
+ (EventTypes.Member, user2_id)
+ ].internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=None,
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ def test_joined_room_is_bumped(self) -> None:
+ """
+ Test that `event_stream_ordering` and `bump_stamp` is updated when a new bump
+ event is sent (`sliding_sync_joined_rooms`).
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ # Add a room name
+ self.helper.send_state(
+ room_id1,
+ EventTypes.Name,
+ {"name": "my super duper room"},
+ tok=user2_tok,
+ )
+
+ # User1 joins the room
+ user1_join_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
+ user1_join_event_pos = self.get_success(
+ self.store.get_position_for_event(user1_join_response["event_id"])
+ )
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ {room_id1},
+ exact=True,
+ )
+ self.assertEqual(
+ sliding_sync_joined_rooms_results[room_id1],
+ _SlidingSyncJoinedRoomResult(
+ room_id=room_id1,
+ event_stream_ordering=user1_join_event_pos.stream,
+ bump_stamp=state_map[
+ (EventTypes.Create, "")
+ ].internal_metadata.stream_ordering,
+ room_type=None,
+ room_name="my super duper room",
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (room_id1, user1_id),
+ (room_id1, user2_id),
+ },
+ exact=True,
+ )
+ # Holds the info according to the current state when the user joined
+ user1_snapshot = _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id1,
+ user_id=user1_id,
+ sender=user1_id,
+ membership_event_id=state_map[(EventTypes.Member, user1_id)].event_id,
+ membership=Membership.JOIN,
+ event_stream_ordering=state_map[
+ (EventTypes.Member, user1_id)
+ ].internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=None,
+ room_name="my super duper room",
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id1, user1_id)),
+ user1_snapshot,
+ )
+ # Holds the info according to the current state when the user joined
+ user2_snapshot = _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id1,
+ user_id=user2_id,
+ sender=user2_id,
+ membership_event_id=state_map[(EventTypes.Member, user2_id)].event_id,
+ membership=Membership.JOIN,
+ event_stream_ordering=state_map[
+ (EventTypes.Member, user2_id)
+ ].internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=None,
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id1, user2_id)),
+ user2_snapshot,
+ )
+
+ # Send a new message to bump the room
+ event_response = self.helper.send(room_id1, "some message", tok=user1_tok)
+ event_pos = self.get_success(
+ self.store.get_position_for_event(event_response["event_id"])
+ )
+
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ {room_id1},
+ exact=True,
+ )
+ # Make sure we see the new room name
+ self.assertEqual(
+ sliding_sync_joined_rooms_results[room_id1],
+ _SlidingSyncJoinedRoomResult(
+ room_id=room_id1,
+ # Updated `event_stream_ordering`
+ event_stream_ordering=event_pos.stream,
+ # And since the event was a bump event, the `bump_stamp` should be updated
+ bump_stamp=event_pos.stream,
+ # The state is still the same (it didn't change)
+ room_type=None,
+ room_name="my super duper room",
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (room_id1, user1_id),
+ (room_id1, user2_id),
+ },
+ exact=True,
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id1, user1_id)),
+ user1_snapshot,
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id1, user2_id)),
+ user2_snapshot,
+ )
+
+ def test_joined_room_bump_stamp_backfill(self) -> None:
+ """
+ Test that `bump_stamp` ignores backfilled events, i.e. events with a
+ negative stream ordering.
+ """
+ user1_id = self.register_user("user1", "pass")
+ _user1_tok = self.login(user1_id, "pass")
+
+ # Create a remote room
+ creator = "@user:other"
+ room_id = "!foo:other"
+ room_version = RoomVersions.V10
+ shared_kwargs = {
+ "room_id": room_id,
+ "room_version": room_version.identifier,
+ }
+
+ create_tuple = self.get_success(
+ create_event(
+ self.hs,
+ prev_event_ids=[],
+ type=EventTypes.Create,
+ state_key="",
+ content={
+ # The `ROOM_CREATOR` field could be removed if we used a room
+ # version > 10 (in favor of relying on `sender`)
+ EventContentFields.ROOM_CREATOR: creator,
+ EventContentFields.ROOM_VERSION: room_version.identifier,
+ },
+ sender=creator,
+ **shared_kwargs,
+ )
+ )
+ creator_tuple = self.get_success(
+ create_event(
+ self.hs,
+ prev_event_ids=[create_tuple[0].event_id],
+ auth_event_ids=[create_tuple[0].event_id],
+ type=EventTypes.Member,
+ state_key=creator,
+ content={"membership": Membership.JOIN},
+ sender=creator,
+ **shared_kwargs,
+ )
+ )
+ room_name_tuple = self.get_success(
+ create_event(
+ self.hs,
+ prev_event_ids=[creator_tuple[0].event_id],
+ auth_event_ids=[create_tuple[0].event_id, creator_tuple[0].event_id],
+ type=EventTypes.Name,
+ state_key="",
+ content={
+ EventContentFields.ROOM_NAME: "my super duper room",
+ },
+ sender=creator,
+ **shared_kwargs,
+ )
+ )
+ # We add a message event as a valid "bump type"
+ msg_tuple = self.get_success(
+ create_event(
+ self.hs,
+ prev_event_ids=[room_name_tuple[0].event_id],
+ auth_event_ids=[create_tuple[0].event_id, creator_tuple[0].event_id],
+ type=EventTypes.Message,
+ content={"body": "foo", "msgtype": "m.text"},
+ sender=creator,
+ **shared_kwargs,
+ )
+ )
+ invite_tuple = self.get_success(
+ create_event(
+ self.hs,
+ prev_event_ids=[msg_tuple[0].event_id],
+ auth_event_ids=[create_tuple[0].event_id, creator_tuple[0].event_id],
+ type=EventTypes.Member,
+ state_key=user1_id,
+ content={"membership": Membership.INVITE},
+ sender=creator,
+ **shared_kwargs,
+ )
+ )
+
+ remote_events_and_contexts = [
+ create_tuple,
+ creator_tuple,
+ room_name_tuple,
+ msg_tuple,
+ invite_tuple,
+ ]
+
+ # Ensure the local HS knows the room version
+ self.get_success(self.store.store_room(room_id, creator, False, room_version))
+
+ # Persist these events as backfilled events.
+ for event, context in remote_events_and_contexts:
+ self.get_success(
+ self.persist_controller.persist_event(event, context, backfilled=True)
+ )
+
+ # Now we join the local user to the room. We want to make this feel as close to
+ # the real `process_remote_join()` as possible but we'd like to avoid some of
+ # the auth checks that would be done in the real code.
+ #
+ # FIXME: The test was originally written using this less-real
+ # `persist_event(...)` shortcut but it would be nice to use the real remote join
+ # process in a `FederatingHomeserverTestCase`.
+ flawed_join_tuple = self.get_success(
+ create_event(
+ self.hs,
+ prev_event_ids=[invite_tuple[0].event_id],
+ # This doesn't work correctly to create an `EventContext` that includes
+ # both of these state events. I assume it's because we're working on our
+ # local homeserver which has the remote state set as `outlier`. We have
+ # to create our own EventContext below to get this right.
+ auth_event_ids=[create_tuple[0].event_id, invite_tuple[0].event_id],
+ type=EventTypes.Member,
+ state_key=user1_id,
+ content={"membership": Membership.JOIN},
+ sender=user1_id,
+ **shared_kwargs,
+ )
+ )
+ # We have to create our own context to get the state set correctly. If we use
+ # the `EventContext` from the `flawed_join_tuple`, the `current_state_events`
+ # table will only have the join event in it which should never happen in our
+ # real server.
+ join_event = flawed_join_tuple[0]
+ join_context = self.get_success(
+ self.state_handler.compute_event_context(
+ join_event,
+ state_ids_before_event={
+ (e.type, e.state_key): e.event_id
+ for e in [create_tuple[0], invite_tuple[0], room_name_tuple[0]]
+ },
+ partial_state=False,
+ )
+ )
+ join_event, _join_event_pos, _room_token = self.get_success(
+ self.persist_controller.persist_event(join_event, join_context)
+ )
+
+ # Make sure the tables are populated correctly
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ {room_id},
+ exact=True,
+ )
+ self.assertEqual(
+ sliding_sync_joined_rooms_results[room_id],
+ _SlidingSyncJoinedRoomResult(
+ room_id=room_id,
+ # This should be the last event in the room (the join membership)
+ event_stream_ordering=join_event.internal_metadata.stream_ordering,
+ # Since all of the bump events are backfilled, the `bump_stamp` should
+ # still be `None`. (and we will fallback to the users membership event
+ # position in the Sliding Sync API)
+ bump_stamp=None,
+ room_type=None,
+ # We still pick up state of the room even if it's backfilled
+ room_name="my super duper room",
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (room_id, user1_id),
+ },
+ exact=True,
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id, user1_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id,
+ user_id=user1_id,
+ sender=user1_id,
+ membership_event_id=join_event.event_id,
+ membership=Membership.JOIN,
+ event_stream_ordering=join_event.internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=None,
+ room_name="my super duper room",
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ @parameterized.expand(
+ # Test both an insert an upsert into the
+ # `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` to exercise
+ # more possibilities of things going wrong.
+ [
+ ("insert", True),
+ ("upsert", False),
+ ]
+ )
+ def test_joined_room_outlier_and_deoutlier(
+ self, description: str, should_insert: bool
+ ) -> None:
+ """
+ This is a regression test.
+
+ This is to simulate the case where an event is first persisted as an outlier
+ (like a remote invite) and then later persisted again to de-outlier it. The
+ first the time, the `outlier` is persisted with one `stream_ordering` but when
+ persisted again and de-outliered, it is assigned a different `stream_ordering`
+ that won't end up being used. Since we call
+ `_calculate_sliding_sync_table_changes()` before `_update_outliers_txn()` which
+ fixes this discrepancy (always use the `stream_ordering` from the first time it
+ was persisted), make sure we're not using an unreliable `stream_ordering` values
+ that will cause `FOREIGN KEY constraint failed` in the
+ `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` tables.
+ """
+ user1_id = self.register_user("user1", "pass")
+ _user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_version = RoomVersions.V10
+ room_id = self.helper.create_room_as(
+ user2_id, tok=user2_tok, room_version=room_version.identifier
+ )
+
+ if should_insert:
+ # Clear these out so we always insert
+ self.get_success(
+ self.store.db_pool.simple_delete(
+ table="sliding_sync_joined_rooms",
+ keyvalues={"room_id": room_id},
+ desc="TODO",
+ )
+ )
+ self.get_success(
+ self.store.db_pool.simple_delete(
+ table="sliding_sync_membership_snapshots",
+ keyvalues={"room_id": room_id},
+ desc="TODO",
+ )
+ )
+
+ # Create a membership event (which triggers an insert into
+ # `sliding_sync_membership_snapshots`)
+ membership_event_dict = {
+ "type": EventTypes.Member,
+ "state_key": user1_id,
+ "sender": user1_id,
+ "room_id": room_id,
+ "content": {EventContentFields.MEMBERSHIP: Membership.JOIN},
+ }
+ # Create a relevant state event (which triggers an insert into
+ # `sliding_sync_joined_rooms`)
+ state_event_dict = {
+ "type": EventTypes.Name,
+ "state_key": "",
+ "sender": user2_id,
+ "room_id": room_id,
+ "content": {EventContentFields.ROOM_NAME: "my super room"},
+ }
+ event_dicts_to_persist = [
+ membership_event_dict,
+ state_event_dict,
+ ]
+
+ for event_dict in event_dicts_to_persist:
+ events_to_persist = []
+
+ # Create the events as an outliers
+ (
+ event,
+ unpersisted_context,
+ ) = self.get_success(
+ self.hs.get_event_creation_handler().create_event(
+ requester=create_requester(user1_id),
+ event_dict=event_dict,
+ outlier=True,
+ )
+ )
+ # FIXME: Should we use an `EventContext.for_outlier(...)` here?
+ # Doesn't seem to matter for this test.
+ context = self.get_success(unpersisted_context.persist(event))
+ events_to_persist.append((event, context))
+
+ # Create the event again but as an non-outlier. This will de-outlier the event
+ # when we persist it.
+ (
+ event,
+ unpersisted_context,
+ ) = self.get_success(
+ self.hs.get_event_creation_handler().create_event(
+ requester=create_requester(user1_id),
+ event_dict=event_dict,
+ outlier=False,
+ )
+ )
+ context = self.get_success(unpersisted_context.persist(event))
+ events_to_persist.append((event, context))
+
+ for event, context in events_to_persist:
+ self.get_success(
+ self.persist_controller.persist_event(
+ event,
+ context,
+ )
+ )
+
+ # We're just testing that it does not explode
+
+ def test_joined_room_meta_state_reset(self) -> None:
+ """
+ Test that a state reset on the room name is reflected in the
+ `sliding_sync_joined_rooms` table.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+ # Add a room name
+ self.helper.send_state(
+ room_id,
+ EventTypes.Name,
+ {"name": "my super duper room"},
+ tok=user2_tok,
+ )
+
+ # User1 joins the room
+ self.helper.join(room_id, user1_id, tok=user1_tok)
+
+ # Make sure we see the new room name
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ {room_id},
+ exact=True,
+ )
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id)
+ )
+ self.assertEqual(
+ sliding_sync_joined_rooms_results[room_id],
+ _SlidingSyncJoinedRoomResult(
+ room_id=room_id,
+ # This should be whatever is the last event in the room
+ event_stream_ordering=state_map[
+ (EventTypes.Member, user1_id)
+ ].internal_metadata.stream_ordering,
+ bump_stamp=state_map[
+ (EventTypes.Create, "")
+ ].internal_metadata.stream_ordering,
+ room_type=None,
+ room_name="my super duper room",
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (room_id, user1_id),
+ (room_id, user2_id),
+ },
+ exact=True,
+ )
+ user1_snapshot = _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id,
+ user_id=user1_id,
+ sender=user1_id,
+ membership_event_id=state_map[(EventTypes.Member, user1_id)].event_id,
+ membership=Membership.JOIN,
+ event_stream_ordering=state_map[
+ (EventTypes.Member, user1_id)
+ ].internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=None,
+ room_name="my super duper room",
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id, user1_id)),
+ user1_snapshot,
+ )
+ # Holds the info according to the current state when the user joined (no room
+ # name when the room creator joined)
+ user2_snapshot = _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id,
+ user_id=user2_id,
+ sender=user2_id,
+ membership_event_id=state_map[(EventTypes.Member, user2_id)].event_id,
+ membership=Membership.JOIN,
+ event_stream_ordering=state_map[
+ (EventTypes.Member, user2_id)
+ ].internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=None,
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id, user2_id)),
+ user2_snapshot,
+ )
+
+ # Mock a state reset removing the room name state from the current state
+ message_tuple = self.get_success(
+ create_event(
+ self.hs,
+ prev_event_ids=[state_map[(EventTypes.Name, "")].event_id],
+ auth_event_ids=[
+ state_map[(EventTypes.Create, "")].event_id,
+ state_map[(EventTypes.Member, user1_id)].event_id,
+ ],
+ type=EventTypes.Message,
+ content={"body": "foo", "msgtype": "m.text"},
+ sender=user1_id,
+ room_id=room_id,
+ room_version=RoomVersions.V10.identifier,
+ )
+ )
+ event_chunk = [message_tuple]
+ self.get_success(
+ self.persist_events_store._persist_events_and_state_updates(
+ room_id,
+ event_chunk,
+ state_delta_for_room=DeltaState(
+ # This is the state reset part. We're removing the room name state.
+ to_delete=[(EventTypes.Name, "")],
+ to_insert={},
+ ),
+ new_forward_extremities={message_tuple[0].event_id},
+ use_negative_stream_ordering=False,
+ inhibit_local_membership_updates=False,
+ new_event_links={},
+ )
+ )
+
+ # Make sure the state reset is reflected in the `sliding_sync_joined_rooms` table
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ {room_id},
+ exact=True,
+ )
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id)
+ )
+ self.assertEqual(
+ sliding_sync_joined_rooms_results[room_id],
+ _SlidingSyncJoinedRoomResult(
+ room_id=room_id,
+ # This should be whatever is the last event in the room
+ event_stream_ordering=message_tuple[
+ 0
+ ].internal_metadata.stream_ordering,
+ bump_stamp=message_tuple[0].internal_metadata.stream_ordering,
+ room_type=None,
+ # This was state reset back to None
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ # State reset shouldn't be reflected in the `sliding_sync_membership_snapshots`
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (room_id, user1_id),
+ (room_id, user2_id),
+ },
+ exact=True,
+ )
+ # Snapshots haven't changed
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id, user1_id)),
+ user1_snapshot,
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id, user2_id)),
+ user2_snapshot,
+ )
+
+ def test_joined_room_fully_insert_on_state_update(self) -> None:
+ """
+ Test that when an existing room updates it's state and we don't have a
+ corresponding row in `sliding_sync_joined_rooms` yet, we fully-insert the row
+ even though only a tiny piece of state changed.
+
+ FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+ foreground update for
+ `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+ https://github.com/element-hq/synapse/issues/17623)
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ # Add a room name
+ self.helper.send_state(
+ room_id,
+ EventTypes.Name,
+ {"name": "my super duper room"},
+ tok=user1_tok,
+ )
+
+ # Clean-up the `sliding_sync_joined_rooms` table as if the the room never made
+ # it into the table. This is to simulate an existing room (before we event added
+ # the sliding sync tables) not being in the `sliding_sync_joined_rooms` table
+ # yet.
+ self.get_success(
+ self.store.db_pool.simple_delete(
+ table="sliding_sync_joined_rooms",
+ keyvalues={"room_id": room_id},
+ desc="simulate existing room not being in the sliding_sync_joined_rooms table yet",
+ )
+ )
+
+ # We shouldn't find anything in the table because we just deleted them in
+ # preparation for the test.
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ set(),
+ exact=True,
+ )
+
+ # Encrypt the room
+ self.helper.send_state(
+ room_id,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=user1_tok,
+ )
+
+ # The room should now be in the `sliding_sync_joined_rooms` table
+ # (fully-inserted with all of the state values).
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ {room_id},
+ exact=True,
+ )
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id)
+ )
+ self.assertEqual(
+ sliding_sync_joined_rooms_results[room_id],
+ _SlidingSyncJoinedRoomResult(
+ room_id=room_id,
+ # This should be whatever is the last event in the room
+ event_stream_ordering=state_map[
+ (EventTypes.RoomEncryption, "")
+ ].internal_metadata.stream_ordering,
+ bump_stamp=state_map[
+ (EventTypes.Create, "")
+ ].internal_metadata.stream_ordering,
+ room_type=None,
+ room_name="my super duper room",
+ is_encrypted=True,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ def test_joined_room_nothing_if_not_in_table_when_bumped(self) -> None:
+ """
+ Test a new message being sent in an existing room when we don't have a
+ corresponding row in `sliding_sync_joined_rooms` yet; either nothing should
+ happen or we should fully-insert the row. We currently do nothing.
+
+ FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+ foreground update for
+ `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+ https://github.com/element-hq/synapse/issues/17623)
+ """
+
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ # Add a room name
+ self.helper.send_state(
+ room_id,
+ EventTypes.Name,
+ {"name": "my super duper room"},
+ tok=user1_tok,
+ )
+ # Encrypt the room
+ self.helper.send_state(
+ room_id,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=user1_tok,
+ )
+
+ # Clean-up the `sliding_sync_joined_rooms` table as if the the room never made
+ # it into the table. This is to simulate an existing room (before we event added
+ # the sliding sync tables) not being in the `sliding_sync_joined_rooms` table
+ # yet.
+ self.get_success(
+ self.store.db_pool.simple_delete(
+ table="sliding_sync_joined_rooms",
+ keyvalues={"room_id": room_id},
+ desc="simulate existing room not being in the sliding_sync_joined_rooms table yet",
+ )
+ )
+
+ # We shouldn't find anything in the table because we just deleted them in
+ # preparation for the test.
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ set(),
+ exact=True,
+ )
+
+ # Send a new message to bump the room
+ self.helper.send(room_id, "some message", tok=user1_tok)
+
+ # Either nothing should happen or we should fully-insert the row. We currently
+ # do nothing for non-state events.
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ set(),
+ exact=True,
+ )
+
+ def test_non_join_space_room_with_info(self) -> None:
+ """
+ Test users who was invited shows up in `sliding_sync_membership_snapshots`.
+ """
+ user1_id = self.register_user("user1", "pass")
+ _user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ space_room_id = self.helper.create_room_as(
+ user2_id,
+ tok=user2_tok,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
+ },
+ )
+ # Add a room name
+ self.helper.send_state(
+ space_room_id,
+ EventTypes.Name,
+ {"name": "my super duper space"},
+ tok=user2_tok,
+ )
+ # Encrypt the room
+ self.helper.send_state(
+ space_room_id,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=user2_tok,
+ )
+ # Add a tombstone
+ self.helper.send_state(
+ space_room_id,
+ EventTypes.Tombstone,
+ {EventContentFields.TOMBSTONE_SUCCESSOR_ROOM: "another_room"},
+ tok=user2_tok,
+ )
+
+ # User1 is invited to the room
+ user1_invited_response = self.helper.invite(
+ space_room_id, src=user2_id, targ=user1_id, tok=user2_tok
+ )
+ user1_invited_event_pos = self.get_success(
+ self.store.get_position_for_event(user1_invited_response["event_id"])
+ )
+
+ # Update the room name after we are invited just to make sure
+ # we don't update non-join memberships when the room name changes.
+ rename_response = self.helper.send_state(
+ space_room_id,
+ EventTypes.Name,
+ {"name": "my super duper space was renamed"},
+ tok=user2_tok,
+ )
+ rename_event_pos = self.get_success(
+ self.store.get_position_for_event(rename_response["event_id"])
+ )
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(space_room_id)
+ )
+
+ # User2 is still joined to the room so we should still have an entry in the
+ # `sliding_sync_joined_rooms` table.
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ {space_room_id},
+ exact=True,
+ )
+ self.assertEqual(
+ sliding_sync_joined_rooms_results[space_room_id],
+ _SlidingSyncJoinedRoomResult(
+ room_id=space_room_id,
+ event_stream_ordering=rename_event_pos.stream,
+ bump_stamp=state_map[
+ (EventTypes.Create, "")
+ ].internal_metadata.stream_ordering,
+ room_type=RoomTypes.SPACE,
+ room_name="my super duper space was renamed",
+ is_encrypted=True,
+ tombstone_successor_room_id="another_room",
+ ),
+ )
+
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (space_room_id, user1_id),
+ (space_room_id, user2_id),
+ },
+ exact=True,
+ )
+ # Holds the info according to the current state when the user was invited
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((space_room_id, user1_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=space_room_id,
+ user_id=user1_id,
+ sender=user2_id,
+ membership_event_id=user1_invited_response["event_id"],
+ membership=Membership.INVITE,
+ event_stream_ordering=user1_invited_event_pos.stream,
+ has_known_state=True,
+ room_type=RoomTypes.SPACE,
+ room_name="my super duper space",
+ is_encrypted=True,
+ tombstone_successor_room_id="another_room",
+ ),
+ )
+ # Holds the info according to the current state when the user joined
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((space_room_id, user2_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=space_room_id,
+ user_id=user2_id,
+ sender=user2_id,
+ membership_event_id=state_map[(EventTypes.Member, user2_id)].event_id,
+ membership=Membership.JOIN,
+ event_stream_ordering=state_map[
+ (EventTypes.Member, user2_id)
+ ].internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=RoomTypes.SPACE,
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ def test_non_join_invite_ban(self) -> None:
+ """
+ Test users who have invite/ban membership in room shows up in
+ `sliding_sync_membership_snapshots`.
+ """
+ user1_id = self.register_user("user1", "pass")
+ _user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+ user3_id = self.register_user("user3", "pass")
+ user3_tok = self.login(user3_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+
+ # User1 is invited to the room
+ user1_invited_response = self.helper.invite(
+ room_id1, src=user2_id, targ=user1_id, tok=user2_tok
+ )
+ user1_invited_event_pos = self.get_success(
+ self.store.get_position_for_event(user1_invited_response["event_id"])
+ )
+
+ # User3 joins the room
+ self.helper.join(room_id1, user3_id, tok=user3_tok)
+ # User3 is banned from the room
+ user3_ban_response = self.helper.ban(
+ room_id1, src=user2_id, targ=user3_id, tok=user2_tok
+ )
+ user3_ban_event_pos = self.get_success(
+ self.store.get_position_for_event(user3_ban_response["event_id"])
+ )
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ # User2 is still joined to the room so we should still have an entry
+ # in the `sliding_sync_joined_rooms` table.
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ {room_id1},
+ exact=True,
+ )
+ self.assertEqual(
+ sliding_sync_joined_rooms_results[room_id1],
+ _SlidingSyncJoinedRoomResult(
+ room_id=room_id1,
+ event_stream_ordering=user3_ban_event_pos.stream,
+ bump_stamp=state_map[
+ (EventTypes.Create, "")
+ ].internal_metadata.stream_ordering,
+ room_type=None,
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (room_id1, user1_id),
+ (room_id1, user2_id),
+ (room_id1, user3_id),
+ },
+ exact=True,
+ )
+ # Holds the info according to the current state when the user was invited
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id1, user1_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id1,
+ user_id=user1_id,
+ sender=user2_id,
+ membership_event_id=user1_invited_response["event_id"],
+ membership=Membership.INVITE,
+ event_stream_ordering=user1_invited_event_pos.stream,
+ has_known_state=True,
+ room_type=None,
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+ # Holds the info according to the current state when the user joined
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id1, user2_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id1,
+ user_id=user2_id,
+ sender=user2_id,
+ membership_event_id=state_map[(EventTypes.Member, user2_id)].event_id,
+ membership=Membership.JOIN,
+ event_stream_ordering=state_map[
+ (EventTypes.Member, user2_id)
+ ].internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=None,
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+ # Holds the info according to the current state when the user was banned
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id1, user3_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id1,
+ user_id=user3_id,
+ sender=user2_id,
+ membership_event_id=user3_ban_response["event_id"],
+ membership=Membership.BAN,
+ event_stream_ordering=user3_ban_event_pos.stream,
+ has_known_state=True,
+ room_type=None,
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ def test_non_join_reject_invite_empty_room(self) -> None:
+ """
+ In a room where no one is joined (`no_longer_in_room`), test rejecting an invite.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+
+ # User1 is invited to the room
+ self.helper.invite(room_id1, src=user2_id, targ=user1_id, tok=user2_tok)
+
+ # User2 leaves the room
+ user2_leave_response = self.helper.leave(room_id1, user2_id, tok=user2_tok)
+ user2_leave_event_pos = self.get_success(
+ self.store.get_position_for_event(user2_leave_response["event_id"])
+ )
+
+ # User1 rejects the invite
+ user1_leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
+ user1_leave_event_pos = self.get_success(
+ self.store.get_position_for_event(user1_leave_response["event_id"])
+ )
+
+ # No one is joined to the room
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ set(),
+ exact=True,
+ )
+
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (room_id1, user1_id),
+ (room_id1, user2_id),
+ },
+ exact=True,
+ )
+ # Holds the info according to the current state when the user left
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id1, user1_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id1,
+ user_id=user1_id,
+ sender=user1_id,
+ membership_event_id=user1_leave_response["event_id"],
+ membership=Membership.LEAVE,
+ event_stream_ordering=user1_leave_event_pos.stream,
+ has_known_state=True,
+ room_type=None,
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+ # Holds the info according to the current state when the left
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id1, user2_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id1,
+ user_id=user2_id,
+ sender=user2_id,
+ membership_event_id=user2_leave_response["event_id"],
+ membership=Membership.LEAVE,
+ event_stream_ordering=user2_leave_event_pos.stream,
+ has_known_state=True,
+ room_type=None,
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ def test_membership_changing(self) -> None:
+ """
+ Test latest snapshot evolves when membership changes (`sliding_sync_membership_snapshots`).
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+
+ # User1 is invited to the room
+ # ======================================================
+ user1_invited_response = self.helper.invite(
+ room_id1, src=user2_id, targ=user1_id, tok=user2_tok
+ )
+ user1_invited_event_pos = self.get_success(
+ self.store.get_position_for_event(user1_invited_response["event_id"])
+ )
+
+ # Update the room name after the user was invited
+ room_name_update_response = self.helper.send_state(
+ room_id1,
+ EventTypes.Name,
+ {"name": "my super duper room"},
+ tok=user2_tok,
+ )
+ room_name_update_event_pos = self.get_success(
+ self.store.get_position_for_event(room_name_update_response["event_id"])
+ )
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id1)
+ )
+
+ # Assert joined room status
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ {room_id1},
+ exact=True,
+ )
+ self.assertEqual(
+ sliding_sync_joined_rooms_results[room_id1],
+ _SlidingSyncJoinedRoomResult(
+ room_id=room_id1,
+ # Latest event in the room
+ event_stream_ordering=room_name_update_event_pos.stream,
+ bump_stamp=state_map[
+ (EventTypes.Create, "")
+ ].internal_metadata.stream_ordering,
+ room_type=None,
+ room_name="my super duper room",
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ # Assert membership snapshots
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (room_id1, user1_id),
+ (room_id1, user2_id),
+ },
+ exact=True,
+ )
+ # Holds the info according to the current state when the user was invited
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id1, user1_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id1,
+ user_id=user1_id,
+ sender=user2_id,
+ membership_event_id=user1_invited_response["event_id"],
+ membership=Membership.INVITE,
+ event_stream_ordering=user1_invited_event_pos.stream,
+ has_known_state=True,
+ room_type=None,
+ # Room name was updated after the user was invited so we should still
+ # see it unset here
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+ # Holds the info according to the current state when the user joined
+ user2_snapshot = _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id1,
+ user_id=user2_id,
+ sender=user2_id,
+ membership_event_id=state_map[(EventTypes.Member, user2_id)].event_id,
+ membership=Membership.JOIN,
+ event_stream_ordering=state_map[
+ (EventTypes.Member, user2_id)
+ ].internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=None,
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id1, user2_id)),
+ user2_snapshot,
+ )
+
+ # User1 joins the room
+ # ======================================================
+ user1_joined_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
+ user1_joined_event_pos = self.get_success(
+ self.store.get_position_for_event(user1_joined_response["event_id"])
+ )
+
+ # Assert joined room status
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ {room_id1},
+ exact=True,
+ )
+ self.assertEqual(
+ sliding_sync_joined_rooms_results[room_id1],
+ _SlidingSyncJoinedRoomResult(
+ room_id=room_id1,
+ # Latest event in the room
+ event_stream_ordering=user1_joined_event_pos.stream,
+ bump_stamp=state_map[
+ (EventTypes.Create, "")
+ ].internal_metadata.stream_ordering,
+ room_type=None,
+ room_name="my super duper room",
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ # Assert membership snapshots
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (room_id1, user1_id),
+ (room_id1, user2_id),
+ },
+ exact=True,
+ )
+ # Holds the info according to the current state when the user joined
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id1, user1_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id1,
+ user_id=user1_id,
+ sender=user1_id,
+ membership_event_id=user1_joined_response["event_id"],
+ membership=Membership.JOIN,
+ event_stream_ordering=user1_joined_event_pos.stream,
+ has_known_state=True,
+ room_type=None,
+ # We see the update state because the user joined after the room name
+ # change
+ room_name="my super duper room",
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+ # Holds the info according to the current state when the user joined
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id1, user2_id)),
+ user2_snapshot,
+ )
+
+ # User1 is banned from the room
+ # ======================================================
+ user1_ban_response = self.helper.ban(
+ room_id1, src=user2_id, targ=user1_id, tok=user2_tok
+ )
+ user1_ban_event_pos = self.get_success(
+ self.store.get_position_for_event(user1_ban_response["event_id"])
+ )
+
+ # Assert joined room status
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ {room_id1},
+ exact=True,
+ )
+ self.assertEqual(
+ sliding_sync_joined_rooms_results[room_id1],
+ _SlidingSyncJoinedRoomResult(
+ room_id=room_id1,
+ # Latest event in the room
+ event_stream_ordering=user1_ban_event_pos.stream,
+ bump_stamp=state_map[
+ (EventTypes.Create, "")
+ ].internal_metadata.stream_ordering,
+ room_type=None,
+ room_name="my super duper room",
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ # Assert membership snapshots
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (room_id1, user1_id),
+ (room_id1, user2_id),
+ },
+ exact=True,
+ )
+ # Holds the info according to the current state when the user was banned
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id1, user1_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id1,
+ user_id=user1_id,
+ sender=user2_id,
+ membership_event_id=user1_ban_response["event_id"],
+ membership=Membership.BAN,
+ event_stream_ordering=user1_ban_event_pos.stream,
+ has_known_state=True,
+ room_type=None,
+ # We see the update state because the user joined after the room name
+ # change
+ room_name="my super duper room",
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+ # Holds the info according to the current state when the user joined
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id1, user2_id)),
+ user2_snapshot,
+ )
+
+ def test_non_join_server_left_room(self) -> None:
+ """
+ Test everyone local leaves the room but their leave membership still shows up in
+ `sliding_sync_membership_snapshots`.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+
+ # User1 joins the room
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ # User2 leaves the room
+ user2_leave_response = self.helper.leave(room_id1, user2_id, tok=user2_tok)
+ user2_leave_event_pos = self.get_success(
+ self.store.get_position_for_event(user2_leave_response["event_id"])
+ )
+
+ # User1 leaves the room
+ user1_leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
+ user1_leave_event_pos = self.get_success(
+ self.store.get_position_for_event(user1_leave_response["event_id"])
+ )
+
+ # No one is joined to the room anymore so we shouldn't have an entry in the
+ # `sliding_sync_joined_rooms` table.
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ set(),
+ exact=True,
+ )
+
+ # We should still see rows for the leave events (non-joins)
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (room_id1, user1_id),
+ (room_id1, user2_id),
+ },
+ exact=True,
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id1, user1_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id1,
+ user_id=user1_id,
+ sender=user1_id,
+ membership_event_id=user1_leave_response["event_id"],
+ membership=Membership.LEAVE,
+ event_stream_ordering=user1_leave_event_pos.stream,
+ has_known_state=True,
+ room_type=None,
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id1, user2_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id1,
+ user_id=user2_id,
+ sender=user2_id,
+ membership_event_id=user2_leave_response["event_id"],
+ membership=Membership.LEAVE,
+ event_stream_ordering=user2_leave_event_pos.stream,
+ has_known_state=True,
+ room_type=None,
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ @parameterized.expand(
+ [
+ # No stripped state provided
+ ("none", None),
+ # Empty stripped state provided
+ ("empty", []),
+ ]
+ )
+ def test_non_join_remote_invite_no_stripped_state(
+ self, _description: str, stripped_state: Optional[List[StrippedStateEvent]]
+ ) -> None:
+ """
+ Test remote invite with no stripped state provided shows up in
+ `sliding_sync_membership_snapshots` with `has_known_state=False`.
+ """
+ user1_id = self.register_user("user1", "pass")
+ _user1_tok = self.login(user1_id, "pass")
+
+ # Create a remote invite room without any `unsigned.invite_room_state`
+ remote_invite_room_id, remote_invite_event = (
+ self._create_remote_invite_room_for_user(user1_id, stripped_state)
+ )
+
+ # No one local is joined to the remote room
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ set(),
+ exact=True,
+ )
+
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (remote_invite_room_id, user1_id),
+ },
+ exact=True,
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get(
+ (remote_invite_room_id, user1_id)
+ ),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=remote_invite_room_id,
+ user_id=user1_id,
+ sender="@inviter:remote_server",
+ membership_event_id=remote_invite_event.event_id,
+ membership=Membership.INVITE,
+ event_stream_ordering=remote_invite_event.internal_metadata.stream_ordering,
+ # No stripped state provided
+ has_known_state=False,
+ room_type=None,
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ def test_non_join_remote_invite_unencrypted_room(self) -> None:
+ """
+ Test remote invite with stripped state (unencrypted room) shows up in
+ `sliding_sync_membership_snapshots`.
+ """
+ user1_id = self.register_user("user1", "pass")
+ _user1_tok = self.login(user1_id, "pass")
+
+ # Create a remote invite room with some `unsigned.invite_room_state`
+ # indicating that the room is encrypted.
+ remote_invite_room_id, remote_invite_event = (
+ self._create_remote_invite_room_for_user(
+ user1_id,
+ [
+ StrippedStateEvent(
+ type=EventTypes.Create,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ROOM_CREATOR: "@inviter:remote_server",
+ EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier,
+ },
+ ),
+ StrippedStateEvent(
+ type=EventTypes.Name,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ROOM_NAME: "my super duper room",
+ },
+ ),
+ ],
+ )
+ )
+
+ # No one local is joined to the remote room
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ set(),
+ exact=True,
+ )
+
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (remote_invite_room_id, user1_id),
+ },
+ exact=True,
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get(
+ (remote_invite_room_id, user1_id)
+ ),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=remote_invite_room_id,
+ user_id=user1_id,
+ sender="@inviter:remote_server",
+ membership_event_id=remote_invite_event.event_id,
+ membership=Membership.INVITE,
+ event_stream_ordering=remote_invite_event.internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=None,
+ room_name="my super duper room",
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ def test_non_join_remote_invite_encrypted_room(self) -> None:
+ """
+ Test remote invite with stripped state (encrypted room) shows up in
+ `sliding_sync_membership_snapshots`.
+ """
+ user1_id = self.register_user("user1", "pass")
+ _user1_tok = self.login(user1_id, "pass")
+
+ # Create a remote invite room with some `unsigned.invite_room_state`
+ # indicating that the room is encrypted.
+ remote_invite_room_id, remote_invite_event = (
+ self._create_remote_invite_room_for_user(
+ user1_id,
+ [
+ StrippedStateEvent(
+ type=EventTypes.Create,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ROOM_CREATOR: "@inviter:remote_server",
+ EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier,
+ },
+ ),
+ StrippedStateEvent(
+ type=EventTypes.RoomEncryption,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2",
+ },
+ ),
+ # This is not one of the stripped state events according to the state
+ # but we still handle it.
+ StrippedStateEvent(
+ type=EventTypes.Tombstone,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.TOMBSTONE_SUCCESSOR_ROOM: "another_room",
+ },
+ ),
+ # Also test a random event that we don't care about
+ StrippedStateEvent(
+ type="org.matrix.foo_state",
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ "foo": "qux",
+ },
+ ),
+ ],
+ )
+ )
+
+ # No one local is joined to the remote room
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ set(),
+ exact=True,
+ )
+
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (remote_invite_room_id, user1_id),
+ },
+ exact=True,
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get(
+ (remote_invite_room_id, user1_id)
+ ),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=remote_invite_room_id,
+ user_id=user1_id,
+ sender="@inviter:remote_server",
+ membership_event_id=remote_invite_event.event_id,
+ membership=Membership.INVITE,
+ event_stream_ordering=remote_invite_event.internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=None,
+ room_name=None,
+ is_encrypted=True,
+ tombstone_successor_room_id="another_room",
+ ),
+ )
+
+ def test_non_join_remote_invite_space_room(self) -> None:
+ """
+ Test remote invite with stripped state (encrypted space room with name) shows up in
+ `sliding_sync_membership_snapshots`.
+ """
+ user1_id = self.register_user("user1", "pass")
+ _user1_tok = self.login(user1_id, "pass")
+
+ # Create a remote invite room with some `unsigned.invite_room_state`
+ # indicating that the room is encrypted.
+ remote_invite_room_id, remote_invite_event = (
+ self._create_remote_invite_room_for_user(
+ user1_id,
+ [
+ StrippedStateEvent(
+ type=EventTypes.Create,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ROOM_CREATOR: "@inviter:remote_server",
+ EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier,
+ # Specify that it is a space room
+ EventContentFields.ROOM_TYPE: RoomTypes.SPACE,
+ },
+ ),
+ StrippedStateEvent(
+ type=EventTypes.RoomEncryption,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2",
+ },
+ ),
+ StrippedStateEvent(
+ type=EventTypes.Name,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ROOM_NAME: "my super duper space",
+ },
+ ),
+ ],
+ )
+ )
+
+ # No one local is joined to the remote room
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ set(),
+ exact=True,
+ )
+
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (remote_invite_room_id, user1_id),
+ },
+ exact=True,
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get(
+ (remote_invite_room_id, user1_id)
+ ),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=remote_invite_room_id,
+ user_id=user1_id,
+ sender="@inviter:remote_server",
+ membership_event_id=remote_invite_event.event_id,
+ membership=Membership.INVITE,
+ event_stream_ordering=remote_invite_event.internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=RoomTypes.SPACE,
+ room_name="my super duper space",
+ is_encrypted=True,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ def test_non_join_reject_remote_invite(self) -> None:
+ """
+ Test rejected remote invite (user decided to leave the room) inherits meta data
+ from when the remote invite stripped state and shows up in
+ `sliding_sync_membership_snapshots`.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a remote invite room with some `unsigned.invite_room_state`
+ # indicating that the room is encrypted.
+ remote_invite_room_id, remote_invite_event = (
+ self._create_remote_invite_room_for_user(
+ user1_id,
+ [
+ StrippedStateEvent(
+ type=EventTypes.Create,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ROOM_CREATOR: "@inviter:remote_server",
+ EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier,
+ },
+ ),
+ StrippedStateEvent(
+ type=EventTypes.RoomEncryption,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2",
+ },
+ ),
+ ],
+ )
+ )
+
+ # User1 decides to leave the room (reject the invite)
+ user1_leave_response = self.helper.leave(
+ remote_invite_room_id, user1_id, tok=user1_tok
+ )
+ user1_leave_pos = self.get_success(
+ self.store.get_position_for_event(user1_leave_response["event_id"])
+ )
+
+ # No one local is joined to the remote room
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ set(),
+ exact=True,
+ )
+
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (remote_invite_room_id, user1_id),
+ },
+ exact=True,
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get(
+ (remote_invite_room_id, user1_id)
+ ),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=remote_invite_room_id,
+ user_id=user1_id,
+ sender=user1_id,
+ membership_event_id=user1_leave_response["event_id"],
+ membership=Membership.LEAVE,
+ event_stream_ordering=user1_leave_pos.stream,
+ has_known_state=True,
+ room_type=None,
+ room_name=None,
+ is_encrypted=True,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ def test_non_join_retracted_remote_invite(self) -> None:
+ """
+ Test retracted remote invite (Remote inviter kicks the person who was invited)
+ inherits meta data from when the remote invite stripped state and shows up in
+ `sliding_sync_membership_snapshots`.
+ """
+ user1_id = self.register_user("user1", "pass")
+ _user1_tok = self.login(user1_id, "pass")
+
+ # Create a remote invite room with some `unsigned.invite_room_state`
+ # indicating that the room is encrypted.
+ remote_invite_room_id, remote_invite_event = (
+ self._create_remote_invite_room_for_user(
+ user1_id,
+ [
+ StrippedStateEvent(
+ type=EventTypes.Create,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ROOM_CREATOR: "@inviter:remote_server",
+ EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier,
+ },
+ ),
+ StrippedStateEvent(
+ type=EventTypes.RoomEncryption,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2",
+ },
+ ),
+ ],
+ )
+ )
+
+ # `@inviter:remote_server` decides to retract the invite (kicks the user).
+ # (Note: A kick is just a leave event with a different sender)
+ remote_invite_retraction_event = self._retract_remote_invite_for_user(
+ user_id=user1_id,
+ remote_room_id=remote_invite_room_id,
+ )
+
+ # No one local is joined to the remote room
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ set(),
+ exact=True,
+ )
+
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (remote_invite_room_id, user1_id),
+ },
+ exact=True,
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get(
+ (remote_invite_room_id, user1_id)
+ ),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=remote_invite_room_id,
+ user_id=user1_id,
+ sender="@inviter:remote_server",
+ membership_event_id=remote_invite_retraction_event.event_id,
+ membership=Membership.LEAVE,
+ event_stream_ordering=remote_invite_retraction_event.internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=None,
+ room_name=None,
+ is_encrypted=True,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ def test_non_join_state_reset(self) -> None:
+ """
+ Test a state reset that removes someone from the room.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+ # Add a room name
+ self.helper.send_state(
+ room_id,
+ EventTypes.Name,
+ {"name": "my super duper room"},
+ tok=user2_tok,
+ )
+
+ # User1 joins the room
+ self.helper.join(room_id, user1_id, tok=user1_tok)
+
+ # Make sure we see the new room name
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ {room_id},
+ exact=True,
+ )
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id)
+ )
+ self.assertEqual(
+ sliding_sync_joined_rooms_results[room_id],
+ _SlidingSyncJoinedRoomResult(
+ room_id=room_id,
+ # This should be whatever is the last event in the room
+ event_stream_ordering=state_map[
+ (EventTypes.Member, user1_id)
+ ].internal_metadata.stream_ordering,
+ bump_stamp=state_map[
+ (EventTypes.Create, "")
+ ].internal_metadata.stream_ordering,
+ room_type=None,
+ room_name="my super duper room",
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (room_id, user1_id),
+ (room_id, user2_id),
+ },
+ exact=True,
+ )
+ user1_snapshot = _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id,
+ user_id=user1_id,
+ sender=user1_id,
+ membership_event_id=state_map[(EventTypes.Member, user1_id)].event_id,
+ membership=Membership.JOIN,
+ event_stream_ordering=state_map[
+ (EventTypes.Member, user1_id)
+ ].internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=None,
+ room_name="my super duper room",
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id, user1_id)),
+ user1_snapshot,
+ )
+ # Holds the info according to the current state when the user joined (no room
+ # name when the room creator joined)
+ user2_snapshot = _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id,
+ user_id=user2_id,
+ sender=user2_id,
+ membership_event_id=state_map[(EventTypes.Member, user2_id)].event_id,
+ membership=Membership.JOIN,
+ event_stream_ordering=state_map[
+ (EventTypes.Member, user2_id)
+ ].internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=None,
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id, user2_id)),
+ user2_snapshot,
+ )
+
+ # Mock a state reset removing the membership for user1 in the current state
+ message_tuple = self.get_success(
+ create_event(
+ self.hs,
+ prev_event_ids=[state_map[(EventTypes.Name, "")].event_id],
+ auth_event_ids=[
+ state_map[(EventTypes.Create, "")].event_id,
+ state_map[(EventTypes.Member, user1_id)].event_id,
+ ],
+ type=EventTypes.Message,
+ content={"body": "foo", "msgtype": "m.text"},
+ sender=user1_id,
+ room_id=room_id,
+ room_version=RoomVersions.V10.identifier,
+ )
+ )
+ event_chunk = [message_tuple]
+ self.get_success(
+ self.persist_events_store._persist_events_and_state_updates(
+ room_id,
+ event_chunk,
+ state_delta_for_room=DeltaState(
+ # This is the state reset part. We're removing the room name state.
+ to_delete=[(EventTypes.Member, user1_id)],
+ to_insert={},
+ ),
+ new_forward_extremities={message_tuple[0].event_id},
+ use_negative_stream_ordering=False,
+ inhibit_local_membership_updates=False,
+ new_event_links={},
+ )
+ )
+
+ # State reset on membership doesn't affect the`sliding_sync_joined_rooms` table
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ {room_id},
+ exact=True,
+ )
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id)
+ )
+ self.assertEqual(
+ sliding_sync_joined_rooms_results[room_id],
+ _SlidingSyncJoinedRoomResult(
+ room_id=room_id,
+ # This should be whatever is the last event in the room
+ event_stream_ordering=message_tuple[
+ 0
+ ].internal_metadata.stream_ordering,
+ bump_stamp=message_tuple[0].internal_metadata.stream_ordering,
+ room_type=None,
+ room_name="my super duper room",
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ # State reset on membership should remove the user's snapshot
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ # We shouldn't see user1 in the snapshots table anymore
+ (room_id, user2_id),
+ },
+ exact=True,
+ )
+ # Snapshot for user2 hasn't changed
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id, user2_id)),
+ user2_snapshot,
+ )
+
+ def test_membership_snapshot_forget(self) -> None:
+ """
+ Test forgetting a room will update `sliding_sync_membership_snapshots`
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+
+ # User1 joins the room
+ self.helper.join(room_id, user1_id, tok=user1_tok)
+ # User1 leaves the room (we have to leave in order to forget the room)
+ self.helper.leave(room_id, user1_id, tok=user1_tok)
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id)
+ )
+
+ # Check on the `sliding_sync_membership_snapshots` table (nothing should be
+ # forgotten yet)
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (room_id, user1_id),
+ (room_id, user2_id),
+ },
+ exact=True,
+ )
+ # Holds the info according to the current state when the user joined
+ user1_snapshot = _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id,
+ user_id=user1_id,
+ sender=user1_id,
+ membership_event_id=state_map[(EventTypes.Member, user1_id)].event_id,
+ membership=Membership.LEAVE,
+ event_stream_ordering=state_map[
+ (EventTypes.Member, user1_id)
+ ].internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=None,
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ # Room is not forgotten
+ forgotten=False,
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id, user1_id)),
+ user1_snapshot,
+ )
+ # Holds the info according to the current state when the user joined
+ user2_snapshot = _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id,
+ user_id=user2_id,
+ sender=user2_id,
+ membership_event_id=state_map[(EventTypes.Member, user2_id)].event_id,
+ membership=Membership.JOIN,
+ event_stream_ordering=state_map[
+ (EventTypes.Member, user2_id)
+ ].internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=None,
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id, user2_id)),
+ user2_snapshot,
+ )
+
+ # Forget the room
+ channel = self.make_request(
+ "POST",
+ f"/_matrix/client/r0/rooms/{room_id}/forget",
+ content={},
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # Check on the `sliding_sync_membership_snapshots` table
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (room_id, user1_id),
+ (room_id, user2_id),
+ },
+ exact=True,
+ )
+ # Room is now forgotten for user1
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id, user1_id)),
+ attr.evolve(user1_snapshot, forgotten=True),
+ )
+ # Nothing changed for user2
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id, user2_id)),
+ user2_snapshot,
+ )
+
+ def test_membership_snapshot_missing_forget(
+ self,
+ ) -> None:
+ """
+ Test forgetting a room with no existing row in `sliding_sync_membership_snapshots`.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+
+ # User1 joins the room
+ self.helper.join(room_id, user1_id, tok=user1_tok)
+ # User1 leaves the room (we have to leave in order to forget the room)
+ self.helper.leave(room_id, user1_id, tok=user1_tok)
+
+ # Clean-up the `sliding_sync_membership_snapshots` table as if the inserts did not
+ # happen during event creation.
+ self.get_success(
+ self.store.db_pool.simple_delete_many(
+ table="sliding_sync_membership_snapshots",
+ column="room_id",
+ iterable=(room_id,),
+ keyvalues={},
+ desc="sliding_sync_membership_snapshots.test_membership_snapshots_background_update_forgotten_missing",
+ )
+ )
+
+ # We shouldn't find anything in the table because we just deleted them in
+ # preparation for the test.
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ set(),
+ exact=True,
+ )
+
+ # Forget the room
+ channel = self.make_request(
+ "POST",
+ f"/_matrix/client/r0/rooms/{room_id}/forget",
+ content={},
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # It doesn't explode
+
+ # We still shouldn't find anything in the table because nothing has re-created them
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ set(),
+ exact=True,
+ )
+
+
+class SlidingSyncTablesBackgroundUpdatesTestCase(SlidingSyncTablesTestCaseBase):
+ """
+ Test the background updates that populate the `sliding_sync_joined_rooms` and
+ `sliding_sync_membership_snapshots` tables.
+ """
+
+ def test_joined_background_update_missing(self) -> None:
+ """
+ Test that the background update for `sliding_sync_joined_rooms` populates missing rows
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create rooms with various levels of state that should appear in the table
+ #
+ room_id_no_info = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ room_id_with_info = self.helper.create_room_as(user1_id, tok=user1_tok)
+ # Add a room name
+ self.helper.send_state(
+ room_id_with_info,
+ EventTypes.Name,
+ {"name": "my super duper room"},
+ tok=user1_tok,
+ )
+ # Encrypt the room
+ self.helper.send_state(
+ room_id_with_info,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=user1_tok,
+ )
+
+ space_room_id = self.helper.create_room_as(
+ user1_id,
+ tok=user1_tok,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
+ },
+ )
+ # Add a room name
+ self.helper.send_state(
+ space_room_id,
+ EventTypes.Name,
+ {"name": "my super duper space"},
+ tok=user1_tok,
+ )
+
+ # Clean-up the `sliding_sync_joined_rooms` table as if the inserts did not
+ # happen during event creation.
+ self.get_success(
+ self.store.db_pool.simple_delete_many(
+ table="sliding_sync_joined_rooms",
+ column="room_id",
+ iterable=(room_id_no_info, room_id_with_info, space_room_id),
+ keyvalues={},
+ desc="sliding_sync_joined_rooms.test_joined_background_update_missing",
+ )
+ )
+
+ # We shouldn't find anything in the table because we just deleted them in
+ # preparation for the test.
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ set(),
+ exact=True,
+ )
+
+ # Insert and run the background updates.
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {
+ "update_name": _BackgroundUpdates.SLIDING_SYNC_PREFILL_JOINED_ROOMS_TO_RECALCULATE_TABLE_BG_UPDATE,
+ "progress_json": "{}",
+ },
+ )
+ )
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {
+ "update_name": _BackgroundUpdates.SLIDING_SYNC_JOINED_ROOMS_BG_UPDATE,
+ "progress_json": "{}",
+ "depends_on": _BackgroundUpdates.SLIDING_SYNC_PREFILL_JOINED_ROOMS_TO_RECALCULATE_TABLE_BG_UPDATE,
+ },
+ )
+ )
+ self.store.db_pool.updates._all_done = False
+ self.wait_for_background_updates()
+
+ # Make sure the table is populated
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ {room_id_no_info, room_id_with_info, space_room_id},
+ exact=True,
+ )
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id_no_info)
+ )
+ self.assertEqual(
+ sliding_sync_joined_rooms_results[room_id_no_info],
+ _SlidingSyncJoinedRoomResult(
+ room_id=room_id_no_info,
+ # History visibility just happens to be the last event sent in the room
+ event_stream_ordering=state_map[
+ (EventTypes.RoomHistoryVisibility, "")
+ ].internal_metadata.stream_ordering,
+ bump_stamp=state_map[
+ (EventTypes.Create, "")
+ ].internal_metadata.stream_ordering,
+ room_type=None,
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id_with_info)
+ )
+ self.assertEqual(
+ sliding_sync_joined_rooms_results[room_id_with_info],
+ _SlidingSyncJoinedRoomResult(
+ room_id=room_id_with_info,
+ # Lastest event sent in the room
+ event_stream_ordering=state_map[
+ (EventTypes.RoomEncryption, "")
+ ].internal_metadata.stream_ordering,
+ bump_stamp=state_map[
+ (EventTypes.Create, "")
+ ].internal_metadata.stream_ordering,
+ room_type=None,
+ room_name="my super duper room",
+ is_encrypted=True,
+ tombstone_successor_room_id=None,
+ ),
+ )
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(space_room_id)
+ )
+ self.assertEqual(
+ sliding_sync_joined_rooms_results[space_room_id],
+ _SlidingSyncJoinedRoomResult(
+ room_id=space_room_id,
+ # Lastest event sent in the room
+ event_stream_ordering=state_map[
+ (EventTypes.Name, "")
+ ].internal_metadata.stream_ordering,
+ bump_stamp=state_map[
+ (EventTypes.Create, "")
+ ].internal_metadata.stream_ordering,
+ room_type=RoomTypes.SPACE,
+ room_name="my super duper space",
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ def test_membership_snapshots_background_update_joined(self) -> None:
+ """
+ Test that the background update for `sliding_sync_membership_snapshots`
+ populates missing rows for join memberships.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create rooms with various levels of state that should appear in the table
+ #
+ room_id_no_info = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ room_id_with_info = self.helper.create_room_as(user1_id, tok=user1_tok)
+ # Add a room name
+ self.helper.send_state(
+ room_id_with_info,
+ EventTypes.Name,
+ {"name": "my super duper room"},
+ tok=user1_tok,
+ )
+ # Encrypt the room
+ self.helper.send_state(
+ room_id_with_info,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=user1_tok,
+ )
+ # Add a tombstone
+ self.helper.send_state(
+ room_id_with_info,
+ EventTypes.Tombstone,
+ {EventContentFields.TOMBSTONE_SUCCESSOR_ROOM: "another_room"},
+ tok=user1_tok,
+ )
+
+ space_room_id = self.helper.create_room_as(
+ user1_id,
+ tok=user1_tok,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
+ },
+ )
+ # Add a room name
+ self.helper.send_state(
+ space_room_id,
+ EventTypes.Name,
+ {"name": "my super duper space"},
+ tok=user1_tok,
+ )
+
+ # Clean-up the `sliding_sync_membership_snapshots` table as if the inserts did not
+ # happen during event creation.
+ self.get_success(
+ self.store.db_pool.simple_delete_many(
+ table="sliding_sync_membership_snapshots",
+ column="room_id",
+ iterable=(room_id_no_info, room_id_with_info, space_room_id),
+ keyvalues={},
+ desc="sliding_sync_membership_snapshots.test_membership_snapshots_background_update_joined",
+ )
+ )
+
+ # We shouldn't find anything in the table because we just deleted them in
+ # preparation for the test.
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ set(),
+ exact=True,
+ )
+
+ # Insert and run the background update.
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {
+ "update_name": _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BG_UPDATE,
+ "progress_json": "{}",
+ },
+ )
+ )
+ self.store.db_pool.updates._all_done = False
+ self.wait_for_background_updates()
+
+ # Make sure the table is populated
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (room_id_no_info, user1_id),
+ (room_id_with_info, user1_id),
+ (space_room_id, user1_id),
+ },
+ exact=True,
+ )
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id_no_info)
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id_no_info, user1_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id_no_info,
+ user_id=user1_id,
+ sender=user1_id,
+ membership_event_id=state_map[(EventTypes.Member, user1_id)].event_id,
+ membership=Membership.JOIN,
+ event_stream_ordering=state_map[
+ (EventTypes.Member, user1_id)
+ ].internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=None,
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id_with_info)
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get(
+ (room_id_with_info, user1_id)
+ ),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id_with_info,
+ user_id=user1_id,
+ sender=user1_id,
+ membership_event_id=state_map[(EventTypes.Member, user1_id)].event_id,
+ membership=Membership.JOIN,
+ event_stream_ordering=state_map[
+ (EventTypes.Member, user1_id)
+ ].internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=None,
+ room_name="my super duper room",
+ is_encrypted=True,
+ tombstone_successor_room_id="another_room",
+ ),
+ )
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(space_room_id)
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((space_room_id, user1_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=space_room_id,
+ user_id=user1_id,
+ sender=user1_id,
+ membership_event_id=state_map[(EventTypes.Member, user1_id)].event_id,
+ membership=Membership.JOIN,
+ event_stream_ordering=state_map[
+ (EventTypes.Member, user1_id)
+ ].internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=RoomTypes.SPACE,
+ room_name="my super duper space",
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ def test_membership_snapshots_background_update_local_invite(self) -> None:
+ """
+ Test that the background update for `sliding_sync_membership_snapshots`
+ populates missing rows for invite memberships.
+ """
+ user1_id = self.register_user("user1", "pass")
+ _user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # Create rooms with various levels of state that should appear in the table
+ #
+ room_id_no_info = self.helper.create_room_as(user2_id, tok=user2_tok)
+
+ room_id_with_info = self.helper.create_room_as(user2_id, tok=user2_tok)
+ # Add a room name
+ self.helper.send_state(
+ room_id_with_info,
+ EventTypes.Name,
+ {"name": "my super duper room"},
+ tok=user2_tok,
+ )
+ # Encrypt the room
+ self.helper.send_state(
+ room_id_with_info,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=user2_tok,
+ )
+ # Add a tombstone
+ self.helper.send_state(
+ room_id_with_info,
+ EventTypes.Tombstone,
+ {EventContentFields.TOMBSTONE_SUCCESSOR_ROOM: "another_room"},
+ tok=user2_tok,
+ )
+
+ space_room_id = self.helper.create_room_as(
+ user1_id,
+ tok=user2_tok,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
+ },
+ )
+ # Add a room name
+ self.helper.send_state(
+ space_room_id,
+ EventTypes.Name,
+ {"name": "my super duper space"},
+ tok=user2_tok,
+ )
+
+ # Invite user1 to the rooms
+ user1_invite_room_id_no_info_response = self.helper.invite(
+ room_id_no_info, src=user2_id, targ=user1_id, tok=user2_tok
+ )
+ user1_invite_room_id_with_info_response = self.helper.invite(
+ room_id_with_info, src=user2_id, targ=user1_id, tok=user2_tok
+ )
+ user1_invite_space_room_id_response = self.helper.invite(
+ space_room_id, src=user2_id, targ=user1_id, tok=user2_tok
+ )
+
+ # Have user2 leave the rooms to make sure that our background update is not just
+ # reading from `current_state_events`. For invite/knock memberships, we should
+ # be reading from the stripped state on the invite/knock event itself.
+ self.helper.leave(room_id_no_info, user2_id, tok=user2_tok)
+ self.helper.leave(room_id_with_info, user2_id, tok=user2_tok)
+ self.helper.leave(space_room_id, user2_id, tok=user2_tok)
+ # Check to make sure we actually don't have any `current_state_events` for the rooms
+ current_state_check_rows = self.get_success(
+ self.store.db_pool.simple_select_many_batch(
+ table="current_state_events",
+ column="room_id",
+ iterable=[room_id_no_info, room_id_with_info, space_room_id],
+ retcols=("event_id",),
+ keyvalues={},
+ desc="check current_state_events in test",
+ )
+ )
+ self.assertEqual(len(current_state_check_rows), 0)
+
+ # Clean-up the `sliding_sync_membership_snapshots` table as if the inserts did not
+ # happen during event creation.
+ self.get_success(
+ self.store.db_pool.simple_delete_many(
+ table="sliding_sync_membership_snapshots",
+ column="room_id",
+ iterable=(room_id_no_info, room_id_with_info, space_room_id),
+ keyvalues={},
+ desc="sliding_sync_membership_snapshots.test_membership_snapshots_background_update_local_invite",
+ )
+ )
+
+ # We shouldn't find anything in the table because we just deleted them in
+ # preparation for the test.
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ set(),
+ exact=True,
+ )
+
+ # Insert and run the background update.
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {
+ "update_name": _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BG_UPDATE,
+ "progress_json": "{}",
+ },
+ )
+ )
+ self.store.db_pool.updates._all_done = False
+ self.wait_for_background_updates()
+
+ # Make sure the table is populated
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ # The invite memberships for user1
+ (room_id_no_info, user1_id),
+ (room_id_with_info, user1_id),
+ (space_room_id, user1_id),
+ # The leave memberships for user2
+ (room_id_no_info, user2_id),
+ (room_id_with_info, user2_id),
+ (space_room_id, user2_id),
+ },
+ exact=True,
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id_no_info, user1_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id_no_info,
+ user_id=user1_id,
+ sender=user2_id,
+ membership_event_id=user1_invite_room_id_no_info_response["event_id"],
+ membership=Membership.INVITE,
+ event_stream_ordering=self.get_success(
+ self.store.get_position_for_event(
+ user1_invite_room_id_no_info_response["event_id"]
+ )
+ ).stream,
+ has_known_state=True,
+ room_type=None,
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get(
+ (room_id_with_info, user1_id)
+ ),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id_with_info,
+ user_id=user1_id,
+ sender=user2_id,
+ membership_event_id=user1_invite_room_id_with_info_response["event_id"],
+ membership=Membership.INVITE,
+ event_stream_ordering=self.get_success(
+ self.store.get_position_for_event(
+ user1_invite_room_id_with_info_response["event_id"]
+ )
+ ).stream,
+ has_known_state=True,
+ room_type=None,
+ room_name="my super duper room",
+ is_encrypted=True,
+ # The tombstone isn't showing here ("another_room") because it's not one
+ # of the stripped events that we hand out as part of the invite event.
+ # Even though we handle this scenario from other remote homservers,
+ # Synapse does not include the tombstone in the invite event.
+ tombstone_successor_room_id=None,
+ ),
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((space_room_id, user1_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=space_room_id,
+ user_id=user1_id,
+ sender=user2_id,
+ membership_event_id=user1_invite_space_room_id_response["event_id"],
+ membership=Membership.INVITE,
+ event_stream_ordering=self.get_success(
+ self.store.get_position_for_event(
+ user1_invite_space_room_id_response["event_id"]
+ )
+ ).stream,
+ has_known_state=True,
+ room_type=RoomTypes.SPACE,
+ room_name="my super duper space",
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ def test_membership_snapshots_background_update_remote_invite(
+ self,
+ ) -> None:
+ """
+ Test that the background update for `sliding_sync_membership_snapshots`
+ populates missing rows for remote invites (out-of-band memberships).
+ """
+ user1_id = self.register_user("user1", "pass")
+ _user1_tok = self.login(user1_id, "pass")
+
+ # Create rooms with various levels of state that should appear in the table
+ #
+ room_id_unknown_state, room_id_unknown_state_invite_event = (
+ self._create_remote_invite_room_for_user(user1_id, None)
+ )
+
+ room_id_no_info, room_id_no_info_invite_event = (
+ self._create_remote_invite_room_for_user(
+ user1_id,
+ [
+ StrippedStateEvent(
+ type=EventTypes.Create,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ROOM_CREATOR: "@inviter:remote_server",
+ EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier,
+ },
+ ),
+ ],
+ )
+ )
+
+ room_id_with_info, room_id_with_info_invite_event = (
+ self._create_remote_invite_room_for_user(
+ user1_id,
+ [
+ StrippedStateEvent(
+ type=EventTypes.Create,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ROOM_CREATOR: "@inviter:remote_server",
+ EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier,
+ },
+ ),
+ StrippedStateEvent(
+ type=EventTypes.Name,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ROOM_NAME: "my super duper room",
+ },
+ ),
+ StrippedStateEvent(
+ type=EventTypes.RoomEncryption,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2",
+ },
+ ),
+ ],
+ )
+ )
+
+ space_room_id, space_room_id_invite_event = (
+ self._create_remote_invite_room_for_user(
+ user1_id,
+ [
+ StrippedStateEvent(
+ type=EventTypes.Create,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ROOM_CREATOR: "@inviter:remote_server",
+ EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier,
+ EventContentFields.ROOM_TYPE: RoomTypes.SPACE,
+ },
+ ),
+ StrippedStateEvent(
+ type=EventTypes.Name,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ROOM_NAME: "my super duper space",
+ },
+ ),
+ ],
+ )
+ )
+
+ # Clean-up the `sliding_sync_membership_snapshots` table as if the inserts did not
+ # happen during event creation.
+ self.get_success(
+ self.store.db_pool.simple_delete_many(
+ table="sliding_sync_membership_snapshots",
+ column="room_id",
+ iterable=(
+ room_id_unknown_state,
+ room_id_no_info,
+ room_id_with_info,
+ space_room_id,
+ ),
+ keyvalues={},
+ desc="sliding_sync_membership_snapshots.test_membership_snapshots_background_update_remote_invite",
+ )
+ )
+
+ # We shouldn't find anything in the table because we just deleted them in
+ # preparation for the test.
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ set(),
+ exact=True,
+ )
+
+ # Insert and run the background update.
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {
+ "update_name": _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BG_UPDATE,
+ "progress_json": "{}",
+ },
+ )
+ )
+ self.store.db_pool.updates._all_done = False
+ self.wait_for_background_updates()
+
+ # Make sure the table is populated
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ # The invite memberships for user1
+ (room_id_unknown_state, user1_id),
+ (room_id_no_info, user1_id),
+ (room_id_with_info, user1_id),
+ (space_room_id, user1_id),
+ },
+ exact=True,
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get(
+ (room_id_unknown_state, user1_id)
+ ),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id_unknown_state,
+ user_id=user1_id,
+ sender="@inviter:remote_server",
+ membership_event_id=room_id_unknown_state_invite_event.event_id,
+ membership=Membership.INVITE,
+ event_stream_ordering=room_id_unknown_state_invite_event.internal_metadata.stream_ordering,
+ has_known_state=False,
+ room_type=None,
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id_no_info, user1_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id_no_info,
+ user_id=user1_id,
+ sender="@inviter:remote_server",
+ membership_event_id=room_id_no_info_invite_event.event_id,
+ membership=Membership.INVITE,
+ event_stream_ordering=room_id_no_info_invite_event.internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=None,
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get(
+ (room_id_with_info, user1_id)
+ ),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id_with_info,
+ user_id=user1_id,
+ sender="@inviter:remote_server",
+ membership_event_id=room_id_with_info_invite_event.event_id,
+ membership=Membership.INVITE,
+ event_stream_ordering=room_id_with_info_invite_event.internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=None,
+ room_name="my super duper room",
+ is_encrypted=True,
+ tombstone_successor_room_id=None,
+ ),
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((space_room_id, user1_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=space_room_id,
+ user_id=user1_id,
+ sender="@inviter:remote_server",
+ membership_event_id=space_room_id_invite_event.event_id,
+ membership=Membership.INVITE,
+ event_stream_ordering=space_room_id_invite_event.internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=RoomTypes.SPACE,
+ room_name="my super duper space",
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ def test_membership_snapshots_background_update_remote_invite_rejections_and_retractions(
+ self,
+ ) -> None:
+ """
+ Test that the background update for `sliding_sync_membership_snapshots`
+ populates missing rows for remote invite rejections/retractions (out-of-band memberships).
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create rooms with various levels of state that should appear in the table
+ #
+ room_id_unknown_state, room_id_unknown_state_invite_event = (
+ self._create_remote_invite_room_for_user(user1_id, None)
+ )
+
+ room_id_no_info, room_id_no_info_invite_event = (
+ self._create_remote_invite_room_for_user(
+ user1_id,
+ [
+ StrippedStateEvent(
+ type=EventTypes.Create,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ROOM_CREATOR: "@inviter:remote_server",
+ EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier,
+ },
+ ),
+ ],
+ )
+ )
+
+ room_id_with_info, room_id_with_info_invite_event = (
+ self._create_remote_invite_room_for_user(
+ user1_id,
+ [
+ StrippedStateEvent(
+ type=EventTypes.Create,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ROOM_CREATOR: "@inviter:remote_server",
+ EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier,
+ },
+ ),
+ StrippedStateEvent(
+ type=EventTypes.Name,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ROOM_NAME: "my super duper room",
+ },
+ ),
+ StrippedStateEvent(
+ type=EventTypes.RoomEncryption,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2",
+ },
+ ),
+ ],
+ )
+ )
+
+ space_room_id, space_room_id_invite_event = (
+ self._create_remote_invite_room_for_user(
+ user1_id,
+ [
+ StrippedStateEvent(
+ type=EventTypes.Create,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ROOM_CREATOR: "@inviter:remote_server",
+ EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier,
+ EventContentFields.ROOM_TYPE: RoomTypes.SPACE,
+ },
+ ),
+ StrippedStateEvent(
+ type=EventTypes.Name,
+ state_key="",
+ sender="@inviter:remote_server",
+ content={
+ EventContentFields.ROOM_NAME: "my super duper space",
+ },
+ ),
+ ],
+ )
+ )
+
+ # Reject the remote invites.
+ # Also try retracting a remote invite.
+ room_id_unknown_state_leave_event_response = self.helper.leave(
+ room_id_unknown_state, user1_id, tok=user1_tok
+ )
+ room_id_no_info_leave_event = self._retract_remote_invite_for_user(
+ user_id=user1_id,
+ remote_room_id=room_id_no_info,
+ )
+ room_id_with_info_leave_event_response = self.helper.leave(
+ room_id_with_info, user1_id, tok=user1_tok
+ )
+ space_room_id_leave_event = self._retract_remote_invite_for_user(
+ user_id=user1_id,
+ remote_room_id=space_room_id,
+ )
+
+ # Clean-up the `sliding_sync_membership_snapshots` table as if the inserts did not
+ # happen during event creation.
+ self.get_success(
+ self.store.db_pool.simple_delete_many(
+ table="sliding_sync_membership_snapshots",
+ column="room_id",
+ iterable=(
+ room_id_unknown_state,
+ room_id_no_info,
+ room_id_with_info,
+ space_room_id,
+ ),
+ keyvalues={},
+ desc="sliding_sync_membership_snapshots.test_membership_snapshots_background_update_remote_invite_rejections_and_retractions",
+ )
+ )
+
+ # We shouldn't find anything in the table because we just deleted them in
+ # preparation for the test.
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ set(),
+ exact=True,
+ )
+
+ # Insert and run the background update.
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {
+ "update_name": _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BG_UPDATE,
+ "progress_json": "{}",
+ },
+ )
+ )
+ self.store.db_pool.updates._all_done = False
+ self.wait_for_background_updates()
+
+ # Make sure the table is populated
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ # The invite memberships for user1
+ (room_id_unknown_state, user1_id),
+ (room_id_no_info, user1_id),
+ (room_id_with_info, user1_id),
+ (space_room_id, user1_id),
+ },
+ exact=True,
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get(
+ (room_id_unknown_state, user1_id)
+ ),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id_unknown_state,
+ user_id=user1_id,
+ sender=user1_id,
+ membership_event_id=room_id_unknown_state_leave_event_response[
+ "event_id"
+ ],
+ membership=Membership.LEAVE,
+ event_stream_ordering=self.get_success(
+ self.store.get_position_for_event(
+ room_id_unknown_state_leave_event_response["event_id"]
+ )
+ ).stream,
+ has_known_state=False,
+ room_type=None,
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id_no_info, user1_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id_no_info,
+ user_id=user1_id,
+ sender="@inviter:remote_server",
+ membership_event_id=room_id_no_info_leave_event.event_id,
+ membership=Membership.LEAVE,
+ event_stream_ordering=room_id_no_info_leave_event.internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=None,
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get(
+ (room_id_with_info, user1_id)
+ ),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id_with_info,
+ user_id=user1_id,
+ sender=user1_id,
+ membership_event_id=room_id_with_info_leave_event_response["event_id"],
+ membership=Membership.LEAVE,
+ event_stream_ordering=self.get_success(
+ self.store.get_position_for_event(
+ room_id_with_info_leave_event_response["event_id"]
+ )
+ ).stream,
+ has_known_state=True,
+ room_type=None,
+ room_name="my super duper room",
+ is_encrypted=True,
+ tombstone_successor_room_id=None,
+ ),
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((space_room_id, user1_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=space_room_id,
+ user_id=user1_id,
+ sender="@inviter:remote_server",
+ membership_event_id=space_room_id_leave_event.event_id,
+ membership=Membership.LEAVE,
+ event_stream_ordering=space_room_id_leave_event.internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=RoomTypes.SPACE,
+ room_name="my super duper space",
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ @parameterized.expand(
+ [
+ # We'll do a kick for this
+ (Membership.LEAVE,),
+ (Membership.BAN,),
+ ]
+ )
+ def test_membership_snapshots_background_update_historical_state(
+ self, test_membership: str
+ ) -> None:
+ """
+ Test that the background update for `sliding_sync_membership_snapshots`
+ populates missing rows for leave memberships.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # Create rooms with various levels of state that should appear in the table
+ #
+ room_id_no_info = self.helper.create_room_as(user2_id, tok=user2_tok)
+
+ room_id_with_info = self.helper.create_room_as(user2_id, tok=user2_tok)
+ # Add a room name
+ self.helper.send_state(
+ room_id_with_info,
+ EventTypes.Name,
+ {"name": "my super duper room"},
+ tok=user2_tok,
+ )
+ # Encrypt the room
+ self.helper.send_state(
+ room_id_with_info,
+ EventTypes.RoomEncryption,
+ {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+ tok=user2_tok,
+ )
+ # Add a tombstone
+ self.helper.send_state(
+ room_id_with_info,
+ EventTypes.Tombstone,
+ {EventContentFields.TOMBSTONE_SUCCESSOR_ROOM: "another_room"},
+ tok=user2_tok,
+ )
+
+ space_room_id = self.helper.create_room_as(
+ user1_id,
+ tok=user2_tok,
+ extra_content={
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
+ },
+ )
+ # Add a room name
+ self.helper.send_state(
+ space_room_id,
+ EventTypes.Name,
+ {"name": "my super duper space"},
+ tok=user2_tok,
+ )
+
+ # Join the room in preparation for our test_membership
+ self.helper.join(room_id_no_info, user1_id, tok=user1_tok)
+ self.helper.join(room_id_with_info, user1_id, tok=user1_tok)
+ self.helper.join(space_room_id, user1_id, tok=user1_tok)
+
+ if test_membership == Membership.LEAVE:
+ # Kick user1 from the rooms
+ user1_membership_room_id_no_info_response = self.helper.change_membership(
+ room=room_id_no_info,
+ src=user2_id,
+ targ=user1_id,
+ tok=user2_tok,
+ membership=Membership.LEAVE,
+ extra_data={
+ "reason": "Bad manners",
+ },
+ )
+ user1_membership_room_id_with_info_response = self.helper.change_membership(
+ room=room_id_with_info,
+ src=user2_id,
+ targ=user1_id,
+ tok=user2_tok,
+ membership=Membership.LEAVE,
+ extra_data={
+ "reason": "Bad manners",
+ },
+ )
+ user1_membership_space_room_id_response = self.helper.change_membership(
+ room=space_room_id,
+ src=user2_id,
+ targ=user1_id,
+ tok=user2_tok,
+ membership=Membership.LEAVE,
+ extra_data={
+ "reason": "Bad manners",
+ },
+ )
+ elif test_membership == Membership.BAN:
+ # Ban user1 from the rooms
+ user1_membership_room_id_no_info_response = self.helper.ban(
+ room_id_no_info, src=user2_id, targ=user1_id, tok=user2_tok
+ )
+ user1_membership_room_id_with_info_response = self.helper.ban(
+ room_id_with_info, src=user2_id, targ=user1_id, tok=user2_tok
+ )
+ user1_membership_space_room_id_response = self.helper.ban(
+ space_room_id, src=user2_id, targ=user1_id, tok=user2_tok
+ )
+ else:
+ raise AssertionError("Unknown test_membership")
+
+ # Have user2 leave the rooms to make sure that our background update is not just
+ # reading from `current_state_events`. For leave memberships, we should be
+ # reading from the historical state.
+ self.helper.leave(room_id_no_info, user2_id, tok=user2_tok)
+ self.helper.leave(room_id_with_info, user2_id, tok=user2_tok)
+ self.helper.leave(space_room_id, user2_id, tok=user2_tok)
+ # Check to make sure we actually don't have any `current_state_events` for the rooms
+ current_state_check_rows = self.get_success(
+ self.store.db_pool.simple_select_many_batch(
+ table="current_state_events",
+ column="room_id",
+ iterable=[room_id_no_info, room_id_with_info, space_room_id],
+ retcols=("event_id",),
+ keyvalues={},
+ desc="check current_state_events in test",
+ )
+ )
+ self.assertEqual(len(current_state_check_rows), 0)
+
+ # Clean-up the `sliding_sync_membership_snapshots` table as if the inserts did not
+ # happen during event creation.
+ self.get_success(
+ self.store.db_pool.simple_delete_many(
+ table="sliding_sync_membership_snapshots",
+ column="room_id",
+ iterable=(room_id_no_info, room_id_with_info, space_room_id),
+ keyvalues={},
+ desc="sliding_sync_membership_snapshots.test_membership_snapshots_background_update_historical_state",
+ )
+ )
+
+ # We shouldn't find anything in the table because we just deleted them in
+ # preparation for the test.
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ set(),
+ exact=True,
+ )
+
+ # Insert and run the background update.
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {
+ "update_name": _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BG_UPDATE,
+ "progress_json": "{}",
+ },
+ )
+ )
+ self.store.db_pool.updates._all_done = False
+ self.wait_for_background_updates()
+
+ # Make sure the table is populated
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ # The memberships for user1
+ (room_id_no_info, user1_id),
+ (room_id_with_info, user1_id),
+ (space_room_id, user1_id),
+ # The leave memberships for user2
+ (room_id_no_info, user2_id),
+ (room_id_with_info, user2_id),
+ (space_room_id, user2_id),
+ },
+ exact=True,
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id_no_info, user1_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id_no_info,
+ user_id=user1_id,
+ # Because user2 kicked/banned user1 from the room
+ sender=user2_id,
+ membership_event_id=user1_membership_room_id_no_info_response[
+ "event_id"
+ ],
+ membership=test_membership,
+ event_stream_ordering=self.get_success(
+ self.store.get_position_for_event(
+ user1_membership_room_id_no_info_response["event_id"]
+ )
+ ).stream,
+ has_known_state=True,
+ room_type=None,
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get(
+ (room_id_with_info, user1_id)
+ ),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id_with_info,
+ user_id=user1_id,
+ # Because user2 kicked/banned user1 from the room
+ sender=user2_id,
+ membership_event_id=user1_membership_room_id_with_info_response[
+ "event_id"
+ ],
+ membership=test_membership,
+ event_stream_ordering=self.get_success(
+ self.store.get_position_for_event(
+ user1_membership_room_id_with_info_response["event_id"]
+ )
+ ).stream,
+ has_known_state=True,
+ room_type=None,
+ room_name="my super duper room",
+ is_encrypted=True,
+ tombstone_successor_room_id="another_room",
+ ),
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((space_room_id, user1_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=space_room_id,
+ user_id=user1_id,
+ # Because user2 kicked/banned user1 from the room
+ sender=user2_id,
+ membership_event_id=user1_membership_space_room_id_response["event_id"],
+ membership=test_membership,
+ event_stream_ordering=self.get_success(
+ self.store.get_position_for_event(
+ user1_membership_space_room_id_response["event_id"]
+ )
+ ).stream,
+ has_known_state=True,
+ room_type=RoomTypes.SPACE,
+ room_name="my super duper space",
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+ def test_membership_snapshots_background_update_forgotten_missing(self) -> None:
+ """
+ Test that a new row is inserted into `sliding_sync_membership_snapshots` when it
+ doesn't exist in the table yet.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+
+ # User1 joins the room
+ self.helper.join(room_id, user1_id, tok=user1_tok)
+ # User1 leaves the room (we have to leave in order to forget the room)
+ self.helper.leave(room_id, user1_id, tok=user1_tok)
+
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id)
+ )
+
+ # Forget the room
+ channel = self.make_request(
+ "POST",
+ f"/_matrix/client/r0/rooms/{room_id}/forget",
+ content={},
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # Clean-up the `sliding_sync_membership_snapshots` table as if the inserts did not
+ # happen during event creation.
+ self.get_success(
+ self.store.db_pool.simple_delete_many(
+ table="sliding_sync_membership_snapshots",
+ column="room_id",
+ iterable=(room_id,),
+ keyvalues={},
+ desc="sliding_sync_membership_snapshots.test_membership_snapshots_background_update_forgotten_missing",
+ )
+ )
+
+ # We shouldn't find anything in the table because we just deleted them in
+ # preparation for the test.
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ set(),
+ exact=True,
+ )
+
+ # Insert and run the background update.
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {
+ "update_name": _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BG_UPDATE,
+ "progress_json": "{}",
+ },
+ )
+ )
+ self.store.db_pool.updates._all_done = False
+ self.wait_for_background_updates()
+
+ # Make sure the table is populated
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (room_id, user1_id),
+ (room_id, user2_id),
+ },
+ exact=True,
+ )
+ # Holds the info according to the current state when the user joined
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id, user1_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id,
+ user_id=user1_id,
+ sender=user1_id,
+ membership_event_id=state_map[(EventTypes.Member, user1_id)].event_id,
+ membership=Membership.LEAVE,
+ event_stream_ordering=state_map[
+ (EventTypes.Member, user1_id)
+ ].internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=None,
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ # Room is forgotten
+ forgotten=True,
+ ),
+ )
+ # Holds the info according to the current state when the user joined
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id, user2_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id,
+ user_id=user2_id,
+ sender=user2_id,
+ membership_event_id=state_map[(EventTypes.Member, user2_id)].event_id,
+ membership=Membership.JOIN,
+ event_stream_ordering=state_map[
+ (EventTypes.Member, user2_id)
+ ].internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=None,
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ ),
+ )
+
+
+class SlidingSyncTablesCatchUpBackgroundUpdatesTestCase(SlidingSyncTablesTestCaseBase):
+ """
+ Test the background updates for catch-up after Synapse downgrade to populate the
+ `sliding_sync_joined_rooms` and `sliding_sync_membership_snapshots` tables.
+
+ This to test the "catch-up" version of the background update vs the "normal"
+ background update to populate the tables with all of the historical data. Both
+ versions share the same background update but just serve different purposes. We
+ check if the "catch-up" version needs to run on start-up based on whether there have
+ been any changes to rooms that aren't reflected in the sliding sync tables.
+
+ FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+ foreground update for
+ `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+ https://github.com/element-hq/synapse/issues/17623)
+ """
+
+ def test_joined_background_update_catch_up_new_room(self) -> None:
+ """
+ Test that new rooms while Synapse is downgraded (making
+ `sliding_sync_joined_rooms` stale) will be caught when Synapse is upgraded and
+ the catch-up routine is run.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Instead of testing with various levels of room state that should appear in the
+ # table, we're only using one room to keep this test simple. Because the
+ # underlying background update to populate these tables is the same as this
+ # catch-up routine, we are going to rely on
+ # `SlidingSyncTablesBackgroundUpdatesTestCase` to cover that logic.
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Make sure all of the background updates have finished before we start the
+ # catch-up. Even though it should work fine if the other background update is
+ # still running, we want to see the catch-up routine restore the progress
+ # correctly.
+ #
+ # We also don't want the normal background update messing with our results so we
+ # run this before we do our manual database clean-up to simulate new events
+ # being sent while Synapse was downgraded.
+ self.wait_for_background_updates()
+
+ # Clean-up the `sliding_sync_joined_rooms` table as if the the room never made
+ # it into the table. This is to simulate the a new room while Synapse was
+ # downgraded.
+ self.get_success(
+ self.store.db_pool.simple_delete(
+ table="sliding_sync_joined_rooms",
+ keyvalues={"room_id": room_id},
+ desc="simulate new room while Synapse was downgraded",
+ )
+ )
+
+ # The function under test. It should clear out stale data and start the
+ # background update to catch-up on the missing data.
+ self.get_success(
+ self.store.db_pool.runInteraction(
+ "_resolve_stale_data_in_sliding_sync_joined_rooms_table",
+ _resolve_stale_data_in_sliding_sync_joined_rooms_table,
+ )
+ )
+
+ # We shouldn't see any new data yet
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ set(),
+ exact=True,
+ )
+
+ # Wait for the catch-up background update to finish
+ self.store.db_pool.updates._all_done = False
+ self.wait_for_background_updates()
+
+ # Ensure that the table is populated correctly after the catch-up background
+ # update finishes
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ {room_id},
+ exact=True,
+ )
+
+ def test_joined_background_update_catch_up_room_state_change(self) -> None:
+ """
+ Test that new events while Synapse is downgraded (making
+ `sliding_sync_joined_rooms` stale) will be caught when Synapse is upgraded and
+ the catch-up routine is run.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Instead of testing with various levels of room state that should appear in the
+ # table, we're only using one room to keep this test simple. Because the
+ # underlying background update to populate these tables is the same as this
+ # catch-up routine, we are going to rely on
+ # `SlidingSyncTablesBackgroundUpdatesTestCase` to cover that logic.
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Get a snapshot of the `sliding_sync_joined_rooms` table before we add some state
+ sliding_sync_joined_rooms_results_before_state = (
+ self._get_sliding_sync_joined_rooms()
+ )
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results_before_state.keys()),
+ {room_id},
+ exact=True,
+ )
+
+ # Add a room name
+ self.helper.send_state(
+ room_id,
+ EventTypes.Name,
+ {"name": "my super duper room"},
+ tok=user1_tok,
+ )
+
+ # Make sure all of the background updates have finished before we start the
+ # catch-up. Even though it should work fine if the other background update is
+ # still running, we want to see the catch-up routine restore the progress
+ # correctly.
+ #
+ # We also don't want the normal background update messing with our results so we
+ # run this before we do our manual database clean-up to simulate new events
+ # being sent while Synapse was downgraded.
+ self.wait_for_background_updates()
+
+ # Clean-up the `sliding_sync_joined_rooms` table as if the the room name
+ # never made it into the table. This is to simulate the room name event
+ # being sent while Synapse was downgraded.
+ self.get_success(
+ self.store.db_pool.simple_update(
+ table="sliding_sync_joined_rooms",
+ keyvalues={"room_id": room_id},
+ updatevalues={
+ # Clear the room name
+ "room_name": None,
+ # Reset the `event_stream_ordering` back to the value before the room name
+ "event_stream_ordering": sliding_sync_joined_rooms_results_before_state[
+ room_id
+ ].event_stream_ordering,
+ },
+ desc="simulate new events while Synapse was downgraded",
+ )
+ )
+
+ # The function under test. It should clear out stale data and start the
+ # background update to catch-up on the missing data.
+ self.get_success(
+ self.store.db_pool.runInteraction(
+ "_resolve_stale_data_in_sliding_sync_joined_rooms_table",
+ _resolve_stale_data_in_sliding_sync_joined_rooms_table,
+ )
+ )
+
+ # Ensure that the stale data is deleted from the table
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ set(),
+ exact=True,
+ )
+
+ # Wait for the catch-up background update to finish
+ self.store.db_pool.updates._all_done = False
+ self.wait_for_background_updates()
+
+ # Ensure that the table is populated correctly after the catch-up background
+ # update finishes
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ {room_id},
+ exact=True,
+ )
+
+ def test_joined_background_update_catch_up_no_rooms(self) -> None:
+ """
+ Test that if you start your homeserver with no rooms on a Synapse version that
+ supports the sliding sync tables and the historical background update completes
+ (because no rooms to process), then Synapse is downgraded and new rooms are
+ created/joined; when Synapse is upgraded, the rooms will be processed catch-up
+ routine is run.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Instead of testing with various levels of room state that should appear in the
+ # table, we're only using one room to keep this test simple. Because the
+ # underlying background update to populate these tables is the same as this
+ # catch-up routine, we are going to rely on
+ # `SlidingSyncTablesBackgroundUpdatesTestCase` to cover that logic.
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+ # Make sure all of the background updates have finished before we start the
+ # catch-up. Even though it should work fine if the other background update is
+ # still running, we want to see the catch-up routine restore the progress
+ # correctly.
+ #
+ # We also don't want the normal background update messing with our results so we
+ # run this before we do our manual database clean-up to simulate room being
+ # created while Synapse was downgraded.
+ self.wait_for_background_updates()
+
+ # Clean-up the `sliding_sync_joined_rooms` table as if the the room never made
+ # it into the table. This is to simulate the room being created while Synapse
+ # was downgraded.
+ self.get_success(
+ self.store.db_pool.simple_delete_many(
+ table="sliding_sync_joined_rooms",
+ column="room_id",
+ iterable=(room_id,),
+ keyvalues={},
+ desc="simulate room being created while Synapse was downgraded",
+ )
+ )
+
+ # We shouldn't find anything in the table because we just deleted them in
+ # preparation for the test.
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ set(),
+ exact=True,
+ )
+
+ # The function under test. It should clear out stale data and start the
+ # background update to catch-up on the missing data.
+ self.get_success(
+ self.store.db_pool.runInteraction(
+ "_resolve_stale_data_in_sliding_sync_joined_rooms_table",
+ _resolve_stale_data_in_sliding_sync_joined_rooms_table,
+ )
+ )
+
+ # We still shouldn't find any data yet
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ set(),
+ exact=True,
+ )
+
+ # Wait for the catch-up background update to finish
+ self.store.db_pool.updates._all_done = False
+ self.wait_for_background_updates()
+
+ # Ensure that the table is populated correctly after the catch-up background
+ # update finishes
+ sliding_sync_joined_rooms_results = self._get_sliding_sync_joined_rooms()
+ self.assertIncludes(
+ set(sliding_sync_joined_rooms_results.keys()),
+ {room_id},
+ exact=True,
+ )
+
+ def test_membership_snapshots_background_update_catch_up_new_membership(
+ self,
+ ) -> None:
+ """
+ Test that completely new membership while Synapse is downgraded (making
+ `sliding_sync_membership_snapshots` stale) will be caught when Synapse is
+ upgraded and the catch-up routine is run.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # Instead of testing with various levels of room state that should appear in the
+ # table, we're only using one room to keep this test simple. Because the
+ # underlying background update to populate these tables is the same as this
+ # catch-up routine, we are going to rely on
+ # `SlidingSyncTablesBackgroundUpdatesTestCase` to cover that logic.
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ # User2 joins the room
+ self.helper.join(room_id, user2_id, tok=user2_tok)
+
+ # Both users are joined to the room
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (room_id, user1_id),
+ (room_id, user2_id),
+ },
+ exact=True,
+ )
+
+ # Make sure all of the background updates have finished before we start the
+ # catch-up. Even though it should work fine if the other background update is
+ # still running, we want to see the catch-up routine restore the progress
+ # correctly.
+ #
+ # We also don't want the normal background update messing with our results so we
+ # run this before we do our manual database clean-up to simulate new events
+ # being sent while Synapse was downgraded.
+ self.wait_for_background_updates()
+
+ # Clean-up the `sliding_sync_membership_snapshots` table as if the user2
+ # membership never made it into the table. This is to simulate a membership
+ # change while Synapse was downgraded.
+ self.get_success(
+ self.store.db_pool.simple_delete(
+ table="sliding_sync_membership_snapshots",
+ keyvalues={"room_id": room_id, "user_id": user2_id},
+ desc="simulate new membership while Synapse was downgraded",
+ )
+ )
+
+ # We shouldn't find the user2 membership in the table because we just deleted it
+ # in preparation for the test.
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (room_id, user1_id),
+ },
+ exact=True,
+ )
+
+ # The function under test. It should clear out stale data and start the
+ # background update to catch-up on the missing data.
+ self.get_success(
+ self.store.db_pool.runInteraction(
+ "_resolve_stale_data_in_sliding_sync_membership_snapshots_table",
+ _resolve_stale_data_in_sliding_sync_membership_snapshots_table,
+ )
+ )
+
+ # We still shouldn't find any data yet
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (room_id, user1_id),
+ },
+ exact=True,
+ )
+
+ # Wait for the catch-up background update to finish
+ self.store.db_pool.updates._all_done = False
+ self.wait_for_background_updates()
+
+ # Ensure that the table is populated correctly after the catch-up background
+ # update finishes
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (room_id, user1_id),
+ (room_id, user2_id),
+ },
+ exact=True,
+ )
+
+ def test_membership_snapshots_background_update_catch_up_membership_change(
+ self,
+ ) -> None:
+ """
+ Test that membership changes while Synapse is downgraded (making
+ `sliding_sync_membership_snapshots` stale) will be caught when Synapse is upgraded and
+ the catch-up routine is run.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # Instead of testing with various levels of room state that should appear in the
+ # table, we're only using one room to keep this test simple. Because the
+ # underlying background update to populate these tables is the same as this
+ # catch-up routine, we are going to rely on
+ # `SlidingSyncTablesBackgroundUpdatesTestCase` to cover that logic.
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ # User2 joins the room
+ self.helper.join(room_id, user2_id, tok=user2_tok)
+
+ # Both users are joined to the room
+ sliding_sync_membership_snapshots_results_before_membership_changes = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(
+ sliding_sync_membership_snapshots_results_before_membership_changes.keys()
+ ),
+ {
+ (room_id, user1_id),
+ (room_id, user2_id),
+ },
+ exact=True,
+ )
+
+ # User2 leaves the room
+ self.helper.leave(room_id, user2_id, tok=user2_tok)
+
+ # Make sure all of the background updates have finished before we start the
+ # catch-up. Even though it should work fine if the other background update is
+ # still running, we want to see the catch-up routine restore the progress
+ # correctly.
+ #
+ # We also don't want the normal background update messing with our results so we
+ # run this before we do our manual database clean-up to simulate new events
+ # being sent while Synapse was downgraded.
+ self.wait_for_background_updates()
+
+ # Rollback the `sliding_sync_membership_snapshots` table as if the user2
+ # membership never made it into the table. This is to simulate a membership
+ # change while Synapse was downgraded.
+ self.get_success(
+ self.store.db_pool.simple_update(
+ table="sliding_sync_membership_snapshots",
+ keyvalues={"room_id": room_id, "user_id": user2_id},
+ updatevalues={
+ # Reset everything back to the value before user2 left the room
+ "membership": sliding_sync_membership_snapshots_results_before_membership_changes[
+ (room_id, user2_id)
+ ].membership,
+ "membership_event_id": sliding_sync_membership_snapshots_results_before_membership_changes[
+ (room_id, user2_id)
+ ].membership_event_id,
+ "event_stream_ordering": sliding_sync_membership_snapshots_results_before_membership_changes[
+ (room_id, user2_id)
+ ].event_stream_ordering,
+ },
+ desc="simulate membership change while Synapse was downgraded",
+ )
+ )
+
+ # We should see user2 still joined to the room because we made that change in
+ # preparation for the test.
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (room_id, user1_id),
+ (room_id, user2_id),
+ },
+ exact=True,
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id, user1_id)),
+ sliding_sync_membership_snapshots_results_before_membership_changes[
+ (room_id, user1_id)
+ ],
+ )
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id, user2_id)),
+ sliding_sync_membership_snapshots_results_before_membership_changes[
+ (room_id, user2_id)
+ ],
+ )
+
+ # The function under test. It should clear out stale data and start the
+ # background update to catch-up on the missing data.
+ self.get_success(
+ self.store.db_pool.runInteraction(
+ "_resolve_stale_data_in_sliding_sync_membership_snapshots_table",
+ _resolve_stale_data_in_sliding_sync_membership_snapshots_table,
+ )
+ )
+
+ # Ensure that the stale data is deleted from the table
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (room_id, user1_id),
+ },
+ exact=True,
+ )
+
+ # Wait for the catch-up background update to finish
+ self.store.db_pool.updates._all_done = False
+ self.wait_for_background_updates()
+
+ # Ensure that the table is populated correctly after the catch-up background
+ # update finishes
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (room_id, user1_id),
+ (room_id, user2_id),
+ },
+ exact=True,
+ )
+
+ def test_membership_snapshots_background_update_catch_up_no_membership(
+ self,
+ ) -> None:
+ """
+ Test that if you start your homeserver with no rooms on a Synapse version that
+ supports the sliding sync tables and the historical background update completes
+ (because no rooms/membership to process), then Synapse is downgraded and new
+ rooms are created/joined; when Synapse is upgraded, the rooms will be processed
+ catch-up routine is run.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # Instead of testing with various levels of room state that should appear in the
+ # table, we're only using one room to keep this test simple. Because the
+ # underlying background update to populate these tables is the same as this
+ # catch-up routine, we are going to rely on
+ # `SlidingSyncTablesBackgroundUpdatesTestCase` to cover that logic.
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+ # User2 joins the room
+ self.helper.join(room_id, user2_id, tok=user2_tok)
+
+ # Make sure all of the background updates have finished before we start the
+ # catch-up. Even though it should work fine if the other background update is
+ # still running, we want to see the catch-up routine restore the progress
+ # correctly.
+ #
+ # We also don't want the normal background update messing with our results so we
+ # run this before we do our manual database clean-up to simulate new events
+ # being sent while Synapse was downgraded.
+ self.wait_for_background_updates()
+
+ # Rollback the `sliding_sync_membership_snapshots` table as if the user2
+ # membership never made it into the table. This is to simulate a membership
+ # change while Synapse was downgraded.
+ self.get_success(
+ self.store.db_pool.simple_delete_many(
+ table="sliding_sync_membership_snapshots",
+ column="room_id",
+ iterable=(room_id,),
+ keyvalues={},
+ desc="simulate room being created while Synapse was downgraded",
+ )
+ )
+
+ # We shouldn't find anything in the table because we just deleted them in
+ # preparation for the test.
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ set(),
+ exact=True,
+ )
+
+ # The function under test. It should clear out stale data and start the
+ # background update to catch-up on the missing data.
+ self.get_success(
+ self.store.db_pool.runInteraction(
+ "_resolve_stale_data_in_sliding_sync_membership_snapshots_table",
+ _resolve_stale_data_in_sliding_sync_membership_snapshots_table,
+ )
+ )
+
+ # We still shouldn't find any data yet
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ set(),
+ exact=True,
+ )
+
+ # Wait for the catch-up background update to finish
+ self.store.db_pool.updates._all_done = False
+ self.wait_for_background_updates()
+
+ # Ensure that the table is populated correctly after the catch-up background
+ # update finishes
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (room_id, user1_id),
+ (room_id, user2_id),
+ },
+ exact=True,
+ )
+
+
+class SlidingSyncMembershipSnapshotsTableFixForgottenColumnBackgroundUpdatesTestCase(
+ SlidingSyncTablesTestCaseBase
+):
+ """
+ Test the background updates that fixes `sliding_sync_membership_snapshots` ->
+ `forgotten` column.
+ """
+
+ def test_membership_snapshots_fix_forgotten_column_background_update(self) -> None:
+ """
+ Test that the background update, updates the `sliding_sync_membership_snapshots`
+ -> `forgotten` column to be in sync with the `room_memberships` table.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
+ # User1 joins the room
+ self.helper.join(room_id, user1_id, tok=user1_tok)
+
+ # Leave and forget the room
+ self.helper.leave(room_id, user1_id, tok=user1_tok)
+ # User1 forgets the room
+ channel = self.make_request(
+ "POST",
+ f"/_matrix/client/r0/rooms/{room_id}/forget",
+ content={},
+ access_token=user1_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # Re-join the room
+ self.helper.join(room_id, user1_id, tok=user1_tok)
+
+ # Reset `sliding_sync_membership_snapshots` table as if the `forgotten` column
+ # got out of sync from the `room_memberships` table from the previous flawed
+ # code.
+ self.get_success(
+ self.store.db_pool.simple_update_one(
+ table="sliding_sync_membership_snapshots",
+ keyvalues={"room_id": room_id, "user_id": user1_id},
+ updatevalues={"forgotten": 1},
+ desc="sliding_sync_membership_snapshots.test_membership_snapshots_fix_forgotten_column_background_update",
+ )
+ )
+
+ # Insert and run the background update.
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {
+ "update_name": _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_FIX_FORGOTTEN_COLUMN_BG_UPDATE,
+ "progress_json": "{}",
+ },
+ )
+ )
+ self.store.db_pool.updates._all_done = False
+ self.wait_for_background_updates()
+
+ # Make sure the table is populated
+
+ sliding_sync_membership_snapshots_results = (
+ self._get_sliding_sync_membership_snapshots()
+ )
+ self.assertIncludes(
+ set(sliding_sync_membership_snapshots_results.keys()),
+ {
+ (room_id, user1_id),
+ (room_id, user2_id),
+ },
+ exact=True,
+ )
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id)
+ )
+ # Holds the info according to the current state when the user joined.
+ #
+ # We only care about checking on user1 as that's what we reset and expect to be
+ # correct now
+ self.assertEqual(
+ sliding_sync_membership_snapshots_results.get((room_id, user1_id)),
+ _SlidingSyncMembershipSnapshotResult(
+ room_id=room_id,
+ user_id=user1_id,
+ sender=user1_id,
+ membership_event_id=state_map[(EventTypes.Member, user1_id)].event_id,
+ membership=Membership.JOIN,
+ event_stream_ordering=state_map[
+ (EventTypes.Member, user1_id)
+ ].internal_metadata.stream_ordering,
+ has_known_state=True,
+ room_type=None,
+ room_name=None,
+ is_encrypted=False,
+ tombstone_successor_room_id=None,
+ # We should see the room as no longer forgotten
+ forgotten=False,
+ ),
+ )
diff --git a/tests/storage/test_state_deletion.py b/tests/storage/test_state_deletion.py
new file mode 100644
index 0000000000..a4d318ae20
--- /dev/null
+++ b/tests/storage/test_state_deletion.py
@@ -0,0 +1,475 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2025 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+
+
+import logging
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.rest import admin
+from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.util import Clock
+
+from tests.test_utils.event_injection import create_event
+from tests.unittest import HomeserverTestCase
+
+logger = logging.getLogger(__name__)
+
+
+class StateDeletionStoreTestCase(HomeserverTestCase):
+ """Tests for the StateDeletionStore."""
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.state_store = hs.get_datastores().state
+ self.state_deletion_store = hs.get_datastores().state_deletion
+ self.purge_events = hs.get_storage_controllers().purge_events
+
+ # We want to disable the automatic deletion of state groups in the
+ # background, so we can do controlled tests.
+ self.purge_events._delete_state_loop_call.stop()
+
+ self.user_id = self.register_user("test", "password")
+ tok = self.login("test", "password")
+ self.room_id = self.helper.create_room_as(self.user_id, tok=tok)
+
+ def check_if_can_be_deleted(self, state_group: int) -> bool:
+ """Check if the state group is pending deletion."""
+
+ state_group_to_sequence_number = self.get_success(
+ self.state_deletion_store.get_pending_deletions([state_group])
+ )
+
+ can_be_deleted = self.get_success(
+ self.state_deletion_store.db_pool.runInteraction(
+ "test_existing_pending_deletion_is_cleared",
+ self.state_deletion_store.get_state_groups_ready_for_potential_deletion_txn,
+ state_group_to_sequence_number,
+ )
+ )
+
+ return state_group in can_be_deleted
+
+ def test_no_deletion(self) -> None:
+ """Test that calling persisting_state_group_references is fine if
+ nothing is pending deletion"""
+ event, context = self.get_success(
+ create_event(
+ self.hs,
+ room_id=self.room_id,
+ type="m.test",
+ sender=self.user_id,
+ )
+ )
+
+ ctx_mgr = self.state_deletion_store.persisting_state_group_references(
+ [(event, context)]
+ )
+
+ self.get_success(ctx_mgr.__aenter__())
+ self.get_success(ctx_mgr.__aexit__(None, None, None))
+
+ def test_no_deletion_error(self) -> None:
+ """Test that calling persisting_state_group_references is fine if
+ nothing is pending deletion, but an error occurs."""
+
+ event, context = self.get_success(
+ create_event(
+ self.hs,
+ room_id=self.room_id,
+ type="m.test",
+ sender=self.user_id,
+ )
+ )
+
+ ctx_mgr = self.state_deletion_store.persisting_state_group_references(
+ [(event, context)]
+ )
+
+ self.get_success(ctx_mgr.__aenter__())
+ self.get_success(ctx_mgr.__aexit__(Exception, Exception("test"), None))
+
+ def test_existing_pending_deletion_is_cleared(self) -> None:
+ """Test that the pending deletion flag gets cleared when the state group
+ gets persisted."""
+
+ event, context = self.get_success(
+ create_event(
+ self.hs,
+ room_id=self.room_id,
+ type="m.test",
+ state_key="",
+ sender=self.user_id,
+ )
+ )
+ assert context.state_group is not None
+
+ # Mark a state group that we're referencing as pending deletion.
+ self.get_success(
+ self.state_deletion_store.mark_state_groups_as_pending_deletion(
+ [context.state_group]
+ )
+ )
+
+ ctx_mgr = self.state_deletion_store.persisting_state_group_references(
+ [(event, context)]
+ )
+
+ self.get_success(ctx_mgr.__aenter__())
+ self.get_success(ctx_mgr.__aexit__(None, None, None))
+
+ # The pending deletion flag should be cleared
+ pending_deletion = self.get_success(
+ self.state_deletion_store.db_pool.simple_select_one_onecol(
+ table="state_groups_pending_deletion",
+ keyvalues={"state_group": context.state_group},
+ retcol="1",
+ allow_none=True,
+ desc="test_existing_pending_deletion_is_cleared",
+ )
+ )
+ self.assertIsNone(pending_deletion)
+
+ def test_pending_deletion_is_cleared_during_persist(self) -> None:
+ """Test that the pending deletion flag is cleared when a state group
+ gets marked for deletion during persistence"""
+
+ event, context = self.get_success(
+ create_event(
+ self.hs,
+ room_id=self.room_id,
+ type="m.test",
+ state_key="",
+ sender=self.user_id,
+ )
+ )
+ assert context.state_group is not None
+
+ ctx_mgr = self.state_deletion_store.persisting_state_group_references(
+ [(event, context)]
+ )
+ self.get_success(ctx_mgr.__aenter__())
+
+ # Mark the state group that we're referencing as pending deletion,
+ # *after* we have started persisting.
+ self.get_success(
+ self.state_deletion_store.mark_state_groups_as_pending_deletion(
+ [context.state_group]
+ )
+ )
+
+ self.get_success(ctx_mgr.__aexit__(None, None, None))
+
+ # The pending deletion flag should be cleared
+ pending_deletion = self.get_success(
+ self.state_deletion_store.db_pool.simple_select_one_onecol(
+ table="state_groups_pending_deletion",
+ keyvalues={"state_group": context.state_group},
+ retcol="1",
+ allow_none=True,
+ desc="test_existing_pending_deletion_is_cleared",
+ )
+ )
+ self.assertIsNone(pending_deletion)
+
+ def test_deletion_check(self) -> None:
+ """Test that the `get_state_groups_that_can_be_purged_txn` check is
+ correct during different points of the lifecycle of persisting an
+ event."""
+ event, context = self.get_success(
+ create_event(
+ self.hs,
+ room_id=self.room_id,
+ type="m.test",
+ state_key="",
+ sender=self.user_id,
+ )
+ )
+ assert context.state_group is not None
+
+ self.get_success(
+ self.state_deletion_store.mark_state_groups_as_pending_deletion(
+ [context.state_group]
+ )
+ )
+
+ # We shouldn't be able to delete the state group as not enough time as passed
+ can_be_deleted = self.check_if_can_be_deleted(context.state_group)
+ self.assertFalse(can_be_deleted)
+
+ # After enough time we can delete the state group
+ self.reactor.advance(
+ 1 + self.state_deletion_store.DELAY_BEFORE_DELETION_MS / 1000
+ )
+ can_be_deleted = self.check_if_can_be_deleted(context.state_group)
+ self.assertTrue(can_be_deleted)
+
+ ctx_mgr = self.state_deletion_store.persisting_state_group_references(
+ [(event, context)]
+ )
+ self.get_success(ctx_mgr.__aenter__())
+
+ # But once we start persisting we can't delete the state group
+ can_be_deleted = self.check_if_can_be_deleted(context.state_group)
+ self.assertFalse(can_be_deleted)
+
+ self.get_success(ctx_mgr.__aexit__(None, None, None))
+
+ # The pending deletion flag should remain cleared after persistence has
+ # finished.
+ can_be_deleted = self.check_if_can_be_deleted(context.state_group)
+ self.assertFalse(can_be_deleted)
+
+ def test_deletion_error_during_persistence(self) -> None:
+ """Test that state groups remain marked as pending deletion if persisting
+ the event fails."""
+
+ event, context = self.get_success(
+ create_event(
+ self.hs,
+ room_id=self.room_id,
+ type="m.test",
+ state_key="",
+ sender=self.user_id,
+ )
+ )
+ assert context.state_group is not None
+
+ # Mark a state group that we're referencing as pending deletion.
+ self.get_success(
+ self.state_deletion_store.mark_state_groups_as_pending_deletion(
+ [context.state_group]
+ )
+ )
+
+ ctx_mgr = self.state_deletion_store.persisting_state_group_references(
+ [(event, context)]
+ )
+
+ self.get_success(ctx_mgr.__aenter__())
+ self.get_success(ctx_mgr.__aexit__(Exception, Exception("test"), None))
+
+ # We should be able to delete the state group after a certain amount of
+ # time
+ self.reactor.advance(
+ 1 + self.state_deletion_store.DELAY_BEFORE_DELETION_MS / 1000
+ )
+ can_be_deleted = self.check_if_can_be_deleted(context.state_group)
+ self.assertTrue(can_be_deleted)
+
+ def test_race_between_check_and_insert(self) -> None:
+ """Check that we correctly handle the race where we go to delete a
+ state group, check that it is unreferenced, and then it becomes
+ referenced just before we delete it."""
+
+ event, context = self.get_success(
+ create_event(
+ self.hs,
+ room_id=self.room_id,
+ type="m.test",
+ state_key="",
+ sender=self.user_id,
+ )
+ )
+ assert context.state_group is not None
+
+ # Mark a state group that we're referencing as pending deletion.
+ self.get_success(
+ self.state_deletion_store.mark_state_groups_as_pending_deletion(
+ [context.state_group]
+ )
+ )
+
+ # Advance time enough so we can delete the state group
+ self.reactor.advance(
+ 1 + self.state_deletion_store.DELAY_BEFORE_DELETION_MS / 1000
+ )
+
+ # Check that we'd be able to delete this state group.
+ state_group_to_sequence_number = self.get_success(
+ self.state_deletion_store.get_pending_deletions([context.state_group])
+ )
+
+ can_be_deleted = self.get_success(
+ self.state_deletion_store.db_pool.runInteraction(
+ "test_existing_pending_deletion_is_cleared",
+ self.state_deletion_store.get_state_groups_ready_for_potential_deletion_txn,
+ state_group_to_sequence_number,
+ )
+ )
+ self.assertCountEqual(can_be_deleted, [context.state_group])
+
+ # ... in the real world we'd check that the state group isn't referenced here ...
+
+ # Now we persist the event to reference the state group, *after* we
+ # check that the state group wasn't referenced
+ ctx_mgr = self.state_deletion_store.persisting_state_group_references(
+ [(event, context)]
+ )
+
+ self.get_success(ctx_mgr.__aenter__())
+ self.get_success(ctx_mgr.__aexit__(Exception, Exception("test"), None))
+
+ # We simulate a pause (required to hit the race)
+ self.reactor.advance(
+ 1 + self.state_deletion_store.DELAY_BEFORE_DELETION_MS / 1000
+ )
+
+ # We should no longer be able to delete the state group, without having
+ # to recheck if its referenced.
+ can_be_deleted = self.get_success(
+ self.state_deletion_store.db_pool.runInteraction(
+ "test_existing_pending_deletion_is_cleared",
+ self.state_deletion_store.get_state_groups_ready_for_potential_deletion_txn,
+ state_group_to_sequence_number,
+ )
+ )
+ self.assertCountEqual(can_be_deleted, [])
+
+ def test_remove_ancestors_from_can_delete(self) -> None:
+ """Test that if a state group is not ready to be deleted, we also don't
+ delete anything that is referenced by it"""
+
+ event, context = self.get_success(
+ create_event(
+ self.hs,
+ room_id=self.room_id,
+ type="m.test",
+ state_key="",
+ sender=self.user_id,
+ )
+ )
+ assert context.state_group is not None
+
+ # Create a new state group that references the one from the event
+ new_state_group = self.get_success(
+ self.state_store.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=context.state_group,
+ delta_ids={},
+ current_state_ids=None,
+ )
+ )
+
+ # Mark them both as pending deletion
+ self.get_success(
+ self.state_deletion_store.mark_state_groups_as_pending_deletion(
+ [context.state_group, new_state_group]
+ )
+ )
+
+ # Advance time enough so we can delete the state group so they're both
+ # ready for deletion.
+ self.reactor.advance(
+ 1 + self.state_deletion_store.DELAY_BEFORE_DELETION_MS / 1000
+ )
+
+ # We can now delete both state groups
+ self.assertTrue(self.check_if_can_be_deleted(context.state_group))
+ self.assertTrue(self.check_if_can_be_deleted(new_state_group))
+
+ # Use the new_state_group to bump its deletion time
+ self.get_success(
+ self.state_store.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=new_state_group,
+ delta_ids={},
+ current_state_ids=None,
+ )
+ )
+
+ # We should now not be able to delete either of the state groups.
+ state_group_to_sequence_number = self.get_success(
+ self.state_deletion_store.get_pending_deletions(
+ [context.state_group, new_state_group]
+ )
+ )
+
+ # We shouldn't be able to delete the state group as not enough time has passed
+ can_be_deleted = self.get_success(
+ self.state_deletion_store.db_pool.runInteraction(
+ "test_existing_pending_deletion_is_cleared",
+ self.state_deletion_store.get_state_groups_ready_for_potential_deletion_txn,
+ state_group_to_sequence_number,
+ )
+ )
+ self.assertCountEqual(can_be_deleted, [])
+
+ def test_newly_referenced_state_group_gets_removed_from_pending(self) -> None:
+ """Check that if a state group marked for deletion becomes referenced
+ (without being removed from pending deletion table), it gets removed
+ from pending deletion table."""
+
+ event, context = self.get_success(
+ create_event(
+ self.hs,
+ room_id=self.room_id,
+ type="m.test",
+ state_key="",
+ sender=self.user_id,
+ )
+ )
+ assert context.state_group is not None
+
+ # Mark a state group that we're referencing as pending deletion.
+ self.get_success(
+ self.state_deletion_store.mark_state_groups_as_pending_deletion(
+ [context.state_group]
+ )
+ )
+
+ # Advance time enough so we can delete the state group so they're both
+ # ready for deletion.
+ self.reactor.advance(
+ 1 + self.state_deletion_store.DELAY_BEFORE_DELETION_MS / 1000
+ )
+
+ # Manually insert into the table to mimic the state group getting used.
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ table="event_to_state_groups",
+ values={"state_group": context.state_group, "event_id": event.event_id},
+ desc="test_newly_referenced_state_group_gets_removed_from_pending",
+ )
+ )
+
+ # Manually run the background task to delete pending state groups.
+ self.get_success(self.purge_events._delete_state_groups_loop())
+
+ # The pending deletion flag should be cleared...
+ pending_deletion = self.get_success(
+ self.state_deletion_store.db_pool.simple_select_one_onecol(
+ table="state_groups_pending_deletion",
+ keyvalues={"state_group": context.state_group},
+ retcol="1",
+ allow_none=True,
+ desc="test_newly_referenced_state_group_gets_removed_from_pending",
+ )
+ )
+ self.assertIsNone(pending_deletion)
+
+ # .. but the state should not have been deleted.
+ state = self.get_success(
+ self.state_store._get_state_for_groups([context.state_group])
+ )
+ self.assertGreater(len(state[context.state_group]), 0)
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
index 7b7590da76..0f58dc8a0a 100644
--- a/tests/storage/test_stream.py
+++ b/tests/storage/test_stream.py
@@ -27,7 +27,13 @@ from immutabledict import immutabledict
from twisted.test.proto_helpers import MemoryReactor
-from synapse.api.constants import Direction, EventTypes, Membership, RelationTypes
+from synapse.api.constants import (
+ Direction,
+ EventTypes,
+ JoinRules,
+ Membership,
+ RelationTypes,
+)
from synapse.api.filtering import Filter
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events import FrozenEventV3
@@ -147,7 +153,7 @@ class PaginationTestCase(HomeserverTestCase):
def _filter_messages(self, filter: JsonDict) -> List[str]:
"""Make a request to /messages with a filter, returns the chunk of events."""
- events, next_key = self.get_success(
+ events, next_key, _ = self.get_success(
self.hs.get_datastores().main.paginate_room_events_by_topological_ordering(
room_id=self.room_id,
from_key=self.from_token.room_key,
@@ -1154,7 +1160,7 @@ class GetCurrentStateDeltaMembershipChangesForUserTestCase(HomeserverTestCase):
room_id=room_id1,
event_id=None,
event_pos=dummy_state_pos,
- membership="leave",
+ membership=Membership.LEAVE,
sender=None, # user1_id,
prev_event_id=join_response1["event_id"],
prev_event_pos=join_pos1,
@@ -1164,6 +1170,75 @@ class GetCurrentStateDeltaMembershipChangesForUserTestCase(HomeserverTestCase):
],
)
+ def test_state_reset2(self) -> None:
+ """
+ Test a state reset scenario where the user gets removed from the room (when
+ there is no corresponding leave event)
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, is_public=True, tok=user2_tok)
+
+ event_response = self.helper.send(room_id1, "test", tok=user2_tok)
+ event_id = event_response["event_id"]
+
+ user1_join_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
+ user1_join_pos = self.get_success(
+ self.store.get_position_for_event(user1_join_response["event_id"])
+ )
+
+ before_reset_token = self.event_sources.get_current_token()
+
+ # Trigger a state reset
+ join_rule_event, join_rule_context = self.get_success(
+ create_event(
+ self.hs,
+ prev_event_ids=[event_id],
+ type=EventTypes.JoinRules,
+ state_key="",
+ content={"join_rule": JoinRules.INVITE},
+ sender=user2_id,
+ room_id=room_id1,
+ room_version=self.get_success(self.store.get_room_version_id(room_id1)),
+ )
+ )
+ _, join_rule_event_pos, _ = self.get_success(
+ self.persistence.persist_event(join_rule_event, join_rule_context)
+ )
+
+ after_reset_token = self.event_sources.get_current_token()
+
+ membership_changes = self.get_success(
+ self.store.get_current_state_delta_membership_changes_for_user(
+ user1_id,
+ from_key=before_reset_token.room_key,
+ to_key=after_reset_token.room_key,
+ )
+ )
+
+ # Let the whole diff show on failure
+ self.maxDiff = None
+ self.assertEqual(
+ membership_changes,
+ [
+ CurrentStateDeltaMembership(
+ room_id=room_id1,
+ event_id=None,
+ # The position where the state reset happened
+ event_pos=join_rule_event_pos,
+ membership=Membership.LEAVE,
+ sender=None,
+ prev_event_id=user1_join_response["event_id"],
+ prev_event_pos=user1_join_pos,
+ prev_membership="join",
+ prev_sender=user1_id,
+ ),
+ ],
+ )
+
def test_excluded_room_ids(self) -> None:
"""
Test that the `excluded_room_ids` option excludes changes from the specified rooms.
@@ -1384,20 +1459,25 @@ class GetCurrentStateDeltaMembershipChangesForUserFederationTestCase(
)
)
- with patch.object(
- self.room_member_handler.federation_handler.federation_client,
- "make_membership_event",
- mock_make_membership_event,
- ), patch.object(
- self.room_member_handler.federation_handler.federation_client,
- "send_join",
- mock_send_join,
- ), patch(
- "synapse.event_auth._is_membership_change_allowed",
- return_value=None,
- ), patch(
- "synapse.handlers.federation_event.check_state_dependent_auth_rules",
- return_value=None,
+ with (
+ patch.object(
+ self.room_member_handler.federation_handler.federation_client,
+ "make_membership_event",
+ mock_make_membership_event,
+ ),
+ patch.object(
+ self.room_member_handler.federation_handler.federation_client,
+ "send_join",
+ mock_send_join,
+ ),
+ patch(
+ "synapse.event_auth._is_membership_change_allowed",
+ return_value=None,
+ ),
+ patch(
+ "synapse.handlers.federation_event.check_state_dependent_auth_rules",
+ return_value=None,
+ ),
):
self.get_success(
self.room_member_handler.update_membership(
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 6d1ae4c8d7..f12402f5f2 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -292,12 +292,14 @@ class EventAuthTestCase(unittest.TestCase):
]
# pleb should not be able to send state
- self.assertRaises(
- AuthError,
- event_auth.check_state_dependent_auth_rules,
- _random_state_event(RoomVersions.V1, pleb),
- auth_events,
- ),
+ (
+ self.assertRaises(
+ AuthError,
+ event_auth.check_state_dependent_auth_rules,
+ _random_state_event(RoomVersions.V1, pleb),
+ auth_events,
+ ),
+ )
# king should be able to send state
event_auth.check_state_dependent_auth_rules(
diff --git a/tests/test_federation.py b/tests/test_federation.py
deleted file mode 100644
index 4e9adc0625..0000000000
--- a/tests/test_federation.py
+++ /dev/null
@@ -1,376 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
-# Copyright (C) 2023 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-# Originally licensed under the Apache License, Version 2.0:
-# <http://www.apache.org/licenses/LICENSE-2.0>.
-#
-# [This file includes modifications made by New Vector Limited]
-#
-#
-
-from typing import Collection, List, Optional, Union
-from unittest.mock import AsyncMock, Mock
-
-from twisted.test.proto_helpers import MemoryReactor
-
-from synapse.api.errors import FederationError
-from synapse.api.room_versions import RoomVersion, RoomVersions
-from synapse.events import EventBase, make_event_from_dict
-from synapse.events.snapshot import EventContext
-from synapse.federation.federation_base import event_from_pdu_json
-from synapse.handlers.device import DeviceListUpdater
-from synapse.http.types import QueryParams
-from synapse.logging.context import LoggingContext
-from synapse.server import HomeServer
-from synapse.types import JsonDict, UserID, create_requester
-from synapse.util import Clock
-from synapse.util.retryutils import NotRetryingDestination
-
-from tests import unittest
-
-
-class MessageAcceptTests(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- self.http_client = Mock()
- return self.setup_test_homeserver(federation_http_client=self.http_client)
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- user_id = UserID("us", "test")
- our_user = create_requester(user_id)
- room_creator = self.hs.get_room_creation_handler()
- self.room_id = self.get_success(
- room_creator.create_room(
- our_user, room_creator._presets_dict["public_chat"], ratelimit=False
- )
- )[0]
-
- self.store = self.hs.get_datastores().main
-
- # Figure out what the most recent event is
- most_recent = next(
- iter(
- self.get_success(
- self.hs.get_datastores().main.get_latest_event_ids_in_room(
- self.room_id
- )
- )
- )
- )
-
- join_event = make_event_from_dict(
- {
- "room_id": self.room_id,
- "sender": "@baduser:test.serv",
- "state_key": "@baduser:test.serv",
- "event_id": "$join:test.serv",
- "depth": 1000,
- "origin_server_ts": 1,
- "type": "m.room.member",
- "origin": "test.servx",
- "content": {"membership": "join"},
- "auth_events": [],
- "prev_state": [(most_recent, {})],
- "prev_events": [(most_recent, {})],
- }
- )
-
- self.handler = self.hs.get_federation_handler()
- federation_event_handler = self.hs.get_federation_event_handler()
-
- async def _check_event_auth(
- origin: Optional[str], event: EventBase, context: EventContext
- ) -> None:
- pass
-
- federation_event_handler._check_event_auth = _check_event_auth # type: ignore[method-assign]
- self.client = self.hs.get_federation_client()
-
- async def _check_sigs_and_hash_for_pulled_events_and_fetch(
- dest: str, pdus: Collection[EventBase], room_version: RoomVersion
- ) -> List[EventBase]:
- return list(pdus)
-
- self.client._check_sigs_and_hash_for_pulled_events_and_fetch = _check_sigs_and_hash_for_pulled_events_and_fetch # type: ignore[assignment]
-
- # Send the join, it should return None (which is not an error)
- self.assertEqual(
- self.get_success(
- federation_event_handler.on_receive_pdu("test.serv", join_event)
- ),
- None,
- )
-
- # Make sure we actually joined the room
- self.assertEqual(
- self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)),
- {"$join:test.serv"},
- )
-
- def test_cant_hide_direct_ancestors(self) -> None:
- """
- If you send a message, you must be able to provide the direct
- prev_events that said event references.
- """
-
- async def post_json(
- destination: str,
- path: str,
- data: Optional[JsonDict] = None,
- long_retries: bool = False,
- timeout: Optional[int] = None,
- ignore_backoff: bool = False,
- args: Optional[QueryParams] = None,
- ) -> Union[JsonDict, list]:
- # If it asks us for new missing events, give them NOTHING
- if path.startswith("/_matrix/federation/v1/get_missing_events/"):
- return {"events": []}
- return {}
-
- self.http_client.post_json = post_json
-
- # Figure out what the most recent event is
- most_recent = next(
- iter(
- self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
- )
- )
-
- # Now lie about an event
- lying_event = make_event_from_dict(
- {
- "room_id": self.room_id,
- "sender": "@baduser:test.serv",
- "event_id": "one:test.serv",
- "depth": 1000,
- "origin_server_ts": 1,
- "type": "m.room.message",
- "origin": "test.serv",
- "content": {"body": "hewwo?"},
- "auth_events": [],
- "prev_events": [("two:test.serv", {}), (most_recent, {})],
- }
- )
-
- federation_event_handler = self.hs.get_federation_event_handler()
- with LoggingContext("test-context"):
- failure = self.get_failure(
- federation_event_handler.on_receive_pdu("test.serv", lying_event),
- FederationError,
- )
-
- # on_receive_pdu should throw an error
- self.assertEqual(
- failure.value.args[0],
- (
- "ERROR 403: Your server isn't divulging details about prev_events "
- "referenced in this event."
- ),
- )
-
- # Make sure the invalid event isn't there
- extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
- self.assertEqual(extrem, {"$join:test.serv"})
-
- def test_retry_device_list_resync(self) -> None:
- """Tests that device lists are marked as stale if they couldn't be synced, and
- that stale device lists are retried periodically.
- """
- remote_user_id = "@john:test_remote"
- remote_origin = "test_remote"
-
- # Track the number of attempts to resync the user's device list.
- self.resync_attempts = 0
-
- # When this function is called, increment the number of resync attempts (only if
- # we're querying devices for the right user ID), then raise a
- # NotRetryingDestination error to fail the resync gracefully.
- def query_user_devices(
- destination: str, user_id: str, timeout: int = 30000
- ) -> JsonDict:
- if user_id == remote_user_id:
- self.resync_attempts += 1
-
- raise NotRetryingDestination(0, 0, destination)
-
- # Register the mock on the federation client.
- federation_client = self.hs.get_federation_client()
- federation_client.query_user_devices = Mock(side_effect=query_user_devices) # type: ignore[method-assign]
-
- # Register a mock on the store so that the incoming update doesn't fail because
- # we don't share a room with the user.
- store = self.hs.get_datastores().main
- store.get_rooms_for_user = AsyncMock(return_value=["!someroom:test"])
-
- # Manually inject a fake device list update. We need this update to include at
- # least one prev_id so that the user's device list will need to be retried.
- device_list_updater = self.hs.get_device_handler().device_list_updater
- assert isinstance(device_list_updater, DeviceListUpdater)
- self.get_success(
- device_list_updater.incoming_device_list_update(
- origin=remote_origin,
- edu_content={
- "deleted": False,
- "device_display_name": "Mobile",
- "device_id": "QBUAZIFURK",
- "prev_id": [5],
- "stream_id": 6,
- "user_id": remote_user_id,
- },
- )
- )
-
- # Check that there was one resync attempt.
- self.assertEqual(self.resync_attempts, 1)
-
- # Check that the resync attempt failed and caused the user's device list to be
- # marked as stale.
- need_resync = self.get_success(
- store.get_user_ids_requiring_device_list_resync()
- )
- self.assertIn(remote_user_id, need_resync)
-
- # Check that waiting for 30 seconds caused Synapse to retry resyncing the device
- # list.
- self.reactor.advance(30)
- self.assertEqual(self.resync_attempts, 2)
-
- def test_cross_signing_keys_retry(self) -> None:
- """Tests that resyncing a device list correctly processes cross-signing keys from
- the remote server.
- """
- remote_user_id = "@john:test_remote"
- remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
- remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
-
- # Register mock device list retrieval on the federation client.
- federation_client = self.hs.get_federation_client()
- federation_client.query_user_devices = AsyncMock( # type: ignore[method-assign]
- return_value={
- "user_id": remote_user_id,
- "stream_id": 1,
- "devices": [],
- "master_key": {
- "user_id": remote_user_id,
- "usage": ["master"],
- "keys": {"ed25519:" + remote_master_key: remote_master_key},
- },
- "self_signing_key": {
- "user_id": remote_user_id,
- "usage": ["self_signing"],
- "keys": {
- "ed25519:" + remote_self_signing_key: remote_self_signing_key
- },
- },
- }
- )
-
- # Resync the device list.
- device_handler = self.hs.get_device_handler()
- self.get_success(
- device_handler.device_list_updater.multi_user_device_resync(
- [remote_user_id]
- ),
- )
-
- # Retrieve the cross-signing keys for this user.
- keys = self.get_success(
- self.store.get_e2e_cross_signing_keys_bulk(user_ids=[remote_user_id]),
- )
- self.assertIn(remote_user_id, keys)
- key = keys[remote_user_id]
- assert key is not None
-
- # Check that the master key is the one returned by the mock.
- master_key = key["master"]
- self.assertEqual(len(master_key["keys"]), 1)
- self.assertTrue("ed25519:" + remote_master_key in master_key["keys"].keys())
- self.assertTrue(remote_master_key in master_key["keys"].values())
-
- # Check that the self-signing key is the one returned by the mock.
- self_signing_key = key["self_signing"]
- self.assertEqual(len(self_signing_key["keys"]), 1)
- self.assertTrue(
- "ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(),
- )
- self.assertTrue(remote_self_signing_key in self_signing_key["keys"].values())
-
-
-class StripUnsignedFromEventsTestCase(unittest.TestCase):
- def test_strip_unauthorized_unsigned_values(self) -> None:
- event1 = {
- "sender": "@baduser:test.serv",
- "state_key": "@baduser:test.serv",
- "event_id": "$event1:test.serv",
- "depth": 1000,
- "origin_server_ts": 1,
- "type": "m.room.member",
- "origin": "test.servx",
- "content": {"membership": "join"},
- "auth_events": [],
- "unsigned": {"malicious garbage": "hackz", "more warez": "more hackz"},
- }
- filtered_event = event_from_pdu_json(event1, RoomVersions.V1)
- # Make sure unauthorized fields are stripped from unsigned
- self.assertNotIn("more warez", filtered_event.unsigned)
-
- def test_strip_event_maintains_allowed_fields(self) -> None:
- event2 = {
- "sender": "@baduser:test.serv",
- "state_key": "@baduser:test.serv",
- "event_id": "$event2:test.serv",
- "depth": 1000,
- "origin_server_ts": 1,
- "type": "m.room.member",
- "origin": "test.servx",
- "auth_events": [],
- "content": {"membership": "join"},
- "unsigned": {
- "malicious garbage": "hackz",
- "more warez": "more hackz",
- "age": 14,
- "invite_room_state": [],
- },
- }
-
- filtered_event2 = event_from_pdu_json(event2, RoomVersions.V1)
- self.assertIn("age", filtered_event2.unsigned)
- self.assertEqual(14, filtered_event2.unsigned["age"])
- self.assertNotIn("more warez", filtered_event2.unsigned)
- # Invite_room_state is allowed in events of type m.room.member
- self.assertIn("invite_room_state", filtered_event2.unsigned)
- self.assertEqual([], filtered_event2.unsigned["invite_room_state"])
-
- def test_strip_event_removes_fields_based_on_event_type(self) -> None:
- event3 = {
- "sender": "@baduser:test.serv",
- "state_key": "@baduser:test.serv",
- "event_id": "$event3:test.serv",
- "depth": 1000,
- "origin_server_ts": 1,
- "type": "m.room.power_levels",
- "origin": "test.servx",
- "content": {},
- "auth_events": [],
- "unsigned": {
- "malicious garbage": "hackz",
- "more warez": "more hackz",
- "age": 14,
- "invite_room_state": [],
- },
- }
- filtered_event3 = event_from_pdu_json(event3, RoomVersions.V1)
- self.assertIn("age", filtered_event3.unsigned)
- # Invite_room_state field is only permitted in event type m.room.member
- self.assertNotIn("invite_room_state", filtered_event3.unsigned)
- self.assertNotIn("more warez", filtered_event3.unsigned)
diff --git a/tests/test_mau.py b/tests/test_mau.py
index 714854cdf2..820913dde4 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -332,6 +332,9 @@ class TestMauLimit(unittest.HomeserverTestCase):
)
if channel.code != 200:
+ #raise Exception(
+ # f"Failed to register user {localpart}: {channel.code} {channel.text_body}"
+ #)
raise HttpResponseException(
channel.code, channel.result["reason"], channel.result["body"]
).to_synapse_error()
diff --git a/tests/test_server.py b/tests/test_server.py
index 9ff2589497..9cb6766b5f 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -233,9 +233,7 @@ class OptionsResourceTests(unittest.TestCase):
self.resource = OptionsResource()
self.resource.putChild(b"res", DummyResource())
- def _make_request(
- self, method: bytes, path: bytes, experimental_cors_msc3886: bool = False
- ) -> FakeChannel:
+ def _make_request(self, method: bytes, path: bytes) -> FakeChannel:
"""Create a request from the method/path and return a channel with the response."""
# Create a site and query for the resource.
site = SynapseSite(
@@ -246,7 +244,6 @@ class OptionsResourceTests(unittest.TestCase):
{
"type": "http",
"port": 0,
- "experimental_cors_msc3886": experimental_cors_msc3886,
},
),
self.resource,
@@ -283,32 +280,6 @@ class OptionsResourceTests(unittest.TestCase):
[b"Synapse-Trace-Id, Server"],
)
- def _check_cors_msc3886_headers(self, channel: FakeChannel) -> None:
- # Ensure the correct CORS headers have been added
- # as per https://github.com/matrix-org/matrix-spec-proposals/blob/hughns/simple-rendezvous-capability/proposals/3886-simple-rendezvous-capability.md#cors
- self.assertEqual(
- channel.headers.getRawHeaders(b"Access-Control-Allow-Origin"),
- [b"*"],
- "has correct CORS Origin header",
- )
- self.assertEqual(
- channel.headers.getRawHeaders(b"Access-Control-Allow-Methods"),
- [b"GET, HEAD, POST, PUT, DELETE, OPTIONS"], # HEAD isn't in the spec
- "has correct CORS Methods header",
- )
- self.assertEqual(
- channel.headers.getRawHeaders(b"Access-Control-Allow-Headers"),
- [
- b"X-Requested-With, Content-Type, Authorization, Date, If-Match, If-None-Match"
- ],
- "has correct CORS Headers header",
- )
- self.assertEqual(
- channel.headers.getRawHeaders(b"Access-Control-Expose-Headers"),
- [b"ETag, Location, X-Max-Bytes"],
- "has correct CORS Expose Headers header",
- )
-
def test_unknown_options_request(self) -> None:
"""An OPTIONS requests to an unknown URL still returns 204 No Content."""
channel = self._make_request(b"OPTIONS", b"/foo/")
@@ -325,16 +296,6 @@ class OptionsResourceTests(unittest.TestCase):
self._check_cors_standard_headers(channel)
- def test_known_options_request_msc3886(self) -> None:
- """An OPTIONS requests to an known URL still returns 204 No Content."""
- channel = self._make_request(
- b"OPTIONS", b"/res/", experimental_cors_msc3886=True
- )
- self.assertEqual(channel.code, 204)
- self.assertNotIn("body", channel.result)
-
- self._check_cors_msc3886_headers(channel)
-
def test_unknown_request(self) -> None:
"""A non-OPTIONS request to an unknown URL should 404."""
channel = self._make_request(b"GET", b"/foo/")
diff --git a/tests/test_state.py b/tests/test_state.py
index 311a590693..adb72b0730 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -31,7 +31,7 @@ from typing import (
Tuple,
cast,
)
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
from twisted.internet import defer
@@ -149,7 +149,7 @@ class _DummyStore:
async def get_partial_state_events(
self, event_ids: Collection[str]
) -> Dict[str, bool]:
- return {e: False for e in event_ids}
+ return dict.fromkeys(event_ids, False)
async def get_state_group_delta(
self, name: str
@@ -221,7 +221,16 @@ class Graph:
class StateTestCase(unittest.TestCase):
def setUp(self) -> None:
self.dummy_store = _DummyStore()
- storage_controllers = Mock(main=self.dummy_store, state=self.dummy_store)
+
+ # Add a dummy epoch store that always retruns that we have all the
+ # necessary state groups.
+ dummy_deletion_store = AsyncMock()
+ dummy_deletion_store.check_state_groups_and_bump_deletion.return_value = []
+
+ storage_controllers = Mock(
+ main=self.dummy_store,
+ state=self.dummy_store,
+ )
hs = Mock(
spec_set=[
"config",
@@ -241,7 +250,10 @@ class StateTestCase(unittest.TestCase):
)
clock = cast(Clock, MockClock())
hs.config = default_config("tesths", True)
- hs.get_datastores.return_value = Mock(main=self.dummy_store)
+ hs.get_datastores.return_value = Mock(
+ main=self.dummy_store,
+ state_deletion=dummy_deletion_store,
+ )
hs.get_state_handler.return_value = None
hs.get_clock.return_value = clock
hs.get_macaroon_generator.return_value = MacaroonGenerator(
diff --git a/tests/test_types.py b/tests/test_types.py
index 00adc65a5a..0c08bc8ecc 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -145,7 +145,9 @@ class MapUsernameTestCase(unittest.TestCase):
(MultiWriterStreamToken,),
(RoomStreamToken,),
],
- class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_{params_dict['token_type'].__name__}",
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{params_dict['token_type'].__name__}",
)
class MultiWriterTokenTestCase(unittest.HomeserverTestCase):
"""Tests for the different types of multi writer tokens."""
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 4ab42a02b9..3e6fd03600 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -22,6 +22,8 @@
"""
Utilities for running the unit tests
"""
+
+import base64
import json
import sys
import warnings
@@ -137,3 +139,23 @@ SMALL_PNG = unhexlify(
b"0000001f15c4890000000a49444154789c63000100000500010d"
b"0a2db40000000049454e44ae426082"
)
+# The SHA256 hexdigest for the above bytes.
+SMALL_PNG_SHA256 = "ebf4f635a17d10d6eb46ba680b70142419aa3220f228001a036d311a22ee9d2a"
+
+# A small CMYK-encoded JPEG image used in some tests.
+#
+# Generated with:
+# img = PIL.Image.new('CMYK', (1, 1), (0, 0, 0, 0))
+# img.save('minimal_cmyk.jpg', 'JPEG')
+#
+# Resolution: 1x1, MIME type: image/jpeg, Extension: jpeg, Size: 4 KiB
+SMALL_CMYK_JPEG = base64.b64decode("""
+/9j/7gAOQWRvYmUAZAAAAAAA/9sAQwAIBgYHBgUIBwcHCQkICgwUDQwLCww
+ZEhMPFB0aHx4dGhwcICQuJyAiLCMcHCg3KSwwMTQ0NB8nOT04MjwuMzQy/8
+AAFAgAAQABBEMRAE0RAFkRAEsRAP/EAB8AAAEFAQEBAQEBAAAAAAAAAAABA
+gMEBQYHCAkKC//EALUQAAIBAwMCBAMFBQQEAAABfQECAwAEEQUSITFBBhNR
+YQcicRQygZGhCCNCscEVUtHwJDNicoIJChYXGBkaJSYnKCkqNDU2Nzg5OkN
+ERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6g4SFhoeIiYqSk5SVlp
+eYmZqio6Slpqeoqaqys7S1tre4ubrCw8TFxsfIycrS09TV1tfY2drh4uPk5
+ebn6Onq8fLz9PX29/j5+v/aAA4EQwBNAFkASwAAPwD3+vf69/r3+v/Z
+""")
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index dd40c338d6..d58222a9f6 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -48,7 +48,7 @@ def setup_logging() -> None:
# We exclude `%(asctime)s` from this format because the Twisted logger adds its own
# timestamp
- log_format = "%(name)s - %(lineno)d - " "%(levelname)s - %(request)s - %(message)s"
+ log_format = "%(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s"
handler = ToTwistedHandler()
formatter = logging.Formatter(log_format)
diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py
index 6c4be1c1f8..5bf5e5cb0c 100644
--- a/tests/test_utils/oidc.py
+++ b/tests/test_utils/oidc.py
@@ -20,7 +20,9 @@
#
+import base64
import json
+from hashlib import sha256
from typing import Any, ContextManager, Dict, List, Optional, Tuple
from unittest.mock import Mock, patch
from urllib.parse import parse_qs
@@ -154,10 +156,23 @@ class FakeOidcServer:
json_payload = json.dumps(payload)
return jws.serialize_compact(protected, json_payload, self._key).decode("utf-8")
- def generate_id_token(self, grant: FakeAuthorizationGrant) -> str:
+ def generate_id_token(
+ self, grant: FakeAuthorizationGrant, access_token: str
+ ) -> str:
+ # Generate a hash of the access token for the optional
+ # `at_hash` field in an ID Token.
+ #
+ # 3.1.3.6. ID Token, https://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken
+ at_hash = (
+ base64.urlsafe_b64encode(sha256(access_token.encode("ascii")).digest()[:16])
+ .rstrip(b"=")
+ .decode("ascii")
+ )
+
now = int(self._clock.time())
id_token = {
**grant.userinfo,
+ "at_hash": at_hash,
"iss": self.issuer,
"aud": grant.client_id,
"iat": now,
@@ -243,7 +258,7 @@ class FakeOidcServer:
}
if "openid" in grant.scope:
- token["id_token"] = self.generate_id_token(grant)
+ token["id_token"] = self.generate_id_token(grant, access_token)
return dict(token)
diff --git a/tests/unittest.py b/tests/unittest.py
index 4aa7f56106..24077d79d6 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -40,6 +40,7 @@ from typing import (
Mapping,
NoReturn,
Optional,
+ Protocol,
Tuple,
Type,
TypeVar,
@@ -50,7 +51,7 @@ from unittest.mock import Mock, patch
import canonicaljson
import signedjson.key
import unpaddedbase64
-from typing_extensions import Concatenate, ParamSpec, Protocol
+from typing_extensions import Concatenate, ParamSpec
from twisted.internet.defer import Deferred, ensureDeferred
from twisted.python.failure import Failure
@@ -272,8 +273,8 @@ class TestCase(unittest.TestCase):
def assertIncludes(
self,
- actual_items: AbstractSet[str],
- expected_items: AbstractSet[str],
+ actual_items: AbstractSet[TV],
+ expected_items: AbstractSet[TV],
exact: bool = False,
message: Optional[str] = None,
) -> None:
@@ -457,7 +458,9 @@ class HomeserverTestCase(TestCase):
# Type ignore: mypy doesn't like us assigning to methods.
self.hs.get_auth().get_user_by_req = get_requester # type: ignore[method-assign]
self.hs.get_auth().get_user_by_access_token = get_requester # type: ignore[method-assign]
- self.hs.get_auth().get_access_token_from_request = Mock(return_value=token) # type: ignore[method-assign]
+ self.hs.get_auth().get_access_token_from_request = Mock( # type: ignore[method-assign]
+ return_value=token
+ )
if self.needs_threadpool:
self.reactor.threadpool = ThreadPool() # type: ignore[assignment]
@@ -779,7 +782,7 @@ class HomeserverTestCase(TestCase):
self,
username: str,
appservice_token: str,
- ) -> Tuple[str, str]:
+ ) -> Tuple[str, Optional[str]]:
"""Register an appservice user as an application service.
Requires the client-facing registration API be registered.
@@ -803,7 +806,7 @@ class HomeserverTestCase(TestCase):
access_token=appservice_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
- return channel.json_body["user_id"], channel.json_body["device_id"]
+ return channel.json_body["user_id"], channel.json_body.get("device_id")
def login(
self,
diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py
index d82822d00d..cfd2882410 100644
--- a/tests/util/test_async_helpers.py
+++ b/tests/util/test_async_helpers.py
@@ -18,7 +18,7 @@
#
#
import traceback
-from typing import Generator, List, NoReturn, Optional
+from typing import Any, Coroutine, Generator, List, NoReturn, Optional, Tuple, TypeVar
from parameterized import parameterized_class
@@ -39,6 +39,7 @@ from synapse.util.async_helpers import (
ObservableDeferred,
concurrently_execute,
delay_cancellation,
+ gather_optional_coroutines,
stop_cancellation,
timeout_deferred,
)
@@ -46,6 +47,8 @@ from synapse.util.async_helpers import (
from tests.server import get_clock
from tests.unittest import TestCase
+T = TypeVar("T")
+
class ObservableDeferredTest(TestCase):
def test_succeed(self) -> None:
@@ -317,12 +320,19 @@ class ConcurrentlyExecuteTest(TestCase):
await concurrently_execute(callback, [1], 2)
except _TestException as e:
tb = traceback.extract_tb(e.__traceback__)
- # we expect to see "caller", "concurrently_execute", "callback",
- # and some magic from inside ensureDeferred that happens when .fail
- # is called.
+
+ # Remove twisted internals from the stack, as we don't care
+ # about the precise details.
+ tb = traceback.StackSummary(
+ t for t in tb if "/twisted/" not in t.filename
+ )
+
+ # we expect to see "caller", "concurrently_execute" at the top of the stack
self.assertEqual(tb[0].name, "caller")
self.assertEqual(tb[1].name, "concurrently_execute")
- self.assertEqual(tb[-2].name, "callback")
+ # ... some stack frames from the implementation of `concurrently_execute` ...
+ # and at the bottom of the stack we expect to see "callback"
+ self.assertEqual(tb[-1].name, "callback")
else:
self.fail("No exception thrown")
@@ -588,3 +598,106 @@ class AwakenableSleeperTests(TestCase):
sleeper.wake("name")
self.assertTrue(d1.called)
self.assertTrue(d2.called)
+
+
+class GatherCoroutineTests(TestCase):
+ """Tests for `gather_optional_coroutines`"""
+
+ def make_coroutine(self) -> Tuple[Coroutine[Any, Any, T], "defer.Deferred[T]"]:
+ """Returns a coroutine and a deferred that it is waiting on to resolve"""
+
+ d: "defer.Deferred[T]" = defer.Deferred()
+
+ async def inner() -> T:
+ with PreserveLoggingContext():
+ return await d
+
+ return inner(), d
+
+ def test_single(self) -> None:
+ "Test passing in a single coroutine works"
+
+ with LoggingContext("test_ctx") as text_ctx:
+ deferred: "defer.Deferred[None]"
+ coroutine, deferred = self.make_coroutine()
+
+ gather_deferred = defer.ensureDeferred(
+ gather_optional_coroutines(coroutine)
+ )
+
+ # We shouldn't have a result yet, and should be in the sentinel
+ # context.
+ self.assertNoResult(gather_deferred)
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
+
+ # Resolving the deferred will resolve the coroutine
+ deferred.callback(None)
+
+ # All coroutines have resolved, and so we should have the results
+ result = self.successResultOf(gather_deferred)
+ self.assertEqual(result, (None,))
+
+ # We should be back in the normal context.
+ self.assertEqual(current_context(), text_ctx)
+
+ def test_multiple_resolve(self) -> None:
+ "Test passing in multiple coroutine that all resolve works"
+
+ with LoggingContext("test_ctx") as test_ctx:
+ deferred1: "defer.Deferred[int]"
+ coroutine1, deferred1 = self.make_coroutine()
+ deferred2: "defer.Deferred[str]"
+ coroutine2, deferred2 = self.make_coroutine()
+
+ gather_deferred = defer.ensureDeferred(
+ gather_optional_coroutines(coroutine1, coroutine2)
+ )
+
+ # We shouldn't have a result yet, and should be in the sentinel
+ # context.
+ self.assertNoResult(gather_deferred)
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
+
+ # Even if we resolve one of the coroutines, we shouldn't have a result
+ # yet
+ deferred2.callback("test")
+ self.assertNoResult(gather_deferred)
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
+
+ deferred1.callback(1)
+
+ # All coroutines have resolved, and so we should have the results
+ result = self.successResultOf(gather_deferred)
+ self.assertEqual(result, (1, "test"))
+
+ # We should be back in the normal context.
+ self.assertEqual(current_context(), test_ctx)
+
+ def test_multiple_fail(self) -> None:
+ "Test passing in multiple coroutine where one fails does the right thing"
+
+ with LoggingContext("test_ctx") as test_ctx:
+ deferred1: "defer.Deferred[int]"
+ coroutine1, deferred1 = self.make_coroutine()
+ deferred2: "defer.Deferred[str]"
+ coroutine2, deferred2 = self.make_coroutine()
+
+ gather_deferred = defer.ensureDeferred(
+ gather_optional_coroutines(coroutine1, coroutine2)
+ )
+
+ # We shouldn't have a result yet, and should be in the sentinel
+ # context.
+ self.assertNoResult(gather_deferred)
+ self.assertEqual(current_context(), SENTINEL_CONTEXT)
+
+ # Throw an exception in one of the coroutines
+ exc = Exception("test")
+ deferred2.errback(exc)
+
+ # Expect the gather deferred to immediately fail
+ result_exc = self.failureResultOf(gather_deferred)
+ self.assertEqual(result_exc.value, exc)
+
+ # We should be back in the normal context.
+ self.assertEqual(current_context(), test_ctx)
diff --git a/tests/util/test_check_dependencies.py b/tests/util/test_check_dependencies.py
index 13a4e6ddaa..c052ba2b75 100644
--- a/tests/util/test_check_dependencies.py
+++ b/tests/util/test_check_dependencies.py
@@ -109,10 +109,13 @@ class TestDependencyChecker(TestCase):
def test_checks_ignore_dev_dependencies(self) -> None:
"""Both generic and per-extra checks should ignore dev dependencies."""
- with patch(
- "synapse.util.check_dependencies.metadata.requires",
- return_value=["dummypkg >= 1; extra == 'mypy'"],
- ), patch("synapse.util.check_dependencies.RUNTIME_EXTRAS", {"cool-extra"}):
+ with (
+ patch(
+ "synapse.util.check_dependencies.metadata.requires",
+ return_value=["dummypkg >= 1; extra == 'mypy'"],
+ ),
+ patch("synapse.util.check_dependencies.RUNTIME_EXTRAS", {"cool-extra"}),
+ ):
# We're testing that none of these calls raise.
with self.mock_installed_package(None):
check_requirements()
@@ -141,10 +144,13 @@ class TestDependencyChecker(TestCase):
def test_check_for_extra_dependencies(self) -> None:
"""Complain if a package required for an extra is missing or old."""
- with patch(
- "synapse.util.check_dependencies.metadata.requires",
- return_value=["dummypkg >= 1; extra == 'cool-extra'"],
- ), patch("synapse.util.check_dependencies.RUNTIME_EXTRAS", {"cool-extra"}):
+ with (
+ patch(
+ "synapse.util.check_dependencies.metadata.requires",
+ return_value=["dummypkg >= 1; extra == 'cool-extra'"],
+ ),
+ patch("synapse.util.check_dependencies.RUNTIME_EXTRAS", {"cool-extra"}),
+ ):
with self.mock_installed_package(None):
self.assertRaises(DependencyException, check_requirements, "cool-extra")
with self.mock_installed_package(old):
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index 7cbb1007da..7510657b85 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -19,9 +19,7 @@
#
#
-from typing import Hashable, Tuple
-
-from typing_extensions import Protocol
+from typing import Hashable, Protocol, Tuple
from twisted.internet import defer, reactor
from twisted.internet.base import ReactorBase
diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py
index af1199ef8a..9254bff79b 100644
--- a/tests/util/test_stream_change_cache.py
+++ b/tests/util/test_stream_change_cache.py
@@ -53,8 +53,8 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
# return True, whether it's a known entity or not.
self.assertTrue(cache.has_entity_changed("user@foo.com", 0))
self.assertTrue(cache.has_entity_changed("not@here.website", 0))
- self.assertTrue(cache.has_entity_changed("user@foo.com", 3))
- self.assertTrue(cache.has_entity_changed("not@here.website", 3))
+ self.assertTrue(cache.has_entity_changed("user@foo.com", 2))
+ self.assertTrue(cache.has_entity_changed("not@here.website", 2))
def test_entity_has_changed_pops_off_start(self) -> None:
"""
@@ -76,9 +76,11 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
self.assertTrue("user@foo.com" not in cache._entity_to_key)
self.assertEqual(
- cache.get_all_entities_changed(3).entities, ["user@elsewhere.org"]
+ cache.get_all_entities_changed(2).entities,
+ ["bar@baz.net", "user@elsewhere.org"],
)
- self.assertFalse(cache.get_all_entities_changed(2).hit)
+ self.assertFalse(cache.get_all_entities_changed(1).hit)
+ self.assertTrue(cache.get_all_entities_changed(2).hit)
# If we update an existing entity, it keeps the two existing entities
cache.entity_has_changed("bar@baz.net", 5)
@@ -89,7 +91,8 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
cache.get_all_entities_changed(3).entities,
["user@elsewhere.org", "bar@baz.net"],
)
- self.assertFalse(cache.get_all_entities_changed(2).hit)
+ self.assertFalse(cache.get_all_entities_changed(1).hit)
+ self.assertTrue(cache.get_all_entities_changed(2).hit)
def test_get_all_entities_changed(self) -> None:
"""
@@ -114,7 +117,8 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
self.assertEqual(
cache.get_all_entities_changed(3).entities, ["user@elsewhere.org"]
)
- self.assertFalse(cache.get_all_entities_changed(1).hit)
+ self.assertFalse(cache.get_all_entities_changed(0).hit)
+ self.assertTrue(cache.get_all_entities_changed(1).hit)
# ... later, things gest more updates
cache.entity_has_changed("user@foo.com", 5)
@@ -149,7 +153,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
# With no entities, it returns True for the past, present, and False for
# the future.
self.assertTrue(cache.has_any_entity_changed(0))
- self.assertTrue(cache.has_any_entity_changed(1))
+ self.assertFalse(cache.has_any_entity_changed(1))
self.assertFalse(cache.has_any_entity_changed(2))
# We add an entity
@@ -251,3 +255,28 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
# Unknown entities will return None
self.assertEqual(cache.get_max_pos_of_last_change("not@here.website"), None)
+
+ def test_all_entities_changed(self) -> None:
+ """
+ `StreamChangeCache.all_entities_changed(...)` will mark all entites as changed.
+ """
+ cache = StreamChangeCache("#test", 1, max_size=10)
+
+ cache.entity_has_changed("user@foo.com", 2)
+ cache.entity_has_changed("bar@baz.net", 3)
+ cache.entity_has_changed("user@elsewhere.org", 4)
+
+ cache.all_entities_changed(5)
+
+ # Everything should be marked as changed before the stream position where the
+ # change occurred.
+ self.assertTrue(cache.has_entity_changed("user@foo.com", 4))
+ self.assertTrue(cache.has_entity_changed("bar@baz.net", 4))
+ self.assertTrue(cache.has_entity_changed("user@elsewhere.org", 4))
+
+ # Nothing should be marked as changed at/after the stream position where the
+ # change occurred. In other words, nothing has changed since the stream position
+ # 5.
+ self.assertFalse(cache.has_entity_changed("user@foo.com", 5))
+ self.assertFalse(cache.has_entity_changed("bar@baz.net", 5))
+ self.assertFalse(cache.has_entity_changed("user@elsewhere.org", 5))
diff --git a/tests/util/test_stringutils.py b/tests/util/test_stringutils.py
index 646fd2163e..34c2395ecf 100644
--- a/tests/util/test_stringutils.py
+++ b/tests/util/test_stringutils.py
@@ -20,7 +20,11 @@
#
from synapse.api.errors import SynapseError
-from synapse.util.stringutils import assert_valid_client_secret, base62_encode
+from synapse.util.stringutils import (
+ assert_valid_client_secret,
+ base62_encode,
+ is_namedspaced_grammar,
+)
from .. import unittest
@@ -58,3 +62,25 @@ class StringUtilsTestCase(unittest.TestCase):
self.assertEqual("10", base62_encode(62))
self.assertEqual("1c", base62_encode(100))
self.assertEqual("001c", base62_encode(100, minwidth=4))
+
+ def test_namespaced_identifier(self) -> None:
+ self.assertTrue(is_namedspaced_grammar("test"))
+ self.assertTrue(is_namedspaced_grammar("m.test"))
+ self.assertTrue(is_namedspaced_grammar("org.matrix.test"))
+ self.assertTrue(is_namedspaced_grammar("org.matrix.msc1234"))
+ self.assertTrue(is_namedspaced_grammar("test"))
+ self.assertTrue(is_namedspaced_grammar("t-e_s.t"))
+
+ # Must start with letter.
+ self.assertFalse(is_namedspaced_grammar("1test"))
+ self.assertFalse(is_namedspaced_grammar("-test"))
+ self.assertFalse(is_namedspaced_grammar("_test"))
+ self.assertFalse(is_namedspaced_grammar(".test"))
+
+ # Must contain only a-z, 0-9, -, _, ..
+ self.assertFalse(is_namedspaced_grammar("test/"))
+ self.assertFalse(is_namedspaced_grammar('test"'))
+ self.assertFalse(is_namedspaced_grammar("testö"))
+
+ # Must be < 255 characters.
+ self.assertFalse(is_namedspaced_grammar("t" * 256))
diff --git a/tests/util/test_task_scheduler.py b/tests/util/test_task_scheduler.py
index 30f0510c9f..7f6e63bd49 100644
--- a/tests/util/test_task_scheduler.py
+++ b/tests/util/test_task_scheduler.py
@@ -18,8 +18,7 @@
# [This file includes modifications made by New Vector Limited]
#
#
-
-from typing import Optional, Tuple
+from typing import List, Optional, Tuple
from twisted.internet.task import deferLater
from twisted.test.proto_helpers import MemoryReactor
@@ -104,38 +103,48 @@ class TestTaskScheduler(HomeserverTestCase):
)
)
- # This is to give the time to the active tasks to finish
+ def get_tasks_of_status(status: TaskStatus) -> List[ScheduledTask]:
+ tasks = (
+ self.get_success(self.task_scheduler.get_task(task_id))
+ for task_id in task_ids
+ )
+ return [t for t in tasks if t is not None and t.status == status]
+
+ # At this point, there should be MAX_CONCURRENT_RUNNING_TASKS active tasks and
+ # one scheduled task.
+ self.assertEqual(
+ len(get_tasks_of_status(TaskStatus.ACTIVE)),
+ TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS,
+ )
+ self.assertEqual(
+ len(get_tasks_of_status(TaskStatus.SCHEDULED)),
+ 1,
+ )
+
+ # Give the time to the active tasks to finish
self.reactor.advance(1)
- # Check that only MAX_CONCURRENT_RUNNING_TASKS tasks has run and that one
+ # Check that MAX_CONCURRENT_RUNNING_TASKS tasks have run and that one
# is still scheduled.
- tasks = [
- self.get_success(self.task_scheduler.get_task(task_id))
- for task_id in task_ids
- ]
-
- self.assertEquals(
- len(
- [t for t in tasks if t is not None and t.status == TaskStatus.COMPLETE]
- ),
+ self.assertEqual(
+ len(get_tasks_of_status(TaskStatus.COMPLETE)),
TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS,
)
+ scheduled_tasks = get_tasks_of_status(TaskStatus.SCHEDULED)
+ self.assertEqual(len(scheduled_tasks), 1)
- scheduled_tasks = [
- t for t in tasks if t is not None and t.status == TaskStatus.ACTIVE
- ]
- self.assertEquals(len(scheduled_tasks), 1)
+ # The scheduled task should start 0.1s after the first of the active tasks
+ # finishes
+ self.reactor.advance(0.1)
+ self.assertEqual(len(get_tasks_of_status(TaskStatus.ACTIVE)), 1)
- # We need to wait for the next run of the scheduler loop
- self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
+ # ... and should finally complete after another second
self.reactor.advance(1)
-
- # Check that the last task has been properly executed after the next scheduler loop run
prev_scheduled_task = self.get_success(
self.task_scheduler.get_task(scheduled_tasks[0].id)
)
assert prev_scheduled_task is not None
- self.assertEquals(
+ self.assertEqual(
prev_scheduled_task.status,
TaskStatus.COMPLETE,
)
diff --git a/tests/util/test_threepids.py b/tests/util/test_threepids.py
deleted file mode 100644
index 15575cc572..0000000000
--- a/tests/util/test_threepids.py
+++ /dev/null
@@ -1,55 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright 2020 Dirk Klimpel
-# Copyright (C) 2023 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-# Originally licensed under the Apache License, Version 2.0:
-# <http://www.apache.org/licenses/LICENSE-2.0>.
-#
-# [This file includes modifications made by New Vector Limited]
-#
-#
-
-from synapse.util.threepids import canonicalise_email
-
-from tests.unittest import HomeserverTestCase
-
-
-class CanonicaliseEmailTests(HomeserverTestCase):
- def test_no_at(self) -> None:
- with self.assertRaises(ValueError):
- canonicalise_email("address-without-at.bar")
-
- def test_two_at(self) -> None:
- with self.assertRaises(ValueError):
- canonicalise_email("foo@foo@test.bar")
-
- def test_bad_format(self) -> None:
- with self.assertRaises(ValueError):
- canonicalise_email("user@bad.example.net@good.example.com")
-
- def test_valid_format(self) -> None:
- self.assertEqual(canonicalise_email("foo@test.bar"), "foo@test.bar")
-
- def test_domain_to_lower(self) -> None:
- self.assertEqual(canonicalise_email("foo@TEST.BAR"), "foo@test.bar")
-
- def test_domain_with_umlaut(self) -> None:
- self.assertEqual(canonicalise_email("foo@Öumlaut.com"), "foo@öumlaut.com")
-
- def test_address_casefold(self) -> None:
- self.assertEqual(
- canonicalise_email("Strauß@Example.com"), "strauss@example.com"
- )
-
- def test_address_trim(self) -> None:
- self.assertEqual(canonicalise_email(" foo@test.bar "), "foo@test.bar")
diff --git a/tests/util/test_wheel_timer.py b/tests/util/test_wheel_timer.py
index 173a7cfaec..6fa575a18e 100644
--- a/tests/util/test_wheel_timer.py
+++ b/tests/util/test_wheel_timer.py
@@ -28,53 +28,55 @@ class WheelTimerTestCase(unittest.TestCase):
def test_single_insert_fetch(self) -> None:
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
- obj = object()
- wheel.insert(100, obj, 150)
+ wheel.insert(100, "1", 150)
self.assertListEqual(wheel.fetch(101), [])
self.assertListEqual(wheel.fetch(110), [])
self.assertListEqual(wheel.fetch(120), [])
self.assertListEqual(wheel.fetch(130), [])
self.assertListEqual(wheel.fetch(149), [])
- self.assertListEqual(wheel.fetch(156), [obj])
+ self.assertListEqual(wheel.fetch(156), ["1"])
self.assertListEqual(wheel.fetch(170), [])
def test_multi_insert(self) -> None:
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
- obj1 = object()
- obj2 = object()
- obj3 = object()
- wheel.insert(100, obj1, 150)
- wheel.insert(105, obj2, 130)
- wheel.insert(106, obj3, 160)
+ wheel.insert(100, "1", 150)
+ wheel.insert(105, "2", 130)
+ wheel.insert(106, "3", 160)
self.assertListEqual(wheel.fetch(110), [])
- self.assertListEqual(wheel.fetch(135), [obj2])
+ self.assertListEqual(wheel.fetch(135), ["2"])
self.assertListEqual(wheel.fetch(149), [])
- self.assertListEqual(wheel.fetch(158), [obj1])
+ self.assertListEqual(wheel.fetch(158), ["1"])
self.assertListEqual(wheel.fetch(160), [])
- self.assertListEqual(wheel.fetch(200), [obj3])
+ self.assertListEqual(wheel.fetch(200), ["3"])
self.assertListEqual(wheel.fetch(210), [])
def test_insert_past(self) -> None:
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
- obj = object()
- wheel.insert(100, obj, 50)
- self.assertListEqual(wheel.fetch(120), [obj])
+ wheel.insert(100, "1", 50)
+ self.assertListEqual(wheel.fetch(120), ["1"])
def test_insert_past_multi(self) -> None:
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
- obj1 = object()
- obj2 = object()
- obj3 = object()
- wheel.insert(100, obj1, 150)
- wheel.insert(100, obj2, 140)
- wheel.insert(100, obj3, 50)
- self.assertListEqual(wheel.fetch(110), [obj3])
+ wheel.insert(100, "1", 150)
+ wheel.insert(100, "2", 140)
+ wheel.insert(100, "3", 50)
+ self.assertListEqual(wheel.fetch(110), ["3"])
self.assertListEqual(wheel.fetch(120), [])
- self.assertListEqual(wheel.fetch(147), [obj2])
- self.assertListEqual(wheel.fetch(200), [obj1])
+ self.assertListEqual(wheel.fetch(147), ["2"])
+ self.assertListEqual(wheel.fetch(200), ["1"])
self.assertListEqual(wheel.fetch(240), [])
+
+ def test_multi_insert_then_past(self) -> None:
+ wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
+
+ wheel.insert(100, "1", 150)
+ wheel.insert(100, "2", 160)
+ wheel.insert(100, "3", 155)
+
+ self.assertListEqual(wheel.fetch(110), [])
+ self.assertListEqual(wheel.fetch(158), ["1"])
diff --git a/tests/utils.py b/tests/utils.py
index 9fd26ef348..57986c18bc 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -28,6 +28,7 @@ from typing import (
Callable,
Dict,
List,
+ Literal,
Optional,
Tuple,
Type,
@@ -37,7 +38,7 @@ from typing import (
)
import attr
-from typing_extensions import Literal, ParamSpec
+from typing_extensions import ParamSpec
from synapse.api.constants import EventTypes
from synapse.api.room_versions import RoomVersions
@@ -181,7 +182,6 @@ def default_config(
"max_mau_value": 50,
"mau_trial_days": 0,
"mau_stats_only": False,
- "mau_limits_reserved_threepids": [],
"admin_contact": None,
"rc_message": {"per_second": 10000, "burst_count": 10000},
"rc_registration": {"per_second": 10000, "burst_count": 10000},
@@ -200,9 +200,8 @@ def default_config(
"per_user": {"per_second": 10000, "burst_count": 10000},
},
"rc_3pid_validation": {"per_second": 10000, "burst_count": 10000},
- "saml2_enabled": False,
+ "rc_presence": {"per_user": {"per_second": 10000, "burst_count": 10000}},
"public_baseurl": None,
- "default_identity_server": None,
"key_refresh_interval": 24 * 60 * 60 * 1000,
"old_signing_keys": {},
"tls_fingerprints": [],
@@ -399,11 +398,24 @@ class TestTimeout(Exception):
class test_timeout:
+ """
+ FIXME: This implementation is not robust against other code tight-looping and
+ preventing the signals propagating and timing out the test. You may need to add
+ `time.sleep(0.1)` to your code in order to allow this timeout to work correctly.
+
+ ```py
+ with test_timeout(3):
+ while True:
+ my_checking_func()
+ time.sleep(0.1)
+ ```
+ """
+
def __init__(self, seconds: int, error_message: Optional[str] = None) -> None:
- if error_message is None:
- error_message = "test timed out after {}s.".format(seconds)
+ self.error_message = f"Test timed out after {seconds}s"
+ if error_message is not None:
+ self.error_message += f": {error_message}"
self.seconds = seconds
- self.error_message = error_message
def handle_timeout(self, signum: int, frame: Optional[FrameType]) -> None:
raise TestTimeout(self.error_message)
|