diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
index 9ff853a83d..c5bff468e2 100644
--- a/tests/handlers/test_admin.py
+++ b/tests/handlers/test_admin.py
@@ -260,7 +260,6 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(args[0]["name"], self.user2)
self.assertIn("displayname", args[0])
self.assertIn("avatar_url", args[0])
- self.assertIn("threepids", args[0])
self.assertIn("external_ids", args[0])
self.assertIn("creation_ts", args[0])
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 1eec0d43b7..1db630e9e4 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -1165,12 +1165,23 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
self.hs.get_datastores().main.services_cache = [self._service]
# Register some appservice users
- self._sender_user, self._sender_device = self.register_appservice_user(
+ user_id, device_id = self.register_appservice_user(
"as.sender", self._service_token
)
- self._namespaced_user, self._namespaced_device = self.register_appservice_user(
+ # With MSC4190 enabled, there will not be a device created
+ # during AS registration. However MSC4190 is not enabled
+ # in this test. It may become the default behaviour in the
+ # future, in which case this test will need to be updated.
+ assert device_id is not None
+ self._sender_user = user_id
+ self._sender_device = device_id
+
+ user_id, device_id = self.register_appservice_user(
"_as_user1", self._service_token
)
+ assert device_id is not None
+ self._namespaced_user = user_id
+ self._namespaced_device = device_id
# Register a real user as well.
self._real_user = self.register_user("real.user", "meow")
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
deleted file mode 100644
index f41f7d36ad..0000000000
--- a/tests/handlers/test_cas.py
+++ /dev/null
@@ -1,239 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
-# Copyright (C) 2023 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-# Originally licensed under the Apache License, Version 2.0:
-# <http://www.apache.org/licenses/LICENSE-2.0>.
-#
-# [This file includes modifications made by New Vector Limited]
-#
-#
-from typing import Any, Dict
-from unittest.mock import AsyncMock, Mock
-
-from twisted.test.proto_helpers import MemoryReactor
-
-from synapse.handlers.cas import CasResponse
-from synapse.server import HomeServer
-from synapse.util import Clock
-
-from tests.unittest import HomeserverTestCase, override_config
-
-# These are a few constants that are used as config parameters in the tests.
-BASE_URL = "https://synapse/"
-SERVER_URL = "https://issuer/"
-
-
-class CasHandlerTestCase(HomeserverTestCase):
- def default_config(self) -> Dict[str, Any]:
- config = super().default_config()
- config["public_baseurl"] = BASE_URL
- cas_config = {
- "enabled": True,
- "server_url": SERVER_URL,
- "service_url": BASE_URL,
- }
-
- # Update this config with what's in the default config so that
- # override_config works as expected.
- cas_config.update(config.get("cas_config", {}))
- config["cas_config"] = cas_config
-
- return config
-
- def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- hs = self.setup_test_homeserver()
-
- self.handler = hs.get_cas_handler()
-
- # Reduce the number of attempts when generating MXIDs.
- sso_handler = hs.get_sso_handler()
- sso_handler._MAP_USERNAME_RETRIES = 3
-
- return hs
-
- def test_map_cas_user_to_user(self) -> None:
- """Ensure that mapping the CAS user returned from a provider to an MXID works properly."""
-
- # stub out the auth handler
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
-
- cas_response = CasResponse("test_user", {})
- request = _mock_request()
- self.get_success(
- self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
- )
-
- # check that the auth handler got called as expected
- auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test",
- "cas",
- request,
- "redirect_uri",
- None,
- new_user=True,
- auth_provider_session_id=None,
- )
-
- def test_map_cas_user_to_existing_user(self) -> None:
- """Existing users can log in with CAS account."""
- store = self.hs.get_datastores().main
- self.get_success(
- store.register_user(user_id="@test_user:test", password_hash=None)
- )
-
- # stub out the auth handler
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
-
- # Map a user via SSO.
- cas_response = CasResponse("test_user", {})
- request = _mock_request()
- self.get_success(
- self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
- )
-
- # check that the auth handler got called as expected
- auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test",
- "cas",
- request,
- "redirect_uri",
- None,
- new_user=False,
- auth_provider_session_id=None,
- )
-
- # Subsequent calls should map to the same mxid.
- auth_handler.complete_sso_login.reset_mock()
- self.get_success(
- self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
- )
- auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test",
- "cas",
- request,
- "redirect_uri",
- None,
- new_user=False,
- auth_provider_session_id=None,
- )
-
- def test_map_cas_user_to_invalid_localpart(self) -> None:
- """CAS automaps invalid characters to base-64 encoding."""
-
- # stub out the auth handler
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
-
- cas_response = CasResponse("föö", {})
- request = _mock_request()
- self.get_success(
- self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
- )
-
- # check that the auth handler got called as expected
- auth_handler.complete_sso_login.assert_called_once_with(
- "@f=c3=b6=c3=b6:test",
- "cas",
- request,
- "redirect_uri",
- None,
- new_user=True,
- auth_provider_session_id=None,
- )
-
- @override_config(
- {
- "cas_config": {
- "required_attributes": {"userGroup": "staff", "department": None}
- }
- }
- )
- def test_required_attributes(self) -> None:
- """The required attributes must be met from the CAS response."""
-
- # stub out the auth handler
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
-
- # The response doesn't have the proper userGroup or department.
- cas_response = CasResponse("test_user", {})
- request = _mock_request()
- self.get_success(
- self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
- )
- auth_handler.complete_sso_login.assert_not_called()
-
- # The response doesn't have any department.
- cas_response = CasResponse("test_user", {"userGroup": ["staff"]})
- request.reset_mock()
- self.get_success(
- self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
- )
- auth_handler.complete_sso_login.assert_not_called()
-
- # Add the proper attributes and it should succeed.
- cas_response = CasResponse(
- "test_user", {"userGroup": ["staff", "admin"], "department": ["sales"]}
- )
- request.reset_mock()
- self.get_success(
- self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
- )
-
- # check that the auth handler got called as expected
- auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test",
- "cas",
- request,
- "redirect_uri",
- None,
- new_user=True,
- auth_provider_session_id=None,
- )
-
- @override_config({"cas_config": {"enable_registration": False}})
- def test_map_cas_user_does_not_register_new_user(self) -> None:
- """Ensures new users are not registered if the enabled registration flag is disabled."""
-
- # stub out the auth handler
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
-
- cas_response = CasResponse("test_user", {})
- request = _mock_request()
- self.get_success(
- self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
- )
-
- # check that the auth handler was not called as expected
- auth_handler.complete_sso_login.assert_not_called()
-
-
-def _mock_request() -> Mock:
- """Returns a mock which will stand in as a SynapseRequest"""
- mock = Mock(
- spec=[
- "finish",
- "getClientAddress",
- "getHeader",
- "setHeader",
- "setResponseCode",
- "write",
- ]
- )
- # `_disconnected` musn't be another `Mock`, otherwise it will be truthy.
- mock._disconnected = False
- return mock
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 4a3e36ffde..b7058d8002 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -587,6 +587,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
self.room_list_handler = hs.get_room_list_handler()
self.directory_handler = hs.get_directory_handler()
+ @unittest.override_config({"room_list_publication_rules": [{"action": "allow"}]})
def test_disabling_room_list(self) -> None:
self.room_list_handler.enable_room_list_search = True
self.directory_handler.enable_room_list_search = True
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 8a3dfdcf75..70fc4263e7 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -19,6 +19,7 @@
# [This file includes modifications made by New Vector Limited]
#
#
+import time
from typing import Dict, Iterable
from unittest import mock
@@ -151,18 +152,30 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
def test_claim_one_time_key(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
- keys = {"alg1:k1": "key1"}
-
res = self.get_success(
self.handler.upload_keys_for_user(
- local_user, device_id, {"one_time_keys": keys}
+ local_user, device_id, {"one_time_keys": {"alg1:k1": "key1"}}
)
)
self.assertDictEqual(
res, {"one_time_key_counts": {"alg1": 1, "signed_curve25519": 0}}
)
- res2 = self.get_success(
+ # Keys should be returned in the order they were uploaded. To test, advance time
+ # a little, then upload a second key with an earlier key ID; it should get
+ # returned second.
+ self.reactor.advance(1)
+ res = self.get_success(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": {"alg1:k0": "key0"}}
+ )
+ )
+ self.assertDictEqual(
+ res, {"one_time_key_counts": {"alg1": 2, "signed_curve25519": 0}}
+ )
+
+ # now claim both keys back. They should be in the same order
+ res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
self.requester,
@@ -171,12 +184,27 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
)
self.assertEqual(
- res2,
+ res,
{
"failures": {},
"one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}},
},
)
+ res = self.get_success(
+ self.handler.claim_one_time_keys(
+ {local_user: {device_id: {"alg1": 1}}},
+ self.requester,
+ timeout=None,
+ always_include_fallback_keys=False,
+ )
+ )
+ self.assertEqual(
+ res,
+ {
+ "failures": {},
+ "one_time_keys": {local_user: {device_id: {"alg1:k0": "key0"}}},
+ },
+ )
def test_claim_one_time_key_bulk(self) -> None:
"""Like test_claim_one_time_key but claims multiple keys in one handler call."""
@@ -336,6 +364,47 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
counts_by_alg, expected_counts_by_alg, f"{user_id}:{device_id}"
)
+ def test_claim_one_time_key_bulk_ordering(self) -> None:
+ """Keys returned by the bulk claim call should be returned in the correct order"""
+
+ # Alice has lots of keys, uploaded in a specific order
+ alice = f"@alice:{self.hs.hostname}"
+ alice_dev = "alice_dev_1"
+
+ self.get_success(
+ self.handler.upload_keys_for_user(
+ alice,
+ alice_dev,
+ {"one_time_keys": {"alg1:k20": 20, "alg1:k21": 21, "alg1:k22": 22}},
+ )
+ )
+ # Advance time by 1s, to ensure that there is a difference in upload time.
+ self.reactor.advance(1)
+ self.get_success(
+ self.handler.upload_keys_for_user(
+ alice,
+ alice_dev,
+ {"one_time_keys": {"alg1:k10": 10, "alg1:k11": 11, "alg1:k12": 12}},
+ )
+ )
+
+ # Now claim some, and check we get the right ones.
+ claim_res = self.get_success(
+ self.handler.claim_one_time_keys(
+ {alice: {alice_dev: {"alg1": 2}}},
+ self.requester,
+ timeout=None,
+ always_include_fallback_keys=False,
+ )
+ )
+ # We should get the first-uploaded keys, even though they have later key ids.
+ # We should get a random set of two of k20, k21, k22.
+ self.assertEqual(claim_res["failures"], {})
+ claimed_keys = claim_res["one_time_keys"]["@alice:test"]["alice_dev_1"]
+ self.assertEqual(len(claimed_keys), 2)
+ for key_id in claimed_keys.keys():
+ self.assertIn(key_id, ["alg1:k20", "alg1:k21", "alg1:k22"])
+
def test_fallback_key(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
@@ -1758,3 +1827,222 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
self.assertIs(exists, True)
self.assertIs(replaceable_without_uia, False)
+
+ def test_delete_old_one_time_keys(self) -> None:
+ """Test the db migration that clears out old OTKs"""
+
+ # We upload two sets of keys, one just over a week ago, and one just less than
+ # a week ago. Each batch contains some keys that match the deletion pattern
+ # (key IDs of 6 chars), and some that do not.
+ #
+ # Finally, set the scheduled task going, and check what gets deleted.
+
+ user_id = "@user000:" + self.hs.hostname
+ device_id = "xyz"
+
+ # The scheduled task should be for "now" in real, wallclock time, so
+ # set the test reactor to just over a week ago.
+ self.reactor.advance(time.time() - 7.5 * 24 * 3600)
+
+ # Upload some keys
+ self.get_success(
+ self.handler.upload_keys_for_user(
+ user_id,
+ device_id,
+ {
+ "one_time_keys": {
+ # some keys to delete
+ "alg1:AAAAAA": "key1",
+ "alg2:AAAAAB": {"key": "key2", "signatures": {"k1": "sig1"}},
+ # A key to *not* delete
+ "alg2:AAAAAAAAAA": {"key": "key3"},
+ }
+ },
+ )
+ )
+
+ # A day passes
+ self.reactor.advance(24 * 3600)
+
+ # Upload some more keys
+ self.get_success(
+ self.handler.upload_keys_for_user(
+ user_id,
+ device_id,
+ {
+ "one_time_keys": {
+ # some keys which match the pattern
+ "alg1:BAAAAA": "key1",
+ "alg2:BAAAAB": {"key": "key2", "signatures": {"k1": "sig1"}},
+ # A key to *not* delete
+ "alg2:BAAAAAAAAA": {"key": "key3"},
+ }
+ },
+ )
+ )
+
+ # The rest of the week passes, which should set the scheduled task going.
+ self.reactor.advance(6.5 * 24 * 3600)
+
+ # Check what we're left with in the database
+ remaining_key_ids = {
+ row[0]
+ for row in self.get_success(
+ self.handler.store.db_pool.simple_select_list(
+ "e2e_one_time_keys_json", None, ["key_id"]
+ )
+ )
+ }
+ self.assertEqual(
+ remaining_key_ids, {"AAAAAAAAAA", "BAAAAA", "BAAAAB", "BAAAAAAAAA"}
+ )
+
+ @override_config(
+ {
+ "experimental_features": {
+ "msc4263_limit_key_queries_to_users_who_share_rooms": True
+ }
+ }
+ )
+ def test_query_devices_remote_restricted_not_in_shared_room(self) -> None:
+ """Tests that querying keys for a remote user that we don't share a room
+ with returns nothing.
+ """
+
+ remote_user_id = "@test:other"
+ local_user_id = "@test:test"
+
+ # Do *not* pretend we're sharing a room with the user we're querying.
+
+ remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
+ remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
+
+ self.hs.get_federation_client().query_client_keys = mock.AsyncMock( # type: ignore[method-assign]
+ return_value={
+ "device_keys": {remote_user_id: {}},
+ "master_keys": {
+ remote_user_id: {
+ "user_id": remote_user_id,
+ "usage": ["master"],
+ "keys": {"ed25519:" + remote_master_key: remote_master_key},
+ },
+ },
+ "self_signing_keys": {
+ remote_user_id: {
+ "user_id": remote_user_id,
+ "usage": ["self_signing"],
+ "keys": {
+ "ed25519:"
+ + remote_self_signing_key: remote_self_signing_key
+ },
+ }
+ },
+ }
+ )
+
+ e2e_handler = self.hs.get_e2e_keys_handler()
+
+ query_result = self.get_success(
+ e2e_handler.query_devices(
+ {
+ "device_keys": {remote_user_id: []},
+ },
+ timeout=10,
+ from_user_id=local_user_id,
+ from_device_id="some_device_id",
+ )
+ )
+
+ self.assertEqual(
+ query_result,
+ {
+ "device_keys": {},
+ "failures": {},
+ "master_keys": {},
+ "self_signing_keys": {},
+ "user_signing_keys": {},
+ },
+ )
+
+ @override_config(
+ {
+ "experimental_features": {
+ "msc4263_limit_key_queries_to_users_who_share_rooms": True
+ }
+ }
+ )
+ def test_query_devices_remote_restricted_in_shared_room(self) -> None:
+ """Tests that querying keys for a remote user that we share a room
+ with returns the cross signing keys correctly.
+ """
+
+ remote_user_id = "@test:other"
+ local_user_id = "@test:test"
+
+ # Pretend we're sharing a room with the user we're querying. If not,
+ # `query_devices` will filter out the user ID and `_query_devices_for_destination`
+ # will return early.
+ self.store.do_users_share_a_room_joined_or_invited = mock.AsyncMock( # type: ignore[method-assign]
+ return_value=[remote_user_id]
+ )
+ self.store.get_rooms_for_user = mock.AsyncMock(return_value={"some_room_id"})
+
+ remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
+ remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
+
+ self.hs.get_federation_client().query_user_devices = mock.AsyncMock( # type: ignore[method-assign]
+ return_value={
+ "user_id": remote_user_id,
+ "stream_id": 1,
+ "devices": [],
+ "master_key": {
+ "user_id": remote_user_id,
+ "usage": ["master"],
+ "keys": {"ed25519:" + remote_master_key: remote_master_key},
+ },
+ "self_signing_key": {
+ "user_id": remote_user_id,
+ "usage": ["self_signing"],
+ "keys": {
+ "ed25519:" + remote_self_signing_key: remote_self_signing_key
+ },
+ },
+ }
+ )
+
+ e2e_handler = self.hs.get_e2e_keys_handler()
+
+ query_result = self.get_success(
+ e2e_handler.query_devices(
+ {
+ "device_keys": {remote_user_id: []},
+ },
+ timeout=10,
+ from_user_id=local_user_id,
+ from_device_id="some_device_id",
+ )
+ )
+
+ self.assertEqual(query_result["failures"], {})
+ self.assertEqual(
+ query_result["master_keys"],
+ {
+ remote_user_id: {
+ "user_id": remote_user_id,
+ "usage": ["master"],
+ "keys": {"ed25519:" + remote_master_key: remote_master_key},
+ }
+ },
+ )
+ self.assertEqual(
+ query_result["self_signing_keys"],
+ {
+ remote_user_id: {
+ "user_id": remote_user_id,
+ "usage": ["self_signing"],
+ "keys": {
+ "ed25519:" + remote_self_signing_key: remote_self_signing_key
+ },
+ }
+ },
+ )
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 3fe5b0a1b4..b64a8a86a2 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -44,7 +44,7 @@ from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.storage.databases.main.events_worker import EventCacheEntry
from synapse.util import Clock
-from synapse.util.stringutils import random_string
+from synapse.util.events import generate_fake_event_id
from tests import unittest
from tests.test_utils import event_injection
@@ -52,10 +52,6 @@ from tests.test_utils import event_injection
logger = logging.getLogger(__name__)
-def generate_fake_event_id() -> str:
- return "$fake_" + random_string(43)
-
-
class FederationTestCase(unittest.FederatingHomeserverTestCase):
servlets = [
admin.register_servlets,
@@ -665,9 +661,12 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
)
)
- with patch.object(
- fed_client, "make_membership_event", mock_make_membership_event
- ), patch.object(fed_client, "send_join", mock_send_join):
+ with (
+ patch.object(
+ fed_client, "make_membership_event", mock_make_membership_event
+ ),
+ patch.object(fed_client, "send_join", mock_send_join),
+ ):
# Join and check that our join event is rejected
# (The join event is rejected because it doesn't have any signatures)
join_exc = self.get_failure(
@@ -712,9 +711,12 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
fed_handler = self.hs.get_federation_handler()
store = self.hs.get_datastores().main
- with patch.object(
- fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
- ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
+ with (
+ patch.object(
+ fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
+ ),
+ patch.object(store, "is_partial_state_room", mock_is_partial_state_room),
+ ):
# Start the partial state sync.
fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
@@ -764,9 +766,12 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
fed_handler = self.hs.get_federation_handler()
store = self.hs.get_datastores().main
- with patch.object(
- fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
- ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
+ with (
+ patch.object(
+ fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
+ ),
+ patch.object(store, "is_partial_state_room", mock_is_partial_state_room),
+ ):
# Start the partial state sync.
fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index 1b83aea579..51eca56c3b 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -288,13 +288,15 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
}
# We also expect an outbound request to /state
- self.mock_federation_transport_client.get_room_state.return_value = StateRequestResponse(
- # Mimic the other server not knowing about the state at all.
- # We want to cause Synapse to throw an error (`Unable to get
- # missing prev_event $fake_prev_event`) and fail to backfill
- # the pulled event.
- auth_events=[],
- state=[],
+ self.mock_federation_transport_client.get_room_state.return_value = (
+ StateRequestResponse(
+ # Mimic the other server not knowing about the state at all.
+ # We want to cause Synapse to throw an error (`Unable to get
+ # missing prev_event $fake_prev_event`) and fail to backfill
+ # the pulled event.
+ auth_events=[],
+ state=[],
+ )
)
pulled_event = make_event_from_dict(
@@ -373,7 +375,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
In this test, we pretend we are processing a "pulled" event via
backfill. The pulled event succesfully processes and the backward
- extremeties are updated along with clearing out any failed pull attempts
+ extremities are updated along with clearing out any failed pull attempts
for those old extremities.
We check that we correctly cleared failed pull attempts of the
@@ -805,6 +807,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
main_store = self.hs.get_datastores().main
+ state_deletion_store = self.hs.get_datastores().state_deletion
# Create the room.
kermit_user_id = self.register_user("kermit", "test")
@@ -956,7 +959,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
bert_member_event.event_id: bert_member_event,
rejected_kick_event.event_id: rejected_kick_event,
},
- state_res_store=StateResolutionStore(main_store),
+ state_res_store=StateResolutionStore(
+ main_store, state_deletion_store
+ ),
)
),
[bert_member_event.event_id, rejected_kick_event.event_id],
@@ -1001,7 +1006,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
rejected_power_levels_event.event_id,
],
event_map={},
- state_res_store=StateResolutionStore(main_store),
+ state_res_store=StateResolutionStore(
+ main_store, state_deletion_store
+ ),
full_conflicted_set=set(),
)
),
diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py
index 036c539db2..37acb660e7 100644
--- a/tests/handlers/test_oauth_delegation.py
+++ b/tests/handlers/test_oauth_delegation.py
@@ -43,6 +43,7 @@ from synapse.api.errors import (
OAuthInsufficientScopeError,
SynapseError,
)
+from synapse.appservice import ApplicationService
from synapse.http.site import SynapseRequest
from synapse.rest import admin
from synapse.rest.client import account, devices, keys, login, logout, register
@@ -146,6 +147,16 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
return hs
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
+ # Provision the user and the device we use in the tests.
+ store = homeserver.get_datastores().main
+ self.get_success(store.register_user(USER_ID))
+ self.get_success(
+ store.store_device(USER_ID, DEVICE, initial_device_display_name=None)
+ )
+
def _assertParams(self) -> None:
"""Assert that the request parameters are correct."""
params = parse_qs(self.http_client.request.call_args[1]["data"].decode("utf-8"))
@@ -379,6 +390,44 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
)
self.assertEqual(requester.device_id, DEVICE)
+ def test_active_user_with_device_explicit_device_id(self) -> None:
+ """The handler should return a requester with normal user rights and a device ID, given explicitly, as supported by MAS 0.15+"""
+
+ self.http_client.request = AsyncMock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={
+ "active": True,
+ "sub": SUBJECT,
+ "scope": " ".join([MATRIX_USER_SCOPE]),
+ "device_id": DEVICE,
+ "username": USERNAME,
+ },
+ )
+ )
+ request = Mock(args={})
+ request.args[b"access_token"] = [b"mockAccessToken"]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ requester = self.get_success(self.auth.get_user_by_req(request))
+ self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+ self.http_client.request.assert_called_once_with(
+ method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+ )
+ # It should have called with the 'X-MAS-Supports-Device-Id: 1' header
+ self.assertEqual(
+ self.http_client.request.call_args[1]["headers"].getRawHeaders(
+ b"X-MAS-Supports-Device-Id",
+ ),
+ [b"1"],
+ )
+ self._assertParams()
+ self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
+ self.assertEqual(requester.is_guest, False)
+ self.assertEqual(
+ get_awaitable_result(self.auth.is_server_admin(requester)), False
+ )
+ self.assertEqual(requester.device_id, DEVICE)
+
def test_multiple_devices(self) -> None:
"""The handler should raise an error if multiple devices are found in the scope."""
@@ -500,6 +549,44 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
self.assertEqual(error.value.code, 503)
+ def test_cached_expired_introspection(self) -> None:
+ """The handler should raise an error if the introspection response gives
+ an expiry time, the introspection response is cached and then the entry is
+ re-requested after it has expired."""
+
+ self.http_client.request = introspection_mock = AsyncMock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={
+ "active": True,
+ "sub": SUBJECT,
+ "scope": " ".join(
+ [
+ MATRIX_USER_SCOPE,
+ f"{MATRIX_DEVICE_SCOPE_PREFIX}AABBCC",
+ ]
+ ),
+ "username": USERNAME,
+ "expires_in": 60,
+ },
+ )
+ )
+ request = Mock(args={})
+ request.args[b"access_token"] = [b"mockAccessToken"]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+
+ # The first CS-API request causes a successful introspection
+ self.get_success(self.auth.get_user_by_req(request))
+ self.assertEqual(introspection_mock.call_count, 1)
+
+ # Sleep for 60 seconds so the token expires.
+ self.reactor.advance(60.0)
+
+ # Now the CS-API request fails because the token expired
+ self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
+ # Ensure another introspection request was not sent
+ self.assertEqual(introspection_mock.call_count, 1)
+
def make_device_keys(self, user_id: str, device_id: str) -> JsonDict:
# We only generate a master key to simplify the test.
master_signing_key = generate_signing_key(device_id)
@@ -550,7 +637,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
access_token="mockAccessToken",
)
- self.assertEqual(channel.code, HTTPStatus.NOT_IMPLEMENTED, channel.json_body)
+ self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.json_body)
def expect_unauthorized(
self, method: str, path: str, content: Union[bytes, str, JsonDict] = ""
@@ -560,15 +647,31 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
self.assertEqual(channel.code, 401, channel.json_body)
def expect_unrecognized(
- self, method: str, path: str, content: Union[bytes, str, JsonDict] = ""
+ self,
+ method: str,
+ path: str,
+ content: Union[bytes, str, JsonDict] = "",
+ auth: bool = False,
) -> None:
- channel = self.make_request(method, path, content)
+ channel = self.make_request(
+ method, path, content, access_token="token" if auth else None
+ )
self.assertEqual(channel.code, 404, channel.json_body)
self.assertEqual(
channel.json_body["errcode"], Codes.UNRECOGNIZED, channel.json_body
)
+ def expect_forbidden(
+ self, method: str, path: str, content: Union[bytes, str, JsonDict] = ""
+ ) -> None:
+ channel = self.make_request(method, path, content)
+
+ self.assertEqual(channel.code, 403, channel.json_body)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.FORBIDDEN, channel.json_body
+ )
+
def test_uia_endpoints(self) -> None:
"""Test that endpoints that were removed in MSC2964 are no longer available."""
@@ -580,36 +683,6 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
"POST", "/_matrix/client/v3/keys/device_signing/upload"
)
- def test_3pid_endpoints(self) -> None:
- """Test that 3pid account management endpoints that were removed in MSC2964 are no longer available."""
-
- # Remains and requires auth:
- self.expect_unauthorized("GET", "/_matrix/client/v3/account/3pid")
- self.expect_unauthorized(
- "POST",
- "/_matrix/client/v3/account/3pid/bind",
- {
- "client_secret": "foo",
- "id_access_token": "bar",
- "id_server": "foo",
- "sid": "bar",
- },
- )
- self.expect_unauthorized("POST", "/_matrix/client/v3/account/3pid/unbind", {})
-
- # These are gone:
- self.expect_unrecognized(
- "POST", "/_matrix/client/v3/account/3pid"
- ) # deprecated
- self.expect_unrecognized("POST", "/_matrix/client/v3/account/3pid/add")
- self.expect_unrecognized("POST", "/_matrix/client/v3/account/3pid/delete")
- self.expect_unrecognized(
- "POST", "/_matrix/client/v3/account/3pid/email/requestToken"
- )
- self.expect_unrecognized(
- "POST", "/_matrix/client/v3/account/3pid/msisdn/requestToken"
- )
-
def test_account_management_endpoints_removed(self) -> None:
"""Test that account management endpoints that were removed in MSC2964 are no longer available."""
self.expect_unrecognized("POST", "/_matrix/client/v3/account/deactivate")
@@ -623,11 +696,35 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_registration_endpoints_removed(self) -> None:
"""Test that registration endpoints that were removed in MSC2964 are no longer available."""
+ appservice = ApplicationService(
+ token="i_am_an_app_service",
+ id="1234",
+ namespaces={"users": [{"regex": r"@alice:.+", "exclusive": True}]},
+ sender="@as_main:test",
+ )
+
+ self.hs.get_datastores().main.services_cache = [appservice]
self.expect_unrecognized(
"GET", "/_matrix/client/v1/register/m.login.registration_token/validity"
)
+
+ # Registration is disabled
+ self.expect_forbidden(
+ "POST",
+ "/_matrix/client/v3/register",
+ {"username": "alice", "password": "hunter2"},
+ )
+
# This is still available for AS registrations
- # self.expect_unrecognized("POST", "/_matrix/client/v3/register")
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/register",
+ {"username": "alice", "type": "m.login.application_service"},
+ shorthand=False,
+ access_token="i_am_an_app_service",
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
self.expect_unrecognized("GET", "/_matrix/client/v3/register/available")
self.expect_unrecognized(
"POST", "/_matrix/client/v3/register/email/requestToken"
@@ -648,8 +745,25 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_device_management_endpoints_removed(self) -> None:
"""Test that device management endpoints that were removed in MSC2964 are no longer available."""
- self.expect_unrecognized("POST", "/_matrix/client/v3/delete_devices")
- self.expect_unrecognized("DELETE", "/_matrix/client/v3/devices/{DEVICE}")
+
+ # Because we still support those endpoints with ASes, it checks the
+ # access token before returning 404
+ self.http_client.request = AsyncMock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={
+ "active": True,
+ "sub": SUBJECT,
+ "scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]),
+ "username": USERNAME,
+ },
+ )
+ )
+
+ self.expect_unrecognized("POST", "/_matrix/client/v3/delete_devices", auth=True)
+ self.expect_unrecognized(
+ "DELETE", "/_matrix/client/v3/devices/{DEVICE}", auth=True
+ )
def test_openid_endpoints_removed(self) -> None:
"""Test that OpenID id_token endpoints that were removed in MSC2964 are no longer available."""
@@ -772,7 +886,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
req = SynapseRequest(channel, self.site) # type: ignore[arg-type]
req.client.host = MAS_IPV4_ADDR
req.requestHeaders.addRawHeader(
- "Authorization", f"Bearer {self.auth._admin_token}"
+ "Authorization", f"Bearer {self.auth._admin_token()}"
)
req.requestHeaders.addRawHeader("User-Agent", MAS_USER_AGENT)
req.content = BytesIO(b"")
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index a81501979d..ff8e3c5cb6 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -57,6 +57,7 @@ CLIENT_ID = "test-client-id"
CLIENT_SECRET = "test-client-secret"
BASE_URL = "https://synapse/"
CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback"
+TEST_REDIRECT_URI = "https://test/oidc/callback"
SCOPES = ["openid"]
# config for common cases
@@ -70,12 +71,16 @@ DEFAULT_CONFIG = {
}
# extends the default config with explicit OAuth2 endpoints instead of using discovery
+#
+# We add "explicit" to things to make them different from the discovered values to make
+# sure that the explicit values override the discovered ones.
EXPLICIT_ENDPOINT_CONFIG = {
**DEFAULT_CONFIG,
"discover": False,
- "authorization_endpoint": ISSUER + "authorize",
- "token_endpoint": ISSUER + "token",
- "jwks_uri": ISSUER + "jwks",
+ "authorization_endpoint": ISSUER + "authorize-explicit",
+ "token_endpoint": ISSUER + "token-explicit",
+ "jwks_uri": ISSUER + "jwks-explicit",
+ "id_token_signing_alg_values_supported": ["RS256", "<explicit>"],
}
@@ -259,12 +264,64 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.provider.load_metadata())
self.fake_server.get_metadata_handler.assert_not_called()
+ @override_config({"oidc_config": {**EXPLICIT_ENDPOINT_CONFIG, "discover": True}})
+ def test_discovery_with_explicit_config(self) -> None:
+ """
+ The handler should discover the endpoints from OIDC discovery document but
+ values are overriden by the explicit config.
+ """
+ # This would throw if some metadata were invalid
+ metadata = self.get_success(self.provider.load_metadata())
+ self.fake_server.get_metadata_handler.assert_called_once()
+
+ self.assertEqual(metadata.issuer, self.fake_server.issuer)
+ # It seems like authlib does not have that defined in its metadata models
+ self.assertEqual(
+ metadata.get("userinfo_endpoint"),
+ self.fake_server.userinfo_endpoint,
+ )
+
+ # Ensure the values are overridden correctly since these were configured
+ # explicitly
+ self.assertEqual(
+ metadata.authorization_endpoint,
+ EXPLICIT_ENDPOINT_CONFIG["authorization_endpoint"],
+ )
+ self.assertEqual(
+ metadata.token_endpoint, EXPLICIT_ENDPOINT_CONFIG["token_endpoint"]
+ )
+ self.assertEqual(metadata.jwks_uri, EXPLICIT_ENDPOINT_CONFIG["jwks_uri"])
+ self.assertEqual(
+ metadata.id_token_signing_alg_values_supported,
+ EXPLICIT_ENDPOINT_CONFIG["id_token_signing_alg_values_supported"],
+ )
+
+ # subsequent calls should be cached
+ self.reset_mocks()
+ self.get_success(self.provider.load_metadata())
+ self.fake_server.get_metadata_handler.assert_not_called()
+
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
def test_no_discovery(self) -> None:
"""When discovery is disabled, it should not try to load from discovery document."""
- self.get_success(self.provider.load_metadata())
+ metadata = self.get_success(self.provider.load_metadata())
self.fake_server.get_metadata_handler.assert_not_called()
+ # Ensure the values are overridden correctly since these were configured
+ # explicitly
+ self.assertEqual(
+ metadata.authorization_endpoint,
+ EXPLICIT_ENDPOINT_CONFIG["authorization_endpoint"],
+ )
+ self.assertEqual(
+ metadata.token_endpoint, EXPLICIT_ENDPOINT_CONFIG["token_endpoint"]
+ )
+ self.assertEqual(metadata.jwks_uri, EXPLICIT_ENDPOINT_CONFIG["jwks_uri"])
+ self.assertEqual(
+ metadata.id_token_signing_alg_values_supported,
+ EXPLICIT_ENDPOINT_CONFIG["id_token_signing_alg_values_supported"],
+ )
+
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_load_jwks(self) -> None:
"""JWKS loading is done once (then cached) if used."""
@@ -427,6 +484,32 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(code_verifier, "")
self.assertEqual(redirect, "http://client/redirect")
+ @override_config(
+ {
+ "oidc_config": {
+ **DEFAULT_CONFIG,
+ "passthrough_authorization_parameters": ["additional_parameter"],
+ }
+ }
+ )
+ def test_passthrough_parameters(self) -> None:
+ """The redirect request has additional parameters, one is authorized, one is not"""
+ req = Mock(spec=["cookies", "args"])
+ req.cookies = []
+ req.args = {}
+ req.args[b"additional_parameter"] = ["a_value".encode("utf-8")]
+ req.args[b"not_authorized_parameter"] = ["any".encode("utf-8")]
+
+ url = urlparse(
+ self.get_success(
+ self.provider.handle_redirect_request(req, b"http://client/redirect")
+ )
+ )
+
+ params = parse_qs(url.query)
+ self.assertEqual(params["additional_parameter"], ["a_value"])
+ self.assertNotIn("not_authorized_parameters", params)
+
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_redirect_request_with_code_challenge(self) -> None:
"""The redirect request has the right arguments & generates a valid session cookie."""
@@ -530,6 +613,24 @@ class OidcHandlerTestCase(HomeserverTestCase):
code_verifier = get_value_from_macaroon(macaroon, "code_verifier")
self.assertEqual(code_verifier, "")
+ @override_config(
+ {"oidc_config": {**DEFAULT_CONFIG, "redirect_uri": TEST_REDIRECT_URI}}
+ )
+ def test_redirect_request_with_overridden_redirect_uri(self) -> None:
+ """The authorization endpoint redirect has the overridden `redirect_uri` value."""
+ req = Mock(spec=["cookies"])
+ req.cookies = []
+
+ url = urlparse(
+ self.get_success(
+ self.provider.handle_redirect_request(req, b"http://client/redirect")
+ )
+ )
+
+ # Ensure that the redirect_uri in the returned url has been overridden.
+ params = parse_qs(url.query)
+ self.assertEqual(params["redirect_uri"], [TEST_REDIRECT_URI])
+
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback_error(self) -> None:
"""Errors from the provider returned in the callback are displayed."""
@@ -901,6 +1002,81 @@ class OidcHandlerTestCase(HomeserverTestCase):
{
"oidc_config": {
**DEFAULT_CONFIG,
+ "redirect_uri": TEST_REDIRECT_URI,
+ }
+ }
+ )
+ def test_code_exchange_with_overridden_redirect_uri(self) -> None:
+ """Code exchange behaves correctly and handles various error scenarios."""
+ # Set up a fake IdP with a token endpoint handler.
+ token = {
+ "type": "Bearer",
+ "access_token": "aabbcc",
+ }
+
+ self.fake_server.post_token_handler.side_effect = None
+ self.fake_server.post_token_handler.return_value = FakeResponse.json(
+ payload=token
+ )
+ code = "code"
+
+ # Exchange the code against the fake IdP.
+ self.get_success(self.provider._exchange_code(code, code_verifier=""))
+
+ # Check that the `redirect_uri` parameter provided matches our
+ # overridden config value.
+ kwargs = self.fake_server.request.call_args[1]
+ args = parse_qs(kwargs["data"].decode("utf-8"))
+ self.assertEqual(args["redirect_uri"], [TEST_REDIRECT_URI])
+
+ @override_config(
+ {
+ "oidc_config": {
+ **DEFAULT_CONFIG,
+ "redirect_uri": TEST_REDIRECT_URI,
+ }
+ }
+ )
+ def test_code_exchange_ignores_access_token(self) -> None:
+ """
+ Code exchange completes successfully and doesn't validate the `at_hash`
+ (access token hash) field of an ID token when the access token isn't
+ going to be used.
+
+ The access token won't be used in this test because Synapse (currently)
+ only needs it to fetch a user's metadata if it isn't included in the ID
+ token itself.
+
+ Because we have included "openid" in the requested scopes for this IdP
+ (see `SCOPES`), user metadata is be included in the ID token. Thus the
+ access token isn't needed, and it's unnecessary for Synapse to validate
+ the access token.
+
+ This is a regression test for a situation where an upstream identity
+ provider was providing an invalid `at_hash` value, which Synapse errored
+ on, yet Synapse wasn't using the access token for anything.
+ """
+ # Exchange the code against the fake IdP.
+ userinfo = {
+ "sub": "foo",
+ "username": "foo",
+ "phone": "1234567",
+ }
+ with self.fake_server.id_token_override(
+ {
+ "at_hash": "invalid-hash",
+ }
+ ):
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+
+ # If no error was rendered, then we have success.
+ self.render_error.assert_not_called()
+
+ @override_config(
+ {
+ "oidc_config": {
+ **DEFAULT_CONFIG,
"user_mapping_provider": {
"module": __name__ + ".TestMappingProviderExtra"
},
@@ -1271,6 +1447,113 @@ class OidcHandlerTestCase(HomeserverTestCase):
{
"oidc_config": {
**DEFAULT_CONFIG,
+ "attribute_requirements": [
+ {"attribute": "test", "one_of": ["foo", "bar"]}
+ ],
+ }
+ }
+ )
+ def test_attribute_requirements_one_of_succeeds(self) -> None:
+ """Test that auth succeeds if userinfo attribute has multiple values and CONTAINS required value"""
+ # userinfo with "test": ["bar"] attribute should succeed.
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": ["bar"],
+ }
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+
+ # check that the auth handler got called as expected
+ self.complete_sso_login.assert_called_once_with(
+ "@tester:test",
+ self.provider.idp_id,
+ request,
+ ANY,
+ None,
+ new_user=True,
+ auth_provider_session_id=None,
+ )
+
+ @override_config(
+ {
+ "oidc_config": {
+ **DEFAULT_CONFIG,
+ "attribute_requirements": [
+ {"attribute": "test", "one_of": ["foo", "bar"]}
+ ],
+ }
+ }
+ )
+ def test_attribute_requirements_one_of_fails(self) -> None:
+ """Test that auth fails if userinfo attribute has multiple values yet
+ DOES NOT CONTAIN a required value
+ """
+ # userinfo with "test": ["something else"] attribute should fail.
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": ["something else"],
+ }
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
+
+ @override_config(
+ {
+ "oidc_config": {
+ **DEFAULT_CONFIG,
+ "attribute_requirements": [{"attribute": "test"}],
+ }
+ }
+ )
+ def test_attribute_requirements_does_not_exist(self) -> None:
+ """OIDC login fails if the required attribute does not exist in the OIDC userinfo response."""
+ # userinfo lacking "test" attribute should fail.
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ }
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
+
+ @override_config(
+ {
+ "oidc_config": {
+ **DEFAULT_CONFIG,
+ "attribute_requirements": [{"attribute": "test"}],
+ }
+ }
+ )
+ def test_attribute_requirements_exist(self) -> None:
+ """OIDC login succeeds if the required attribute exist (regardless of value)
+ in the OIDC userinfo response.
+ """
+ # userinfo with "test" attribute and random value should succeed.
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": random_string(5), # value does not matter
+ }
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+
+ # check that the auth handler got called as expected
+ self.complete_sso_login.assert_called_once_with(
+ "@tester:test",
+ self.provider.idp_id,
+ request,
+ ANY,
+ None,
+ new_user=True,
+ auth_provider_session_id=None,
+ )
+
+ @override_config(
+ {
+ "oidc_config": {
+ **DEFAULT_CONFIG,
"attribute_requirements": [{"attribute": "test", "value": "foobar"}],
}
}
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index ed203eb299..d0351a8509 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -768,17 +768,6 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# Check that the callback has been called.
m.assert_called_once()
- # Set some email configuration so the test doesn't fail because of its absence.
- @override_config({"email": {"notif_from": "noreply@test"}})
- def test_3pid_allowed(self) -> None:
- """Tests that an is_3pid_allowed_callbacks forbidding a 3PID makes Synapse refuse
- to bind the new 3PID, and that one allowing a 3PID makes Synapse accept to bind
- the 3PID. Also checks that the module is passed a boolean indicating whether the
- user to bind this 3PID to is currently registering.
- """
- self._test_3pid_allowed("rin", False)
- self._test_3pid_allowed("kitay", True)
-
def test_displayname(self) -> None:
"""Tests that the get_displayname_for_registration callback can define the
display name of a user when registering.
@@ -829,66 +818,6 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# Check that the callback has been called.
m.assert_called_once()
- def _test_3pid_allowed(self, username: str, registration: bool) -> None:
- """Tests that the "is_3pid_allowed" module callback is called correctly, using
- either /register or /account URLs depending on the arguments.
-
- Args:
- username: The username to use for the test.
- registration: Whether to test with registration URLs.
- """
- self.hs.get_identity_handler().send_threepid_validation = AsyncMock( # type: ignore[method-assign]
- return_value=0
- )
-
- m = AsyncMock(return_value=False)
- self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
-
- self.register_user(username, "password")
- tok = self.login(username, "password")
-
- if registration:
- url = "/register/email/requestToken"
- else:
- url = "/account/3pid/email/requestToken"
-
- channel = self.make_request(
- "POST",
- url,
- {
- "client_secret": "foo",
- "email": "foo@test.com",
- "send_attempt": 0,
- },
- access_token=tok,
- )
- self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
- self.assertEqual(
- channel.json_body["errcode"],
- Codes.THREEPID_DENIED,
- channel.json_body,
- )
-
- m.assert_called_once_with("email", "foo@test.com", registration)
-
- m = AsyncMock(return_value=True)
- self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
-
- channel = self.make_request(
- "POST",
- url,
- {
- "client_secret": "foo",
- "email": "bar@test.com",
- "send_attempt": 0,
- },
- access_token=tok,
- )
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
- self.assertIn("sid", channel.json_body)
-
- m.assert_called_once_with("email", "bar@test.com", registration)
-
def _setup_get_name_for_registration(self, callback_name: str) -> Mock:
"""Registers either a get_username_for_registration callback or a
get_displayname_for_registration callback that appends "-foo" to the username the
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index cc630d606c..6b7bf112c2 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -23,14 +23,21 @@ from typing import Optional, cast
from unittest.mock import Mock, call
from parameterized import parameterized
-from signedjson.key import generate_signing_key
+from signedjson.key import (
+ encode_verify_key_base64,
+ generate_signing_key,
+ get_verify_key,
+)
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.presence import UserDevicePresenceState, UserPresenceState
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
-from synapse.events.builder import EventBuilder
+from synapse.api.room_versions import (
+ RoomVersion,
+)
+from synapse.crypto.event_signing import add_hashes_and_signatures
+from synapse.events import EventBase, make_event_from_dict
from synapse.federation.sender import FederationSender
from synapse.handlers.presence import (
BUSY_ONLINE_TIMEOUT,
@@ -45,18 +52,24 @@ from synapse.handlers.presence import (
handle_update,
)
from synapse.rest import admin
-from synapse.rest.client import room
+from synapse.rest.client import login, room, sync
from synapse.server import HomeServer
from synapse.storage.database import LoggingDatabaseConnection
+from synapse.storage.keys import FetchKeyResult
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util import Clock
from tests import unittest
from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.unittest import override_config
class PresenceUpdateTestCase(unittest.HomeserverTestCase):
- servlets = [admin.register_servlets]
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ ]
def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
@@ -425,6 +438,102 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
wheel_timer.insert.assert_not_called()
+ # `rc_presence` is set very high during unit tests to avoid ratelimiting
+ # subtly impacting unrelated tests. We set the ratelimiting back to a
+ # reasonable value for the tests specific to presence ratelimiting.
+ @override_config(
+ {"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}}
+ )
+ def test_over_ratelimit_offline_to_online_to_unavailable(self) -> None:
+ """
+ Send a presence update, check that it went through, immediately send another one and
+ check that it was ignored.
+ """
+ self._test_ratelimit_offline_to_online_to_unavailable(ratelimited=True)
+
+ @override_config(
+ {"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}}
+ )
+ def test_within_ratelimit_offline_to_online_to_unavailable(self) -> None:
+ """
+ Send a presence update, check that it went through, advancing time a sufficient amount,
+ send another presence update and check that it also worked.
+ """
+ self._test_ratelimit_offline_to_online_to_unavailable(ratelimited=False)
+
+ @override_config(
+ {"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}}
+ )
+ def _test_ratelimit_offline_to_online_to_unavailable(
+ self, ratelimited: bool
+ ) -> None:
+ """Test rate limit for presence updates sent with sync requests.
+
+ Args:
+ ratelimited: Test rate limited case.
+ """
+ wheel_timer = Mock()
+ user_id = "@user:pass"
+ now = 5000000
+ sync_url = "/sync?access_token=%s&set_presence=%s"
+
+ # Register the user who syncs presence
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Get the handler (which kicks off a bunch of timers).
+ presence_handler = self.hs.get_presence_handler()
+
+ # Ensure the user is initially offline.
+ prev_state = UserPresenceState.default(user_id)
+ new_state = prev_state.copy_and_replace(
+ state=PresenceState.OFFLINE, last_active_ts=now
+ )
+
+ state, persist_and_notify, federation_ping = handle_update(
+ prev_state,
+ new_state,
+ is_mine=True,
+ wheel_timer=wheel_timer,
+ now=now,
+ persist=False,
+ )
+
+ # Check that the user is offline.
+ state = self.get_success(
+ presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, PresenceState.OFFLINE)
+
+ # Send sync request with set_presence=online.
+ channel = self.make_request("GET", sync_url % (access_token, "online"))
+ self.assertEqual(200, channel.code)
+
+ # Assert the user is now online.
+ state = self.get_success(
+ presence_handler.get_state(UserID.from_string(user_id))
+ )
+ self.assertEqual(state.state, PresenceState.ONLINE)
+
+ if not ratelimited:
+ # Advance time a sufficient amount to avoid rate limiting.
+ self.reactor.advance(30)
+
+ # Send another sync request with set_presence=unavailable.
+ channel = self.make_request("GET", sync_url % (access_token, "unavailable"))
+ self.assertEqual(200, channel.code)
+
+ state = self.get_success(
+ presence_handler.get_state(UserID.from_string(user_id))
+ )
+
+ if ratelimited:
+ # Assert the user is still online and presence update was ignored.
+ self.assertEqual(state.state, PresenceState.ONLINE)
+ else:
+ # Assert the user is now unavailable.
+ self.assertEqual(state.state, PresenceState.UNAVAILABLE)
+
class PresenceTimeoutTestCase(unittest.TestCase):
"""Tests different timers and that the timer does not change `status_msg` of user."""
@@ -1107,7 +1216,9 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
),
]
],
- name_func=lambda testcase_func, param_num, params: f"{testcase_func.__name__}_{param_num}_{'workers' if params.args[5] else 'monolith'}",
+ name_func=lambda testcase_func,
+ param_num,
+ params: f"{testcase_func.__name__}_{param_num}_{'workers' if params.args[5] else 'monolith'}",
)
@unittest.override_config({"experimental_features": {"msc3026_enabled": True}})
def test_set_presence_from_syncing_multi_device(
@@ -1343,7 +1454,9 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
),
]
],
- name_func=lambda testcase_func, param_num, params: f"{testcase_func.__name__}_{param_num}_{'workers' if params.args[4] else 'monolith'}",
+ name_func=lambda testcase_func,
+ param_num,
+ params: f"{testcase_func.__name__}_{param_num}_{'workers' if params.args[4] else 'monolith'}",
)
@unittest.override_config({"experimental_features": {"msc3026_enabled": True}})
def test_set_presence_from_non_syncing_multi_device(
@@ -1821,6 +1934,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
# self.event_builder_for_2.hostname = "test2"
self.store = hs.get_datastores().main
+ self.storage_controllers = hs.get_storage_controllers()
self.state = hs.get_state_handler()
self._event_auth_handler = hs.get_event_auth_handler()
@@ -1936,29 +2050,35 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
hostname = get_domain_from_id(user_id)
- room_version = self.get_success(self.store.get_room_version_id(room_id))
+ room_version = self.get_success(self.store.get_room_version(room_id))
- builder = EventBuilder(
- state=self.state,
- event_auth_handler=self._event_auth_handler,
- store=self.store,
- clock=self.clock,
- hostname=hostname,
- signing_key=self.random_signing_key,
- room_version=KNOWN_ROOM_VERSIONS[room_version],
- room_id=room_id,
- type=EventTypes.Member,
- sender=user_id,
- state_key=user_id,
- content={"membership": Membership.JOIN},
+ state_map = self.get_success(
+ self.storage_controllers.state.get_current_state(room_id)
)
- prev_event_ids = self.get_success(
- self.store.get_latest_event_ids_in_room(room_id)
+ # Figure out what the forward extremities in the room are (the most recent
+ # events that aren't tied into the DAG)
+ forward_extremity_event_ids = self.get_success(
+ self.hs.get_datastores().main.get_latest_event_ids_in_room(room_id)
)
- event = self.get_success(
- builder.build(prev_event_ids=list(prev_event_ids), auth_event_ids=None)
+ event = self.create_fake_event_from_remote_server(
+ remote_server_name=hostname,
+ event_dict={
+ "room_id": room_id,
+ "sender": user_id,
+ "type": EventTypes.Member,
+ "state_key": user_id,
+ "depth": 1000,
+ "origin_server_ts": 1,
+ "content": {"membership": Membership.JOIN},
+ "auth_events": [
+ state_map[(EventTypes.Create, "")].event_id,
+ state_map[(EventTypes.JoinRules, "")].event_id,
+ ],
+ "prev_events": list(forward_extremity_event_ids),
+ },
+ room_version=room_version,
)
self.get_success(self.federation_event_handler.on_receive_pdu(hostname, event))
@@ -1966,3 +2086,50 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
# Check that it was successfully persisted.
self.get_success(self.store.get_event(event.event_id))
self.get_success(self.store.get_event(event.event_id))
+
+ def create_fake_event_from_remote_server(
+ self, remote_server_name: str, event_dict: JsonDict, room_version: RoomVersion
+ ) -> EventBase:
+ """
+ This is similar to what `FederatingHomeserverTestCase` is doing but we don't
+ need all of the extra baggage and we want to be able to create an event from
+ many remote servers.
+ """
+
+ # poke the other server's signing key into the key store, so that we don't
+ # make requests for it
+ other_server_signature_key = generate_signing_key("test")
+ verify_key = get_verify_key(other_server_signature_key)
+ verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version)
+
+ self.get_success(
+ self.hs.get_datastores().main.store_server_keys_response(
+ remote_server_name,
+ from_server=remote_server_name,
+ ts_added_ms=self.clock.time_msec(),
+ verify_keys={
+ verify_key_id: FetchKeyResult(
+ verify_key=verify_key,
+ valid_until_ts=self.clock.time_msec() + 10000,
+ ),
+ },
+ response_json={
+ "verify_keys": {
+ verify_key_id: {"key": encode_verify_key_base64(verify_key)}
+ }
+ },
+ )
+ )
+
+ add_hashes_and_signatures(
+ room_version=room_version,
+ event_dict=event_dict,
+ signature_name=remote_server_name,
+ signing_key=other_server_signature_key,
+ )
+ event = make_event_from_dict(
+ event_dict,
+ room_version=room_version,
+ )
+
+ return event
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index cb1c6fbb80..2b9b56da95 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -369,6 +369,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
time_now_ms=self.clock.time_msec(),
upload_name=None,
filesystem_id="xyz",
+ sha256="abcdefg12345",
)
)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 92487692db..99bd0de834 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -588,6 +588,29 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
d = self.store.is_support_user(user_id)
self.assertFalse(self.get_success(d))
+ def test_underscore_localpart_rejected_by_default(self) -> None:
+ for invalid_user_id in ("_", "_prefixed"):
+ with self.subTest(invalid_user_id=invalid_user_id):
+ self.get_failure(
+ self.handler.register_user(localpart=invalid_user_id),
+ SynapseError,
+ )
+
+ @override_config(
+ {
+ "allow_underscore_prefixed_localpart": True,
+ }
+ )
+ def test_underscore_localpart_allowed_if_configured(self) -> None:
+ for valid_user_id in ("_", "_prefixed"):
+ with self.subTest(valid_user_id=valid_user_id):
+ user_id = self.get_success(
+ self.handler.register_user(
+ localpart=valid_user_id,
+ ),
+ )
+ self.assertEqual(user_id, f"@{valid_user_id}:test")
+
def test_invalid_user_id(self) -> None:
invalid_user_id = "^abcd"
self.get_failure(
@@ -715,6 +738,41 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart="bobflimflob", auth_provider_id="saml")
)
+ def test_register_default_user_type(self) -> None:
+ """Test that the default user type is none when registering a user."""
+ user_id = self.get_success(self.handler.register_user(localpart="user"))
+ user_info = self.get_success(self.store.get_user_by_id(user_id))
+ assert user_info is not None
+ self.assertEqual(user_info.user_type, None)
+
+ def test_register_extra_user_types_valid(self) -> None:
+ """
+ Test that the specified user type is set correctly when registering a user.
+ n.b. No validation is done on the user type, so this test
+ is only to ensure that the user type can be set to any value.
+ """
+ user_id = self.get_success(
+ self.handler.register_user(localpart="user", user_type="anyvalue")
+ )
+ user_info = self.get_success(self.store.get_user_by_id(user_id))
+ assert user_info is not None
+ self.assertEqual(user_info.user_type, "anyvalue")
+
+ @override_config(
+ {
+ "user_types": {
+ "extra_user_types": ["extra1", "extra2"],
+ "default_user_type": "extra1",
+ }
+ }
+ )
+ def test_register_extra_user_types_with_default(self) -> None:
+ """Test that the default_user_type in config is set correctly when registering a user."""
+ user_id = self.get_success(self.handler.register_user(localpart="user"))
+ user_info = self.get_success(self.store.get_user_by_id(user_id))
+ assert user_info is not None
+ self.assertEqual(user_info.user_type, "extra1")
+
async def get_or_create_user(
self,
requester: Requester,
diff --git a/tests/handlers/test_room_list.py b/tests/handlers/test_room_list.py
index 4d22ef98c2..45cef09b22 100644
--- a/tests/handlers/test_room_list.py
+++ b/tests/handlers/test_room_list.py
@@ -6,6 +6,7 @@ from synapse.rest.client import directory, login, room
from synapse.types import JsonDict
from tests import unittest
+from tests.utils import default_config
class RoomListHandlerTestCase(unittest.HomeserverTestCase):
@@ -30,6 +31,11 @@ class RoomListHandlerTestCase(unittest.HomeserverTestCase):
assert channel.code == HTTPStatus.OK, f"couldn't publish room: {channel.result}"
return room_id
+ def default_config(self) -> JsonDict:
+ config = default_config("test")
+ config["room_list_publication_rules"] = [{"action": "allow"}]
+ return config
+
def test_acls_applied_to_room_directory_results(self) -> None:
"""
Creates 3 rooms. Room 2 has an ACL that only permits the homeservers
diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py
index 213a66ed1a..d87fe9d62c 100644
--- a/tests/handlers/test_room_member.py
+++ b/tests/handlers/test_room_member.py
@@ -5,10 +5,13 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
import synapse.rest.client.login
import synapse.rest.client.room
-from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import LimitExceededError, SynapseError
+from synapse.api.constants import AccountDataTypes, EventTypes, Membership
+from synapse.api.errors import Codes, LimitExceededError, SynapseError
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events import FrozenEventV3
+from synapse.federation.federation_base import (
+ event_from_pdu_json,
+)
from synapse.federation.federation_client import SendJoinResult
from synapse.server import HomeServer
from synapse.types import UserID, create_requester
@@ -172,20 +175,25 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
)
)
- with patch.object(
- self.handler.federation_handler.federation_client,
- "make_membership_event",
- mock_make_membership_event,
- ), patch.object(
- self.handler.federation_handler.federation_client,
- "send_join",
- mock_send_join,
- ), patch(
- "synapse.event_auth._is_membership_change_allowed",
- return_value=None,
- ), patch(
- "synapse.handlers.federation_event.check_state_dependent_auth_rules",
- return_value=None,
+ with (
+ patch.object(
+ self.handler.federation_handler.federation_client,
+ "make_membership_event",
+ mock_make_membership_event,
+ ),
+ patch.object(
+ self.handler.federation_handler.federation_client,
+ "send_join",
+ mock_send_join,
+ ),
+ patch(
+ "synapse.event_auth._is_membership_change_allowed",
+ return_value=None,
+ ),
+ patch(
+ "synapse.handlers.federation_event.check_state_dependent_auth_rules",
+ return_value=None,
+ ),
):
self.get_success(
self.handler.update_membership(
@@ -380,9 +388,29 @@ class RoomMemberMasterHandlerTestCase(HomeserverTestCase):
)
def test_forget_when_not_left(self) -> None:
- """Tests that a user cannot not forgets a room that has not left."""
+ """Tests that a user cannot forget a room that they are still in."""
self.get_failure(self.handler.forget(self.alice_ID, self.room_id), SynapseError)
+ def test_nonlocal_room_user_action(self) -> None:
+ """
+ Test that non-local user ids cannot perform room actions through
+ this homeserver.
+ """
+ alien_user_id = UserID.from_string("@cheeky_monkey:matrix.org")
+ bad_room_id = f"{self.room_id}+BAD_ID"
+
+ exc = self.get_failure(
+ self.handler.update_membership(
+ create_requester(self.alice),
+ alien_user_id,
+ bad_room_id,
+ "unban",
+ ),
+ SynapseError,
+ ).value
+
+ self.assertEqual(exc.errcode, Codes.BAD_JSON)
+
def test_rejoin_forgotten_by_user(self) -> None:
"""Test that a user that has forgotten a room can do a re-join.
The room was not forgotten from the local server.
@@ -428,3 +456,165 @@ class RoomMemberMasterHandlerTestCase(HomeserverTestCase):
new_count = rows[0][0]
self.assertEqual(initial_count, new_count)
+
+
+class TestInviteFiltering(FederatingHomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ synapse.rest.client.login.register_servlets,
+ synapse.rest.client.room.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.handler = hs.get_room_member_handler()
+ self.fed_handler = hs.get_federation_handler()
+ self.store = hs.get_datastores().main
+
+ # Create three users.
+ self.alice = self.register_user("alice", "pass")
+ self.alice_token = self.login("alice", "pass")
+ self.bob = self.register_user("bob", "pass")
+ self.bob_token = self.login("bob", "pass")
+
+ @override_config({"experimental_features": {"msc4155_enabled": True}})
+ def test_misc4155_block_invite_local(self) -> None:
+ """Test that MSC4155 will block a user from being invited to a room"""
+ room_id = self.helper.create_room_as(self.alice, tok=self.alice_token)
+
+ self.get_success(
+ self.store.add_account_data_for_user(
+ self.bob,
+ AccountDataTypes.MSC4155_INVITE_PERMISSION_CONFIG,
+ {
+ "blocked_users": [self.alice],
+ },
+ )
+ )
+
+ f = self.get_failure(
+ self.handler.update_membership(
+ requester=create_requester(self.alice),
+ target=UserID.from_string(self.bob),
+ room_id=room_id,
+ action=Membership.INVITE,
+ ),
+ SynapseError,
+ ).value
+ self.assertEqual(f.code, 403)
+ self.assertEqual(f.errcode, "ORG.MATRIX.MSC4155.M_INVITE_BLOCKED")
+
+ @override_config({"experimental_features": {"msc4155_enabled": False}})
+ def test_msc4155_disabled_allow_invite_local(self) -> None:
+ """Test that MSC4155 will block a user from being invited to a room"""
+ room_id = self.helper.create_room_as(self.alice, tok=self.alice_token)
+
+ self.get_success(
+ self.store.add_account_data_for_user(
+ self.bob,
+ AccountDataTypes.MSC4155_INVITE_PERMISSION_CONFIG,
+ {
+ "blocked_users": [self.alice],
+ },
+ )
+ )
+
+ self.get_success(
+ self.handler.update_membership(
+ requester=create_requester(self.alice),
+ target=UserID.from_string(self.bob),
+ room_id=room_id,
+ action=Membership.INVITE,
+ ),
+ )
+
+ @override_config({"experimental_features": {"msc4155_enabled": True}})
+ def test_msc4155_block_invite_remote(self) -> None:
+ """Test that MSC4155 will block a remote user from being invited to a room"""
+ # A remote user who sends the invite
+ remote_server = "otherserver"
+ remote_user = "@otheruser:" + remote_server
+
+ self.get_success(
+ self.store.add_account_data_for_user(
+ self.bob,
+ AccountDataTypes.MSC4155_INVITE_PERMISSION_CONFIG,
+ {"blocked_users": [remote_user]},
+ )
+ )
+
+ room_id = self.helper.create_room_as(
+ room_creator=self.alice, tok=self.alice_token
+ )
+ room_version = self.get_success(self.store.get_room_version(room_id))
+
+ invite_event = event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "content": {"membership": "invite"},
+ "room_id": room_id,
+ "sender": remote_user,
+ "state_key": self.bob,
+ "depth": 32,
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ room_version,
+ )
+
+ f = self.get_failure(
+ self.fed_handler.on_invite_request(
+ remote_server,
+ invite_event,
+ invite_event.room_version,
+ ),
+ SynapseError,
+ ).value
+ self.assertEqual(f.code, 403)
+ self.assertEqual(f.errcode, "ORG.MATRIX.MSC4155.M_INVITE_BLOCKED")
+
+ @override_config({"experimental_features": {"msc4155_enabled": True}})
+ def test_msc4155_block_invite_remote_server(self) -> None:
+ """Test that MSC4155 will block a remote server's user from being invited to a room"""
+ # A remote user who sends the invite
+ remote_server = "otherserver"
+ remote_user = "@otheruser:" + remote_server
+
+ self.get_success(
+ self.store.add_account_data_for_user(
+ self.bob,
+ AccountDataTypes.MSC4155_INVITE_PERMISSION_CONFIG,
+ {"blocked_servers": [remote_server]},
+ )
+ )
+
+ room_id = self.helper.create_room_as(
+ room_creator=self.alice, tok=self.alice_token
+ )
+ room_version = self.get_success(self.store.get_room_version(room_id))
+
+ invite_event = event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "content": {"membership": "invite"},
+ "room_id": room_id,
+ "sender": remote_user,
+ "state_key": self.bob,
+ "depth": 32,
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ room_version,
+ )
+
+ f = self.get_failure(
+ self.fed_handler.on_invite_request(
+ remote_server,
+ invite_event,
+ invite_event.room_version,
+ ),
+ SynapseError,
+ ).value
+ self.assertEqual(f.code, 403)
+ self.assertEqual(f.errcode, "ORG.MATRIX.MSC4155.M_INVITE_BLOCKED")
diff --git a/tests/handlers/test_room_policy.py b/tests/handlers/test_room_policy.py
new file mode 100644
index 0000000000..26642c18ea
--- /dev/null
+++ b/tests/handlers/test_room_policy.py
@@ -0,0 +1,226 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2025 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+#
+from typing import Optional
+from unittest import mock
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.events import EventBase, make_event_from_dict
+from synapse.rest import admin
+from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.types import JsonDict, UserID
+from synapse.types.handlers.policy_server import RECOMMENDATION_OK, RECOMMENDATION_SPAM
+from synapse.util import Clock
+
+from tests import unittest
+from tests.test_utils import event_injection
+
+
+class RoomPolicyTestCase(unittest.FederatingHomeserverTestCase):
+ """Tests room policy handler."""
+
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ # mock out the federation transport client
+ self.mock_federation_transport_client = mock.Mock(
+ spec=["get_policy_recommendation_for_pdu"]
+ )
+ self.mock_federation_transport_client.get_policy_recommendation_for_pdu = (
+ mock.AsyncMock()
+ )
+ return super().setup_test_homeserver(
+ federation_transport_client=self.mock_federation_transport_client
+ )
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.hs = hs
+ self.handler = hs.get_room_policy_handler()
+ main_store = self.hs.get_datastores().main
+
+ # Create a room
+ self.creator = self.register_user("creator", "test1234")
+ self.creator_token = self.login("creator", "test1234")
+ self.room_id = self.helper.create_room_as(
+ room_creator=self.creator, tok=self.creator_token
+ )
+ room_version = self.get_success(main_store.get_room_version(self.room_id))
+
+ # Create some sample events
+ self.spammy_event = make_event_from_dict(
+ room_version=room_version,
+ internal_metadata_dict={},
+ event_dict={
+ "room_id": self.room_id,
+ "type": "m.room.message",
+ "sender": "@spammy:example.org",
+ "content": {
+ "msgtype": "m.text",
+ "body": "This is a spammy event.",
+ },
+ },
+ )
+ self.not_spammy_event = make_event_from_dict(
+ room_version=room_version,
+ internal_metadata_dict={},
+ event_dict={
+ "room_id": self.room_id,
+ "type": "m.room.message",
+ "sender": "@not_spammy:example.org",
+ "content": {
+ "msgtype": "m.text",
+ "body": "This is a NOT spammy event.",
+ },
+ },
+ )
+
+ # Prepare the policy server mock to decide spam vs not spam on those events
+ self.call_count = 0
+
+ async def get_policy_recommendation_for_pdu(
+ destination: str,
+ pdu: EventBase,
+ timeout: Optional[int] = None,
+ ) -> JsonDict:
+ self.call_count += 1
+ self.assertEqual(destination, self.OTHER_SERVER_NAME)
+ if pdu.event_id == self.spammy_event.event_id:
+ return {"recommendation": RECOMMENDATION_SPAM}
+ elif pdu.event_id == self.not_spammy_event.event_id:
+ return {"recommendation": RECOMMENDATION_OK}
+ else:
+ self.fail("Unexpected event ID")
+
+ self.mock_federation_transport_client.get_policy_recommendation_for_pdu.side_effect = get_policy_recommendation_for_pdu
+
+ def _add_policy_server_to_room(self) -> None:
+ # Inject a member event into the room
+ policy_user_id = f"@policy:{self.OTHER_SERVER_NAME}"
+ self.get_success(
+ event_injection.inject_member_event(
+ self.hs, self.room_id, policy_user_id, "join"
+ )
+ )
+ self.helper.send_state(
+ self.room_id,
+ "org.matrix.msc4284.policy",
+ {
+ "via": self.OTHER_SERVER_NAME,
+ },
+ tok=self.creator_token,
+ state_key="",
+ )
+
+ def test_no_policy_event_set(self) -> None:
+ # We don't need to modify the room state at all - we're testing the default
+ # case where a room doesn't use a policy server.
+ ok = self.get_success(self.handler.is_event_allowed(self.spammy_event))
+ self.assertEqual(ok, True)
+ self.assertEqual(self.call_count, 0)
+
+ def test_empty_policy_event_set(self) -> None:
+ self.helper.send_state(
+ self.room_id,
+ "org.matrix.msc4284.policy",
+ {
+ # empty content (no `via`)
+ },
+ tok=self.creator_token,
+ state_key="",
+ )
+
+ ok = self.get_success(self.handler.is_event_allowed(self.spammy_event))
+ self.assertEqual(ok, True)
+ self.assertEqual(self.call_count, 0)
+
+ def test_nonstring_policy_event_set(self) -> None:
+ self.helper.send_state(
+ self.room_id,
+ "org.matrix.msc4284.policy",
+ {
+ "via": 42, # should be a server name
+ },
+ tok=self.creator_token,
+ state_key="",
+ )
+
+ ok = self.get_success(self.handler.is_event_allowed(self.spammy_event))
+ self.assertEqual(ok, True)
+ self.assertEqual(self.call_count, 0)
+
+ def test_self_policy_event_set(self) -> None:
+ self.helper.send_state(
+ self.room_id,
+ "org.matrix.msc4284.policy",
+ {
+ # We ignore events when the policy server is ourselves (for now?)
+ "via": (UserID.from_string(self.creator)).domain,
+ },
+ tok=self.creator_token,
+ state_key="",
+ )
+
+ ok = self.get_success(self.handler.is_event_allowed(self.spammy_event))
+ self.assertEqual(ok, True)
+ self.assertEqual(self.call_count, 0)
+
+ def test_invalid_server_policy_event_set(self) -> None:
+ self.helper.send_state(
+ self.room_id,
+ "org.matrix.msc4284.policy",
+ {
+ "via": "|this| is *not* a (valid) server name.com",
+ },
+ tok=self.creator_token,
+ state_key="",
+ )
+
+ ok = self.get_success(self.handler.is_event_allowed(self.spammy_event))
+ self.assertEqual(ok, True)
+ self.assertEqual(self.call_count, 0)
+
+ def test_not_in_room_policy_event_set(self) -> None:
+ self.helper.send_state(
+ self.room_id,
+ "org.matrix.msc4284.policy",
+ {
+ "via": f"x.{self.OTHER_SERVER_NAME}",
+ },
+ tok=self.creator_token,
+ state_key="",
+ )
+
+ ok = self.get_success(self.handler.is_event_allowed(self.spammy_event))
+ self.assertEqual(ok, True)
+ self.assertEqual(self.call_count, 0)
+
+ def test_spammy_event_is_spam(self) -> None:
+ self._add_policy_server_to_room()
+
+ ok = self.get_success(self.handler.is_event_allowed(self.spammy_event))
+ self.assertEqual(ok, False)
+ self.assertEqual(self.call_count, 1)
+
+ def test_not_spammy_event_is_not_spam(self) -> None:
+ self._add_policy_server_to_room()
+
+ ok = self.get_success(self.handler.is_event_allowed(self.not_spammy_event))
+ self.assertEqual(ok, True)
+ self.assertEqual(self.call_count, 1)
diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py
index 244a4e7689..b55fa1a8fd 100644
--- a/tests/handlers/test_room_summary.py
+++ b/tests/handlers/test_room_summary.py
@@ -757,6 +757,54 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
)
self._assert_hierarchy(result, expected)
+ def test_fed_root(self) -> None:
+ """
+ Test if requested room is available over federation.
+ """
+ fed_hostname = self.hs.hostname + "2"
+ fed_space = "#fed_space:" + fed_hostname
+ fed_subroom = "#fed_sub_room:" + fed_hostname
+
+ requested_room_entry = _RoomEntry(
+ fed_space,
+ {
+ "room_id": fed_space,
+ "world_readable": True,
+ "room_type": RoomTypes.SPACE,
+ },
+ [
+ {
+ "type": EventTypes.SpaceChild,
+ "room_id": fed_space,
+ "state_key": fed_subroom,
+ "content": {"via": [fed_hostname]},
+ }
+ ],
+ )
+ child_room = {
+ "room_id": fed_subroom,
+ "world_readable": True,
+ }
+
+ async def summarize_remote_room_hierarchy(
+ _self: Any, room: Any, suggested_only: bool
+ ) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]:
+ return requested_room_entry, {fed_subroom: child_room}, set()
+
+ expected = [
+ (fed_space, [fed_subroom]),
+ (fed_subroom, ()),
+ ]
+
+ with mock.patch(
+ "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy",
+ new=summarize_remote_room_hierarchy,
+ ):
+ result = self.get_success(
+ self.handler.get_room_hierarchy(create_requester(self.user), fed_space)
+ )
+ self._assert_hierarchy(result, expected)
+
def test_fed_filtering(self) -> None:
"""
Rooms returned over federation should be properly filtered to only include
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
deleted file mode 100644
index 6ab8fda6e7..0000000000
--- a/tests/handlers/test_saml.py
+++ /dev/null
@@ -1,381 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright 2020 The Matrix.org Foundation C.I.C.
-# Copyright (C) 2023 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-# Originally licensed under the Apache License, Version 2.0:
-# <http://www.apache.org/licenses/LICENSE-2.0>.
-#
-# [This file includes modifications made by New Vector Limited]
-#
-#
-
-from typing import Any, Dict, Optional, Set, Tuple
-from unittest.mock import AsyncMock, Mock
-
-import attr
-
-from twisted.test.proto_helpers import MemoryReactor
-
-from synapse.api.errors import RedirectException
-from synapse.module_api import ModuleApi
-from synapse.server import HomeServer
-from synapse.types import JsonDict
-from synapse.util import Clock
-
-from tests.unittest import HomeserverTestCase, override_config
-
-# Check if we have the dependencies to run the tests.
-try:
- import saml2.config
- import saml2.response
- from saml2.sigver import SigverError
-
- has_saml2 = True
-
- # pysaml2 can be installed and imported, but might not be able to find xmlsec1.
- config = saml2.config.SPConfig()
- try:
- config.load({"metadata": {}})
- has_xmlsec1 = True
- except SigverError:
- has_xmlsec1 = False
-except ImportError:
- has_saml2 = False
- has_xmlsec1 = False
-
-# These are a few constants that are used as config parameters in the tests.
-BASE_URL = "https://synapse/"
-
-
-@attr.s
-class FakeAuthnResponse:
- ava = attr.ib(type=dict)
- assertions = attr.ib(type=list, factory=list)
- in_response_to = attr.ib(type=Optional[str], default=None)
-
-
-class TestMappingProvider:
- def __init__(self, config: None, module: ModuleApi):
- pass
-
- @staticmethod
- def parse_config(config: JsonDict) -> None:
- return None
-
- @staticmethod
- def get_saml_attributes(config: None) -> Tuple[Set[str], Set[str]]:
- return {"uid"}, {"displayName"}
-
- def get_remote_user_id(
- self, saml_response: "saml2.response.AuthnResponse", client_redirect_url: str
- ) -> str:
- return saml_response.ava["uid"]
-
- def saml_response_to_user_attributes(
- self,
- saml_response: "saml2.response.AuthnResponse",
- failures: int,
- client_redirect_url: str,
- ) -> dict:
- localpart = saml_response.ava["username"] + (str(failures) if failures else "")
- return {"mxid_localpart": localpart, "displayname": None}
-
-
-class TestRedirectMappingProvider(TestMappingProvider):
- def saml_response_to_user_attributes(
- self,
- saml_response: "saml2.response.AuthnResponse",
- failures: int,
- client_redirect_url: str,
- ) -> dict:
- raise RedirectException(b"https://custom-saml-redirect/")
-
-
-class SamlHandlerTestCase(HomeserverTestCase):
- def default_config(self) -> Dict[str, Any]:
- config = super().default_config()
- config["public_baseurl"] = BASE_URL
- saml_config: Dict[str, Any] = {
- "sp_config": {"metadata": {}},
- # Disable grandfathering.
- "grandfathered_mxid_source_attribute": None,
- "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
- }
-
- # Update this config with what's in the default config so that
- # override_config works as expected.
- saml_config.update(config.get("saml2_config", {}))
- config["saml2_config"] = saml_config
-
- return config
-
- def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- hs = self.setup_test_homeserver()
-
- self.handler = hs.get_saml_handler()
-
- # Reduce the number of attempts when generating MXIDs.
- sso_handler = hs.get_sso_handler()
- sso_handler._MAP_USERNAME_RETRIES = 3
-
- return hs
-
- if not has_saml2:
- skip = "Requires pysaml2"
- elif not has_xmlsec1:
- skip = "Requires xmlsec1"
-
- def test_map_saml_response_to_user(self) -> None:
- """Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
-
- # stub out the auth handler
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
-
- # send a mocked-up SAML response to the callback
- saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
- request = _mock_request()
- self.get_success(
- self.handler._handle_authn_response(request, saml_response, "redirect_uri")
- )
-
- # check that the auth handler got called as expected
- auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test",
- "saml",
- request,
- "redirect_uri",
- None,
- new_user=True,
- auth_provider_session_id=None,
- )
-
- @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
- def test_map_saml_response_to_existing_user(self) -> None:
- """Existing users can log in with SAML account."""
- store = self.hs.get_datastores().main
- self.get_success(
- store.register_user(user_id="@test_user:test", password_hash=None)
- )
-
- # stub out the auth handler
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
-
- # Map a user via SSO.
- saml_response = FakeAuthnResponse(
- {"uid": "tester", "mxid": ["test_user"], "username": "test_user"}
- )
- request = _mock_request()
- self.get_success(
- self.handler._handle_authn_response(request, saml_response, "")
- )
-
- # check that the auth handler got called as expected
- auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test",
- "saml",
- request,
- "",
- None,
- new_user=False,
- auth_provider_session_id=None,
- )
-
- # Subsequent calls should map to the same mxid.
- auth_handler.complete_sso_login.reset_mock()
- self.get_success(
- self.handler._handle_authn_response(request, saml_response, "")
- )
- auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test",
- "saml",
- request,
- "",
- None,
- new_user=False,
- auth_provider_session_id=None,
- )
-
- def test_map_saml_response_to_invalid_localpart(self) -> None:
- """If the mapping provider generates an invalid localpart it should be rejected."""
-
- # stub out the auth handler
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
-
- # mock out the error renderer too
- sso_handler = self.hs.get_sso_handler()
- sso_handler.render_error = Mock(return_value=None) # type: ignore[method-assign]
-
- saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
- request = _mock_request()
- self.get_success(
- self.handler._handle_authn_response(request, saml_response, ""),
- )
- sso_handler.render_error.assert_called_once_with(
- request, "mapping_error", "localpart is invalid: föö"
- )
- auth_handler.complete_sso_login.assert_not_called()
-
- def test_map_saml_response_to_user_retries(self) -> None:
- """The mapping provider can retry generating an MXID if the MXID is already in use."""
-
- # stub out the auth handler and error renderer
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
- sso_handler = self.hs.get_sso_handler()
- sso_handler.render_error = Mock(return_value=None) # type: ignore[method-assign]
-
- # register a user to occupy the first-choice MXID
- store = self.hs.get_datastores().main
- self.get_success(
- store.register_user(user_id="@test_user:test", password_hash=None)
- )
-
- # send the fake SAML response
- saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
- request = _mock_request()
- self.get_success(
- self.handler._handle_authn_response(request, saml_response, ""),
- )
-
- # test_user is already taken, so test_user1 gets registered instead.
- auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user1:test",
- "saml",
- request,
- "",
- None,
- new_user=True,
- auth_provider_session_id=None,
- )
- auth_handler.complete_sso_login.reset_mock()
-
- # Register all of the potential mxids for a particular SAML username.
- self.get_success(
- store.register_user(user_id="@tester:test", password_hash=None)
- )
- for i in range(1, 3):
- self.get_success(
- store.register_user(user_id="@tester%d:test" % i, password_hash=None)
- )
-
- # Now attempt to map to a username, this will fail since all potential usernames are taken.
- saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"})
- self.get_success(
- self.handler._handle_authn_response(request, saml_response, ""),
- )
- sso_handler.render_error.assert_called_once_with(
- request,
- "mapping_error",
- "Unable to generate a Matrix ID from the SSO response",
- )
- auth_handler.complete_sso_login.assert_not_called()
-
- @override_config(
- {
- "saml2_config": {
- "user_mapping_provider": {
- "module": __name__ + ".TestRedirectMappingProvider"
- },
- }
- }
- )
- def test_map_saml_response_redirect(self) -> None:
- """Test a mapping provider that raises a RedirectException"""
-
- saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
- request = _mock_request()
- e = self.get_failure(
- self.handler._handle_authn_response(request, saml_response, ""),
- RedirectException,
- )
- self.assertEqual(e.value.location, b"https://custom-saml-redirect/")
-
- @override_config(
- {
- "saml2_config": {
- "attribute_requirements": [
- {"attribute": "userGroup", "value": "staff"},
- {"attribute": "department", "value": "sales"},
- ],
- },
- }
- )
- def test_attribute_requirements(self) -> None:
- """The required attributes must be met from the SAML response."""
-
- # stub out the auth handler
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
-
- # The response doesn't have the proper userGroup or department.
- saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
- request = _mock_request()
- self.get_success(
- self.handler._handle_authn_response(request, saml_response, "redirect_uri")
- )
- auth_handler.complete_sso_login.assert_not_called()
-
- # The response doesn't have the proper department.
- saml_response = FakeAuthnResponse(
- {"uid": "test_user", "username": "test_user", "userGroup": ["staff"]}
- )
- request = _mock_request()
- self.get_success(
- self.handler._handle_authn_response(request, saml_response, "redirect_uri")
- )
- auth_handler.complete_sso_login.assert_not_called()
-
- # Add the proper attributes and it should succeed.
- saml_response = FakeAuthnResponse(
- {
- "uid": "test_user",
- "username": "test_user",
- "userGroup": ["staff", "admin"],
- "department": ["sales"],
- }
- )
- request.reset_mock()
- self.get_success(
- self.handler._handle_authn_response(request, saml_response, "redirect_uri")
- )
-
- # check that the auth handler got called as expected
- auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test",
- "saml",
- request,
- "redirect_uri",
- None,
- new_user=True,
- auth_provider_session_id=None,
- )
-
-
-def _mock_request() -> Mock:
- """Returns a mock which will stand in as a SynapseRequest"""
- mock = Mock(
- spec=[
- "finish",
- "getClientAddress",
- "getHeader",
- "setHeader",
- "setResponseCode",
- "write",
- ]
- )
- # `_disconnected` musn't be another `Mock`, otherwise it will be truthy.
- mock._disconnected = False
- return mock
diff --git a/tests/handlers/test_send_email.py b/tests/handlers/test_send_email.py
deleted file mode 100644
index cedcea27d9..0000000000
--- a/tests/handlers/test_send_email.py
+++ /dev/null
@@ -1,230 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright 2021 The Matrix.org Foundation C.I.C.
-# Copyright (C) 2023 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-# Originally licensed under the Apache License, Version 2.0:
-# <http://www.apache.org/licenses/LICENSE-2.0>.
-#
-# [This file includes modifications made by New Vector Limited]
-#
-#
-
-
-from typing import Callable, List, Tuple, Type, Union
-from unittest.mock import patch
-
-from zope.interface import implementer
-
-from twisted.internet import defer
-from twisted.internet._sslverify import ClientTLSOptions
-from twisted.internet.address import IPv4Address, IPv6Address
-from twisted.internet.defer import ensureDeferred
-from twisted.internet.interfaces import IProtocolFactory
-from twisted.internet.ssl import ContextFactory
-from twisted.mail import interfaces, smtp
-
-from tests.server import FakeTransport
-from tests.unittest import HomeserverTestCase, override_config
-
-
-def TestingESMTPTLSClientFactory(
- contextFactory: ContextFactory,
- _connectWrapped: bool,
- wrappedProtocol: IProtocolFactory,
-) -> IProtocolFactory:
- """We use this to pass through in testing without using TLS, but
- saving the context information to check that it would have happened.
-
- Note that this is what the MemoryReactor does on connectSSL.
- It only saves the contextFactory, but starts the connection with the
- underlying Factory.
- See: L{twisted.internet.testing.MemoryReactor.connectSSL}"""
-
- wrappedProtocol._testingContextFactory = contextFactory # type: ignore[attr-defined]
- return wrappedProtocol
-
-
-@implementer(interfaces.IMessageDelivery)
-class _DummyMessageDelivery:
- def __init__(self) -> None:
- # (recipient, message) tuples
- self.messages: List[Tuple[smtp.Address, bytes]] = []
-
- def receivedHeader(
- self,
- helo: Tuple[bytes, bytes],
- origin: smtp.Address,
- recipients: List[smtp.User],
- ) -> None:
- return None
-
- def validateFrom(
- self, helo: Tuple[bytes, bytes], origin: smtp.Address
- ) -> smtp.Address:
- return origin
-
- def record_message(self, recipient: smtp.Address, message: bytes) -> None:
- self.messages.append((recipient, message))
-
- def validateTo(self, user: smtp.User) -> Callable[[], interfaces.IMessageSMTP]:
- return lambda: _DummyMessage(self, user)
-
-
-@implementer(interfaces.IMessageSMTP)
-class _DummyMessage:
- """IMessageSMTP implementation which saves the message delivered to it
- to the _DummyMessageDelivery object.
- """
-
- def __init__(self, delivery: _DummyMessageDelivery, user: smtp.User):
- self._delivery = delivery
- self._user = user
- self._buffer: List[bytes] = []
-
- def lineReceived(self, line: bytes) -> None:
- self._buffer.append(line)
-
- def eomReceived(self) -> "defer.Deferred[bytes]":
- message = b"\n".join(self._buffer) + b"\n"
- self._delivery.record_message(self._user.dest, message)
- return defer.succeed(b"saved")
-
- def connectionLost(self) -> None:
- pass
-
-
-class SendEmailHandlerTestCaseIPv4(HomeserverTestCase):
- ip_class: Union[Type[IPv4Address], Type[IPv6Address]] = IPv4Address
-
- def setUp(self) -> None:
- super().setUp()
- self.reactor.lookups["localhost"] = "127.0.0.1"
-
- def test_send_email(self) -> None:
- """Happy-path test that we can send email to a non-TLS server."""
- h = self.hs.get_send_email_handler()
- d = ensureDeferred(
- h.send_email(
- "foo@bar.com", "test subject", "Tests", "HTML content", "Text content"
- )
- )
- # there should be an attempt to connect to localhost:25
- self.assertEqual(len(self.reactor.tcpClients), 1)
- (host, port, client_factory, _timeout, _bindAddress) = self.reactor.tcpClients[
- 0
- ]
- self.assertEqual(host, self.reactor.lookups["localhost"])
- self.assertEqual(port, 25)
-
- # wire it up to an SMTP server
- message_delivery = _DummyMessageDelivery()
- server_protocol = smtp.ESMTP()
- server_protocol.delivery = message_delivery
- # make sure that the server uses the test reactor to set timeouts
- server_protocol.callLater = self.reactor.callLater # type: ignore[assignment]
-
- client_protocol = client_factory.buildProtocol(None)
- client_protocol.makeConnection(FakeTransport(server_protocol, self.reactor))
- server_protocol.makeConnection(
- FakeTransport(
- client_protocol,
- self.reactor,
- peer_address=self.ip_class(
- "TCP", self.reactor.lookups["localhost"], 1234
- ),
- )
- )
-
- # the message should now get delivered
- self.get_success(d, by=0.1)
-
- # check it arrived
- self.assertEqual(len(message_delivery.messages), 1)
- user, msg = message_delivery.messages.pop()
- self.assertEqual(str(user), "foo@bar.com")
- self.assertIn(b"Subject: test subject", msg)
-
- @patch(
- "synapse.handlers.send_email.TLSMemoryBIOFactory",
- TestingESMTPTLSClientFactory,
- )
- @override_config(
- {
- "email": {
- "notif_from": "noreply@test",
- "force_tls": True,
- },
- }
- )
- def test_send_email_force_tls(self) -> None:
- """Happy-path test that we can send email to an Implicit TLS server."""
- h = self.hs.get_send_email_handler()
- d = ensureDeferred(
- h.send_email(
- "foo@bar.com", "test subject", "Tests", "HTML content", "Text content"
- )
- )
- # there should be an attempt to connect to localhost:465
- self.assertEqual(len(self.reactor.tcpClients), 1)
- (
- host,
- port,
- client_factory,
- _timeout,
- _bindAddress,
- ) = self.reactor.tcpClients[0]
- self.assertEqual(host, self.reactor.lookups["localhost"])
- self.assertEqual(port, 465)
- # We need to make sure that TLS is happenning
- self.assertIsInstance(
- client_factory._wrappedFactory._testingContextFactory,
- ClientTLSOptions,
- )
- # And since we use endpoints, they go through reactor.connectTCP
- # which works differently to connectSSL on the testing reactor
-
- # wire it up to an SMTP server
- message_delivery = _DummyMessageDelivery()
- server_protocol = smtp.ESMTP()
- server_protocol.delivery = message_delivery
- # make sure that the server uses the test reactor to set timeouts
- server_protocol.callLater = self.reactor.callLater # type: ignore[assignment]
-
- client_protocol = client_factory.buildProtocol(None)
- client_protocol.makeConnection(FakeTransport(server_protocol, self.reactor))
- server_protocol.makeConnection(
- FakeTransport(
- client_protocol,
- self.reactor,
- peer_address=self.ip_class(
- "TCP", self.reactor.lookups["localhost"], 1234
- ),
- )
- )
-
- # the message should now get delivered
- self.get_success(d, by=0.1)
-
- # check it arrived
- self.assertEqual(len(message_delivery.messages), 1)
- user, msg = message_delivery.messages.pop()
- self.assertEqual(str(user), "foo@bar.com")
- self.assertIn(b"Subject: test subject", msg)
-
-
-class SendEmailHandlerTestCaseIPv6(SendEmailHandlerTestCaseIPv4):
- ip_class = IPv6Address
-
- def setUp(self) -> None:
- super().setUp()
- self.reactor.lookups["localhost"] = "::1"
diff --git a/tests/handlers/test_sliding_sync.py b/tests/handlers/test_sliding_sync.py
index 96da47f3b9..7144c58217 100644
--- a/tests/handlers/test_sliding_sync.py
+++ b/tests/handlers/test_sliding_sync.py
@@ -18,39 +18,40 @@
#
#
import logging
-from copy import deepcopy
-from typing import Dict, List, Optional
+from typing import AbstractSet, Dict, Mapping, Optional, Set, Tuple
from unittest.mock import patch
-from parameterized import parameterized
+import attr
+from parameterized import parameterized, parameterized_class
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import (
- AccountDataTypes,
- EventContentFields,
EventTypes,
JoinRules,
Membership,
- RoomTypes,
)
from synapse.api.room_versions import RoomVersions
-from synapse.events import StrippedStateEvent, make_event_from_dict
-from synapse.events.snapshot import EventContext
from synapse.handlers.sliding_sync import (
+ MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER,
+ RoomsForUserType,
RoomSyncConfig,
StateValues,
- _RoomMembershipForUser,
+ _required_state_changes,
)
from synapse.rest import admin
from synapse.rest.client import knock, login, room
from synapse.server import HomeServer
from synapse.storage.util.id_generators import MultiWriterIdGenerator
-from synapse.types import JsonDict, StreamToken, UserID
-from synapse.types.handlers import SlidingSyncConfig
+from synapse.types import JsonDict, StateMap, StreamToken, UserID, create_requester
+from synapse.types.handlers.sliding_sync import PerConnectionState, SlidingSyncConfig
+from synapse.types.state import StateFilter
from synapse.util import Clock
+from tests import unittest
from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase
+from tests.test_utils.event_injection import create_event
from tests.unittest import HomeserverTestCase, TestCase
logger = logging.getLogger(__name__)
@@ -566,31 +567,39 @@ class RoomSyncConfigTestCase(TestCase):
"""
Combine A into B and B into A to make sure we get the same result.
"""
- # Since we're mutating these in place, make a copy for each of our trials
- room_sync_config_a = deepcopy(a)
- room_sync_config_b = deepcopy(b)
-
- # Combine B into A
- room_sync_config_a.combine_room_sync_config(room_sync_config_b)
-
- self._assert_room_config_equal(room_sync_config_a, expected, "B into A")
-
- # Since we're mutating these in place, make a copy for each of our trials
- room_sync_config_a = deepcopy(a)
- room_sync_config_b = deepcopy(b)
-
- # Combine A into B
- room_sync_config_b.combine_room_sync_config(room_sync_config_a)
-
- self._assert_room_config_equal(room_sync_config_b, expected, "A into B")
-
-
-class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
+ combined_config = a.combine_room_sync_config(b)
+ self._assert_room_config_equal(combined_config, expected, "B into A")
+
+ combined_config = a.combine_room_sync_config(b)
+ self._assert_room_config_equal(combined_config, expected, "A into B")
+
+
+# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+# foreground update for
+# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+# https://github.com/element-hq/synapse/issues/17623)
+@parameterized_class(
+ ("use_new_tables",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}",
+)
+class ComputeInterestedRoomsTestCase(SlidingSyncBase):
"""
- Tests Sliding Sync handler `get_room_membership_for_user_at_to_token()` to make sure it returns
+ Tests Sliding Sync handler `compute_interested_rooms()` to make sure it returns
the correct list of rooms IDs.
"""
+ # FIXME: We should refactor these tests to run against `compute_interested_rooms(...)`
+ # instead of just `get_room_membership_for_user_at_to_token(...)` which is only used
+ # in the fallback path (`_compute_interested_rooms_fallback(...)`). These scenarios do
+ # well to stress that logic and we shouldn't remove them just because we're removing
+ # the fallback path (tracked by https://github.com/element-hq/synapse/issues/17623).
+
servlets = [
admin.register_servlets,
knock.register_servlets,
@@ -609,6 +618,11 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
self.store = self.hs.get_datastores().main
self.event_sources = hs.get_event_sources()
self.storage_controllers = hs.get_storage_controllers()
+ persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
+ self.persistence = persistence
+
+ super().prepare(reactor, clock, hs)
def test_no_rooms(self) -> None:
"""
@@ -619,15 +633,28 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
now_token = self.event_sources.get_current_token()
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=now_token,
to_token=now_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
- self.assertEqual(room_id_results.keys(), set())
+ self.assertIncludes(room_id_results, set(), exact=True)
def test_get_newly_joined_room(self) -> None:
"""
@@ -646,26 +673,48 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
after_room_token = self.event_sources.get_current_token()
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=before_room_token,
to_token=after_room_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
- self.assertEqual(room_id_results.keys(), {room_id})
+ self.assertIncludes(
+ room_id_results,
+ {room_id},
+ exact=True,
+ )
# It should be pointing to the join event (latest membership event in the
# from/to range)
self.assertEqual(
- room_id_results[room_id].event_id,
+ interested_rooms.room_membership_for_user_map[room_id].event_id,
join_response["event_id"],
)
- self.assertEqual(room_id_results[room_id].membership, Membership.JOIN)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id].membership,
+ Membership.JOIN,
+ )
# We should be considered `newly_joined` because we joined during the token
# range
- self.assertEqual(room_id_results[room_id].newly_joined, True)
- self.assertEqual(room_id_results[room_id].newly_left, False)
+ self.assertTrue(room_id in newly_joined)
+ self.assertTrue(room_id not in newly_left)
def test_get_already_joined_room(self) -> None:
"""
@@ -681,25 +730,43 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
after_room_token = self.event_sources.get_current_token()
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=after_room_token,
to_token=after_room_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
- self.assertEqual(room_id_results.keys(), {room_id})
+ self.assertIncludes(room_id_results, {room_id}, exact=True)
# It should be pointing to the join event (latest membership event in the
# from/to range)
self.assertEqual(
- room_id_results[room_id].event_id,
+ interested_rooms.room_membership_for_user_map[room_id].event_id,
join_response["event_id"],
)
- self.assertEqual(room_id_results[room_id].membership, Membership.JOIN)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id].membership,
+ Membership.JOIN,
+ )
# We should *NOT* be `newly_joined` because we joined before the token range
- self.assertEqual(room_id_results[room_id].newly_joined, False)
- self.assertEqual(room_id_results[room_id].newly_left, False)
+ self.assertTrue(room_id not in newly_joined)
+ self.assertTrue(room_id not in newly_left)
def test_get_invited_banned_knocked_room(self) -> None:
"""
@@ -755,48 +822,73 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
after_room_token = self.event_sources.get_current_token()
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=before_room_token,
to_token=after_room_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# Ensure that the invited, ban, and knock rooms show up
- self.assertEqual(
- room_id_results.keys(),
+ self.assertIncludes(
+ room_id_results,
{
invited_room_id,
ban_room_id,
knock_room_id,
},
+ exact=True,
)
# It should be pointing to the the respective membership event (latest
# membership event in the from/to range)
self.assertEqual(
- room_id_results[invited_room_id].event_id,
+ interested_rooms.room_membership_for_user_map[invited_room_id].event_id,
invite_response["event_id"],
)
- self.assertEqual(room_id_results[invited_room_id].membership, Membership.INVITE)
- self.assertEqual(room_id_results[invited_room_id].newly_joined, False)
- self.assertEqual(room_id_results[invited_room_id].newly_left, False)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[invited_room_id].membership,
+ Membership.INVITE,
+ )
+ self.assertTrue(invited_room_id not in newly_joined)
+ self.assertTrue(invited_room_id not in newly_left)
self.assertEqual(
- room_id_results[ban_room_id].event_id,
+ interested_rooms.room_membership_for_user_map[ban_room_id].event_id,
ban_response["event_id"],
)
- self.assertEqual(room_id_results[ban_room_id].membership, Membership.BAN)
- self.assertEqual(room_id_results[ban_room_id].newly_joined, False)
- self.assertEqual(room_id_results[ban_room_id].newly_left, False)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[ban_room_id].membership,
+ Membership.BAN,
+ )
+ self.assertTrue(ban_room_id not in newly_joined)
+ self.assertTrue(ban_room_id not in newly_left)
self.assertEqual(
- room_id_results[knock_room_id].event_id,
+ interested_rooms.room_membership_for_user_map[knock_room_id].event_id,
knock_room_membership_state_event.event_id,
)
- self.assertEqual(room_id_results[knock_room_id].membership, Membership.KNOCK)
- self.assertEqual(room_id_results[knock_room_id].newly_joined, False)
- self.assertEqual(room_id_results[knock_room_id].newly_left, False)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[knock_room_id].membership,
+ Membership.KNOCK,
+ )
+ self.assertTrue(knock_room_id not in newly_joined)
+ self.assertTrue(knock_room_id not in newly_left)
def test_get_kicked_room(self) -> None:
"""
@@ -827,27 +919,47 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
after_kick_token = self.event_sources.get_current_token()
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=after_kick_token,
to_token=after_kick_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# The kicked room should show up
- self.assertEqual(room_id_results.keys(), {kick_room_id})
+ self.assertIncludes(room_id_results, {kick_room_id}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[kick_room_id].event_id,
+ interested_rooms.room_membership_for_user_map[kick_room_id].event_id,
kick_response["event_id"],
)
- self.assertEqual(room_id_results[kick_room_id].membership, Membership.LEAVE)
- self.assertNotEqual(room_id_results[kick_room_id].sender, user1_id)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[kick_room_id].membership,
+ Membership.LEAVE,
+ )
+ self.assertNotEqual(
+ interested_rooms.room_membership_for_user_map[kick_room_id].sender, user1_id
+ )
# We should *NOT* be `newly_joined` because we were not joined at the the time
# of the `to_token`.
- self.assertEqual(room_id_results[kick_room_id].newly_joined, False)
- self.assertEqual(room_id_results[kick_room_id].newly_left, False)
+ self.assertTrue(kick_room_id not in newly_joined)
+ self.assertTrue(kick_room_id not in newly_left)
def test_forgotten_rooms(self) -> None:
"""
@@ -920,16 +1032,29 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
)
self.assertEqual(channel.code, 200, channel.result)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=before_room_forgets,
to_token=before_room_forgets,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
# We shouldn't see the room because it was forgotten
- self.assertEqual(room_id_results.keys(), set())
+ self.assertIncludes(room_id_results, set(), exact=True)
def test_newly_left_rooms(self) -> None:
"""
@@ -940,7 +1065,7 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
# Leave before we calculate the `from_token`
room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok)
- leave_response1 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
+ _leave_response1 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
after_room1_token = self.event_sources.get_current_token()
@@ -950,34 +1075,55 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
after_room2_token = self.event_sources.get_current_token()
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=after_room1_token,
to_token=after_room2_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
- self.assertEqual(room_id_results.keys(), {room_id1, room_id2})
-
- self.assertEqual(
- room_id_results[room_id1].event_id,
- leave_response1["event_id"],
+ # `room_id1` should not show up because it was left before the token range.
+ # `room_id2` should show up because it is `newly_left` within the token range.
+ self.assertIncludes(
+ room_id_results,
+ {room_id2},
+ exact=True,
+ message="Corresponding map to disambiguate the opaque room IDs: "
+ + str(
+ {
+ "room_id1": room_id1,
+ "room_id2": room_id2,
+ }
+ ),
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE)
- # We should *NOT* be `newly_joined` or `newly_left` because that happened before
- # the from/to range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
self.assertEqual(
- room_id_results[room_id2].event_id,
+ interested_rooms.room_membership_for_user_map[room_id2].event_id,
leave_response2["event_id"],
)
- self.assertEqual(room_id_results[room_id2].membership, Membership.LEAVE)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id2].membership,
+ Membership.LEAVE,
+ )
# We should *NOT* be `newly_joined` because we are instead `newly_left`
- self.assertEqual(room_id_results[room_id2].newly_joined, False)
- self.assertEqual(room_id_results[room_id2].newly_left, True)
+ self.assertTrue(room_id2 not in newly_joined)
+ self.assertTrue(room_id2 in newly_left)
def test_no_joins_after_to_token(self) -> None:
"""
@@ -1000,24 +1146,42 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok)
self.helper.join(room_id2, user1_id, tok=user1_tok)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=before_room1_token,
to_token=after_room1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
- self.assertEqual(room_id_results.keys(), {room_id1})
+ self.assertIncludes(room_id_results, {room_id1}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
join_response1["event_id"],
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.JOIN,
+ )
# We should be `newly_joined` because we joined during the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, True)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
+ self.assertTrue(room_id1 in newly_joined)
+ self.assertTrue(room_id1 not in newly_left)
def test_join_during_range_and_left_room_after_to_token(self) -> None:
"""
@@ -1040,20 +1204,35 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
# Leave the room after we already have our tokens
leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=before_room1_token,
to_token=after_room1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# We should still see the room because we were joined during the
# from_token/to_token time period.
- self.assertEqual(room_id_results.keys(), {room_id1})
+ self.assertIncludes(room_id_results, {room_id1}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
join_response["event_id"],
"Corresponding map to disambiguate the opaque event IDs: "
+ str(
@@ -1063,10 +1242,13 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
}
),
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.JOIN,
+ )
# We should be `newly_joined` because we joined during the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, True)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
+ self.assertTrue(room_id1 in newly_joined)
+ self.assertTrue(room_id1 not in newly_left)
def test_join_before_range_and_left_room_after_to_token(self) -> None:
"""
@@ -1087,19 +1269,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
# Leave the room after we already have our tokens
leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=after_room1_token,
to_token=after_room1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# We should still see the room because we were joined before the `from_token`
- self.assertEqual(room_id_results.keys(), {room_id1})
+ self.assertIncludes(room_id_results, {room_id1}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
join_response["event_id"],
"Corresponding map to disambiguate the opaque event IDs: "
+ str(
@@ -1109,10 +1306,13 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
}
),
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.JOIN,
+ )
# We should *NOT* be `newly_joined` because we joined before the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
+ self.assertTrue(room_id1 not in newly_joined)
+ self.assertTrue(room_id1 not in newly_left)
def test_kicked_before_range_and_left_after_to_token(self) -> None:
"""
@@ -1151,19 +1351,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
join_response2 = self.helper.join(kick_room_id, user1_id, tok=user1_tok)
leave_response = self.helper.leave(kick_room_id, user1_id, tok=user1_tok)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=after_kick_token,
to_token=after_kick_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# We shouldn't see the room because it was forgotten
- self.assertEqual(room_id_results.keys(), {kick_room_id})
+ self.assertIncludes(room_id_results, {kick_room_id}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[kick_room_id].event_id,
+ interested_rooms.room_membership_for_user_map[kick_room_id].event_id,
kick_response["event_id"],
"Corresponding map to disambiguate the opaque event IDs: "
+ str(
@@ -1175,11 +1390,16 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
}
),
)
- self.assertEqual(room_id_results[kick_room_id].membership, Membership.LEAVE)
- self.assertNotEqual(room_id_results[kick_room_id].sender, user1_id)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[kick_room_id].membership,
+ Membership.LEAVE,
+ )
+ self.assertNotEqual(
+ interested_rooms.room_membership_for_user_map[kick_room_id].sender, user1_id
+ )
# We should *NOT* be `newly_joined` because we were kicked
- self.assertEqual(room_id_results[kick_room_id].newly_joined, False)
- self.assertEqual(room_id_results[kick_room_id].newly_left, False)
+ self.assertTrue(kick_room_id not in newly_joined)
+ self.assertTrue(kick_room_id not in newly_left)
def test_newly_left_during_range_and_join_leave_after_to_token(self) -> None:
"""
@@ -1207,19 +1427,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
join_response2 = self.helper.join(room_id1, user1_id, tok=user1_tok)
leave_response2 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=before_room1_token,
to_token=after_room1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# Room should still show up because it's newly_left during the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
+ self.assertIncludes(room_id_results, {room_id1}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
leave_response1["event_id"],
"Corresponding map to disambiguate the opaque event IDs: "
+ str(
@@ -1231,11 +1466,14 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
}
),
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.LEAVE,
+ )
# We should *NOT* be `newly_joined` because we are actually `newly_left` during
# the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, True)
+ self.assertTrue(room_id1 not in newly_joined)
+ self.assertTrue(room_id1 in newly_left)
def test_newly_left_during_range_and_join_after_to_token(self) -> None:
"""
@@ -1262,19 +1500,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
# Join the room after we already have our tokens
join_response2 = self.helper.join(room_id1, user1_id, tok=user1_tok)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=before_room1_token,
to_token=after_room1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# Room should still show up because it's newly_left during the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
+ self.assertIncludes(room_id_results, {room_id1}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
leave_response1["event_id"],
"Corresponding map to disambiguate the opaque event IDs: "
+ str(
@@ -1285,11 +1538,14 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
}
),
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.LEAVE,
+ )
# We should *NOT* be `newly_joined` because we are actually `newly_left` during
# the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, True)
+ self.assertTrue(room_id1 not in newly_joined)
+ self.assertTrue(room_id1 in newly_left)
def test_no_from_token(self) -> None:
"""
@@ -1314,47 +1570,53 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
# Join and leave the room2 before the `to_token`
self.helper.join(room_id2, user1_id, tok=user1_tok)
- leave_response2 = self.helper.leave(room_id2, user1_id, tok=user1_tok)
+ _leave_response2 = self.helper.leave(room_id2, user1_id, tok=user1_tok)
after_room1_token = self.event_sources.get_current_token()
# Join the room2 after we already have our tokens
self.helper.join(room_id2, user1_id, tok=user1_tok)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=None,
to_token=after_room1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# Only rooms we were joined to before the `to_token` should show up
- self.assertEqual(room_id_results.keys(), {room_id1, room_id2})
+ self.assertIncludes(room_id_results, {room_id1}, exact=True)
# Room1
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
join_response1["event_id"],
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
- # We should *NOT* be `newly_joined`/`newly_left` because there is no
- # `from_token` to define a "live" range to compare against
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
-
- # Room2
- # It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id2].event_id,
- leave_response2["event_id"],
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.JOIN,
)
- self.assertEqual(room_id_results[room_id2].membership, Membership.LEAVE)
# We should *NOT* be `newly_joined`/`newly_left` because there is no
# `from_token` to define a "live" range to compare against
- self.assertEqual(room_id_results[room_id2].newly_joined, False)
- self.assertEqual(room_id_results[room_id2].newly_left, False)
+ self.assertTrue(room_id1 not in newly_joined)
+ self.assertTrue(room_id1 not in newly_left)
def test_from_token_ahead_of_to_token(self) -> None:
"""
@@ -1378,7 +1640,7 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
# Join and leave the room2 before `to_token`
_join_room2_response1 = self.helper.join(room_id2, user1_id, tok=user1_tok)
- leave_room2_response1 = self.helper.leave(room_id2, user1_id, tok=user1_tok)
+ _leave_room2_response1 = self.helper.leave(room_id2, user1_id, tok=user1_tok)
# Note: These are purposely swapped. The `from_token` should come after
# the `to_token` in this test
@@ -1403,54 +1665,69 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
# Join the room4 after we already have our tokens
self.helper.join(room_id4, user1_id, tok=user1_tok)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=from_token,
to_token=to_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# In the "current" state snapshot, we're joined to all of the rooms but in the
# from/to token range...
self.assertIncludes(
- room_id_results.keys(),
+ room_id_results,
{
# Included because we were joined before both tokens
room_id1,
- # Included because we had membership before the to_token
- room_id2,
+ # Excluded because we left before the `from_token` and `to_token`
+ # room_id2,
# Excluded because we joined after the `to_token`
# room_id3,
# Excluded because we joined after the `to_token`
# room_id4,
},
exact=True,
+ message="Corresponding map to disambiguate the opaque room IDs: "
+ + str(
+ {
+ "room_id1": room_id1,
+ "room_id2": room_id2,
+ "room_id3": room_id3,
+ "room_id4": room_id4,
+ }
+ ),
)
# Room1
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
join_room1_response1["event_id"],
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
- # We should *NOT* be `newly_joined`/`newly_left` because we joined `room1`
- # before either of the tokens
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
-
- # Room2
- # It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id2].event_id,
- leave_room2_response1["event_id"],
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.JOIN,
)
- self.assertEqual(room_id_results[room_id2].membership, Membership.LEAVE)
- # We should *NOT* be `newly_joined`/`newly_left` because we joined and left
- # `room1` before either of the tokens
- self.assertEqual(room_id_results[room_id2].newly_joined, False)
- self.assertEqual(room_id_results[room_id2].newly_left, False)
+ # We should *NOT* be `newly_joined`/`newly_left` because we joined `room1`
+ # before either of the tokens
+ self.assertTrue(room_id1 not in newly_joined)
+ self.assertTrue(room_id1 not in newly_left)
def test_leave_before_range_and_join_leave_after_to_token(self) -> None:
"""
@@ -1468,7 +1745,7 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
# Join and leave the room before the from/to range
self.helper.join(room_id1, user1_id, tok=user1_tok)
- leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
+ self.helper.leave(room_id1, user1_id, tok=user1_tok)
after_room1_token = self.event_sources.get_current_token()
@@ -1476,25 +1753,28 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
self.helper.join(room_id1, user1_id, tok=user1_tok)
self.helper.leave(room_id1, user1_id, tok=user1_tok)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=after_room1_token,
to_token=after_room1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
- self.assertEqual(room_id_results.keys(), {room_id1})
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[room_id1].event_id,
- leave_response["event_id"],
- )
- self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE)
- # We should *NOT* be `newly_joined`/`newly_left` because we joined and left
- # `room1` before either of the tokens
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
+ self.assertIncludes(room_id_results, set(), exact=True)
def test_leave_before_range_and_join_after_to_token(self) -> None:
"""
@@ -1512,32 +1792,35 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
# Join and leave the room before the from/to range
self.helper.join(room_id1, user1_id, tok=user1_tok)
- leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
+ self.helper.leave(room_id1, user1_id, tok=user1_tok)
after_room1_token = self.event_sources.get_current_token()
# Join the room after we already have our tokens
self.helper.join(room_id1, user1_id, tok=user1_tok)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=after_room1_token,
to_token=after_room1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
- self.assertEqual(room_id_results.keys(), {room_id1})
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[room_id1].event_id,
- leave_response["event_id"],
- )
- self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE)
- # We should *NOT* be `newly_joined`/`newly_left` because we joined and left
- # `room1` before either of the tokens
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
+ self.assertIncludes(room_id_results, set(), exact=True)
def test_join_leave_multiple_times_during_range_and_after_to_token(
self,
@@ -1569,19 +1852,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
join_response3 = self.helper.join(room_id1, user1_id, tok=user1_tok)
leave_response3 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=before_room1_token,
to_token=after_room1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# Room should show up because it was newly_left and joined during the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
+ self.assertIncludes(room_id_results, {room_id1}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
join_response2["event_id"],
"Corresponding map to disambiguate the opaque event IDs: "
+ str(
@@ -1595,12 +1893,15 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
}
),
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.JOIN,
+ )
# We should be `newly_joined` because we joined during the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, True)
+ self.assertTrue(room_id1 in newly_joined)
# We should *NOT* be `newly_left` because we joined during the token range and
# was still joined at the end of the range
- self.assertEqual(room_id_results[room_id1].newly_left, False)
+ self.assertTrue(room_id1 not in newly_left)
def test_join_leave_multiple_times_before_range_and_after_to_token(
self,
@@ -1631,19 +1932,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
join_response3 = self.helper.join(room_id1, user1_id, tok=user1_tok)
leave_response3 = self.helper.leave(room_id1, user1_id, tok=user1_tok)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=after_room1_token,
to_token=after_room1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# Room should show up because we were joined before the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
+ self.assertIncludes(room_id_results, {room_id1}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
join_response2["event_id"],
"Corresponding map to disambiguate the opaque event IDs: "
+ str(
@@ -1657,10 +1973,13 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
}
),
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.JOIN,
+ )
# We should *NOT* be `newly_joined` because we joined before the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
+ self.assertTrue(room_id1 not in newly_joined)
+ self.assertTrue(room_id1 not in newly_left)
def test_invite_before_range_and_join_leave_after_to_token(
self,
@@ -1690,19 +2009,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
join_respsonse = self.helper.join(room_id1, user1_id, tok=user1_tok)
leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=after_room1_token,
to_token=after_room1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# Room should show up because we were invited before the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
+ self.assertIncludes(room_id_results, {room_id1}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
invite_response["event_id"],
"Corresponding map to disambiguate the opaque event IDs: "
+ str(
@@ -1713,11 +2047,14 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
}
),
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.INVITE)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.INVITE,
+ )
# We should *NOT* be `newly_joined` because we were only invited before the
# token range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
+ self.assertTrue(room_id1 not in newly_joined)
+ self.assertTrue(room_id1 not in newly_left)
def test_join_and_display_name_changes_in_token_range(
self,
@@ -1764,19 +2101,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
tok=user1_tok,
)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=before_room1_token,
to_token=after_room1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# Room should show up because we were joined during the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
+ self.assertIncludes(room_id_results, {room_id1}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
displayname_change_during_token_range_response["event_id"],
"Corresponding map to disambiguate the opaque event IDs: "
+ str(
@@ -1791,10 +2143,13 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
}
),
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.JOIN,
+ )
# We should be `newly_joined` because we joined during the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, True)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
+ self.assertTrue(room_id1 in newly_joined)
+ self.assertTrue(room_id1 not in newly_left)
def test_display_name_changes_in_token_range(
self,
@@ -1829,19 +2184,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
after_change1_token = self.event_sources.get_current_token()
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=after_room1_token,
to_token=after_change1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# Room should show up because we were joined during the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
+ self.assertIncludes(room_id_results, {room_id1}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
displayname_change_during_token_range_response["event_id"],
"Corresponding map to disambiguate the opaque event IDs: "
+ str(
@@ -1853,10 +2223,13 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
}
),
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.JOIN,
+ )
# We should *NOT* be `newly_joined` because we joined before the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
+ self.assertTrue(room_id1 not in newly_joined)
+ self.assertTrue(room_id1 not in newly_left)
def test_display_name_changes_before_and_after_token_range(
self,
@@ -1901,19 +2274,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
tok=user1_tok,
)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=after_room1_token,
to_token=after_room1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# Room should show up because we were joined before the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
+ self.assertIncludes(room_id_results, {room_id1}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
displayname_change_before_token_range_response["event_id"],
"Corresponding map to disambiguate the opaque event IDs: "
+ str(
@@ -1928,18 +2316,22 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
}
),
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.JOIN,
+ )
# We should *NOT* be `newly_joined` because we joined before the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
+ self.assertTrue(room_id1 not in newly_joined)
+ self.assertTrue(room_id1 not in newly_left)
- def test_display_name_changes_leave_after_token_range(
+ def test_newly_joined_display_name_changes_leave_after_token_range(
self,
) -> None:
"""
Test that we point to the correct membership event within the from/to range even
- if there are multiple `join` membership events in a row indicating
- `displayname`/`avatar_url` updates and we leave after the `to_token`.
+ if we are `newly_joined` and there are multiple `join` membership events in a
+ row indicating `displayname`/`avatar_url` updates and we leave after the
+ `to_token`.
See condition "1a)" comments in the `get_room_membership_for_user_at_to_token()` method.
"""
@@ -1954,6 +2346,7 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
# leave and can still re-join.
room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
join_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
+
# Update the displayname during the token range
displayname_change_during_token_range_response = self.helper.send_state(
room_id1,
@@ -1983,19 +2376,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
# Leave after the token
self.helper.leave(room_id1, user1_id, tok=user1_tok)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=before_room1_token,
to_token=after_room1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# Room should show up because we were joined during the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
+ self.assertIncludes(room_id_results, {room_id1}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
displayname_change_during_token_range_response["event_id"],
"Corresponding map to disambiguate the opaque event IDs: "
+ str(
@@ -2010,10 +2418,117 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
}
),
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.JOIN,
+ )
# We should be `newly_joined` because we joined during the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, True)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
+ self.assertTrue(room_id1 in newly_joined)
+ self.assertTrue(room_id1 not in newly_left)
+
+ def test_display_name_changes_leave_after_token_range(
+ self,
+ ) -> None:
+ """
+ Test that we point to the correct membership event within the from/to range even
+ if there are multiple `join` membership events in a row indicating
+ `displayname`/`avatar_url` updates and we leave after the `to_token`.
+
+ See condition "1a)" comments in the `get_room_membership_for_user_at_to_token()` method.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ _before_room1_token = self.event_sources.get_current_token()
+
+ # We create the room with user2 so the room isn't left with no members when we
+ # leave and can still re-join.
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
+ join_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ after_join_token = self.event_sources.get_current_token()
+
+ # Update the displayname during the token range
+ displayname_change_during_token_range_response = self.helper.send_state(
+ room_id1,
+ event_type=EventTypes.Member,
+ state_key=user1_id,
+ body={
+ "membership": Membership.JOIN,
+ "displayname": "displayname during token range",
+ },
+ tok=user1_tok,
+ )
+
+ after_display_name_change_token = self.event_sources.get_current_token()
+
+ # Update the displayname after the token range
+ displayname_change_after_token_range_response = self.helper.send_state(
+ room_id1,
+ event_type=EventTypes.Member,
+ state_key=user1_id,
+ body={
+ "membership": Membership.JOIN,
+ "displayname": "displayname after token range",
+ },
+ tok=user1_tok,
+ )
+
+ # Leave after the token
+ self.helper.leave(room_id1, user1_id, tok=user1_tok)
+
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
+ from_token=after_join_token,
+ to_token=after_display_name_change_token,
+ )
+ )
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
+
+ # Room should show up because we were joined during the from/to range
+ self.assertIncludes(room_id_results, {room_id1}, exact=True)
+ # It should be pointing to the latest membership event in the from/to range
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
+ displayname_change_during_token_range_response["event_id"],
+ "Corresponding map to disambiguate the opaque event IDs: "
+ + str(
+ {
+ "join_response": join_response["event_id"],
+ "displayname_change_during_token_range_response": displayname_change_during_token_range_response[
+ "event_id"
+ ],
+ "displayname_change_after_token_range_response": displayname_change_after_token_range_response[
+ "event_id"
+ ],
+ }
+ ),
+ )
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.JOIN,
+ )
+ # We only changed our display name during the token range so we shouldn't be
+ # considered `newly_joined` or `newly_left`
+ self.assertTrue(room_id1 not in newly_joined)
+ self.assertTrue(room_id1 not in newly_left)
def test_display_name_changes_join_after_token_range(
self,
@@ -2051,16 +2566,29 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
tok=user1_tok,
)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=before_room1_token,
to_token=after_room1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
# Room shouldn't show up because we joined after the from/to range
- self.assertEqual(room_id_results.keys(), set())
+ self.assertIncludes(room_id_results, set(), exact=True)
def test_newly_joined_with_leave_join_in_token_range(
self,
@@ -2087,26 +2615,44 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
after_more_changes_token = self.event_sources.get_current_token()
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=after_room1_token,
to_token=after_more_changes_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# Room should show up because we were joined during the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
+ self.assertIncludes(room_id_results, {room_id1}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
join_response2["event_id"],
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.JOIN,
+ )
# We should be considered `newly_joined` because there is some non-join event in
# between our latest join event.
- self.assertEqual(room_id_results[room_id1].newly_joined, True)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
+ self.assertTrue(room_id1 in newly_joined)
+ self.assertTrue(room_id1 not in newly_left)
def test_newly_joined_only_joins_during_token_range(
self,
@@ -2152,19 +2698,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
after_room1_token = self.event_sources.get_current_token()
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=before_room1_token,
to_token=after_room1_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# Room should show up because it was newly_left and joined during the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1})
+ self.assertIncludes(room_id_results, {room_id1}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
displayname_change_during_token_range_response2["event_id"],
"Corresponding map to disambiguate the opaque event IDs: "
+ str(
@@ -2179,10 +2740,13 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
}
),
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.JOIN,
+ )
# We should be `newly_joined` because we first joined during the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, True)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
+ self.assertTrue(room_id1 in newly_joined)
+ self.assertTrue(room_id1 not in newly_left)
def test_multiple_rooms_are_not_confused(
self,
@@ -2205,7 +2769,7 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
# Invited and left the room before the token
self.helper.invite(room_id1, src=user2_id, targ=user1_id, tok=user2_tok)
- leave_room1_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
+ _leave_room1_response = self.helper.leave(room_id1, user1_id, tok=user1_tok)
# Invited to room2
invite_room2_response = self.helper.invite(
room_id2, src=user2_id, targ=user1_id, tok=user2_tok
@@ -2228,61 +2792,71 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
# Leave room3
self.helper.leave(room_id3, user1_id, tok=user1_tok)
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=before_room3_token,
to_token=after_room3_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
- self.assertEqual(
- room_id_results.keys(),
+ self.assertIncludes(
+ room_id_results,
{
- # Left before the from/to range
- room_id1,
+ # Excluded because we left before the from/to range
+ # room_id1,
# Invited before the from/to range
room_id2,
# `newly_left` during the from/to range
room_id3,
},
+ exact=True,
)
- # Room1
- # It should be pointing to the latest membership event in the from/to range
- self.assertEqual(
- room_id_results[room_id1].event_id,
- leave_room1_response["event_id"],
- )
- self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE)
- # We should *NOT* be `newly_joined`/`newly_left` because we were invited and left
- # before the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
-
# Room2
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id2].event_id,
+ interested_rooms.room_membership_for_user_map[room_id2].event_id,
invite_room2_response["event_id"],
)
- self.assertEqual(room_id_results[room_id2].membership, Membership.INVITE)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id2].membership,
+ Membership.INVITE,
+ )
# We should *NOT* be `newly_joined`/`newly_left` because we were invited before
# the token range
- self.assertEqual(room_id_results[room_id2].newly_joined, False)
- self.assertEqual(room_id_results[room_id2].newly_left, False)
+ self.assertTrue(room_id2 not in newly_joined)
+ self.assertTrue(room_id2 not in newly_left)
# Room3
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id3].event_id,
+ interested_rooms.room_membership_for_user_map[room_id3].event_id,
leave_room3_response["event_id"],
)
- self.assertEqual(room_id_results[room_id3].membership, Membership.LEAVE)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id3].membership,
+ Membership.LEAVE,
+ )
# We should be `newly_left` because we were invited and left during
# the token range
- self.assertEqual(room_id_results[room_id3].newly_joined, False)
- self.assertEqual(room_id_results[room_id3].newly_left, True)
+ self.assertTrue(room_id3 not in newly_joined)
+ self.assertTrue(room_id3 in newly_left)
def test_state_reset(self) -> None:
"""
@@ -2295,7 +2869,16 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
user2_tok = self.login(user2_id, "pass")
# The room where the state reset will happen
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ room_id1 = self.helper.create_room_as(
+ user2_id,
+ is_public=True,
+ tok=user2_tok,
+ )
+ # Create a dummy event for us to point back to for the state reset
+ dummy_event_response = self.helper.send(room_id1, "test", tok=user2_tok)
+ dummy_event_id = dummy_event_response["event_id"]
+
+ # Join after the dummy event
join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
# Join another room so we don't hit the short-circuit and return early if they
@@ -2305,95 +2888,106 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase):
before_reset_token = self.event_sources.get_current_token()
- # Send another state event to make a position for the state reset to happen at
- dummy_state_response = self.helper.send_state(
- room_id1,
- event_type="foobarbaz",
- state_key="",
- body={"foo": "bar"},
- tok=user2_tok,
- )
- dummy_state_pos = self.get_success(
- self.store.get_position_for_event(dummy_state_response["event_id"])
- )
-
- # Mock a state reset removing the membership for user1 in the current state
- self.get_success(
- self.store.db_pool.simple_delete(
- table="current_state_events",
- keyvalues={
- "room_id": room_id1,
- "type": EventTypes.Member,
- "state_key": user1_id,
- },
- desc="state reset user in current_state_events",
+ # Trigger a state reset
+ join_rule_event, join_rule_context = self.get_success(
+ create_event(
+ self.hs,
+ prev_event_ids=[dummy_event_id],
+ type=EventTypes.JoinRules,
+ state_key="",
+ content={"join_rule": JoinRules.INVITE},
+ sender=user2_id,
+ room_id=room_id1,
+ room_version=self.get_success(self.store.get_room_version_id(room_id1)),
)
)
- self.get_success(
- self.store.db_pool.simple_delete(
- table="local_current_membership",
- keyvalues={
- "room_id": room_id1,
- "user_id": user1_id,
- },
- desc="state reset user in local_current_membership",
- )
- )
- self.get_success(
- self.store.db_pool.simple_insert(
- table="current_state_delta_stream",
- values={
- "stream_id": dummy_state_pos.stream,
- "room_id": room_id1,
- "type": EventTypes.Member,
- "state_key": user1_id,
- "event_id": None,
- "prev_event_id": join_response1["event_id"],
- "instance_name": dummy_state_pos.instance_name,
- },
- desc="state reset user in current_state_delta_stream",
- )
+ _, join_rule_event_pos, _ = self.get_success(
+ self.persistence.persist_event(join_rule_event, join_rule_context)
)
- # Manually bust the cache since we we're just manually messing with the database
- # and not causing an actual state reset.
- self.store._membership_stream_cache.entity_has_changed(
- user1_id, dummy_state_pos.stream
- )
+ # Ensure that the state reset worked and only user2 is in the room now
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id1))
+ self.assertIncludes(set(users_in_room), {user2_id}, exact=True)
after_reset_token = self.event_sources.get_current_token()
# The function under test
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=before_reset_token,
to_token=after_reset_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
# Room1 should show up because it was `newly_left` via state reset during the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1, room_id2})
+ self.assertIncludes(room_id_results, {room_id1, room_id2}, exact=True)
# It should be pointing to no event because we were removed from the room
# without a corresponding leave event
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
None,
+ "Corresponding map to disambiguate the opaque event IDs: "
+ + str(
+ {
+ "join_response1": join_response1["event_id"],
+ }
+ ),
)
# State reset caused us to leave the room and there is no corresponding leave event
- self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.LEAVE,
+ )
# We should *NOT* be `newly_joined` because we joined before the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
+ self.assertTrue(room_id1 not in newly_joined)
# We should be `newly_left` because we were removed via state reset during the from/to range
- self.assertEqual(room_id_results[room_id1].newly_left, True)
-
-
-class GetRoomMembershipForUserAtToTokenShardTestCase(BaseMultiWorkerStreamTestCase):
+ self.assertTrue(room_id1 in newly_left)
+
+
+# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the
+# foreground update for
+# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by
+# https://github.com/element-hq/synapse/issues/17623)
+@parameterized_class(
+ ("use_new_tables",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}",
+)
+class ComputeInterestedRoomsShardTestCase(
+ BaseMultiWorkerStreamTestCase, SlidingSyncBase
+):
"""
- Tests Sliding Sync handler `get_room_membership_for_user_at_to_token()` to make sure it works with
+ Tests Sliding Sync handler `compute_interested_rooms()` to make sure it works with
sharded event stream_writers enabled
"""
+ # FIXME: We should refactor these tests to run against `compute_interested_rooms(...)`
+ # instead of just `get_room_membership_for_user_at_to_token(...)` which is only used
+ # in the fallback path (`_compute_interested_rooms_fallback(...)`). These scenarios do
+ # well to stress that logic and we shouldn't remove them just because we're removing
+ # the fallback path (tracked by https://github.com/element-hq/synapse/issues/17623).
+
servlets = [
admin.register_servlets_for_client_rest_resource,
room.register_servlets,
@@ -2488,7 +3082,7 @@ class GetRoomMembershipForUserAtToTokenShardTestCase(BaseMultiWorkerStreamTestCa
join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
join_response2 = self.helper.join(room_id2, user1_id, tok=user1_tok)
# Leave room2
- leave_room2_response = self.helper.leave(room_id2, user1_id, tok=user1_tok)
+ _leave_room2_response = self.helper.leave(room_id2, user1_id, tok=user1_tok)
join_response3 = self.helper.join(room_id3, user1_id, tok=user1_tok)
# Leave room3
self.helper.leave(room_id3, user1_id, tok=user1_tok)
@@ -2578,60 +3172,77 @@ class GetRoomMembershipForUserAtToTokenShardTestCase(BaseMultiWorkerStreamTestCa
self.get_success(actx.__aexit__(None, None, None))
# The function under test
- room_id_results = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- UserID.from_string(user1_id),
+ interested_rooms = self.get_success(
+ self.sliding_sync_handler.room_lists.compute_interested_rooms(
+ SlidingSyncConfig(
+ user=UserID.from_string(user1_id),
+ requester=create_requester(user_id=user1_id),
+ lists={
+ "foo-list": SlidingSyncConfig.SlidingSyncList(
+ ranges=[(0, 99)],
+ required_state=[],
+ timeline_limit=1,
+ )
+ },
+ conn_id=None,
+ ),
+ PerConnectionState(),
from_token=before_stuck_activity_token,
to_token=stuck_activity_token,
)
)
+ room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids)
+ newly_joined = interested_rooms.newly_joined_rooms
+ newly_left = interested_rooms.newly_left_rooms
- self.assertEqual(
- room_id_results.keys(),
+ self.assertIncludes(
+ room_id_results,
{
room_id1,
- room_id2,
+ # Excluded because we left before the from/to range and the second join
+ # event happened while worker2 was stuck and technically occurs after
+ # the `stuck_activity_token`.
+ # room_id2,
room_id3,
},
+ exact=True,
+ message="Corresponding map to disambiguate the opaque room IDs: "
+ + str(
+ {
+ "room_id1": room_id1,
+ "room_id2": room_id2,
+ "room_id3": room_id3,
+ }
+ ),
)
# Room1
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id1].event_id,
+ interested_rooms.room_membership_for_user_map[room_id1].event_id,
join_room1_response["event_id"],
)
- self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN)
- # We should be `newly_joined` because we joined during the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, True)
- self.assertEqual(room_id_results[room_id1].newly_left, False)
-
- # Room2
- # It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id2].event_id,
- leave_room2_response["event_id"],
- )
- self.assertEqual(room_id_results[room_id2].membership, Membership.LEAVE)
- # room_id2 should *NOT* be considered `newly_left` because we left before the
- # from/to range and the join event during the range happened while worker2 was
- # stuck. This means that from the perspective of the master, where the
- # `stuck_activity_token` is generated, the stream position for worker2 wasn't
- # advanced to the join yet. Looking at the `instance_map`, the join technically
- # comes after `stuck_activity_token`.
- self.assertEqual(room_id_results[room_id2].newly_joined, False)
- self.assertEqual(room_id_results[room_id2].newly_left, False)
+ interested_rooms.room_membership_for_user_map[room_id1].membership,
+ Membership.JOIN,
+ )
+ # We should be `newly_joined` because we joined during the token range
+ self.assertTrue(room_id1 in newly_joined)
+ self.assertTrue(room_id1 not in newly_left)
# Room3
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
- room_id_results[room_id3].event_id,
+ interested_rooms.room_membership_for_user_map[room_id3].event_id,
join_on_worker3_response["event_id"],
)
- self.assertEqual(room_id_results[room_id3].membership, Membership.JOIN)
+ self.assertEqual(
+ interested_rooms.room_membership_for_user_map[room_id3].membership,
+ Membership.JOIN,
+ )
# We should be `newly_joined` because we joined during the token range
- self.assertEqual(room_id_results[room_id3].newly_joined, True)
- self.assertEqual(room_id_results[room_id3].newly_left, False)
+ self.assertTrue(room_id3 in newly_joined)
+ self.assertTrue(room_id3 not in newly_left)
class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase):
@@ -2658,31 +3269,35 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase):
self.store = self.hs.get_datastores().main
self.event_sources = hs.get_event_sources()
self.storage_controllers = hs.get_storage_controllers()
+ persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
+ self.persistence = persistence
def _get_sync_room_ids_for_user(
self,
user: UserID,
to_token: StreamToken,
from_token: Optional[StreamToken],
- ) -> Dict[str, _RoomMembershipForUser]:
+ ) -> Tuple[Dict[str, RoomsForUserType], AbstractSet[str], AbstractSet[str]]:
"""
Get the rooms the user should be syncing with
"""
- room_membership_for_user_map = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
+ room_membership_for_user_map, newly_joined, newly_left = self.get_success(
+ self.sliding_sync_handler.room_lists.get_room_membership_for_user_at_to_token(
user=user,
from_token=from_token,
to_token=to_token,
)
)
filtered_sync_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms_relevant_for_sync(
+ self.sliding_sync_handler.room_lists.filter_rooms_relevant_for_sync(
user=user,
room_membership_for_user_map=room_membership_for_user_map,
+ newly_left_room_ids=newly_left,
)
)
- return filtered_sync_room_map
+ return filtered_sync_room_map, newly_joined, newly_left
def test_no_rooms(self) -> None:
"""
@@ -2693,13 +3308,13 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase):
now_token = self.event_sources.get_current_token()
- room_id_results = self._get_sync_room_ids_for_user(
+ room_id_results, newly_joined, newly_left = self._get_sync_room_ids_for_user(
UserID.from_string(user1_id),
from_token=now_token,
to_token=now_token,
)
- self.assertEqual(room_id_results.keys(), set())
+ self.assertIncludes(room_id_results.keys(), set(), exact=True)
def test_basic_rooms(self) -> None:
"""
@@ -2758,14 +3373,14 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase):
after_room_token = self.event_sources.get_current_token()
- room_id_results = self._get_sync_room_ids_for_user(
+ room_id_results, newly_joined, newly_left = self._get_sync_room_ids_for_user(
UserID.from_string(user1_id),
from_token=before_room_token,
to_token=after_room_token,
)
# Ensure that the invited, ban, and knock rooms show up
- self.assertEqual(
+ self.assertIncludes(
room_id_results.keys(),
{
join_room_id,
@@ -2773,6 +3388,7 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase):
ban_room_id,
knock_room_id,
},
+ exact=True,
)
# It should be pointing to the the respective membership event (latest
# membership event in the from/to range)
@@ -2781,32 +3397,32 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase):
join_response["event_id"],
)
self.assertEqual(room_id_results[join_room_id].membership, Membership.JOIN)
- self.assertEqual(room_id_results[join_room_id].newly_joined, True)
- self.assertEqual(room_id_results[join_room_id].newly_left, False)
+ self.assertTrue(join_room_id in newly_joined)
+ self.assertTrue(join_room_id not in newly_left)
self.assertEqual(
room_id_results[invited_room_id].event_id,
invite_response["event_id"],
)
self.assertEqual(room_id_results[invited_room_id].membership, Membership.INVITE)
- self.assertEqual(room_id_results[invited_room_id].newly_joined, False)
- self.assertEqual(room_id_results[invited_room_id].newly_left, False)
+ self.assertTrue(invited_room_id not in newly_joined)
+ self.assertTrue(invited_room_id not in newly_left)
self.assertEqual(
room_id_results[ban_room_id].event_id,
ban_response["event_id"],
)
self.assertEqual(room_id_results[ban_room_id].membership, Membership.BAN)
- self.assertEqual(room_id_results[ban_room_id].newly_joined, False)
- self.assertEqual(room_id_results[ban_room_id].newly_left, False)
+ self.assertTrue(ban_room_id not in newly_joined)
+ self.assertTrue(ban_room_id not in newly_left)
self.assertEqual(
room_id_results[knock_room_id].event_id,
knock_room_membership_state_event.event_id,
)
self.assertEqual(room_id_results[knock_room_id].membership, Membership.KNOCK)
- self.assertEqual(room_id_results[knock_room_id].newly_joined, False)
- self.assertEqual(room_id_results[knock_room_id].newly_left, False)
+ self.assertTrue(knock_room_id not in newly_joined)
+ self.assertTrue(knock_room_id not in newly_left)
def test_only_newly_left_rooms_show_up(self) -> None:
"""
@@ -2829,21 +3445,21 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase):
after_room2_token = self.event_sources.get_current_token()
- room_id_results = self._get_sync_room_ids_for_user(
+ room_id_results, newly_joined, newly_left = self._get_sync_room_ids_for_user(
UserID.from_string(user1_id),
from_token=after_room1_token,
to_token=after_room2_token,
)
# Only the `newly_left` room should show up
- self.assertEqual(room_id_results.keys(), {room_id2})
+ self.assertIncludes(room_id_results.keys(), {room_id2}, exact=True)
self.assertEqual(
room_id_results[room_id2].event_id,
_leave_response2["event_id"],
)
# We should *NOT* be `newly_joined` because we are instead `newly_left`
- self.assertEqual(room_id_results[room_id2].newly_joined, False)
- self.assertEqual(room_id_results[room_id2].newly_left, True)
+ self.assertTrue(room_id2 not in newly_joined)
+ self.assertTrue(room_id2 in newly_left)
def test_get_kicked_room(self) -> None:
"""
@@ -2874,14 +3490,14 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase):
after_kick_token = self.event_sources.get_current_token()
- room_id_results = self._get_sync_room_ids_for_user(
+ room_id_results, newly_joined, newly_left = self._get_sync_room_ids_for_user(
UserID.from_string(user1_id),
from_token=after_kick_token,
to_token=after_kick_token,
)
# The kicked room should show up
- self.assertEqual(room_id_results.keys(), {kick_room_id})
+ self.assertIncludes(room_id_results.keys(), {kick_room_id}, exact=True)
# It should be pointing to the latest membership event in the from/to range
self.assertEqual(
room_id_results[kick_room_id].event_id,
@@ -2891,8 +3507,8 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase):
self.assertNotEqual(room_id_results[kick_room_id].sender, user1_id)
# We should *NOT* be `newly_joined` because we were not joined at the the time
# of the `to_token`.
- self.assertEqual(room_id_results[kick_room_id].newly_joined, False)
- self.assertEqual(room_id_results[kick_room_id].newly_left, False)
+ self.assertTrue(kick_room_id not in newly_joined)
+ self.assertTrue(kick_room_id not in newly_left)
def test_state_reset(self) -> None:
"""
@@ -2905,8 +3521,17 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase):
user2_tok = self.login(user2_id, "pass")
# The room where the state reset will happen
- room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
- join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
+ room_id1 = self.helper.create_room_as(
+ user2_id,
+ is_public=True,
+ tok=user2_tok,
+ )
+ # Create a dummy event for us to point back to for the state reset
+ dummy_event_response = self.helper.send(room_id1, "test", tok=user2_tok)
+ dummy_event_id = dummy_event_response["event_id"]
+
+ # Join after the dummy event
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
# Join another room so we don't hit the short-circuit and return early if they
# have no room membership
@@ -2915,73 +3540,38 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase):
before_reset_token = self.event_sources.get_current_token()
- # Send another state event to make a position for the state reset to happen at
- dummy_state_response = self.helper.send_state(
- room_id1,
- event_type="foobarbaz",
- state_key="",
- body={"foo": "bar"},
- tok=user2_tok,
- )
- dummy_state_pos = self.get_success(
- self.store.get_position_for_event(dummy_state_response["event_id"])
- )
-
- # Mock a state reset removing the membership for user1 in the current state
- self.get_success(
- self.store.db_pool.simple_delete(
- table="current_state_events",
- keyvalues={
- "room_id": room_id1,
- "type": EventTypes.Member,
- "state_key": user1_id,
- },
- desc="state reset user in current_state_events",
+ # Trigger a state reset
+ join_rule_event, join_rule_context = self.get_success(
+ create_event(
+ self.hs,
+ prev_event_ids=[dummy_event_id],
+ type=EventTypes.JoinRules,
+ state_key="",
+ content={"join_rule": JoinRules.INVITE},
+ sender=user2_id,
+ room_id=room_id1,
+ room_version=self.get_success(self.store.get_room_version_id(room_id1)),
)
)
- self.get_success(
- self.store.db_pool.simple_delete(
- table="local_current_membership",
- keyvalues={
- "room_id": room_id1,
- "user_id": user1_id,
- },
- desc="state reset user in local_current_membership",
- )
- )
- self.get_success(
- self.store.db_pool.simple_insert(
- table="current_state_delta_stream",
- values={
- "stream_id": dummy_state_pos.stream,
- "room_id": room_id1,
- "type": EventTypes.Member,
- "state_key": user1_id,
- "event_id": None,
- "prev_event_id": join_response1["event_id"],
- "instance_name": dummy_state_pos.instance_name,
- },
- desc="state reset user in current_state_delta_stream",
- )
+ _, join_rule_event_pos, _ = self.get_success(
+ self.persistence.persist_event(join_rule_event, join_rule_context)
)
- # Manually bust the cache since we we're just manually messing with the database
- # and not causing an actual state reset.
- self.store._membership_stream_cache.entity_has_changed(
- user1_id, dummy_state_pos.stream
- )
+ # Ensure that the state reset worked and only user2 is in the room now
+ users_in_room = self.get_success(self.store.get_users_in_room(room_id1))
+ self.assertIncludes(set(users_in_room), {user2_id}, exact=True)
after_reset_token = self.event_sources.get_current_token()
# The function under test
- room_id_results = self._get_sync_room_ids_for_user(
+ room_id_results, newly_joined, newly_left = self._get_sync_room_ids_for_user(
UserID.from_string(user1_id),
from_token=before_reset_token,
to_token=after_reset_token,
)
# Room1 should show up because it was `newly_left` via state reset during the from/to range
- self.assertEqual(room_id_results.keys(), {room_id1, room_id2})
+ self.assertIncludes(room_id_results.keys(), {room_id1, room_id2}, exact=True)
# It should be pointing to no event because we were removed from the room
# without a corresponding leave event
self.assertEqual(
@@ -2991,1345 +3581,9 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase):
# State reset caused us to leave the room and there is no corresponding leave event
self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE)
# We should *NOT* be `newly_joined` because we joined before the token range
- self.assertEqual(room_id_results[room_id1].newly_joined, False)
+ self.assertTrue(room_id1 not in newly_joined)
# We should be `newly_left` because we were removed via state reset during the from/to range
- self.assertEqual(room_id_results[room_id1].newly_left, True)
-
-
-class FilterRoomsTestCase(HomeserverTestCase):
- """
- Tests Sliding Sync handler `filter_rooms()` to make sure it includes/excludes rooms
- correctly.
- """
-
- servlets = [
- admin.register_servlets,
- knock.register_servlets,
- login.register_servlets,
- room.register_servlets,
- ]
-
- def default_config(self) -> JsonDict:
- config = super().default_config()
- # Enable sliding sync
- config["experimental_features"] = {"msc3575_enabled": True}
- return config
-
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.sliding_sync_handler = self.hs.get_sliding_sync_handler()
- self.store = self.hs.get_datastores().main
- self.event_sources = hs.get_event_sources()
-
- def _get_sync_room_ids_for_user(
- self,
- user: UserID,
- to_token: StreamToken,
- from_token: Optional[StreamToken],
- ) -> Dict[str, _RoomMembershipForUser]:
- """
- Get the rooms the user should be syncing with
- """
- room_membership_for_user_map = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
- user=user,
- from_token=from_token,
- to_token=to_token,
- )
- )
- filtered_sync_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms_relevant_for_sync(
- user=user,
- room_membership_for_user_map=room_membership_for_user_map,
- )
- )
-
- return filtered_sync_room_map
-
- def _create_dm_room(
- self,
- inviter_user_id: str,
- inviter_tok: str,
- invitee_user_id: str,
- invitee_tok: str,
- ) -> str:
- """
- Helper to create a DM room as the "inviter" and invite the "invitee" user to the room. The
- "invitee" user also will join the room. The `m.direct` account data will be set
- for both users.
- """
-
- # Create a room and send an invite the other user
- room_id = self.helper.create_room_as(
- inviter_user_id,
- is_public=False,
- tok=inviter_tok,
- )
- self.helper.invite(
- room_id,
- src=inviter_user_id,
- targ=invitee_user_id,
- tok=inviter_tok,
- extra_data={"is_direct": True},
- )
- # Person that was invited joins the room
- self.helper.join(room_id, invitee_user_id, tok=invitee_tok)
-
- # Mimic the client setting the room as a direct message in the global account
- # data
- self.get_success(
- self.store.add_account_data_for_user(
- invitee_user_id,
- AccountDataTypes.DIRECT,
- {inviter_user_id: [room_id]},
- )
- )
- self.get_success(
- self.store.add_account_data_for_user(
- inviter_user_id,
- AccountDataTypes.DIRECT,
- {invitee_user_id: [room_id]},
- )
- )
-
- return room_id
-
- _remote_invite_count: int = 0
-
- def _create_remote_invite_room_for_user(
- self,
- invitee_user_id: str,
- unsigned_invite_room_state: Optional[List[StrippedStateEvent]],
- ) -> str:
- """
- Create a fake invite for a remote room and persist it.
-
- We don't have any state for these kind of rooms and can only rely on the
- stripped state included in the unsigned portion of the invite event to identify
- the room.
-
- Args:
- invitee_user_id: The person being invited
- unsigned_invite_room_state: List of stripped state events to assist the
- receiver in identifying the room.
-
- Returns:
- The room ID of the remote invite room
- """
- invite_room_id = f"!test_room{self._remote_invite_count}:remote_server"
-
- invite_event_dict = {
- "room_id": invite_room_id,
- "sender": "@inviter:remote_server",
- "state_key": invitee_user_id,
- "depth": 1,
- "origin_server_ts": 1,
- "type": EventTypes.Member,
- "content": {"membership": Membership.INVITE},
- "auth_events": [],
- "prev_events": [],
- }
- if unsigned_invite_room_state is not None:
- serialized_stripped_state_events = []
- for stripped_event in unsigned_invite_room_state:
- serialized_stripped_state_events.append(
- {
- "type": stripped_event.type,
- "state_key": stripped_event.state_key,
- "sender": stripped_event.sender,
- "content": stripped_event.content,
- }
- )
-
- invite_event_dict["unsigned"] = {
- "invite_room_state": serialized_stripped_state_events
- }
-
- invite_event = make_event_from_dict(
- invite_event_dict,
- room_version=RoomVersions.V10,
- )
- invite_event.internal_metadata.outlier = True
- invite_event.internal_metadata.out_of_band_membership = True
-
- self.get_success(
- self.store.maybe_store_room_on_outlier_membership(
- room_id=invite_room_id, room_version=invite_event.room_version
- )
- )
- context = EventContext.for_outlier(self.hs.get_storage_controllers())
- persist_controller = self.hs.get_storage_controllers().persistence
- assert persist_controller is not None
- self.get_success(persist_controller.persist_event(invite_event, context))
-
- self._remote_invite_count += 1
-
- return invite_room_id
-
- def test_filter_dm_rooms(self) -> None:
- """
- Test `filter.is_dm` for DM rooms
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- # Create a normal room
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # Create a DM room
- dm_room_id = self._create_dm_room(
- inviter_user_id=user1_id,
- inviter_tok=user1_tok,
- invitee_user_id=user2_id,
- invitee_tok=user2_tok,
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try with `is_dm=True`
- truthy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_dm=True,
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(truthy_filtered_room_map.keys(), {dm_room_id})
-
- # Try with `is_dm=False`
- falsy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_dm=False,
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
-
- def test_filter_encrypted_rooms(self) -> None:
- """
- Test `filter.is_encrypted` for encrypted rooms
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Create an unencrypted room
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # Create an encrypted room
- encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
- self.helper.send_state(
- encrypted_room_id,
- EventTypes.RoomEncryption,
- {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
- tok=user1_tok,
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try with `is_encrypted=True`
- truthy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=True,
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(truthy_filtered_room_map.keys(), {encrypted_room_id})
-
- # Try with `is_encrypted=False`
- falsy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=False,
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
-
- def test_filter_encrypted_server_left_room(self) -> None:
- """
- Test that we can apply a `filter.is_encrypted` against a room that everyone has left.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- before_rooms_token = self.event_sources.get_current_token()
-
- # Create an unencrypted room
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
- # Leave the room
- self.helper.leave(room_id, user1_id, tok=user1_tok)
-
- # Create an encrypted room
- encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
- self.helper.send_state(
- encrypted_room_id,
- EventTypes.RoomEncryption,
- {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
- tok=user1_tok,
- )
- # Leave the room
- self.helper.leave(encrypted_room_id, user1_id, tok=user1_tok)
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- # We're using a `from_token` so that the room is considered `newly_left` and
- # appears in our list of relevant sync rooms
- from_token=before_rooms_token,
- to_token=after_rooms_token,
- )
-
- # Try with `is_encrypted=True`
- truthy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=True,
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(truthy_filtered_room_map.keys(), {encrypted_room_id})
-
- # Try with `is_encrypted=False`
- falsy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=False,
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
-
- def test_filter_encrypted_server_left_room2(self) -> None:
- """
- Test that we can apply a `filter.is_encrypted` against a room that everyone has
- left.
-
- There is still someone local who is invited to the rooms but that doesn't affect
- whether the server is participating in the room (users need to be joined).
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- _user2_tok = self.login(user2_id, "pass")
-
- before_rooms_token = self.event_sources.get_current_token()
-
- # Create an unencrypted room
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
- # Invite user2
- self.helper.invite(room_id, targ=user2_id, tok=user1_tok)
- # User1 leaves the room
- self.helper.leave(room_id, user1_id, tok=user1_tok)
-
- # Create an encrypted room
- encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
- self.helper.send_state(
- encrypted_room_id,
- EventTypes.RoomEncryption,
- {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
- tok=user1_tok,
- )
- # Invite user2
- self.helper.invite(encrypted_room_id, targ=user2_id, tok=user1_tok)
- # User1 leaves the room
- self.helper.leave(encrypted_room_id, user1_id, tok=user1_tok)
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- # We're using a `from_token` so that the room is considered `newly_left` and
- # appears in our list of relevant sync rooms
- from_token=before_rooms_token,
- to_token=after_rooms_token,
- )
-
- # Try with `is_encrypted=True`
- truthy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=True,
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(truthy_filtered_room_map.keys(), {encrypted_room_id})
-
- # Try with `is_encrypted=False`
- falsy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=False,
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
-
- def test_filter_encrypted_after_we_left(self) -> None:
- """
- Test that we can apply a `filter.is_encrypted` against a room that was encrypted
- after we left the room (make sure we don't just use the current state)
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- before_rooms_token = self.event_sources.get_current_token()
-
- # Create an unencrypted room
- room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
- # Leave the room
- self.helper.join(room_id, user1_id, tok=user1_tok)
- self.helper.leave(room_id, user1_id, tok=user1_tok)
-
- # Create a room that will be encrypted
- encrypted_after_we_left_room_id = self.helper.create_room_as(
- user2_id, tok=user2_tok
- )
- # Leave the room
- self.helper.join(encrypted_after_we_left_room_id, user1_id, tok=user1_tok)
- self.helper.leave(encrypted_after_we_left_room_id, user1_id, tok=user1_tok)
-
- # Encrypt the room after we've left
- self.helper.send_state(
- encrypted_after_we_left_room_id,
- EventTypes.RoomEncryption,
- {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
- tok=user2_tok,
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- # We're using a `from_token` so that the room is considered `newly_left` and
- # appears in our list of relevant sync rooms
- from_token=before_rooms_token,
- to_token=after_rooms_token,
- )
-
- # Try with `is_encrypted=True`
- truthy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=True,
- ),
- after_rooms_token,
- )
- )
-
- # Even though we left the room before it was encrypted, we still see it because
- # someone else on our server is still participating in the room and we "leak"
- # the current state to the left user. But we consider the room encryption status
- # to not be a secret given it's often set at the start of the room and it's one
- # of the stripped state events that is normally handed out.
- self.assertEqual(
- truthy_filtered_room_map.keys(), {encrypted_after_we_left_room_id}
- )
-
- # Try with `is_encrypted=False`
- falsy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=False,
- ),
- after_rooms_token,
- )
- )
-
- # Even though we left the room before it was encrypted... (see comment above)
- self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
-
- def test_filter_encrypted_with_remote_invite_room_no_stripped_state(self) -> None:
- """
- Test that we can apply a `filter.is_encrypted` filter against a remote invite
- room without any `unsigned.invite_room_state` (stripped state).
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Create a remote invite room without any `unsigned.invite_room_state`
- _remote_invite_room_id = self._create_remote_invite_room_for_user(
- user1_id, None
- )
-
- # Create an unencrypted room
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # Create an encrypted room
- encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
- self.helper.send_state(
- encrypted_room_id,
- EventTypes.RoomEncryption,
- {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
- tok=user1_tok,
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try with `is_encrypted=True`
- truthy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=True,
- ),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should not appear because we can't figure out whether
- # it is encrypted or not (no stripped state, `unsigned.invite_room_state`).
- self.assertEqual(truthy_filtered_room_map.keys(), {encrypted_room_id})
-
- # Try with `is_encrypted=False`
- falsy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=False,
- ),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should not appear because we can't figure out whether
- # it is encrypted or not (no stripped state, `unsigned.invite_room_state`).
- self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
-
- def test_filter_encrypted_with_remote_invite_encrypted_room(self) -> None:
- """
- Test that we can apply a `filter.is_encrypted` filter against a remote invite
- encrypted room with some `unsigned.invite_room_state` (stripped state).
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Create a remote invite room with some `unsigned.invite_room_state`
- # indicating that the room is encrypted.
- remote_invite_room_id = self._create_remote_invite_room_for_user(
- user1_id,
- [
- StrippedStateEvent(
- type=EventTypes.Create,
- state_key="",
- sender="@inviter:remote_server",
- content={
- EventContentFields.ROOM_CREATOR: "@inviter:remote_server",
- EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier,
- },
- ),
- StrippedStateEvent(
- type=EventTypes.RoomEncryption,
- state_key="",
- sender="@inviter:remote_server",
- content={
- EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2",
- },
- ),
- ],
- )
-
- # Create an unencrypted room
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # Create an encrypted room
- encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
- self.helper.send_state(
- encrypted_room_id,
- EventTypes.RoomEncryption,
- {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
- tok=user1_tok,
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try with `is_encrypted=True`
- truthy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=True,
- ),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should appear here because it is encrypted
- # according to the stripped state
- self.assertEqual(
- truthy_filtered_room_map.keys(), {encrypted_room_id, remote_invite_room_id}
- )
-
- # Try with `is_encrypted=False`
- falsy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=False,
- ),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should not appear here because it is encrypted
- # according to the stripped state
- self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
-
- def test_filter_encrypted_with_remote_invite_unencrypted_room(self) -> None:
- """
- Test that we can apply a `filter.is_encrypted` filter against a remote invite
- unencrypted room with some `unsigned.invite_room_state` (stripped state).
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Create a remote invite room with some `unsigned.invite_room_state`
- # but don't set any room encryption event.
- remote_invite_room_id = self._create_remote_invite_room_for_user(
- user1_id,
- [
- StrippedStateEvent(
- type=EventTypes.Create,
- state_key="",
- sender="@inviter:remote_server",
- content={
- EventContentFields.ROOM_CREATOR: "@inviter:remote_server",
- EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier,
- },
- ),
- # No room encryption event
- ],
- )
-
- # Create an unencrypted room
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # Create an encrypted room
- encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
- self.helper.send_state(
- encrypted_room_id,
- EventTypes.RoomEncryption,
- {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
- tok=user1_tok,
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try with `is_encrypted=True`
- truthy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=True,
- ),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should not appear here because it is unencrypted
- # according to the stripped state
- self.assertEqual(truthy_filtered_room_map.keys(), {encrypted_room_id})
-
- # Try with `is_encrypted=False`
- falsy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_encrypted=False,
- ),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should appear because it is unencrypted according to
- # the stripped state
- self.assertEqual(
- falsy_filtered_room_map.keys(), {room_id, remote_invite_room_id}
- )
-
- def test_filter_invite_rooms(self) -> None:
- """
- Test `filter.is_invite` for rooms that the user has been invited to
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- user2_tok = self.login(user2_id, "pass")
-
- # Create a normal room
- room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.join(room_id, user1_id, tok=user1_tok)
-
- # Create a room that user1 is invited to
- invite_room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
- self.helper.invite(invite_room_id, src=user2_id, targ=user1_id, tok=user2_tok)
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try with `is_invite=True`
- truthy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_invite=True,
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(truthy_filtered_room_map.keys(), {invite_room_id})
-
- # Try with `is_invite=False`
- falsy_filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- is_invite=False,
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
-
- def test_filter_room_types(self) -> None:
- """
- Test `filter.room_types` for different room types
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Create a normal room (no room type)
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # Create a space room
- space_room_id = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- extra_content={
- "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
- },
- )
-
- # Create an arbitrarily typed room
- foo_room_id = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- extra_content={
- "creation_content": {
- EventContentFields.ROOM_TYPE: "org.matrix.foobarbaz"
- }
- },
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try finding only normal rooms
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[None]),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {room_id})
-
- # Try finding only spaces
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[RoomTypes.SPACE]),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {space_room_id})
-
- # Try finding normal rooms and spaces
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- room_types=[None, RoomTypes.SPACE]
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {room_id, space_room_id})
-
- # Try finding an arbitrary room type
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- room_types=["org.matrix.foobarbaz"]
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {foo_room_id})
-
- def test_filter_not_room_types(self) -> None:
- """
- Test `filter.not_room_types` for different room types
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Create a normal room (no room type)
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # Create a space room
- space_room_id = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- extra_content={
- "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
- },
- )
-
- # Create an arbitrarily typed room
- foo_room_id = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- extra_content={
- "creation_content": {
- EventContentFields.ROOM_TYPE: "org.matrix.foobarbaz"
- }
- },
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try finding *NOT* normal rooms
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(not_room_types=[None]),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {space_room_id, foo_room_id})
-
- # Try finding *NOT* spaces
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- not_room_types=[RoomTypes.SPACE]
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {room_id, foo_room_id})
-
- # Try finding *NOT* normal rooms or spaces
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- not_room_types=[None, RoomTypes.SPACE]
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {foo_room_id})
-
- # Test how it behaves when we have both `room_types` and `not_room_types`.
- # `not_room_types` should win.
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- room_types=[None], not_room_types=[None]
- ),
- after_rooms_token,
- )
- )
-
- # Nothing matches because nothing is both a normal room and not a normal room
- self.assertEqual(filtered_room_map.keys(), set())
-
- # Test how it behaves when we have both `room_types` and `not_room_types`.
- # `not_room_types` should win.
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(
- room_types=[None, RoomTypes.SPACE], not_room_types=[None]
- ),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {space_room_id})
-
- def test_filter_room_types_server_left_room(self) -> None:
- """
- Test that we can apply a `filter.room_types` against a room that everyone has left.
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- before_rooms_token = self.event_sources.get_current_token()
-
- # Create a normal room (no room type)
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
- # Leave the room
- self.helper.leave(room_id, user1_id, tok=user1_tok)
-
- # Create a space room
- space_room_id = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- extra_content={
- "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
- },
- )
- # Leave the room
- self.helper.leave(space_room_id, user1_id, tok=user1_tok)
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- # We're using a `from_token` so that the room is considered `newly_left` and
- # appears in our list of relevant sync rooms
- from_token=before_rooms_token,
- to_token=after_rooms_token,
- )
-
- # Try finding only normal rooms
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[None]),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {room_id})
-
- # Try finding only spaces
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[RoomTypes.SPACE]),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {space_room_id})
-
- def test_filter_room_types_server_left_room2(self) -> None:
- """
- Test that we can apply a `filter.room_types` against a room that everyone has left.
-
- There is still someone local who is invited to the rooms but that doesn't affect
- whether the server is participating in the room (users need to be joined).
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
- user2_id = self.register_user("user2", "pass")
- _user2_tok = self.login(user2_id, "pass")
-
- before_rooms_token = self.event_sources.get_current_token()
-
- # Create a normal room (no room type)
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
- # Invite user2
- self.helper.invite(room_id, targ=user2_id, tok=user1_tok)
- # User1 leaves the room
- self.helper.leave(room_id, user1_id, tok=user1_tok)
-
- # Create a space room
- space_room_id = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- extra_content={
- "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
- },
- )
- # Invite user2
- self.helper.invite(space_room_id, targ=user2_id, tok=user1_tok)
- # User1 leaves the room
- self.helper.leave(space_room_id, user1_id, tok=user1_tok)
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- # We're using a `from_token` so that the room is considered `newly_left` and
- # appears in our list of relevant sync rooms
- from_token=before_rooms_token,
- to_token=after_rooms_token,
- )
-
- # Try finding only normal rooms
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[None]),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {room_id})
-
- # Try finding only spaces
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[RoomTypes.SPACE]),
- after_rooms_token,
- )
- )
-
- self.assertEqual(filtered_room_map.keys(), {space_room_id})
-
- def test_filter_room_types_with_remote_invite_room_no_stripped_state(self) -> None:
- """
- Test that we can apply a `filter.room_types` filter against a remote invite
- room without any `unsigned.invite_room_state` (stripped state).
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Create a remote invite room without any `unsigned.invite_room_state`
- _remote_invite_room_id = self._create_remote_invite_room_for_user(
- user1_id, None
- )
-
- # Create a normal room (no room type)
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # Create a space room
- space_room_id = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- extra_content={
- "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
- },
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try finding only normal rooms
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[None]),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should not appear because we can't figure out what
- # room type it is (no stripped state, `unsigned.invite_room_state`)
- self.assertEqual(filtered_room_map.keys(), {room_id})
-
- # Try finding only spaces
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[RoomTypes.SPACE]),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should not appear because we can't figure out what
- # room type it is (no stripped state, `unsigned.invite_room_state`)
- self.assertEqual(filtered_room_map.keys(), {space_room_id})
-
- def test_filter_room_types_with_remote_invite_space(self) -> None:
- """
- Test that we can apply a `filter.room_types` filter against a remote invite
- to a space room with some `unsigned.invite_room_state` (stripped state).
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Create a remote invite room with some `unsigned.invite_room_state` indicating
- # that it is a space room
- remote_invite_room_id = self._create_remote_invite_room_for_user(
- user1_id,
- [
- StrippedStateEvent(
- type=EventTypes.Create,
- state_key="",
- sender="@inviter:remote_server",
- content={
- EventContentFields.ROOM_CREATOR: "@inviter:remote_server",
- EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier,
- # Specify that it is a space room
- EventContentFields.ROOM_TYPE: RoomTypes.SPACE,
- },
- ),
- ],
- )
-
- # Create a normal room (no room type)
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # Create a space room
- space_room_id = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- extra_content={
- "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
- },
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try finding only normal rooms
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[None]),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should not appear here because it is a space room
- # according to the stripped state
- self.assertEqual(filtered_room_map.keys(), {room_id})
-
- # Try finding only spaces
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[RoomTypes.SPACE]),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should appear here because it is a space room
- # according to the stripped state
- self.assertEqual(
- filtered_room_map.keys(), {space_room_id, remote_invite_room_id}
- )
-
- def test_filter_room_types_with_remote_invite_normal_room(self) -> None:
- """
- Test that we can apply a `filter.room_types` filter against a remote invite
- to a normal room with some `unsigned.invite_room_state` (stripped state).
- """
- user1_id = self.register_user("user1", "pass")
- user1_tok = self.login(user1_id, "pass")
-
- # Create a remote invite room with some `unsigned.invite_room_state`
- # but the create event does not specify a room type (normal room)
- remote_invite_room_id = self._create_remote_invite_room_for_user(
- user1_id,
- [
- StrippedStateEvent(
- type=EventTypes.Create,
- state_key="",
- sender="@inviter:remote_server",
- content={
- EventContentFields.ROOM_CREATOR: "@inviter:remote_server",
- EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier,
- # No room type means this is a normal room
- },
- ),
- ],
- )
-
- # Create a normal room (no room type)
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
-
- # Create a space room
- space_room_id = self.helper.create_room_as(
- user1_id,
- tok=user1_tok,
- extra_content={
- "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}
- },
- )
-
- after_rooms_token = self.event_sources.get_current_token()
-
- # Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
- UserID.from_string(user1_id),
- from_token=None,
- to_token=after_rooms_token,
- )
-
- # Try finding only normal rooms
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[None]),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should appear here because it is a normal room
- # according to the stripped state (no room type)
- self.assertEqual(filtered_room_map.keys(), {room_id, remote_invite_room_id})
-
- # Try finding only spaces
- filtered_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms(
- UserID.from_string(user1_id),
- sync_room_map,
- SlidingSyncConfig.SlidingSyncList.Filters(room_types=[RoomTypes.SPACE]),
- after_rooms_token,
- )
- )
-
- # `remote_invite_room_id` should not appear here because it is a normal room
- # according to the stripped state (no room type)
- self.assertEqual(filtered_room_map.keys(), {space_room_id})
+ self.assertTrue(room_id1 in newly_left)
class SortRoomsTestCase(HomeserverTestCase):
@@ -4361,25 +3615,26 @@ class SortRoomsTestCase(HomeserverTestCase):
user: UserID,
to_token: StreamToken,
from_token: Optional[StreamToken],
- ) -> Dict[str, _RoomMembershipForUser]:
+ ) -> Tuple[Dict[str, RoomsForUserType], AbstractSet[str], AbstractSet[str]]:
"""
Get the rooms the user should be syncing with
"""
- room_membership_for_user_map = self.get_success(
- self.sliding_sync_handler.get_room_membership_for_user_at_to_token(
+ room_membership_for_user_map, newly_joined, newly_left = self.get_success(
+ self.sliding_sync_handler.room_lists.get_room_membership_for_user_at_to_token(
user=user,
from_token=from_token,
to_token=to_token,
)
)
filtered_sync_room_map = self.get_success(
- self.sliding_sync_handler.filter_rooms_relevant_for_sync(
+ self.sliding_sync_handler.room_lists.filter_rooms_relevant_for_sync(
user=user,
room_membership_for_user_map=room_membership_for_user_map,
+ newly_left_room_ids=newly_left,
)
)
- return filtered_sync_room_map
+ return filtered_sync_room_map, newly_joined, newly_left
def test_sort_activity_basic(self) -> None:
"""
@@ -4400,7 +3655,7 @@ class SortRoomsTestCase(HomeserverTestCase):
after_rooms_token = self.event_sources.get_current_token()
# Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
+ sync_room_map, newly_joined, newly_left = self._get_sync_room_ids_for_user(
UserID.from_string(user1_id),
from_token=None,
to_token=after_rooms_token,
@@ -4408,7 +3663,7 @@ class SortRoomsTestCase(HomeserverTestCase):
# Sort the rooms (what we're testing)
sorted_sync_rooms = self.get_success(
- self.sliding_sync_handler.sort_rooms(
+ self.sliding_sync_handler.room_lists.sort_rooms(
sync_room_map=sync_room_map,
to_token=after_rooms_token,
)
@@ -4481,7 +3736,7 @@ class SortRoomsTestCase(HomeserverTestCase):
self.helper.send(room_id3, "activity in room3", tok=user2_tok)
# Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
+ sync_room_map, newly_joined, newly_left = self._get_sync_room_ids_for_user(
UserID.from_string(user1_id),
from_token=before_rooms_token,
to_token=after_rooms_token,
@@ -4489,7 +3744,7 @@ class SortRoomsTestCase(HomeserverTestCase):
# Sort the rooms (what we're testing)
sorted_sync_rooms = self.get_success(
- self.sliding_sync_handler.sort_rooms(
+ self.sliding_sync_handler.room_lists.sort_rooms(
sync_room_map=sync_room_map,
to_token=after_rooms_token,
)
@@ -4545,7 +3800,7 @@ class SortRoomsTestCase(HomeserverTestCase):
after_rooms_token = self.event_sources.get_current_token()
# Get the rooms the user should be syncing with
- sync_room_map = self._get_sync_room_ids_for_user(
+ sync_room_map, newly_joined, newly_left = self._get_sync_room_ids_for_user(
UserID.from_string(user1_id),
from_token=None,
to_token=after_rooms_token,
@@ -4553,7 +3808,7 @@ class SortRoomsTestCase(HomeserverTestCase):
# Sort the rooms (what we're testing)
sorted_sync_rooms = self.get_success(
- self.sliding_sync_handler.sort_rooms(
+ self.sliding_sync_handler.room_lists.sort_rooms(
sync_room_map=sync_room_map,
to_token=after_rooms_token,
)
@@ -4565,3 +3820,1071 @@ class SortRoomsTestCase(HomeserverTestCase):
# We only care about the *latest* event in the room.
[room_id1, room_id2],
)
+
+
+@attr.s(slots=True, auto_attribs=True, frozen=True)
+class RequiredStateChangesTestParameters:
+ previous_required_state_map: Dict[str, Set[str]]
+ request_required_state_map: Dict[str, Set[str]]
+ state_deltas: StateMap[str]
+ expected_with_state_deltas: Tuple[
+ Optional[Mapping[str, AbstractSet[str]]], StateFilter
+ ]
+ expected_without_state_deltas: Tuple[
+ Optional[Mapping[str, AbstractSet[str]]], StateFilter
+ ]
+
+
+class RequiredStateChangesTestCase(unittest.TestCase):
+ """Test cases for `_required_state_changes`"""
+
+ @parameterized.expand(
+ [
+ (
+ "simple_no_change",
+ """Test no change to required state""",
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={"type1": {"state_key"}},
+ request_required_state_map={"type1": {"state_key"}},
+ state_deltas={("type1", "state_key"): "$event_id"},
+ # No changes
+ expected_with_state_deltas=(None, StateFilter.none()),
+ expected_without_state_deltas=(None, StateFilter.none()),
+ ),
+ ),
+ (
+ "simple_add_type",
+ """Test adding a type to the config""",
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={"type1": {"state_key"}},
+ request_required_state_map={
+ "type1": {"state_key"},
+ "type2": {"state_key"},
+ },
+ state_deltas={("type2", "state_key"): "$event_id"},
+ expected_with_state_deltas=(
+ # We've added a type so we should persist the changed required state
+ # config.
+ {"type1": {"state_key"}, "type2": {"state_key"}},
+ # We should see the new type added
+ StateFilter.from_types([("type2", "state_key")]),
+ ),
+ expected_without_state_deltas=(
+ {"type1": {"state_key"}, "type2": {"state_key"}},
+ StateFilter.from_types([("type2", "state_key")]),
+ ),
+ ),
+ ),
+ (
+ "simple_add_type_from_nothing",
+ """Test adding a type to the config when previously requesting nothing""",
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={},
+ request_required_state_map={
+ "type1": {"state_key"},
+ "type2": {"state_key"},
+ },
+ state_deltas={("type2", "state_key"): "$event_id"},
+ expected_with_state_deltas=(
+ # We've added a type so we should persist the changed required state
+ # config.
+ {"type1": {"state_key"}, "type2": {"state_key"}},
+ # We should see the new types added
+ StateFilter.from_types(
+ [("type1", "state_key"), ("type2", "state_key")]
+ ),
+ ),
+ expected_without_state_deltas=(
+ {"type1": {"state_key"}, "type2": {"state_key"}},
+ StateFilter.from_types(
+ [("type1", "state_key"), ("type2", "state_key")]
+ ),
+ ),
+ ),
+ ),
+ (
+ "simple_add_state_key",
+ """Test adding a state key to the config""",
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={"type": {"state_key1"}},
+ request_required_state_map={"type": {"state_key1", "state_key2"}},
+ state_deltas={("type", "state_key2"): "$event_id"},
+ expected_with_state_deltas=(
+ # We've added a key so we should persist the changed required state
+ # config.
+ {"type": {"state_key1", "state_key2"}},
+ # We should see the new state_keys added
+ StateFilter.from_types([("type", "state_key2")]),
+ ),
+ expected_without_state_deltas=(
+ {"type": {"state_key1", "state_key2"}},
+ StateFilter.from_types([("type", "state_key2")]),
+ ),
+ ),
+ ),
+ (
+ "simple_retain_previous_state_keys",
+ """Test adding a state key to the config and retaining a previously sent state_key""",
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={"type": {"state_key1"}},
+ request_required_state_map={"type": {"state_key2", "state_key3"}},
+ state_deltas={("type", "state_key2"): "$event_id"},
+ expected_with_state_deltas=(
+ # We've added a key so we should persist the changed required state
+ # config.
+ #
+ # Retain `state_key1` from the `previous_required_state_map`
+ {"type": {"state_key1", "state_key2", "state_key3"}},
+ # We should see the new state_keys added
+ StateFilter.from_types(
+ [("type", "state_key2"), ("type", "state_key3")]
+ ),
+ ),
+ expected_without_state_deltas=(
+ {"type": {"state_key1", "state_key2", "state_key3"}},
+ StateFilter.from_types(
+ [("type", "state_key2"), ("type", "state_key3")]
+ ),
+ ),
+ ),
+ ),
+ (
+ "simple_remove_type",
+ """
+ Test removing a type from the config when there are a matching state
+ delta does cause the persisted required state config to change
+
+ Test removing a type from the config when there are no matching state
+ deltas does *not* cause the persisted required state config to change
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={
+ "type1": {"state_key"},
+ "type2": {"state_key"},
+ },
+ request_required_state_map={"type1": {"state_key"}},
+ state_deltas={("type2", "state_key"): "$event_id"},
+ expected_with_state_deltas=(
+ # Remove `type2` since there's been a change to that state,
+ # (persist the change to required state). That way next time,
+ # they request `type2`, we see that we haven't sent it before
+ # and send the new state. (we should still keep track that we've
+ # sent `type1` before).
+ {"type1": {"state_key"}},
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ expected_without_state_deltas=(
+ # `type2` is no longer requested but since that state hasn't
+ # changed, nothing should change (we should still keep track
+ # that we've sent `type2` before).
+ None,
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ ),
+ ),
+ (
+ "simple_remove_type_to_nothing",
+ """
+ Test removing a type from the config and no longer requesting any state
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={
+ "type1": {"state_key"},
+ "type2": {"state_key"},
+ },
+ request_required_state_map={},
+ state_deltas={("type2", "state_key"): "$event_id"},
+ expected_with_state_deltas=(
+ # Remove `type2` since there's been a change to that state,
+ # (persist the change to required state). That way next time,
+ # they request `type2`, we see that we haven't sent it before
+ # and send the new state. (we should still keep track that we've
+ # sent `type1` before).
+ {"type1": {"state_key"}},
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ expected_without_state_deltas=(
+ # `type2` is no longer requested but since that state hasn't
+ # changed, nothing should change (we should still keep track
+ # that we've sent `type2` before).
+ None,
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ ),
+ ),
+ (
+ "simple_remove_state_key",
+ """
+ Test removing a state_key from the config
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={"type": {"state_key1", "state_key2"}},
+ request_required_state_map={"type": {"state_key1"}},
+ state_deltas={("type", "state_key2"): "$event_id"},
+ expected_with_state_deltas=(
+ # Remove `(type, state_key2)` since there's been a change
+ # to that state (persist the change to required state).
+ # That way next time, they request `(type, state_key2)`, we see
+ # that we haven't sent it before and send the new state. (we
+ # should still keep track that we've sent `(type, state_key1)`
+ # before).
+ {"type": {"state_key1"}},
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ expected_without_state_deltas=(
+ # `(type, state_key2)` is no longer requested but since that
+ # state hasn't changed, nothing should change (we should still
+ # keep track that we've sent `(type, state_key1)` and `(type,
+ # state_key2)` before).
+ None,
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ ),
+ ),
+ (
+ "type_wildcards_add",
+ """
+ Test adding a wildcard type causes the persisted required state config
+ to change and we request everything.
+
+ If a event type wildcard has been added or removed we don't try and do
+ anything fancy, and instead always update the effective room required
+ state config to match the request.
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={"type1": {"state_key2"}},
+ request_required_state_map={
+ "type1": {"state_key2"},
+ StateValues.WILDCARD: {"state_key"},
+ },
+ state_deltas={
+ ("other_type", "state_key"): "$event_id",
+ },
+ # We've added a wildcard, so we persist the change and request everything
+ expected_with_state_deltas=(
+ {"type1": {"state_key2"}, StateValues.WILDCARD: {"state_key"}},
+ StateFilter.all(),
+ ),
+ expected_without_state_deltas=(
+ {"type1": {"state_key2"}, StateValues.WILDCARD: {"state_key"}},
+ StateFilter.all(),
+ ),
+ ),
+ ),
+ (
+ "type_wildcards_remove",
+ """
+ Test removing a wildcard type causes the persisted required state config
+ to change and request nothing.
+
+ If a event type wildcard has been added or removed we don't try and do
+ anything fancy, and instead always update the effective room required
+ state config to match the request.
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={
+ "type1": {"state_key2"},
+ StateValues.WILDCARD: {"state_key"},
+ },
+ request_required_state_map={"type1": {"state_key2"}},
+ state_deltas={
+ ("other_type", "state_key"): "$event_id",
+ },
+ # We've removed a type wildcard, so we persist the change but don't request anything
+ expected_with_state_deltas=(
+ {"type1": {"state_key2"}},
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ expected_without_state_deltas=(
+ {"type1": {"state_key2"}},
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ ),
+ ),
+ (
+ "state_key_wildcards_add",
+ """Test adding a wildcard state_key""",
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={"type1": {"state_key"}},
+ request_required_state_map={
+ "type1": {"state_key"},
+ "type2": {StateValues.WILDCARD},
+ },
+ state_deltas={("type2", "state_key"): "$event_id"},
+ # We've added a wildcard state_key, so we persist the change and
+ # request all of the state for that type
+ expected_with_state_deltas=(
+ {"type1": {"state_key"}, "type2": {StateValues.WILDCARD}},
+ StateFilter.from_types([("type2", None)]),
+ ),
+ expected_without_state_deltas=(
+ {"type1": {"state_key"}, "type2": {StateValues.WILDCARD}},
+ StateFilter.from_types([("type2", None)]),
+ ),
+ ),
+ ),
+ (
+ "state_key_wildcards_remove",
+ """Test removing a wildcard state_key""",
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={
+ "type1": {"state_key"},
+ "type2": {StateValues.WILDCARD},
+ },
+ request_required_state_map={"type1": {"state_key"}},
+ state_deltas={("type2", "state_key"): "$event_id"},
+ # We've removed a state_key wildcard, so we persist the change and
+ # request nothing
+ expected_with_state_deltas=(
+ {"type1": {"state_key"}},
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ # We've removed a state_key wildcard but there have been no matching
+ # state changes, so no changes needed, just persist the
+ # `request_required_state_map` as-is.
+ expected_without_state_deltas=(
+ None,
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ ),
+ ),
+ (
+ "state_key_remove_some",
+ """
+ Test that removing state keys work when only some of the state keys have
+ changed
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={
+ "type1": {"state_key1", "state_key2", "state_key3"}
+ },
+ request_required_state_map={"type1": {"state_key1"}},
+ state_deltas={("type1", "state_key3"): "$event_id"},
+ expected_with_state_deltas=(
+ # We've removed some state keys from the type, but only state_key3 was
+ # changed so only that one should be removed.
+ {"type1": {"state_key1", "state_key2"}},
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ expected_without_state_deltas=(
+ # No changes needed, just persist the
+ # `request_required_state_map` as-is
+ None,
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ ),
+ ),
+ (
+ "state_key_me_add",
+ """
+ Test adding state keys work when using "$ME"
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={},
+ request_required_state_map={"type1": {StateValues.ME}},
+ state_deltas={("type1", "@user:test"): "$event_id"},
+ expected_with_state_deltas=(
+ # We've added a type so we should persist the changed required state
+ # config.
+ {"type1": {StateValues.ME}},
+ # We should see the new state_keys added
+ StateFilter.from_types([("type1", "@user:test")]),
+ ),
+ expected_without_state_deltas=(
+ {"type1": {StateValues.ME}},
+ StateFilter.from_types([("type1", "@user:test")]),
+ ),
+ ),
+ ),
+ (
+ "state_key_me_remove",
+ """
+ Test removing state keys work when using "$ME"
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={"type1": {StateValues.ME}},
+ request_required_state_map={},
+ state_deltas={("type1", "@user:test"): "$event_id"},
+ expected_with_state_deltas=(
+ # Remove `type1` since there's been a change to that state,
+ # (persist the change to required state). That way next time,
+ # they request `type1`, we see that we haven't sent it before
+ # and send the new state. (if we were tracking that we sent any
+ # other state, we should still keep track that).
+ {},
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ expected_without_state_deltas=(
+ # `type1` is no longer requested but since that state hasn't
+ # changed, nothing should change (we should still keep track
+ # that we've sent `type1` before).
+ None,
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ ),
+ ),
+ (
+ "state_key_user_id_add",
+ """
+ Test adding state keys work when using your own user ID
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={},
+ request_required_state_map={"type1": {"@user:test"}},
+ state_deltas={("type1", "@user:test"): "$event_id"},
+ expected_with_state_deltas=(
+ # We've added a type so we should persist the changed required state
+ # config.
+ {"type1": {"@user:test"}},
+ # We should see the new state_keys added
+ StateFilter.from_types([("type1", "@user:test")]),
+ ),
+ expected_without_state_deltas=(
+ {"type1": {"@user:test"}},
+ StateFilter.from_types([("type1", "@user:test")]),
+ ),
+ ),
+ ),
+ (
+ "state_key_me_remove",
+ """
+ Test removing state keys work when using your own user ID
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={"type1": {"@user:test"}},
+ request_required_state_map={},
+ state_deltas={("type1", "@user:test"): "$event_id"},
+ expected_with_state_deltas=(
+ # Remove `type1` since there's been a change to that state,
+ # (persist the change to required state). That way next time,
+ # they request `type1`, we see that we haven't sent it before
+ # and send the new state. (if we were tracking that we sent any
+ # other state, we should still keep track that).
+ {},
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ expected_without_state_deltas=(
+ # `type1` is no longer requested but since that state hasn't
+ # changed, nothing should change (we should still keep track
+ # that we've sent `type1` before).
+ None,
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ ),
+ ),
+ (
+ "state_key_lazy_add",
+ """
+ Test adding state keys work when using "$LAZY"
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={},
+ request_required_state_map={EventTypes.Member: {StateValues.LAZY}},
+ state_deltas={(EventTypes.Member, "@user:test"): "$event_id"},
+ expected_with_state_deltas=(
+ # If a "$LAZY" has been added or removed we always update the
+ # required state to what was requested for simplicity.
+ {EventTypes.Member: {StateValues.LAZY}},
+ StateFilter.none(),
+ ),
+ expected_without_state_deltas=(
+ {EventTypes.Member: {StateValues.LAZY}},
+ StateFilter.none(),
+ ),
+ ),
+ ),
+ (
+ "state_key_lazy_remove",
+ """
+ Test removing state keys work when using "$LAZY"
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={EventTypes.Member: {StateValues.LAZY}},
+ request_required_state_map={},
+ state_deltas={(EventTypes.Member, "@user:test"): "$event_id"},
+ expected_with_state_deltas=(
+ # If a "$LAZY" has been added or removed we always update the
+ # required state to what was requested for simplicity.
+ {},
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ expected_without_state_deltas=(
+ # `EventTypes.Member` is no longer requested but since that
+ # state hasn't changed, nothing should change (we should still
+ # keep track that we've sent `EventTypes.Member` before).
+ None,
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ ),
+ ),
+ (
+ "state_key_lazy_keep_previous_memberships_and_no_new_memberships",
+ """
+ This test mimics a request with lazy-loading room members enabled where
+ we have previously sent down user2 and user3's membership events and now
+ we're sending down another response without any timeline events.
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={
+ EventTypes.Member: {
+ StateValues.LAZY,
+ "@user2:test",
+ "@user3:test",
+ }
+ },
+ request_required_state_map={EventTypes.Member: {StateValues.LAZY}},
+ state_deltas={(EventTypes.Member, "@user2:test"): "$event_id"},
+ expected_with_state_deltas=(
+ # Remove "@user2:test" since that state has changed and is no
+ # longer being requested anymore. Since something was removed,
+ # we should persist the changed to required state. That way next
+ # time, they request "@user2:test", we see that we haven't sent
+ # it before and send the new state. (we should still keep track
+ # that we've sent specific `EventTypes.Member` before)
+ {
+ EventTypes.Member: {
+ StateValues.LAZY,
+ "@user3:test",
+ }
+ },
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ expected_without_state_deltas=(
+ # We're not requesting any specific `EventTypes.Member` now but
+ # since that state hasn't changed, nothing should change (we
+ # should still keep track that we've sent specific
+ # `EventTypes.Member` before).
+ None,
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ ),
+ ),
+ (
+ "state_key_lazy_keep_previous_memberships_with_new_memberships",
+ """
+ This test mimics a request with lazy-loading room members enabled where
+ we have previously sent down user2 and user3's membership events and now
+ we're sending down another response with a new event from user4.
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={
+ EventTypes.Member: {
+ StateValues.LAZY,
+ "@user2:test",
+ "@user3:test",
+ }
+ },
+ request_required_state_map={
+ EventTypes.Member: {StateValues.LAZY, "@user4:test"}
+ },
+ state_deltas={(EventTypes.Member, "@user2:test"): "$event_id"},
+ expected_with_state_deltas=(
+ # Since "@user4:test" was added, we should persist the changed
+ # required state config.
+ #
+ # Also remove "@user2:test" since that state has changed and is no
+ # longer being requested anymore. Since something was removed,
+ # we also should persist the changed to required state. That way next
+ # time, they request "@user2:test", we see that we haven't sent
+ # it before and send the new state. (we should still keep track
+ # that we've sent specific `EventTypes.Member` before)
+ {
+ EventTypes.Member: {
+ StateValues.LAZY,
+ "@user3:test",
+ "@user4:test",
+ }
+ },
+ # We should see the new state_keys added
+ StateFilter.from_types([(EventTypes.Member, "@user4:test")]),
+ ),
+ expected_without_state_deltas=(
+ # Since "@user4:test" was added, we should persist the changed
+ # required state config.
+ {
+ EventTypes.Member: {
+ StateValues.LAZY,
+ "@user2:test",
+ "@user3:test",
+ "@user4:test",
+ }
+ },
+ # We should see the new state_keys added
+ StateFilter.from_types([(EventTypes.Member, "@user4:test")]),
+ ),
+ ),
+ ),
+ (
+ "state_key_expand_lazy_keep_previous_memberships",
+ """
+ Test expanding the `required_state` to lazy-loading room members.
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={
+ EventTypes.Member: {"@user2:test", "@user3:test"}
+ },
+ request_required_state_map={EventTypes.Member: {StateValues.LAZY}},
+ state_deltas={(EventTypes.Member, "@user2:test"): "$event_id"},
+ expected_with_state_deltas=(
+ # Since `StateValues.LAZY` was added, we should persist the
+ # changed required state config.
+ #
+ # Also remove "@user2:test" since that state has changed and is no
+ # longer being requested anymore. Since something was removed,
+ # we also should persist the changed to required state. That way next
+ # time, they request "@user2:test", we see that we haven't sent
+ # it before and send the new state. (we should still keep track
+ # that we've sent specific `EventTypes.Member` before)
+ {
+ EventTypes.Member: {
+ StateValues.LAZY,
+ "@user3:test",
+ }
+ },
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ expected_without_state_deltas=(
+ # Since `StateValues.LAZY` was added, we should persist the
+ # changed required state config.
+ {
+ EventTypes.Member: {
+ StateValues.LAZY,
+ "@user2:test",
+ "@user3:test",
+ }
+ },
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ ),
+ ),
+ (
+ "state_key_retract_lazy_keep_previous_memberships_no_new_memberships",
+ """
+ Test retracting the `required_state` to no longer lazy-loading room members.
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={
+ EventTypes.Member: {
+ StateValues.LAZY,
+ "@user2:test",
+ "@user3:test",
+ }
+ },
+ request_required_state_map={},
+ state_deltas={(EventTypes.Member, "@user2:test"): "$event_id"},
+ expected_with_state_deltas=(
+ # Remove `EventTypes.Member` since there's been a change to that
+ # state, (persist the change to required state). That way next
+ # time, they request `EventTypes.Member`, we see that we haven't
+ # sent it before and send the new state. (if we were tracking
+ # that we sent any other state, we should still keep track
+ # that).
+ #
+ # This acts the same as the `simple_remove_type` test. It's
+ # possible that we could remember the specific `state_keys` that
+ # we have sent down before but this currently just acts the same
+ # as if a whole `type` was removed. Perhaps it's good that we
+ # "garbage collect" and forget what we've sent before for a
+ # given `type` when the client stops caring about a certain
+ # `type`.
+ {},
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ expected_without_state_deltas=(
+ # `EventTypes.Member` is no longer requested but since that
+ # state hasn't changed, nothing should change (we should still
+ # keep track that we've sent `EventTypes.Member` before).
+ None,
+ # We don't need to request anything more if they are requesting
+ # less state now
+ StateFilter.none(),
+ ),
+ ),
+ ),
+ (
+ "state_key_retract_lazy_keep_previous_memberships_with_new_memberships",
+ """
+ Test retracting the `required_state` to no longer lazy-loading room members.
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={
+ EventTypes.Member: {
+ StateValues.LAZY,
+ "@user2:test",
+ "@user3:test",
+ }
+ },
+ request_required_state_map={EventTypes.Member: {"@user4:test"}},
+ state_deltas={(EventTypes.Member, "@user2:test"): "$event_id"},
+ expected_with_state_deltas=(
+ # Since "@user4:test" was added, we should persist the changed
+ # required state config.
+ #
+ # Also remove "@user2:test" since that state has changed and is no
+ # longer being requested anymore. Since something was removed,
+ # we also should persist the changed to required state. That way next
+ # time, they request "@user2:test", we see that we haven't sent
+ # it before and send the new state. (we should still keep track
+ # that we've sent specific `EventTypes.Member` before)
+ {
+ EventTypes.Member: {
+ "@user3:test",
+ "@user4:test",
+ }
+ },
+ # We should see the new state_keys added
+ StateFilter.from_types([(EventTypes.Member, "@user4:test")]),
+ ),
+ expected_without_state_deltas=(
+ # Since "@user4:test" was added, we should persist the changed
+ # required state config.
+ {
+ EventTypes.Member: {
+ "@user2:test",
+ "@user3:test",
+ "@user4:test",
+ }
+ },
+ # We should see the new state_keys added
+ StateFilter.from_types([(EventTypes.Member, "@user4:test")]),
+ ),
+ ),
+ ),
+ (
+ "type_wildcard_with_state_key_wildcard_to_explicit_state_keys",
+ """
+ Test switching from a wildcard ("*", "*") to explicit state keys
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={
+ StateValues.WILDCARD: {StateValues.WILDCARD}
+ },
+ request_required_state_map={
+ StateValues.WILDCARD: {"state_key1", "state_key2", "state_key3"}
+ },
+ state_deltas={("type1", "state_key1"): "$event_id"},
+ # If we were previously fetching everything ("*", "*"), always update the effective
+ # room required state config to match the request. And since we we're previously
+ # already fetching everything, we don't have to fetch anything now that they've
+ # narrowed.
+ expected_with_state_deltas=(
+ {
+ StateValues.WILDCARD: {
+ "state_key1",
+ "state_key2",
+ "state_key3",
+ }
+ },
+ StateFilter.none(),
+ ),
+ expected_without_state_deltas=(
+ {
+ StateValues.WILDCARD: {
+ "state_key1",
+ "state_key2",
+ "state_key3",
+ }
+ },
+ StateFilter.none(),
+ ),
+ ),
+ ),
+ (
+ "type_wildcard_with_explicit_state_keys_to_wildcard_state_key",
+ """
+ Test switching from explicit to wildcard state keys ("*", "*")
+ """,
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={
+ StateValues.WILDCARD: {"state_key1", "state_key2", "state_key3"}
+ },
+ request_required_state_map={
+ StateValues.WILDCARD: {StateValues.WILDCARD}
+ },
+ state_deltas={("type1", "state_key1"): "$event_id"},
+ # We've added a wildcard, so we persist the change and request everything
+ expected_with_state_deltas=(
+ {StateValues.WILDCARD: {StateValues.WILDCARD}},
+ StateFilter.all(),
+ ),
+ expected_without_state_deltas=(
+ {StateValues.WILDCARD: {StateValues.WILDCARD}},
+ StateFilter.all(),
+ ),
+ ),
+ ),
+ (
+ "state_key_wildcard_to_explicit_state_keys",
+ """Test switching from a wildcard to explicit state keys with a concrete type""",
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={"type1": {StateValues.WILDCARD}},
+ request_required_state_map={
+ "type1": {"state_key1", "state_key2", "state_key3"}
+ },
+ state_deltas={("type1", "state_key1"): "$event_id"},
+ # If a state_key wildcard has been added or removed, we always
+ # update the effective room required state config to match the
+ # request. And since we we're previously already fetching
+ # everything, we don't have to fetch anything now that they've
+ # narrowed.
+ expected_with_state_deltas=(
+ {
+ "type1": {
+ "state_key1",
+ "state_key2",
+ "state_key3",
+ }
+ },
+ StateFilter.none(),
+ ),
+ expected_without_state_deltas=(
+ {
+ "type1": {
+ "state_key1",
+ "state_key2",
+ "state_key3",
+ }
+ },
+ StateFilter.none(),
+ ),
+ ),
+ ),
+ (
+ "explicit_state_keys_to_wildcard_state_key",
+ """Test switching from a wildcard to explicit state keys with a concrete type""",
+ RequiredStateChangesTestParameters(
+ previous_required_state_map={
+ "type1": {"state_key1", "state_key2", "state_key3"}
+ },
+ request_required_state_map={"type1": {StateValues.WILDCARD}},
+ state_deltas={("type1", "state_key1"): "$event_id"},
+ # If a state_key wildcard has been added or removed, we always
+ # update the effective room required state config to match the
+ # request. And we need to request all of the state for that type
+ # because we previously, only sent down a few keys.
+ expected_with_state_deltas=(
+ {"type1": {StateValues.WILDCARD, "state_key2", "state_key3"}},
+ StateFilter.from_types([("type1", None)]),
+ ),
+ expected_without_state_deltas=(
+ {
+ "type1": {
+ StateValues.WILDCARD,
+ "state_key1",
+ "state_key2",
+ "state_key3",
+ }
+ },
+ StateFilter.from_types([("type1", None)]),
+ ),
+ ),
+ ),
+ ]
+ )
+ def test_xxx(
+ self,
+ _test_label: str,
+ _test_description: str,
+ test_parameters: RequiredStateChangesTestParameters,
+ ) -> None:
+ # Without `state_deltas`
+ changed_required_state_map, added_state_filter = _required_state_changes(
+ user_id="@user:test",
+ prev_required_state_map=test_parameters.previous_required_state_map,
+ request_required_state_map=test_parameters.request_required_state_map,
+ state_deltas={},
+ )
+
+ self.assertEqual(
+ changed_required_state_map,
+ test_parameters.expected_without_state_deltas[0],
+ "changed_required_state_map does not match (without state_deltas)",
+ )
+ self.assertEqual(
+ added_state_filter,
+ test_parameters.expected_without_state_deltas[1],
+ "added_state_filter does not match (without state_deltas)",
+ )
+
+ # With `state_deltas`
+ changed_required_state_map, added_state_filter = _required_state_changes(
+ user_id="@user:test",
+ prev_required_state_map=test_parameters.previous_required_state_map,
+ request_required_state_map=test_parameters.request_required_state_map,
+ state_deltas=test_parameters.state_deltas,
+ )
+
+ self.assertEqual(
+ changed_required_state_map,
+ test_parameters.expected_with_state_deltas[0],
+ "changed_required_state_map does not match (with state_deltas)",
+ )
+ self.assertEqual(
+ added_state_filter,
+ test_parameters.expected_with_state_deltas[1],
+ "added_state_filter does not match (with state_deltas)",
+ )
+
+ @parameterized.expand(
+ [
+ # Test with a normal arbitrary type (no special meaning)
+ ("arbitrary_type", "type", set()),
+ # Test with membership
+ ("membership", EventTypes.Member, set()),
+ # Test with lazy-loading room members
+ ("lazy_loading_membership", EventTypes.Member, {StateValues.LAZY}),
+ ]
+ )
+ def test_limit_retained_previous_state_keys(
+ self,
+ _test_label: str,
+ event_type: str,
+ extra_state_keys: Set[str],
+ ) -> None:
+ """
+ Test that we limit the number of state_keys that we remember but always include
+ the state_keys that we've just requested.
+ """
+ previous_required_state_map = {
+ event_type: {
+ # Prefix the state_keys we've "prev_"iously sent so they are easier to
+ # identify in our assertions.
+ f"prev_state_key{i}"
+ for i in range(MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER - 30)
+ }
+ | extra_state_keys
+ }
+ request_required_state_map = {
+ event_type: {f"state_key{i}" for i in range(50)} | extra_state_keys
+ }
+
+ # (function under test)
+ changed_required_state_map, added_state_filter = _required_state_changes(
+ user_id="@user:test",
+ prev_required_state_map=previous_required_state_map,
+ request_required_state_map=request_required_state_map,
+ state_deltas={},
+ )
+ assert changed_required_state_map is not None
+
+ # We should only remember up to the maximum number of state keys
+ self.assertGreaterEqual(
+ len(changed_required_state_map[event_type]),
+ # Most of the time this will be `MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER` but
+ # because we are just naively selecting enough previous state_keys to fill
+ # the limit, there might be some overlap in what's added back which means we
+ # might have slightly less than the limit.
+ #
+ # `extra_state_keys` overlaps in the previous and requested
+ # `required_state_map` so we might see this this scenario.
+ MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER - len(extra_state_keys),
+ )
+
+ # Should include all of the requested state
+ self.assertIncludes(
+ changed_required_state_map[event_type],
+ request_required_state_map[event_type],
+ )
+ # And the rest is filled with the previous state keys
+ #
+ # We can't assert the exact state_keys since we don't know the order so we just
+ # check that they all start with "prev_" and that we have the correct amount.
+ remaining_state_keys = (
+ changed_required_state_map[event_type]
+ - request_required_state_map[event_type]
+ )
+ self.assertGreater(
+ len(remaining_state_keys),
+ 0,
+ )
+ assert all(
+ state_key.startswith("prev_") for state_key in remaining_state_keys
+ ), "Remaining state_keys should be the previous state_keys"
+
+ def test_request_more_state_keys_than_remember_limit(self) -> None:
+ """
+ Test requesting more state_keys than fit in our limit to remember from previous
+ requests.
+ """
+ previous_required_state_map = {
+ "type": {
+ # Prefix the state_keys we've "prev_"iously sent so they are easier to
+ # identify in our assertions.
+ f"prev_state_key{i}"
+ for i in range(MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER - 30)
+ }
+ }
+ request_required_state_map = {
+ "type": {
+ f"state_key{i}"
+ # Requesting more than the MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER
+ for i in range(MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER + 20)
+ }
+ }
+ # Ensure that we are requesting more than the limit
+ self.assertGreater(
+ len(request_required_state_map["type"]),
+ MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER,
+ )
+
+ # (function under test)
+ changed_required_state_map, added_state_filter = _required_state_changes(
+ user_id="@user:test",
+ prev_required_state_map=previous_required_state_map,
+ request_required_state_map=request_required_state_map,
+ state_deltas={},
+ )
+ assert changed_required_state_map is not None
+
+ # Should include all of the requested state
+ self.assertIncludes(
+ changed_required_state_map["type"],
+ request_required_state_map["type"],
+ exact=True,
+ )
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index fa55f76916..6b202dfbd5 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -17,10 +17,11 @@
# [This file includes modifications made by New Vector Limited]
#
#
+from http import HTTPStatus
from typing import Collection, ContextManager, List, Optional
from unittest.mock import AsyncMock, Mock, patch
-from parameterized import parameterized
+from parameterized import parameterized, parameterized_class
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
@@ -32,7 +33,13 @@ from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.federation.federation_base import event_from_pdu_json
-from synapse.handlers.sync import SyncConfig, SyncRequestKey, SyncResult, SyncVersion
+from synapse.handlers.sync import (
+ SyncConfig,
+ SyncRequestKey,
+ SyncResult,
+ SyncVersion,
+ TimelineBatch,
+)
from synapse.rest import admin
from synapse.rest.client import knock, login, room
from synapse.server import HomeServer
@@ -58,9 +65,21 @@ def generate_request_key() -> SyncRequestKey:
return ("request_key", _request_key)
+@parameterized_class(
+ ("use_state_after",),
+ [
+ (True,),
+ (False,),
+ ],
+ class_name_func=lambda cls,
+ num,
+ params_dict: f"{cls.__name__}_{'state_after' if params_dict['use_state_after'] else 'state'}",
+)
class SyncTestCase(tests.unittest.HomeserverTestCase):
"""Tests Sync Handler."""
+ use_state_after: bool
+
servlets = [
admin.register_servlets,
knock.register_servlets,
@@ -79,7 +98,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
def test_wait_for_sync_for_user_auth_blocking(self) -> None:
user_id1 = "@user1:test"
user_id2 = "@user2:test"
- sync_config = generate_sync_config(user_id1)
+ sync_config = generate_sync_config(
+ user_id1, use_state_after=self.use_state_after
+ )
requester = create_requester(user_id1)
self.reactor.advance(100) # So we get not 0 time
@@ -112,7 +133,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.auth_blocking._hs_disabled = False
- sync_config = generate_sync_config(user_id2)
+ sync_config = generate_sync_config(
+ user_id2, use_state_after=self.use_state_after
+ )
requester = create_requester(user_id2)
e = self.get_failure(
@@ -141,7 +164,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
initial_result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
requester,
- sync_config=generate_sync_config(user, device_id="dev"),
+ sync_config=generate_sync_config(
+ user, device_id="dev", use_state_after=self.use_state_after
+ ),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -175,7 +200,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
requester,
- sync_config=generate_sync_config(user),
+ sync_config=generate_sync_config(
+ user, use_state_after=self.use_state_after
+ ),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -188,7 +215,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
requester,
- sync_config=generate_sync_config(user, device_id="dev"),
+ sync_config=generate_sync_config(
+ user, device_id="dev", use_state_after=self.use_state_after
+ ),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
since_token=initial_result.next_batch,
@@ -220,7 +249,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
requester,
- sync_config=generate_sync_config(user),
+ sync_config=generate_sync_config(
+ user, use_state_after=self.use_state_after
+ ),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -233,7 +264,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
requester,
- sync_config=generate_sync_config(user, device_id="dev"),
+ sync_config=generate_sync_config(
+ user, device_id="dev", use_state_after=self.use_state_after
+ ),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
since_token=initial_result.next_batch,
@@ -276,7 +309,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
alice_sync_result: SyncResult = self.get_success(
self.sync_handler.wait_for_sync_for_user(
create_requester(owner),
- generate_sync_config(owner),
+ generate_sync_config(owner, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -296,7 +329,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Eve syncs.
eve_requester = create_requester(eve)
- eve_sync_config = generate_sync_config(eve)
+ eve_sync_config = generate_sync_config(
+ eve, use_state_after=self.use_state_after
+ )
eve_sync_after_ban: SyncResult = self.get_success(
self.sync_handler.wait_for_sync_for_user(
eve_requester,
@@ -313,7 +348,15 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# the prev_events used when creating the join event, such that the ban does not
# precede the join.
with self._patch_get_latest_events([last_room_creation_event_id]):
- self.helper.join(room_id, eve, tok=eve_token)
+ self.helper.join(
+ room_id,
+ eve,
+ tok=eve_token,
+ # Previously, this join would succeed but now we expect it to fail at
+ # this point. The rest of the test is for the case when this used to
+ # succeed.
+ expect_code=HTTPStatus.FORBIDDEN,
+ )
# Eve makes a second, incremental sync.
eve_incremental_sync_after_join: SyncResult = self.get_success(
@@ -367,7 +410,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
initial_sync_result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
alice_requester,
- generate_sync_config(alice),
+ generate_sync_config(alice, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -396,6 +439,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
filter_collection=FilterCollection(
self.hs, {"room": {"timeline": {"limit": 2}}}
),
+ use_state_after=self.use_state_after,
),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
@@ -442,7 +486,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
initial_sync_result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
alice_requester,
- generate_sync_config(alice),
+ generate_sync_config(alice, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -481,6 +525,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
}
},
),
+ use_state_after=self.use_state_after,
),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
@@ -518,6 +563,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
... and a filter that means we only return 1 event, represented by the dashed
horizontal lines: `S2` must be included in the `state` section on the second sync.
+
+ When `use_state_after` is enabled, then we expect to see `s2` in the first sync.
"""
alice = self.register_user("alice", "password")
alice_tok = self.login(alice, "password")
@@ -528,7 +575,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
initial_sync_result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
alice_requester,
- generate_sync_config(alice),
+ generate_sync_config(alice, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -554,6 +601,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
filter_collection=FilterCollection(
self.hs, {"room": {"timeline": {"limit": 1}}}
),
+ use_state_after=self.use_state_after,
),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
@@ -567,10 +615,18 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
[e.event_id for e in room_sync.timeline.events],
[e3_event],
)
- self.assertEqual(
- [e.event_id for e in room_sync.state.values()],
- [],
- )
+
+ if self.use_state_after:
+ # When using `state_after` we get told about s2 immediately
+ self.assertEqual(
+ [e.event_id for e in room_sync.state.values()],
+ [s2_event],
+ )
+ else:
+ self.assertEqual(
+ [e.event_id for e in room_sync.state.values()],
+ [],
+ )
# Now send another event that points to S2, but not E3.
with self._patch_get_latest_events([s2_event]):
@@ -585,6 +641,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
filter_collection=FilterCollection(
self.hs, {"room": {"timeline": {"limit": 1}}}
),
+ use_state_after=self.use_state_after,
),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
@@ -598,10 +655,19 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
[e.event_id for e in room_sync.timeline.events],
[e4_event],
)
- self.assertEqual(
- [e.event_id for e in room_sync.state.values()],
- [s2_event],
- )
+
+ if self.use_state_after:
+ # When using `state_after` we got told about s2 previously, so we
+ # don't again.
+ self.assertEqual(
+ [e.event_id for e in room_sync.state.values()],
+ [],
+ )
+ else:
+ self.assertEqual(
+ [e.event_id for e in room_sync.state.values()],
+ [s2_event],
+ )
def test_state_includes_changes_on_ungappy_syncs(self) -> None:
"""Test `state` where the sync is not gappy.
@@ -638,6 +704,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
This is the last chance for us to tell the client about S2, so it *must* be
included in the response.
+
+ When `use_state_after` is enabled, then we expect to see `s2` in the first sync.
"""
alice = self.register_user("alice", "password")
alice_tok = self.login(alice, "password")
@@ -648,7 +716,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
initial_sync_result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
alice_requester,
- generate_sync_config(alice),
+ generate_sync_config(alice, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -673,6 +741,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
filter_collection=FilterCollection(
self.hs, {"room": {"timeline": {"limit": 1}}}
),
+ use_state_after=self.use_state_after,
),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
@@ -684,7 +753,11 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
[e.event_id for e in room_sync.timeline.events],
[e3_event],
)
- self.assertNotIn(s2_event, [e.event_id for e in room_sync.state.values()])
+ if self.use_state_after:
+ # When using `state_after` we get told about s2 immediately
+ self.assertIn(s2_event, [e.event_id for e in room_sync.state.values()])
+ else:
+ self.assertNotIn(s2_event, [e.event_id for e in room_sync.state.values()])
# More events, E4 and E5
with self._patch_get_latest_events([e3_event]):
@@ -695,7 +768,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
incremental_sync = self.get_success(
self.sync_handler.wait_for_sync_for_user(
alice_requester,
- generate_sync_config(alice),
+ generate_sync_config(alice, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
since_token=initial_sync_result.next_batch,
@@ -710,10 +783,19 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
[e.event_id for e in room_sync.timeline.events],
[e4_event, e5_event],
)
- self.assertEqual(
- [e.event_id for e in room_sync.state.values()],
- [s2_event],
- )
+
+ if self.use_state_after:
+ # When using `state_after` we got told about s2 previously, so we
+ # don't again.
+ self.assertEqual(
+ [e.event_id for e in room_sync.state.values()],
+ [],
+ )
+ else:
+ self.assertEqual(
+ [e.event_id for e in room_sync.state.values()],
+ [s2_event],
+ )
@parameterized.expand(
[
@@ -721,7 +803,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
(True, False),
(False, True),
(True, True),
- ]
+ ],
+ name_func=lambda func, num, p: f"{func.__name__}_{p.args[0]}_{p.args[1]}",
)
def test_archived_rooms_do_not_include_state_after_leave(
self, initial_sync: bool, empty_timeline: bool
@@ -749,7 +832,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
initial_sync_result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
bob_requester,
- generate_sync_config(bob),
+ generate_sync_config(bob, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -780,7 +863,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler.wait_for_sync_for_user(
bob_requester,
generate_sync_config(
- bob, filter_collection=FilterCollection(self.hs, filter_dict)
+ bob,
+ filter_collection=FilterCollection(self.hs, filter_dict),
+ use_state_after=self.use_state_after,
),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
@@ -791,7 +876,15 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
if empty_timeline:
# The timeline should be empty
self.assertEqual(sync_room_result.timeline.events, [])
+ else:
+ # The last three events in the timeline should be those leading up to the
+ # leave
+ self.assertEqual(
+ [e.event_id for e in sync_room_result.timeline.events[-3:]],
+ [before_message_event, before_state_event, leave_event],
+ )
+ if empty_timeline or self.use_state_after:
# And the state should include the leave event...
self.assertEqual(
sync_room_result.state[("m.room.member", bob)].event_id, leave_event
@@ -801,12 +894,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
sync_room_result.state[("test_state", "")].event_id, before_state_event
)
else:
- # The last three events in the timeline should be those leading up to the
- # leave
- self.assertEqual(
- [e.event_id for e in sync_room_result.timeline.events[-3:]],
- [before_message_event, before_state_event, leave_event],
- )
# ... And the state should be empty
self.assertEqual(sync_room_result.state, {})
@@ -843,7 +930,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
) -> List[EventBase]:
return list(pdus)
- self.client._check_sigs_and_hash_for_pulled_events_and_fetch = _check_sigs_and_hash_for_pulled_events_and_fetch # type: ignore[assignment]
+ self.client._check_sigs_and_hash_for_pulled_events_and_fetch = ( # type: ignore[method-assign]
+ _check_sigs_and_hash_for_pulled_events_and_fetch # type: ignore[assignment]
+ )
prev_events = self.get_success(self.store.get_prev_events_for_room(room_id))
@@ -877,7 +966,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
sync_result: SyncResult = self.get_success(
self.sync_handler.wait_for_sync_for_user(
create_requester(user),
- generate_sync_config(user),
+ generate_sync_config(user, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -926,7 +1015,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
private_sync_result: SyncResult = self.get_success(
self.sync_handler.wait_for_sync_for_user(
create_requester(user2),
- generate_sync_config(user2),
+ generate_sync_config(user2, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -952,7 +1041,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
sync_result: SyncResult = self.get_success(
self.sync_handler.wait_for_sync_for_user(
create_requester(user),
- generate_sync_config(user),
+ generate_sync_config(user, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
@@ -989,7 +1078,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
sync_d = defer.ensureDeferred(
self.sync_handler.wait_for_sync_for_user(
create_requester(user),
- generate_sync_config(user),
+ generate_sync_config(user, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
since_token=since_token,
@@ -1044,7 +1133,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
sync_d = defer.ensureDeferred(
self.sync_handler.wait_for_sync_for_user(
create_requester(user),
- generate_sync_config(user),
+ generate_sync_config(user, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
since_token=since_token,
@@ -1060,6 +1149,7 @@ def generate_sync_config(
user_id: str,
device_id: Optional[str] = "device_id",
filter_collection: Optional[FilterCollection] = None,
+ use_state_after: bool = False,
) -> SyncConfig:
"""Generate a sync config (with a unique request key).
@@ -1067,7 +1157,8 @@ def generate_sync_config(
user_id: user who is syncing.
device_id: device that is syncing. Defaults to "device_id".
filter_collection: filter to apply. Defaults to the default filter (ie,
- return everything, with a default limit)
+ return everything, with a default limit)
+ use_state_after: whether the `use_state_after` flag was set.
"""
if filter_collection is None:
filter_collection = Filtering(Mock()).DEFAULT_FILTER_COLLECTION
@@ -1077,4 +1168,138 @@ def generate_sync_config(
filter_collection=filter_collection,
is_guest=False,
device_id=device_id,
+ use_state_after=use_state_after,
)
+
+
+class SyncStateAfterTestCase(tests.unittest.HomeserverTestCase):
+ """Tests Sync Handler state behavior when using `use_state_after."""
+
+ servlets = [
+ admin.register_servlets,
+ knock.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.sync_handler = self.hs.get_sync_handler()
+ self.store = self.hs.get_datastores().main
+
+ # AuthBlocking reads from the hs' config on initialization. We need to
+ # modify its config instead of the hs'
+ self.auth_blocking = self.hs.get_auth_blocking()
+
+ def test_initial_sync_multiple_deltas(self) -> None:
+ """Test that if multiple state deltas have happened during processing of
+ a full state sync we return the correct state"""
+
+ user = self.register_user("user", "password")
+ tok = self.login("user", "password")
+
+ # Create a room as the user and set some custom state.
+ joined_room = self.helper.create_room_as(user, tok=tok)
+
+ first_state = self.helper.send_state(
+ joined_room, event_type="m.test_event", body={"num": 1}, tok=tok
+ )
+
+ # Take a snapshot of the stream token, to simulate doing an initial sync
+ # at this point.
+ end_stream_token = self.hs.get_event_sources().get_current_token()
+
+ # Send some state *after* the stream token
+ self.helper.send_state(
+ joined_room, event_type="m.test_event", body={"num": 2}, tok=tok
+ )
+
+ # Calculating the full state will return the first state, and not the
+ # second.
+ state = self.get_success(
+ self.sync_handler._compute_state_delta_for_full_sync(
+ room_id=joined_room,
+ sync_config=generate_sync_config(user, use_state_after=True),
+ batch=TimelineBatch(
+ prev_batch=end_stream_token, events=[], limited=True
+ ),
+ end_token=end_stream_token,
+ members_to_fetch=None,
+ timeline_state={},
+ joined=True,
+ )
+ )
+ self.assertEqual(state[("m.test_event", "")], first_state["event_id"])
+
+ def test_incremental_sync_multiple_deltas(self) -> None:
+ """Test that if multiple state deltas have happened since an incremental
+ state sync we return the correct state"""
+
+ user = self.register_user("user", "password")
+ tok = self.login("user", "password")
+
+ # Create a room as the user and set some custom state.
+ joined_room = self.helper.create_room_as(user, tok=tok)
+
+ # Take a snapshot of the stream token, to simulate doing an incremental sync
+ # from this point.
+ since_token = self.hs.get_event_sources().get_current_token()
+
+ self.helper.send_state(
+ joined_room, event_type="m.test_event", body={"num": 1}, tok=tok
+ )
+
+ # Send some state *after* the stream token
+ second_state = self.helper.send_state(
+ joined_room, event_type="m.test_event", body={"num": 2}, tok=tok
+ )
+
+ end_stream_token = self.hs.get_event_sources().get_current_token()
+
+ # Calculating the incrementals state will return the second state, and not the
+ # first.
+ state = self.get_success(
+ self.sync_handler._compute_state_delta_for_incremental_sync(
+ room_id=joined_room,
+ sync_config=generate_sync_config(user, use_state_after=True),
+ batch=TimelineBatch(
+ prev_batch=end_stream_token, events=[], limited=True
+ ),
+ since_token=since_token,
+ end_token=end_stream_token,
+ members_to_fetch=None,
+ timeline_state={},
+ )
+ )
+ self.assertEqual(state[("m.test_event", "")], second_state["event_id"])
+
+ def test_incremental_sync_lazy_loaded_no_timeline(self) -> None:
+ """Test that lazy-loading with an empty timeline doesn't return the full
+ state.
+
+ There was a bug where an empty state filter would cause the DB to return
+ the full state, rather than an empty set.
+ """
+ user = self.register_user("user", "password")
+ tok = self.login("user", "password")
+
+ # Create a room as the user and set some custom state.
+ joined_room = self.helper.create_room_as(user, tok=tok)
+
+ since_token = self.hs.get_event_sources().get_current_token()
+ end_stream_token = self.hs.get_event_sources().get_current_token()
+
+ state = self.get_success(
+ self.sync_handler._compute_state_delta_for_incremental_sync(
+ room_id=joined_room,
+ sync_config=generate_sync_config(user, use_state_after=True),
+ batch=TimelineBatch(
+ prev_batch=end_stream_token, events=[], limited=True
+ ),
+ since_token=since_token,
+ end_token=end_stream_token,
+ members_to_fetch=set(),
+ timeline_state={},
+ )
+ )
+
+ self.assertEqual(state, {})
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 878d9683b6..b12ffc3665 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -796,6 +796,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 1)
+ # Kept old spam checker without `requester_id` tests for backwards compatibility.
async def allow_all(user_profile: UserProfile) -> bool:
# Allow all users.
return False
@@ -809,6 +810,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 1)
+ # Kept old spam checker without `requester_id` tests for backwards compatibility.
# Configure a spam checker that filters all users.
async def block_all(user_profile: UserProfile) -> bool:
# All users are spammy.
@@ -820,6 +822,40 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 0)
+ async def allow_all_expects_requester_id(
+ user_profile: UserProfile, requester_id: str
+ ) -> bool:
+ self.assertEqual(requester_id, u1)
+ # Allow all users.
+ return False
+
+ # Configure a spam checker that does not filter any users.
+ spam_checker = self.hs.get_module_api_callbacks().spam_checker
+ spam_checker._check_username_for_spam_callbacks = [
+ allow_all_expects_requester_id
+ ]
+
+ # The results do not change:
+ # We get one search result when searching for user2 by user1.
+ s = self.get_success(self.handler.search_users(u1, "user2", 10))
+ self.assertEqual(len(s["results"]), 1)
+
+ # Configure a spam checker that filters all users.
+ async def block_all_expects_requester_id(
+ user_profile: UserProfile, requester_id: str
+ ) -> bool:
+ self.assertEqual(requester_id, u1)
+ # All users are spammy.
+ return True
+
+ spam_checker._check_username_for_spam_callbacks = [
+ block_all_expects_requester_id
+ ]
+
+ # User1 now gets no search results for any of the other users.
+ s = self.get_success(self.handler.search_users(u1, "user2", 10))
+ self.assertEqual(len(s["results"]), 0)
+
@override_config(
{
"spam_checker": {
@@ -956,6 +992,67 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
[self.assertIn(user, local_users) for user in received_user_id_ordering[:3]]
[self.assertIn(user, remote_users) for user in received_user_id_ordering[3:]]
+ @override_config(
+ {
+ "user_directory": {
+ "enabled": True,
+ "search_all_users": True,
+ "exclude_remote_users": True,
+ }
+ }
+ )
+ def test_exclude_remote_users(self) -> None:
+ """Tests that only local users are returned when
+ user_directory.exclude_remote_users is True.
+ """
+
+ # Create a room and few users to test the directory with
+ searching_user = self.register_user("searcher", "password")
+ searching_user_tok = self.login("searcher", "password")
+
+ room_id = self.helper.create_room_as(
+ searching_user,
+ room_version=RoomVersions.V1.identifier,
+ tok=searching_user_tok,
+ )
+
+ # Create a few local users and join them to the room
+ local_user_1 = self.register_user("user_xxxxx", "password")
+ local_user_2 = self.register_user("user_bbbbb", "password")
+ local_user_3 = self.register_user("user_zzzzz", "password")
+
+ self._add_user_to_room(room_id, RoomVersions.V1, local_user_1)
+ self._add_user_to_room(room_id, RoomVersions.V1, local_user_2)
+ self._add_user_to_room(room_id, RoomVersions.V1, local_user_3)
+
+ # Create a few "remote" users and join them to the room
+ remote_user_1 = "@user_aaaaa:remote_server"
+ remote_user_2 = "@user_yyyyy:remote_server"
+ remote_user_3 = "@user_ccccc:remote_server"
+ self._add_user_to_room(room_id, RoomVersions.V1, remote_user_1)
+ self._add_user_to_room(room_id, RoomVersions.V1, remote_user_2)
+ self._add_user_to_room(room_id, RoomVersions.V1, remote_user_3)
+
+ local_users = [local_user_1, local_user_2, local_user_3]
+ remote_users = [remote_user_1, remote_user_2, remote_user_3]
+
+ # The local searching user searches for the term "user", which other users have
+ # in their user id
+ results = self.get_success(
+ self.handler.search_users(searching_user, "user", 20)
+ )["results"]
+ received_user_ids = [result["user_id"] for result in results]
+
+ for user in local_users:
+ self.assertIn(
+ user, received_user_ids, f"Local user {user} not found in results"
+ )
+
+ for user in remote_users:
+ self.assertNotIn(
+ user, received_user_ids, f"Remote user {user} should not be in results"
+ )
+
def _add_user_to_room(
self,
room_id: str,
@@ -1081,10 +1178,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
for use_numeric in [False, True]:
if use_numeric:
prefix1 = f"{i}"
- prefix2 = f"{i+1}"
+ prefix2 = f"{i + 1}"
else:
prefix1 = f"a{i}"
- prefix2 = f"a{i+1}"
+ prefix2 = f"a{i + 1}"
local_user_1 = self.register_user(f"user{char}{prefix1}", "password")
local_user_2 = self.register_user(f"user{char}{prefix2}", "password")
diff --git a/tests/handlers/test_worker_lock.py b/tests/handlers/test_worker_lock.py
index 6e9a15c8ee..0691d3f99c 100644
--- a/tests/handlers/test_worker_lock.py
+++ b/tests/handlers/test_worker_lock.py
@@ -19,6 +19,9 @@
#
#
+import logging
+import platform
+
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
@@ -29,6 +32,8 @@ from tests import unittest
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.utils import test_timeout
+logger = logging.getLogger(__name__)
+
class WorkerLockTestCase(unittest.HomeserverTestCase):
def prepare(
@@ -53,12 +58,27 @@ class WorkerLockTestCase(unittest.HomeserverTestCase):
def test_lock_contention(self) -> None:
"""Test lock contention when a lot of locks wait on a single worker"""
-
+ nb_locks_to_test = 500
+ current_machine = platform.machine().lower()
+ if current_machine.startswith("riscv"):
+ # RISC-V specific settings
+ timeout_seconds = 15 # Increased timeout for RISC-V
+ # add a print or log statement here for visibility in CI logs
+ logger.info( # use logger.info
+ f"Detected RISC-V architecture ({current_machine}). "
+ f"Adjusting test_lock_contention: timeout={timeout_seconds}s"
+ )
+ else:
+ # Settings for other architectures
+ timeout_seconds = 5
# It takes around 0.5s on a 5+ years old laptop
- with test_timeout(5):
- nb_locks = 500
- d = self._take_locks(nb_locks)
- self.assertEqual(self.get_success(d), nb_locks)
+ with test_timeout(timeout_seconds): # Use the dynamically set timeout
+ d = self._take_locks(
+ nb_locks_to_test
+ ) # Use the (potentially adjusted) number of locks
+ self.assertEqual(
+ self.get_success(d), nb_locks_to_test
+ ) # Assert against the used number of locks
async def _take_locks(self, nb_locks: int) -> int:
locks = [
|