summary refs log tree commit diff
path: root/tests/handlers
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/handlers/test_admin.py1
-rw-r--r--tests/handlers/test_appservice.py15
-rw-r--r--tests/handlers/test_cas.py239
-rw-r--r--tests/handlers/test_directory.py1
-rw-r--r--tests/handlers/test_e2e_keys.py298
-rw-r--r--tests/handlers/test_federation.py33
-rw-r--r--tests/handlers/test_federation_event.py27
-rw-r--r--tests/handlers/test_oauth_delegation.py188
-rw-r--r--tests/handlers/test_oidc.py291
-rw-r--r--tests/handlers/test_password_providers.py71
-rw-r--r--tests/handlers/test_presence.py217
-rw-r--r--tests/handlers/test_profile.py1
-rw-r--r--tests/handlers/test_register.py58
-rw-r--r--tests/handlers/test_room_list.py6
-rw-r--r--tests/handlers/test_room_member.py224
-rw-r--r--tests/handlers/test_room_policy.py226
-rw-r--r--tests/handlers/test_room_summary.py48
-rw-r--r--tests/handlers/test_saml.py381
-rw-r--r--tests/handlers/test_send_email.py230
-rw-r--r--tests/handlers/test_sliding_sync.py4027
-rw-r--r--tests/handlers/test_sync.py317
-rw-r--r--tests/handlers/test_user_directory.py101
-rw-r--r--tests/handlers/test_worker_lock.py30
23 files changed, 4089 insertions, 2941 deletions
diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py

