From 59710437e4a885252de5e5555fbcf42d223b092c Mon Sep 17 00:00:00 2001 From: Melvyn Laïly Date: Fri, 26 Apr 2024 10:43:52 +0200 Subject: Return the search terms as search highlights for SQLite instead of nothing (#17000) Fixes https://github.com/element-hq/synapse/issues/16999 and https://github.com/element-hq/element-android/pull/8729 by returning the search terms as search highlights. --- tests/storage/test_room_search.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) (limited to 'tests') diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py index 1eab89f140..340642b7e7 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py @@ -71,17 +71,16 @@ class EventSearchInsertionTest(HomeserverTestCase): store.search_msgs([room_id], "hi bob", ["content.body"]) ) self.assertEqual(result.get("count"), 1) - if isinstance(store.database_engine, PostgresEngine): - self.assertIn("hi", result.get("highlights")) - self.assertIn("bob", result.get("highlights")) + self.assertIn("hi", result.get("highlights")) + self.assertIn("bob", result.get("highlights")) # Check that search works for an unrelated message result = self.get_success( store.search_msgs([room_id], "another", ["content.body"]) ) self.assertEqual(result.get("count"), 1) - if isinstance(store.database_engine, PostgresEngine): - self.assertIn("another", result.get("highlights")) + + self.assertIn("another", result.get("highlights")) # Check that search works for a search term that overlaps with the message # containing a null byte and an unrelated message. @@ -90,8 +89,8 @@ class EventSearchInsertionTest(HomeserverTestCase): result = self.get_success( store.search_msgs([room_id], "hi alice", ["content.body"]) ) - if isinstance(store.database_engine, PostgresEngine): - self.assertIn("alice", result.get("highlights")) + + self.assertIn("alice", result.get("highlights")) def test_non_string(self) -> None: """Test that non-string `value`s are not inserted into `event_search`. -- cgit 1.5.1 From 89fc579329d7c81c040b1c178099860e7de37bed Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 26 Apr 2024 10:52:24 +0100 Subject: Fix filtering of rooms when supplying the `destination` query parameter to `/_synapse/admin/v1/federation/destinations//rooms` (#17077) --- changelog.d/17077.bugfix | 1 + synapse/storage/databases/main/transactions.py | 1 + tests/rest/admin/test_federation.py | 67 ++++++++++++++++++++++++-- 3 files changed, 66 insertions(+), 3 deletions(-) create mode 100644 changelog.d/17077.bugfix (limited to 'tests') diff --git a/changelog.d/17077.bugfix b/changelog.d/17077.bugfix new file mode 100644 index 0000000000..7d8ea37406 --- /dev/null +++ b/changelog.d/17077.bugfix @@ -0,0 +1 @@ +Fixes a bug introduced in v1.52.0 where the `destination` query parameter for the [Destination Rooms Admin API](https://element-hq.github.io/synapse/v1.105/usage/administration/admin_api/federation.html#destination-rooms) failed to actually filter returned rooms. \ No newline at end of file diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 08e0241f68..770802483c 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -660,6 +660,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): limit=limit, retcols=("room_id", "stream_ordering"), order_direction=order, + keyvalues={"destination": destination}, ), ) return rooms, count diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py index c1d88f0176..c2015774a1 100644 --- a/tests/rest/admin/test_federation.py +++ b/tests/rest/admin/test_federation.py @@ -778,20 +778,81 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): self.assertEqual(number_rooms, len(channel.json_body["rooms"])) self._check_fields(channel.json_body["rooms"]) - def _create_destination_rooms(self, number_rooms: int) -> None: - """Create a number rooms for destination + def test_room_filtering(self) -> None: + """Tests that rooms are correctly filtered""" + + # Create two rooms on the homeserver. Each has a different remote homeserver + # participating in it. + other_destination = "other.destination.org" + room_ids_self_dest = self._create_destination_rooms(2, destination=self.dest) + room_ids_other_dest = self._create_destination_rooms( + 1, destination=other_destination + ) + + # Ask for the rooms that `self.dest` is participating in. + channel = self.make_request("GET", self.url, access_token=self.admin_user_tok) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # Verify that we received only the rooms that `self.dest` is participating in. + # This assertion method name is a bit misleading. It does check that both lists + # contain the same items, and the same counts. + self.assertCountEqual( + [r["room_id"] for r in channel.json_body["rooms"]], room_ids_self_dest + ) + self.assertEqual(channel.json_body["total"], len(room_ids_self_dest)) + + # Ask for the rooms that `other_destination` is participating in. + channel = self.make_request( + "GET", + self.url.replace(self.dest, other_destination), + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # Verify that we received only the rooms that `other_destination` is + # participating in. + self.assertCountEqual( + [r["room_id"] for r in channel.json_body["rooms"]], room_ids_other_dest + ) + self.assertEqual(channel.json_body["total"], len(room_ids_other_dest)) + + def _create_destination_rooms( + self, + number_rooms: int, + destination: Optional[str] = None, + ) -> List[str]: + """ + Create the given number of rooms. The given `destination` homeserver will + be recorded as a participant. Args: number_rooms: Number of rooms to be created + destination: The domain of the homeserver that will be considered + as a participant in the rooms. + + Returns: + The IDs of the rooms that have been created. """ + room_ids = [] + + # If no destination was provided, default to `self.dest`. + if destination is None: + destination = self.dest + for _ in range(number_rooms): room_id = self.helper.create_room_as( self.admin_user, tok=self.admin_user_tok ) + room_ids.append(room_id) + self.get_success( - self.store.store_destination_rooms_entries((self.dest,), room_id, 1234) + self.store.store_destination_rooms_entries( + (destination,), room_id, 1234 + ) ) + return room_ids + def _check_fields(self, content: List[JsonDict]) -> None: """Checks that the expected room attributes are present in content -- cgit 1.5.1 From c897ac63e90e198723baa4bc73574a30fb02176b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 29 Apr 2024 14:11:00 +0100 Subject: Ensure that incoming to-device messages are not dropped (#17127) ... when workers are unreachable, etc. Fixes https://github.com/element-hq/synapse/issues/17117. The general principle is just to make sure that we propagate any exceptions to the JsonResource, so that we return an error code to the sending server. That means that the sending server no longer considers the message safely sent, so it will retry later. In the issue, Erik mentions that an alternative solution would be to persist the to-device messages into a table so that they can be retried. This might be an improvement for performance, but even if we did that, we still need this mechanism, since we might be unable to reach the database. So, if we want to do that, it can be a later follow-up. --------- Co-authored-by: Erik Johnston --- changelog.d/17127.bugfix | 1 + synapse/federation/federation_server.py | 44 ++++++++++++++++++------------ synapse/handlers/devicemessage.py | 3 ++ tests/federation/test_federation_server.py | 17 ++++++++++++ tests/federation/transport/test_server.py | 9 +++++- 5 files changed, 55 insertions(+), 19 deletions(-) create mode 100644 changelog.d/17127.bugfix (limited to 'tests') diff --git a/changelog.d/17127.bugfix b/changelog.d/17127.bugfix new file mode 100644 index 0000000000..93c7314098 --- /dev/null +++ b/changelog.d/17127.bugfix @@ -0,0 +1 @@ +Fix a bug which meant that to-device messages received over federation could be dropped when the server was under load or networking problems caused problems between Synapse processes or the database. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 65d3a661fe..7ffc650aa1 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -546,7 +546,25 @@ class FederationServer(FederationBase): edu_type=edu_dict["edu_type"], content=edu_dict["content"], ) - await self.registry.on_edu(edu.edu_type, origin, edu.content) + try: + await self.registry.on_edu(edu.edu_type, origin, edu.content) + except Exception: + # If there was an error handling the EDU, we must reject the + # transaction. + # + # Some EDU types (notably, to-device messages) are, despite their name, + # expected to be reliable; if we weren't able to do something with it, + # we have to tell the sender that, and the only way the protocol gives + # us to do so is by sending an HTTP error back on the transaction. + # + # We log the exception now, and then raise a new SynapseError to cause + # the transaction to be failed. + logger.exception("Error handling EDU of type %s", edu.edu_type) + raise SynapseError(500, f"Error handing EDU of type {edu.edu_type}") + + # TODO: if the first EDU fails, we should probably abort the whole + # thing rather than carrying on with the rest of them. That would + # probably be best done inside `concurrently_execute`. await concurrently_execute( _process_edu, @@ -1414,12 +1432,7 @@ class FederationHandlerRegistry: handler = self.edu_handlers.get(edu_type) if handler: with start_active_span_from_edu(content, "handle_edu"): - try: - await handler(origin, content) - except SynapseError as e: - logger.info("Failed to handle edu %r: %r", edu_type, e) - except Exception: - logger.exception("Failed to handle edu %r", edu_type) + await handler(origin, content) return # Check if we can route it somewhere else that isn't us @@ -1428,17 +1441,12 @@ class FederationHandlerRegistry: # Pick an instance randomly so that we don't overload one. route_to = random.choice(instances) - try: - await self._send_edu( - instance_name=route_to, - edu_type=edu_type, - origin=origin, - content=content, - ) - except SynapseError as e: - logger.info("Failed to handle edu %r: %r", edu_type, e) - except Exception: - logger.exception("Failed to handle edu %r", edu_type) + await self._send_edu( + instance_name=route_to, + edu_type=edu_type, + origin=origin, + content=content, + ) return # Oh well, let's just log and move on. diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 2b034dcbb7..79be7c97c8 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -104,6 +104,9 @@ class DeviceMessageHandler: """ Handle receiving to-device messages from remote homeservers. + Note that any errors thrown from this method will cause the federation /send + request to receive an error response. + Args: origin: The remote homeserver. content: The JSON dictionary containing the to-device messages. diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index 36684c2c91..88261450b1 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -67,6 +67,23 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase): self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON") + def test_failed_edu_causes_500(self) -> None: + """If the EDU handler fails, /send should return a 500.""" + + async def failing_handler(_origin: str, _content: JsonDict) -> None: + raise Exception("bleh") + + self.hs.get_federation_registry().register_edu_handler( + "FAIL_EDU_TYPE", failing_handler + ) + + channel = self.make_signed_federation_request( + "PUT", + "/_matrix/federation/v1/send/txn", + {"edus": [{"edu_type": "FAIL_EDU_TYPE", "content": {}}]}, + ) + self.assertEqual(500, channel.code, channel.result) + class ServerACLsTestCase(unittest.TestCase): def test_blocked_server(self) -> None: diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py index 190b79bf26..0237369998 100644 --- a/tests/federation/transport/test_server.py +++ b/tests/federation/transport/test_server.py @@ -59,7 +59,14 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase): "/_matrix/federation/v1/send/txn_id_1234/", content={ "edus": [ - {"edu_type": EduTypes.DEVICE_LIST_UPDATE, "content": {"foo": "bar"}} + { + "edu_type": EduTypes.DEVICE_LIST_UPDATE, + "content": { + "device_id": "QBUAZIFURK", + "stream_id": 0, + "user_id": "@user:id", + }, + }, ], "pdus": [], }, -- cgit 1.5.1 From b548f7803a9b7ba51a66d47ddb9bb69dce541a48 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 29 Apr 2024 15:22:13 +0100 Subject: Add support for MSC4115 (#17104) Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- changelog.d/17104.feature | 1 + .../complement/conf/workers-shared-extra.yaml.j2 | 4 +- rust/src/events/internal_metadata.rs | 9 +- synapse/api/constants.py | 7 + synapse/config/experimental.py | 4 + synapse/events/utils.py | 30 +- synapse/handlers/admin.py | 6 +- synapse/handlers/events.py | 7 +- synapse/handlers/initial_sync.py | 7 +- synapse/handlers/pagination.py | 1 + synapse/handlers/relations.py | 3 + synapse/handlers/room.py | 1 + synapse/handlers/search.py | 20 +- synapse/handlers/sync.py | 2 + synapse/notifier.py | 1 + synapse/push/mailer.py | 5 +- synapse/visibility.py | 73 ++++- tests/events/test_utils.py | 24 ++ tests/rest/client/test_retention.py | 7 +- tests/test_visibility.py | 320 +++++++++++++++------ 20 files changed, 407 insertions(+), 125 deletions(-) create mode 100644 changelog.d/17104.feature (limited to 'tests') diff --git a/changelog.d/17104.feature b/changelog.d/17104.feature new file mode 100644 index 0000000000..1c2355e155 --- /dev/null +++ b/changelog.d/17104.feature @@ -0,0 +1 @@ +Add support for MSC4115 (membership metadata on events). diff --git a/docker/complement/conf/workers-shared-extra.yaml.j2 b/docker/complement/conf/workers-shared-extra.yaml.j2 index 32eada4419..a2c378f547 100644 --- a/docker/complement/conf/workers-shared-extra.yaml.j2 +++ b/docker/complement/conf/workers-shared-extra.yaml.j2 @@ -92,8 +92,6 @@ allow_device_name_lookup_over_federation: true ## Experimental Features ## experimental_features: - # client-side support for partial state in /send_join responses - faster_joins: true # Enable support for polls msc3381_polls_enabled: true # Enable deleting device-specific notification settings stored in account data @@ -105,6 +103,8 @@ experimental_features: # no UIA for x-signing upload for the first time msc3967_enabled: true + msc4115_membership_on_events: true + server_notices: system_mxid_localpart: _server system_mxid_display_name: "Server Alert" diff --git a/rust/src/events/internal_metadata.rs b/rust/src/events/internal_metadata.rs index a53601862d..53c7b1ba61 100644 --- a/rust/src/events/internal_metadata.rs +++ b/rust/src/events/internal_metadata.rs @@ -20,8 +20,10 @@ //! Implements the internal metadata class attached to events. //! -//! The internal metadata is a bit like a `TypedDict`, in that it is stored as a -//! JSON dict in the DB. Most events have zero, or only a few, of these keys +//! The internal metadata is a bit like a `TypedDict`, in that most of +//! it is stored as a JSON dict in the DB (the exceptions being `outlier` +//! and `stream_ordering` which have their own columns in the database). +//! Most events have zero, or only a few, of these keys //! set. Therefore, since we care more about memory size than performance here, //! we store these fields in a mapping. //! @@ -234,6 +236,9 @@ impl EventInternalMetadata { self.clone() } + /// Get a dict holding the data stored in the `internal_metadata` column in the database. + /// + /// Note that `outlier` and `stream_ordering` are stored in separate columns so are not returned here. fn get_dict(&self, py: Python<'_>) -> PyResult { let dict = PyDict::new(py); diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 98884b4967..0a9123c56b 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -234,6 +234,13 @@ class EventContentFields: TO_DEVICE_MSGID: Final = "org.matrix.msgid" +class EventUnsignedContentFields: + """Fields found inside the 'unsigned' data on events""" + + # Requesting user's membership, per MSC4115 + MSC4115_MEMBERSHIP: Final = "io.element.msc4115.membership" + + class RoomTypes: """Understood values of the room_type field of m.room.create events.""" diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index baa3580f29..749452ce93 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -432,3 +432,7 @@ class ExperimentalConfig(Config): "You cannot have MSC4108 both enabled and delegated at the same time", ("experimental", "msc4108_delegation_endpoint"), ) + + self.msc4115_membership_on_events = experimental.get( + "msc4115_membership_on_events", False + ) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index e0613d0dbc..0772472312 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -49,7 +49,7 @@ from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersion from synapse.types import JsonDict, Requester -from . import EventBase +from . import EventBase, make_event_from_dict if TYPE_CHECKING: from synapse.handlers.relations import BundledAggregations @@ -82,17 +82,14 @@ def prune_event(event: EventBase) -> EventBase: """ pruned_event_dict = prune_event_dict(event.room_version, event.get_dict()) - from . import make_event_from_dict - pruned_event = make_event_from_dict( pruned_event_dict, event.room_version, event.internal_metadata.get_dict() ) - # copy the internal fields + # Copy the bits of `internal_metadata` that aren't returned by `get_dict` pruned_event.internal_metadata.stream_ordering = ( event.internal_metadata.stream_ordering ) - pruned_event.internal_metadata.outlier = event.internal_metadata.outlier # Mark the event as redacted @@ -101,6 +98,29 @@ def prune_event(event: EventBase) -> EventBase: return pruned_event +def clone_event(event: EventBase) -> EventBase: + """Take a copy of the event. + + This is mostly useful because it does a *shallow* copy of the `unsigned` data, + which means it can then be updated without corrupting the in-memory cache. Note that + other properties of the event, such as `content`, are *not* (currently) copied here. + """ + # XXX: We rely on at least one of `event.get_dict()` and `make_event_from_dict()` + # making a copy of `unsigned`. Currently, both do, though I don't really know why. + # Still, as long as they do, there's not much point doing yet another copy here. + new_event = make_event_from_dict( + event.get_dict(), event.room_version, event.internal_metadata.get_dict() + ) + + # Copy the bits of `internal_metadata` that aren't returned by `get_dict`. + new_event.internal_metadata.stream_ordering = ( + event.internal_metadata.stream_ordering + ) + new_event.internal_metadata.outlier = event.internal_metadata.outlier + + return new_event + + def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDict: """Redacts the event_dict in the same way as `prune_event`, except it operates on dicts rather than event objects diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 360614e25b..702d40332c 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -42,6 +42,7 @@ class AdminHandler: self._device_handler = hs.get_device_handler() self._storage_controllers = hs.get_storage_controllers() self._state_storage_controller = self._storage_controllers.state + self._hs_config = hs.config self._msc3866_enabled = hs.config.experimental.msc3866.enabled async def get_whois(self, user: UserID) -> JsonMapping: @@ -217,7 +218,10 @@ class AdminHandler: ) events = await filter_events_for_client( - self._storage_controllers, user_id, events + self._storage_controllers, + user_id, + events, + msc4115_membership_on_events=self._hs_config.experimental.msc4115_membership_on_events, ) writer.write_events(room_id, events) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index c3fee74a98..09d553cff1 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -148,6 +148,7 @@ class EventHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self._storage_controllers = hs.get_storage_controllers() + self._config = hs.config async def get_event( self, @@ -189,7 +190,11 @@ class EventHandler: is_peeking = not is_user_in_room filtered = await filter_events_for_client( - self._storage_controllers, user.to_string(), [event], is_peeking=is_peeking + self._storage_controllers, + user.to_string(), + [event], + is_peeking=is_peeking, + msc4115_membership_on_events=self._config.experimental.msc4115_membership_on_events, ) if not filtered: diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index bcc5b285ac..d99fc4bec0 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -221,7 +221,10 @@ class InitialSyncHandler: ).addErrback(unwrapFirstError) messages = await filter_events_for_client( - self._storage_controllers, user_id, messages + self._storage_controllers, + user_id, + messages, + msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events, ) start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token) @@ -380,6 +383,7 @@ class InitialSyncHandler: requester.user.to_string(), messages, is_peeking=is_peeking, + msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events, ) start_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, token) @@ -494,6 +498,7 @@ class InitialSyncHandler: requester.user.to_string(), messages, is_peeking=is_peeking, + msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events, ) start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token) diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index cd3a9088cd..6617105cdb 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -623,6 +623,7 @@ class PaginationHandler: user_id, events, is_peeking=(member_event_id is None), + msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events, ) # if after the filter applied there are no more events diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 931ac0c813..c5cee8860b 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -95,6 +95,7 @@ class RelationsHandler: self._event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() self._event_creation_handler = hs.get_event_creation_handler() + self._config = hs.config async def get_relations( self, @@ -163,6 +164,7 @@ class RelationsHandler: user_id, events, is_peeking=(member_event_id is None), + msc4115_membership_on_events=self._config.experimental.msc4115_membership_on_events, ) # The relations returned for the requested event do include their @@ -608,6 +610,7 @@ class RelationsHandler: user_id, events, is_peeking=(member_event_id is None), + msc4115_membership_on_events=self._config.experimental.msc4115_membership_on_events, ) aggregations = await self.get_bundled_aggregations( diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 5e81a51638..51739a2653 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1476,6 +1476,7 @@ class RoomContextHandler: user.to_string(), events, is_peeking=is_peeking, + msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events, ) event = await self.store.get_event( diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 19c5a2f257..fdbe98de3b 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -480,7 +480,10 @@ class SearchHandler: filtered_events = await search_filter.filter([r["event"] for r in results]) events = await filter_events_for_client( - self._storage_controllers, user.to_string(), filtered_events + self._storage_controllers, + user.to_string(), + filtered_events, + msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events, ) events.sort(key=lambda e: -rank_map[e.event_id]) @@ -579,7 +582,10 @@ class SearchHandler: filtered_events = await search_filter.filter([r["event"] for r in results]) events = await filter_events_for_client( - self._storage_controllers, user.to_string(), filtered_events + self._storage_controllers, + user.to_string(), + filtered_events, + msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events, ) room_events.extend(events) @@ -664,11 +670,17 @@ class SearchHandler: ) events_before = await filter_events_for_client( - self._storage_controllers, user.to_string(), res.events_before + self._storage_controllers, + user.to_string(), + res.events_before, + msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events, ) events_after = await filter_events_for_client( - self._storage_controllers, user.to_string(), res.events_after + self._storage_controllers, + user.to_string(), + res.events_after, + msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events, ) context: JsonDict = { diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index a6d54ee4b8..8ff45a3353 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -596,6 +596,7 @@ class SyncHandler: sync_config.user.to_string(), recents, always_include_ids=current_state_ids, + msc4115_membership_on_events=self.hs_config.experimental.msc4115_membership_on_events, ) log_kv({"recents_after_visibility_filtering": len(recents)}) else: @@ -681,6 +682,7 @@ class SyncHandler: sync_config.user.to_string(), loaded_recents, always_include_ids=current_state_ids, + msc4115_membership_on_events=self.hs_config.experimental.msc4115_membership_on_events, ) loaded_recents = [] diff --git a/synapse/notifier.py b/synapse/notifier.py index e87333a80a..7c1cd3b5f2 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -721,6 +721,7 @@ class Notifier: user.to_string(), new_events, is_peeking=is_peeking, + msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events, ) elif keyname == StreamKeyType.PRESENCE: now = self.clock.time_msec() diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 7c15eb7440..49ce9d6dda 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -529,7 +529,10 @@ class Mailer: } the_events = await filter_events_for_client( - self._storage_controllers, user_id, results.events_before + self._storage_controllers, + user_id, + results.events_before, + msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events, ) the_events.append(notif_event) diff --git a/synapse/visibility.py b/synapse/visibility.py index d1d478129f..09a947ef15 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -36,10 +36,15 @@ from typing import ( import attr -from synapse.api.constants import EventTypes, HistoryVisibility, Membership +from synapse.api.constants import ( + EventTypes, + EventUnsignedContentFields, + HistoryVisibility, + Membership, +) from synapse.events import EventBase from synapse.events.snapshot import EventContext -from synapse.events.utils import prune_event +from synapse.events.utils import clone_event, prune_event from synapse.logging.opentracing import trace from synapse.storage.controllers import StorageControllers from synapse.storage.databases.main import DataStore @@ -77,6 +82,7 @@ async def filter_events_for_client( is_peeking: bool = False, always_include_ids: FrozenSet[str] = frozenset(), filter_send_to_client: bool = True, + msc4115_membership_on_events: bool = False, ) -> List[EventBase]: """ Check which events a user is allowed to see. If the user can see the event but its @@ -95,9 +101,12 @@ async def filter_events_for_client( filter_send_to_client: Whether we're checking an event that's going to be sent to a client. This might not always be the case since this function can also be called to check whether a user can see the state at a given point. + msc4115_membership_on_events: Whether to include the requesting user's + membership in the "unsigned" data, per MSC4115. Returns: - The filtered events. + The filtered events. If `msc4115_membership_on_events` is true, the `unsigned` + data is annotated with the membership state of `user_id` at each event. """ # Filter out events that have been soft failed so that we don't relay them # to clients. @@ -134,7 +143,8 @@ async def filter_events_for_client( ) def allowed(event: EventBase) -> Optional[EventBase]: - return _check_client_allowed_to_see_event( + state_after_event = event_id_to_state.get(event.event_id) + filtered = _check_client_allowed_to_see_event( user_id=user_id, event=event, clock=storage.main.clock, @@ -142,13 +152,45 @@ async def filter_events_for_client( sender_ignored=event.sender in ignore_list, always_include_ids=always_include_ids, retention_policy=retention_policies[room_id], - state=event_id_to_state.get(event.event_id), + state=state_after_event, is_peeking=is_peeking, sender_erased=erased_senders.get(event.sender, False), ) + if filtered is None: + return None + + if not msc4115_membership_on_events: + return filtered + + # Annotate the event with the user's membership after the event. + # + # Normally we just look in `state_after_event`, but if the event is an outlier + # we won't have such a state. The only outliers that are returned here are the + # user's own membership event, so we can just inspect that. + + user_membership_event: Optional[EventBase] + if event.type == EventTypes.Member and event.state_key == user_id: + user_membership_event = event + elif state_after_event is not None: + user_membership_event = state_after_event.get((EventTypes.Member, user_id)) + else: + # unreachable! + raise Exception("Missing state for event that is not user's own membership") + + user_membership = ( + user_membership_event.membership + if user_membership_event + else Membership.LEAVE + ) - # Check each event: gives an iterable of None or (a potentially modified) - # EventBase. + # Copy the event before updating the unsigned data: this shouldn't be persisted + # to the cache! + cloned = clone_event(filtered) + cloned.unsigned[EventUnsignedContentFields.MSC4115_MEMBERSHIP] = user_membership + + return cloned + + # Check each event: gives an iterable of None or (a modified) EventBase. filtered_events = map(allowed, events) # Turn it into a list and remove None entries before returning. @@ -396,7 +438,13 @@ def _check_client_allowed_to_see_event( @attr.s(frozen=True, slots=True, auto_attribs=True) class _CheckMembershipReturn: - "Return value of _check_membership" + """Return value of `_check_membership`. + + Attributes: + allowed: Whether the user should be allowed to see the event. + joined: Whether the user was joined to the room at the event. + """ + allowed: bool joined: bool @@ -408,12 +456,7 @@ def _check_membership( state: StateMap[EventBase], is_peeking: bool, ) -> _CheckMembershipReturn: - """Check whether the user can see the event due to their membership - - Returns: - True if they can, False if they can't, plus the membership of the user - at the event. - """ + """Check whether the user can see the event due to their membership""" # If the event is the user's own membership event, use the 'most joined' # membership membership = None @@ -435,7 +478,7 @@ def _check_membership( if membership == "leave" and ( prev_membership == "join" or prev_membership == "invite" ): - return _CheckMembershipReturn(True, membership == Membership.JOIN) + return _CheckMembershipReturn(True, False) new_priority = MEMBERSHIP_PRIORITY.index(membership) old_priority = MEMBERSHIP_PRIORITY.index(prev_membership) diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index cf81bcf52c..d5ac66a6ed 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -32,6 +32,7 @@ from synapse.events.utils import ( PowerLevelsContent, SerializeEventConfig, _split_field, + clone_event, copy_and_fixup_power_levels_contents, maybe_upsert_event_field, prune_event, @@ -611,6 +612,29 @@ class PruneEventTestCase(stdlib_unittest.TestCase): ) +class CloneEventTestCase(stdlib_unittest.TestCase): + def test_unsigned_is_copied(self) -> None: + original = make_event_from_dict( + { + "type": "A", + "event_id": "$test:domain", + "unsigned": {"a": 1, "b": 2}, + }, + RoomVersions.V1, + {"txn_id": "txn"}, + ) + original.internal_metadata.stream_ordering = 1234 + self.assertEqual(original.internal_metadata.stream_ordering, 1234) + + cloned = clone_event(original) + cloned.unsigned["b"] = 3 + + self.assertEqual(original.unsigned, {"a": 1, "b": 2}) + self.assertEqual(cloned.unsigned, {"a": 1, "b": 3}) + self.assertEqual(cloned.internal_metadata.stream_ordering, 1234) + self.assertEqual(cloned.internal_metadata.txn_id, "txn") + + class SerializeEventTestCase(stdlib_unittest.TestCase): def serialize(self, ev: EventBase, fields: Optional[List[str]]) -> JsonDict: return serialize_event( diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index 09a5d64349..ceae40498e 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -163,7 +163,12 @@ class RetentionTestCase(unittest.HomeserverTestCase): ) self.assertEqual(2, len(events), "events retrieved from database") filtered_events = self.get_success( - filter_events_for_client(storage_controllers, self.user_id, events) + filter_events_for_client( + storage_controllers, + self.user_id, + events, + msc4115_membership_on_events=True, + ) ) # We should only get one event back. diff --git a/tests/test_visibility.py b/tests/test_visibility.py index e51f72d65f..3e2100eab4 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -21,13 +21,19 @@ import logging from typing import Optional from unittest.mock import patch +from synapse.api.constants import EventUnsignedContentFields from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext -from synapse.types import JsonDict, create_requester +from synapse.rest import admin +from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.types import create_requester from synapse.visibility import filter_events_for_client, filter_events_for_server from tests import unittest +from tests.test_utils.event_injection import inject_event, inject_member_event +from tests.unittest import HomeserverTestCase from tests.utils import create_room logger = logging.getLogger(__name__) @@ -56,15 +62,31 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): # # before we do that, we persist some other events to act as state. - self._inject_visibility("@admin:hs", "joined") + self.get_success( + inject_visibility_event(self.hs, TEST_ROOM_ID, "@admin:hs", "joined") + ) for i in range(10): - self._inject_room_member("@resident%i:hs" % i) + self.get_success( + inject_member_event( + self.hs, + TEST_ROOM_ID, + "@resident%i:hs" % i, + "join", + ) + ) events_to_filter = [] for i in range(10): - user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server") - evt = self._inject_room_member(user, extra_content={"a": "b"}) + evt = self.get_success( + inject_member_event( + self.hs, + TEST_ROOM_ID, + "@user%i:%s" % (i, "test_server" if i == 5 else "other_server"), + "join", + extra_content={"a": "b"}, + ) + ) events_to_filter.append(evt) filtered = self.get_success( @@ -90,8 +112,19 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): def test_filter_outlier(self) -> None: # outlier events must be returned, for the good of the collective federation - self._inject_room_member("@resident:remote_hs") - self._inject_visibility("@resident:remote_hs", "joined") + self.get_success( + inject_member_event( + self.hs, + TEST_ROOM_ID, + "@resident:remote_hs", + "join", + ) + ) + self.get_success( + inject_visibility_event( + self.hs, TEST_ROOM_ID, "@resident:remote_hs", "joined" + ) + ) outlier = self._inject_outlier() self.assertEqual( @@ -110,7 +143,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): ) # it should also work when there are other events in the list - evt = self._inject_message("@unerased:local_hs") + evt = self.get_success( + inject_message_event(self.hs, TEST_ROOM_ID, "@unerased:local_hs") + ) filtered = self.get_success( filter_events_for_server( @@ -150,19 +185,34 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): # change in the middle of them. events_to_filter = [] - evt = self._inject_message("@unerased:local_hs") + evt = self.get_success( + inject_message_event(self.hs, TEST_ROOM_ID, "@unerased:local_hs") + ) events_to_filter.append(evt) - evt = self._inject_message("@erased:local_hs") + evt = self.get_success( + inject_message_event(self.hs, TEST_ROOM_ID, "@erased:local_hs") + ) events_to_filter.append(evt) - evt = self._inject_room_member("@joiner:remote_hs") + evt = self.get_success( + inject_member_event( + self.hs, + TEST_ROOM_ID, + "@joiner:remote_hs", + "join", + ) + ) events_to_filter.append(evt) - evt = self._inject_message("@unerased:local_hs") + evt = self.get_success( + inject_message_event(self.hs, TEST_ROOM_ID, "@unerased:local_hs") + ) events_to_filter.append(evt) - evt = self._inject_message("@erased:local_hs") + evt = self.get_success( + inject_message_event(self.hs, TEST_ROOM_ID, "@erased:local_hs") + ) events_to_filter.append(evt) # the erasey user gets erased @@ -200,76 +250,6 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): for i in (1, 4): self.assertNotIn("body", filtered[i].content) - def _inject_visibility(self, user_id: str, visibility: str) -> EventBase: - content = {"history_visibility": visibility} - builder = self.event_builder_factory.for_room_version( - RoomVersions.V1, - { - "type": "m.room.history_visibility", - "sender": user_id, - "state_key": "", - "room_id": TEST_ROOM_ID, - "content": content, - }, - ) - - event, unpersisted_context = self.get_success( - self.event_creation_handler.create_new_client_event(builder) - ) - context = self.get_success(unpersisted_context.persist(event)) - self.get_success(self._persistence.persist_event(event, context)) - return event - - def _inject_room_member( - self, - user_id: str, - membership: str = "join", - extra_content: Optional[JsonDict] = None, - ) -> EventBase: - content = {"membership": membership} - content.update(extra_content or {}) - builder = self.event_builder_factory.for_room_version( - RoomVersions.V1, - { - "type": "m.room.member", - "sender": user_id, - "state_key": user_id, - "room_id": TEST_ROOM_ID, - "content": content, - }, - ) - - event, unpersisted_context = self.get_success( - self.event_creation_handler.create_new_client_event(builder) - ) - context = self.get_success(unpersisted_context.persist(event)) - - self.get_success(self._persistence.persist_event(event, context)) - return event - - def _inject_message( - self, user_id: str, content: Optional[JsonDict] = None - ) -> EventBase: - if content is None: - content = {"body": "testytest", "msgtype": "m.text"} - builder = self.event_builder_factory.for_room_version( - RoomVersions.V1, - { - "type": "m.room.message", - "sender": user_id, - "room_id": TEST_ROOM_ID, - "content": content, - }, - ) - - event, unpersisted_context = self.get_success( - self.event_creation_handler.create_new_client_event(builder) - ) - context = self.get_success(unpersisted_context.persist(event)) - - self.get_success(self._persistence.persist_event(event, context)) - return event - def _inject_outlier(self) -> EventBase: builder = self.event_builder_factory.for_room_version( RoomVersions.V1, @@ -292,7 +272,122 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): return event -class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase): +class FilterEventsForClientTestCase(HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def test_joined_history_visibility(self) -> None: + # User joins and leaves room. Should be able to see the join and leave, + # and messages sent between the two, but not before or after. + + self.register_user("resident", "p1") + resident_token = self.login("resident", "p1") + room_id = self.helper.create_room_as("resident", tok=resident_token) + + self.get_success( + inject_visibility_event(self.hs, room_id, "@resident:test", "joined") + ) + before_event = self.get_success( + inject_message_event(self.hs, room_id, "@resident:test", body="before") + ) + join_event = self.get_success( + inject_member_event(self.hs, room_id, "@joiner:test", "join") + ) + during_event = self.get_success( + inject_message_event(self.hs, room_id, "@resident:test", body="during") + ) + leave_event = self.get_success( + inject_member_event(self.hs, room_id, "@joiner:test", "leave") + ) + after_event = self.get_success( + inject_message_event(self.hs, room_id, "@resident:test", body="after") + ) + + # We have to reload the events from the db, to ensure that prev_content is + # populated. + events_to_filter = [ + self.get_success( + self.hs.get_storage_controllers().main.get_event( + e.event_id, + get_prev_content=True, + ) + ) + for e in [ + before_event, + join_event, + during_event, + leave_event, + after_event, + ] + ] + + # Now run the events through the filter, and check that we can see the events + # we expect, and that the membership prop is as expected. + # + # We deliberately do the queries for both users upfront; this simulates + # concurrent queries on the server, and helps ensure that we aren't + # accidentally serving the same event object (with the same unsigned.membership + # property) to both users. + joiner_filtered_events = self.get_success( + filter_events_for_client( + self.hs.get_storage_controllers(), + "@joiner:test", + events_to_filter, + msc4115_membership_on_events=True, + ) + ) + resident_filtered_events = self.get_success( + filter_events_for_client( + self.hs.get_storage_controllers(), + "@resident:test", + events_to_filter, + msc4115_membership_on_events=True, + ) + ) + + # The joiner should be able to seem the join and leave, + # and messages sent between the two, but not before or after. + self.assertEqual( + [e.event_id for e in [join_event, during_event, leave_event]], + [e.event_id for e in joiner_filtered_events], + ) + self.assertEqual( + ["join", "join", "leave"], + [ + e.unsigned[EventUnsignedContentFields.MSC4115_MEMBERSHIP] + for e in joiner_filtered_events + ], + ) + + # The resident user should see all the events. + self.assertEqual( + [ + e.event_id + for e in [ + before_event, + join_event, + during_event, + leave_event, + after_event, + ] + ], + [e.event_id for e in resident_filtered_events], + ) + self.assertEqual( + ["join", "join", "join", "join", "join"], + [ + e.unsigned[EventUnsignedContentFields.MSC4115_MEMBERSHIP] + for e in resident_filtered_events + ], + ) + + +class FilterEventsOutOfBandEventsForClientTestCase( + unittest.FederatingHomeserverTestCase +): def test_out_of_band_invite_rejection(self) -> None: # this is where we have received an invite event over federation, and then # rejected it. @@ -341,15 +436,24 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase): ) # the invited user should be able to see both the invite and the rejection + filtered_events = self.get_success( + filter_events_for_client( + self.hs.get_storage_controllers(), + "@user:test", + [invite_event, reject_event], + msc4115_membership_on_events=True, + ) + ) self.assertEqual( - self.get_success( - filter_events_for_client( - self.hs.get_storage_controllers(), - "@user:test", - [invite_event, reject_event], - ) - ), - [invite_event, reject_event], + [e.event_id for e in filtered_events], + [e.event_id for e in [invite_event, reject_event]], + ) + self.assertEqual( + ["invite", "leave"], + [ + e.unsigned[EventUnsignedContentFields.MSC4115_MEMBERSHIP] + for e in filtered_events + ], ) # other users should see neither @@ -359,7 +463,39 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase): self.hs.get_storage_controllers(), "@other:test", [invite_event, reject_event], + msc4115_membership_on_events=True, ) ), [], ) + + +async def inject_visibility_event( + hs: HomeServer, + room_id: str, + sender: str, + visibility: str, +) -> EventBase: + return await inject_event( + hs, + type="m.room.history_visibility", + sender=sender, + state_key="", + room_id=room_id, + content={"history_visibility": visibility}, + ) + + +async def inject_message_event( + hs: HomeServer, + room_id: str, + sender: str, + body: Optional[str] = "testytest", +) -> EventBase: + return await inject_event( + hs, + type="m.room.message", + sender=sender, + room_id=room_id, + content={"body": body, "msgtype": "m.text"}, + ) -- cgit 1.5.1 From 7ab0f630da0ab16c4d5dc0603695df888e2a7ab0 Mon Sep 17 00:00:00 2001 From: devonh Date: Mon, 29 Apr 2024 15:23:05 +0000 Subject: Apply user `email` & `picture` during OIDC registration if present & selected (#17120) This change will apply the `email` & `picture` provided by OIDC to the new user account when registering a new user via OIDC. If the user is directed to the account details form, this change makes sure they have been selected before applying them, otherwise they are omitted. In particular, this change ensures the values are carried through when Synapse has consent configured, and the redirect to the consent form/s are followed. I have tested everything manually. Including: - with/without consent configured - allowing/not allowing the use of email/avatar (via `sso_auth_account_details.html`) - with/without automatic account detail population (by un/commenting the `localpart_template` option in synapse config). ### Pull Request Checklist * [X] Pull request is based on the develop branch * [X] Pull request includes a [changelog file](https://element-hq.github.io/synapse/latest/development/contributing_guide.html#changelog). The entry should: - Be a short description of your change which makes sense to users. "Fixed a bug that prevented receiving messages from other servers." instead of "Moved X method from `EventStore` to `EventWorkerStore`.". - Use markdown where necessary, mostly for `code blocks`. - End with either a period (.) or an exclamation mark (!). - Start with a capital letter. - Feel free to credit yourself, by adding a sentence "Contributed by @github_username." or "Contributed by [Your Name]." to the end of the entry. * [X] [Code style](https://element-hq.github.io/synapse/latest/code_style.html) is correct (run the [linters](https://element-hq.github.io/synapse/latest/development/contributing_guide.html#run-the-linters)) --- changelog.d/17120.bugfix | 1 + docs/sso_mapping_providers.md | 1 + synapse/handlers/sso.py | 10 ++ synapse/rest/synapse/client/pick_username.py | 4 +- tests/rest/client/test_login.py | 204 +++++++++++++++++++++++++-- 5 files changed, 205 insertions(+), 15 deletions(-) create mode 100644 changelog.d/17120.bugfix (limited to 'tests') diff --git a/changelog.d/17120.bugfix b/changelog.d/17120.bugfix new file mode 100644 index 0000000000..85b34c2e98 --- /dev/null +++ b/changelog.d/17120.bugfix @@ -0,0 +1 @@ +Apply user email & picture during OIDC registration if present & selected. diff --git a/docs/sso_mapping_providers.md b/docs/sso_mapping_providers.md index 10c695029f..d6c4e860ae 100644 --- a/docs/sso_mapping_providers.md +++ b/docs/sso_mapping_providers.md @@ -98,6 +98,7 @@ A custom mapping provider must specify the following methods: either accept this localpart or pick their own username. Otherwise this option has no effect. If omitted, defaults to `False`. - `display_name`: An optional string, the display name for the user. + - `picture`: An optional string, the avatar url for the user. - `emails`: A list of strings, the email address(es) to associate with this user. If omitted, defaults to an empty list. * `async def get_extra_attributes(self, userinfo, token)` diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 8e39e76c97..f275d4f35a 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -169,6 +169,7 @@ class UsernameMappingSession: # attributes returned by the ID mapper display_name: Optional[str] emails: StrCollection + avatar_url: Optional[str] # An optional dictionary of extra attributes to be provided to the client in the # login response. @@ -183,6 +184,7 @@ class UsernameMappingSession: # choices made by the user chosen_localpart: Optional[str] = None use_display_name: bool = True + use_avatar: bool = True emails_to_use: StrCollection = () terms_accepted_version: Optional[str] = None @@ -660,6 +662,9 @@ class SsoHandler: remote_user_id=remote_user_id, display_name=attributes.display_name, emails=attributes.emails, + avatar_url=attributes.picture, + # Default to using all mapped emails. Will be overwritten in handle_submit_username_request. + emails_to_use=attributes.emails, client_redirect_url=client_redirect_url, expiry_time_ms=now + self._MAPPING_SESSION_VALIDITY_PERIOD_MS, extra_login_attributes=extra_login_attributes, @@ -966,6 +971,7 @@ class SsoHandler: session_id: str, localpart: str, use_display_name: bool, + use_avatar: bool, emails_to_use: Iterable[str], ) -> None: """Handle a request to the username-picker 'submit' endpoint @@ -988,6 +994,7 @@ class SsoHandler: # update the session with the user's choices session.chosen_localpart = localpart session.use_display_name = use_display_name + session.use_avatar = use_avatar emails_from_idp = set(session.emails) filtered_emails: Set[str] = set() @@ -1068,6 +1075,9 @@ class SsoHandler: if session.use_display_name: attributes.display_name = session.display_name + if session.use_avatar: + attributes.picture = session.avatar_url + # the following will raise a 400 error if the username has been taken in the # meantime. user_id = await self._register_mapped_user( diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py index e671774aeb..7d16b796d4 100644 --- a/synapse/rest/synapse/client/pick_username.py +++ b/synapse/rest/synapse/client/pick_username.py @@ -113,6 +113,7 @@ class AccountDetailsResource(DirectServeHtmlResource): "display_name": session.display_name, "emails": session.emails, "localpart": localpart, + "avatar_url": session.avatar_url, }, } @@ -134,6 +135,7 @@ class AccountDetailsResource(DirectServeHtmlResource): try: localpart = parse_string(request, "username", required=True) use_display_name = parse_boolean(request, "use_display_name", default=False) + use_avatar = parse_boolean(request, "use_avatar", default=False) try: emails_to_use: List[str] = [ @@ -147,5 +149,5 @@ class AccountDetailsResource(DirectServeHtmlResource): return await self._sso_handler.handle_submit_username_request( - request, session_id, localpart, use_display_name, emails_to_use + request, session_id, localpart, use_display_name, use_avatar, emails_to_use ) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index 3a1f150082..3fb77fd9dd 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -20,7 +20,17 @@ # import time import urllib.parse -from typing import Any, Collection, Dict, List, Optional, Tuple, Union +from typing import ( + Any, + BinaryIO, + Callable, + Collection, + Dict, + List, + Optional, + Tuple, + Union, +) from unittest.mock import Mock from urllib.parse import urlencode @@ -34,8 +44,9 @@ import synapse.rest.admin from synapse.api.constants import ApprovalNoticeMedium, LoginType from synapse.api.errors import Codes from synapse.appservice import ApplicationService +from synapse.http.client import RawHeaders from synapse.module_api import ModuleApi -from synapse.rest.client import devices, login, logout, register +from synapse.rest.client import account, devices, login, logout, profile, register from synapse.rest.client.account import WhoamiRestServlet from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.server import HomeServer @@ -48,6 +59,7 @@ from tests.handlers.test_saml import has_saml2 from tests.rest.client.utils import TEST_OIDC_CONFIG from tests.server import FakeChannel from tests.test_utils.html_parsers import TestHtmlParser +from tests.test_utils.oidc import FakeOidcServer from tests.unittest import HomeserverTestCase, override_config, skip_unless try: @@ -1421,7 +1433,19 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): class UsernamePickerTestCase(HomeserverTestCase): """Tests for the username picker flow of SSO login""" - servlets = [login.register_servlets] + servlets = [ + login.register_servlets, + profile.register_servlets, + account.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.http_client = Mock(spec=["get_file"]) + self.http_client.get_file.side_effect = mock_get_file + hs = self.setup_test_homeserver( + proxied_blocklisted_http_client=self.http_client + ) + return hs def default_config(self) -> Dict[str, Any]: config = super().default_config() @@ -1430,7 +1454,11 @@ class UsernamePickerTestCase(HomeserverTestCase): config["oidc_config"] = {} config["oidc_config"].update(TEST_OIDC_CONFIG) config["oidc_config"]["user_mapping_provider"] = { - "config": {"display_name_template": "{{ user.displayname }}"} + "config": { + "display_name_template": "{{ user.displayname }}", + "email_template": "{{ user.email }}", + "picture_template": "{{ user.picture }}", + } } # whitelist this client URI so we redirect straight to it rather than @@ -1443,15 +1471,22 @@ class UsernamePickerTestCase(HomeserverTestCase): d.update(build_synapse_client_resource_tree(self.hs)) return d - def test_username_picker(self) -> None: - """Test the happy path of a username picker flow.""" - - fake_oidc_server = self.helper.fake_oidc_server() - + def proceed_to_username_picker_page( + self, + fake_oidc_server: FakeOidcServer, + displayname: str, + email: str, + picture: str, + ) -> Tuple[str, str]: # do the start of the login flow channel, _ = self.helper.auth_via_oidc( fake_oidc_server, - {"sub": "tester", "displayname": "Jonny"}, + { + "sub": "tester", + "displayname": displayname, + "picture": picture, + "email": email, + }, TEST_CLIENT_REDIRECT_URL, ) @@ -1478,16 +1513,132 @@ class UsernamePickerTestCase(HomeserverTestCase): ) session = username_mapping_sessions[session_id] self.assertEqual(session.remote_user_id, "tester") - self.assertEqual(session.display_name, "Jonny") + self.assertEqual(session.display_name, displayname) + self.assertEqual(session.emails, [email]) + self.assertEqual(session.avatar_url, picture) self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL) # the expiry time should be about 15 minutes away expected_expiry = self.clock.time_msec() + (15 * 60 * 1000) self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000) + return picker_url, session_id + + def test_username_picker_use_displayname_avatar_and_email(self) -> None: + """Test the happy path of a username picker flow with using displayname, avatar and email.""" + + fake_oidc_server = self.helper.fake_oidc_server() + + mxid = "@bobby:test" + displayname = "Jonny" + email = "bobby@test.com" + picture = "mxc://test/avatar_url" + + picker_url, session_id = self.proceed_to_username_picker_page( + fake_oidc_server, displayname, email, picture + ) + + # Now, submit a username to the username picker, which should serve a redirect + # to the completion page. + # Also specify that we should use the provided displayname, avatar and email. + content = urlencode( + { + b"username": b"bobby", + b"use_display_name": b"true", + b"use_avatar": b"true", + b"use_email": email, + } + ).encode("utf8") + chan = self.make_request( + "POST", + path=picker_url, + content=content, + content_is_form=True, + custom_headers=[ + ("Cookie", "username_mapping_session=" + session_id), + # old versions of twisted don't do form-parsing without a valid + # content-length header. + ("Content-Length", str(len(content))), + ], + ) + self.assertEqual(chan.code, 302, chan.result) + location_headers = chan.headers.getRawHeaders("Location") + assert location_headers + + # send a request to the completion page, which should 302 to the client redirectUrl + chan = self.make_request( + "GET", + path=location_headers[0], + custom_headers=[("Cookie", "username_mapping_session=" + session_id)], + ) + self.assertEqual(chan.code, 302, chan.result) + location_headers = chan.headers.getRawHeaders("Location") + assert location_headers + + # ensure that the returned location matches the requested redirect URL + path, query = location_headers[0].split("?", 1) + self.assertEqual(path, "https://x") + + # it will have url-encoded the params properly, so we'll have to parse them + params = urllib.parse.parse_qsl( + query, keep_blank_values=True, strict_parsing=True, errors="strict" + ) + self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS) + self.assertEqual(params[2][0], "loginToken") + + # fish the login token out of the returned redirect uri + login_token = params[2][1] + + # finally, submit the matrix login token to the login API, which gives us our + # matrix access token, mxid, and device id. + chan = self.make_request( + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, + ) + self.assertEqual(chan.code, 200, chan.result) + self.assertEqual(chan.json_body["user_id"], mxid) + + # ensure the displayname and avatar from the OIDC response have been configured for the user. + channel = self.make_request( + "GET", "/profile/" + mxid, access_token=chan.json_body["access_token"] + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertIn("mxc://test", channel.json_body["avatar_url"]) + self.assertEqual(displayname, channel.json_body["displayname"]) + + # ensure the email from the OIDC response has been configured for the user. + channel = self.make_request( + "GET", "/account/3pid", access_token=chan.json_body["access_token"] + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(email, channel.json_body["threepids"][0]["address"]) + + def test_username_picker_dont_use_displayname_avatar_or_email(self) -> None: + """Test the happy path of a username picker flow without using displayname, avatar or email.""" + + fake_oidc_server = self.helper.fake_oidc_server() + + mxid = "@bobby:test" + displayname = "Jonny" + email = "bobby@test.com" + picture = "mxc://test/avatar_url" + username = "bobby" + + picker_url, session_id = self.proceed_to_username_picker_page( + fake_oidc_server, displayname, email, picture + ) + # Now, submit a username to the username picker, which should serve a redirect - # to the completion page - content = urlencode({b"username": b"bobby"}).encode("utf8") + # to the completion page. + # Also specify that we should not use the provided displayname, avatar or email. + content = urlencode( + { + b"username": username, + b"use_display_name": b"false", + b"use_avatar": b"false", + } + ).encode("utf8") chan = self.make_request( "POST", path=picker_url, @@ -1536,4 +1687,29 @@ class UsernamePickerTestCase(HomeserverTestCase): content={"type": "m.login.token", "token": login_token}, ) self.assertEqual(chan.code, 200, chan.result) - self.assertEqual(chan.json_body["user_id"], "@bobby:test") + self.assertEqual(chan.json_body["user_id"], mxid) + + # ensure the displayname and avatar from the OIDC response have not been configured for the user. + channel = self.make_request( + "GET", "/profile/" + mxid, access_token=chan.json_body["access_token"] + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertNotIn("avatar_url", channel.json_body) + self.assertEqual(username, channel.json_body["displayname"]) + + # ensure the email from the OIDC response has not been configured for the user. + channel = self.make_request( + "GET", "/account/3pid", access_token=chan.json_body["access_token"] + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertListEqual([], channel.json_body["threepids"]) + + +async def mock_get_file( + url: str, + output_stream: BinaryIO, + max_size: Optional[int] = None, + headers: Optional[RawHeaders] = None, + is_allowed_content_type: Optional[Callable[[str], bool]] = None, +) -> Tuple[int, Dict[bytes, List[bytes]], str, int]: + return 0, {b"Content-Type": [b"image/png"]}, "", 200 -- cgit 1.5.1 From 37558d5e4cd22ec8f120d2c0fbb8c9842d6dd131 Mon Sep 17 00:00:00 2001 From: Shay Date: Wed, 1 May 2024 09:45:17 -0700 Subject: Add support for MSC3823 - Account Suspension (#17051) --- changelog.d/17051.feature | 1 + synapse/_scripts/synapse_port_db.py | 2 +- synapse/handlers/room_member.py | 30 ++++++++++ synapse/storage/databases/main/registration.py | 55 ++++++++++++++++- synapse/storage/schema/__init__.py | 5 +- .../schema/main/delta/85/01_add_suspended.sql | 14 +++++ synapse/types/__init__.py | 2 + tests/rest/client/test_rooms.py | 69 +++++++++++++++++++++- tests/storage/test_registration.py | 2 +- 9 files changed, 173 insertions(+), 7 deletions(-) create mode 100644 changelog.d/17051.feature create mode 100644 synapse/storage/schema/main/delta/85/01_add_suspended.sql (limited to 'tests') diff --git a/changelog.d/17051.feature b/changelog.d/17051.feature new file mode 100644 index 0000000000..1c41f49f7d --- /dev/null +++ b/changelog.d/17051.feature @@ -0,0 +1 @@ +Add preliminary support for [MSC3823](https://github.com/matrix-org/matrix-spec-proposals/pull/3823) - Account Suspension. \ No newline at end of file diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 15507372a4..1e56f46911 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -127,7 +127,7 @@ BOOLEAN_COLUMNS = { "redactions": ["have_censored"], "room_stats_state": ["is_federatable"], "rooms": ["is_public", "has_auth_chain_index"], - "users": ["shadow_banned", "approved", "locked"], + "users": ["shadow_banned", "approved", "locked", "suspended"], "un_partial_stated_event_stream": ["rejection_status_changed"], "users_who_share_rooms": ["share_private"], "per_user_experimental_features": ["enabled"], diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 601d37341b..655c78e150 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -752,6 +752,36 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): and requester.user.to_string() == self._server_notices_mxid ) + requester_suspended = await self.store.get_user_suspended_status( + requester.user.to_string() + ) + if action == Membership.INVITE and requester_suspended: + raise SynapseError( + 403, + "Sending invites while account is suspended is not allowed.", + Codes.USER_ACCOUNT_SUSPENDED, + ) + + if target.to_string() != requester.user.to_string(): + target_suspended = await self.store.get_user_suspended_status( + target.to_string() + ) + else: + target_suspended = requester_suspended + + if action == Membership.JOIN and target_suspended: + raise SynapseError( + 403, + "Joining rooms while account is suspended is not allowed.", + Codes.USER_ACCOUNT_SUSPENDED, + ) + if action == Membership.KNOCK and target_suspended: + raise SynapseError( + 403, + "Knocking on rooms while account is suspended is not allowed.", + Codes.USER_ACCOUNT_SUSPENDED, + ) + if ( not self.allow_per_room_profiles and not is_requester_server_notices_user ) or requester.shadow_banned: diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 29bf47befc..df7f8a43b7 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -236,7 +236,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): consent_server_notice_sent, appservice_id, creation_ts, user_type, deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned, COALESCE(approved, TRUE) AS approved, - COALESCE(locked, FALSE) AS locked + COALESCE(locked, FALSE) AS locked, + suspended FROM users WHERE name = ? """, @@ -261,6 +262,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): shadow_banned, approved, locked, + suspended, ) = row return UserInfo( @@ -277,6 +279,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): user_type=user_type, approved=bool(approved), locked=bool(locked), + suspended=bool(suspended), ) return await self.db_pool.runInteraction( @@ -1180,6 +1183,27 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): # Convert the potential integer into a boolean. return bool(res) + @cached() + async def get_user_suspended_status(self, user_id: str) -> bool: + """ + Determine whether the user's account is suspended. + Args: + user_id: The user ID of the user in question + Returns: + True if the user's account is suspended, false if it is not suspended or + if the user ID cannot be found. + """ + + res = await self.db_pool.simple_select_one_onecol( + table="users", + keyvalues={"name": user_id}, + retcol="suspended", + allow_none=True, + desc="get_user_suspended", + ) + + return bool(res) + async def get_threepid_validation_session( self, medium: Optional[str], @@ -2213,6 +2237,35 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) txn.call_after(self.is_guest.invalidate, (user_id,)) + async def set_user_suspended_status(self, user_id: str, suspended: bool) -> None: + """ + Set whether the user's account is suspended in the `users` table. + + Args: + user_id: The user ID of the user in question + suspended: True if the user is suspended, false if not + """ + await self.db_pool.runInteraction( + "set_user_suspended_status", + self.set_user_suspended_status_txn, + user_id, + suspended, + ) + + def set_user_suspended_status_txn( + self, txn: LoggingTransaction, user_id: str, suspended: bool + ) -> None: + self.db_pool.simple_update_one_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"suspended": suspended}, + ) + self._invalidate_cache_and_stream( + txn, self.get_user_suspended_status, (user_id,) + ) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + async def set_user_locked_status(self, user_id: str, locked: bool) -> None: """Set the `locked` property for the provided user to the provided value. diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 039aa91b92..0dc5d24249 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -19,7 +19,7 @@ # # -SCHEMA_VERSION = 84 # remember to update the list below when updating +SCHEMA_VERSION = 85 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -136,6 +136,9 @@ Changes in SCHEMA_VERSION = 83 Changes in SCHEMA_VERSION = 84 - No longer assumes that `event_auth_chain_links` holds transitive links, and so read operations must do graph traversal. + +Changes in SCHEMA_VERSION = 85 + - Add a column `suspended` to the `users` table """ diff --git a/synapse/storage/schema/main/delta/85/01_add_suspended.sql b/synapse/storage/schema/main/delta/85/01_add_suspended.sql new file mode 100644 index 0000000000..807aad374f --- /dev/null +++ b/synapse/storage/schema/main/delta/85/01_add_suspended.sql @@ -0,0 +1,14 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2024 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +ALTER TABLE users ADD COLUMN suspended BOOLEAN DEFAULT FALSE NOT NULL; \ No newline at end of file diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index a88982a04c..509a2d3a0f 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -1156,6 +1156,7 @@ class UserInfo: user_type: User type (None for normal user, 'support' and 'bot' other options). approved: If the user has been "approved" to register on the server. locked: Whether the user's account has been locked + suspended: Whether the user's account is currently suspended """ user_id: UserID @@ -1171,6 +1172,7 @@ class UserInfo: is_shadow_banned: bool approved: bool locked: bool + suspended: bool class UserProfile(TypedDict): diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index b796163dcb..d398cead1c 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -48,7 +48,16 @@ from synapse.appservice import ApplicationService from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.rest import admin -from synapse.rest.client import account, directory, login, profile, register, room, sync +from synapse.rest.client import ( + account, + directory, + knock, + login, + profile, + register, + room, + sync, +) from synapse.server import HomeServer from synapse.types import JsonDict, RoomAlias, UserID, create_requester from synapse.util import Clock @@ -733,7 +742,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(32, channel.resource_usage.db_txn_count) + self.assertEqual(33, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -746,7 +755,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(34, channel.resource_usage.db_txn_count) + self.assertEqual(35, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id @@ -1154,6 +1163,7 @@ class RoomJoinTestCase(RoomBase): admin.register_servlets, login.register_servlets, room.register_servlets, + knock.register_servlets, ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -1167,6 +1177,8 @@ class RoomJoinTestCase(RoomBase): self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) + self.store = hs.get_datastores().main + def test_spam_checker_may_join_room_deprecated(self) -> None: """Tests that the user_may_join_room spam checker callback is correctly called and blocks room joins when needed. @@ -1317,6 +1329,57 @@ class RoomJoinTestCase(RoomBase): expect_additional_fields=return_value[1], ) + def test_suspended_user_cannot_join_room(self) -> None: + # set the user as suspended + self.get_success(self.store.set_user_suspended_status(self.user2, True)) + + channel = self.make_request( + "POST", f"/join/{self.room1}", access_token=self.tok2 + ) + self.assertEqual(channel.code, 403) + self.assertEqual( + channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" + ) + + channel = self.make_request( + "POST", f"/rooms/{self.room1}/join", access_token=self.tok2 + ) + self.assertEqual(channel.code, 403) + self.assertEqual( + channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" + ) + + def test_suspended_user_cannot_knock_on_room(self) -> None: + # set the user as suspended + self.get_success(self.store.set_user_suspended_status(self.user2, True)) + + channel = self.make_request( + "POST", + f"/_matrix/client/v3/knock/{self.room1}", + access_token=self.tok2, + content={}, + shorthand=False, + ) + self.assertEqual(channel.code, 403) + self.assertEqual( + channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" + ) + + def test_suspended_user_cannot_invite_to_room(self) -> None: + # set the user as suspended + self.get_success(self.store.set_user_suspended_status(self.user1, True)) + + # first user invites second user + channel = self.make_request( + "POST", + f"/rooms/{self.room1}/invite", + access_token=self.tok1, + content={"user_id": self.user2}, + ) + self.assertEqual( + channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" + ) + class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase): servlets = [ diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 505465d529..14e3871dc1 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -43,7 +43,6 @@ class RegistrationStoreTestCase(HomeserverTestCase): self.assertEqual( UserInfo( - # TODO(paul): Surely this field should be 'user_id', not 'name' user_id=UserID.from_string(self.user_id), is_admin=False, is_guest=False, @@ -57,6 +56,7 @@ class RegistrationStoreTestCase(HomeserverTestCase): locked=False, is_shadow_banned=False, approved=True, + suspended=False, ), (self.get_success(self.store.get_user_by_id(self.user_id))), ) -- cgit 1.5.1 From 3e6ee8ff88c41ad1fca8c055520be952ab21b705 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 6 May 2024 12:56:52 +0100 Subject: Add optimisation to `StreamChangeCache` (#17130) When there have been lots of changes compared with the number of entities, we can do a fast(er) path. Locally I ran some benchmarking, and the comparison seems to give the best determination of which method we use. --- changelog.d/17130.misc | 1 + synapse/util/caches/stream_change_cache.py | 20 +++++++++++++++++++- tests/util/test_stream_change_cache.py | 17 ++++++++++++++--- 3 files changed, 34 insertions(+), 4 deletions(-) create mode 100644 changelog.d/17130.misc (limited to 'tests') diff --git a/changelog.d/17130.misc b/changelog.d/17130.misc new file mode 100644 index 0000000000..ac20c90bde --- /dev/null +++ b/changelog.d/17130.misc @@ -0,0 +1 @@ +Add optimisation to `StreamChangeCache.get_entities_changed(..)`. diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 2079ca789c..91c335f85b 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -165,7 +165,7 @@ class StreamChangeCache: return False def get_entities_changed( - self, entities: Collection[EntityType], stream_pos: int + self, entities: Collection[EntityType], stream_pos: int, _perf_factor: int = 1 ) -> Union[Set[EntityType], FrozenSet[EntityType]]: """ Returns the subset of the given entities that have had changes after the given position. @@ -177,6 +177,8 @@ class StreamChangeCache: Args: entities: Entities to check for changes. stream_pos: The stream position to check for changes after. + _perf_factor: Used by unit tests to choose when to use each + optimisation. Return: A subset of entities which have changed after the given stream position. @@ -184,6 +186,22 @@ class StreamChangeCache: This will be all entities if the given stream position is at or earlier than the earliest known stream position. """ + if not self._cache or stream_pos <= self._earliest_known_stream_pos: + self.metrics.inc_misses() + return set(entities) + + # If there have been tonnes of changes compared with the number of + # entities, it is faster to check each entities stream ordering + # one-by-one. + max_stream_pos, _ = self._cache.peekitem() + if max_stream_pos - stream_pos > _perf_factor * len(entities): + self.metrics.inc_hits() + return { + entity + for entity in entities + if self._entity_to_key.get(entity, -1) > stream_pos + } + cache_result = self.get_all_entities_changed(stream_pos) if cache_result.hit: # We now do an intersection, trying to do so in the most efficient diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index 3df053493b..5d38718a50 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -1,3 +1,5 @@ +from parameterized import parameterized + from synapse.util.caches.stream_change_cache import StreamChangeCache from tests import unittest @@ -161,7 +163,8 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): self.assertFalse(cache.has_any_entity_changed(2)) self.assertFalse(cache.has_any_entity_changed(3)) - def test_get_entities_changed(self) -> None: + @parameterized.expand([(0,), (1000000000,)]) + def test_get_entities_changed(self, perf_factor: int) -> None: """ StreamChangeCache.get_entities_changed will return the entities in the given list that have changed since the provided stream ID. If the @@ -178,7 +181,9 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): # get the ones after that point. self.assertEqual( cache.get_entities_changed( - ["user@foo.com", "bar@baz.net", "user@elsewhere.org"], stream_pos=2 + ["user@foo.com", "bar@baz.net", "user@elsewhere.org"], + stream_pos=2, + _perf_factor=perf_factor, ), {"bar@baz.net", "user@elsewhere.org"}, ) @@ -195,6 +200,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): "not@here.website", ], stream_pos=2, + _perf_factor=perf_factor, ), {"bar@baz.net", "user@elsewhere.org"}, ) @@ -210,6 +216,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): "not@here.website", ], stream_pos=0, + _perf_factor=perf_factor, ), {"user@foo.com", "bar@baz.net", "user@elsewhere.org", "not@here.website"}, ) @@ -217,7 +224,11 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): # Query a subset of the entries mid-way through the stream. We should # only get back the subset. self.assertEqual( - cache.get_entities_changed(["bar@baz.net"], stream_pos=2), + cache.get_entities_changed( + ["bar@baz.net"], + stream_pos=2, + _perf_factor=perf_factor, + ), {"bar@baz.net"}, ) -- cgit 1.5.1