summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2024-05-07 11:48:08 +0100
committerErik Johnston <erik@matrix.org>2024-05-07 11:48:08 +0100
commitfa68816fb858081b880557af8baee03e9325e494 (patch)
tree7a24c8bc68ca45ec4ca84d51dbf6e1076476599d /tests
parentMerge remote-tracking branch 'origin/release-v1.106' into matrix-org-hotfixes (diff)
parentBump serde from 1.0.199 to 1.0.200 (#17161) (diff)
downloadsynapse-fa68816fb858081b880557af8baee03e9325e494.tar.xz
Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
Diffstat (limited to 'tests')
-rw-r--r--tests/events/test_utils.py24
-rw-r--r--tests/federation/test_federation_server.py17
-rw-r--r--tests/federation/transport/test_server.py9
-rw-r--r--tests/rest/admin/test_federation.py67
-rw-r--r--tests/rest/client/test_login.py204
-rw-r--r--tests/rest/client/test_retention.py7
-rw-r--r--tests/rest/client/test_rooms.py69
-rw-r--r--tests/storage/test_registration.py2
-rw-r--r--tests/storage/test_room_search.py13
-rw-r--r--tests/test_visibility.py320
-rw-r--r--tests/util/test_stream_change_cache.py17
11 files changed, 624 insertions, 125 deletions
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py

index cf81bcf52c..d5ac66a6ed 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py
@@ -32,6 +32,7 @@ from synapse.events.utils import ( PowerLevelsContent, SerializeEventConfig, _split_field, + clone_event, copy_and_fixup_power_levels_contents, maybe_upsert_event_field, prune_event, @@ -611,6 +612,29 @@ class PruneEventTestCase(stdlib_unittest.TestCase): ) +class CloneEventTestCase(stdlib_unittest.TestCase): + def test_unsigned_is_copied(self) -> None: + original = make_event_from_dict( + { + "type": "A", + "event_id": "$test:domain", + "unsigned": {"a": 1, "b": 2}, + }, + RoomVersions.V1, + {"txn_id": "txn"}, + ) + original.internal_metadata.stream_ordering = 1234 + self.assertEqual(original.internal_metadata.stream_ordering, 1234) + + cloned = clone_event(original) + cloned.unsigned["b"] = 3 + + self.assertEqual(original.unsigned, {"a": 1, "b": 2}) + self.assertEqual(cloned.unsigned, {"a": 1, "b": 3}) + self.assertEqual(cloned.internal_metadata.stream_ordering, 1234) + self.assertEqual(cloned.internal_metadata.txn_id, "txn") + + class SerializeEventTestCase(stdlib_unittest.TestCase): def serialize(self, ev: EventBase, fields: Optional[List[str]]) -> JsonDict: return serialize_event( diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index 36684c2c91..88261450b1 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py
@@ -67,6 +67,23 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON") + def test_failed_edu_causes_500(self) -> None: + """If the EDU handler fails, /send should return a 500.""" + + async def failing_handler(_origin: str, _content: JsonDict) -> None: + raise Exception("bleh") + + self.hs.get_federation_registry().register_edu_handler( + "FAIL_EDU_TYPE", failing_handler + ) + + channel = self.make_signed_federation_request( + "PUT", + "/_matrix/federation/v1/send/txn", + {"edus": [{"edu_type": "FAIL_EDU_TYPE", "content": {}}]}, + ) + self.assertEqual(500, channel.code, channel.result) + class ServerACLsTestCase(unittest.TestCase): def test_blocked_server(self) -> None: diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py
index 190b79bf26..0237369998 100644 --- a/tests/federation/transport/test_server.py +++ b/tests/federation/transport/test_server.py
@@ -59,7 +59,14 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase): "/_matrix/federation/v1/send/txn_id_1234/", content={ "edus": [ - {"edu_type": EduTypes.DEVICE_LIST_UPDATE, "content": {"foo": "bar"}} + { + "edu_type": EduTypes.DEVICE_LIST_UPDATE, + "content": { + "device_id": "QBUAZIFURK", + "stream_id": 0, + "user_id": "@user:id", + }, + }, ], "pdus": [], }, diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py
index c1d88f0176..c2015774a1 100644 --- a/tests/rest/admin/test_federation.py +++ b/tests/rest/admin/test_federation.py
@@ -778,20 +778,81 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): self.assertEqual(number_rooms, len(channel.json_body["rooms"])) self._check_fields(channel.json_body["rooms"]) - def _create_destination_rooms(self, number_rooms: int) -> None: - """Create a number rooms for destination + def test_room_filtering(self) -> None: + """Tests that rooms are correctly filtered""" + + # Create two rooms on the homeserver. Each has a different remote homeserver + # participating in it. + other_destination = "other.destination.org" + room_ids_self_dest = self._create_destination_rooms(2, destination=self.dest) + room_ids_other_dest = self._create_destination_rooms( + 1, destination=other_destination + ) + + # Ask for the rooms that `self.dest` is participating in. + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # Verify that we received only the rooms that `self.dest` is participating in. + # This assertion method name is a bit misleading. It does check that both lists + # contain the same items, and the same counts. + self.assertCountEqual( + [r["room_id"] for r in channel.json_body["rooms"]], room_ids_self_dest + ) + self.assertEqual(channel.json_body["total"], len(room_ids_self_dest)) + + # Ask for the rooms that `other_destination` is participating in. + channel = self.make_request( + "GET", + self.url.replace(self.dest, other_destination), + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # Verify that we received only the rooms that `other_destination` is + # participating in. + self.assertCountEqual( + [r["room_id"] for r in channel.json_body["rooms"]], room_ids_other_dest + ) + self.assertEqual(channel.json_body["total"], len(room_ids_other_dest)) + + def _create_destination_rooms( + self, + number_rooms: int, + destination: Optional[str] = None, + ) -> List[str]: + """ + Create the given number of rooms. The given `destination` homeserver will + be recorded as a participant. Args: number_rooms: Number of rooms to be created + destination: The domain of the homeserver that will be considered + as a participant in the rooms. + + Returns: + The IDs of the rooms that have been created. """ + room_ids = [] + + # If no destination was provided, default to `self.dest`. + if destination is None: + destination = self.dest + for _ in range(number_rooms): room_id = self.helper.create_room_as( self.admin_user, tok=self.admin_user_tok ) + room_ids.append(room_id) + self.get_success( - self.store.store_destination_rooms_entries((self.dest,), room_id, 1234) + self.store.store_destination_rooms_entries( + (destination,), room_id, 1234 + ) ) + return room_ids + def _check_fields(self, content: List[JsonDict]) -> None: """Checks that the expected room attributes are present in content diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index 3a1f150082..3fb77fd9dd 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py
@@ -20,7 +20,17 @@ # import time import urllib.parse -from typing import Any, Collection, Dict, List, Optional, Tuple, Union +from typing import ( + Any, + BinaryIO, + Callable, + Collection, + Dict, + List, + Optional, + Tuple, + Union, +) from unittest.mock import Mock from urllib.parse import urlencode @@ -34,8 +44,9 @@ import synapse.rest.admin from synapse.api.constants import ApprovalNoticeMedium, LoginType from synapse.api.errors import Codes from synapse.appservice import ApplicationService +from synapse.http.client import RawHeaders from synapse.module_api import ModuleApi -from synapse.rest.client import devices, login, logout, register +from synapse.rest.client import account, devices, login, logout, profile, register from synapse.rest.client.account import WhoamiRestServlet from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.server import HomeServer @@ -48,6 +59,7 @@ from tests.handlers.test_saml import has_saml2 from tests.rest.client.utils import TEST_OIDC_CONFIG from tests.server import FakeChannel from tests.test_utils.html_parsers import TestHtmlParser +from tests.test_utils.oidc import FakeOidcServer from tests.unittest import HomeserverTestCase, override_config, skip_unless try: @@ -1421,7 +1433,19 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): class UsernamePickerTestCase(HomeserverTestCase): """Tests for the username picker flow of SSO login""" - servlets = [login.register_servlets] + servlets = [ + login.register_servlets, + profile.register_servlets, + account.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.http_client = Mock(spec=["get_file"]) + self.http_client.get_file.side_effect = mock_get_file + hs = self.setup_test_homeserver( + proxied_blocklisted_http_client=self.http_client + ) + return hs def default_config(self) -> Dict[str, Any]: config = super().default_config() @@ -1430,7 +1454,11 @@ class UsernamePickerTestCase(HomeserverTestCase): config["oidc_config"] = {} config["oidc_config"].update(TEST_OIDC_CONFIG) config["oidc_config"]["user_mapping_provider"] = { - "config": {"display_name_template": "{{ user.displayname }}"} + "config": { + "display_name_template": "{{ user.displayname }}", + "email_template": "{{ user.email }}", + "picture_template": "{{ user.picture }}", + } } # whitelist this client URI so we redirect straight to it rather than @@ -1443,15 +1471,22 @@ class UsernamePickerTestCase(HomeserverTestCase): d.update(build_synapse_client_resource_tree(self.hs)) return d - def test_username_picker(self) -> None: - """Test the happy path of a username picker flow.""" - - fake_oidc_server = self.helper.fake_oidc_server() - + def proceed_to_username_picker_page( + self, + fake_oidc_server: FakeOidcServer, + displayname: str, + email: str, + picture: str, + ) -> Tuple[str, str]: # do the start of the login flow channel, _ = self.helper.auth_via_oidc( fake_oidc_server, - {"sub": "tester", "displayname": "Jonny"}, + { + "sub": "tester", + "displayname": displayname, + "picture": picture, + "email": email, + }, TEST_CLIENT_REDIRECT_URL, ) @@ -1478,16 +1513,132 @@ class UsernamePickerTestCase(HomeserverTestCase): ) session = username_mapping_sessions[session_id] self.assertEqual(session.remote_user_id, "tester") - self.assertEqual(session.display_name, "Jonny") + self.assertEqual(session.display_name, displayname) + self.assertEqual(session.emails, [email]) + self.assertEqual(session.avatar_url, picture) self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL) # the expiry time should be about 15 minutes away expected_expiry = self.clock.time_msec() + (15 * 60 * 1000) self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000) + return picker_url, session_id + + def test_username_picker_use_displayname_avatar_and_email(self) -> None: + """Test the happy path of a username picker flow with using displayname, avatar and email.""" + + fake_oidc_server = self.helper.fake_oidc_server() + + mxid = "@bobby:test" + displayname = "Jonny" + email = "bobby@test.com" + picture = "mxc://test/avatar_url" + + picker_url, session_id = self.proceed_to_username_picker_page( + fake_oidc_server, displayname, email, picture + ) + + # Now, submit a username to the username picker, which should serve a redirect + # to the completion page. + # Also specify that we should use the provided displayname, avatar and email. + content = urlencode( + { + b"username": b"bobby", + b"use_display_name": b"true", + b"use_avatar": b"true", + b"use_email": email, + } + ).encode("utf8") + chan = self.make_request( + "POST", + path=picker_url, + content=content, + content_is_form=True, + custom_headers=[ + ("Cookie", "username_mapping_session=" + session_id), + # old versions of twisted don't do form-parsing without a valid + # content-length header. + ("Content-Length", str(len(content))), + ], + ) + self.assertEqual(chan.code, 302, chan.result) + location_headers = chan.headers.getRawHeaders("Location") + assert location_headers + + # send a request to the completion page, which should 302 to the client redirectUrl + chan = self.make_request( + "GET", + path=location_headers[0], + custom_headers=[("Cookie", "username_mapping_session=" + session_id)], + ) + self.assertEqual(chan.code, 302, chan.result) + location_headers = chan.headers.getRawHeaders("Location") + assert location_headers + + # ensure that the returned location matches the requested redirect URL + path, query = location_headers[0].split("?", 1) + self.assertEqual(path, "https://x") + + # it will have url-encoded the params properly, so we'll have to parse them + params = urllib.parse.parse_qsl( + query, keep_blank_values=True, strict_parsing=True, errors="strict" + ) + self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS) + self.assertEqual(params[2][0], "loginToken") + + # fish the login token out of the returned redirect uri + login_token = params[2][1] + + # finally, submit the matrix login token to the login API, which gives us our + # matrix access token, mxid, and device id. + chan = self.make_request( + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, + ) + self.assertEqual(chan.code, 200, chan.result) + self.assertEqual(chan.json_body["user_id"], mxid) + + # ensure the displayname and avatar from the OIDC response have been configured for the user. + channel = self.make_request( + "GET", "/profile/" + mxid, access_token=chan.json_body["access_token"] + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertIn("mxc://test", channel.json_body["avatar_url"]) + self.assertEqual(displayname, channel.json_body["displayname"]) + + # ensure the email from the OIDC response has been configured for the user. + channel = self.make_request( + "GET", "/account/3pid", access_token=chan.json_body["access_token"] + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(email, channel.json_body["threepids"][0]["address"]) + + def test_username_picker_dont_use_displayname_avatar_or_email(self) -> None: + """Test the happy path of a username picker flow without using displayname, avatar or email.""" + + fake_oidc_server = self.helper.fake_oidc_server() + + mxid = "@bobby:test" + displayname = "Jonny" + email = "bobby@test.com" + picture = "mxc://test/avatar_url" + username = "bobby" + + picker_url, session_id = self.proceed_to_username_picker_page( + fake_oidc_server, displayname, email, picture + ) + # Now, submit a username to the username picker, which should serve a redirect - # to the completion page - content = urlencode({b"username": b"bobby"}).encode("utf8") + # to the completion page. + # Also specify that we should not use the provided displayname, avatar or email. + content = urlencode( + { + b"username": username, + b"use_display_name": b"false", + b"use_avatar": b"false", + } + ).encode("utf8") chan = self.make_request( "POST", path=picker_url, @@ -1536,4 +1687,29 @@ class UsernamePickerTestCase(HomeserverTestCase): content={"type": "m.login.token", "token": login_token}, ) self.assertEqual(chan.code, 200, chan.result) - self.assertEqual(chan.json_body["user_id"], "@bobby:test") + self.assertEqual(chan.json_body["user_id"], mxid) + + # ensure the displayname and avatar from the OIDC response have not been configured for the user. + channel = self.make_request( + "GET", "/profile/" + mxid, access_token=chan.json_body["access_token"] + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertNotIn("avatar_url", channel.json_body) + self.assertEqual(username, channel.json_body["displayname"]) + + # ensure the email from the OIDC response has not been configured for the user. + channel = self.make_request( + "GET", "/account/3pid", access_token=chan.json_body["access_token"] + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertListEqual([], channel.json_body["threepids"]) + + +async def mock_get_file( + url: str, + output_stream: BinaryIO, + max_size: Optional[int] = None, + headers: Optional[RawHeaders] = None, + is_allowed_content_type: Optional[Callable[[str], bool]] = None, +) -> Tuple[int, Dict[bytes, List[bytes]], str, int]: + return 0, {b"Content-Type": [b"image/png"]}, "", 200 diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 09a5d64349..ceae40498e 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py
@@ -163,7 +163,12 @@ class RetentionTestCase(unittest.HomeserverTestCase): ) self.assertEqual(2, len(events), "events retrieved from database") filtered_events = self.get_success( - filter_events_for_client(storage_controllers, self.user_id, events) + filter_events_for_client( + storage_controllers, + self.user_id, + events, + msc4115_membership_on_events=True, + ) ) # We should only get one event back. diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index b796163dcb..d398cead1c 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py
@@ -48,7 +48,16 @@ from synapse.appservice import ApplicationService from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.rest import admin -from synapse.rest.client import account, directory, login, profile, register, room, sync +from synapse.rest.client import ( + account, + directory, + knock, + login, + profile, + register, + room, + sync, +) from synapse.server import HomeServer from synapse.types import JsonDict, RoomAlias, UserID, create_requester from synapse.util import Clock @@ -733,7 +742,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(32, channel.resource_usage.db_txn_count) + self.assertEqual(33, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -746,7 +755,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(34, channel.resource_usage.db_txn_count) + self.assertEqual(35, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id @@ -1154,6 +1163,7 @@ class RoomJoinTestCase(RoomBase): admin.register_servlets, login.register_servlets, room.register_servlets, + knock.register_servlets, ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -1167,6 +1177,8 @@ class RoomJoinTestCase(RoomBase): self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) + self.store = hs.get_datastores().main + def test_spam_checker_may_join_room_deprecated(self) -> None: """Tests that the user_may_join_room spam checker callback is correctly called and blocks room joins when needed. @@ -1317,6 +1329,57 @@ class RoomJoinTestCase(RoomBase): expect_additional_fields=return_value[1], ) + def test_suspended_user_cannot_join_room(self) -> None: + # set the user as suspended + self.get_success(self.store.set_user_suspended_status(self.user2, True)) + + channel = self.make_request( + "POST", f"/join/{self.room1}", access_token=self.tok2 + ) + self.assertEqual(channel.code, 403) + self.assertEqual( + channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" + ) + + channel = self.make_request( + "POST", f"/rooms/{self.room1}/join", access_token=self.tok2 + ) + self.assertEqual(channel.code, 403) + self.assertEqual( + channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" + ) + + def test_suspended_user_cannot_knock_on_room(self) -> None: + # set the user as suspended + self.get_success(self.store.set_user_suspended_status(self.user2, True)) + + channel = self.make_request( + "POST", + f"/_matrix/client/v3/knock/{self.room1}", + access_token=self.tok2, + content={}, + shorthand=False, + ) + self.assertEqual(channel.code, 403) + self.assertEqual( + channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" + ) + + def test_suspended_user_cannot_invite_to_room(self) -> None: + # set the user as suspended + self.get_success(self.store.set_user_suspended_status(self.user1, True)) + + # first user invites second user + channel = self.make_request( + "POST", + f"/rooms/{self.room1}/invite", + access_token=self.tok1, + content={"user_id": self.user2}, + ) + self.assertEqual( + channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" + ) + class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase): servlets = [ diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 505465d529..14e3871dc1 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py
@@ -43,7 +43,6 @@ class RegistrationStoreTestCase(HomeserverTestCase): self.assertEqual( UserInfo( - # TODO(paul): Surely this field should be 'user_id', not 'name' user_id=UserID.from_string(self.user_id), is_admin=False, is_guest=False, @@ -57,6 +56,7 @@ class RegistrationStoreTestCase(HomeserverTestCase): locked=False, is_shadow_banned=False, approved=True, + suspended=False, ), (self.get_success(self.store.get_user_by_id(self.user_id))), ) diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py
index 1eab89f140..340642b7e7 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py
@@ -71,17 +71,16 @@ class EventSearchInsertionTest(HomeserverTestCase): store.search_msgs([room_id], "hi bob", ["content.body"]) ) self.assertEqual(result.get("count"), 1) - if isinstance(store.database_engine, PostgresEngine): - self.assertIn("hi", result.get("highlights")) - self.assertIn("bob", result.get("highlights")) + self.assertIn("hi", result.get("highlights")) + self.assertIn("bob", result.get("highlights")) # Check that search works for an unrelated message result = self.get_success( store.search_msgs([room_id], "another", ["content.body"]) ) self.assertEqual(result.get("count"), 1) - if isinstance(store.database_engine, PostgresEngine): - self.assertIn("another", result.get("highlights")) + + self.assertIn("another", result.get("highlights")) # Check that search works for a search term that overlaps with the message # containing a null byte and an unrelated message. @@ -90,8 +89,8 @@ class EventSearchInsertionTest(HomeserverTestCase): result = self.get_success( store.search_msgs([room_id], "hi alice", ["content.body"]) ) - if isinstance(store.database_engine, PostgresEngine): - self.assertIn("alice", result.get("highlights")) + + self.assertIn("alice", result.get("highlights")) def test_non_string(self) -> None: """Test that non-string `value`s are not inserted into `event_search`. diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index e51f72d65f..3e2100eab4 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py
@@ -21,13 +21,19 @@ import logging from typing import Optional from unittest.mock import patch +from synapse.api.constants import EventUnsignedContentFields from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext -from synapse.types import JsonDict, create_requester +from synapse.rest import admin +from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.types import create_requester from synapse.visibility import filter_events_for_client, filter_events_for_server from tests import unittest +from tests.test_utils.event_injection import inject_event, inject_member_event +from tests.unittest import HomeserverTestCase from tests.utils import create_room logger = logging.getLogger(__name__) @@ -56,15 +62,31 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): # # before we do that, we persist some other events to act as state. - self._inject_visibility("@admin:hs", "joined") + self.get_success( + inject_visibility_event(self.hs, TEST_ROOM_ID, "@admin:hs", "joined") + ) for i in range(10): - self._inject_room_member("@resident%i:hs" % i) + self.get_success( + inject_member_event( + self.hs, + TEST_ROOM_ID, + "@resident%i:hs" % i, + "join", + ) + ) events_to_filter = [] for i in range(10): - user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server") - evt = self._inject_room_member(user, extra_content={"a": "b"}) + evt = self.get_success( + inject_member_event( + self.hs, + TEST_ROOM_ID, + "@user%i:%s" % (i, "test_server" if i == 5 else "other_server"), + "join", + extra_content={"a": "b"}, + ) + ) events_to_filter.append(evt) filtered = self.get_success( @@ -90,8 +112,19 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): def test_filter_outlier(self) -> None: # outlier events must be returned, for the good of the collective federation - self._inject_room_member("@resident:remote_hs") - self._inject_visibility("@resident:remote_hs", "joined") + self.get_success( + inject_member_event( + self.hs, + TEST_ROOM_ID, + "@resident:remote_hs", + "join", + ) + ) + self.get_success( + inject_visibility_event( + self.hs, TEST_ROOM_ID, "@resident:remote_hs", "joined" + ) + ) outlier = self._inject_outlier() self.assertEqual( @@ -110,7 +143,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): ) # it should also work when there are other events in the list - evt = self._inject_message("@unerased:local_hs") + evt = self.get_success( + inject_message_event(self.hs, TEST_ROOM_ID, "@unerased:local_hs") + ) filtered = self.get_success( filter_events_for_server( @@ -150,19 +185,34 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): # change in the middle of them. events_to_filter = [] - evt = self._inject_message("@unerased:local_hs") + evt = self.get_success( + inject_message_event(self.hs, TEST_ROOM_ID, "@unerased:local_hs") + ) events_to_filter.append(evt) - evt = self._inject_message("@erased:local_hs") + evt = self.get_success( + inject_message_event(self.hs, TEST_ROOM_ID, "@erased:local_hs") + ) events_to_filter.append(evt) - evt = self._inject_room_member("@joiner:remote_hs") + evt = self.get_success( + inject_member_event( + self.hs, + TEST_ROOM_ID, + "@joiner:remote_hs", + "join", + ) + ) events_to_filter.append(evt) - evt = self._inject_message("@unerased:local_hs") + evt = self.get_success( + inject_message_event(self.hs, TEST_ROOM_ID, "@unerased:local_hs") + ) events_to_filter.append(evt) - evt = self._inject_message("@erased:local_hs") + evt = self.get_success( + inject_message_event(self.hs, TEST_ROOM_ID, "@erased:local_hs") + ) events_to_filter.append(evt) # the erasey user gets erased @@ -200,76 +250,6 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): for i in (1, 4): self.assertNotIn("body", filtered[i].content) - def _inject_visibility(self, user_id: str, visibility: str) -> EventBase: - content = {"history_visibility": visibility} - builder = self.event_builder_factory.for_room_version( - RoomVersions.V1, - { - "type": "m.room.history_visibility", - "sender": user_id, - "state_key": "", - "room_id": TEST_ROOM_ID, - "content": content, - }, - ) - - event, unpersisted_context = self.get_success( - self.event_creation_handler.create_new_client_event(builder) - ) - context = self.get_success(unpersisted_context.persist(event)) - self.get_success(self._persistence.persist_event(event, context)) - return event - - def _inject_room_member( - self, - user_id: str, - membership: str = "join", - extra_content: Optional[JsonDict] = None, - ) -> EventBase: - content = {"membership": membership} - content.update(extra_content or {}) - builder = self.event_builder_factory.for_room_version( - RoomVersions.V1, - { - "type": "m.room.member", - "sender": user_id, - "state_key": user_id, - "room_id": TEST_ROOM_ID, - "content": content, - }, - ) - - event, unpersisted_context = self.get_success( - self.event_creation_handler.create_new_client_event(builder) - ) - context = self.get_success(unpersisted_context.persist(event)) - - self.get_success(self._persistence.persist_event(event, context)) - return event - - def _inject_message( - self, user_id: str, content: Optional[JsonDict] = None - ) -> EventBase: - if content is None: - content = {"body": "testytest", "msgtype": "m.text"} - builder = self.event_builder_factory.for_room_version( - RoomVersions.V1, - { - "type": "m.room.message", - "sender": user_id, - "room_id": TEST_ROOM_ID, - "content": content, - }, - ) - - event, unpersisted_context = self.get_success( - self.event_creation_handler.create_new_client_event(builder) - ) - context = self.get_success(unpersisted_context.persist(event)) - - self.get_success(self._persistence.persist_event(event, context)) - return event - def _inject_outlier(self) -> EventBase: builder = self.event_builder_factory.for_room_version( RoomVersions.V1, @@ -292,7 +272,122 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): return event -class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase): +class FilterEventsForClientTestCase(HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def test_joined_history_visibility(self) -> None: + # User joins and leaves room. Should be able to see the join and leave, + # and messages sent between the two, but not before or after. + + self.register_user("resident", "p1") + resident_token = self.login("resident", "p1") + room_id = self.helper.create_room_as("resident", tok=resident_token) + + self.get_success( + inject_visibility_event(self.hs, room_id, "@resident:test", "joined") + ) + before_event = self.get_success( + inject_message_event(self.hs, room_id, "@resident:test", body="before") + ) + join_event = self.get_success( + inject_member_event(self.hs, room_id, "@joiner:test", "join") + ) + during_event = self.get_success( + inject_message_event(self.hs, room_id, "@resident:test", body="during") + ) + leave_event = self.get_success( + inject_member_event(self.hs, room_id, "@joiner:test", "leave") + ) + after_event = self.get_success( + inject_message_event(self.hs, room_id, "@resident:test", body="after") + ) + + # We have to reload the events from the db, to ensure that prev_content is + # populated. + events_to_filter = [ + self.get_success( + self.hs.get_storage_controllers().main.get_event( + e.event_id, + get_prev_content=True, + ) + ) + for e in [ + before_event, + join_event, + during_event, + leave_event, + after_event, + ] + ] + + # Now run the events through the filter, and check that we can see the events + # we expect, and that the membership prop is as expected. + # + # We deliberately do the queries for both users upfront; this simulates + # concurrent queries on the server, and helps ensure that we aren't + # accidentally serving the same event object (with the same unsigned.membership + # property) to both users. + joiner_filtered_events = self.get_success( + filter_events_for_client( + self.hs.get_storage_controllers(), + "@joiner:test", + events_to_filter, + msc4115_membership_on_events=True, + ) + ) + resident_filtered_events = self.get_success( + filter_events_for_client( + self.hs.get_storage_controllers(), + "@resident:test", + events_to_filter, + msc4115_membership_on_events=True, + ) + ) + + # The joiner should be able to seem the join and leave, + # and messages sent between the two, but not before or after. + self.assertEqual( + [e.event_id for e in [join_event, during_event, leave_event]], + [e.event_id for e in joiner_filtered_events], + ) + self.assertEqual( + ["join", "join", "leave"], + [ + e.unsigned[EventUnsignedContentFields.MSC4115_MEMBERSHIP] + for e in joiner_filtered_events + ], + ) + + # The resident user should see all the events. + self.assertEqual( + [ + e.event_id + for e in [ + before_event, + join_event, + during_event, + leave_event, + after_event, + ] + ], + [e.event_id for e in resident_filtered_events], + ) + self.assertEqual( + ["join", "join", "join", "join", "join"], + [ + e.unsigned[EventUnsignedContentFields.MSC4115_MEMBERSHIP] + for e in resident_filtered_events + ], + ) + + +class FilterEventsOutOfBandEventsForClientTestCase( + unittest.FederatingHomeserverTestCase +): def test_out_of_band_invite_rejection(self) -> None: # this is where we have received an invite event over federation, and then # rejected it. @@ -341,15 +436,24 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase): ) # the invited user should be able to see both the invite and the rejection + filtered_events = self.get_success( + filter_events_for_client( + self.hs.get_storage_controllers(), + "@user:test", + [invite_event, reject_event], + msc4115_membership_on_events=True, + ) + ) self.assertEqual( - self.get_success( - filter_events_for_client( - self.hs.get_storage_controllers(), - "@user:test", - [invite_event, reject_event], - ) - ), - [invite_event, reject_event], + [e.event_id for e in filtered_events], + [e.event_id for e in [invite_event, reject_event]], + ) + self.assertEqual( + ["invite", "leave"], + [ + e.unsigned[EventUnsignedContentFields.MSC4115_MEMBERSHIP] + for e in filtered_events + ], ) # other users should see neither @@ -359,7 +463,39 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase): self.hs.get_storage_controllers(), "@other:test", [invite_event, reject_event], + msc4115_membership_on_events=True, ) ), [], ) + + +async def inject_visibility_event( + hs: HomeServer, + room_id: str, + sender: str, + visibility: str, +) -> EventBase: + return await inject_event( + hs, + type="m.room.history_visibility", + sender=sender, + state_key="", + room_id=room_id, + content={"history_visibility": visibility}, + ) + + +async def inject_message_event( + hs: HomeServer, + room_id: str, + sender: str, + body: Optional[str] = "testytest", +) -> EventBase: + return await inject_event( + hs, + type="m.room.message", + sender=sender, + room_id=room_id, + content={"body": body, "msgtype": "m.text"}, + ) diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py
index 3df053493b..5d38718a50 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py
@@ -1,3 +1,5 @@ +from parameterized import parameterized + from synapse.util.caches.stream_change_cache import StreamChangeCache from tests import unittest @@ -161,7 +163,8 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): self.assertFalse(cache.has_any_entity_changed(2)) self.assertFalse(cache.has_any_entity_changed(3)) - def test_get_entities_changed(self) -> None: + @parameterized.expand([(0,), (1000000000,)]) + def test_get_entities_changed(self, perf_factor: int) -> None: """ StreamChangeCache.get_entities_changed will return the entities in the given list that have changed since the provided stream ID. If the @@ -178,7 +181,9 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): # get the ones after that point. self.assertEqual( cache.get_entities_changed( - ["user@foo.com", "bar@baz.net", "user@elsewhere.org"], stream_pos=2 + ["user@foo.com", "bar@baz.net", "user@elsewhere.org"], + stream_pos=2, + _perf_factor=perf_factor, ), {"bar@baz.net", "user@elsewhere.org"}, ) @@ -195,6 +200,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): "not@here.website", ], stream_pos=2, + _perf_factor=perf_factor, ), {"bar@baz.net", "user@elsewhere.org"}, ) @@ -210,6 +216,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): "not@here.website", ], stream_pos=0, + _perf_factor=perf_factor, ), {"user@foo.com", "bar@baz.net", "user@elsewhere.org", "not@here.website"}, ) @@ -217,7 +224,11 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): # Query a subset of the entries mid-way through the stream. We should # only get back the subset. self.assertEqual( - cache.get_entities_changed(["bar@baz.net"], stream_pos=2), + cache.get_entities_changed( + ["bar@baz.net"], + stream_pos=2, + _perf_factor=perf_factor, + ), {"bar@baz.net"}, )