index 9ff853a83d..c5bff468e2 100644 --- a/tests/handlers/test_admin.py +++ b/tests/handlers/test_admin.py
@@ -260,7 +260,6 @@ class ExfiltrateData(unittest.HomeserverTestCase): self.assertEqual(args[0]["name"], self.user2) self.assertIn("displayname", args[0]) self.assertIn("avatar_url", args[0]) - self.assertIn("threepids", args[0]) self.assertIn("external_ids", args[0]) self.assertIn("creation_ts", args[0]) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 1eec0d43b7..1db630e9e4 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py
@@ -1165,12 +1165,23 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase): self.hs.get_datastores().main.services_cache = [self._service] # Register some appservice users - self._sender_user, self._sender_device = self.register_appservice_user( + user_id, device_id = self.register_appservice_user( "as.sender", self._service_token ) - self._namespaced_user, self._namespaced_device = self.register_appservice_user( + # With MSC4190 enabled, there will not be a device created + # during AS registration. However MSC4190 is not enabled + # in this test. It may become the default behaviour in the + # future, in which case this test will need to be updated. + assert device_id is not None + self._sender_user = user_id + self._sender_device = device_id + + user_id, device_id = self.register_appservice_user( "_as_user1", self._service_token ) + assert device_id is not None + self._namespaced_user = user_id + self._namespaced_device = device_id # Register a real user as well. self._real_user = self.register_user("real.user", "meow") diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py deleted file mode 100644
index f41f7d36ad..0000000000 --- a/tests/handlers/test_cas.py +++ /dev/null
@@ -1,239 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright 2020 The Matrix.org Foundation C.I.C. -# Copyright (C) 2023 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -# Originally licensed under the Apache License, Version 2.0: -# <http://www.apache.org/licenses/LICENSE-2.0>. -# -# [This file includes modifications made by New Vector Limited] -# -# -from typing import Any, Dict -from unittest.mock import AsyncMock, Mock - -from twisted.test.proto_helpers import MemoryReactor - -from synapse.handlers.cas import CasResponse -from synapse.server import HomeServer -from synapse.util import Clock - -from tests.unittest import HomeserverTestCase, override_config - -# These are a few constants that are used as config parameters in the tests. -BASE_URL = "https://synapse/" -SERVER_URL = "https://issuer/" - - -class CasHandlerTestCase(HomeserverTestCase): - def default_config(self) -> Dict[str, Any]: - config = super().default_config() - config["public_baseurl"] = BASE_URL - cas_config = { - "enabled": True, - "server_url": SERVER_URL, - "service_url": BASE_URL, - } - - # Update this config with what's in the default config so that - # override_config works as expected. - cas_config.update(config.get("cas_config", {})) - config["cas_config"] = cas_config - - return config - - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - hs = self.setup_test_homeserver() - - self.handler = hs.get_cas_handler() - - # Reduce the number of attempts when generating MXIDs. - sso_handler = hs.get_sso_handler() - sso_handler._MAP_USERNAME_RETRIES = 3 - - return hs - - def test_map_cas_user_to_user(self) -> None: - """Ensure that mapping the CAS user returned from a provider to an MXID works properly.""" - - # stub out the auth handler - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] - - cas_response = CasResponse("test_user", {}) - request = _mock_request() - self.get_success( - self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") - ) - - # check that the auth handler got called as expected - auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", - "cas", - request, - "redirect_uri", - None, - new_user=True, - auth_provider_session_id=None, - ) - - def test_map_cas_user_to_existing_user(self) -> None: - """Existing users can log in with CAS account.""" - store = self.hs.get_datastores().main - self.get_success( - store.register_user(user_id="@test_user:test", password_hash=None) - ) - - # stub out the auth handler - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] - - # Map a user via SSO. - cas_response = CasResponse("test_user", {}) - request = _mock_request() - self.get_success( - self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") - ) - - # check that the auth handler got called as expected - auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", - "cas", - request, - "redirect_uri", - None, - new_user=False, - auth_provider_session_id=None, - ) - - # Subsequent calls should map to the same mxid. - auth_handler.complete_sso_login.reset_mock() - self.get_success( - self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") - ) - auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", - "cas", - request, - "redirect_uri", - None, - new_user=False, - auth_provider_session_id=None, - ) - - def test_map_cas_user_to_invalid_localpart(self) -> None: - """CAS automaps invalid characters to base-64 encoding.""" - - # stub out the auth handler - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] - - cas_response = CasResponse("föö", {}) - request = _mock_request() - self.get_success( - self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") - ) - - # check that the auth handler got called as expected - auth_handler.complete_sso_login.assert_called_once_with( - "@f=c3=b6=c3=b6:test", - "cas", - request, - "redirect_uri", - None, - new_user=True, - auth_provider_session_id=None, - ) - - @override_config( - { - "cas_config": { - "required_attributes": {"userGroup": "staff", "department": None} - } - } - ) - def test_required_attributes(self) -> None: - """The required attributes must be met from the CAS response.""" - - # stub out the auth handler - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] - - # The response doesn't have the proper userGroup or department. - cas_response = CasResponse("test_user", {}) - request = _mock_request() - self.get_success( - self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") - ) - auth_handler.complete_sso_login.assert_not_called() - - # The response doesn't have any department. - cas_response = CasResponse("test_user", {"userGroup": ["staff"]}) - request.reset_mock() - self.get_success( - self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") - ) - auth_handler.complete_sso_login.assert_not_called() - - # Add the proper attributes and it should succeed. - cas_response = CasResponse( - "test_user", {"userGroup": ["staff", "admin"], "department": ["sales"]} - ) - request.reset_mock() - self.get_success( - self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") - ) - - # check that the auth handler got called as expected - auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", - "cas", - request, - "redirect_uri", - None, - new_user=True, - auth_provider_session_id=None, - ) - - @override_config({"cas_config": {"enable_registration": False}}) - def test_map_cas_user_does_not_register_new_user(self) -> None: - """Ensures new users are not registered if the enabled registration flag is disabled.""" - - # stub out the auth handler - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] - - cas_response = CasResponse("test_user", {}) - request = _mock_request() - self.get_success( - self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") - ) - - # check that the auth handler was not called as expected - auth_handler.complete_sso_login.assert_not_called() - - -def _mock_request() -> Mock: - """Returns a mock which will stand in as a SynapseRequest""" - mock = Mock( - spec=[ - "finish", - "getClientAddress", - "getHeader", - "setHeader", - "setResponseCode", - "write", - ] - ) - # `_disconnected` musn't be another `Mock`, otherwise it will be truthy. - mock._disconnected = False - return mock diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 4a3e36ffde..b7058d8002 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py
@@ -587,6 +587,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): self.room_list_handler = hs.get_room_list_handler() self.directory_handler = hs.get_directory_handler() + @unittest.override_config({"room_list_publication_rules": [{"action": "allow"}]}) def test_disabling_room_list(self) -> None: self.room_list_handler.enable_room_list_search = True self.directory_handler.enable_room_list_search = True diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 8a3dfdcf75..70fc4263e7 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py
@@ -19,6 +19,7 @@ # [This file includes modifications made by New Vector Limited] # # +import time from typing import Dict, Iterable from unittest import mock @@ -151,18 +152,30 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): def test_claim_one_time_key(self) -> None: local_user = "@boris:" + self.hs.hostname device_id = "xyz" - keys = {"alg1:k1": "key1"} - res = self.get_success( self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": keys} + local_user, device_id, {"one_time_keys": {"alg1:k1": "key1"}} ) ) self.assertDictEqual( res, {"one_time_key_counts": {"alg1": 1, "signed_curve25519": 0}} ) - res2 = self.get_success( + # Keys should be returned in the order they were uploaded. To test, advance time + # a little, then upload a second key with an earlier key ID; it should get + # returned second. + self.reactor.advance(1) + res = self.get_success( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": {"alg1:k0": "key0"}} + ) + ) + self.assertDictEqual( + res, {"one_time_key_counts": {"alg1": 2, "signed_curve25519": 0}} + ) + + # now claim both keys back. They should be in the same order + res = self.get_success( self.handler.claim_one_time_keys( {local_user: {device_id: {"alg1": 1}}}, self.requester, @@ -171,12 +184,27 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): ) ) self.assertEqual( - res2, + res, { "failures": {}, "one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}}, }, ) + res = self.get_success( + self.handler.claim_one_time_keys( + {local_user: {device_id: {"alg1": 1}}}, + self.requester, + timeout=None, + always_include_fallback_keys=False, + ) + ) + self.assertEqual( + res, + { + "failures": {}, + "one_time_keys": {local_user: {device_id: {"alg1:k0": "key0"}}}, + }, + ) def test_claim_one_time_key_bulk(self) -> None: """Like test_claim_one_time_key but claims multiple keys in one handler call.""" @@ -336,6 +364,47 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): counts_by_alg, expected_counts_by_alg, f"{user_id}:{device_id}" ) + def test_claim_one_time_key_bulk_ordering(self) -> None: + """Keys returned by the bulk claim call should be returned in the correct order""" + + # Alice has lots of keys, uploaded in a specific order + alice = f"@alice:{self.hs.hostname}" + alice_dev = "alice_dev_1" + + self.get_success( + self.handler.upload_keys_for_user( + alice, + alice_dev, + {"one_time_keys": {"alg1:k20": 20, "alg1:k21": 21, "alg1:k22": 22}}, + ) + ) + # Advance time by 1s, to ensure that there is a difference in upload time. + self.reactor.advance(1) + self.get_success( + self.handler.upload_keys_for_user( + alice, + alice_dev, + {"one_time_keys": {"alg1:k10": 10, "alg1:k11": 11, "alg1:k12": 12}}, + ) + ) + + # Now claim some, and check we get the right ones. + claim_res = self.get_success( + self.handler.claim_one_time_keys( + {alice: {alice_dev: {"alg1": 2}}}, + self.requester, + timeout=None, + always_include_fallback_keys=False, + ) + ) + # We should get the first-uploaded keys, even though they have later key ids. + # We should get a random set of two of k20, k21, k22. + self.assertEqual(claim_res["failures"], {}) + claimed_keys = claim_res["one_time_keys"]["@alice:test"]["alice_dev_1"] + self.assertEqual(len(claimed_keys), 2) + for key_id in claimed_keys.keys(): + self.assertIn(key_id, ["alg1:k20", "alg1:k21", "alg1:k22"]) + def test_fallback_key(self) -> None: local_user = "@boris:" + self.hs.hostname device_id = "xyz" @@ -1758,3 +1827,222 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): ) self.assertIs(exists, True) self.assertIs(replaceable_without_uia, False) + + def test_delete_old_one_time_keys(self) -> None: + """Test the db migration that clears out old OTKs""" + + # We upload two sets of keys, one just over a week ago, and one just less than + # a week ago. Each batch contains some keys that match the deletion pattern + # (key IDs of 6 chars), and some that do not. + # + # Finally, set the scheduled task going, and check what gets deleted. + + user_id = "@user000:" + self.hs.hostname + device_id = "xyz" + + # The scheduled task should be for "now" in real, wallclock time, so + # set the test reactor to just over a week ago. + self.reactor.advance(time.time() - 7.5 * 24 * 3600) + + # Upload some keys + self.get_success( + self.handler.upload_keys_for_user( + user_id, + device_id, + { + "one_time_keys": { + # some keys to delete + "alg1:AAAAAA": "key1", + "alg2:AAAAAB": {"key": "key2", "signatures": {"k1": "sig1"}}, + # A key to *not* delete + "alg2:AAAAAAAAAA": {"key": "key3"}, + } + }, + ) + ) + + # A day passes + self.reactor.advance(24 * 3600) + + # Upload some more keys + self.get_success( + self.handler.upload_keys_for_user( + user_id, + device_id, + { + "one_time_keys": { + # some keys which match the pattern + "alg1:BAAAAA": "key1", + "alg2:BAAAAB": {"key": "key2", "signatures": {"k1": "sig1"}}, + # A key to *not* delete + "alg2:BAAAAAAAAA": {"key": "key3"}, + } + }, + ) + ) + + # The rest of the week passes, which should set the scheduled task going. + self.reactor.advance(6.5 * 24 * 3600) + + # Check what we're left with in the database + remaining_key_ids = { + row[0] + for row in self.get_success( + self.handler.store.db_pool.simple_select_list( + "e2e_one_time_keys_json", None, ["key_id"] + ) + ) + } + self.assertEqual( + remaining_key_ids, {"AAAAAAAAAA", "BAAAAA", "BAAAAB", "BAAAAAAAAA"} + ) + + @override_config( + { + "experimental_features": { + "msc4263_limit_key_queries_to_users_who_share_rooms": True + } + } + ) + def test_query_devices_remote_restricted_not_in_shared_room(self) -> None: + """Tests that querying keys for a remote user that we don't share a room + with returns nothing. + """ + + remote_user_id = "@test:other" + local_user_id = "@test:test" + + # Do *not* pretend we're sharing a room with the user we're querying. + + remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" + remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" + + self.hs.get_federation_client().query_client_keys = mock.AsyncMock( # type: ignore[method-assign] + return_value={ + "device_keys": {remote_user_id: {}}, + "master_keys": { + remote_user_id: { + "user_id": remote_user_id, + "usage": ["master"], + "keys": {"ed25519:" + remote_master_key: remote_master_key}, + }, + }, + "self_signing_keys": { + remote_user_id: { + "user_id": remote_user_id, + "usage": ["self_signing"], + "keys": { + "ed25519:" + + remote_self_signing_key: remote_self_signing_key + }, + } + }, + } + ) + + e2e_handler = self.hs.get_e2e_keys_handler() + + query_result = self.get_success( + e2e_handler.query_devices( + { + "device_keys": {remote_user_id: []}, + }, + timeout=10, + from_user_id=local_user_id, + from_device_id="some_device_id", + ) + ) + + self.assertEqual( + query_result, + { + "device_keys": {}, + "failures": {}, + "master_keys": {}, + "self_signing_keys": {}, + "user_signing_keys": {}, + }, + ) + + @override_config( + { + "experimental_features": { + "msc4263_limit_key_queries_to_users_who_share_rooms": True + } + } + ) + def test_query_devices_remote_restricted_in_shared_room(self) -> None: + """Tests that querying keys for a remote user that we share a room + with returns the cross signing keys correctly. + """ + + remote_user_id = "@test:other" + local_user_id = "@test:test" + + # Pretend we're sharing a room with the user we're querying. If not, + # `query_devices` will filter out the user ID and `_query_devices_for_destination` + # will return early. + self.store.do_users_share_a_room_joined_or_invited = mock.AsyncMock( # type: ignore[method-assign] + return_value=[remote_user_id] + ) + self.store.get_rooms_for_user = mock.AsyncMock(return_value={"some_room_id"}) + + remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" + remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" + + self.hs.get_federation_client().query_user_devices = mock.AsyncMock( # type: ignore[method-assign] + return_value={ + "user_id": remote_user_id, + "stream_id": 1, + "devices": [], + "master_key": { + "user_id": remote_user_id, + "usage": ["master"], + "keys": {"ed25519:" + remote_master_key: remote_master_key}, + }, + "self_signing_key": { + "user_id": remote_user_id, + "usage": ["self_signing"], + "keys": { + "ed25519:" + remote_self_signing_key: remote_self_signing_key + }, + }, + } + ) + + e2e_handler = self.hs.get_e2e_keys_handler() + + query_result = self.get_success( + e2e_handler.query_devices( + { + "device_keys": {remote_user_id: []}, + }, + timeout=10, + from_user_id=local_user_id, + from_device_id="some_device_id", + ) + ) + + self.assertEqual(query_result["failures"], {}) + self.assertEqual( + query_result["master_keys"], + { + remote_user_id: { + "user_id": remote_user_id, + "usage": ["master"], + "keys": {"ed25519:" + remote_master_key: remote_master_key}, + } + }, + ) + self.assertEqual( + query_result["self_signing_keys"], + { + remote_user_id: { + "user_id": remote_user_id, + "usage": ["self_signing"], + "keys": { + "ed25519:" + remote_self_signing_key: remote_self_signing_key + }, + } + }, + ) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 3fe5b0a1b4..b64a8a86a2 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py
@@ -44,7 +44,7 @@ from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.storage.databases.main.events_worker import EventCacheEntry from synapse.util import Clock -from synapse.util.stringutils import random_string +from synapse.util.events import generate_fake_event_id from tests import unittest from tests.test_utils import event_injection @@ -52,10 +52,6 @@ from tests.test_utils import event_injection logger = logging.getLogger(__name__) -def generate_fake_event_id() -> str: - return "$fake_" + random_string(43) - - class FederationTestCase(unittest.FederatingHomeserverTestCase): servlets = [ admin.register_servlets, @@ -665,9 +661,12 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): ) ) - with patch.object( - fed_client, "make_membership_event", mock_make_membership_event - ), patch.object(fed_client, "send_join", mock_send_join): + with ( + patch.object( + fed_client, "make_membership_event", mock_make_membership_event + ), + patch.object(fed_client, "send_join", mock_send_join), + ): # Join and check that our join event is rejected # (The join event is rejected because it doesn't have any signatures) join_exc = self.get_failure( @@ -712,9 +711,12 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): fed_handler = self.hs.get_federation_handler() store = self.hs.get_datastores().main - with patch.object( - fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room - ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room): + with ( + patch.object( + fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room + ), + patch.object(store, "is_partial_state_room", mock_is_partial_state_room), + ): # Start the partial state sync. fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id") self.assertEqual(mock_sync_partial_state_room.call_count, 1) @@ -764,9 +766,12 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): fed_handler = self.hs.get_federation_handler() store = self.hs.get_datastores().main - with patch.object( - fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room - ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room): + with ( + patch.object( + fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room + ), + patch.object(store, "is_partial_state_room", mock_is_partial_state_room), + ): # Start the partial state sync. fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id") self.assertEqual(mock_sync_partial_state_room.call_count, 1) diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index 1b83aea579..51eca56c3b 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py
@@ -288,13 +288,15 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): } # We also expect an outbound request to /state - self.mock_federation_transport_client.get_room_state.return_value = StateRequestResponse( - # Mimic the other server not knowing about the state at all. - # We want to cause Synapse to throw an error (`Unable to get - # missing prev_event $fake_prev_event`) and fail to backfill - # the pulled event. - auth_events=[], - state=[], + self.mock_federation_transport_client.get_room_state.return_value = ( + StateRequestResponse( + # Mimic the other server not knowing about the state at all. + # We want to cause Synapse to throw an error (`Unable to get + # missing prev_event $fake_prev_event`) and fail to backfill + # the pulled event. + auth_events=[], + state=[], + ) ) pulled_event = make_event_from_dict( @@ -373,7 +375,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): In this test, we pretend we are processing a "pulled" event via backfill. The pulled event succesfully processes and the backward - extremeties are updated along with clearing out any failed pull attempts + extremities are updated along with clearing out any failed pull attempts for those old extremities. We check that we correctly cleared failed pull attempts of the @@ -805,6 +807,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}" main_store = self.hs.get_datastores().main + state_deletion_store = self.hs.get_datastores().state_deletion # Create the room. kermit_user_id = self.register_user("kermit", "test") @@ -956,7 +959,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): bert_member_event.event_id: bert_member_event, rejected_kick_event.event_id: rejected_kick_event, }, - state_res_store=StateResolutionStore(main_store), + state_res_store=StateResolutionStore( + main_store, state_deletion_store + ), ) ), [bert_member_event.event_id, rejected_kick_event.event_id], @@ -1001,7 +1006,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): rejected_power_levels_event.event_id, ], event_map={}, - state_res_store=StateResolutionStore(main_store), + state_res_store=StateResolutionStore( + main_store, state_deletion_store + ), full_conflicted_set=set(), ) ), diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py
index 036c539db2..37acb660e7 100644 --- a/tests/handlers/test_oauth_delegation.py +++ b/tests/handlers/test_oauth_delegation.py
@@ -43,6 +43,7 @@ from synapse.api.errors import ( OAuthInsufficientScopeError, SynapseError, ) +from synapse.appservice import ApplicationService from synapse.http.site import SynapseRequest from synapse.rest import admin from synapse.rest.client import account, devices, keys, login, logout, register @@ -146,6 +147,16 @@ class MSC3861OAuthDelegation(HomeserverTestCase): return hs + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + # Provision the user and the device we use in the tests. + store = homeserver.get_datastores().main + self.get_success(store.register_user(USER_ID)) + self.get_success( + store.store_device(USER_ID, DEVICE, initial_device_display_name=None) + ) + def _assertParams(self) -> None: """Assert that the request parameters are correct.""" params = parse_qs(self.http_client.request.call_args[1]["data"].decode("utf-8")) @@ -379,6 +390,44 @@ class MSC3861OAuthDelegation(HomeserverTestCase): ) self.assertEqual(requester.device_id, DEVICE) + def test_active_user_with_device_explicit_device_id(self) -> None: + """The handler should return a requester with normal user rights and a device ID, given explicitly, as supported by MAS 0.15+""" + + self.http_client.request = AsyncMock( + return_value=FakeResponse.json( + code=200, + payload={ + "active": True, + "sub": SUBJECT, + "scope": " ".join([MATRIX_USER_SCOPE]), + "device_id": DEVICE, + "username": USERNAME, + }, + ) + ) + request = Mock(args={}) + request.args[b"access_token"] = [b"mockAccessToken"] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + requester = self.get_success(self.auth.get_user_by_req(request)) + self.http_client.get_json.assert_called_once_with(WELL_KNOWN) + self.http_client.request.assert_called_once_with( + method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY + ) + # It should have called with the 'X-MAS-Supports-Device-Id: 1' header + self.assertEqual( + self.http_client.request.call_args[1]["headers"].getRawHeaders( + b"X-MAS-Supports-Device-Id", + ), + [b"1"], + ) + self._assertParams() + self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME)) + self.assertEqual(requester.is_guest, False) + self.assertEqual( + get_awaitable_result(self.auth.is_server_admin(requester)), False + ) + self.assertEqual(requester.device_id, DEVICE) + def test_multiple_devices(self) -> None: """The handler should raise an error if multiple devices are found in the scope.""" @@ -500,6 +549,44 @@ class MSC3861OAuthDelegation(HomeserverTestCase): error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) self.assertEqual(error.value.code, 503) + def test_cached_expired_introspection(self) -> None: + """The handler should raise an error if the introspection response gives + an expiry time, the introspection response is cached and then the entry is + re-requested after it has expired.""" + + self.http_client.request = introspection_mock = AsyncMock( + return_value=FakeResponse.json( + code=200, + payload={ + "active": True, + "sub": SUBJECT, + "scope": " ".join( + [ + MATRIX_USER_SCOPE, + f"{MATRIX_DEVICE_SCOPE_PREFIX}AABBCC", + ] + ), + "username": USERNAME, + "expires_in": 60, + }, + ) + ) + request = Mock(args={}) + request.args[b"access_token"] = [b"mockAccessToken"] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + + # The first CS-API request causes a successful introspection + self.get_success(self.auth.get_user_by_req(request)) + self.assertEqual(introspection_mock.call_count, 1) + + # Sleep for 60 seconds so the token expires. + self.reactor.advance(60.0) + + # Now the CS-API request fails because the token expired + self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError) + # Ensure another introspection request was not sent + self.assertEqual(introspection_mock.call_count, 1) + def make_device_keys(self, user_id: str, device_id: str) -> JsonDict: # We only generate a master key to simplify the test. master_signing_key = generate_signing_key(device_id) @@ -550,7 +637,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): access_token="mockAccessToken", ) - self.assertEqual(channel.code, HTTPStatus.NOT_IMPLEMENTED, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.json_body) def expect_unauthorized( self, method: str, path: str, content: Union[bytes, str, JsonDict] = "" @@ -560,15 +647,31 @@ class MSC3861OAuthDelegation(HomeserverTestCase): self.assertEqual(channel.code, 401, channel.json_body) def expect_unrecognized( - self, method: str, path: str, content: Union[bytes, str, JsonDict] = "" + self, + method: str, + path: str, + content: Union[bytes, str, JsonDict] = "", + auth: bool = False, ) -> None: - channel = self.make_request(method, path, content) + channel = self.make_request( + method, path, content, access_token="token" if auth else None + ) self.assertEqual(channel.code, 404, channel.json_body) self.assertEqual( channel.json_body["errcode"], Codes.UNRECOGNIZED, channel.json_body ) + def expect_forbidden( + self, method: str, path: str, content: Union[bytes, str, JsonDict] = "" + ) -> None: + channel = self.make_request(method, path, content) + + self.assertEqual(channel.code, 403, channel.json_body) + self.assertEqual( + channel.json_body["errcode"], Codes.FORBIDDEN, channel.json_body + ) + def test_uia_endpoints(self) -> None: """Test that endpoints that were removed in MSC2964 are no longer available.""" @@ -580,36 +683,6 @@ class MSC3861OAuthDelegation(HomeserverTestCase): "POST", "/_matrix/client/v3/keys/device_signing/upload" ) - def test_3pid_endpoints(self) -> None: - """Test that 3pid account management endpoints that were removed in MSC2964 are no longer available.""" - - # Remains and requires auth: - self.expect_unauthorized("GET", "/_matrix/client/v3/account/3pid") - self.expect_unauthorized( - "POST", - "/_matrix/client/v3/account/3pid/bind", - { - "client_secret": "foo", - "id_access_token": "bar", - "id_server": "foo", - "sid": "bar", - }, - ) - self.expect_unauthorized("POST", "/_matrix/client/v3/account/3pid/unbind", {}) - - # These are gone: - self.expect_unrecognized( - "POST", "/_matrix/client/v3/account/3pid" - ) # deprecated - self.expect_unrecognized("POST", "/_matrix/client/v3/account/3pid/add") - self.expect_unrecognized("POST", "/_matrix/client/v3/account/3pid/delete") - self.expect_unrecognized( - "POST", "/_matrix/client/v3/account/3pid/email/requestToken" - ) - self.expect_unrecognized( - "POST", "/_matrix/client/v3/account/3pid/msisdn/requestToken" - ) - def test_account_management_endpoints_removed(self) -> None: """Test that account management endpoints that were removed in MSC2964 are no longer available.""" self.expect_unrecognized("POST", "/_matrix/client/v3/account/deactivate") @@ -623,11 +696,35 @@ class MSC3861OAuthDelegation(HomeserverTestCase): def test_registration_endpoints_removed(self) -> None: """Test that registration endpoints that were removed in MSC2964 are no longer available.""" + appservice = ApplicationService( + token="i_am_an_app_service", + id="1234", + namespaces={"users": [{"regex": r"@alice:.+", "exclusive": True}]}, + sender="@as_main:test", + ) + + self.hs.get_datastores().main.services_cache = [appservice] self.expect_unrecognized( "GET", "/_matrix/client/v1/register/m.login.registration_token/validity" ) + + # Registration is disabled + self.expect_forbidden( + "POST", + "/_matrix/client/v3/register", + {"username": "alice", "password": "hunter2"}, + ) + # This is still available for AS registrations - # self.expect_unrecognized("POST", "/_matrix/client/v3/register") + channel = self.make_request( + "POST", + "/_matrix/client/v3/register", + {"username": "alice", "type": "m.login.application_service"}, + shorthand=False, + access_token="i_am_an_app_service", + ) + self.assertEqual(channel.code, 200, channel.json_body) + self.expect_unrecognized("GET", "/_matrix/client/v3/register/available") self.expect_unrecognized( "POST", "/_matrix/client/v3/register/email/requestToken" @@ -648,8 +745,25 @@ class MSC3861OAuthDelegation(HomeserverTestCase): def test_device_management_endpoints_removed(self) -> None: """Test that device management endpoints that were removed in MSC2964 are no longer available.""" - self.expect_unrecognized("POST", "/_matrix/client/v3/delete_devices") - self.expect_unrecognized("DELETE", "/_matrix/client/v3/devices/{DEVICE}") + + # Because we still support those endpoints with ASes, it checks the + # access token before returning 404 + self.http_client.request = AsyncMock( + return_value=FakeResponse.json( + code=200, + payload={ + "active": True, + "sub": SUBJECT, + "scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]), + "username": USERNAME, + }, + ) + ) + + self.expect_unrecognized("POST", "/_matrix/client/v3/delete_devices", auth=True) + self.expect_unrecognized( + "DELETE", "/_matrix/client/v3/devices/{DEVICE}", auth=True + ) def test_openid_endpoints_removed(self) -> None: """Test that OpenID id_token endpoints that were removed in MSC2964 are no longer available.""" @@ -772,7 +886,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): req = SynapseRequest(channel, self.site) # type: ignore[arg-type] req.client.host = MAS_IPV4_ADDR req.requestHeaders.addRawHeader( - "Authorization", f"Bearer {self.auth._admin_token}" + "Authorization", f"Bearer {self.auth._admin_token()}" ) req.requestHeaders.addRawHeader("User-Agent", MAS_USER_AGENT) req.content = BytesIO(b"") diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index a81501979d..ff8e3c5cb6 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py
@@ -57,6 +57,7 @@ CLIENT_ID = "test-client-id" CLIENT_SECRET = "test-client-secret" BASE_URL = "https://synapse/" CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback" +TEST_REDIRECT_URI = "https://test/oidc/callback" SCOPES = ["openid"] # config for common cases @@ -70,12 +71,16 @@ DEFAULT_CONFIG = { } # extends the default config with explicit OAuth2 endpoints instead of using discovery +# +# We add "explicit" to things to make them different from the discovered values to make +# sure that the explicit values override the discovered ones. EXPLICIT_ENDPOINT_CONFIG = { **DEFAULT_CONFIG, "discover": False, - "authorization_endpoint": ISSUER + "authorize", - "token_endpoint": ISSUER + "token", - "jwks_uri": ISSUER + "jwks", + "authorization_endpoint": ISSUER + "authorize-explicit", + "token_endpoint": ISSUER + "token-explicit", + "jwks_uri": ISSUER + "jwks-explicit", + "id_token_signing_alg_values_supported": ["RS256", "<explicit>"], } @@ -259,12 +264,64 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(self.provider.load_metadata()) self.fake_server.get_metadata_handler.assert_not_called() + @override_config({"oidc_config": {**EXPLICIT_ENDPOINT_CONFIG, "discover": True}}) + def test_discovery_with_explicit_config(self) -> None: + """ + The handler should discover the endpoints from OIDC discovery document but + values are overriden by the explicit config. + """ + # This would throw if some metadata were invalid + metadata = self.get_success(self.provider.load_metadata()) + self.fake_server.get_metadata_handler.assert_called_once() + + self.assertEqual(metadata.issuer, self.fake_server.issuer) + # It seems like authlib does not have that defined in its metadata models + self.assertEqual( + metadata.get("userinfo_endpoint"), + self.fake_server.userinfo_endpoint, + ) + + # Ensure the values are overridden correctly since these were configured + # explicitly + self.assertEqual( + metadata.authorization_endpoint, + EXPLICIT_ENDPOINT_CONFIG["authorization_endpoint"], + ) + self.assertEqual( + metadata.token_endpoint, EXPLICIT_ENDPOINT_CONFIG["token_endpoint"] + ) + self.assertEqual(metadata.jwks_uri, EXPLICIT_ENDPOINT_CONFIG["jwks_uri"]) + self.assertEqual( + metadata.id_token_signing_alg_values_supported, + EXPLICIT_ENDPOINT_CONFIG["id_token_signing_alg_values_supported"], + ) + + # subsequent calls should be cached + self.reset_mocks() + self.get_success(self.provider.load_metadata()) + self.fake_server.get_metadata_handler.assert_not_called() + @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) def test_no_discovery(self) -> None: """When discovery is disabled, it should not try to load from discovery document.""" - self.get_success(self.provider.load_metadata()) + metadata = self.get_success(self.provider.load_metadata()) self.fake_server.get_metadata_handler.assert_not_called() + # Ensure the values are overridden correctly since these were configured + # explicitly + self.assertEqual( + metadata.authorization_endpoint, + EXPLICIT_ENDPOINT_CONFIG["authorization_endpoint"], + ) + self.assertEqual( + metadata.token_endpoint, EXPLICIT_ENDPOINT_CONFIG["token_endpoint"] + ) + self.assertEqual(metadata.jwks_uri, EXPLICIT_ENDPOINT_CONFIG["jwks_uri"]) + self.assertEqual( + metadata.id_token_signing_alg_values_supported, + EXPLICIT_ENDPOINT_CONFIG["id_token_signing_alg_values_supported"], + ) + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_load_jwks(self) -> None: """JWKS loading is done once (then cached) if used.""" @@ -427,6 +484,32 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(code_verifier, "") self.assertEqual(redirect, "http://client/redirect") + @override_config( + { + "oidc_config": { + **DEFAULT_CONFIG, + "passthrough_authorization_parameters": ["additional_parameter"], + } + } + ) + def test_passthrough_parameters(self) -> None: + """The redirect request has additional parameters, one is authorized, one is not""" + req = Mock(spec=["cookies", "args"]) + req.cookies = [] + req.args = {} + req.args[b"additional_parameter"] = ["a_value".encode("utf-8")] + req.args[b"not_authorized_parameter"] = ["any".encode("utf-8")] + + url = urlparse( + self.get_success( + self.provider.handle_redirect_request(req, b"http://client/redirect") + ) + ) + + params = parse_qs(url.query) + self.assertEqual(params["additional_parameter"], ["a_value"]) + self.assertNotIn("not_authorized_parameters", params) + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_redirect_request_with_code_challenge(self) -> None: """The redirect request has the right arguments & generates a valid session cookie.""" @@ -530,6 +613,24 @@ class OidcHandlerTestCase(HomeserverTestCase): code_verifier = get_value_from_macaroon(macaroon, "code_verifier") self.assertEqual(code_verifier, "") + @override_config( + {"oidc_config": {**DEFAULT_CONFIG, "redirect_uri": TEST_REDIRECT_URI}} + ) + def test_redirect_request_with_overridden_redirect_uri(self) -> None: + """The authorization endpoint redirect has the overridden `redirect_uri` value.""" + req = Mock(spec=["cookies"]) + req.cookies = [] + + url = urlparse( + self.get_success( + self.provider.handle_redirect_request(req, b"http://client/redirect") + ) + ) + + # Ensure that the redirect_uri in the returned url has been overridden. + params = parse_qs(url.query) + self.assertEqual(params["redirect_uri"], [TEST_REDIRECT_URI]) + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_callback_error(self) -> None: """Errors from the provider returned in the callback are displayed.""" @@ -901,6 +1002,81 @@ class OidcHandlerTestCase(HomeserverTestCase): { "oidc_config": { **DEFAULT_CONFIG, + "redirect_uri": TEST_REDIRECT_URI, + } + } + ) + def test_code_exchange_with_overridden_redirect_uri(self) -> None: + """Code exchange behaves correctly and handles various error scenarios.""" + # Set up a fake IdP with a token endpoint handler. + token = { + "type": "Bearer", + "access_token": "aabbcc", + } + + self.fake_server.post_token_handler.side_effect = None + self.fake_server.post_token_handler.return_value = FakeResponse.json( + payload=token + ) + code = "code" + + # Exchange the code against the fake IdP. + self.get_success(self.provider._exchange_code(code, code_verifier="")) + + # Check that the `redirect_uri` parameter provided matches our + # overridden config value. + kwargs = self.fake_server.request.call_args[1] + args = parse_qs(kwargs["data"].decode("utf-8")) + self.assertEqual(args["redirect_uri"], [TEST_REDIRECT_URI]) + + @override_config( + { + "oidc_config": { + **DEFAULT_CONFIG, + "redirect_uri": TEST_REDIRECT_URI, + } + } + ) + def test_code_exchange_ignores_access_token(self) -> None: + """ + Code exchange completes successfully and doesn't validate the `at_hash` + (access token hash) field of an ID token when the access token isn't + going to be used. + + The access token won't be used in this test because Synapse (currently) + only needs it to fetch a user's metadata if it isn't included in the ID + token itself. + + Because we have included "openid" in the requested scopes for this IdP + (see `SCOPES`), user metadata is be included in the ID token. Thus the + access token isn't needed, and it's unnecessary for Synapse to validate + the access token. + + This is a regression test for a situation where an upstream identity + provider was providing an invalid `at_hash` value, which Synapse errored + on, yet Synapse wasn't using the access token for anything. + """ + # Exchange the code against the fake IdP. + userinfo = { + "sub": "foo", + "username": "foo", + "phone": "1234567", + } + with self.fake_server.id_token_override( + { + "at_hash": "invalid-hash", + } + ): + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + + # If no error was rendered, then we have success. + self.render_error.assert_not_called() + + @override_config( + { + "oidc_config": { + **DEFAULT_CONFIG, "user_mapping_provider": { "module": __name__ + ".TestMappingProviderExtra" }, @@ -1271,6 +1447,113 @@ class OidcHandlerTestCase(HomeserverTestCase): { "oidc_config": { **DEFAULT_CONFIG, + "attribute_requirements": [ + {"attribute": "test", "one_of": ["foo", "bar"]} + ], + } + } + ) + def test_attribute_requirements_one_of_succeeds(self) -> None: + """Test that auth succeeds if userinfo attribute has multiple values and CONTAINS required value""" + # userinfo with "test": ["bar"] attribute should succeed. + userinfo = { + "sub": "tester", + "username": "tester", + "test": ["bar"], + } + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + + # check that the auth handler got called as expected + self.complete_sso_login.assert_called_once_with( + "@tester:test", + self.provider.idp_id, + request, + ANY, + None, + new_user=True, + auth_provider_session_id=None, + ) + + @override_config( + { + "oidc_config": { + **DEFAULT_CONFIG, + "attribute_requirements": [ + {"attribute": "test", "one_of": ["foo", "bar"]} + ], + } + } + ) + def test_attribute_requirements_one_of_fails(self) -> None: + """Test that auth fails if userinfo attribute has multiple values yet + DOES NOT CONTAIN a required value + """ + # userinfo with "test": ["something else"] attribute should fail. + userinfo = { + "sub": "tester", + "username": "tester", + "test": ["something else"], + } + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() + + @override_config( + { + "oidc_config": { + **DEFAULT_CONFIG, + "attribute_requirements": [{"attribute": "test"}], + } + } + ) + def test_attribute_requirements_does_not_exist(self) -> None: + """OIDC login fails if the required attribute does not exist in the OIDC userinfo response.""" + # userinfo lacking "test" attribute should fail. + userinfo = { + "sub": "tester", + "username": "tester", + } + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + self.complete_sso_login.assert_not_called() + + @override_config( + { + "oidc_config": { + **DEFAULT_CONFIG, + "attribute_requirements": [{"attribute": "test"}], + } + } + ) + def test_attribute_requirements_exist(self) -> None: + """OIDC login succeeds if the required attribute exist (regardless of value) + in the OIDC userinfo response. + """ + # userinfo with "test" attribute and random value should succeed. + userinfo = { + "sub": "tester", + "username": "tester", + "test": random_string(5), # value does not matter + } + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + + # check that the auth handler got called as expected + self.complete_sso_login.assert_called_once_with( + "@tester:test", + self.provider.idp_id, + request, + ANY, + None, + new_user=True, + auth_provider_session_id=None, + ) + + @override_config( + { + "oidc_config": { + **DEFAULT_CONFIG, "attribute_requirements": [{"attribute": "test", "value": "foobar"}], } } diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index ed203eb299..d0351a8509 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py
@@ -768,17 +768,6 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # Check that the callback has been called. m.assert_called_once() - # Set some email configuration so the test doesn't fail because of its absence. - @override_config({"email": {"notif_from": "noreply@test"}}) - def test_3pid_allowed(self) -> None: - """Tests that an is_3pid_allowed_callbacks forbidding a 3PID makes Synapse refuse - to bind the new 3PID, and that one allowing a 3PID makes Synapse accept to bind - the 3PID. Also checks that the module is passed a boolean indicating whether the - user to bind this 3PID to is currently registering. - """ - self._test_3pid_allowed("rin", False) - self._test_3pid_allowed("kitay", True) - def test_displayname(self) -> None: """Tests that the get_displayname_for_registration callback can define the display name of a user when registering. @@ -829,66 +818,6 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # Check that the callback has been called. m.assert_called_once() - def _test_3pid_allowed(self, username: str, registration: bool) -> None: - """Tests that the "is_3pid_allowed" module callback is called correctly, using - either /register or /account URLs depending on the arguments. - - Args: - username: The username to use for the test. - registration: Whether to test with registration URLs. - """ - self.hs.get_identity_handler().send_threepid_validation = AsyncMock( # type: ignore[method-assign] - return_value=0 - ) - - m = AsyncMock(return_value=False) - self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m] - - self.register_user(username, "password") - tok = self.login(username, "password") - - if registration: - url = "/register/email/requestToken" - else: - url = "/account/3pid/email/requestToken" - - channel = self.make_request( - "POST", - url, - { - "client_secret": "foo", - "email": "foo@test.com", - "send_attempt": 0, - }, - access_token=tok, - ) - self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result) - self.assertEqual( - channel.json_body["errcode"], - Codes.THREEPID_DENIED, - channel.json_body, - ) - - m.assert_called_once_with("email", "foo@test.com", registration) - - m = AsyncMock(return_value=True) - self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m] - - channel = self.make_request( - "POST", - url, - { - "client_secret": "foo", - "email": "bar@test.com", - "send_attempt": 0, - }, - access_token=tok, - ) - self.assertEqual(channel.code, HTTPStatus.OK, channel.result) - self.assertIn("sid", channel.json_body) - - m.assert_called_once_with("email", "bar@test.com", registration) - def _setup_get_name_for_registration(self, callback_name: str) -> Mock: """Registers either a get_username_for_registration callback or a get_displayname_for_registration callback that appends "-foo" to the username the diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index cc630d606c..6b7bf112c2 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py
@@ -23,14 +23,21 @@ from typing import Optional, cast from unittest.mock import Mock, call from parameterized import parameterized -from signedjson.key import generate_signing_key +from signedjson.key import ( + encode_verify_key_base64, + generate_signing_key, + get_verify_key, +) from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes, Membership, PresenceState from synapse.api.presence import UserDevicePresenceState, UserPresenceState -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.events.builder import EventBuilder +from synapse.api.room_versions import ( + RoomVersion, +) +from synapse.crypto.event_signing import add_hashes_and_signatures +from synapse.events import EventBase, make_event_from_dict from synapse.federation.sender import FederationSender from synapse.handlers.presence import ( BUSY_ONLINE_TIMEOUT, @@ -45,18 +52,24 @@ from synapse.handlers.presence import ( handle_update, ) from synapse.rest import admin -from synapse.rest.client import room +from synapse.rest.client import login, room, sync from synapse.server import HomeServer from synapse.storage.database import LoggingDatabaseConnection +from synapse.storage.keys import FetchKeyResult from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.util import Clock from tests import unittest from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.unittest import override_config class PresenceUpdateTestCase(unittest.HomeserverTestCase): - servlets = [admin.register_servlets] + servlets = [ + admin.register_servlets, + login.register_servlets, + sync.register_servlets, + ] def prepare( self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer @@ -425,6 +438,102 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): wheel_timer.insert.assert_not_called() + # `rc_presence` is set very high during unit tests to avoid ratelimiting + # subtly impacting unrelated tests. We set the ratelimiting back to a + # reasonable value for the tests specific to presence ratelimiting. + @override_config( + {"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}} + ) + def test_over_ratelimit_offline_to_online_to_unavailable(self) -> None: + """ + Send a presence update, check that it went through, immediately send another one and + check that it was ignored. + """ + self._test_ratelimit_offline_to_online_to_unavailable(ratelimited=True) + + @override_config( + {"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}} + ) + def test_within_ratelimit_offline_to_online_to_unavailable(self) -> None: + """ + Send a presence update, check that it went through, advancing time a sufficient amount, + send another presence update and check that it also worked. + """ + self._test_ratelimit_offline_to_online_to_unavailable(ratelimited=False) + + @override_config( + {"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}} + ) + def _test_ratelimit_offline_to_online_to_unavailable( + self, ratelimited: bool + ) -> None: + """Test rate limit for presence updates sent with sync requests. + + Args: + ratelimited: Test rate limited case. + """ + wheel_timer = Mock() + user_id = "@user:pass" + now = 5000000 + sync_url = "/sync?access_token=%s&set_presence=%s" + + # Register the user who syncs presence + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Get the handler (which kicks off a bunch of timers). + presence_handler = self.hs.get_presence_handler() + + # Ensure the user is initially offline. + prev_state = UserPresenceState.default(user_id) + new_state = prev_state.copy_and_replace( + state=PresenceState.OFFLINE, last_active_ts=now + ) + + state, persist_and_notify, federation_ping = handle_update( + prev_state, + new_state, + is_mine=True, + wheel_timer=wheel_timer, + now=now, + persist=False, + ) + + # Check that the user is offline. + state = self.get_success( + presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(state.state, PresenceState.OFFLINE) + + # Send sync request with set_presence=online. + channel = self.make_request("GET", sync_url % (access_token, "online")) + self.assertEqual(200, channel.code) + + # Assert the user is now online. + state = self.get_success( + presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(state.state, PresenceState.ONLINE) + + if not ratelimited: + # Advance time a sufficient amount to avoid rate limiting. + self.reactor.advance(30) + + # Send another sync request with set_presence=unavailable. + channel = self.make_request("GET", sync_url % (access_token, "unavailable")) + self.assertEqual(200, channel.code) + + state = self.get_success( + presence_handler.get_state(UserID.from_string(user_id)) + ) + + if ratelimited: + # Assert the user is still online and presence update was ignored. + self.assertEqual(state.state, PresenceState.ONLINE) + else: + # Assert the user is now unavailable. + self.assertEqual(state.state, PresenceState.UNAVAILABLE) + class PresenceTimeoutTestCase(unittest.TestCase): """Tests different timers and that the timer does not change `status_msg` of user.""" @@ -1107,7 +1216,9 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): ), ] ], - name_func=lambda testcase_func, param_num, params: f"{testcase_func.__name__}_{param_num}_{'workers' if params.args[5] else 'monolith'}", + name_func=lambda testcase_func, + param_num, + params: f"{testcase_func.__name__}_{param_num}_{'workers' if params.args[5] else 'monolith'}", ) @unittest.override_config({"experimental_features": {"msc3026_enabled": True}}) def test_set_presence_from_syncing_multi_device( @@ -1343,7 +1454,9 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): ), ] ], - name_func=lambda testcase_func, param_num, params: f"{testcase_func.__name__}_{param_num}_{'workers' if params.args[4] else 'monolith'}", + name_func=lambda testcase_func, + param_num, + params: f"{testcase_func.__name__}_{param_num}_{'workers' if params.args[4] else 'monolith'}", ) @unittest.override_config({"experimental_features": {"msc3026_enabled": True}}) def test_set_presence_from_non_syncing_multi_device( @@ -1821,6 +1934,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): # self.event_builder_for_2.hostname = "test2" self.store = hs.get_datastores().main + self.storage_controllers = hs.get_storage_controllers() self.state = hs.get_state_handler() self._event_auth_handler = hs.get_event_auth_handler() @@ -1936,29 +2050,35 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): hostname = get_domain_from_id(user_id) - room_version = self.get_success(self.store.get_room_version_id(room_id)) + room_version = self.get_success(self.store.get_room_version(room_id)) - builder = EventBuilder( - state=self.state, - event_auth_handler=self._event_auth_handler, - store=self.store, - clock=self.clock, - hostname=hostname, - signing_key=self.random_signing_key, - room_version=KNOWN_ROOM_VERSIONS[room_version], - room_id=room_id, - type=EventTypes.Member, - sender=user_id, - state_key=user_id, - content={"membership": Membership.JOIN}, + state_map = self.get_success( + self.storage_controllers.state.get_current_state(room_id) ) - prev_event_ids = self.get_success( - self.store.get_latest_event_ids_in_room(room_id) + # Figure out what the forward extremities in the room are (the most recent + # events that aren't tied into the DAG) + forward_extremity_event_ids = self.get_success( + self.hs.get_datastores().main.get_latest_event_ids_in_room(room_id) ) - event = self.get_success( - builder.build(prev_event_ids=list(prev_event_ids), auth_event_ids=None) + event = self.create_fake_event_from_remote_server( + remote_server_name=hostname, + event_dict={ + "room_id": room_id, + "sender": user_id, + "type": EventTypes.Member, + "state_key": user_id, + "depth": 1000, + "origin_server_ts": 1, + "content": {"membership": Membership.JOIN}, + "auth_events": [ + state_map[(EventTypes.Create, "")].event_id, + state_map[(EventTypes.JoinRules, "")].event_id, + ], + "prev_events": list(forward_extremity_event_ids), + }, + room_version=room_version, ) self.get_success(self.federation_event_handler.on_receive_pdu(hostname, event)) @@ -1966,3 +2086,50 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): # Check that it was successfully persisted. self.get_success(self.store.get_event(event.event_id)) self.get_success(self.store.get_event(event.event_id)) + + def create_fake_event_from_remote_server( + self, remote_server_name: str, event_dict: JsonDict, room_version: RoomVersion + ) -> EventBase: + """ + This is similar to what `FederatingHomeserverTestCase` is doing but we don't + need all of the extra baggage and we want to be able to create an event from + many remote servers. + """ + + # poke the other server's signing key into the key store, so that we don't + # make requests for it + other_server_signature_key = generate_signing_key("test") + verify_key = get_verify_key(other_server_signature_key) + verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version) + + self.get_success( + self.hs.get_datastores().main.store_server_keys_response( + remote_server_name, + from_server=remote_server_name, + ts_added_ms=self.clock.time_msec(), + verify_keys={ + verify_key_id: FetchKeyResult( + verify_key=verify_key, + valid_until_ts=self.clock.time_msec() + 10000, + ), + }, + response_json={ + "verify_keys": { + verify_key_id: {"key": encode_verify_key_base64(verify_key)} + } + }, + ) + ) + + add_hashes_and_signatures( + room_version=room_version, + event_dict=event_dict, + signature_name=remote_server_name, + signing_key=other_server_signature_key, + ) + event = make_event_from_dict( + event_dict, + room_version=room_version, + ) + + return event diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index cb1c6fbb80..2b9b56da95 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py
@@ -369,6 +369,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): time_now_ms=self.clock.time_msec(), upload_name=None, filesystem_id="xyz", + sha256="abcdefg12345", ) ) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 92487692db..99bd0de834 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py
@@ -588,6 +588,29 @@ class RegistrationTestCase(unittest.HomeserverTestCase): d = self.store.is_support_user(user_id) self.assertFalse(self.get_success(d)) + def test_underscore_localpart_rejected_by_default(self) -> None: + for invalid_user_id in ("_", "_prefixed"): + with self.subTest(invalid_user_id=invalid_user_id): + self.get_failure( + self.handler.register_user(localpart=invalid_user_id), + SynapseError, + ) + + @override_config( + { + "allow_underscore_prefixed_localpart": True, + } + ) + def test_underscore_localpart_allowed_if_configured(self) -> None: + for valid_user_id in ("_", "_prefixed"): + with self.subTest(valid_user_id=valid_user_id): + user_id = self.get_success( + self.handler.register_user( + localpart=valid_user_id, + ), + ) + self.assertEqual(user_id, f"@{valid_user_id}:test") + def test_invalid_user_id(self) -> None: invalid_user_id = "^abcd" self.get_failure( @@ -715,6 +738,41 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.handler.register_user(localpart="bobflimflob", auth_provider_id="saml") ) + def test_register_default_user_type(self) -> None: + """Test that the default user type is none when registering a user.""" + user_id = self.get_success(self.handler.register_user(localpart="user")) + user_info = self.get_success(self.store.get_user_by_id(user_id)) + assert user_info is not None + self.assertEqual(user_info.user_type, None) + + def test_register_extra_user_types_valid(self) -> None: + """ + Test that the specified user type is set correctly when registering a user. + n.b. No validation is done on the user type, so this test + is only to ensure that the user type can be set to any value. + """ + user_id = self.get_success( + self.handler.register_user(localpart="user", user_type="anyvalue") + ) + user_info = self.get_success(self.store.get_user_by_id(user_id)) + assert user_info is not None + self.assertEqual(user_info.user_type, "anyvalue") + + @override_config( + { + "user_types": { + "extra_user_types": ["extra1", "extra2"], + "default_user_type": "extra1", + } + } + ) + def test_register_extra_user_types_with_default(self) -> None: + """Test that the default_user_type in config is set correctly when registering a user.""" + user_id = self.get_success(self.handler.register_user(localpart="user")) + user_info = self.get_success(self.store.get_user_by_id(user_id)) + assert user_info is not None + self.assertEqual(user_info.user_type, "extra1") + async def get_or_create_user( self, requester: Requester, diff --git a/tests/handlers/test_room_list.py b/tests/handlers/test_room_list.py
index 4d22ef98c2..45cef09b22 100644 --- a/tests/handlers/test_room_list.py +++ b/tests/handlers/test_room_list.py
@@ -6,6 +6,7 @@ from synapse.rest.client import directory, login, room from synapse.types import JsonDict from tests import unittest +from tests.utils import default_config class RoomListHandlerTestCase(unittest.HomeserverTestCase): @@ -30,6 +31,11 @@ class RoomListHandlerTestCase(unittest.HomeserverTestCase): assert channel.code == HTTPStatus.OK, f"couldn't publish room: {channel.result}" return room_id + def default_config(self) -> JsonDict: + config = default_config("test") + config["room_list_publication_rules"] = [{"action": "allow"}] + return config + def test_acls_applied_to_room_directory_results(self) -> None: """ Creates 3 rooms. Room 2 has an ACL that only permits the homeservers diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py
index 213a66ed1a..d87fe9d62c 100644 --- a/tests/handlers/test_room_member.py +++ b/tests/handlers/test_room_member.py
@@ -5,10 +5,13 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin import synapse.rest.client.login import synapse.rest.client.room -from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import LimitExceededError, SynapseError +from synapse.api.constants import AccountDataTypes, EventTypes, Membership +from synapse.api.errors import Codes, LimitExceededError, SynapseError from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events import FrozenEventV3 +from synapse.federation.federation_base import ( + event_from_pdu_json, +) from synapse.federation.federation_client import SendJoinResult from synapse.server import HomeServer from synapse.types import UserID, create_requester @@ -172,20 +175,25 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase): ) ) - with patch.object( - self.handler.federation_handler.federation_client, - "make_membership_event", - mock_make_membership_event, - ), patch.object( - self.handler.federation_handler.federation_client, - "send_join", - mock_send_join, - ), patch( - "synapse.event_auth._is_membership_change_allowed", - return_value=None, - ), patch( - "synapse.handlers.federation_event.check_state_dependent_auth_rules", - return_value=None, + with ( + patch.object( + self.handler.federation_handler.federation_client, + "make_membership_event", + mock_make_membership_event, + ), + patch.object( + self.handler.federation_handler.federation_client, + "send_join", + mock_send_join, + ), + patch( + "synapse.event_auth._is_membership_change_allowed", + return_value=None, + ), + patch( + "synapse.handlers.federation_event.check_state_dependent_auth_rules", + return_value=None, + ), ): self.get_success( self.handler.update_membership( @@ -380,9 +388,29 @@ class RoomMemberMasterHandlerTestCase(HomeserverTestCase): ) def test_forget_when_not_left(self) -> None: - """Tests that a user cannot not forgets a room that has not left.""" + """Tests that a user cannot forget a room that they are still in.""" self.get_failure(self.handler.forget(self.alice_ID, self.room_id), SynapseError) + def test_nonlocal_room_user_action(self) -> None: + """ + Test that non-local user ids cannot perform room actions through + this homeserver. + """ + alien_user_id = UserID.from_string("@cheeky_monkey:matrix.org") + bad_room_id = f"{self.room_id}+BAD_ID" + + exc = self.get_failure( + self.handler.update_membership( + create_requester(self.alice), + alien_user_id, + bad_room_id, + "unban", + ), + SynapseError, + ).value + + self.assertEqual(exc.errcode, Codes.BAD_JSON) + def test_rejoin_forgotten_by_user(self) -> None: """Test that a user that has forgotten a room can do a re-join. The room was not forgotten from the local server. @@ -428,3 +456,165 @@ class RoomMemberMasterHandlerTestCase(HomeserverTestCase): new_count = rows[0][0] self.assertEqual(initial_count, new_count) + + +class TestInviteFiltering(FederatingHomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + synapse.rest.client.login.register_servlets, + synapse.rest.client.room.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.handler = hs.get_room_member_handler() + self.fed_handler = hs.get_federation_handler() + self.store = hs.get_datastores().main + + # Create three users. + self.alice = self.register_user("alice", "pass") + self.alice_token = self.login("alice", "pass") + self.bob = self.register_user("bob", "pass") + self.bob_token = self.login("bob", "pass") + + @override_config({"experimental_features": {"msc4155_enabled": True}}) + def test_misc4155_block_invite_local(self) -> None: + """Test that MSC4155 will block a user from being invited to a room""" + room_id = self.helper.create_room_as(self.alice, tok=self.alice_token) + + self.get_success( + self.store.add_account_data_for_user( + self.bob, + AccountDataTypes.MSC4155_INVITE_PERMISSION_CONFIG, + { + "blocked_users": [self.alice], + }, + ) + ) + + f = self.get_failure( + self.handler.update_membership( + requester=create_requester(self.alice), + target=UserID.from_string(self.bob), + room_id=room_id, + action=Membership.INVITE, + ), + SynapseError, + ).value + self.assertEqual(f.code, 403) + self.assertEqual(f.errcode, "ORG.MATRIX.MSC4155.M_INVITE_BLOCKED") + + @override_config({"experimental_features": {"msc4155_enabled": False}}) + def test_msc4155_disabled_allow_invite_local(self) -> None: + """Test that MSC4155 will block a user from being invited to a room""" + room_id = self.helper.create_room_as(self.alice, tok=self.alice_token) + + self.get_success( + self.store.add_account_data_for_user( + self.bob, + AccountDataTypes.MSC4155_INVITE_PERMISSION_CONFIG, + { + "blocked_users": [self.alice], + }, + ) + ) + + self.get_success( + self.handler.update_membership( + requester=create_requester(self.alice), + target=UserID.from_string(self.bob), + room_id=room_id, + action=Membership.INVITE, + ), + ) + + @override_config({"experimental_features": {"msc4155_enabled": True}}) + def test_msc4155_block_invite_remote(self) -> None: + """Test that MSC4155 will block a remote user from being invited to a room""" + # A remote user who sends the invite + remote_server = "otherserver" + remote_user = "@otheruser:" + remote_server + + self.get_success( + self.store.add_account_data_for_user( + self.bob, + AccountDataTypes.MSC4155_INVITE_PERMISSION_CONFIG, + {"blocked_users": [remote_user]}, + ) + ) + + room_id = self.helper.create_room_as( + room_creator=self.alice, tok=self.alice_token + ) + room_version = self.get_success(self.store.get_room_version(room_id)) + + invite_event = event_from_pdu_json( + { + "type": EventTypes.Member, + "content": {"membership": "invite"}, + "room_id": room_id, + "sender": remote_user, + "state_key": self.bob, + "depth": 32, + "prev_events": [], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + room_version, + ) + + f = self.get_failure( + self.fed_handler.on_invite_request( + remote_server, + invite_event, + invite_event.room_version, + ), + SynapseError, + ).value + self.assertEqual(f.code, 403) + self.assertEqual(f.errcode, "ORG.MATRIX.MSC4155.M_INVITE_BLOCKED") + + @override_config({"experimental_features": {"msc4155_enabled": True}}) + def test_msc4155_block_invite_remote_server(self) -> None: + """Test that MSC4155 will block a remote server's user from being invited to a room""" + # A remote user who sends the invite + remote_server = "otherserver" + remote_user = "@otheruser:" + remote_server + + self.get_success( + self.store.add_account_data_for_user( + self.bob, + AccountDataTypes.MSC4155_INVITE_PERMISSION_CONFIG, + {"blocked_servers": [remote_server]}, + ) + ) + + room_id = self.helper.create_room_as( + room_creator=self.alice, tok=self.alice_token + ) + room_version = self.get_success(self.store.get_room_version(room_id)) + + invite_event = event_from_pdu_json( + { + "type": EventTypes.Member, + "content": {"membership": "invite"}, + "room_id": room_id, + "sender": remote_user, + "state_key": self.bob, + "depth": 32, + "prev_events": [], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + room_version, + ) + + f = self.get_failure( + self.fed_handler.on_invite_request( + remote_server, + invite_event, + invite_event.room_version, + ), + SynapseError, + ).value + self.assertEqual(f.code, 403) + self.assertEqual(f.errcode, "ORG.MATRIX.MSC4155.M_INVITE_BLOCKED") diff --git a/tests/handlers/test_room_policy.py b/tests/handlers/test_room_policy.py new file mode 100644
index 0000000000..26642c18ea --- /dev/null +++ b/tests/handlers/test_room_policy.py
@@ -0,0 +1,226 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# <https://www.gnu.org/licenses/agpl-3.0.html>. +# +# +from typing import Optional +from unittest import mock + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.events import EventBase, make_event_from_dict +from synapse.rest import admin +from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.types import JsonDict, UserID +from synapse.types.handlers.policy_server import RECOMMENDATION_OK, RECOMMENDATION_SPAM +from synapse.util import Clock + +from tests import unittest +from tests.test_utils import event_injection + + +class RoomPolicyTestCase(unittest.FederatingHomeserverTestCase): + """Tests room policy handler.""" + + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + # mock out the federation transport client + self.mock_federation_transport_client = mock.Mock( + spec=["get_policy_recommendation_for_pdu"] + ) + self.mock_federation_transport_client.get_policy_recommendation_for_pdu = ( + mock.AsyncMock() + ) + return super().setup_test_homeserver( + federation_transport_client=self.mock_federation_transport_client + ) + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.hs = hs + self.handler = hs.get_room_policy_handler() + main_store = self.hs.get_datastores().main + + # Create a room + self.creator = self.register_user("creator", "test1234") + self.creator_token = self.login("creator", "test1234") + self.room_id = self.helper.create_room_as( + room_creator=self.creator, tok=self.creator_token + ) + room_version = self.get_success(main_store.get_room_version(self.room_id)) + + # Create some sample events + self.spammy_event = make_event_from_dict( + room_version=room_version, + internal_metadata_dict={}, + event_dict={ + "room_id": self.room_id, + "type": "m.room.message", + "sender": "@spammy:example.org", + "content": { + "msgtype": "m.text", + "body": "This is a spammy event.", + }, + }, + ) + self.not_spammy_event = make_event_from_dict( + room_version=room_version, + internal_metadata_dict={}, + event_dict={ + "room_id": self.room_id, + "type": "m.room.message", + "sender": "@not_spammy:example.org", + "content": { + "msgtype": "m.text", + "body": "This is a NOT spammy event.", + }, + }, + ) + + # Prepare the policy server mock to decide spam vs not spam on those events + self.call_count = 0 + + async def get_policy_recommendation_for_pdu( + destination: str, + pdu: EventBase, + timeout: Optional[int] = None, + ) -> JsonDict: + self.call_count += 1 + self.assertEqual(destination, self.OTHER_SERVER_NAME) + if pdu.event_id == self.spammy_event.event_id: + return {"recommendation": RECOMMENDATION_SPAM} + elif pdu.event_id == self.not_spammy_event.event_id: + return {"recommendation": RECOMMENDATION_OK} + else: + self.fail("Unexpected event ID") + + self.mock_federation_transport_client.get_policy_recommendation_for_pdu.side_effect = get_policy_recommendation_for_pdu + + def _add_policy_server_to_room(self) -> None: + # Inject a member event into the room + policy_user_id = f"@policy:{self.OTHER_SERVER_NAME}" + self.get_success( + event_injection.inject_member_event( + self.hs, self.room_id, policy_user_id, "join" + ) + ) + self.helper.send_state( + self.room_id, + "org.matrix.msc4284.policy", + { + "via": self.OTHER_SERVER_NAME, + }, + tok=self.creator_token, + state_key="", + ) + + def test_no_policy_event_set(self) -> None: + # We don't need to modify the room state at all - we're testing the default + # case where a room doesn't use a policy server. + ok = self.get_success(self.handler.is_event_allowed(self.spammy_event)) + self.assertEqual(ok, True) + self.assertEqual(self.call_count, 0) + + def test_empty_policy_event_set(self) -> None: + self.helper.send_state( + self.room_id, + "org.matrix.msc4284.policy", + { + # empty content (no `via`) + }, + tok=self.creator_token, + state_key="", + ) + + ok = self.get_success(self.handler.is_event_allowed(self.spammy_event)) + self.assertEqual(ok, True) + self.assertEqual(self.call_count, 0) + + def test_nonstring_policy_event_set(self) -> None: + self.helper.send_state( + self.room_id, + "org.matrix.msc4284.policy", + { + "via": 42, # should be a server name + }, + tok=self.creator_token, + state_key="", + ) + + ok = self.get_success(self.handler.is_event_allowed(self.spammy_event)) + self.assertEqual(ok, True) + self.assertEqual(self.call_count, 0) + + def test_self_policy_event_set(self) -> None: + self.helper.send_state( + self.room_id, + "org.matrix.msc4284.policy", + { + # We ignore events when the policy server is ourselves (for now?) + "via": (UserID.from_string(self.creator)).domain, + }, + tok=self.creator_token, + state_key="", + ) + + ok = self.get_success(self.handler.is_event_allowed(self.spammy_event)) + self.assertEqual(ok, True) + self.assertEqual(self.call_count, 0) + + def test_invalid_server_policy_event_set(self) -> None: + self.helper.send_state( + self.room_id, + "org.matrix.msc4284.policy", + { + "via": "|this| is *not* a (valid) server name.com", + }, + tok=self.creator_token, + state_key="", + ) + + ok = self.get_success(self.handler.is_event_allowed(self.spammy_event)) + self.assertEqual(ok, True) + self.assertEqual(self.call_count, 0) + + def test_not_in_room_policy_event_set(self) -> None: + self.helper.send_state( + self.room_id, + "org.matrix.msc4284.policy", + { + "via": f"x.{self.OTHER_SERVER_NAME}", + }, + tok=self.creator_token, + state_key="", + ) + + ok = self.get_success(self.handler.is_event_allowed(self.spammy_event)) + self.assertEqual(ok, True) + self.assertEqual(self.call_count, 0) + + def test_spammy_event_is_spam(self) -> None: + self._add_policy_server_to_room() + + ok = self.get_success(self.handler.is_event_allowed(self.spammy_event)) + self.assertEqual(ok, False) + self.assertEqual(self.call_count, 1) + + def test_not_spammy_event_is_not_spam(self) -> None: + self._add_policy_server_to_room() + + ok = self.get_success(self.handler.is_event_allowed(self.not_spammy_event)) + self.assertEqual(ok, True) + self.assertEqual(self.call_count, 1) diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py
index 244a4e7689..b55fa1a8fd 100644 --- a/tests/handlers/test_room_summary.py +++ b/tests/handlers/test_room_summary.py
@@ -757,6 +757,54 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ) self._assert_hierarchy(result, expected) + def test_fed_root(self) -> None: + """ + Test if requested room is available over federation. + """ + fed_hostname = self.hs.hostname + "2" + fed_space = "#fed_space:" + fed_hostname + fed_subroom = "#fed_sub_room:" + fed_hostname + + requested_room_entry = _RoomEntry( + fed_space, + { + "room_id": fed_space, + "world_readable": True, + "room_type": RoomTypes.SPACE, + }, + [ + { + "type": EventTypes.SpaceChild, + "room_id": fed_space, + "state_key": fed_subroom, + "content": {"via": [fed_hostname]}, + } + ], + ) + child_room = { + "room_id": fed_subroom, + "world_readable": True, + } + + async def summarize_remote_room_hierarchy( + _self: Any, room: Any, suggested_only: bool + ) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]: + return requested_room_entry, {fed_subroom: child_room}, set() + + expected = [ + (fed_space, [fed_subroom]), + (fed_subroom, ()), + ] + + with mock.patch( + "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy", + new=summarize_remote_room_hierarchy, + ): + result = self.get_success( + self.handler.get_room_hierarchy(create_requester(self.user), fed_space) + ) + self._assert_hierarchy(result, expected) + def test_fed_filtering(self) -> None: """ Rooms returned over federation should be properly filtered to only include diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py deleted file mode 100644
index 6ab8fda6e7..0000000000 --- a/tests/handlers/test_saml.py +++ /dev/null
@@ -1,381 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright 2020 The Matrix.org Foundation C.I.C. -# Copyright (C) 2023 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -# Originally licensed under the Apache License, Version 2.0: -# <http://www.apache.org/licenses/LICENSE-2.0>. -# -# [This file includes modifications made by New Vector Limited] -# -# - -from typing import Any, Dict, Optional, Set, Tuple -from unittest.mock import AsyncMock, Mock - -import attr - -from twisted.test.proto_helpers import MemoryReactor - -from synapse.api.errors import RedirectException -from synapse.module_api import ModuleApi -from synapse.server import HomeServer -from synapse.types import JsonDict -from synapse.util import Clock - -from tests.unittest import HomeserverTestCase, override_config - -# Check if we have the dependencies to run the tests. -try: - import saml2.config - import saml2.response - from saml2.sigver import SigverError - - has_saml2 = True - - # pysaml2 can be installed and imported, but might not be able to find xmlsec1. - config = saml2.config.SPConfig() - try: - config.load({"metadata": {}}) - has_xmlsec1 = True - except SigverError: - has_xmlsec1 = False -except ImportError: - has_saml2 = False - has_xmlsec1 = False - -# These are a few constants that are used as config parameters in the tests. -BASE_URL = "https://synapse/" - - -@attr.s -class FakeAuthnResponse: - ava = attr.ib(type=dict) - assertions = attr.ib(type=list, factory=list) - in_response_to = attr.ib(type=Optional[str], default=None) - - -class TestMappingProvider: - def __init__(self, config: None, module: ModuleApi): - pass - - @staticmethod - def parse_config(config: JsonDict) -> None: - return None - - @staticmethod - def get_saml_attributes(config: None) -> Tuple[Set[str], Set[str]]: - return {"uid"}, {"displayName"} - - def get_remote_user_id( - self, saml_response: "saml2.response.AuthnResponse", client_redirect_url: str - ) -> str: - return saml_response.ava["uid"] - - def saml_response_to_user_attributes( - self, - saml_response: "saml2.response.AuthnResponse", - failures: int, - client_redirect_url: str, - ) -> dict: - localpart = saml_response.ava["username"] + (str(failures) if failures else "") - return {"mxid_localpart": localpart, "displayname": None} - - -class TestRedirectMappingProvider(TestMappingProvider): - def saml_response_to_user_attributes( - self, - saml_response: "saml2.response.AuthnResponse", - failures: int, - client_redirect_url: str, - ) -> dict: - raise RedirectException(b"https://custom-saml-redirect/") - - -class SamlHandlerTestCase(HomeserverTestCase): - def default_config(self) -> Dict[str, Any]: - config = super().default_config() - config["public_baseurl"] = BASE_URL - saml_config: Dict[str, Any] = { - "sp_config": {"metadata": {}}, - # Disable grandfathering. - "grandfathered_mxid_source_attribute": None, - "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"}, - } - - # Update this config with what's in the default config so that - # override_config works as expected. - saml_config.update(config.get("saml2_config", {})) - config["saml2_config"] = saml_config - - return config - - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - hs = self.setup_test_homeserver() - - self.handler = hs.get_saml_handler() - - # Reduce the number of attempts when generating MXIDs. - sso_handler = hs.get_sso_handler() - sso_handler._MAP_USERNAME_RETRIES = 3 - - return hs - - if not has_saml2: - skip = "Requires pysaml2" - elif not has_xmlsec1: - skip = "Requires xmlsec1" - - def test_map_saml_response_to_user(self) -> None: - """Ensure that mapping the SAML response returned from a provider to an MXID works properly.""" - - # stub out the auth handler - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] - - # send a mocked-up SAML response to the callback - saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) - request = _mock_request() - self.get_success( - self.handler._handle_authn_response(request, saml_response, "redirect_uri") - ) - - # check that the auth handler got called as expected - auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", - "saml", - request, - "redirect_uri", - None, - new_user=True, - auth_provider_session_id=None, - ) - - @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) - def test_map_saml_response_to_existing_user(self) -> None: - """Existing users can log in with SAML account.""" - store = self.hs.get_datastores().main - self.get_success( - store.register_user(user_id="@test_user:test", password_hash=None) - ) - - # stub out the auth handler - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] - - # Map a user via SSO. - saml_response = FakeAuthnResponse( - {"uid": "tester", "mxid": ["test_user"], "username": "test_user"} - ) - request = _mock_request() - self.get_success( - self.handler._handle_authn_response(request, saml_response, "") - ) - - # check that the auth handler got called as expected - auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", - "saml", - request, - "", - None, - new_user=False, - auth_provider_session_id=None, - ) - - # Subsequent calls should map to the same mxid. - auth_handler.complete_sso_login.reset_mock() - self.get_success( - self.handler._handle_authn_response(request, saml_response, "") - ) - auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", - "saml", - request, - "", - None, - new_user=False, - auth_provider_session_id=None, - ) - - def test_map_saml_response_to_invalid_localpart(self) -> None: - """If the mapping provider generates an invalid localpart it should be rejected.""" - - # stub out the auth handler - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] - - # mock out the error renderer too - sso_handler = self.hs.get_sso_handler() - sso_handler.render_error = Mock(return_value=None) # type: ignore[method-assign] - - saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"}) - request = _mock_request() - self.get_success( - self.handler._handle_authn_response(request, saml_response, ""), - ) - sso_handler.render_error.assert_called_once_with( - request, "mapping_error", "localpart is invalid: föö" - ) - auth_handler.complete_sso_login.assert_not_called() - - def test_map_saml_response_to_user_retries(self) -> None: - """The mapping provider can retry generating an MXID if the MXID is already in use.""" - - # stub out the auth handler and error renderer - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] - sso_handler = self.hs.get_sso_handler() - sso_handler.render_error = Mock(return_value=None) # type: ignore[method-assign] - - # register a user to occupy the first-choice MXID - store = self.hs.get_datastores().main - self.get_success( - store.register_user(user_id="@test_user:test", password_hash=None) - ) - - # send the fake SAML response - saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"}) - request = _mock_request() - self.get_success( - self.handler._handle_authn_response(request, saml_response, ""), - ) - - # test_user is already taken, so test_user1 gets registered instead. - auth_handler.complete_sso_login.assert_called_once_with( - "@test_user1:test", - "saml", - request, - "", - None, - new_user=True, - auth_provider_session_id=None, - ) - auth_handler.complete_sso_login.reset_mock() - - # Register all of the potential mxids for a particular SAML username. - self.get_success( - store.register_user(user_id="@tester:test", password_hash=None) - ) - for i in range(1, 3): - self.get_success( - store.register_user(user_id="@tester%d:test" % i, password_hash=None) - ) - - # Now attempt to map to a username, this will fail since all potential usernames are taken. - saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"}) - self.get_success( - self.handler._handle_authn_response(request, saml_response, ""), - ) - sso_handler.render_error.assert_called_once_with( - request, - "mapping_error", - "Unable to generate a Matrix ID from the SSO response", - ) - auth_handler.complete_sso_login.assert_not_called() - - @override_config( - { - "saml2_config": { - "user_mapping_provider": { - "module": __name__ + ".TestRedirectMappingProvider" - }, - } - } - ) - def test_map_saml_response_redirect(self) -> None: - """Test a mapping provider that raises a RedirectException""" - - saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"}) - request = _mock_request() - e = self.get_failure( - self.handler._handle_authn_response(request, saml_response, ""), - RedirectException, - ) - self.assertEqual(e.value.location, b"https://custom-saml-redirect/") - - @override_config( - { - "saml2_config": { - "attribute_requirements": [ - {"attribute": "userGroup", "value": "staff"}, - {"attribute": "department", "value": "sales"}, - ], - }, - } - ) - def test_attribute_requirements(self) -> None: - """The required attributes must be met from the SAML response.""" - - # stub out the auth handler - auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] - - # The response doesn't have the proper userGroup or department. - saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) - request = _mock_request() - self.get_success( - self.handler._handle_authn_response(request, saml_response, "redirect_uri") - ) - auth_handler.complete_sso_login.assert_not_called() - - # The response doesn't have the proper department. - saml_response = FakeAuthnResponse( - {"uid": "test_user", "username": "test_user", "userGroup": ["staff"]} - ) - request = _mock_request() - self.get_success( - self.handler._handle_authn_response(request, saml_response, "redirect_uri") - ) - auth_handler.complete_sso_login.assert_not_called() - - # Add the proper attributes and it should succeed. - saml_response = FakeAuthnResponse( - { - "uid": "test_user", - "username": "test_user", - "userGroup": ["staff", "admin"], - "department": ["sales"], - } - ) - request.reset_mock() - self.get_success( - self.handler._handle_authn_response(request, saml_response, "redirect_uri") - ) - - # check that the auth handler got called as expected - auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", - "saml", - request, - "redirect_uri", - None, - new_user=True, - auth_provider_session_id=None, - ) - - -def _mock_request() -> Mock: - """Returns a mock which will stand in as a SynapseRequest""" - mock = Mock( - spec=[ - "finish", - "getClientAddress", - "getHeader", - "setHeader", - "setResponseCode", - "write", - ] - ) - # `_disconnected` musn't be another `Mock`, otherwise it will be truthy. - mock._disconnected = False - return mock diff --git a/tests/handlers/test_send_email.py b/tests/handlers/test_send_email.py deleted file mode 100644
index cedcea27d9..0000000000 --- a/tests/handlers/test_send_email.py +++ /dev/null
@@ -1,230 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright 2021 The Matrix.org Foundation C.I.C. -# Copyright (C) 2023 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# <https://www.gnu.org/licenses/agpl-3.0.html>. -# -# Originally licensed under the Apache License, Version 2.0: -# <http://www.apache.org/licenses/LICENSE-2.0>. -# -# [This file includes modifications made by New Vector Limited] -# -# - - -from typing import Callable, List, Tuple, Type, Union -from unittest.mock import patch - -from zope.interface import implementer - -from twisted.internet import defer -from twisted.internet._sslverify import ClientTLSOptions -from twisted.internet.address import IPv4Address, IPv6Address -from twisted.internet.defer import ensureDeferred -from twisted.internet.interfaces import IProtocolFactory -from twisted.internet.ssl import ContextFactory -from twisted.mail import interfaces, smtp - -from tests.server import FakeTransport -from tests.unittest import HomeserverTestCase, override_config - - -def TestingESMTPTLSClientFactory( - contextFactory: ContextFactory, - _connectWrapped: bool, - wrappedProtocol: IProtocolFactory, -) -> IProtocolFactory: - """We use this to pass through in testing without using TLS, but - saving the context information to check that it would have happened. - - Note that this is what the MemoryReactor does on connectSSL. - It only saves the contextFactory, but starts the connection with the - underlying Factory. - See: L{twisted.internet.testing.MemoryReactor.connectSSL}""" - - wrappedProtocol._testingContextFactory = contextFactory # type: ignore[attr-defined] - return wrappedProtocol - - -@implementer(interfaces.IMessageDelivery) -class _DummyMessageDelivery: - def __init__(self) -> None: - # (recipient, message) tuples - self.messages: List[Tuple[smtp.Address, bytes]] = [] - - def receivedHeader( - self, - helo: Tuple[bytes, bytes], - origin: smtp.Address, - recipients: List[smtp.User], - ) -> None: - return None - - def validateFrom( - self, helo: Tuple[bytes, bytes], origin: smtp.Address - ) -> smtp.Address: - return origin - - def record_message(self, recipient: smtp.Address, message: bytes) -> None: - self.messages.append((recipient, message)) - - def validateTo(self, user: smtp.User) -> Callable[[], interfaces.IMessageSMTP]: - return lambda: _DummyMessage(self, user) - - -@implementer(interfaces.IMessageSMTP) -class _DummyMessage: - """IMessageSMTP implementation which saves the message delivered to it - to the _DummyMessageDelivery object. - """ - - def __init__(self, delivery: _DummyMessageDelivery, user: smtp.User): - self._delivery = delivery - self._user = user - self._buffer: List[bytes] = [] - - def lineReceived(self, line: bytes) -> None: - self._buffer.append(line) - - def eomReceived(self) -> "defer.Deferred[bytes]": - message = b"\n".join(self._buffer) + b"\n" - self._delivery.record_message(self._user.dest, message) - return defer.succeed(b"saved") - - def connectionLost(self) -> None: - pass - - -class SendEmailHandlerTestCaseIPv4(HomeserverTestCase): - ip_class: Union[Type[IPv4Address], Type[IPv6Address]] = IPv4Address - - def setUp(self) -> None: - super().setUp() - self.reactor.lookups["localhost"] = "127.0.0.1" - - def test_send_email(self) -> None: - """Happy-path test that we can send email to a non-TLS server.""" - h = self.hs.get_send_email_handler() - d = ensureDeferred( - h.send_email( - "foo@bar.com", "test subject", "Tests", "HTML content", "Text content" - ) - ) - # there should be an attempt to connect to localhost:25 - self.assertEqual(len(self.reactor.tcpClients), 1) - (host, port, client_factory, _timeout, _bindAddress) = self.reactor.tcpClients[ - 0 - ] - self.assertEqual(host, self.reactor.lookups["localhost"]) - self.assertEqual(port, 25) - - # wire it up to an SMTP server - message_delivery = _DummyMessageDelivery() - server_protocol = smtp.ESMTP() - server_protocol.delivery = message_delivery - # make sure that the server uses the test reactor to set timeouts - server_protocol.callLater = self.reactor.callLater # type: ignore[assignment] - - client_protocol = client_factory.buildProtocol(None) - client_protocol.makeConnection(FakeTransport(server_protocol, self.reactor)) - server_protocol.makeConnection( - FakeTransport( - client_protocol, - self.reactor, - peer_address=self.ip_class( - "TCP", self.reactor.lookups["localhost"], 1234 - ), - ) - ) - - # the message should now get delivered - self.get_success(d, by=0.1) - - # check it arrived - self.assertEqual(len(message_delivery.messages), 1) - user, msg = message_delivery.messages.pop() - self.assertEqual(str(user), "foo@bar.com") - self.assertIn(b"Subject: test subject", msg) - - @patch( - "synapse.handlers.send_email.TLSMemoryBIOFactory", - TestingESMTPTLSClientFactory, - ) - @override_config( - { - "email": { - "notif_from": "noreply@test", - "force_tls": True, - }, - } - ) - def test_send_email_force_tls(self) -> None: - """Happy-path test that we can send email to an Implicit TLS server.""" - h = self.hs.get_send_email_handler() - d = ensureDeferred( - h.send_email( - "foo@bar.com", "test subject", "Tests", "HTML content", "Text content" - ) - ) - # there should be an attempt to connect to localhost:465 - self.assertEqual(len(self.reactor.tcpClients), 1) - ( - host, - port, - client_factory, - _timeout, - _bindAddress, - ) = self.reactor.tcpClients[0] - self.assertEqual(host, self.reactor.lookups["localhost"]) - self.assertEqual(port, 465) - # We need to make sure that TLS is happenning - self.assertIsInstance( - client_factory._wrappedFactory._testingContextFactory, - ClientTLSOptions, - ) - # And since we use endpoints, they go through reactor.connectTCP - # which works differently to connectSSL on the testing reactor - - # wire it up to an SMTP server - message_delivery = _DummyMessageDelivery() - server_protocol = smtp.ESMTP() - server_protocol.delivery = message_delivery - # make sure that the server uses the test reactor to set timeouts - server_protocol.callLater = self.reactor.callLater # type: ignore[assignment] - - client_protocol = client_factory.buildProtocol(None) - client_protocol.makeConnection(FakeTransport(server_protocol, self.reactor)) - server_protocol.makeConnection( - FakeTransport( - client_protocol, - self.reactor, - peer_address=self.ip_class( - "TCP", self.reactor.lookups["localhost"], 1234 - ), - ) - ) - - # the message should now get delivered - self.get_success(d, by=0.1) - - # check it arrived - self.assertEqual(len(message_delivery.messages), 1) - user, msg = message_delivery.messages.pop() - self.assertEqual(str(user), "foo@bar.com") - self.assertIn(b"Subject: test subject", msg) - - -class SendEmailHandlerTestCaseIPv6(SendEmailHandlerTestCaseIPv4): - ip_class = IPv6Address - - def setUp(self) -> None: - super().setUp() - self.reactor.lookups["localhost"] = "::1" diff --git a/tests/handlers/test_sliding_sync.py b/tests/handlers/test_sliding_sync.py
index 96da47f3b9..7144c58217 100644 --- a/tests/handlers/test_sliding_sync.py +++ b/tests/handlers/test_sliding_sync.py
@@ -18,39 +18,40 @@ # # import logging -from copy import deepcopy -from typing import Dict, List, Optional +from typing import AbstractSet, Dict, Mapping, Optional, Set, Tuple from unittest.mock import patch -from parameterized import parameterized +import attr +from parameterized import parameterized, parameterized_class from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import ( - AccountDataTypes, - EventContentFields, EventTypes, JoinRules, Membership, - RoomTypes, ) from synapse.api.room_versions import RoomVersions -from synapse.events import StrippedStateEvent, make_event_from_dict -from synapse.events.snapshot import EventContext from synapse.handlers.sliding_sync import ( + MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER, + RoomsForUserType, RoomSyncConfig, StateValues, - _RoomMembershipForUser, + _required_state_changes, ) from synapse.rest import admin from synapse.rest.client import knock, login, room from synapse.server import HomeServer from synapse.storage.util.id_generators import MultiWriterIdGenerator -from synapse.types import JsonDict, StreamToken, UserID -from synapse.types.handlers import SlidingSyncConfig +from synapse.types import JsonDict, StateMap, StreamToken, UserID, create_requester +from synapse.types.handlers.sliding_sync import PerConnectionState, SlidingSyncConfig +from synapse.types.state import StateFilter from synapse.util import Clock +from tests import unittest from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase +from tests.test_utils.event_injection import create_event from tests.unittest import HomeserverTestCase, TestCase logger = logging.getLogger(__name__) @@ -566,31 +567,39 @@ class RoomSyncConfigTestCase(TestCase): """ Combine A into B and B into A to make sure we get the same result. """ - # Since we're mutating these in place, make a copy for each of our trials - room_sync_config_a = deepcopy(a) - room_sync_config_b = deepcopy(b) - - # Combine B into A - room_sync_config_a.combine_room_sync_config(room_sync_config_b) - - self._assert_room_config_equal(room_sync_config_a, expected, "B into A") - - # Since we're mutating these in place, make a copy for each of our trials - room_sync_config_a = deepcopy(a) - room_sync_config_b = deepcopy(b) - - # Combine A into B - room_sync_config_b.combine_room_sync_config(room_sync_config_a) - - self._assert_room_config_equal(room_sync_config_b, expected, "A into B") - - -class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): + combined_config = a.combine_room_sync_config(b) + self._assert_room_config_equal(combined_config, expected, "B into A") + + combined_config = a.combine_room_sync_config(b) + self._assert_room_config_equal(combined_config, expected, "A into B") + + +# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the +# foreground update for +# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by +# https://github.com/element-hq/synapse/issues/17623) +@parameterized_class( + ("use_new_tables",), + [ + (True,), + (False,), + ], + class_name_func=lambda cls, + num, + params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}", +) +class ComputeInterestedRoomsTestCase(SlidingSyncBase): """ - Tests Sliding Sync handler `get_room_membership_for_user_at_to_token()` to make sure it returns + Tests Sliding Sync handler `compute_interested_rooms()` to make sure it returns the correct list of rooms IDs. """ + # FIXME: We should refactor these tests to run against `compute_interested_rooms(...)` + # instead of just `get_room_membership_for_user_at_to_token(...)` which is only used + # in the fallback path (`_compute_interested_rooms_fallback(...)`). These scenarios do + # well to stress that logic and we shouldn't remove them just because we're removing + # the fallback path (tracked by https://github.com/element-hq/synapse/issues/17623). + servlets = [ admin.register_servlets, knock.register_servlets, @@ -609,6 +618,11 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): self.store = self.hs.get_datastores().main self.event_sources = hs.get_event_sources() self.storage_controllers = hs.get_storage_controllers() + persistence = self.hs.get_storage_controllers().persistence + assert persistence is not None + self.persistence = persistence + + super().prepare(reactor, clock, hs) def test_no_rooms(self) -> None: """ @@ -619,15 +633,28 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): now_token = self.event_sources.get_current_token() - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), + interested_rooms = self.get_success( + self.sliding_sync_handler.room_lists.compute_interested_rooms( + SlidingSyncConfig( + user=UserID.from_string(user1_id), + requester=create_requester(user_id=user1_id), + lists={ + "foo-list": SlidingSyncConfig.SlidingSyncList( + ranges=[(0, 99)], + required_state=[], + timeline_limit=1, + ) + }, + conn_id=None, + ), + PerConnectionState(), from_token=now_token, to_token=now_token, ) ) + room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids) - self.assertEqual(room_id_results.keys(), set()) + self.assertIncludes(room_id_results, set(), exact=True) def test_get_newly_joined_room(self) -> None: """ @@ -646,26 +673,48 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): after_room_token = self.event_sources.get_current_token() - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), + interested_rooms = self.get_success( + self.sliding_sync_handler.room_lists.compute_interested_rooms( + SlidingSyncConfig( + user=UserID.from_string(user1_id), + requester=create_requester(user_id=user1_id), + lists={ + "foo-list": SlidingSyncConfig.SlidingSyncList( + ranges=[(0, 99)], + required_state=[], + timeline_limit=1, + ) + }, + conn_id=None, + ), + PerConnectionState(), from_token=before_room_token, to_token=after_room_token, ) ) + room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids) + newly_joined = interested_rooms.newly_joined_rooms + newly_left = interested_rooms.newly_left_rooms - self.assertEqual(room_id_results.keys(), {room_id}) + self.assertIncludes( + room_id_results, + {room_id}, + exact=True, + ) # It should be pointing to the join event (latest membership event in the # from/to range) self.assertEqual( - room_id_results[room_id].event_id, + interested_rooms.room_membership_for_user_map[room_id].event_id, join_response["event_id"], ) - self.assertEqual(room_id_results[room_id].membership, Membership.JOIN) + self.assertEqual( + interested_rooms.room_membership_for_user_map[room_id].membership, + Membership.JOIN, + ) # We should be considered `newly_joined` because we joined during the token # range - self.assertEqual(room_id_results[room_id].newly_joined, True) - self.assertEqual(room_id_results[room_id].newly_left, False) + self.assertTrue(room_id in newly_joined) + self.assertTrue(room_id not in newly_left) def test_get_already_joined_room(self) -> None: """ @@ -681,25 +730,43 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): after_room_token = self.event_sources.get_current_token() - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), + interested_rooms = self.get_success( + self.sliding_sync_handler.room_lists.compute_interested_rooms( + SlidingSyncConfig( + user=UserID.from_string(user1_id), + requester=create_requester(user_id=user1_id), + lists={ + "foo-list": SlidingSyncConfig.SlidingSyncList( + ranges=[(0, 99)], + required_state=[], + timeline_limit=1, + ) + }, + conn_id=None, + ), + PerConnectionState(), from_token=after_room_token, to_token=after_room_token, ) ) + room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids) + newly_joined = interested_rooms.newly_joined_rooms + newly_left = interested_rooms.newly_left_rooms - self.assertEqual(room_id_results.keys(), {room_id}) + self.assertIncludes(room_id_results, {room_id}, exact=True) # It should be pointing to the join event (latest membership event in the # from/to range) self.assertEqual( - room_id_results[room_id].event_id, + interested_rooms.room_membership_for_user_map[room_id].event_id, join_response["event_id"], ) - self.assertEqual(room_id_results[room_id].membership, Membership.JOIN) + self.assertEqual( + interested_rooms.room_membership_for_user_map[room_id].membership, + Membership.JOIN, + ) # We should *NOT* be `newly_joined` because we joined before the token range - self.assertEqual(room_id_results[room_id].newly_joined, False) - self.assertEqual(room_id_results[room_id].newly_left, False) + self.assertTrue(room_id not in newly_joined) + self.assertTrue(room_id not in newly_left) def test_get_invited_banned_knocked_room(self) -> None: """ @@ -755,48 +822,73 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): after_room_token = self.event_sources.get_current_token() - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), + interested_rooms = self.get_success( + self.sliding_sync_handler.room_lists.compute_interested_rooms( + SlidingSyncConfig( + user=UserID.from_string(user1_id), + requester=create_requester(user_id=user1_id), + lists={ + "foo-list": SlidingSyncConfig.SlidingSyncList( + ranges=[(0, 99)], + required_state=[], + timeline_limit=1, + ) + }, + conn_id=None, + ), + PerConnectionState(), from_token=before_room_token, to_token=after_room_token, ) ) + room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids) + newly_joined = interested_rooms.newly_joined_rooms + newly_left = interested_rooms.newly_left_rooms # Ensure that the invited, ban, and knock rooms show up - self.assertEqual( - room_id_results.keys(), + self.assertIncludes( + room_id_results, { invited_room_id, ban_room_id, knock_room_id, }, + exact=True, ) # It should be pointing to the the respective membership event (latest # membership event in the from/to range) self.assertEqual( - room_id_results[invited_room_id].event_id, + interested_rooms.room_membership_for_user_map[invited_room_id].event_id, invite_response["event_id"], ) - self.assertEqual(room_id_results[invited_room_id].membership, Membership.INVITE) - self.assertEqual(room_id_results[invited_room_id].newly_joined, False) - self.assertEqual(room_id_results[invited_room_id].newly_left, False) + self.assertEqual( + interested_rooms.room_membership_for_user_map[invited_room_id].membership, + Membership.INVITE, + ) + self.assertTrue(invited_room_id not in newly_joined) + self.assertTrue(invited_room_id not in newly_left) self.assertEqual( - room_id_results[ban_room_id].event_id, + interested_rooms.room_membership_for_user_map[ban_room_id].event_id, ban_response["event_id"], ) - self.assertEqual(room_id_results[ban_room_id].membership, Membership.BAN) - self.assertEqual(room_id_results[ban_room_id].newly_joined, False) - self.assertEqual(room_id_results[ban_room_id].newly_left, False) + self.assertEqual( + interested_rooms.room_membership_for_user_map[ban_room_id].membership, + Membership.BAN, + ) + self.assertTrue(ban_room_id not in newly_joined) + self.assertTrue(ban_room_id not in newly_left) self.assertEqual( - room_id_results[knock_room_id].event_id, + interested_rooms.room_membership_for_user_map[knock_room_id].event_id, knock_room_membership_state_event.event_id, ) - self.assertEqual(room_id_results[knock_room_id].membership, Membership.KNOCK) - self.assertEqual(room_id_results[knock_room_id].newly_joined, False) - self.assertEqual(room_id_results[knock_room_id].newly_left, False) + self.assertEqual( + interested_rooms.room_membership_for_user_map[knock_room_id].membership, + Membership.KNOCK, + ) + self.assertTrue(knock_room_id not in newly_joined) + self.assertTrue(knock_room_id not in newly_left) def test_get_kicked_room(self) -> None: """ @@ -827,27 +919,47 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): after_kick_token = self.event_sources.get_current_token() - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), + interested_rooms = self.get_success( + self.sliding_sync_handler.room_lists.compute_interested_rooms( + SlidingSyncConfig( + user=UserID.from_string(user1_id), + requester=create_requester(user_id=user1_id), + lists={ + "foo-list": SlidingSyncConfig.SlidingSyncList( + ranges=[(0, 99)], + required_state=[], + timeline_limit=1, + ) + }, + conn_id=None, + ), + PerConnectionState(), from_token=after_kick_token, to_token=after_kick_token, ) ) + room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids) + newly_joined = interested_rooms.newly_joined_rooms + newly_left = interested_rooms.newly_left_rooms # The kicked room should show up - self.assertEqual(room_id_results.keys(), {kick_room_id}) + self.assertIncludes(room_id_results, {kick_room_id}, exact=True) # It should be pointing to the latest membership event in the from/to range self.assertEqual( - room_id_results[kick_room_id].event_id, + interested_rooms.room_membership_for_user_map[kick_room_id].event_id, kick_response["event_id"], ) - self.assertEqual(room_id_results[kick_room_id].membership, Membership.LEAVE) - self.assertNotEqual(room_id_results[kick_room_id].sender, user1_id) + self.assertEqual( + interested_rooms.room_membership_for_user_map[kick_room_id].membership, + Membership.LEAVE, + ) + self.assertNotEqual( + interested_rooms.room_membership_for_user_map[kick_room_id].sender, user1_id + ) # We should *NOT* be `newly_joined` because we were not joined at the the time # of the `to_token`. - self.assertEqual(room_id_results[kick_room_id].newly_joined, False) - self.assertEqual(room_id_results[kick_room_id].newly_left, False) + self.assertTrue(kick_room_id not in newly_joined) + self.assertTrue(kick_room_id not in newly_left) def test_forgotten_rooms(self) -> None: """ @@ -920,16 +1032,29 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): ) self.assertEqual(channel.code, 200, channel.result) - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), + interested_rooms = self.get_success( + self.sliding_sync_handler.room_lists.compute_interested_rooms( + SlidingSyncConfig( + user=UserID.from_string(user1_id), + requester=create_requester(user_id=user1_id), + lists={ + "foo-list": SlidingSyncConfig.SlidingSyncList( + ranges=[(0, 99)], + required_state=[], + timeline_limit=1, + ) + }, + conn_id=None, + ), + PerConnectionState(), from_token=before_room_forgets, to_token=before_room_forgets, ) ) + room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids) # We shouldn't see the room because it was forgotten - self.assertEqual(room_id_results.keys(), set()) + self.assertIncludes(room_id_results, set(), exact=True) def test_newly_left_rooms(self) -> None: """ @@ -940,7 +1065,7 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): # Leave before we calculate the `from_token` room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok) - leave_response1 = self.helper.leave(room_id1, user1_id, tok=user1_tok) + _leave_response1 = self.helper.leave(room_id1, user1_id, tok=user1_tok) after_room1_token = self.event_sources.get_current_token() @@ -950,34 +1075,55 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): after_room2_token = self.event_sources.get_current_token() - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), + interested_rooms = self.get_success( + self.sliding_sync_handler.room_lists.compute_interested_rooms( + SlidingSyncConfig( + user=UserID.from_string(user1_id), + requester=create_requester(user_id=user1_id), + lists={ + "foo-list": SlidingSyncConfig.SlidingSyncList( + ranges=[(0, 99)], + required_state=[], + timeline_limit=1, + ) + }, + conn_id=None, + ), + PerConnectionState(), from_token=after_room1_token, to_token=after_room2_token, ) ) + room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids) + newly_joined = interested_rooms.newly_joined_rooms + newly_left = interested_rooms.newly_left_rooms - self.assertEqual(room_id_results.keys(), {room_id1, room_id2}) - - self.assertEqual( - room_id_results[room_id1].event_id, - leave_response1["event_id"], + # `room_id1` should not show up because it was left before the token range. + # `room_id2` should show up because it is `newly_left` within the token range. + self.assertIncludes( + room_id_results, + {room_id2}, + exact=True, + message="Corresponding map to disambiguate the opaque room IDs: " + + str( + { + "room_id1": room_id1, + "room_id2": room_id2, + } + ), ) - self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE) - # We should *NOT* be `newly_joined` or `newly_left` because that happened before - # the from/to range - self.assertEqual(room_id_results[room_id1].newly_joined, False) - self.assertEqual(room_id_results[room_id1].newly_left, False) self.assertEqual( - room_id_results[room_id2].event_id, + interested_rooms.room_membership_for_user_map[room_id2].event_id, leave_response2["event_id"], ) - self.assertEqual(room_id_results[room_id2].membership, Membership.LEAVE) + self.assertEqual( + interested_rooms.room_membership_for_user_map[room_id2].membership, + Membership.LEAVE, + ) # We should *NOT* be `newly_joined` because we are instead `newly_left` - self.assertEqual(room_id_results[room_id2].newly_joined, False) - self.assertEqual(room_id_results[room_id2].newly_left, True) + self.assertTrue(room_id2 not in newly_joined) + self.assertTrue(room_id2 in newly_left) def test_no_joins_after_to_token(self) -> None: """ @@ -1000,24 +1146,42 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok) self.helper.join(room_id2, user1_id, tok=user1_tok) - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), + interested_rooms = self.get_success( + self.sliding_sync_handler.room_lists.compute_interested_rooms( + SlidingSyncConfig( + user=UserID.from_string(user1_id), + requester=create_requester(user_id=user1_id), + lists={ + "foo-list": SlidingSyncConfig.SlidingSyncList( + ranges=[(0, 99)], + required_state=[], + timeline_limit=1, + ) + }, + conn_id=None, + ), + PerConnectionState(), from_token=before_room1_token, to_token=after_room1_token, ) ) + room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids) + newly_joined = interested_rooms.newly_joined_rooms + newly_left = interested_rooms.newly_left_rooms - self.assertEqual(room_id_results.keys(), {room_id1}) + self.assertIncludes(room_id_results, {room_id1}, exact=True) # It should be pointing to the latest membership event in the from/to range self.assertEqual( - room_id_results[room_id1].event_id, + interested_rooms.room_membership_for_user_map[room_id1].event_id, join_response1["event_id"], ) - self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) + self.assertEqual( + interested_rooms.room_membership_for_user_map[room_id1].membership, + Membership.JOIN, + ) # We should be `newly_joined` because we joined during the token range - self.assertEqual(room_id_results[room_id1].newly_joined, True) - self.assertEqual(room_id_results[room_id1].newly_left, False) + self.assertTrue(room_id1 in newly_joined) + self.assertTrue(room_id1 not in newly_left) def test_join_during_range_and_left_room_after_to_token(self) -> None: """ @@ -1040,20 +1204,35 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): # Leave the room after we already have our tokens leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok) - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), + interested_rooms = self.get_success( + self.sliding_sync_handler.room_lists.compute_interested_rooms( + SlidingSyncConfig( + user=UserID.from_string(user1_id), + requester=create_requester(user_id=user1_id), + lists={ + "foo-list": SlidingSyncConfig.SlidingSyncList( + ranges=[(0, 99)], + required_state=[], + timeline_limit=1, + ) + }, + conn_id=None, + ), + PerConnectionState(), from_token=before_room1_token, to_token=after_room1_token, ) ) + room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids) + newly_joined = interested_rooms.newly_joined_rooms + newly_left = interested_rooms.newly_left_rooms # We should still see the room because we were joined during the # from_token/to_token time period. - self.assertEqual(room_id_results.keys(), {room_id1}) + self.assertIncludes(room_id_results, {room_id1}, exact=True) # It should be pointing to the latest membership event in the from/to range self.assertEqual( - room_id_results[room_id1].event_id, + interested_rooms.room_membership_for_user_map[room_id1].event_id, join_response["event_id"], "Corresponding map to disambiguate the opaque event IDs: " + str( @@ -1063,10 +1242,13 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): } ), ) - self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) + self.assertEqual( + interested_rooms.room_membership_for_user_map[room_id1].membership, + Membership.JOIN, + ) # We should be `newly_joined` because we joined during the token range - self.assertEqual(room_id_results[room_id1].newly_joined, True) - self.assertEqual(room_id_results[room_id1].newly_left, False) + self.assertTrue(room_id1 in newly_joined) + self.assertTrue(room_id1 not in newly_left) def test_join_before_range_and_left_room_after_to_token(self) -> None: """ @@ -1087,19 +1269,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): # Leave the room after we already have our tokens leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok) - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), + interested_rooms = self.get_success( + self.sliding_sync_handler.room_lists.compute_interested_rooms( + SlidingSyncConfig( + user=UserID.from_string(user1_id), + requester=create_requester(user_id=user1_id), + lists={ + "foo-list": SlidingSyncConfig.SlidingSyncList( + ranges=[(0, 99)], + required_state=[], + timeline_limit=1, + ) + }, + conn_id=None, + ), + PerConnectionState(), from_token=after_room1_token, to_token=after_room1_token, ) ) + room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids) + newly_joined = interested_rooms.newly_joined_rooms + newly_left = interested_rooms.newly_left_rooms # We should still see the room because we were joined before the `from_token` - self.assertEqual(room_id_results.keys(), {room_id1}) + self.assertIncludes(room_id_results, {room_id1}, exact=True) # It should be pointing to the latest membership event in the from/to range self.assertEqual( - room_id_results[room_id1].event_id, + interested_rooms.room_membership_for_user_map[room_id1].event_id, join_response["event_id"], "Corresponding map to disambiguate the opaque event IDs: " + str( @@ -1109,10 +1306,13 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): } ), ) - self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) + self.assertEqual( + interested_rooms.room_membership_for_user_map[room_id1].membership, + Membership.JOIN, + ) # We should *NOT* be `newly_joined` because we joined before the token range - self.assertEqual(room_id_results[room_id1].newly_joined, False) - self.assertEqual(room_id_results[room_id1].newly_left, False) + self.assertTrue(room_id1 not in newly_joined) + self.assertTrue(room_id1 not in newly_left) def test_kicked_before_range_and_left_after_to_token(self) -> None: """ @@ -1151,19 +1351,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): join_response2 = self.helper.join(kick_room_id, user1_id, tok=user1_tok) leave_response = self.helper.leave(kick_room_id, user1_id, tok=user1_tok) - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), + interested_rooms = self.get_success( + self.sliding_sync_handler.room_lists.compute_interested_rooms( + SlidingSyncConfig( + user=UserID.from_string(user1_id), + requester=create_requester(user_id=user1_id), + lists={ + "foo-list": SlidingSyncConfig.SlidingSyncList( + ranges=[(0, 99)], + required_state=[], + timeline_limit=1, + ) + }, + conn_id=None, + ), + PerConnectionState(), from_token=after_kick_token, to_token=after_kick_token, ) ) + room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids) + newly_joined = interested_rooms.newly_joined_rooms + newly_left = interested_rooms.newly_left_rooms # We shouldn't see the room because it was forgotten - self.assertEqual(room_id_results.keys(), {kick_room_id}) + self.assertIncludes(room_id_results, {kick_room_id}, exact=True) # It should be pointing to the latest membership event in the from/to range self.assertEqual( - room_id_results[kick_room_id].event_id, + interested_rooms.room_membership_for_user_map[kick_room_id].event_id, kick_response["event_id"], "Corresponding map to disambiguate the opaque event IDs: " + str( @@ -1175,11 +1390,16 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): } ), ) - self.assertEqual(room_id_results[kick_room_id].membership, Membership.LEAVE) - self.assertNotEqual(room_id_results[kick_room_id].sender, user1_id) + self.assertEqual( + interested_rooms.room_membership_for_user_map[kick_room_id].membership, + Membership.LEAVE, + ) + self.assertNotEqual( + interested_rooms.room_membership_for_user_map[kick_room_id].sender, user1_id + ) # We should *NOT* be `newly_joined` because we were kicked - self.assertEqual(room_id_results[kick_room_id].newly_joined, False) - self.assertEqual(room_id_results[kick_room_id].newly_left, False) + self.assertTrue(kick_room_id not in newly_joined) + self.assertTrue(kick_room_id not in newly_left) def test_newly_left_during_range_and_join_leave_after_to_token(self) -> None: """ @@ -1207,19 +1427,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): join_response2 = self.helper.join(room_id1, user1_id, tok=user1_tok) leave_response2 = self.helper.leave(room_id1, user1_id, tok=user1_tok) - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), + interested_rooms = self.get_success( + self.sliding_sync_handler.room_lists.compute_interested_rooms( + SlidingSyncConfig( + user=UserID.from_string(user1_id), + requester=create_requester(user_id=user1_id), + lists={ + "foo-list": SlidingSyncConfig.SlidingSyncList( + ranges=[(0, 99)], + required_state=[], + timeline_limit=1, + ) + }, + conn_id=None, + ), + PerConnectionState(), from_token=before_room1_token, to_token=after_room1_token, ) ) + room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids) + newly_joined = interested_rooms.newly_joined_rooms + newly_left = interested_rooms.newly_left_rooms # Room should still show up because it's newly_left during the from/to range - self.assertEqual(room_id_results.keys(), {room_id1}) + self.assertIncludes(room_id_results, {room_id1}, exact=True) # It should be pointing to the latest membership event in the from/to range self.assertEqual( - room_id_results[room_id1].event_id, + interested_rooms.room_membership_for_user_map[room_id1].event_id, leave_response1["event_id"], "Corresponding map to disambiguate the opaque event IDs: " + str( @@ -1231,11 +1466,14 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): } ), ) - self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE) + self.assertEqual( + interested_rooms.room_membership_for_user_map[room_id1].membership, + Membership.LEAVE, + ) # We should *NOT* be `newly_joined` because we are actually `newly_left` during # the token range - self.assertEqual(room_id_results[room_id1].newly_joined, False) - self.assertEqual(room_id_results[room_id1].newly_left, True) + self.assertTrue(room_id1 not in newly_joined) + self.assertTrue(room_id1 in newly_left) def test_newly_left_during_range_and_join_after_to_token(self) -> None: """ @@ -1262,19 +1500,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): # Join the room after we already have our tokens join_response2 = self.helper.join(room_id1, user1_id, tok=user1_tok) - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), + interested_rooms = self.get_success( + self.sliding_sync_handler.room_lists.compute_interested_rooms( + SlidingSyncConfig( + user=UserID.from_string(user1_id), + requester=create_requester(user_id=user1_id), + lists={ + "foo-list": SlidingSyncConfig.SlidingSyncList( + ranges=[(0, 99)], + required_state=[], + timeline_limit=1, + ) + }, + conn_id=None, + ), + PerConnectionState(), from_token=before_room1_token, to_token=after_room1_token, ) ) + room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids) + newly_joined = interested_rooms.newly_joined_rooms + newly_left = interested_rooms.newly_left_rooms # Room should still show up because it's newly_left during the from/to range - self.assertEqual(room_id_results.keys(), {room_id1}) + self.assertIncludes(room_id_results, {room_id1}, exact=True) # It should be pointing to the latest membership event in the from/to range self.assertEqual( - room_id_results[room_id1].event_id, + interested_rooms.room_membership_for_user_map[room_id1].event_id, leave_response1["event_id"], "Corresponding map to disambiguate the opaque event IDs: " + str( @@ -1285,11 +1538,14 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): } ), ) - self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE) + self.assertEqual( + interested_rooms.room_membership_for_user_map[room_id1].membership, + Membership.LEAVE, + ) # We should *NOT* be `newly_joined` because we are actually `newly_left` during # the token range - self.assertEqual(room_id_results[room_id1].newly_joined, False) - self.assertEqual(room_id_results[room_id1].newly_left, True) + self.assertTrue(room_id1 not in newly_joined) + self.assertTrue(room_id1 in newly_left) def test_no_from_token(self) -> None: """ @@ -1314,47 +1570,53 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): # Join and leave the room2 before the `to_token` self.helper.join(room_id2, user1_id, tok=user1_tok) - leave_response2 = self.helper.leave(room_id2, user1_id, tok=user1_tok) + _leave_response2 = self.helper.leave(room_id2, user1_id, tok=user1_tok) after_room1_token = self.event_sources.get_current_token() # Join the room2 after we already have our tokens self.helper.join(room_id2, user1_id, tok=user1_tok) - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), + interested_rooms = self.get_success( + self.sliding_sync_handler.room_lists.compute_interested_rooms( + SlidingSyncConfig( + user=UserID.from_string(user1_id), + requester=create_requester(user_id=user1_id), + lists={ + "foo-list": SlidingSyncConfig.SlidingSyncList( + ranges=[(0, 99)], + required_state=[], + timeline_limit=1, + ) + }, + conn_id=None, + ), + PerConnectionState(), from_token=None, to_token=after_room1_token, ) ) + room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids) + newly_joined = interested_rooms.newly_joined_rooms + newly_left = interested_rooms.newly_left_rooms # Only rooms we were joined to before the `to_token` should show up - self.assertEqual(room_id_results.keys(), {room_id1, room_id2}) + self.assertIncludes(room_id_results, {room_id1}, exact=True) # Room1 # It should be pointing to the latest membership event in the from/to range self.assertEqual( - room_id_results[room_id1].event_id, + interested_rooms.room_membership_for_user_map[room_id1].event_id, join_response1["event_id"], ) - self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) - # We should *NOT* be `newly_joined`/`newly_left` because there is no - # `from_token` to define a "live" range to compare against - self.assertEqual(room_id_results[room_id1].newly_joined, False) - self.assertEqual(room_id_results[room_id1].newly_left, False) - - # Room2 - # It should be pointing to the latest membership event in the from/to range self.assertEqual( - room_id_results[room_id2].event_id, - leave_response2["event_id"], + interested_rooms.room_membership_for_user_map[room_id1].membership, + Membership.JOIN, ) - self.assertEqual(room_id_results[room_id2].membership, Membership.LEAVE) # We should *NOT* be `newly_joined`/`newly_left` because there is no # `from_token` to define a "live" range to compare against - self.assertEqual(room_id_results[room_id2].newly_joined, False) - self.assertEqual(room_id_results[room_id2].newly_left, False) + self.assertTrue(room_id1 not in newly_joined) + self.assertTrue(room_id1 not in newly_left) def test_from_token_ahead_of_to_token(self) -> None: """ @@ -1378,7 +1640,7 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): # Join and leave the room2 before `to_token` _join_room2_response1 = self.helper.join(room_id2, user1_id, tok=user1_tok) - leave_room2_response1 = self.helper.leave(room_id2, user1_id, tok=user1_tok) + _leave_room2_response1 = self.helper.leave(room_id2, user1_id, tok=user1_tok) # Note: These are purposely swapped. The `from_token` should come after # the `to_token` in this test @@ -1403,54 +1665,69 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): # Join the room4 after we already have our tokens self.helper.join(room_id4, user1_id, tok=user1_tok) - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), + interested_rooms = self.get_success( + self.sliding_sync_handler.room_lists.compute_interested_rooms( + SlidingSyncConfig( + user=UserID.from_string(user1_id), + requester=create_requester(user_id=user1_id), + lists={ + "foo-list": SlidingSyncConfig.SlidingSyncList( + ranges=[(0, 99)], + required_state=[], + timeline_limit=1, + ) + }, + conn_id=None, + ), + PerConnectionState(), from_token=from_token, to_token=to_token, ) ) + room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids) + newly_joined = interested_rooms.newly_joined_rooms + newly_left = interested_rooms.newly_left_rooms # In the "current" state snapshot, we're joined to all of the rooms but in the # from/to token range... self.assertIncludes( - room_id_results.keys(), + room_id_results, { # Included because we were joined before both tokens room_id1, - # Included because we had membership before the to_token - room_id2, + # Excluded because we left before the `from_token` and `to_token` + # room_id2, # Excluded because we joined after the `to_token` # room_id3, # Excluded because we joined after the `to_token` # room_id4, }, exact=True, + message="Corresponding map to disambiguate the opaque room IDs: " + + str( + { + "room_id1": room_id1, + "room_id2": room_id2, + "room_id3": room_id3, + "room_id4": room_id4, + } + ), ) # Room1 # It should be pointing to the latest membership event in the from/to range self.assertEqual( - room_id_results[room_id1].event_id, + interested_rooms.room_membership_for_user_map[room_id1].event_id, join_room1_response1["event_id"], ) - self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) - # We should *NOT* be `newly_joined`/`newly_left` because we joined `room1` - # before either of the tokens - self.assertEqual(room_id_results[room_id1].newly_joined, False) - self.assertEqual(room_id_results[room_id1].newly_left, False) - - # Room2 - # It should be pointing to the latest membership event in the from/to range self.assertEqual( - room_id_results[room_id2].event_id, - leave_room2_response1["event_id"], + interested_rooms.room_membership_for_user_map[room_id1].membership, + Membership.JOIN, ) - self.assertEqual(room_id_results[room_id2].membership, Membership.LEAVE) - # We should *NOT* be `newly_joined`/`newly_left` because we joined and left - # `room1` before either of the tokens - self.assertEqual(room_id_results[room_id2].newly_joined, False) - self.assertEqual(room_id_results[room_id2].newly_left, False) + # We should *NOT* be `newly_joined`/`newly_left` because we joined `room1` + # before either of the tokens + self.assertTrue(room_id1 not in newly_joined) + self.assertTrue(room_id1 not in newly_left) def test_leave_before_range_and_join_leave_after_to_token(self) -> None: """ @@ -1468,7 +1745,7 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) # Join and leave the room before the from/to range self.helper.join(room_id1, user1_id, tok=user1_tok) - leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok) + self.helper.leave(room_id1, user1_id, tok=user1_tok) after_room1_token = self.event_sources.get_current_token() @@ -1476,25 +1753,28 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): self.helper.join(room_id1, user1_id, tok=user1_tok) self.helper.leave(room_id1, user1_id, tok=user1_tok) - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), + interested_rooms = self.get_success( + self.sliding_sync_handler.room_lists.compute_interested_rooms( + SlidingSyncConfig( + user=UserID.from_string(user1_id), + requester=create_requester(user_id=user1_id), + lists={ + "foo-list": SlidingSyncConfig.SlidingSyncList( + ranges=[(0, 99)], + required_state=[], + timeline_limit=1, + ) + }, + conn_id=None, + ), + PerConnectionState(), from_token=after_room1_token, to_token=after_room1_token, ) ) + room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids) - self.assertEqual(room_id_results.keys(), {room_id1}) - # It should be pointing to the latest membership event in the from/to range - self.assertEqual( - room_id_results[room_id1].event_id, - leave_response["event_id"], - ) - self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE) - # We should *NOT* be `newly_joined`/`newly_left` because we joined and left - # `room1` before either of the tokens - self.assertEqual(room_id_results[room_id1].newly_joined, False) - self.assertEqual(room_id_results[room_id1].newly_left, False) + self.assertIncludes(room_id_results, set(), exact=True) def test_leave_before_range_and_join_after_to_token(self) -> None: """ @@ -1512,32 +1792,35 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) # Join and leave the room before the from/to range self.helper.join(room_id1, user1_id, tok=user1_tok) - leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok) + self.helper.leave(room_id1, user1_id, tok=user1_tok) after_room1_token = self.event_sources.get_current_token() # Join the room after we already have our tokens self.helper.join(room_id1, user1_id, tok=user1_tok) - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), + interested_rooms = self.get_success( + self.sliding_sync_handler.room_lists.compute_interested_rooms( + SlidingSyncConfig( + user=UserID.from_string(user1_id), + requester=create_requester(user_id=user1_id), + lists={ + "foo-list": SlidingSyncConfig.SlidingSyncList( + ranges=[(0, 99)], + required_state=[], + timeline_limit=1, + ) + }, + conn_id=None, + ), + PerConnectionState(), from_token=after_room1_token, to_token=after_room1_token, ) ) + room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids) - self.assertEqual(room_id_results.keys(), {room_id1}) - # It should be pointing to the latest membership event in the from/to range - self.assertEqual( - room_id_results[room_id1].event_id, - leave_response["event_id"], - ) - self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE) - # We should *NOT* be `newly_joined`/`newly_left` because we joined and left - # `room1` before either of the tokens - self.assertEqual(room_id_results[room_id1].newly_joined, False) - self.assertEqual(room_id_results[room_id1].newly_left, False) + self.assertIncludes(room_id_results, set(), exact=True) def test_join_leave_multiple_times_during_range_and_after_to_token( self, @@ -1569,19 +1852,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): join_response3 = self.helper.join(room_id1, user1_id, tok=user1_tok) leave_response3 = self.helper.leave(room_id1, user1_id, tok=user1_tok) - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), + interested_rooms = self.get_success( + self.sliding_sync_handler.room_lists.compute_interested_rooms( + SlidingSyncConfig( + user=UserID.from_string(user1_id), + requester=create_requester(user_id=user1_id), + lists={ + "foo-list": SlidingSyncConfig.SlidingSyncList( + ranges=[(0, 99)], + required_state=[], + timeline_limit=1, + ) + }, + conn_id=None, + ), + PerConnectionState(), from_token=before_room1_token, to_token=after_room1_token, ) ) + room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids) + newly_joined = interested_rooms.newly_joined_rooms + newly_left = interested_rooms.newly_left_rooms # Room should show up because it was newly_left and joined during the from/to range - self.assertEqual(room_id_results.keys(), {room_id1}) + self.assertIncludes(room_id_results, {room_id1}, exact=True) # It should be pointing to the latest membership event in the from/to range self.assertEqual( - room_id_results[room_id1].event_id, + interested_rooms.room_membership_for_user_map[room_id1].event_id, join_response2["event_id"], "Corresponding map to disambiguate the opaque event IDs: " + str( @@ -1595,12 +1893,15 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): } ), ) - self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) + self.assertEqual( + interested_rooms.room_membership_for_user_map[room_id1].membership, + Membership.JOIN, + ) # We should be `newly_joined` because we joined during the token range - self.assertEqual(room_id_results[room_id1].newly_joined, True) + self.assertTrue(room_id1 in newly_joined) # We should *NOT* be `newly_left` because we joined during the token range and # was still joined at the end of the range - self.assertEqual(room_id_results[room_id1].newly_left, False) + self.assertTrue(room_id1 not in newly_left) def test_join_leave_multiple_times_before_range_and_after_to_token( self, @@ -1631,19 +1932,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): join_response3 = self.helper.join(room_id1, user1_id, tok=user1_tok) leave_response3 = self.helper.leave(room_id1, user1_id, tok=user1_tok) - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), + interested_rooms = self.get_success( + self.sliding_sync_handler.room_lists.compute_interested_rooms( + SlidingSyncConfig( + user=UserID.from_string(user1_id), + requester=create_requester(user_id=user1_id), + lists={ + "foo-list": SlidingSyncConfig.SlidingSyncList( + ranges=[(0, 99)], + required_state=[], + timeline_limit=1, + ) + }, + conn_id=None, + ), + PerConnectionState(), from_token=after_room1_token, to_token=after_room1_token, ) ) + room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids) + newly_joined = interested_rooms.newly_joined_rooms + newly_left = interested_rooms.newly_left_rooms # Room should show up because we were joined before the from/to range - self.assertEqual(room_id_results.keys(), {room_id1}) + self.assertIncludes(room_id_results, {room_id1}, exact=True) # It should be pointing to the latest membership event in the from/to range self.assertEqual( - room_id_results[room_id1].event_id, + interested_rooms.room_membership_for_user_map[room_id1].event_id, join_response2["event_id"], "Corresponding map to disambiguate the opaque event IDs: " + str( @@ -1657,10 +1973,13 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): } ), ) - self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) + self.assertEqual( + interested_rooms.room_membership_for_user_map[room_id1].membership, + Membership.JOIN, + ) # We should *NOT* be `newly_joined` because we joined before the token range - self.assertEqual(room_id_results[room_id1].newly_joined, False) - self.assertEqual(room_id_results[room_id1].newly_left, False) + self.assertTrue(room_id1 not in newly_joined) + self.assertTrue(room_id1 not in newly_left) def test_invite_before_range_and_join_leave_after_to_token( self, @@ -1690,19 +2009,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): join_respsonse = self.helper.join(room_id1, user1_id, tok=user1_tok) leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok) - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), + interested_rooms = self.get_success( + self.sliding_sync_handler.room_lists.compute_interested_rooms( + SlidingSyncConfig( + user=UserID.from_string(user1_id), + requester=create_requester(user_id=user1_id), + lists={ + "foo-list": SlidingSyncConfig.SlidingSyncList( + ranges=[(0, 99)], + required_state=[], + timeline_limit=1, + ) + }, + conn_id=None, + ), + PerConnectionState(), from_token=after_room1_token, to_token=after_room1_token, ) ) + room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids) + newly_joined = interested_rooms.newly_joined_rooms + newly_left = interested_rooms.newly_left_rooms # Room should show up because we were invited before the from/to range - self.assertEqual(room_id_results.keys(), {room_id1}) + self.assertIncludes(room_id_results, {room_id1}, exact=True) # It should be pointing to the latest membership event in the from/to range self.assertEqual( - room_id_results[room_id1].event_id, + interested_rooms.room_membership_for_user_map[room_id1].event_id, invite_response["event_id"], "Corresponding map to disambiguate the opaque event IDs: " + str( @@ -1713,11 +2047,14 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): } ), ) - self.assertEqual(room_id_results[room_id1].membership, Membership.INVITE) + self.assertEqual( + interested_rooms.room_membership_for_user_map[room_id1].membership, + Membership.INVITE, + ) # We should *NOT* be `newly_joined` because we were only invited before the # token range - self.assertEqual(room_id_results[room_id1].newly_joined, False) - self.assertEqual(room_id_results[room_id1].newly_left, False) + self.assertTrue(room_id1 not in newly_joined) + self.assertTrue(room_id1 not in newly_left) def test_join_and_display_name_changes_in_token_range( self, @@ -1764,19 +2101,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): tok=user1_tok, ) - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), + interested_rooms = self.get_success( + self.sliding_sync_handler.room_lists.compute_interested_rooms( + SlidingSyncConfig( + user=UserID.from_string(user1_id), + requester=create_requester(user_id=user1_id), + lists={ + "foo-list": SlidingSyncConfig.SlidingSyncList( + ranges=[(0, 99)], + required_state=[], + timeline_limit=1, + ) + }, + conn_id=None, + ), + PerConnectionState(), from_token=before_room1_token, to_token=after_room1_token, ) ) + room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids) + newly_joined = interested_rooms.newly_joined_rooms + newly_left = interested_rooms.newly_left_rooms # Room should show up because we were joined during the from/to range - self.assertEqual(room_id_results.keys(), {room_id1}) + self.assertIncludes(room_id_results, {room_id1}, exact=True) # It should be pointing to the latest membership event in the from/to range self.assertEqual( - room_id_results[room_id1].event_id, + interested_rooms.room_membership_for_user_map[room_id1].event_id, displayname_change_during_token_range_response["event_id"], "Corresponding map to disambiguate the opaque event IDs: " + str( @@ -1791,10 +2143,13 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): } ), ) - self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) + self.assertEqual( + interested_rooms.room_membership_for_user_map[room_id1].membership, + Membership.JOIN, + ) # We should be `newly_joined` because we joined during the token range - self.assertEqual(room_id_results[room_id1].newly_joined, True) - self.assertEqual(room_id_results[room_id1].newly_left, False) + self.assertTrue(room_id1 in newly_joined) + self.assertTrue(room_id1 not in newly_left) def test_display_name_changes_in_token_range( self, @@ -1829,19 +2184,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): after_change1_token = self.event_sources.get_current_token() - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), + interested_rooms = self.get_success( + self.sliding_sync_handler.room_lists.compute_interested_rooms( + SlidingSyncConfig( + user=UserID.from_string(user1_id), + requester=create_requester(user_id=user1_id), + lists={ + "foo-list": SlidingSyncConfig.SlidingSyncList( + ranges=[(0, 99)], + required_state=[], + timeline_limit=1, + ) + }, + conn_id=None, + ), + PerConnectionState(), from_token=after_room1_token, to_token=after_change1_token, ) ) + room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids) + newly_joined = interested_rooms.newly_joined_rooms + newly_left = interested_rooms.newly_left_rooms # Room should show up because we were joined during the from/to range - self.assertEqual(room_id_results.keys(), {room_id1}) + self.assertIncludes(room_id_results, {room_id1}, exact=True) # It should be pointing to the latest membership event in the from/to range self.assertEqual( - room_id_results[room_id1].event_id, + interested_rooms.room_membership_for_user_map[room_id1].event_id, displayname_change_during_token_range_response["event_id"], "Corresponding map to disambiguate the opaque event IDs: " + str( @@ -1853,10 +2223,13 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): } ), ) - self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) + self.assertEqual( + interested_rooms.room_membership_for_user_map[room_id1].membership, + Membership.JOIN, + ) # We should *NOT* be `newly_joined` because we joined before the token range - self.assertEqual(room_id_results[room_id1].newly_joined, False) - self.assertEqual(room_id_results[room_id1].newly_left, False) + self.assertTrue(room_id1 not in newly_joined) + self.assertTrue(room_id1 not in newly_left) def test_display_name_changes_before_and_after_token_range( self, @@ -1901,19 +2274,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): tok=user1_tok, ) - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), + interested_rooms = self.get_success( + self.sliding_sync_handler.room_lists.compute_interested_rooms( + SlidingSyncConfig( + user=UserID.from_string(user1_id), + requester=create_requester(user_id=user1_id), + lists={ + "foo-list": SlidingSyncConfig.SlidingSyncList( + ranges=[(0, 99)], + required_state=[], + timeline_limit=1, + ) + }, + conn_id=None, + ), + PerConnectionState(), from_token=after_room1_token, to_token=after_room1_token, ) ) + room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids) + newly_joined = interested_rooms.newly_joined_rooms + newly_left = interested_rooms.newly_left_rooms # Room should show up because we were joined before the from/to range - self.assertEqual(room_id_results.keys(), {room_id1}) + self.assertIncludes(room_id_results, {room_id1}, exact=True) # It should be pointing to the latest membership event in the from/to range self.assertEqual( - room_id_results[room_id1].event_id, + interested_rooms.room_membership_for_user_map[room_id1].event_id, displayname_change_before_token_range_response["event_id"], "Corresponding map to disambiguate the opaque event IDs: " + str( @@ -1928,18 +2316,22 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): } ), ) - self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) + self.assertEqual( + interested_rooms.room_membership_for_user_map[room_id1].membership, + Membership.JOIN, + ) # We should *NOT* be `newly_joined` because we joined before the token range - self.assertEqual(room_id_results[room_id1].newly_joined, False) - self.assertEqual(room_id_results[room_id1].newly_left, False) + self.assertTrue(room_id1 not in newly_joined) + self.assertTrue(room_id1 not in newly_left) - def test_display_name_changes_leave_after_token_range( + def test_newly_joined_display_name_changes_leave_after_token_range( self, ) -> None: """ Test that we point to the correct membership event within the from/to range even - if there are multiple `join` membership events in a row indicating - `displayname`/`avatar_url` updates and we leave after the `to_token`. + if we are `newly_joined` and there are multiple `join` membership events in a + row indicating `displayname`/`avatar_url` updates and we leave after the + `to_token`. See condition "1a)" comments in the `get_room_membership_for_user_at_to_token()` method. """ @@ -1954,6 +2346,7 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): # leave and can still re-join. room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) join_response = self.helper.join(room_id1, user1_id, tok=user1_tok) + # Update the displayname during the token range displayname_change_during_token_range_response = self.helper.send_state( room_id1, @@ -1983,19 +2376,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): # Leave after the token self.helper.leave(room_id1, user1_id, tok=user1_tok) - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), + interested_rooms = self.get_success( + self.sliding_sync_handler.room_lists.compute_interested_rooms( + SlidingSyncConfig( + user=UserID.from_string(user1_id), + requester=create_requester(user_id=user1_id), + lists={ + "foo-list": SlidingSyncConfig.SlidingSyncList( + ranges=[(0, 99)], + required_state=[], + timeline_limit=1, + ) + }, + conn_id=None, + ), + PerConnectionState(), from_token=before_room1_token, to_token=after_room1_token, ) ) + room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids) + newly_joined = interested_rooms.newly_joined_rooms + newly_left = interested_rooms.newly_left_rooms # Room should show up because we were joined during the from/to range - self.assertEqual(room_id_results.keys(), {room_id1}) + self.assertIncludes(room_id_results, {room_id1}, exact=True) # It should be pointing to the latest membership event in the from/to range self.assertEqual( - room_id_results[room_id1].event_id, + interested_rooms.room_membership_for_user_map[room_id1].event_id, displayname_change_during_token_range_response["event_id"], "Corresponding map to disambiguate the opaque event IDs: " + str( @@ -2010,10 +2418,117 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): } ), ) - self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) + self.assertEqual( + interested_rooms.room_membership_for_user_map[room_id1].membership, + Membership.JOIN, + ) # We should be `newly_joined` because we joined during the token range - self.assertEqual(room_id_results[room_id1].newly_joined, True) - self.assertEqual(room_id_results[room_id1].newly_left, False) + self.assertTrue(room_id1 in newly_joined) + self.assertTrue(room_id1 not in newly_left) + + def test_display_name_changes_leave_after_token_range( + self, + ) -> None: + """ + Test that we point to the correct membership event within the from/to range even + if there are multiple `join` membership events in a row indicating + `displayname`/`avatar_url` updates and we leave after the `to_token`. + + See condition "1a)" comments in the `get_room_membership_for_user_at_to_token()` method. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + _before_room1_token = self.event_sources.get_current_token() + + # We create the room with user2 so the room isn't left with no members when we + # leave and can still re-join. + room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) + join_response = self.helper.join(room_id1, user1_id, tok=user1_tok) + + after_join_token = self.event_sources.get_current_token() + + # Update the displayname during the token range + displayname_change_during_token_range_response = self.helper.send_state( + room_id1, + event_type=EventTypes.Member, + state_key=user1_id, + body={ + "membership": Membership.JOIN, + "displayname": "displayname during token range", + }, + tok=user1_tok, + ) + + after_display_name_change_token = self.event_sources.get_current_token() + + # Update the displayname after the token range + displayname_change_after_token_range_response = self.helper.send_state( + room_id1, + event_type=EventTypes.Member, + state_key=user1_id, + body={ + "membership": Membership.JOIN, + "displayname": "displayname after token range", + }, + tok=user1_tok, + ) + + # Leave after the token + self.helper.leave(room_id1, user1_id, tok=user1_tok) + + interested_rooms = self.get_success( + self.sliding_sync_handler.room_lists.compute_interested_rooms( + SlidingSyncConfig( + user=UserID.from_string(user1_id), + requester=create_requester(user_id=user1_id), + lists={ + "foo-list": SlidingSyncConfig.SlidingSyncList( + ranges=[(0, 99)], + required_state=[], + timeline_limit=1, + ) + }, + conn_id=None, + ), + PerConnectionState(), + from_token=after_join_token, + to_token=after_display_name_change_token, + ) + ) + room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids) + newly_joined = interested_rooms.newly_joined_rooms + newly_left = interested_rooms.newly_left_rooms + + # Room should show up because we were joined during the from/to range + self.assertIncludes(room_id_results, {room_id1}, exact=True) + # It should be pointing to the latest membership event in the from/to range + self.assertEqual( + interested_rooms.room_membership_for_user_map[room_id1].event_id, + displayname_change_during_token_range_response["event_id"], + "Corresponding map to disambiguate the opaque event IDs: " + + str( + { + "join_response": join_response["event_id"], + "displayname_change_during_token_range_response": displayname_change_during_token_range_response[ + "event_id" + ], + "displayname_change_after_token_range_response": displayname_change_after_token_range_response[ + "event_id" + ], + } + ), + ) + self.assertEqual( + interested_rooms.room_membership_for_user_map[room_id1].membership, + Membership.JOIN, + ) + # We only changed our display name during the token range so we shouldn't be + # considered `newly_joined` or `newly_left` + self.assertTrue(room_id1 not in newly_joined) + self.assertTrue(room_id1 not in newly_left) def test_display_name_changes_join_after_token_range( self, @@ -2051,16 +2566,29 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): tok=user1_tok, ) - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), + interested_rooms = self.get_success( + self.sliding_sync_handler.room_lists.compute_interested_rooms( + SlidingSyncConfig( + user=UserID.from_string(user1_id), + requester=create_requester(user_id=user1_id), + lists={ + "foo-list": SlidingSyncConfig.SlidingSyncList( + ranges=[(0, 99)], + required_state=[], + timeline_limit=1, + ) + }, + conn_id=None, + ), + PerConnectionState(), from_token=before_room1_token, to_token=after_room1_token, ) ) + room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids) # Room shouldn't show up because we joined after the from/to range - self.assertEqual(room_id_results.keys(), set()) + self.assertIncludes(room_id_results, set(), exact=True) def test_newly_joined_with_leave_join_in_token_range( self, @@ -2087,26 +2615,44 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): after_more_changes_token = self.event_sources.get_current_token() - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), + interested_rooms = self.get_success( + self.sliding_sync_handler.room_lists.compute_interested_rooms( + SlidingSyncConfig( + user=UserID.from_string(user1_id), + requester=create_requester(user_id=user1_id), + lists={ + "foo-list": SlidingSyncConfig.SlidingSyncList( + ranges=[(0, 99)], + required_state=[], + timeline_limit=1, + ) + }, + conn_id=None, + ), + PerConnectionState(), from_token=after_room1_token, to_token=after_more_changes_token, ) ) + room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids) + newly_joined = interested_rooms.newly_joined_rooms + newly_left = interested_rooms.newly_left_rooms # Room should show up because we were joined during the from/to range - self.assertEqual(room_id_results.keys(), {room_id1}) + self.assertIncludes(room_id_results, {room_id1}, exact=True) # It should be pointing to the latest membership event in the from/to range self.assertEqual( - room_id_results[room_id1].event_id, + interested_rooms.room_membership_for_user_map[room_id1].event_id, join_response2["event_id"], ) - self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) + self.assertEqual( + interested_rooms.room_membership_for_user_map[room_id1].membership, + Membership.JOIN, + ) # We should be considered `newly_joined` because there is some non-join event in # between our latest join event. - self.assertEqual(room_id_results[room_id1].newly_joined, True) - self.assertEqual(room_id_results[room_id1].newly_left, False) + self.assertTrue(room_id1 in newly_joined) + self.assertTrue(room_id1 not in newly_left) def test_newly_joined_only_joins_during_token_range( self, @@ -2152,19 +2698,34 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): after_room1_token = self.event_sources.get_current_token() - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), + interested_rooms = self.get_success( + self.sliding_sync_handler.room_lists.compute_interested_rooms( + SlidingSyncConfig( + user=UserID.from_string(user1_id), + requester=create_requester(user_id=user1_id), + lists={ + "foo-list": SlidingSyncConfig.SlidingSyncList( + ranges=[(0, 99)], + required_state=[], + timeline_limit=1, + ) + }, + conn_id=None, + ), + PerConnectionState(), from_token=before_room1_token, to_token=after_room1_token, ) ) + room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids) + newly_joined = interested_rooms.newly_joined_rooms + newly_left = interested_rooms.newly_left_rooms # Room should show up because it was newly_left and joined during the from/to range - self.assertEqual(room_id_results.keys(), {room_id1}) + self.assertIncludes(room_id_results, {room_id1}, exact=True) # It should be pointing to the latest membership event in the from/to range self.assertEqual( - room_id_results[room_id1].event_id, + interested_rooms.room_membership_for_user_map[room_id1].event_id, displayname_change_during_token_range_response2["event_id"], "Corresponding map to disambiguate the opaque event IDs: " + str( @@ -2179,10 +2740,13 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): } ), ) - self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) + self.assertEqual( + interested_rooms.room_membership_for_user_map[room_id1].membership, + Membership.JOIN, + ) # We should be `newly_joined` because we first joined during the token range - self.assertEqual(room_id_results[room_id1].newly_joined, True) - self.assertEqual(room_id_results[room_id1].newly_left, False) + self.assertTrue(room_id1 in newly_joined) + self.assertTrue(room_id1 not in newly_left) def test_multiple_rooms_are_not_confused( self, @@ -2205,7 +2769,7 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): # Invited and left the room before the token self.helper.invite(room_id1, src=user2_id, targ=user1_id, tok=user2_tok) - leave_room1_response = self.helper.leave(room_id1, user1_id, tok=user1_tok) + _leave_room1_response = self.helper.leave(room_id1, user1_id, tok=user1_tok) # Invited to room2 invite_room2_response = self.helper.invite( room_id2, src=user2_id, targ=user1_id, tok=user2_tok @@ -2228,61 +2792,71 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): # Leave room3 self.helper.leave(room_id3, user1_id, tok=user1_tok) - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), + interested_rooms = self.get_success( + self.sliding_sync_handler.room_lists.compute_interested_rooms( + SlidingSyncConfig( + user=UserID.from_string(user1_id), + requester=create_requester(user_id=user1_id), + lists={ + "foo-list": SlidingSyncConfig.SlidingSyncList( + ranges=[(0, 99)], + required_state=[], + timeline_limit=1, + ) + }, + conn_id=None, + ), + PerConnectionState(), from_token=before_room3_token, to_token=after_room3_token, ) ) + room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids) + newly_joined = interested_rooms.newly_joined_rooms + newly_left = interested_rooms.newly_left_rooms - self.assertEqual( - room_id_results.keys(), + self.assertIncludes( + room_id_results, { - # Left before the from/to range - room_id1, + # Excluded because we left before the from/to range + # room_id1, # Invited before the from/to range room_id2, # `newly_left` during the from/to range room_id3, }, + exact=True, ) - # Room1 - # It should be pointing to the latest membership event in the from/to range - self.assertEqual( - room_id_results[room_id1].event_id, - leave_room1_response["event_id"], - ) - self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE) - # We should *NOT* be `newly_joined`/`newly_left` because we were invited and left - # before the token range - self.assertEqual(room_id_results[room_id1].newly_joined, False) - self.assertEqual(room_id_results[room_id1].newly_left, False) - # Room2 # It should be pointing to the latest membership event in the from/to range self.assertEqual( - room_id_results[room_id2].event_id, + interested_rooms.room_membership_for_user_map[room_id2].event_id, invite_room2_response["event_id"], ) - self.assertEqual(room_id_results[room_id2].membership, Membership.INVITE) + self.assertEqual( + interested_rooms.room_membership_for_user_map[room_id2].membership, + Membership.INVITE, + ) # We should *NOT* be `newly_joined`/`newly_left` because we were invited before # the token range - self.assertEqual(room_id_results[room_id2].newly_joined, False) - self.assertEqual(room_id_results[room_id2].newly_left, False) + self.assertTrue(room_id2 not in newly_joined) + self.assertTrue(room_id2 not in newly_left) # Room3 # It should be pointing to the latest membership event in the from/to range self.assertEqual( - room_id_results[room_id3].event_id, + interested_rooms.room_membership_for_user_map[room_id3].event_id, leave_room3_response["event_id"], ) - self.assertEqual(room_id_results[room_id3].membership, Membership.LEAVE) + self.assertEqual( + interested_rooms.room_membership_for_user_map[room_id3].membership, + Membership.LEAVE, + ) # We should be `newly_left` because we were invited and left during # the token range - self.assertEqual(room_id_results[room_id3].newly_joined, False) - self.assertEqual(room_id_results[room_id3].newly_left, True) + self.assertTrue(room_id3 not in newly_joined) + self.assertTrue(room_id3 in newly_left) def test_state_reset(self) -> None: """ @@ -2295,7 +2869,16 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): user2_tok = self.login(user2_id, "pass") # The room where the state reset will happen - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) + room_id1 = self.helper.create_room_as( + user2_id, + is_public=True, + tok=user2_tok, + ) + # Create a dummy event for us to point back to for the state reset + dummy_event_response = self.helper.send(room_id1, "test", tok=user2_tok) + dummy_event_id = dummy_event_response["event_id"] + + # Join after the dummy event join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok) # Join another room so we don't hit the short-circuit and return early if they @@ -2305,95 +2888,106 @@ class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): before_reset_token = self.event_sources.get_current_token() - # Send another state event to make a position for the state reset to happen at - dummy_state_response = self.helper.send_state( - room_id1, - event_type="foobarbaz", - state_key="", - body={"foo": "bar"}, - tok=user2_tok, - ) - dummy_state_pos = self.get_success( - self.store.get_position_for_event(dummy_state_response["event_id"]) - ) - - # Mock a state reset removing the membership for user1 in the current state - self.get_success( - self.store.db_pool.simple_delete( - table="current_state_events", - keyvalues={ - "room_id": room_id1, - "type": EventTypes.Member, - "state_key": user1_id, - }, - desc="state reset user in current_state_events", + # Trigger a state reset + join_rule_event, join_rule_context = self.get_success( + create_event( + self.hs, + prev_event_ids=[dummy_event_id], + type=EventTypes.JoinRules, + state_key="", + content={"join_rule": JoinRules.INVITE}, + sender=user2_id, + room_id=room_id1, + room_version=self.get_success(self.store.get_room_version_id(room_id1)), ) ) - self.get_success( - self.store.db_pool.simple_delete( - table="local_current_membership", - keyvalues={ - "room_id": room_id1, - "user_id": user1_id, - }, - desc="state reset user in local_current_membership", - ) - ) - self.get_success( - self.store.db_pool.simple_insert( - table="current_state_delta_stream", - values={ - "stream_id": dummy_state_pos.stream, - "room_id": room_id1, - "type": EventTypes.Member, - "state_key": user1_id, - "event_id": None, - "prev_event_id": join_response1["event_id"], - "instance_name": dummy_state_pos.instance_name, - }, - desc="state reset user in current_state_delta_stream", - ) + _, join_rule_event_pos, _ = self.get_success( + self.persistence.persist_event(join_rule_event, join_rule_context) ) - # Manually bust the cache since we we're just manually messing with the database - # and not causing an actual state reset. - self.store._membership_stream_cache.entity_has_changed( - user1_id, dummy_state_pos.stream - ) + # Ensure that the state reset worked and only user2 is in the room now + users_in_room = self.get_success(self.store.get_users_in_room(room_id1)) + self.assertIncludes(set(users_in_room), {user2_id}, exact=True) after_reset_token = self.event_sources.get_current_token() # The function under test - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), + interested_rooms = self.get_success( + self.sliding_sync_handler.room_lists.compute_interested_rooms( + SlidingSyncConfig( + user=UserID.from_string(user1_id), + requester=create_requester(user_id=user1_id), + lists={ + "foo-list": SlidingSyncConfig.SlidingSyncList( + ranges=[(0, 99)], + required_state=[], + timeline_limit=1, + ) + }, + conn_id=None, + ), + PerConnectionState(), from_token=before_reset_token, to_token=after_reset_token, ) ) + room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids) + newly_joined = interested_rooms.newly_joined_rooms + newly_left = interested_rooms.newly_left_rooms # Room1 should show up because it was `newly_left` via state reset during the from/to range - self.assertEqual(room_id_results.keys(), {room_id1, room_id2}) + self.assertIncludes(room_id_results, {room_id1, room_id2}, exact=True) # It should be pointing to no event because we were removed from the room # without a corresponding leave event self.assertEqual( - room_id_results[room_id1].event_id, + interested_rooms.room_membership_for_user_map[room_id1].event_id, None, + "Corresponding map to disambiguate the opaque event IDs: " + + str( + { + "join_response1": join_response1["event_id"], + } + ), ) # State reset caused us to leave the room and there is no corresponding leave event - self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE) + self.assertEqual( + interested_rooms.room_membership_for_user_map[room_id1].membership, + Membership.LEAVE, + ) # We should *NOT* be `newly_joined` because we joined before the token range - self.assertEqual(room_id_results[room_id1].newly_joined, False) + self.assertTrue(room_id1 not in newly_joined) # We should be `newly_left` because we were removed via state reset during the from/to range - self.assertEqual(room_id_results[room_id1].newly_left, True) - - -class GetRoomMembershipForUserAtToTokenShardTestCase(BaseMultiWorkerStreamTestCase): + self.assertTrue(room_id1 in newly_left) + + +# FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the +# foreground update for +# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by +# https://github.com/element-hq/synapse/issues/17623) +@parameterized_class( + ("use_new_tables",), + [ + (True,), + (False,), + ], + class_name_func=lambda cls, + num, + params_dict: f"{cls.__name__}_{'new' if params_dict['use_new_tables'] else 'fallback'}", +) +class ComputeInterestedRoomsShardTestCase( + BaseMultiWorkerStreamTestCase, SlidingSyncBase +): """ - Tests Sliding Sync handler `get_room_membership_for_user_at_to_token()` to make sure it works with + Tests Sliding Sync handler `compute_interested_rooms()` to make sure it works with sharded event stream_writers enabled """ + # FIXME: We should refactor these tests to run against `compute_interested_rooms(...)` + # instead of just `get_room_membership_for_user_at_to_token(...)` which is only used + # in the fallback path (`_compute_interested_rooms_fallback(...)`). These scenarios do + # well to stress that logic and we shouldn't remove them just because we're removing + # the fallback path (tracked by https://github.com/element-hq/synapse/issues/17623). + servlets = [ admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -2488,7 +3082,7 @@ class GetRoomMembershipForUserAtToTokenShardTestCase(BaseMultiWorkerStreamTestCa join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok) join_response2 = self.helper.join(room_id2, user1_id, tok=user1_tok) # Leave room2 - leave_room2_response = self.helper.leave(room_id2, user1_id, tok=user1_tok) + _leave_room2_response = self.helper.leave(room_id2, user1_id, tok=user1_tok) join_response3 = self.helper.join(room_id3, user1_id, tok=user1_tok) # Leave room3 self.helper.leave(room_id3, user1_id, tok=user1_tok) @@ -2578,60 +3172,77 @@ class GetRoomMembershipForUserAtToTokenShardTestCase(BaseMultiWorkerStreamTestCa self.get_success(actx.__aexit__(None, None, None)) # The function under test - room_id_results = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - UserID.from_string(user1_id), + interested_rooms = self.get_success( + self.sliding_sync_handler.room_lists.compute_interested_rooms( + SlidingSyncConfig( + user=UserID.from_string(user1_id), + requester=create_requester(user_id=user1_id), + lists={ + "foo-list": SlidingSyncConfig.SlidingSyncList( + ranges=[(0, 99)], + required_state=[], + timeline_limit=1, + ) + }, + conn_id=None, + ), + PerConnectionState(), from_token=before_stuck_activity_token, to_token=stuck_activity_token, ) ) + room_id_results = set(interested_rooms.lists["foo-list"].ops[0].room_ids) + newly_joined = interested_rooms.newly_joined_rooms + newly_left = interested_rooms.newly_left_rooms - self.assertEqual( - room_id_results.keys(), + self.assertIncludes( + room_id_results, { room_id1, - room_id2, + # Excluded because we left before the from/to range and the second join + # event happened while worker2 was stuck and technically occurs after + # the `stuck_activity_token`. + # room_id2, room_id3, }, + exact=True, + message="Corresponding map to disambiguate the opaque room IDs: " + + str( + { + "room_id1": room_id1, + "room_id2": room_id2, + "room_id3": room_id3, + } + ), ) # Room1 # It should be pointing to the latest membership event in the from/to range self.assertEqual( - room_id_results[room_id1].event_id, + interested_rooms.room_membership_for_user_map[room_id1].event_id, join_room1_response["event_id"], ) - self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) - # We should be `newly_joined` because we joined during the token range - self.assertEqual(room_id_results[room_id1].newly_joined, True) - self.assertEqual(room_id_results[room_id1].newly_left, False) - - # Room2 - # It should be pointing to the latest membership event in the from/to range self.assertEqual( - room_id_results[room_id2].event_id, - leave_room2_response["event_id"], - ) - self.assertEqual(room_id_results[room_id2].membership, Membership.LEAVE) - # room_id2 should *NOT* be considered `newly_left` because we left before the - # from/to range and the join event during the range happened while worker2 was - # stuck. This means that from the perspective of the master, where the - # `stuck_activity_token` is generated, the stream position for worker2 wasn't - # advanced to the join yet. Looking at the `instance_map`, the join technically - # comes after `stuck_activity_token`. - self.assertEqual(room_id_results[room_id2].newly_joined, False) - self.assertEqual(room_id_results[room_id2].newly_left, False) + interested_rooms.room_membership_for_user_map[room_id1].membership, + Membership.JOIN, + ) + # We should be `newly_joined` because we joined during the token range + self.assertTrue(room_id1 in newly_joined) + self.assertTrue(room_id1 not in newly_left) # Room3 # It should be pointing to the latest membership event in the from/to range self.assertEqual( - room_id_results[room_id3].event_id, + interested_rooms.room_membership_for_user_map[room_id3].event_id, join_on_worker3_response["event_id"], ) - self.assertEqual(room_id_results[room_id3].membership, Membership.JOIN) + self.assertEqual( + interested_rooms.room_membership_for_user_map[room_id3].membership, + Membership.JOIN, + ) # We should be `newly_joined` because we joined during the token range - self.assertEqual(room_id_results[room_id3].newly_joined, True) - self.assertEqual(room_id_results[room_id3].newly_left, False) + self.assertTrue(room_id3 in newly_joined) + self.assertTrue(room_id3 not in newly_left) class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase): @@ -2658,31 +3269,35 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase): self.store = self.hs.get_datastores().main self.event_sources = hs.get_event_sources() self.storage_controllers = hs.get_storage_controllers() + persistence = self.hs.get_storage_controllers().persistence + assert persistence is not None + self.persistence = persistence def _get_sync_room_ids_for_user( self, user: UserID, to_token: StreamToken, from_token: Optional[StreamToken], - ) -> Dict[str, _RoomMembershipForUser]: + ) -> Tuple[Dict[str, RoomsForUserType], AbstractSet[str], AbstractSet[str]]: """ Get the rooms the user should be syncing with """ - room_membership_for_user_map = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( + room_membership_for_user_map, newly_joined, newly_left = self.get_success( + self.sliding_sync_handler.room_lists.get_room_membership_for_user_at_to_token( user=user, from_token=from_token, to_token=to_token, ) ) filtered_sync_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms_relevant_for_sync( + self.sliding_sync_handler.room_lists.filter_rooms_relevant_for_sync( user=user, room_membership_for_user_map=room_membership_for_user_map, + newly_left_room_ids=newly_left, ) ) - return filtered_sync_room_map + return filtered_sync_room_map, newly_joined, newly_left def test_no_rooms(self) -> None: """ @@ -2693,13 +3308,13 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase): now_token = self.event_sources.get_current_token() - room_id_results = self._get_sync_room_ids_for_user( + room_id_results, newly_joined, newly_left = self._get_sync_room_ids_for_user( UserID.from_string(user1_id), from_token=now_token, to_token=now_token, ) - self.assertEqual(room_id_results.keys(), set()) + self.assertIncludes(room_id_results.keys(), set(), exact=True) def test_basic_rooms(self) -> None: """ @@ -2758,14 +3373,14 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase): after_room_token = self.event_sources.get_current_token() - room_id_results = self._get_sync_room_ids_for_user( + room_id_results, newly_joined, newly_left = self._get_sync_room_ids_for_user( UserID.from_string(user1_id), from_token=before_room_token, to_token=after_room_token, ) # Ensure that the invited, ban, and knock rooms show up - self.assertEqual( + self.assertIncludes( room_id_results.keys(), { join_room_id, @@ -2773,6 +3388,7 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase): ban_room_id, knock_room_id, }, + exact=True, ) # It should be pointing to the the respective membership event (latest # membership event in the from/to range) @@ -2781,32 +3397,32 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase): join_response["event_id"], ) self.assertEqual(room_id_results[join_room_id].membership, Membership.JOIN) - self.assertEqual(room_id_results[join_room_id].newly_joined, True) - self.assertEqual(room_id_results[join_room_id].newly_left, False) + self.assertTrue(join_room_id in newly_joined) + self.assertTrue(join_room_id not in newly_left) self.assertEqual( room_id_results[invited_room_id].event_id, invite_response["event_id"], ) self.assertEqual(room_id_results[invited_room_id].membership, Membership.INVITE) - self.assertEqual(room_id_results[invited_room_id].newly_joined, False) - self.assertEqual(room_id_results[invited_room_id].newly_left, False) + self.assertTrue(invited_room_id not in newly_joined) + self.assertTrue(invited_room_id not in newly_left) self.assertEqual( room_id_results[ban_room_id].event_id, ban_response["event_id"], ) self.assertEqual(room_id_results[ban_room_id].membership, Membership.BAN) - self.assertEqual(room_id_results[ban_room_id].newly_joined, False) - self.assertEqual(room_id_results[ban_room_id].newly_left, False) + self.assertTrue(ban_room_id not in newly_joined) + self.assertTrue(ban_room_id not in newly_left) self.assertEqual( room_id_results[knock_room_id].event_id, knock_room_membership_state_event.event_id, ) self.assertEqual(room_id_results[knock_room_id].membership, Membership.KNOCK) - self.assertEqual(room_id_results[knock_room_id].newly_joined, False) - self.assertEqual(room_id_results[knock_room_id].newly_left, False) + self.assertTrue(knock_room_id not in newly_joined) + self.assertTrue(knock_room_id not in newly_left) def test_only_newly_left_rooms_show_up(self) -> None: """ @@ -2829,21 +3445,21 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase): after_room2_token = self.event_sources.get_current_token() - room_id_results = self._get_sync_room_ids_for_user( + room_id_results, newly_joined, newly_left = self._get_sync_room_ids_for_user( UserID.from_string(user1_id), from_token=after_room1_token, to_token=after_room2_token, ) # Only the `newly_left` room should show up - self.assertEqual(room_id_results.keys(), {room_id2}) + self.assertIncludes(room_id_results.keys(), {room_id2}, exact=True) self.assertEqual( room_id_results[room_id2].event_id, _leave_response2["event_id"], ) # We should *NOT* be `newly_joined` because we are instead `newly_left` - self.assertEqual(room_id_results[room_id2].newly_joined, False) - self.assertEqual(room_id_results[room_id2].newly_left, True) + self.assertTrue(room_id2 not in newly_joined) + self.assertTrue(room_id2 in newly_left) def test_get_kicked_room(self) -> None: """ @@ -2874,14 +3490,14 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase): after_kick_token = self.event_sources.get_current_token() - room_id_results = self._get_sync_room_ids_for_user( + room_id_results, newly_joined, newly_left = self._get_sync_room_ids_for_user( UserID.from_string(user1_id), from_token=after_kick_token, to_token=after_kick_token, ) # The kicked room should show up - self.assertEqual(room_id_results.keys(), {kick_room_id}) + self.assertIncludes(room_id_results.keys(), {kick_room_id}, exact=True) # It should be pointing to the latest membership event in the from/to range self.assertEqual( room_id_results[kick_room_id].event_id, @@ -2891,8 +3507,8 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase): self.assertNotEqual(room_id_results[kick_room_id].sender, user1_id) # We should *NOT* be `newly_joined` because we were not joined at the the time # of the `to_token`. - self.assertEqual(room_id_results[kick_room_id].newly_joined, False) - self.assertEqual(room_id_results[kick_room_id].newly_left, False) + self.assertTrue(kick_room_id not in newly_joined) + self.assertTrue(kick_room_id not in newly_left) def test_state_reset(self) -> None: """ @@ -2905,8 +3521,17 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase): user2_tok = self.login(user2_id, "pass") # The room where the state reset will happen - room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) - join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok) + room_id1 = self.helper.create_room_as( + user2_id, + is_public=True, + tok=user2_tok, + ) + # Create a dummy event for us to point back to for the state reset + dummy_event_response = self.helper.send(room_id1, "test", tok=user2_tok) + dummy_event_id = dummy_event_response["event_id"] + + # Join after the dummy event + self.helper.join(room_id1, user1_id, tok=user1_tok) # Join another room so we don't hit the short-circuit and return early if they # have no room membership @@ -2915,73 +3540,38 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase): before_reset_token = self.event_sources.get_current_token() - # Send another state event to make a position for the state reset to happen at - dummy_state_response = self.helper.send_state( - room_id1, - event_type="foobarbaz", - state_key="", - body={"foo": "bar"}, - tok=user2_tok, - ) - dummy_state_pos = self.get_success( - self.store.get_position_for_event(dummy_state_response["event_id"]) - ) - - # Mock a state reset removing the membership for user1 in the current state - self.get_success( - self.store.db_pool.simple_delete( - table="current_state_events", - keyvalues={ - "room_id": room_id1, - "type": EventTypes.Member, - "state_key": user1_id, - }, - desc="state reset user in current_state_events", + # Trigger a state reset + join_rule_event, join_rule_context = self.get_success( + create_event( + self.hs, + prev_event_ids=[dummy_event_id], + type=EventTypes.JoinRules, + state_key="", + content={"join_rule": JoinRules.INVITE}, + sender=user2_id, + room_id=room_id1, + room_version=self.get_success(self.store.get_room_version_id(room_id1)), ) ) - self.get_success( - self.store.db_pool.simple_delete( - table="local_current_membership", - keyvalues={ - "room_id": room_id1, - "user_id": user1_id, - }, - desc="state reset user in local_current_membership", - ) - ) - self.get_success( - self.store.db_pool.simple_insert( - table="current_state_delta_stream", - values={ - "stream_id": dummy_state_pos.stream, - "room_id": room_id1, - "type": EventTypes.Member, - "state_key": user1_id, - "event_id": None, - "prev_event_id": join_response1["event_id"], - "instance_name": dummy_state_pos.instance_name, - }, - desc="state reset user in current_state_delta_stream", - ) + _, join_rule_event_pos, _ = self.get_success( + self.persistence.persist_event(join_rule_event, join_rule_context) ) - # Manually bust the cache since we we're just manually messing with the database - # and not causing an actual state reset. - self.store._membership_stream_cache.entity_has_changed( - user1_id, dummy_state_pos.stream - ) + # Ensure that the state reset worked and only user2 is in the room now + users_in_room = self.get_success(self.store.get_users_in_room(room_id1)) + self.assertIncludes(set(users_in_room), {user2_id}, exact=True) after_reset_token = self.event_sources.get_current_token() # The function under test - room_id_results = self._get_sync_room_ids_for_user( + room_id_results, newly_joined, newly_left = self._get_sync_room_ids_for_user( UserID.from_string(user1_id), from_token=before_reset_token, to_token=after_reset_token, ) # Room1 should show up because it was `newly_left` via state reset during the from/to range - self.assertEqual(room_id_results.keys(), {room_id1, room_id2}) + self.assertIncludes(room_id_results.keys(), {room_id1, room_id2}, exact=True) # It should be pointing to no event because we were removed from the room # without a corresponding leave event self.assertEqual( @@ -2991,1345 +3581,9 @@ class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase): # State reset caused us to leave the room and there is no corresponding leave event self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE) # We should *NOT* be `newly_joined` because we joined before the token range - self.assertEqual(room_id_results[room_id1].newly_joined, False) + self.assertTrue(room_id1 not in newly_joined) # We should be `newly_left` because we were removed via state reset during the from/to range - self.assertEqual(room_id_results[room_id1].newly_left, True) - - -class FilterRoomsTestCase(HomeserverTestCase): - """ - Tests Sliding Sync handler `filter_rooms()` to make sure it includes/excludes rooms - correctly. - """ - - servlets = [ - admin.register_servlets, - knock.register_servlets, - login.register_servlets, - room.register_servlets, - ] - - def default_config(self) -> JsonDict: - config = super().default_config() - # Enable sliding sync - config["experimental_features"] = {"msc3575_enabled": True} - return config - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.sliding_sync_handler = self.hs.get_sliding_sync_handler() - self.store = self.hs.get_datastores().main - self.event_sources = hs.get_event_sources() - - def _get_sync_room_ids_for_user( - self, - user: UserID, - to_token: StreamToken, - from_token: Optional[StreamToken], - ) -> Dict[str, _RoomMembershipForUser]: - """ - Get the rooms the user should be syncing with - """ - room_membership_for_user_map = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( - user=user, - from_token=from_token, - to_token=to_token, - ) - ) - filtered_sync_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms_relevant_for_sync( - user=user, - room_membership_for_user_map=room_membership_for_user_map, - ) - ) - - return filtered_sync_room_map - - def _create_dm_room( - self, - inviter_user_id: str, - inviter_tok: str, - invitee_user_id: str, - invitee_tok: str, - ) -> str: - """ - Helper to create a DM room as the "inviter" and invite the "invitee" user to the room. The - "invitee" user also will join the room. The `m.direct` account data will be set - for both users. - """ - - # Create a room and send an invite the other user - room_id = self.helper.create_room_as( - inviter_user_id, - is_public=False, - tok=inviter_tok, - ) - self.helper.invite( - room_id, - src=inviter_user_id, - targ=invitee_user_id, - tok=inviter_tok, - extra_data={"is_direct": True}, - ) - # Person that was invited joins the room - self.helper.join(room_id, invitee_user_id, tok=invitee_tok) - - # Mimic the client setting the room as a direct message in the global account - # data - self.get_success( - self.store.add_account_data_for_user( - invitee_user_id, - AccountDataTypes.DIRECT, - {inviter_user_id: [room_id]}, - ) - ) - self.get_success( - self.store.add_account_data_for_user( - inviter_user_id, - AccountDataTypes.DIRECT, - {invitee_user_id: [room_id]}, - ) - ) - - return room_id - - _remote_invite_count: int = 0 - - def _create_remote_invite_room_for_user( - self, - invitee_user_id: str, - unsigned_invite_room_state: Optional[List[StrippedStateEvent]], - ) -> str: - """ - Create a fake invite for a remote room and persist it. - - We don't have any state for these kind of rooms and can only rely on the - stripped state included in the unsigned portion of the invite event to identify - the room. - - Args: - invitee_user_id: The person being invited - unsigned_invite_room_state: List of stripped state events to assist the - receiver in identifying the room. - - Returns: - The room ID of the remote invite room - """ - invite_room_id = f"!test_room{self._remote_invite_count}:remote_server" - - invite_event_dict = { - "room_id": invite_room_id, - "sender": "@inviter:remote_server", - "state_key": invitee_user_id, - "depth": 1, - "origin_server_ts": 1, - "type": EventTypes.Member, - "content": {"membership": Membership.INVITE}, - "auth_events": [], - "prev_events": [], - } - if unsigned_invite_room_state is not None: - serialized_stripped_state_events = [] - for stripped_event in unsigned_invite_room_state: - serialized_stripped_state_events.append( - { - "type": stripped_event.type, - "state_key": stripped_event.state_key, - "sender": stripped_event.sender, - "content": stripped_event.content, - } - ) - - invite_event_dict["unsigned"] = { - "invite_room_state": serialized_stripped_state_events - } - - invite_event = make_event_from_dict( - invite_event_dict, - room_version=RoomVersions.V10, - ) - invite_event.internal_metadata.outlier = True - invite_event.internal_metadata.out_of_band_membership = True - - self.get_success( - self.store.maybe_store_room_on_outlier_membership( - room_id=invite_room_id, room_version=invite_event.room_version - ) - ) - context = EventContext.for_outlier(self.hs.get_storage_controllers()) - persist_controller = self.hs.get_storage_controllers().persistence - assert persist_controller is not None - self.get_success(persist_controller.persist_event(invite_event, context)) - - self._remote_invite_count += 1 - - return invite_room_id - - def test_filter_dm_rooms(self) -> None: - """ - Test `filter.is_dm` for DM rooms - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - # Create a normal room - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - - # Create a DM room - dm_room_id = self._create_dm_room( - inviter_user_id=user1_id, - inviter_tok=user1_tok, - invitee_user_id=user2_id, - invitee_tok=user2_tok, - ) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=None, - to_token=after_rooms_token, - ) - - # Try with `is_dm=True` - truthy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_dm=True, - ), - after_rooms_token, - ) - ) - - self.assertEqual(truthy_filtered_room_map.keys(), {dm_room_id}) - - # Try with `is_dm=False` - falsy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_dm=False, - ), - after_rooms_token, - ) - ) - - self.assertEqual(falsy_filtered_room_map.keys(), {room_id}) - - def test_filter_encrypted_rooms(self) -> None: - """ - Test `filter.is_encrypted` for encrypted rooms - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - # Create an unencrypted room - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - - # Create an encrypted room - encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - self.helper.send_state( - encrypted_room_id, - EventTypes.RoomEncryption, - {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"}, - tok=user1_tok, - ) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=None, - to_token=after_rooms_token, - ) - - # Try with `is_encrypted=True` - truthy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_encrypted=True, - ), - after_rooms_token, - ) - ) - - self.assertEqual(truthy_filtered_room_map.keys(), {encrypted_room_id}) - - # Try with `is_encrypted=False` - falsy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_encrypted=False, - ), - after_rooms_token, - ) - ) - - self.assertEqual(falsy_filtered_room_map.keys(), {room_id}) - - def test_filter_encrypted_server_left_room(self) -> None: - """ - Test that we can apply a `filter.is_encrypted` against a room that everyone has left. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - before_rooms_token = self.event_sources.get_current_token() - - # Create an unencrypted room - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - # Leave the room - self.helper.leave(room_id, user1_id, tok=user1_tok) - - # Create an encrypted room - encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - self.helper.send_state( - encrypted_room_id, - EventTypes.RoomEncryption, - {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"}, - tok=user1_tok, - ) - # Leave the room - self.helper.leave(encrypted_room_id, user1_id, tok=user1_tok) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - # We're using a `from_token` so that the room is considered `newly_left` and - # appears in our list of relevant sync rooms - from_token=before_rooms_token, - to_token=after_rooms_token, - ) - - # Try with `is_encrypted=True` - truthy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_encrypted=True, - ), - after_rooms_token, - ) - ) - - self.assertEqual(truthy_filtered_room_map.keys(), {encrypted_room_id}) - - # Try with `is_encrypted=False` - falsy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_encrypted=False, - ), - after_rooms_token, - ) - ) - - self.assertEqual(falsy_filtered_room_map.keys(), {room_id}) - - def test_filter_encrypted_server_left_room2(self) -> None: - """ - Test that we can apply a `filter.is_encrypted` against a room that everyone has - left. - - There is still someone local who is invited to the rooms but that doesn't affect - whether the server is participating in the room (users need to be joined). - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - _user2_tok = self.login(user2_id, "pass") - - before_rooms_token = self.event_sources.get_current_token() - - # Create an unencrypted room - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - # Invite user2 - self.helper.invite(room_id, targ=user2_id, tok=user1_tok) - # User1 leaves the room - self.helper.leave(room_id, user1_id, tok=user1_tok) - - # Create an encrypted room - encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - self.helper.send_state( - encrypted_room_id, - EventTypes.RoomEncryption, - {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"}, - tok=user1_tok, - ) - # Invite user2 - self.helper.invite(encrypted_room_id, targ=user2_id, tok=user1_tok) - # User1 leaves the room - self.helper.leave(encrypted_room_id, user1_id, tok=user1_tok) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - # We're using a `from_token` so that the room is considered `newly_left` and - # appears in our list of relevant sync rooms - from_token=before_rooms_token, - to_token=after_rooms_token, - ) - - # Try with `is_encrypted=True` - truthy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_encrypted=True, - ), - after_rooms_token, - ) - ) - - self.assertEqual(truthy_filtered_room_map.keys(), {encrypted_room_id}) - - # Try with `is_encrypted=False` - falsy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_encrypted=False, - ), - after_rooms_token, - ) - ) - - self.assertEqual(falsy_filtered_room_map.keys(), {room_id}) - - def test_filter_encrypted_after_we_left(self) -> None: - """ - Test that we can apply a `filter.is_encrypted` against a room that was encrypted - after we left the room (make sure we don't just use the current state) - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - before_rooms_token = self.event_sources.get_current_token() - - # Create an unencrypted room - room_id = self.helper.create_room_as(user2_id, tok=user2_tok) - # Leave the room - self.helper.join(room_id, user1_id, tok=user1_tok) - self.helper.leave(room_id, user1_id, tok=user1_tok) - - # Create a room that will be encrypted - encrypted_after_we_left_room_id = self.helper.create_room_as( - user2_id, tok=user2_tok - ) - # Leave the room - self.helper.join(encrypted_after_we_left_room_id, user1_id, tok=user1_tok) - self.helper.leave(encrypted_after_we_left_room_id, user1_id, tok=user1_tok) - - # Encrypt the room after we've left - self.helper.send_state( - encrypted_after_we_left_room_id, - EventTypes.RoomEncryption, - {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"}, - tok=user2_tok, - ) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - # We're using a `from_token` so that the room is considered `newly_left` and - # appears in our list of relevant sync rooms - from_token=before_rooms_token, - to_token=after_rooms_token, - ) - - # Try with `is_encrypted=True` - truthy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_encrypted=True, - ), - after_rooms_token, - ) - ) - - # Even though we left the room before it was encrypted, we still see it because - # someone else on our server is still participating in the room and we "leak" - # the current state to the left user. But we consider the room encryption status - # to not be a secret given it's often set at the start of the room and it's one - # of the stripped state events that is normally handed out. - self.assertEqual( - truthy_filtered_room_map.keys(), {encrypted_after_we_left_room_id} - ) - - # Try with `is_encrypted=False` - falsy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_encrypted=False, - ), - after_rooms_token, - ) - ) - - # Even though we left the room before it was encrypted... (see comment above) - self.assertEqual(falsy_filtered_room_map.keys(), {room_id}) - - def test_filter_encrypted_with_remote_invite_room_no_stripped_state(self) -> None: - """ - Test that we can apply a `filter.is_encrypted` filter against a remote invite - room without any `unsigned.invite_room_state` (stripped state). - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - # Create a remote invite room without any `unsigned.invite_room_state` - _remote_invite_room_id = self._create_remote_invite_room_for_user( - user1_id, None - ) - - # Create an unencrypted room - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - - # Create an encrypted room - encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - self.helper.send_state( - encrypted_room_id, - EventTypes.RoomEncryption, - {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"}, - tok=user1_tok, - ) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=None, - to_token=after_rooms_token, - ) - - # Try with `is_encrypted=True` - truthy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_encrypted=True, - ), - after_rooms_token, - ) - ) - - # `remote_invite_room_id` should not appear because we can't figure out whether - # it is encrypted or not (no stripped state, `unsigned.invite_room_state`). - self.assertEqual(truthy_filtered_room_map.keys(), {encrypted_room_id}) - - # Try with `is_encrypted=False` - falsy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_encrypted=False, - ), - after_rooms_token, - ) - ) - - # `remote_invite_room_id` should not appear because we can't figure out whether - # it is encrypted or not (no stripped state, `unsigned.invite_room_state`). - self.assertEqual(falsy_filtered_room_map.keys(), {room_id}) - - def test_filter_encrypted_with_remote_invite_encrypted_room(self) -> None: - """ - Test that we can apply a `filter.is_encrypted` filter against a remote invite - encrypted room with some `unsigned.invite_room_state` (stripped state). - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - # Create a remote invite room with some `unsigned.invite_room_state` - # indicating that the room is encrypted. - remote_invite_room_id = self._create_remote_invite_room_for_user( - user1_id, - [ - StrippedStateEvent( - type=EventTypes.Create, - state_key="", - sender="@inviter:remote_server", - content={ - EventContentFields.ROOM_CREATOR: "@inviter:remote_server", - EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier, - }, - ), - StrippedStateEvent( - type=EventTypes.RoomEncryption, - state_key="", - sender="@inviter:remote_server", - content={ - EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2", - }, - ), - ], - ) - - # Create an unencrypted room - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - - # Create an encrypted room - encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - self.helper.send_state( - encrypted_room_id, - EventTypes.RoomEncryption, - {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"}, - tok=user1_tok, - ) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=None, - to_token=after_rooms_token, - ) - - # Try with `is_encrypted=True` - truthy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_encrypted=True, - ), - after_rooms_token, - ) - ) - - # `remote_invite_room_id` should appear here because it is encrypted - # according to the stripped state - self.assertEqual( - truthy_filtered_room_map.keys(), {encrypted_room_id, remote_invite_room_id} - ) - - # Try with `is_encrypted=False` - falsy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_encrypted=False, - ), - after_rooms_token, - ) - ) - - # `remote_invite_room_id` should not appear here because it is encrypted - # according to the stripped state - self.assertEqual(falsy_filtered_room_map.keys(), {room_id}) - - def test_filter_encrypted_with_remote_invite_unencrypted_room(self) -> None: - """ - Test that we can apply a `filter.is_encrypted` filter against a remote invite - unencrypted room with some `unsigned.invite_room_state` (stripped state). - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - # Create a remote invite room with some `unsigned.invite_room_state` - # but don't set any room encryption event. - remote_invite_room_id = self._create_remote_invite_room_for_user( - user1_id, - [ - StrippedStateEvent( - type=EventTypes.Create, - state_key="", - sender="@inviter:remote_server", - content={ - EventContentFields.ROOM_CREATOR: "@inviter:remote_server", - EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier, - }, - ), - # No room encryption event - ], - ) - - # Create an unencrypted room - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - - # Create an encrypted room - encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - self.helper.send_state( - encrypted_room_id, - EventTypes.RoomEncryption, - {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"}, - tok=user1_tok, - ) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=None, - to_token=after_rooms_token, - ) - - # Try with `is_encrypted=True` - truthy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_encrypted=True, - ), - after_rooms_token, - ) - ) - - # `remote_invite_room_id` should not appear here because it is unencrypted - # according to the stripped state - self.assertEqual(truthy_filtered_room_map.keys(), {encrypted_room_id}) - - # Try with `is_encrypted=False` - falsy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_encrypted=False, - ), - after_rooms_token, - ) - ) - - # `remote_invite_room_id` should appear because it is unencrypted according to - # the stripped state - self.assertEqual( - falsy_filtered_room_map.keys(), {room_id, remote_invite_room_id} - ) - - def test_filter_invite_rooms(self) -> None: - """ - Test `filter.is_invite` for rooms that the user has been invited to - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - user2_tok = self.login(user2_id, "pass") - - # Create a normal room - room_id = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.join(room_id, user1_id, tok=user1_tok) - - # Create a room that user1 is invited to - invite_room_id = self.helper.create_room_as(user2_id, tok=user2_tok) - self.helper.invite(invite_room_id, src=user2_id, targ=user1_id, tok=user2_tok) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=None, - to_token=after_rooms_token, - ) - - # Try with `is_invite=True` - truthy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_invite=True, - ), - after_rooms_token, - ) - ) - - self.assertEqual(truthy_filtered_room_map.keys(), {invite_room_id}) - - # Try with `is_invite=False` - falsy_filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - is_invite=False, - ), - after_rooms_token, - ) - ) - - self.assertEqual(falsy_filtered_room_map.keys(), {room_id}) - - def test_filter_room_types(self) -> None: - """ - Test `filter.room_types` for different room types - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - # Create a normal room (no room type) - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - - # Create a space room - space_room_id = self.helper.create_room_as( - user1_id, - tok=user1_tok, - extra_content={ - "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} - }, - ) - - # Create an arbitrarily typed room - foo_room_id = self.helper.create_room_as( - user1_id, - tok=user1_tok, - extra_content={ - "creation_content": { - EventContentFields.ROOM_TYPE: "org.matrix.foobarbaz" - } - }, - ) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=None, - to_token=after_rooms_token, - ) - - # Try finding only normal rooms - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters(room_types=[None]), - after_rooms_token, - ) - ) - - self.assertEqual(filtered_room_map.keys(), {room_id}) - - # Try finding only spaces - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters(room_types=[RoomTypes.SPACE]), - after_rooms_token, - ) - ) - - self.assertEqual(filtered_room_map.keys(), {space_room_id}) - - # Try finding normal rooms and spaces - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - room_types=[None, RoomTypes.SPACE] - ), - after_rooms_token, - ) - ) - - self.assertEqual(filtered_room_map.keys(), {room_id, space_room_id}) - - # Try finding an arbitrary room type - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - room_types=["org.matrix.foobarbaz"] - ), - after_rooms_token, - ) - ) - - self.assertEqual(filtered_room_map.keys(), {foo_room_id}) - - def test_filter_not_room_types(self) -> None: - """ - Test `filter.not_room_types` for different room types - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - # Create a normal room (no room type) - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - - # Create a space room - space_room_id = self.helper.create_room_as( - user1_id, - tok=user1_tok, - extra_content={ - "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} - }, - ) - - # Create an arbitrarily typed room - foo_room_id = self.helper.create_room_as( - user1_id, - tok=user1_tok, - extra_content={ - "creation_content": { - EventContentFields.ROOM_TYPE: "org.matrix.foobarbaz" - } - }, - ) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=None, - to_token=after_rooms_token, - ) - - # Try finding *NOT* normal rooms - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters(not_room_types=[None]), - after_rooms_token, - ) - ) - - self.assertEqual(filtered_room_map.keys(), {space_room_id, foo_room_id}) - - # Try finding *NOT* spaces - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - not_room_types=[RoomTypes.SPACE] - ), - after_rooms_token, - ) - ) - - self.assertEqual(filtered_room_map.keys(), {room_id, foo_room_id}) - - # Try finding *NOT* normal rooms or spaces - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - not_room_types=[None, RoomTypes.SPACE] - ), - after_rooms_token, - ) - ) - - self.assertEqual(filtered_room_map.keys(), {foo_room_id}) - - # Test how it behaves when we have both `room_types` and `not_room_types`. - # `not_room_types` should win. - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - room_types=[None], not_room_types=[None] - ), - after_rooms_token, - ) - ) - - # Nothing matches because nothing is both a normal room and not a normal room - self.assertEqual(filtered_room_map.keys(), set()) - - # Test how it behaves when we have both `room_types` and `not_room_types`. - # `not_room_types` should win. - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters( - room_types=[None, RoomTypes.SPACE], not_room_types=[None] - ), - after_rooms_token, - ) - ) - - self.assertEqual(filtered_room_map.keys(), {space_room_id}) - - def test_filter_room_types_server_left_room(self) -> None: - """ - Test that we can apply a `filter.room_types` against a room that everyone has left. - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - before_rooms_token = self.event_sources.get_current_token() - - # Create a normal room (no room type) - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - # Leave the room - self.helper.leave(room_id, user1_id, tok=user1_tok) - - # Create a space room - space_room_id = self.helper.create_room_as( - user1_id, - tok=user1_tok, - extra_content={ - "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} - }, - ) - # Leave the room - self.helper.leave(space_room_id, user1_id, tok=user1_tok) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - # We're using a `from_token` so that the room is considered `newly_left` and - # appears in our list of relevant sync rooms - from_token=before_rooms_token, - to_token=after_rooms_token, - ) - - # Try finding only normal rooms - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters(room_types=[None]), - after_rooms_token, - ) - ) - - self.assertEqual(filtered_room_map.keys(), {room_id}) - - # Try finding only spaces - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters(room_types=[RoomTypes.SPACE]), - after_rooms_token, - ) - ) - - self.assertEqual(filtered_room_map.keys(), {space_room_id}) - - def test_filter_room_types_server_left_room2(self) -> None: - """ - Test that we can apply a `filter.room_types` against a room that everyone has left. - - There is still someone local who is invited to the rooms but that doesn't affect - whether the server is participating in the room (users need to be joined). - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - user2_id = self.register_user("user2", "pass") - _user2_tok = self.login(user2_id, "pass") - - before_rooms_token = self.event_sources.get_current_token() - - # Create a normal room (no room type) - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - # Invite user2 - self.helper.invite(room_id, targ=user2_id, tok=user1_tok) - # User1 leaves the room - self.helper.leave(room_id, user1_id, tok=user1_tok) - - # Create a space room - space_room_id = self.helper.create_room_as( - user1_id, - tok=user1_tok, - extra_content={ - "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} - }, - ) - # Invite user2 - self.helper.invite(space_room_id, targ=user2_id, tok=user1_tok) - # User1 leaves the room - self.helper.leave(space_room_id, user1_id, tok=user1_tok) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - # We're using a `from_token` so that the room is considered `newly_left` and - # appears in our list of relevant sync rooms - from_token=before_rooms_token, - to_token=after_rooms_token, - ) - - # Try finding only normal rooms - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters(room_types=[None]), - after_rooms_token, - ) - ) - - self.assertEqual(filtered_room_map.keys(), {room_id}) - - # Try finding only spaces - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters(room_types=[RoomTypes.SPACE]), - after_rooms_token, - ) - ) - - self.assertEqual(filtered_room_map.keys(), {space_room_id}) - - def test_filter_room_types_with_remote_invite_room_no_stripped_state(self) -> None: - """ - Test that we can apply a `filter.room_types` filter against a remote invite - room without any `unsigned.invite_room_state` (stripped state). - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - # Create a remote invite room without any `unsigned.invite_room_state` - _remote_invite_room_id = self._create_remote_invite_room_for_user( - user1_id, None - ) - - # Create a normal room (no room type) - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - - # Create a space room - space_room_id = self.helper.create_room_as( - user1_id, - tok=user1_tok, - extra_content={ - "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} - }, - ) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=None, - to_token=after_rooms_token, - ) - - # Try finding only normal rooms - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters(room_types=[None]), - after_rooms_token, - ) - ) - - # `remote_invite_room_id` should not appear because we can't figure out what - # room type it is (no stripped state, `unsigned.invite_room_state`) - self.assertEqual(filtered_room_map.keys(), {room_id}) - - # Try finding only spaces - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters(room_types=[RoomTypes.SPACE]), - after_rooms_token, - ) - ) - - # `remote_invite_room_id` should not appear because we can't figure out what - # room type it is (no stripped state, `unsigned.invite_room_state`) - self.assertEqual(filtered_room_map.keys(), {space_room_id}) - - def test_filter_room_types_with_remote_invite_space(self) -> None: - """ - Test that we can apply a `filter.room_types` filter against a remote invite - to a space room with some `unsigned.invite_room_state` (stripped state). - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - # Create a remote invite room with some `unsigned.invite_room_state` indicating - # that it is a space room - remote_invite_room_id = self._create_remote_invite_room_for_user( - user1_id, - [ - StrippedStateEvent( - type=EventTypes.Create, - state_key="", - sender="@inviter:remote_server", - content={ - EventContentFields.ROOM_CREATOR: "@inviter:remote_server", - EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier, - # Specify that it is a space room - EventContentFields.ROOM_TYPE: RoomTypes.SPACE, - }, - ), - ], - ) - - # Create a normal room (no room type) - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - - # Create a space room - space_room_id = self.helper.create_room_as( - user1_id, - tok=user1_tok, - extra_content={ - "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} - }, - ) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=None, - to_token=after_rooms_token, - ) - - # Try finding only normal rooms - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters(room_types=[None]), - after_rooms_token, - ) - ) - - # `remote_invite_room_id` should not appear here because it is a space room - # according to the stripped state - self.assertEqual(filtered_room_map.keys(), {room_id}) - - # Try finding only spaces - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters(room_types=[RoomTypes.SPACE]), - after_rooms_token, - ) - ) - - # `remote_invite_room_id` should appear here because it is a space room - # according to the stripped state - self.assertEqual( - filtered_room_map.keys(), {space_room_id, remote_invite_room_id} - ) - - def test_filter_room_types_with_remote_invite_normal_room(self) -> None: - """ - Test that we can apply a `filter.room_types` filter against a remote invite - to a normal room with some `unsigned.invite_room_state` (stripped state). - """ - user1_id = self.register_user("user1", "pass") - user1_tok = self.login(user1_id, "pass") - - # Create a remote invite room with some `unsigned.invite_room_state` - # but the create event does not specify a room type (normal room) - remote_invite_room_id = self._create_remote_invite_room_for_user( - user1_id, - [ - StrippedStateEvent( - type=EventTypes.Create, - state_key="", - sender="@inviter:remote_server", - content={ - EventContentFields.ROOM_CREATOR: "@inviter:remote_server", - EventContentFields.ROOM_VERSION: RoomVersions.V10.identifier, - # No room type means this is a normal room - }, - ), - ], - ) - - # Create a normal room (no room type) - room_id = self.helper.create_room_as(user1_id, tok=user1_tok) - - # Create a space room - space_room_id = self.helper.create_room_as( - user1_id, - tok=user1_tok, - extra_content={ - "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} - }, - ) - - after_rooms_token = self.event_sources.get_current_token() - - # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=None, - to_token=after_rooms_token, - ) - - # Try finding only normal rooms - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters(room_types=[None]), - after_rooms_token, - ) - ) - - # `remote_invite_room_id` should appear here because it is a normal room - # according to the stripped state (no room type) - self.assertEqual(filtered_room_map.keys(), {room_id, remote_invite_room_id}) - - # Try finding only spaces - filtered_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms( - UserID.from_string(user1_id), - sync_room_map, - SlidingSyncConfig.SlidingSyncList.Filters(room_types=[RoomTypes.SPACE]), - after_rooms_token, - ) - ) - - # `remote_invite_room_id` should not appear here because it is a normal room - # according to the stripped state (no room type) - self.assertEqual(filtered_room_map.keys(), {space_room_id}) + self.assertTrue(room_id1 in newly_left) class SortRoomsTestCase(HomeserverTestCase): @@ -4361,25 +3615,26 @@ class SortRoomsTestCase(HomeserverTestCase): user: UserID, to_token: StreamToken, from_token: Optional[StreamToken], - ) -> Dict[str, _RoomMembershipForUser]: + ) -> Tuple[Dict[str, RoomsForUserType], AbstractSet[str], AbstractSet[str]]: """ Get the rooms the user should be syncing with """ - room_membership_for_user_map = self.get_success( - self.sliding_sync_handler.get_room_membership_for_user_at_to_token( + room_membership_for_user_map, newly_joined, newly_left = self.get_success( + self.sliding_sync_handler.room_lists.get_room_membership_for_user_at_to_token( user=user, from_token=from_token, to_token=to_token, ) ) filtered_sync_room_map = self.get_success( - self.sliding_sync_handler.filter_rooms_relevant_for_sync( + self.sliding_sync_handler.room_lists.filter_rooms_relevant_for_sync( user=user, room_membership_for_user_map=room_membership_for_user_map, + newly_left_room_ids=newly_left, ) ) - return filtered_sync_room_map + return filtered_sync_room_map, newly_joined, newly_left def test_sort_activity_basic(self) -> None: """ @@ -4400,7 +3655,7 @@ class SortRoomsTestCase(HomeserverTestCase): after_rooms_token = self.event_sources.get_current_token() # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( + sync_room_map, newly_joined, newly_left = self._get_sync_room_ids_for_user( UserID.from_string(user1_id), from_token=None, to_token=after_rooms_token, @@ -4408,7 +3663,7 @@ class SortRoomsTestCase(HomeserverTestCase): # Sort the rooms (what we're testing) sorted_sync_rooms = self.get_success( - self.sliding_sync_handler.sort_rooms( + self.sliding_sync_handler.room_lists.sort_rooms( sync_room_map=sync_room_map, to_token=after_rooms_token, ) @@ -4481,7 +3736,7 @@ class SortRoomsTestCase(HomeserverTestCase): self.helper.send(room_id3, "activity in room3", tok=user2_tok) # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( + sync_room_map, newly_joined, newly_left = self._get_sync_room_ids_for_user( UserID.from_string(user1_id), from_token=before_rooms_token, to_token=after_rooms_token, @@ -4489,7 +3744,7 @@ class SortRoomsTestCase(HomeserverTestCase): # Sort the rooms (what we're testing) sorted_sync_rooms = self.get_success( - self.sliding_sync_handler.sort_rooms( + self.sliding_sync_handler.room_lists.sort_rooms( sync_room_map=sync_room_map, to_token=after_rooms_token, ) @@ -4545,7 +3800,7 @@ class SortRoomsTestCase(HomeserverTestCase): after_rooms_token = self.event_sources.get_current_token() # Get the rooms the user should be syncing with - sync_room_map = self._get_sync_room_ids_for_user( + sync_room_map, newly_joined, newly_left = self._get_sync_room_ids_for_user( UserID.from_string(user1_id), from_token=None, to_token=after_rooms_token, @@ -4553,7 +3808,7 @@ class SortRoomsTestCase(HomeserverTestCase): # Sort the rooms (what we're testing) sorted_sync_rooms = self.get_success( - self.sliding_sync_handler.sort_rooms( + self.sliding_sync_handler.room_lists.sort_rooms( sync_room_map=sync_room_map, to_token=after_rooms_token, ) @@ -4565,3 +3820,1071 @@ class SortRoomsTestCase(HomeserverTestCase): # We only care about the *latest* event in the room. [room_id1, room_id2], ) + + +@attr.s(slots=True, auto_attribs=True, frozen=True) +class RequiredStateChangesTestParameters: + previous_required_state_map: Dict[str, Set[str]] + request_required_state_map: Dict[str, Set[str]] + state_deltas: StateMap[str] + expected_with_state_deltas: Tuple[ + Optional[Mapping[str, AbstractSet[str]]], StateFilter + ] + expected_without_state_deltas: Tuple[ + Optional[Mapping[str, AbstractSet[str]]], StateFilter + ] + + +class RequiredStateChangesTestCase(unittest.TestCase): + """Test cases for `_required_state_changes`""" + + @parameterized.expand( + [ + ( + "simple_no_change", + """Test no change to required state""", + RequiredStateChangesTestParameters( + previous_required_state_map={"type1": {"state_key"}}, + request_required_state_map={"type1": {"state_key"}}, + state_deltas={("type1", "state_key"): "$event_id"}, + # No changes + expected_with_state_deltas=(None, StateFilter.none()), + expected_without_state_deltas=(None, StateFilter.none()), + ), + ), + ( + "simple_add_type", + """Test adding a type to the config""", + RequiredStateChangesTestParameters( + previous_required_state_map={"type1": {"state_key"}}, + request_required_state_map={ + "type1": {"state_key"}, + "type2": {"state_key"}, + }, + state_deltas={("type2", "state_key"): "$event_id"}, + expected_with_state_deltas=( + # We've added a type so we should persist the changed required state + # config. + {"type1": {"state_key"}, "type2": {"state_key"}}, + # We should see the new type added + StateFilter.from_types([("type2", "state_key")]), + ), + expected_without_state_deltas=( + {"type1": {"state_key"}, "type2": {"state_key"}}, + StateFilter.from_types([("type2", "state_key")]), + ), + ), + ), + ( + "simple_add_type_from_nothing", + """Test adding a type to the config when previously requesting nothing""", + RequiredStateChangesTestParameters( + previous_required_state_map={}, + request_required_state_map={ + "type1": {"state_key"}, + "type2": {"state_key"}, + }, + state_deltas={("type2", "state_key"): "$event_id"}, + expected_with_state_deltas=( + # We've added a type so we should persist the changed required state + # config. + {"type1": {"state_key"}, "type2": {"state_key"}}, + # We should see the new types added + StateFilter.from_types( + [("type1", "state_key"), ("type2", "state_key")] + ), + ), + expected_without_state_deltas=( + {"type1": {"state_key"}, "type2": {"state_key"}}, + StateFilter.from_types( + [("type1", "state_key"), ("type2", "state_key")] + ), + ), + ), + ), + ( + "simple_add_state_key", + """Test adding a state key to the config""", + RequiredStateChangesTestParameters( + previous_required_state_map={"type": {"state_key1"}}, + request_required_state_map={"type": {"state_key1", "state_key2"}}, + state_deltas={("type", "state_key2"): "$event_id"}, + expected_with_state_deltas=( + # We've added a key so we should persist the changed required state + # config. + {"type": {"state_key1", "state_key2"}}, + # We should see the new state_keys added + StateFilter.from_types([("type", "state_key2")]), + ), + expected_without_state_deltas=( + {"type": {"state_key1", "state_key2"}}, + StateFilter.from_types([("type", "state_key2")]), + ), + ), + ), + ( + "simple_retain_previous_state_keys", + """Test adding a state key to the config and retaining a previously sent state_key""", + RequiredStateChangesTestParameters( + previous_required_state_map={"type": {"state_key1"}}, + request_required_state_map={"type": {"state_key2", "state_key3"}}, + state_deltas={("type", "state_key2"): "$event_id"}, + expected_with_state_deltas=( + # We've added a key so we should persist the changed required state + # config. + # + # Retain `state_key1` from the `previous_required_state_map` + {"type": {"state_key1", "state_key2", "state_key3"}}, + # We should see the new state_keys added + StateFilter.from_types( + [("type", "state_key2"), ("type", "state_key3")] + ), + ), + expected_without_state_deltas=( + {"type": {"state_key1", "state_key2", "state_key3"}}, + StateFilter.from_types( + [("type", "state_key2"), ("type", "state_key3")] + ), + ), + ), + ), + ( + "simple_remove_type", + """ + Test removing a type from the config when there are a matching state + delta does cause the persisted required state config to change + + Test removing a type from the config when there are no matching state + deltas does *not* cause the persisted required state config to change + """, + RequiredStateChangesTestParameters( + previous_required_state_map={ + "type1": {"state_key"}, + "type2": {"state_key"}, + }, + request_required_state_map={"type1": {"state_key"}}, + state_deltas={("type2", "state_key"): "$event_id"}, + expected_with_state_deltas=( + # Remove `type2` since there's been a change to that state, + # (persist the change to required state). That way next time, + # they request `type2`, we see that we haven't sent it before + # and send the new state. (we should still keep track that we've + # sent `type1` before). + {"type1": {"state_key"}}, + # We don't need to request anything more if they are requesting + # less state now + StateFilter.none(), + ), + expected_without_state_deltas=( + # `type2` is no longer requested but since that state hasn't + # changed, nothing should change (we should still keep track + # that we've sent `type2` before). + None, + # We don't need to request anything more if they are requesting + # less state now + StateFilter.none(), + ), + ), + ), + ( + "simple_remove_type_to_nothing", + """ + Test removing a type from the config and no longer requesting any state + """, + RequiredStateChangesTestParameters( + previous_required_state_map={ + "type1": {"state_key"}, + "type2": {"state_key"}, + }, + request_required_state_map={}, + state_deltas={("type2", "state_key"): "$event_id"}, + expected_with_state_deltas=( + # Remove `type2` since there's been a change to that state, + # (persist the change to required state). That way next time, + # they request `type2`, we see that we haven't sent it before + # and send the new state. (we should still keep track that we've + # sent `type1` before). + {"type1": {"state_key"}}, + # We don't need to request anything more if they are requesting + # less state now + StateFilter.none(), + ), + expected_without_state_deltas=( + # `type2` is no longer requested but since that state hasn't + # changed, nothing should change (we should still keep track + # that we've sent `type2` before). + None, + # We don't need to request anything more if they are requesting + # less state now + StateFilter.none(), + ), + ), + ), + ( + "simple_remove_state_key", + """ + Test removing a state_key from the config + """, + RequiredStateChangesTestParameters( + previous_required_state_map={"type": {"state_key1", "state_key2"}}, + request_required_state_map={"type": {"state_key1"}}, + state_deltas={("type", "state_key2"): "$event_id"}, + expected_with_state_deltas=( + # Remove `(type, state_key2)` since there's been a change + # to that state (persist the change to required state). + # That way next time, they request `(type, state_key2)`, we see + # that we haven't sent it before and send the new state. (we + # should still keep track that we've sent `(type, state_key1)` + # before). + {"type": {"state_key1"}}, + # We don't need to request anything more if they are requesting + # less state now + StateFilter.none(), + ), + expected_without_state_deltas=( + # `(type, state_key2)` is no longer requested but since that + # state hasn't changed, nothing should change (we should still + # keep track that we've sent `(type, state_key1)` and `(type, + # state_key2)` before). + None, + # We don't need to request anything more if they are requesting + # less state now + StateFilter.none(), + ), + ), + ), + ( + "type_wildcards_add", + """ + Test adding a wildcard type causes the persisted required state config + to change and we request everything. + + If a event type wildcard has been added or removed we don't try and do + anything fancy, and instead always update the effective room required + state config to match the request. + """, + RequiredStateChangesTestParameters( + previous_required_state_map={"type1": {"state_key2"}}, + request_required_state_map={ + "type1": {"state_key2"}, + StateValues.WILDCARD: {"state_key"}, + }, + state_deltas={ + ("other_type", "state_key"): "$event_id", + }, + # We've added a wildcard, so we persist the change and request everything + expected_with_state_deltas=( + {"type1": {"state_key2"}, StateValues.WILDCARD: {"state_key"}}, + StateFilter.all(), + ), + expected_without_state_deltas=( + {"type1": {"state_key2"}, StateValues.WILDCARD: {"state_key"}}, + StateFilter.all(), + ), + ), + ), + ( + "type_wildcards_remove", + """ + Test removing a wildcard type causes the persisted required state config + to change and request nothing. + + If a event type wildcard has been added or removed we don't try and do + anything fancy, and instead always update the effective room required + state config to match the request. + """, + RequiredStateChangesTestParameters( + previous_required_state_map={ + "type1": {"state_key2"}, + StateValues.WILDCARD: {"state_key"}, + }, + request_required_state_map={"type1": {"state_key2"}}, + state_deltas={ + ("other_type", "state_key"): "$event_id", + }, + # We've removed a type wildcard, so we persist the change but don't request anything + expected_with_state_deltas=( + {"type1": {"state_key2"}}, + # We don't need to request anything more if they are requesting + # less state now + StateFilter.none(), + ), + expected_without_state_deltas=( + {"type1": {"state_key2"}}, + # We don't need to request anything more if they are requesting + # less state now + StateFilter.none(), + ), + ), + ), + ( + "state_key_wildcards_add", + """Test adding a wildcard state_key""", + RequiredStateChangesTestParameters( + previous_required_state_map={"type1": {"state_key"}}, + request_required_state_map={ + "type1": {"state_key"}, + "type2": {StateValues.WILDCARD}, + }, + state_deltas={("type2", "state_key"): "$event_id"}, + # We've added a wildcard state_key, so we persist the change and + # request all of the state for that type + expected_with_state_deltas=( + {"type1": {"state_key"}, "type2": {StateValues.WILDCARD}}, + StateFilter.from_types([("type2", None)]), + ), + expected_without_state_deltas=( + {"type1": {"state_key"}, "type2": {StateValues.WILDCARD}}, + StateFilter.from_types([("type2", None)]), + ), + ), + ), + ( + "state_key_wildcards_remove", + """Test removing a wildcard state_key""", + RequiredStateChangesTestParameters( + previous_required_state_map={ + "type1": {"state_key"}, + "type2": {StateValues.WILDCARD}, + }, + request_required_state_map={"type1": {"state_key"}}, + state_deltas={("type2", "state_key"): "$event_id"}, + # We've removed a state_key wildcard, so we persist the change and + # request nothing + expected_with_state_deltas=( + {"type1": {"state_key"}}, + # We don't need to request anything more if they are requesting + # less state now + StateFilter.none(), + ), + # We've removed a state_key wildcard but there have been no matching + # state changes, so no changes needed, just persist the + # `request_required_state_map` as-is. + expected_without_state_deltas=( + None, + # We don't need to request anything more if they are requesting + # less state now + StateFilter.none(), + ), + ), + ), + ( + "state_key_remove_some", + """ + Test that removing state keys work when only some of the state keys have + changed + """, + RequiredStateChangesTestParameters( + previous_required_state_map={ + "type1": {"state_key1", "state_key2", "state_key3"} + }, + request_required_state_map={"type1": {"state_key1"}}, + state_deltas={("type1", "state_key3"): "$event_id"}, + expected_with_state_deltas=( + # We've removed some state keys from the type, but only state_key3 was + # changed so only that one should be removed. + {"type1": {"state_key1", "state_key2"}}, + # We don't need to request anything more if they are requesting + # less state now + StateFilter.none(), + ), + expected_without_state_deltas=( + # No changes needed, just persist the + # `request_required_state_map` as-is + None, + # We don't need to request anything more if they are requesting + # less state now + StateFilter.none(), + ), + ), + ), + ( + "state_key_me_add", + """ + Test adding state keys work when using "$ME" + """, + RequiredStateChangesTestParameters( + previous_required_state_map={}, + request_required_state_map={"type1": {StateValues.ME}}, + state_deltas={("type1", "@user:test"): "$event_id"}, + expected_with_state_deltas=( + # We've added a type so we should persist the changed required state + # config. + {"type1": {StateValues.ME}}, + # We should see the new state_keys added + StateFilter.from_types([("type1", "@user:test")]), + ), + expected_without_state_deltas=( + {"type1": {StateValues.ME}}, + StateFilter.from_types([("type1", "@user:test")]), + ), + ), + ), + ( + "state_key_me_remove", + """ + Test removing state keys work when using "$ME" + """, + RequiredStateChangesTestParameters( + previous_required_state_map={"type1": {StateValues.ME}}, + request_required_state_map={}, + state_deltas={("type1", "@user:test"): "$event_id"}, + expected_with_state_deltas=( + # Remove `type1` since there's been a change to that state, + # (persist the change to required state). That way next time, + # they request `type1`, we see that we haven't sent it before + # and send the new state. (if we were tracking that we sent any + # other state, we should still keep track that). + {}, + # We don't need to request anything more if they are requesting + # less state now + StateFilter.none(), + ), + expected_without_state_deltas=( + # `type1` is no longer requested but since that state hasn't + # changed, nothing should change (we should still keep track + # that we've sent `type1` before). + None, + # We don't need to request anything more if they are requesting + # less state now + StateFilter.none(), + ), + ), + ), + ( + "state_key_user_id_add", + """ + Test adding state keys work when using your own user ID + """, + RequiredStateChangesTestParameters( + previous_required_state_map={}, + request_required_state_map={"type1": {"@user:test"}}, + state_deltas={("type1", "@user:test"): "$event_id"}, + expected_with_state_deltas=( + # We've added a type so we should persist the changed required state + # config. + {"type1": {"@user:test"}}, + # We should see the new state_keys added + StateFilter.from_types([("type1", "@user:test")]), + ), + expected_without_state_deltas=( + {"type1": {"@user:test"}}, + StateFilter.from_types([("type1", "@user:test")]), + ), + ), + ), + ( + "state_key_me_remove", + """ + Test removing state keys work when using your own user ID + """, + RequiredStateChangesTestParameters( + previous_required_state_map={"type1": {"@user:test"}}, + request_required_state_map={}, + state_deltas={("type1", "@user:test"): "$event_id"}, + expected_with_state_deltas=( + # Remove `type1` since there's been a change to that state, + # (persist the change to required state). That way next time, + # they request `type1`, we see that we haven't sent it before + # and send the new state. (if we were tracking that we sent any + # other state, we should still keep track that). + {}, + # We don't need to request anything more if they are requesting + # less state now + StateFilter.none(), + ), + expected_without_state_deltas=( + # `type1` is no longer requested but since that state hasn't + # changed, nothing should change (we should still keep track + # that we've sent `type1` before). + None, + # We don't need to request anything more if they are requesting + # less state now + StateFilter.none(), + ), + ), + ), + ( + "state_key_lazy_add", + """ + Test adding state keys work when using "$LAZY" + """, + RequiredStateChangesTestParameters( + previous_required_state_map={}, + request_required_state_map={EventTypes.Member: {StateValues.LAZY}}, + state_deltas={(EventTypes.Member, "@user:test"): "$event_id"}, + expected_with_state_deltas=( + # If a "$LAZY" has been added or removed we always update the + # required state to what was requested for simplicity. + {EventTypes.Member: {StateValues.LAZY}}, + StateFilter.none(), + ), + expected_without_state_deltas=( + {EventTypes.Member: {StateValues.LAZY}}, + StateFilter.none(), + ), + ), + ), + ( + "state_key_lazy_remove", + """ + Test removing state keys work when using "$LAZY" + """, + RequiredStateChangesTestParameters( + previous_required_state_map={EventTypes.Member: {StateValues.LAZY}}, + request_required_state_map={}, + state_deltas={(EventTypes.Member, "@user:test"): "$event_id"}, + expected_with_state_deltas=( + # If a "$LAZY" has been added or removed we always update the + # required state to what was requested for simplicity. + {}, + # We don't need to request anything more if they are requesting + # less state now + StateFilter.none(), + ), + expected_without_state_deltas=( + # `EventTypes.Member` is no longer requested but since that + # state hasn't changed, nothing should change (we should still + # keep track that we've sent `EventTypes.Member` before). + None, + # We don't need to request anything more if they are requesting + # less state now + StateFilter.none(), + ), + ), + ), + ( + "state_key_lazy_keep_previous_memberships_and_no_new_memberships", + """ + This test mimics a request with lazy-loading room members enabled where + we have previously sent down user2 and user3's membership events and now + we're sending down another response without any timeline events. + """, + RequiredStateChangesTestParameters( + previous_required_state_map={ + EventTypes.Member: { + StateValues.LAZY, + "@user2:test", + "@user3:test", + } + }, + request_required_state_map={EventTypes.Member: {StateValues.LAZY}}, + state_deltas={(EventTypes.Member, "@user2:test"): "$event_id"}, + expected_with_state_deltas=( + # Remove "@user2:test" since that state has changed and is no + # longer being requested anymore. Since something was removed, + # we should persist the changed to required state. That way next + # time, they request "@user2:test", we see that we haven't sent + # it before and send the new state. (we should still keep track + # that we've sent specific `EventTypes.Member` before) + { + EventTypes.Member: { + StateValues.LAZY, + "@user3:test", + } + }, + # We don't need to request anything more if they are requesting + # less state now + StateFilter.none(), + ), + expected_without_state_deltas=( + # We're not requesting any specific `EventTypes.Member` now but + # since that state hasn't changed, nothing should change (we + # should still keep track that we've sent specific + # `EventTypes.Member` before). + None, + # We don't need to request anything more if they are requesting + # less state now + StateFilter.none(), + ), + ), + ), + ( + "state_key_lazy_keep_previous_memberships_with_new_memberships", + """ + This test mimics a request with lazy-loading room members enabled where + we have previously sent down user2 and user3's membership events and now + we're sending down another response with a new event from user4. + """, + RequiredStateChangesTestParameters( + previous_required_state_map={ + EventTypes.Member: { + StateValues.LAZY, + "@user2:test", + "@user3:test", + } + }, + request_required_state_map={ + EventTypes.Member: {StateValues.LAZY, "@user4:test"} + }, + state_deltas={(EventTypes.Member, "@user2:test"): "$event_id"}, + expected_with_state_deltas=( + # Since "@user4:test" was added, we should persist the changed + # required state config. + # + # Also remove "@user2:test" since that state has changed and is no + # longer being requested anymore. Since something was removed, + # we also should persist the changed to required state. That way next + # time, they request "@user2:test", we see that we haven't sent + # it before and send the new state. (we should still keep track + # that we've sent specific `EventTypes.Member` before) + { + EventTypes.Member: { + StateValues.LAZY, + "@user3:test", + "@user4:test", + } + }, + # We should see the new state_keys added + StateFilter.from_types([(EventTypes.Member, "@user4:test")]), + ), + expected_without_state_deltas=( + # Since "@user4:test" was added, we should persist the changed + # required state config. + { + EventTypes.Member: { + StateValues.LAZY, + "@user2:test", + "@user3:test", + "@user4:test", + } + }, + # We should see the new state_keys added + StateFilter.from_types([(EventTypes.Member, "@user4:test")]), + ), + ), + ), + ( + "state_key_expand_lazy_keep_previous_memberships", + """ + Test expanding the `required_state` to lazy-loading room members. + """, + RequiredStateChangesTestParameters( + previous_required_state_map={ + EventTypes.Member: {"@user2:test", "@user3:test"} + }, + request_required_state_map={EventTypes.Member: {StateValues.LAZY}}, + state_deltas={(EventTypes.Member, "@user2:test"): "$event_id"}, + expected_with_state_deltas=( + # Since `StateValues.LAZY` was added, we should persist the + # changed required state config. + # + # Also remove "@user2:test" since that state has changed and is no + # longer being requested anymore. Since something was removed, + # we also should persist the changed to required state. That way next + # time, they request "@user2:test", we see that we haven't sent + # it before and send the new state. (we should still keep track + # that we've sent specific `EventTypes.Member` before) + { + EventTypes.Member: { + StateValues.LAZY, + "@user3:test", + } + }, + # We don't need to request anything more if they are requesting + # less state now + StateFilter.none(), + ), + expected_without_state_deltas=( + # Since `StateValues.LAZY` was added, we should persist the + # changed required state config. + { + EventTypes.Member: { + StateValues.LAZY, + "@user2:test", + "@user3:test", + } + }, + # We don't need to request anything more if they are requesting + # less state now + StateFilter.none(), + ), + ), + ), + ( + "state_key_retract_lazy_keep_previous_memberships_no_new_memberships", + """ + Test retracting the `required_state` to no longer lazy-loading room members. + """, + RequiredStateChangesTestParameters( + previous_required_state_map={ + EventTypes.Member: { + StateValues.LAZY, + "@user2:test", + "@user3:test", + } + }, + request_required_state_map={}, + state_deltas={(EventTypes.Member, "@user2:test"): "$event_id"}, + expected_with_state_deltas=( + # Remove `EventTypes.Member` since there's been a change to that + # state, (persist the change to required state). That way next + # time, they request `EventTypes.Member`, we see that we haven't + # sent it before and send the new state. (if we were tracking + # that we sent any other state, we should still keep track + # that). + # + # This acts the same as the `simple_remove_type` test. It's + # possible that we could remember the specific `state_keys` that + # we have sent down before but this currently just acts the same + # as if a whole `type` was removed. Perhaps it's good that we + # "garbage collect" and forget what we've sent before for a + # given `type` when the client stops caring about a certain + # `type`. + {}, + # We don't need to request anything more if they are requesting + # less state now + StateFilter.none(), + ), + expected_without_state_deltas=( + # `EventTypes.Member` is no longer requested but since that + # state hasn't changed, nothing should change (we should still + # keep track that we've sent `EventTypes.Member` before). + None, + # We don't need to request anything more if they are requesting + # less state now + StateFilter.none(), + ), + ), + ), + ( + "state_key_retract_lazy_keep_previous_memberships_with_new_memberships", + """ + Test retracting the `required_state` to no longer lazy-loading room members. + """, + RequiredStateChangesTestParameters( + previous_required_state_map={ + EventTypes.Member: { + StateValues.LAZY, + "@user2:test", + "@user3:test", + } + }, + request_required_state_map={EventTypes.Member: {"@user4:test"}}, + state_deltas={(EventTypes.Member, "@user2:test"): "$event_id"}, + expected_with_state_deltas=( + # Since "@user4:test" was added, we should persist the changed + # required state config. + # + # Also remove "@user2:test" since that state has changed and is no + # longer being requested anymore. Since something was removed, + # we also should persist the changed to required state. That way next + # time, they request "@user2:test", we see that we haven't sent + # it before and send the new state. (we should still keep track + # that we've sent specific `EventTypes.Member` before) + { + EventTypes.Member: { + "@user3:test", + "@user4:test", + } + }, + # We should see the new state_keys added + StateFilter.from_types([(EventTypes.Member, "@user4:test")]), + ), + expected_without_state_deltas=( + # Since "@user4:test" was added, we should persist the changed + # required state config. + { + EventTypes.Member: { + "@user2:test", + "@user3:test", + "@user4:test", + } + }, + # We should see the new state_keys added + StateFilter.from_types([(EventTypes.Member, "@user4:test")]), + ), + ), + ), + ( + "type_wildcard_with_state_key_wildcard_to_explicit_state_keys", + """ + Test switching from a wildcard ("*", "*") to explicit state keys + """, + RequiredStateChangesTestParameters( + previous_required_state_map={ + StateValues.WILDCARD: {StateValues.WILDCARD} + }, + request_required_state_map={ + StateValues.WILDCARD: {"state_key1", "state_key2", "state_key3"} + }, + state_deltas={("type1", "state_key1"): "$event_id"}, + # If we were previously fetching everything ("*", "*"), always update the effective + # room required state config to match the request. And since we we're previously + # already fetching everything, we don't have to fetch anything now that they've + # narrowed. + expected_with_state_deltas=( + { + StateValues.WILDCARD: { + "state_key1", + "state_key2", + "state_key3", + } + }, + StateFilter.none(), + ), + expected_without_state_deltas=( + { + StateValues.WILDCARD: { + "state_key1", + "state_key2", + "state_key3", + } + }, + StateFilter.none(), + ), + ), + ), + ( + "type_wildcard_with_explicit_state_keys_to_wildcard_state_key", + """ + Test switching from explicit to wildcard state keys ("*", "*") + """, + RequiredStateChangesTestParameters( + previous_required_state_map={ + StateValues.WILDCARD: {"state_key1", "state_key2", "state_key3"} + }, + request_required_state_map={ + StateValues.WILDCARD: {StateValues.WILDCARD} + }, + state_deltas={("type1", "state_key1"): "$event_id"}, + # We've added a wildcard, so we persist the change and request everything + expected_with_state_deltas=( + {StateValues.WILDCARD: {StateValues.WILDCARD}}, + StateFilter.all(), + ), + expected_without_state_deltas=( + {StateValues.WILDCARD: {StateValues.WILDCARD}}, + StateFilter.all(), + ), + ), + ), + ( + "state_key_wildcard_to_explicit_state_keys", + """Test switching from a wildcard to explicit state keys with a concrete type""", + RequiredStateChangesTestParameters( + previous_required_state_map={"type1": {StateValues.WILDCARD}}, + request_required_state_map={ + "type1": {"state_key1", "state_key2", "state_key3"} + }, + state_deltas={("type1", "state_key1"): "$event_id"}, + # If a state_key wildcard has been added or removed, we always + # update the effective room required state config to match the + # request. And since we we're previously already fetching + # everything, we don't have to fetch anything now that they've + # narrowed. + expected_with_state_deltas=( + { + "type1": { + "state_key1", + "state_key2", + "state_key3", + } + }, + StateFilter.none(), + ), + expected_without_state_deltas=( + { + "type1": { + "state_key1", + "state_key2", + "state_key3", + } + }, + StateFilter.none(), + ), + ), + ), + ( + "explicit_state_keys_to_wildcard_state_key", + """Test switching from a wildcard to explicit state keys with a concrete type""", + RequiredStateChangesTestParameters( + previous_required_state_map={ + "type1": {"state_key1", "state_key2", "state_key3"} + }, + request_required_state_map={"type1": {StateValues.WILDCARD}}, + state_deltas={("type1", "state_key1"): "$event_id"}, + # If a state_key wildcard has been added or removed, we always + # update the effective room required state config to match the + # request. And we need to request all of the state for that type + # because we previously, only sent down a few keys. + expected_with_state_deltas=( + {"type1": {StateValues.WILDCARD, "state_key2", "state_key3"}}, + StateFilter.from_types([("type1", None)]), + ), + expected_without_state_deltas=( + { + "type1": { + StateValues.WILDCARD, + "state_key1", + "state_key2", + "state_key3", + } + }, + StateFilter.from_types([("type1", None)]), + ), + ), + ), + ] + ) + def test_xxx( + self, + _test_label: str, + _test_description: str, + test_parameters: RequiredStateChangesTestParameters, + ) -> None: + # Without `state_deltas` + changed_required_state_map, added_state_filter = _required_state_changes( + user_id="@user:test", + prev_required_state_map=test_parameters.previous_required_state_map, + request_required_state_map=test_parameters.request_required_state_map, + state_deltas={}, + ) + + self.assertEqual( + changed_required_state_map, + test_parameters.expected_without_state_deltas[0], + "changed_required_state_map does not match (without state_deltas)", + ) + self.assertEqual( + added_state_filter, + test_parameters.expected_without_state_deltas[1], + "added_state_filter does not match (without state_deltas)", + ) + + # With `state_deltas` + changed_required_state_map, added_state_filter = _required_state_changes( + user_id="@user:test", + prev_required_state_map=test_parameters.previous_required_state_map, + request_required_state_map=test_parameters.request_required_state_map, + state_deltas=test_parameters.state_deltas, + ) + + self.assertEqual( + changed_required_state_map, + test_parameters.expected_with_state_deltas[0], + "changed_required_state_map does not match (with state_deltas)", + ) + self.assertEqual( + added_state_filter, + test_parameters.expected_with_state_deltas[1], + "added_state_filter does not match (with state_deltas)", + ) + + @parameterized.expand( + [ + # Test with a normal arbitrary type (no special meaning) + ("arbitrary_type", "type", set()), + # Test with membership + ("membership", EventTypes.Member, set()), + # Test with lazy-loading room members + ("lazy_loading_membership", EventTypes.Member, {StateValues.LAZY}), + ] + ) + def test_limit_retained_previous_state_keys( + self, + _test_label: str, + event_type: str, + extra_state_keys: Set[str], + ) -> None: + """ + Test that we limit the number of state_keys that we remember but always include + the state_keys that we've just requested. + """ + previous_required_state_map = { + event_type: { + # Prefix the state_keys we've "prev_"iously sent so they are easier to + # identify in our assertions. + f"prev_state_key{i}" + for i in range(MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER - 30) + } + | extra_state_keys + } + request_required_state_map = { + event_type: {f"state_key{i}" for i in range(50)} | extra_state_keys + } + + # (function under test) + changed_required_state_map, added_state_filter = _required_state_changes( + user_id="@user:test", + prev_required_state_map=previous_required_state_map, + request_required_state_map=request_required_state_map, + state_deltas={}, + ) + assert changed_required_state_map is not None + + # We should only remember up to the maximum number of state keys + self.assertGreaterEqual( + len(changed_required_state_map[event_type]), + # Most of the time this will be `MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER` but + # because we are just naively selecting enough previous state_keys to fill + # the limit, there might be some overlap in what's added back which means we + # might have slightly less than the limit. + # + # `extra_state_keys` overlaps in the previous and requested + # `required_state_map` so we might see this this scenario. + MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER - len(extra_state_keys), + ) + + # Should include all of the requested state + self.assertIncludes( + changed_required_state_map[event_type], + request_required_state_map[event_type], + ) + # And the rest is filled with the previous state keys + # + # We can't assert the exact state_keys since we don't know the order so we just + # check that they all start with "prev_" and that we have the correct amount. + remaining_state_keys = ( + changed_required_state_map[event_type] + - request_required_state_map[event_type] + ) + self.assertGreater( + len(remaining_state_keys), + 0, + ) + assert all( + state_key.startswith("prev_") for state_key in remaining_state_keys + ), "Remaining state_keys should be the previous state_keys" + + def test_request_more_state_keys_than_remember_limit(self) -> None: + """ + Test requesting more state_keys than fit in our limit to remember from previous + requests. + """ + previous_required_state_map = { + "type": { + # Prefix the state_keys we've "prev_"iously sent so they are easier to + # identify in our assertions. + f"prev_state_key{i}" + for i in range(MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER - 30) + } + } + request_required_state_map = { + "type": { + f"state_key{i}" + # Requesting more than the MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER + for i in range(MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER + 20) + } + } + # Ensure that we are requesting more than the limit + self.assertGreater( + len(request_required_state_map["type"]), + MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER, + ) + + # (function under test) + changed_required_state_map, added_state_filter = _required_state_changes( + user_id="@user:test", + prev_required_state_map=previous_required_state_map, + request_required_state_map=request_required_state_map, + state_deltas={}, + ) + assert changed_required_state_map is not None + + # Should include all of the requested state + self.assertIncludes( + changed_required_state_map["type"], + request_required_state_map["type"], + exact=True, + ) diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index fa55f76916..6b202dfbd5 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py
@@ -17,10 +17,11 @@ # [This file includes modifications made by New Vector Limited] # # +from http import HTTPStatus from typing import Collection, ContextManager, List, Optional from unittest.mock import AsyncMock, Mock, patch -from parameterized import parameterized +from parameterized import parameterized, parameterized_class from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor @@ -32,7 +33,13 @@ from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.federation.federation_base import event_from_pdu_json -from synapse.handlers.sync import SyncConfig, SyncRequestKey, SyncResult, SyncVersion +from synapse.handlers.sync import ( + SyncConfig, + SyncRequestKey, + SyncResult, + SyncVersion, + TimelineBatch, +) from synapse.rest import admin from synapse.rest.client import knock, login, room from synapse.server import HomeServer @@ -58,9 +65,21 @@ def generate_request_key() -> SyncRequestKey: return ("request_key", _request_key) +@parameterized_class( + ("use_state_after",), + [ + (True,), + (False,), + ], + class_name_func=lambda cls, + num, + params_dict: f"{cls.__name__}_{'state_after' if params_dict['use_state_after'] else 'state'}", +) class SyncTestCase(tests.unittest.HomeserverTestCase): """Tests Sync Handler.""" + use_state_after: bool + servlets = [ admin.register_servlets, knock.register_servlets, @@ -79,7 +98,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): def test_wait_for_sync_for_user_auth_blocking(self) -> None: user_id1 = "@user1:test" user_id2 = "@user2:test" - sync_config = generate_sync_config(user_id1) + sync_config = generate_sync_config( + user_id1, use_state_after=self.use_state_after + ) requester = create_requester(user_id1) self.reactor.advance(100) # So we get not 0 time @@ -112,7 +133,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): self.auth_blocking._hs_disabled = False - sync_config = generate_sync_config(user_id2) + sync_config = generate_sync_config( + user_id2, use_state_after=self.use_state_after + ) requester = create_requester(user_id2) e = self.get_failure( @@ -141,7 +164,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): initial_result = self.get_success( self.sync_handler.wait_for_sync_for_user( requester, - sync_config=generate_sync_config(user, device_id="dev"), + sync_config=generate_sync_config( + user, device_id="dev", use_state_after=self.use_state_after + ), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), ) @@ -175,7 +200,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): result = self.get_success( self.sync_handler.wait_for_sync_for_user( requester, - sync_config=generate_sync_config(user), + sync_config=generate_sync_config( + user, use_state_after=self.use_state_after + ), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), ) @@ -188,7 +215,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): result = self.get_success( self.sync_handler.wait_for_sync_for_user( requester, - sync_config=generate_sync_config(user, device_id="dev"), + sync_config=generate_sync_config( + user, device_id="dev", use_state_after=self.use_state_after + ), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), since_token=initial_result.next_batch, @@ -220,7 +249,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): result = self.get_success( self.sync_handler.wait_for_sync_for_user( requester, - sync_config=generate_sync_config(user), + sync_config=generate_sync_config( + user, use_state_after=self.use_state_after + ), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), ) @@ -233,7 +264,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): result = self.get_success( self.sync_handler.wait_for_sync_for_user( requester, - sync_config=generate_sync_config(user, device_id="dev"), + sync_config=generate_sync_config( + user, device_id="dev", use_state_after=self.use_state_after + ), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), since_token=initial_result.next_batch, @@ -276,7 +309,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): alice_sync_result: SyncResult = self.get_success( self.sync_handler.wait_for_sync_for_user( create_requester(owner), - generate_sync_config(owner), + generate_sync_config(owner, use_state_after=self.use_state_after), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), ) @@ -296,7 +329,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): # Eve syncs. eve_requester = create_requester(eve) - eve_sync_config = generate_sync_config(eve) + eve_sync_config = generate_sync_config( + eve, use_state_after=self.use_state_after + ) eve_sync_after_ban: SyncResult = self.get_success( self.sync_handler.wait_for_sync_for_user( eve_requester, @@ -313,7 +348,15 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): # the prev_events used when creating the join event, such that the ban does not # precede the join. with self._patch_get_latest_events([last_room_creation_event_id]): - self.helper.join(room_id, eve, tok=eve_token) + self.helper.join( + room_id, + eve, + tok=eve_token, + # Previously, this join would succeed but now we expect it to fail at + # this point. The rest of the test is for the case when this used to + # succeed. + expect_code=HTTPStatus.FORBIDDEN, + ) # Eve makes a second, incremental sync. eve_incremental_sync_after_join: SyncResult = self.get_success( @@ -367,7 +410,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): initial_sync_result = self.get_success( self.sync_handler.wait_for_sync_for_user( alice_requester, - generate_sync_config(alice), + generate_sync_config(alice, use_state_after=self.use_state_after), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), ) @@ -396,6 +439,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): filter_collection=FilterCollection( self.hs, {"room": {"timeline": {"limit": 2}}} ), + use_state_after=self.use_state_after, ), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), @@ -442,7 +486,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): initial_sync_result = self.get_success( self.sync_handler.wait_for_sync_for_user( alice_requester, - generate_sync_config(alice), + generate_sync_config(alice, use_state_after=self.use_state_after), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), ) @@ -481,6 +525,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): } }, ), + use_state_after=self.use_state_after, ), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), @@ -518,6 +563,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): ... and a filter that means we only return 1 event, represented by the dashed horizontal lines: `S2` must be included in the `state` section on the second sync. + + When `use_state_after` is enabled, then we expect to see `s2` in the first sync. """ alice = self.register_user("alice", "password") alice_tok = self.login(alice, "password") @@ -528,7 +575,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): initial_sync_result = self.get_success( self.sync_handler.wait_for_sync_for_user( alice_requester, - generate_sync_config(alice), + generate_sync_config(alice, use_state_after=self.use_state_after), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), ) @@ -554,6 +601,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): filter_collection=FilterCollection( self.hs, {"room": {"timeline": {"limit": 1}}} ), + use_state_after=self.use_state_after, ), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), @@ -567,10 +615,18 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): [e.event_id for e in room_sync.timeline.events], [e3_event], ) - self.assertEqual( - [e.event_id for e in room_sync.state.values()], - [], - ) + + if self.use_state_after: + # When using `state_after` we get told about s2 immediately + self.assertEqual( + [e.event_id for e in room_sync.state.values()], + [s2_event], + ) + else: + self.assertEqual( + [e.event_id for e in room_sync.state.values()], + [], + ) # Now send another event that points to S2, but not E3. with self._patch_get_latest_events([s2_event]): @@ -585,6 +641,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): filter_collection=FilterCollection( self.hs, {"room": {"timeline": {"limit": 1}}} ), + use_state_after=self.use_state_after, ), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), @@ -598,10 +655,19 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): [e.event_id for e in room_sync.timeline.events], [e4_event], ) - self.assertEqual( - [e.event_id for e in room_sync.state.values()], - [s2_event], - ) + + if self.use_state_after: + # When using `state_after` we got told about s2 previously, so we + # don't again. + self.assertEqual( + [e.event_id for e in room_sync.state.values()], + [], + ) + else: + self.assertEqual( + [e.event_id for e in room_sync.state.values()], + [s2_event], + ) def test_state_includes_changes_on_ungappy_syncs(self) -> None: """Test `state` where the sync is not gappy. @@ -638,6 +704,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): This is the last chance for us to tell the client about S2, so it *must* be included in the response. + + When `use_state_after` is enabled, then we expect to see `s2` in the first sync. """ alice = self.register_user("alice", "password") alice_tok = self.login(alice, "password") @@ -648,7 +716,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): initial_sync_result = self.get_success( self.sync_handler.wait_for_sync_for_user( alice_requester, - generate_sync_config(alice), + generate_sync_config(alice, use_state_after=self.use_state_after), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), ) @@ -673,6 +741,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): filter_collection=FilterCollection( self.hs, {"room": {"timeline": {"limit": 1}}} ), + use_state_after=self.use_state_after, ), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), @@ -684,7 +753,11 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): [e.event_id for e in room_sync.timeline.events], [e3_event], ) - self.assertNotIn(s2_event, [e.event_id for e in room_sync.state.values()]) + if self.use_state_after: + # When using `state_after` we get told about s2 immediately + self.assertIn(s2_event, [e.event_id for e in room_sync.state.values()]) + else: + self.assertNotIn(s2_event, [e.event_id for e in room_sync.state.values()]) # More events, E4 and E5 with self._patch_get_latest_events([e3_event]): @@ -695,7 +768,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): incremental_sync = self.get_success( self.sync_handler.wait_for_sync_for_user( alice_requester, - generate_sync_config(alice), + generate_sync_config(alice, use_state_after=self.use_state_after), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), since_token=initial_sync_result.next_batch, @@ -710,10 +783,19 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): [e.event_id for e in room_sync.timeline.events], [e4_event, e5_event], ) - self.assertEqual( - [e.event_id for e in room_sync.state.values()], - [s2_event], - ) + + if self.use_state_after: + # When using `state_after` we got told about s2 previously, so we + # don't again. + self.assertEqual( + [e.event_id for e in room_sync.state.values()], + [], + ) + else: + self.assertEqual( + [e.event_id for e in room_sync.state.values()], + [s2_event], + ) @parameterized.expand( [ @@ -721,7 +803,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): (True, False), (False, True), (True, True), - ] + ], + name_func=lambda func, num, p: f"{func.__name__}_{p.args[0]}_{p.args[1]}", ) def test_archived_rooms_do_not_include_state_after_leave( self, initial_sync: bool, empty_timeline: bool @@ -749,7 +832,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): initial_sync_result = self.get_success( self.sync_handler.wait_for_sync_for_user( bob_requester, - generate_sync_config(bob), + generate_sync_config(bob, use_state_after=self.use_state_after), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), ) @@ -780,7 +863,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): self.sync_handler.wait_for_sync_for_user( bob_requester, generate_sync_config( - bob, filter_collection=FilterCollection(self.hs, filter_dict) + bob, + filter_collection=FilterCollection(self.hs, filter_dict), + use_state_after=self.use_state_after, ), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), @@ -791,7 +876,15 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): if empty_timeline: # The timeline should be empty self.assertEqual(sync_room_result.timeline.events, []) + else: + # The last three events in the timeline should be those leading up to the + # leave + self.assertEqual( + [e.event_id for e in sync_room_result.timeline.events[-3:]], + [before_message_event, before_state_event, leave_event], + ) + if empty_timeline or self.use_state_after: # And the state should include the leave event... self.assertEqual( sync_room_result.state[("m.room.member", bob)].event_id, leave_event @@ -801,12 +894,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): sync_room_result.state[("test_state", "")].event_id, before_state_event ) else: - # The last three events in the timeline should be those leading up to the - # leave - self.assertEqual( - [e.event_id for e in sync_room_result.timeline.events[-3:]], - [before_message_event, before_state_event, leave_event], - ) # ... And the state should be empty self.assertEqual(sync_room_result.state, {}) @@ -843,7 +930,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): ) -> List[EventBase]: return list(pdus) - self.client._check_sigs_and_hash_for_pulled_events_and_fetch = _check_sigs_and_hash_for_pulled_events_and_fetch # type: ignore[assignment] + self.client._check_sigs_and_hash_for_pulled_events_and_fetch = ( # type: ignore[method-assign] + _check_sigs_and_hash_for_pulled_events_and_fetch # type: ignore[assignment] + ) prev_events = self.get_success(self.store.get_prev_events_for_room(room_id)) @@ -877,7 +966,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): sync_result: SyncResult = self.get_success( self.sync_handler.wait_for_sync_for_user( create_requester(user), - generate_sync_config(user), + generate_sync_config(user, use_state_after=self.use_state_after), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), ) @@ -926,7 +1015,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): private_sync_result: SyncResult = self.get_success( self.sync_handler.wait_for_sync_for_user( create_requester(user2), - generate_sync_config(user2), + generate_sync_config(user2, use_state_after=self.use_state_after), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), ) @@ -952,7 +1041,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): sync_result: SyncResult = self.get_success( self.sync_handler.wait_for_sync_for_user( create_requester(user), - generate_sync_config(user), + generate_sync_config(user, use_state_after=self.use_state_after), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), ) @@ -989,7 +1078,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): sync_d = defer.ensureDeferred( self.sync_handler.wait_for_sync_for_user( create_requester(user), - generate_sync_config(user), + generate_sync_config(user, use_state_after=self.use_state_after), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), since_token=since_token, @@ -1044,7 +1133,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): sync_d = defer.ensureDeferred( self.sync_handler.wait_for_sync_for_user( create_requester(user), - generate_sync_config(user), + generate_sync_config(user, use_state_after=self.use_state_after), sync_version=SyncVersion.SYNC_V2, request_key=generate_request_key(), since_token=since_token, @@ -1060,6 +1149,7 @@ def generate_sync_config( user_id: str, device_id: Optional[str] = "device_id", filter_collection: Optional[FilterCollection] = None, + use_state_after: bool = False, ) -> SyncConfig: """Generate a sync config (with a unique request key). @@ -1067,7 +1157,8 @@ def generate_sync_config( user_id: user who is syncing. device_id: device that is syncing. Defaults to "device_id". filter_collection: filter to apply. Defaults to the default filter (ie, - return everything, with a default limit) + return everything, with a default limit) + use_state_after: whether the `use_state_after` flag was set. """ if filter_collection is None: filter_collection = Filtering(Mock()).DEFAULT_FILTER_COLLECTION @@ -1077,4 +1168,138 @@ def generate_sync_config( filter_collection=filter_collection, is_guest=False, device_id=device_id, + use_state_after=use_state_after, ) + + +class SyncStateAfterTestCase(tests.unittest.HomeserverTestCase): + """Tests Sync Handler state behavior when using `use_state_after.""" + + servlets = [ + admin.register_servlets, + knock.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.sync_handler = self.hs.get_sync_handler() + self.store = self.hs.get_datastores().main + + # AuthBlocking reads from the hs' config on initialization. We need to + # modify its config instead of the hs' + self.auth_blocking = self.hs.get_auth_blocking() + + def test_initial_sync_multiple_deltas(self) -> None: + """Test that if multiple state deltas have happened during processing of + a full state sync we return the correct state""" + + user = self.register_user("user", "password") + tok = self.login("user", "password") + + # Create a room as the user and set some custom state. + joined_room = self.helper.create_room_as(user, tok=tok) + + first_state = self.helper.send_state( + joined_room, event_type="m.test_event", body={"num": 1}, tok=tok + ) + + # Take a snapshot of the stream token, to simulate doing an initial sync + # at this point. + end_stream_token = self.hs.get_event_sources().get_current_token() + + # Send some state *after* the stream token + self.helper.send_state( + joined_room, event_type="m.test_event", body={"num": 2}, tok=tok + ) + + # Calculating the full state will return the first state, and not the + # second. + state = self.get_success( + self.sync_handler._compute_state_delta_for_full_sync( + room_id=joined_room, + sync_config=generate_sync_config(user, use_state_after=True), + batch=TimelineBatch( + prev_batch=end_stream_token, events=[], limited=True + ), + end_token=end_stream_token, + members_to_fetch=None, + timeline_state={}, + joined=True, + ) + ) + self.assertEqual(state[("m.test_event", "")], first_state["event_id"]) + + def test_incremental_sync_multiple_deltas(self) -> None: + """Test that if multiple state deltas have happened since an incremental + state sync we return the correct state""" + + user = self.register_user("user", "password") + tok = self.login("user", "password") + + # Create a room as the user and set some custom state. + joined_room = self.helper.create_room_as(user, tok=tok) + + # Take a snapshot of the stream token, to simulate doing an incremental sync + # from this point. + since_token = self.hs.get_event_sources().get_current_token() + + self.helper.send_state( + joined_room, event_type="m.test_event", body={"num": 1}, tok=tok + ) + + # Send some state *after* the stream token + second_state = self.helper.send_state( + joined_room, event_type="m.test_event", body={"num": 2}, tok=tok + ) + + end_stream_token = self.hs.get_event_sources().get_current_token() + + # Calculating the incrementals state will return the second state, and not the + # first. + state = self.get_success( + self.sync_handler._compute_state_delta_for_incremental_sync( + room_id=joined_room, + sync_config=generate_sync_config(user, use_state_after=True), + batch=TimelineBatch( + prev_batch=end_stream_token, events=[], limited=True + ), + since_token=since_token, + end_token=end_stream_token, + members_to_fetch=None, + timeline_state={}, + ) + ) + self.assertEqual(state[("m.test_event", "")], second_state["event_id"]) + + def test_incremental_sync_lazy_loaded_no_timeline(self) -> None: + """Test that lazy-loading with an empty timeline doesn't return the full + state. + + There was a bug where an empty state filter would cause the DB to return + the full state, rather than an empty set. + """ + user = self.register_user("user", "password") + tok = self.login("user", "password") + + # Create a room as the user and set some custom state. + joined_room = self.helper.create_room_as(user, tok=tok) + + since_token = self.hs.get_event_sources().get_current_token() + end_stream_token = self.hs.get_event_sources().get_current_token() + + state = self.get_success( + self.sync_handler._compute_state_delta_for_incremental_sync( + room_id=joined_room, + sync_config=generate_sync_config(user, use_state_after=True), + batch=TimelineBatch( + prev_batch=end_stream_token, events=[], limited=True + ), + since_token=since_token, + end_token=end_stream_token, + members_to_fetch=set(), + timeline_state={}, + ) + ) + + self.assertEqual(state, {}) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 878d9683b6..b12ffc3665 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py
@@ -796,6 +796,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): s = self.get_success(self.handler.search_users(u1, "user2", 10)) self.assertEqual(len(s["results"]), 1) + # Kept old spam checker without `requester_id` tests for backwards compatibility. async def allow_all(user_profile: UserProfile) -> bool: # Allow all users. return False @@ -809,6 +810,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): s = self.get_success(self.handler.search_users(u1, "user2", 10)) self.assertEqual(len(s["results"]), 1) + # Kept old spam checker without `requester_id` tests for backwards compatibility. # Configure a spam checker that filters all users. async def block_all(user_profile: UserProfile) -> bool: # All users are spammy. @@ -820,6 +822,40 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): s = self.get_success(self.handler.search_users(u1, "user2", 10)) self.assertEqual(len(s["results"]), 0) + async def allow_all_expects_requester_id( + user_profile: UserProfile, requester_id: str + ) -> bool: + self.assertEqual(requester_id, u1) + # Allow all users. + return False + + # Configure a spam checker that does not filter any users. + spam_checker = self.hs.get_module_api_callbacks().spam_checker + spam_checker._check_username_for_spam_callbacks = [ + allow_all_expects_requester_id + ] + + # The results do not change: + # We get one search result when searching for user2 by user1. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 1) + + # Configure a spam checker that filters all users. + async def block_all_expects_requester_id( + user_profile: UserProfile, requester_id: str + ) -> bool: + self.assertEqual(requester_id, u1) + # All users are spammy. + return True + + spam_checker._check_username_for_spam_callbacks = [ + block_all_expects_requester_id + ] + + # User1 now gets no search results for any of the other users. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 0) + @override_config( { "spam_checker": { @@ -956,6 +992,67 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): [self.assertIn(user, local_users) for user in received_user_id_ordering[:3]] [self.assertIn(user, remote_users) for user in received_user_id_ordering[3:]] + @override_config( + { + "user_directory": { + "enabled": True, + "search_all_users": True, + "exclude_remote_users": True, + } + } + ) + def test_exclude_remote_users(self) -> None: + """Tests that only local users are returned when + user_directory.exclude_remote_users is True. + """ + + # Create a room and few users to test the directory with + searching_user = self.register_user("searcher", "password") + searching_user_tok = self.login("searcher", "password") + + room_id = self.helper.create_room_as( + searching_user, + room_version=RoomVersions.V1.identifier, + tok=searching_user_tok, + ) + + # Create a few local users and join them to the room + local_user_1 = self.register_user("user_xxxxx", "password") + local_user_2 = self.register_user("user_bbbbb", "password") + local_user_3 = self.register_user("user_zzzzz", "password") + + self._add_user_to_room(room_id, RoomVersions.V1, local_user_1) + self._add_user_to_room(room_id, RoomVersions.V1, local_user_2) + self._add_user_to_room(room_id, RoomVersions.V1, local_user_3) + + # Create a few "remote" users and join them to the room + remote_user_1 = "@user_aaaaa:remote_server" + remote_user_2 = "@user_yyyyy:remote_server" + remote_user_3 = "@user_ccccc:remote_server" + self._add_user_to_room(room_id, RoomVersions.V1, remote_user_1) + self._add_user_to_room(room_id, RoomVersions.V1, remote_user_2) + self._add_user_to_room(room_id, RoomVersions.V1, remote_user_3) + + local_users = [local_user_1, local_user_2, local_user_3] + remote_users = [remote_user_1, remote_user_2, remote_user_3] + + # The local searching user searches for the term "user", which other users have + # in their user id + results = self.get_success( + self.handler.search_users(searching_user, "user", 20) + )["results"] + received_user_ids = [result["user_id"] for result in results] + + for user in local_users: + self.assertIn( + user, received_user_ids, f"Local user {user} not found in results" + ) + + for user in remote_users: + self.assertNotIn( + user, received_user_ids, f"Remote user {user} should not be in results" + ) + def _add_user_to_room( self, room_id: str, @@ -1081,10 +1178,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): for use_numeric in [False, True]: if use_numeric: prefix1 = f"{i}" - prefix2 = f"{i+1}" + prefix2 = f"{i + 1}" else: prefix1 = f"a{i}" - prefix2 = f"a{i+1}" + prefix2 = f"a{i + 1}" local_user_1 = self.register_user(f"user{char}{prefix1}", "password") local_user_2 = self.register_user(f"user{char}{prefix2}", "password") diff --git a/tests/handlers/test_worker_lock.py b/tests/handlers/test_worker_lock.py
index 6e9a15c8ee..0691d3f99c 100644 --- a/tests/handlers/test_worker_lock.py +++ b/tests/handlers/test_worker_lock.py
@@ -19,6 +19,9 @@ # # +import logging +import platform + from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor @@ -29,6 +32,8 @@ from tests import unittest from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.utils import test_timeout +logger = logging.getLogger(__name__) + class WorkerLockTestCase(unittest.HomeserverTestCase): def prepare( @@ -53,12 +58,27 @@ class WorkerLockTestCase(unittest.HomeserverTestCase): def test_lock_contention(self) -> None: """Test lock contention when a lot of locks wait on a single worker""" - + nb_locks_to_test = 500 + current_machine = platform.machine().lower() + if current_machine.startswith("riscv"): + # RISC-V specific settings + timeout_seconds = 15 # Increased timeout for RISC-V + # add a print or log statement here for visibility in CI logs + logger.info( # use logger.info + f"Detected RISC-V architecture ({current_machine}). " + f"Adjusting test_lock_contention: timeout={timeout_seconds}s" + ) + else: + # Settings for other architectures + timeout_seconds = 5 # It takes around 0.5s on a 5+ years old laptop - with test_timeout(5): - nb_locks = 500 - d = self._take_locks(nb_locks) - self.assertEqual(self.get_success(d), nb_locks) + with test_timeout(timeout_seconds): # Use the dynamically set timeout + d = self._take_locks( + nb_locks_to_test + ) # Use the (potentially adjusted) number of locks + self.assertEqual( + self.get_success(d), nb_locks_to_test + ) # Assert against the used number of locks async def _take_locks(self, nb_locks: int) -> int: locks = [