From a648a06d52715d0d4ad1ec72d042df1b3fd1be71 Mon Sep 17 00:00:00 2001 From: Shay Date: Wed, 3 Aug 2022 10:19:34 -0700 Subject: Add some tracing spans to give insight into local joins (#13439) --- synapse/handlers/message.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) (limited to 'synapse/handlers/message.py') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index ee0773988e..6b03603598 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -52,6 +52,7 @@ from synapse.events.builder import EventBuilder from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.handlers.directory import DirectoryHandler +from synapse.logging import opentracing from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet @@ -1374,9 +1375,10 @@ class EventCreationHandler: # and `state_groups` because they have `prev_events` that aren't persisted yet # (historical messages persisted in reverse-chronological order). if not event.internal_metadata.is_historical(): - await self._bulk_push_rule_evaluator.action_for_event_by_user( - event, context - ) + with opentracing.start_active_span("calculate_push_actions"): + await self._bulk_push_rule_evaluator.action_for_event_by_user( + event, context + ) try: # If we're a worker we need to hit out to the master. @@ -1463,9 +1465,10 @@ class EventCreationHandler: state = await state_entry.get_state( self._storage_controllers.state, StateFilter.all() ) - joined_hosts = await self.store.get_joined_hosts( - event.room_id, state, state_entry - ) + with opentracing.start_active_span("get_joined_hosts"): + joined_hosts = await self.store.get_joined_hosts( + event.room_id, state, state_entry + ) # Note that the expiry times must be larger than the expiry time in # _external_cache_joined_hosts_updates. -- cgit 1.5.1 From c3516e9decc355b75a297d72a13b98a43d312e66 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Tue, 16 Aug 2022 12:16:56 +0000 Subject: Faster room joins: make `/joined_members` block whilst the room is partial stated. (#13514) --- changelog.d/13514.bugfix | 1 + synapse/handlers/message.py | 6 +++++- synapse/storage/controllers/state.py | 13 +++++++++++++ synapse/storage/databases/main/roommember.py | 3 +++ 4 files changed, 22 insertions(+), 1 deletion(-) create mode 100644 changelog.d/13514.bugfix (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/13514.bugfix b/changelog.d/13514.bugfix new file mode 100644 index 0000000000..7498af0e47 --- /dev/null +++ b/changelog.d/13514.bugfix @@ -0,0 +1 @@ +Faster room joins: make `/joined_members` block whilst the room is partial stated. \ No newline at end of file diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 6b03603598..8f29ee9a87 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -331,7 +331,11 @@ class MessageHandler: msg="Getting joined members while not being a current member of the room is forbidden.", ) - users_with_profile = await self.store.get_users_in_room_with_profiles(room_id) + users_with_profile = ( + await self._state_storage_controller.get_users_in_room_with_profiles( + room_id + ) + ) # If this is an AS, double check that they are allowed to see the members. # This can either be because the AS user is in the room or because there diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 0d480f1014..0c78eb735e 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -30,6 +30,7 @@ from typing import ( from synapse.api.constants import EventTypes from synapse.events import EventBase from synapse.logging.opentracing import trace +from synapse.storage.roommember import ProfileInfo from synapse.storage.state import StateFilter from synapse.storage.util.partial_state_events_tracker import ( PartialCurrentStateTracker, @@ -506,3 +507,15 @@ class StateStorageController: await self._partial_state_room_tracker.await_full_state(room_id) return await self.stores.main.get_current_hosts_in_room(room_id) + + async def get_users_in_room_with_profiles( + self, room_id: str + ) -> Dict[str, ProfileInfo]: + """ + Get the current users in the room with their profiles. + If the room is currently partial-stated, this will block until the room has + full state. + """ + await self._partial_state_room_tracker.await_full_state(room_id) + + return await self.stores.main.get_users_in_room_with_profiles(room_id) diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 93ff4816c8..5e5f607a14 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -283,6 +283,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): Returns: A mapping from user ID to ProfileInfo. + + Preconditions: + - There is full state available for the room (it is not partial-stated). """ def _get_users_in_room_with_profiles( -- cgit 1.5.1 From 3dd175b628bab5638165f20de9eade36a4e88147 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 22 Aug 2022 15:17:59 +0200 Subject: `synapse.api.auth.Auth` cleanup: make permission-related methods use `Requester` instead of the `UserID` (#13024) Part of #13019 This changes all the permission-related methods to rely on the Requester instead of the UserID. This is a first step towards enabling scoped access tokens at some point, since I expect the Requester to have scope-related informations in it. It also changes methods which figure out the user/device/appservice out of the access token to return a Requester instead of something else. This avoids having store-related objects in the methods signatures. --- changelog.d/13024.misc | 1 + synapse/api/auth.py | 202 +++++++++++------------ synapse/handlers/auth.py | 17 +- synapse/handlers/directory.py | 24 ++- synapse/handlers/initial_sync.py | 6 +- synapse/handlers/message.py | 23 +-- synapse/handlers/pagination.py | 2 +- synapse/handlers/register.py | 15 +- synapse/handlers/relations.py | 2 +- synapse/handlers/room.py | 4 +- synapse/handlers/room_member.py | 10 +- synapse/handlers/typing.py | 10 +- synapse/http/site.py | 2 +- synapse/rest/admin/_base.py | 10 +- synapse/rest/admin/media.py | 6 +- synapse/rest/admin/rooms.py | 12 +- synapse/rest/admin/users.py | 15 +- synapse/rest/client/profile.py | 4 +- synapse/rest/client/register.py | 3 - synapse/rest/client/room.py | 13 +- synapse/server_notices/server_notices_manager.py | 2 +- synapse/storage/databases/main/registration.py | 2 +- tests/api/test_auth.py | 8 +- tests/handlers/test_typing.py | 8 +- tests/rest/client/test_retention.py | 4 +- tests/rest/client/test_shadow_banned.py | 6 +- 26 files changed, 203 insertions(+), 208 deletions(-) create mode 100644 changelog.d/13024.misc (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/13024.misc b/changelog.d/13024.misc new file mode 100644 index 0000000000..aa43c82429 --- /dev/null +++ b/changelog.d/13024.misc @@ -0,0 +1 @@ +Refactor methods in `synapse.api.auth.Auth` to use `Requester` objects everywhere instead of user IDs. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 523bad0c55..9a1aea083f 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -37,8 +37,7 @@ from synapse.logging.opentracing import ( start_active_span, trace, ) -from synapse.storage.databases.main.registration import TokenLookupResult -from synapse.types import Requester, UserID, create_requester +from synapse.types import Requester, create_requester if TYPE_CHECKING: from synapse.server import HomeServer @@ -70,14 +69,14 @@ class Auth: async def check_user_in_room( self, room_id: str, - user_id: str, + requester: Requester, allow_departed_users: bool = False, ) -> Tuple[str, Optional[str]]: """Check if the user is in the room, or was at some point. Args: room_id: The room to check. - user_id: The user to check. + requester: The user making the request, according to the access token. current_state: Optional map of the current state of the room. If provided then that map is used to check whether they are a @@ -94,6 +93,7 @@ class Auth: membership event ID of the user. """ + user_id = requester.user.to_string() ( membership, member_event_id, @@ -182,96 +182,69 @@ class Auth: access_token = self.get_access_token_from_request(request) - ( - user_id, - device_id, - app_service, - ) = await self._get_appservice_user_id_and_device_id(request) - if user_id and app_service: - if ip_addr and self._track_appservice_user_ips: - await self.store.insert_client_ip( - user_id=user_id, - access_token=access_token, - ip=ip_addr, - user_agent=user_agent, - device_id="dummy-device" - if device_id is None - else device_id, # stubbed - ) - - requester = create_requester( - user_id, app_service=app_service, device_id=device_id + # First check if it could be a request from an appservice + requester = await self._get_appservice_user(request) + if not requester: + # If not, it should be from a regular user + requester = await self.get_user_by_access_token( + access_token, allow_expired=allow_expired ) - request.requester = user_id - return requester - - user_info = await self.get_user_by_access_token( - access_token, allow_expired=allow_expired - ) - token_id = user_info.token_id - is_guest = user_info.is_guest - shadow_banned = user_info.shadow_banned - - # Deny the request if the user account has expired. - if not allow_expired: - if await self._account_validity_handler.is_user_expired( - user_info.user_id - ): - # Raise the error if either an account validity module has determined - # the account has expired, or the legacy account validity - # implementation is enabled and determined the account has expired - raise AuthError( - 403, - "User account has expired", - errcode=Codes.EXPIRED_ACCOUNT, - ) - - device_id = user_info.device_id - - if access_token and ip_addr: + # Deny the request if the user account has expired. + # This check is only done for regular users, not appservice ones. + if not allow_expired: + if await self._account_validity_handler.is_user_expired( + requester.user.to_string() + ): + # Raise the error if either an account validity module has determined + # the account has expired, or the legacy account validity + # implementation is enabled and determined the account has expired + raise AuthError( + 403, + "User account has expired", + errcode=Codes.EXPIRED_ACCOUNT, + ) + + if ip_addr and ( + not requester.app_service or self._track_appservice_user_ips + ): + # XXX(quenting): I'm 95% confident that we could skip setting the + # device_id to "dummy-device" for appservices, and that the only impact + # would be some rows which whould not deduplicate in the 'user_ips' + # table during the transition + recorded_device_id = ( + "dummy-device" + if requester.device_id is None and requester.app_service is not None + else requester.device_id + ) await self.store.insert_client_ip( - user_id=user_info.token_owner, + user_id=requester.authenticated_entity, access_token=access_token, ip=ip_addr, user_agent=user_agent, - device_id=device_id, + device_id=recorded_device_id, ) + # Track also the puppeted user client IP if enabled and the user is puppeting if ( - user_info.user_id != user_info.token_owner + requester.user.to_string() != requester.authenticated_entity and self._track_puppeted_user_ips ): await self.store.insert_client_ip( - user_id=user_info.user_id, + user_id=requester.user.to_string(), access_token=access_token, ip=ip_addr, user_agent=user_agent, - device_id=device_id, + device_id=requester.device_id, ) - if is_guest and not allow_guest: + if requester.is_guest and not allow_guest: raise AuthError( 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN, ) - # Mark the token as used. This is used to invalidate old refresh - # tokens after some time. - if not user_info.token_used and token_id is not None: - await self.store.mark_access_token_as_used(token_id) - - requester = create_requester( - user_info.user_id, - token_id, - is_guest, - shadow_banned, - device_id, - app_service=app_service, - authenticated_entity=user_info.token_owner, - ) - request.requester = requester return requester except KeyError: @@ -308,9 +281,7 @@ class Auth: 403, "Application service has not registered this user (%s)" % user_id ) - async def _get_appservice_user_id_and_device_id( - self, request: Request - ) -> Tuple[Optional[str], Optional[str], Optional[ApplicationService]]: + async def _get_appservice_user(self, request: Request) -> Optional[Requester]: """ Given a request, reads the request parameters to determine: - whether it's an application service that's making this request @@ -325,15 +296,13 @@ class Auth: Must use `org.matrix.msc3202.device_id` in place of `device_id` for now. Returns: - 3-tuple of - (user ID?, device ID?, application service?) + the application service `Requester` of that request Postconditions: - - If an application service is returned, so is a user ID - - A user ID is never returned without an application service - - A device ID is never returned without a user ID or an application service - - The returned application service, if present, is permitted to control the - returned user ID. + - The `app_service` field in the returned `Requester` is set + - The `user_id` field in the returned `Requester` is either the application + service sender or the controlled user set by the `user_id` URI parameter + - The returned application service is permitted to control the returned user ID. - The returned device ID, if present, has been checked to be a valid device ID for the returned user ID. """ @@ -343,12 +312,12 @@ class Auth: self.get_access_token_from_request(request) ) if app_service is None: - return None, None, None + return None if app_service.ip_range_whitelist: ip_address = IPAddress(request.getClientAddress().host) if ip_address not in app_service.ip_range_whitelist: - return None, None, None + return None # This will always be set by the time Twisted calls us. assert request.args is not None @@ -382,13 +351,15 @@ class Auth: Codes.EXCLUSIVE, ) - return effective_user_id, effective_device_id, app_service + return create_requester( + effective_user_id, app_service=app_service, device_id=effective_device_id + ) async def get_user_by_access_token( self, token: str, allow_expired: bool = False, - ) -> TokenLookupResult: + ) -> Requester: """Validate access token and get user_id from it Args: @@ -405,9 +376,9 @@ class Auth: # First look in the database to see if the access token is present # as an opaque token. - r = await self.store.get_user_by_access_token(token) - if r: - valid_until_ms = r.valid_until_ms + user_info = await self.store.get_user_by_access_token(token) + if user_info: + valid_until_ms = user_info.valid_until_ms if ( not allow_expired and valid_until_ms is not None @@ -419,7 +390,20 @@ class Auth: msg="Access token has expired", soft_logout=True ) - return r + # Mark the token as used. This is used to invalidate old refresh + # tokens after some time. + await self.store.mark_access_token_as_used(user_info.token_id) + + requester = create_requester( + user_id=user_info.user_id, + access_token_id=user_info.token_id, + is_guest=user_info.is_guest, + shadow_banned=user_info.shadow_banned, + device_id=user_info.device_id, + authenticated_entity=user_info.token_owner, + ) + + return requester # If the token isn't found in the database, then it could still be a # macaroon for a guest, so we check that here. @@ -445,11 +429,12 @@ class Auth: "Guest access token used for regular user" ) - return TokenLookupResult( + return create_requester( user_id=user_id, is_guest=True, # all guests get the same device id device_id=GUEST_DEVICE_ID, + authenticated_entity=user_id, ) except ( pymacaroons.exceptions.MacaroonException, @@ -472,32 +457,33 @@ class Auth: request.requester = create_requester(service.sender, app_service=service) return service - async def is_server_admin(self, user: UserID) -> bool: + async def is_server_admin(self, requester: Requester) -> bool: """Check if the given user is a local server admin. Args: - user: user to check + requester: The user making the request, according to the access token. Returns: True if the user is an admin """ - return await self.store.is_server_admin(user) + return await self.store.is_server_admin(requester.user) - async def check_can_change_room_list(self, room_id: str, user: UserID) -> bool: + async def check_can_change_room_list( + self, room_id: str, requester: Requester + ) -> bool: """Determine whether the user is allowed to edit the room's entry in the published room list. Args: - room_id - user + room_id: The room to check. + requester: The user making the request, according to the access token. """ - is_admin = await self.is_server_admin(user) + is_admin = await self.is_server_admin(requester) if is_admin: return True - user_id = user.to_string() - await self.check_user_in_room(room_id, user_id) + await self.check_user_in_room(room_id, requester) # We currently require the user is a "moderator" in the room. We do this # by checking if they would (theoretically) be able to change the @@ -516,7 +502,9 @@ class Auth: send_level = event_auth.get_send_level( EventTypes.CanonicalAlias, "", power_level_event ) - user_level = event_auth.get_user_power_level(user_id, auth_events) + user_level = event_auth.get_user_power_level( + requester.user.to_string(), auth_events + ) return user_level >= send_level @@ -574,16 +562,16 @@ class Auth: @trace async def check_user_in_room_or_world_readable( - self, room_id: str, user_id: str, allow_departed_users: bool = False + self, room_id: str, requester: Requester, allow_departed_users: bool = False ) -> Tuple[str, Optional[str]]: """Checks that the user is or was in the room or the room is world readable. If it isn't then an exception is raised. Args: - room_id: room to check - user_id: user to check - allow_departed_users: if True, accept users that were previously - members but have now departed + room_id: The room to check. + requester: The user making the request, according to the access token. + allow_departed_users: If True, accept users that were previously + members but have now departed. Returns: Resolves to the current membership of the user in the room and the @@ -598,7 +586,7 @@ class Auth: # * The user is a guest user, and has joined the room # else it will throw. return await self.check_user_in_room( - room_id, user_id, allow_departed_users=allow_departed_users + room_id, requester, allow_departed_users=allow_departed_users ) except AuthError: visibility = await self._storage_controllers.state.get_current_state_event( @@ -613,6 +601,6 @@ class Auth: raise UnstableSpecAuthError( 403, "User %s not in room %s, and room previews are disabled" - % (user_id, room_id), + % (requester.user, room_id), errcode=Codes.NOT_JOINED, ) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index bfa5535044..0327fc57a4 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -280,7 +280,7 @@ class AuthHandler: that it isn't stolen by re-authenticating them. Args: - requester: The user, as given by the access token + requester: The user making the request, according to the access token. request: The request sent by the client. @@ -1435,20 +1435,25 @@ class AuthHandler: access_token: access token to be deleted """ - user_info = await self.auth.get_user_by_access_token(access_token) + token = await self.store.get_user_by_access_token(access_token) + if not token: + # At this point, the token should already have been fetched once by + # the caller, so this should not happen, unless of a race condition + # between two delete requests + raise SynapseError(HTTPStatus.UNAUTHORIZED, "Unrecognised access token") await self.store.delete_access_token(access_token) # see if any modules want to know about this await self.password_auth_provider.on_logged_out( - user_id=user_info.user_id, - device_id=user_info.device_id, + user_id=token.user_id, + device_id=token.device_id, access_token=access_token, ) # delete pushers associated with this access token - if user_info.token_id is not None: + if token.token_id is not None: await self.hs.get_pusherpool().remove_pushers_by_access_token( - user_info.user_id, (user_info.token_id,) + token.user_id, (token.token_id,) ) async def delete_access_tokens_for_user( diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 09a7a4b238..948f66a94d 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -30,7 +30,7 @@ from synapse.api.errors import ( from synapse.appservice import ApplicationService from synapse.module_api import NOT_SPAM from synapse.storage.databases.main.directory import RoomAliasMapping -from synapse.types import JsonDict, Requester, RoomAlias, UserID, get_domain_from_id +from synapse.types import JsonDict, Requester, RoomAlias, get_domain_from_id if TYPE_CHECKING: from synapse.server import HomeServer @@ -133,7 +133,7 @@ class DirectoryHandler: else: # Server admins are not subject to the same constraints as normal # users when creating an alias (e.g. being in the room). - is_admin = await self.auth.is_server_admin(requester.user) + is_admin = await self.auth.is_server_admin(requester) if (self.require_membership and check_membership) and not is_admin: rooms_for_user = await self.store.get_rooms_for_user(user_id) @@ -197,7 +197,7 @@ class DirectoryHandler: user_id = requester.user.to_string() try: - can_delete = await self._user_can_delete_alias(room_alias, user_id) + can_delete = await self._user_can_delete_alias(room_alias, requester) except StoreError as e: if e.code == 404: raise NotFoundError("Unknown room alias") @@ -400,7 +400,9 @@ class DirectoryHandler: # either no interested services, or no service with an exclusive lock return True - async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str) -> bool: + async def _user_can_delete_alias( + self, alias: RoomAlias, requester: Requester + ) -> bool: """Determine whether a user can delete an alias. One of the following must be true: @@ -413,7 +415,7 @@ class DirectoryHandler: """ creator = await self.store.get_room_alias_creator(alias.to_string()) - if creator == user_id: + if creator == requester.user.to_string(): return True # Resolve the alias to the corresponding room. @@ -422,9 +424,7 @@ class DirectoryHandler: if not room_id: return False - return await self.auth.check_can_change_room_list( - room_id, UserID.from_string(user_id) - ) + return await self.auth.check_can_change_room_list(room_id, requester) async def edit_published_room_list( self, requester: Requester, room_id: str, visibility: str @@ -463,7 +463,7 @@ class DirectoryHandler: raise SynapseError(400, "Unknown room") can_change_room_list = await self.auth.check_can_change_room_list( - room_id, requester.user + room_id, requester ) if not can_change_room_list: raise AuthError( @@ -528,10 +528,8 @@ class DirectoryHandler: Get a list of the aliases that currently point to this room on this server """ # allow access to server admins and current members of the room - is_admin = await self.auth.is_server_admin(requester.user) + is_admin = await self.auth.is_server_admin(requester) if not is_admin: - await self.auth.check_user_in_room_or_world_readable( - room_id, requester.user.to_string() - ) + await self.auth.check_user_in_room_or_world_readable(room_id, requester) return await self.store.get_aliases_for_room(room_id) diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 6484e47e5f..860c82c110 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -309,18 +309,18 @@ class InitialSyncHandler: if blocked: raise SynapseError(403, "This room has been blocked on this server") - user_id = requester.user.to_string() - ( membership, member_event_id, ) = await self.auth.check_user_in_room_or_world_readable( room_id, - user_id, + requester, allow_departed_users=True, ) is_peeking = member_event_id is None + user_id = requester.user.to_string() + if membership == Membership.JOIN: result = await self._room_initial_sync_joined( user_id, room_id, pagin_config, membership, is_peeking diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 8f29ee9a87..acd3de06f6 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -104,7 +104,7 @@ class MessageHandler: async def get_room_data( self, - user_id: str, + requester: Requester, room_id: str, event_type: str, state_key: str, @@ -112,7 +112,7 @@ class MessageHandler: """Get data from a room. Args: - user_id + requester: The user who did the request. room_id event_type state_key @@ -125,7 +125,7 @@ class MessageHandler: membership, membership_event_id, ) = await self.auth.check_user_in_room_or_world_readable( - room_id, user_id, allow_departed_users=True + room_id, requester, allow_departed_users=True ) if membership == Membership.JOIN: @@ -161,11 +161,10 @@ class MessageHandler: async def get_state_events( self, - user_id: str, + requester: Requester, room_id: str, state_filter: Optional[StateFilter] = None, at_token: Optional[StreamToken] = None, - is_guest: bool = False, ) -> List[dict]: """Retrieve all state events for a given room. If the user is joined to the room then return the current state. If the user has @@ -174,14 +173,13 @@ class MessageHandler: visible. Args: - user_id: The user requesting state events. + requester: The user requesting state events. room_id: The room ID to get all state events from. state_filter: The state filter used to fetch state from the database. at_token: the stream token of the at which we are requesting the stats. If the user is not allowed to view the state as of that stream token, we raise a 403 SynapseError. If None, returns the current state based on the current_state_events table. - is_guest: whether this user is a guest Returns: A list of dicts representing state events. [{}, {}, {}] Raises: @@ -191,6 +189,7 @@ class MessageHandler: members of this room. """ state_filter = state_filter or StateFilter.all() + user_id = requester.user.to_string() if at_token: last_event_id = ( @@ -223,7 +222,7 @@ class MessageHandler: membership, membership_event_id, ) = await self.auth.check_user_in_room_or_world_readable( - room_id, user_id, allow_departed_users=True + room_id, requester, allow_departed_users=True ) if membership == Membership.JOIN: @@ -317,12 +316,11 @@ class MessageHandler: Returns: A dict of user_id to profile info """ - user_id = requester.user.to_string() if not requester.app_service: # We check AS auth after fetching the room membership, as it # requires us to pull out all joined members anyway. membership, _ = await self.auth.check_user_in_room_or_world_readable( - room_id, user_id, allow_departed_users=True + room_id, requester, allow_departed_users=True ) if membership != Membership.JOIN: raise SynapseError( @@ -340,7 +338,10 @@ class MessageHandler: # If this is an AS, double check that they are allowed to see the members. # This can either be because the AS user is in the room or because there # is a user in the room that the AS is "interested in" - if requester.app_service and user_id not in users_with_profile: + if ( + requester.app_service + and requester.user.to_string() not in users_with_profile + ): for uid in users_with_profile: if requester.app_service.is_interested_in_user(uid): break diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index e1e34e3b16..74e944bce7 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -464,7 +464,7 @@ class PaginationHandler: membership, member_event_id, ) = await self.auth.check_user_in_room_or_world_readable( - room_id, user_id, allow_departed_users=True + room_id, requester, allow_departed_users=True ) if pagin_config.direction == "b": diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index c77d181722..20ec22105a 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -29,7 +29,13 @@ from synapse.api.constants import ( JoinRules, LoginType, ) -from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError +from synapse.api.errors import ( + AuthError, + Codes, + ConsentNotGivenError, + InvalidClientTokenError, + SynapseError, +) from synapse.appservice import ApplicationService from synapse.config.server import is_threepid_reserved from synapse.http.servlet import assert_params_in_dict @@ -180,10 +186,7 @@ class RegistrationHandler: ) if guest_access_token: user_data = await self.auth.get_user_by_access_token(guest_access_token) - if ( - not user_data.is_guest - or UserID.from_string(user_data.user_id).localpart != localpart - ): + if not user_data.is_guest or user_data.user.localpart != localpart: raise AuthError( 403, "Cannot register taken user ID without valid guest " @@ -618,7 +621,7 @@ class RegistrationHandler: user_id = user.to_string() service = self.store.get_app_service_by_token(as_token) if not service: - raise AuthError(403, "Invalid application service token.") + raise InvalidClientTokenError() if not service.is_interested_in_user(user_id): raise SynapseError( 400, diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 72d25df8c8..28d7093f08 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -103,7 +103,7 @@ class RelationsHandler: # TODO Properly handle a user leaving a room. (_, member_event_id) = await self._auth.check_user_in_room_or_world_readable( - room_id, user_id, allow_departed_users=True + room_id, requester, allow_departed_users=True ) # This gets the original event and checks that a) the event exists and diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 55395457c3..2bf0ebd025 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -721,7 +721,7 @@ class RoomCreationHandler: # allow the server notices mxid to create rooms is_requester_admin = True else: - is_requester_admin = await self.auth.is_server_admin(requester.user) + is_requester_admin = await self.auth.is_server_admin(requester) # Let the third party rules modify the room creation config if needed, or abort # the room creation entirely with an exception. @@ -1279,7 +1279,7 @@ class RoomContextHandler: """ user = requester.user if use_admin_priviledge: - await assert_user_is_admin(self.auth, requester.user) + await assert_user_is_admin(self.auth, requester) before_limit = math.floor(limit / 2.0) after_limit = limit - before_limit diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 70dc69c809..d1909665d6 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -179,7 +179,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): """Try and join a room that this server is not in Args: - requester + requester: The user making the request, according to the access token. remote_room_hosts: List of servers that can be used to join via. room_id: Room that we are trying to join user: User who is trying to join @@ -744,7 +744,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): is_requester_admin = True else: - is_requester_admin = await self.auth.is_server_admin(requester.user) + is_requester_admin = await self.auth.is_server_admin(requester) if not is_requester_admin: if self.config.server.block_non_admin_invites: @@ -868,7 +868,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): bypass_spam_checker = True else: - bypass_spam_checker = await self.auth.is_server_admin(requester.user) + bypass_spam_checker = await self.auth.is_server_admin(requester) inviter = await self._get_inviter(target.to_string(), room_id) if ( @@ -1410,7 +1410,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ShadowBanError if the requester has been shadow-banned. """ if self.config.server.block_non_admin_invites: - is_requester_admin = await self.auth.is_server_admin(requester.user) + is_requester_admin = await self.auth.is_server_admin(requester) if not is_requester_admin: raise SynapseError( 403, "Invites have been disabled on this server", Codes.FORBIDDEN @@ -1693,7 +1693,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): check_complexity and self.hs.config.server.limit_remote_rooms.admins_can_join ): - check_complexity = not await self.auth.is_server_admin(user) + check_complexity = not await self.store.is_server_admin(user) if check_complexity: # Fetch the room complexity diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 27aa0d3126..bcac3372a2 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -253,12 +253,11 @@ class TypingWriterHandler(FollowerTypingHandler): self, target_user: UserID, requester: Requester, room_id: str, timeout: int ) -> None: target_user_id = target_user.to_string() - auth_user_id = requester.user.to_string() if not self.is_mine_id(target_user_id): raise SynapseError(400, "User is not hosted on this homeserver") - if target_user_id != auth_user_id: + if target_user != requester.user: raise AuthError(400, "Cannot set another user's typing state") if requester.shadow_banned: @@ -266,7 +265,7 @@ class TypingWriterHandler(FollowerTypingHandler): await self.clock.sleep(random.randint(1, 10)) raise ShadowBanError() - await self.auth.check_user_in_room(room_id, target_user_id) + await self.auth.check_user_in_room(room_id, requester) logger.debug("%s has started typing in %s", target_user_id, room_id) @@ -289,12 +288,11 @@ class TypingWriterHandler(FollowerTypingHandler): self, target_user: UserID, requester: Requester, room_id: str ) -> None: target_user_id = target_user.to_string() - auth_user_id = requester.user.to_string() if not self.is_mine_id(target_user_id): raise SynapseError(400, "User is not hosted on this homeserver") - if target_user_id != auth_user_id: + if target_user != requester.user: raise AuthError(400, "Cannot set another user's typing state") if requester.shadow_banned: @@ -302,7 +300,7 @@ class TypingWriterHandler(FollowerTypingHandler): await self.clock.sleep(random.randint(1, 10)) raise ShadowBanError() - await self.auth.check_user_in_room(room_id, target_user_id) + await self.auth.check_user_in_room(room_id, requester) logger.debug("%s has stopped typing in %s", target_user_id, room_id) diff --git a/synapse/http/site.py b/synapse/http/site.py index eeec74b78a..1155f3f610 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -226,7 +226,7 @@ class SynapseRequest(Request): # If this is a request where the target user doesn't match the user who # authenticated (e.g. and admin is puppetting a user) then we return both. - if self._requester.user.to_string() != authenticated_entity: + if requester != authenticated_entity: return requester, authenticated_entity return requester, None diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py index 399b205aaf..b467a61dfb 100644 --- a/synapse/rest/admin/_base.py +++ b/synapse/rest/admin/_base.py @@ -19,7 +19,7 @@ from typing import Iterable, Pattern from synapse.api.auth import Auth from synapse.api.errors import AuthError from synapse.http.site import SynapseRequest -from synapse.types import UserID +from synapse.types import Requester def admin_patterns(path_regex: str, version: str = "v1") -> Iterable[Pattern]: @@ -48,19 +48,19 @@ async def assert_requester_is_admin(auth: Auth, request: SynapseRequest) -> None AuthError if the requester is not a server admin """ requester = await auth.get_user_by_req(request) - await assert_user_is_admin(auth, requester.user) + await assert_user_is_admin(auth, requester) -async def assert_user_is_admin(auth: Auth, user_id: UserID) -> None: +async def assert_user_is_admin(auth: Auth, requester: Requester) -> None: """Verify that the given user is an admin user Args: auth: Auth singleton - user_id: user to check + requester: The user making the request, according to the access token. Raises: AuthError if the user is not a server admin """ - is_admin = await auth.is_server_admin(user_id) + is_admin = await auth.is_server_admin(requester) if not is_admin: raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin") diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 19d4a008e8..73470f09ae 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -54,7 +54,7 @@ class QuarantineMediaInRoom(RestServlet): self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_user_is_admin(self.auth, requester) logging.info("Quarantining room: %s", room_id) @@ -81,7 +81,7 @@ class QuarantineMediaByUser(RestServlet): self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_user_is_admin(self.auth, requester) logging.info("Quarantining media by user: %s", user_id) @@ -110,7 +110,7 @@ class QuarantineMediaByID(RestServlet): self, request: SynapseRequest, server_name: str, media_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_user_is_admin(self.auth, requester) logging.info("Quarantining media by ID: %s/%s", server_name, media_id) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 68054ffc28..3d870629c4 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -75,7 +75,7 @@ class RoomRestV2Servlet(RestServlet): ) -> Tuple[int, JsonDict]: requester = await self._auth.get_user_by_req(request) - await assert_user_is_admin(self._auth, requester.user) + await assert_user_is_admin(self._auth, requester) content = parse_json_object_from_request(request) @@ -327,7 +327,7 @@ class RoomRestServlet(RestServlet): pagination_handler: "PaginationHandler", ) -> Tuple[int, JsonDict]: requester = await auth.get_user_by_req(request) - await assert_user_is_admin(auth, requester.user) + await assert_user_is_admin(auth, requester) content = parse_json_object_from_request(request) @@ -461,7 +461,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): assert request.args is not None requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_user_is_admin(self.auth, requester) content = parse_json_object_from_request(request) @@ -551,7 +551,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): self, request: SynapseRequest, room_identifier: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_user_is_admin(self.auth, requester) content = parse_json_object_from_request(request, allow_empty_body=True) room_id, _ = await self.resolve_room_id(room_identifier) @@ -742,7 +742,7 @@ class RoomEventContextServlet(RestServlet): self, request: SynapseRequest, room_id: str, event_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=False) - await assert_user_is_admin(self.auth, requester.user) + await assert_user_is_admin(self.auth, requester) limit = parse_integer(request, "limit", default=10) @@ -834,7 +834,7 @@ class BlockRoomRestServlet(RestServlet): self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: requester = await self._auth.get_user_by_req(request) - await assert_user_is_admin(self._auth, requester.user) + await assert_user_is_admin(self._auth, requester) content = parse_json_object_from_request(request) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index ba2f7fa6d8..78ee9b6532 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -183,7 +183,7 @@ class UserRestServletV2(RestServlet): self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_user_is_admin(self.auth, requester) target_user = UserID.from_string(user_id) body = parse_json_object_from_request(request) @@ -575,10 +575,9 @@ class WhoisRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: target_user = UserID.from_string(user_id) requester = await self.auth.get_user_by_req(request) - auth_user = requester.user - if target_user != auth_user: - await assert_user_is_admin(self.auth, auth_user) + if target_user != requester.user: + await assert_user_is_admin(self.auth, requester) if not self.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only whois a local user") @@ -601,7 +600,7 @@ class DeactivateAccountRestServlet(RestServlet): self, request: SynapseRequest, target_user_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_user_is_admin(self.auth, requester) if not self.is_mine(UserID.from_string(target_user_id)): raise SynapseError( @@ -693,7 +692,7 @@ class ResetPasswordRestServlet(RestServlet): This needs user to have administrator access in Synapse. """ requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_user_is_admin(self.auth, requester) UserID.from_string(target_user_id) @@ -807,7 +806,7 @@ class UserAdminServlet(RestServlet): self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_user_is_admin(self.auth, requester) auth_user = requester.user target_user = UserID.from_string(user_id) @@ -921,7 +920,7 @@ class UserTokenRestServlet(RestServlet): self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) + await assert_user_is_admin(self.auth, requester) auth_user = requester.user if not self.is_mine_id(user_id): diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py index c16d707909..e69fa0829d 100644 --- a/synapse/rest/client/profile.py +++ b/synapse/rest/client/profile.py @@ -66,7 +66,7 @@ class ProfileDisplaynameRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) user = UserID.from_string(user_id) - is_admin = await self.auth.is_server_admin(requester.user) + is_admin = await self.auth.is_server_admin(requester) content = parse_json_object_from_request(request) @@ -123,7 +123,7 @@ class ProfileAvatarURLRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) user = UserID.from_string(user_id) - is_admin = await self.auth.is_server_admin(requester.user) + is_admin = await self.auth.is_server_admin(requester) content = parse_json_object_from_request(request) try: diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 956c45e60a..1b953d3fa0 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -484,9 +484,6 @@ class RegisterRestServlet(RestServlet): "Appservice token must be provided when using a type of m.login.application_service", ) - # Verify the AS - self.auth.get_appservice_by_req(request) - # Set the desired user according to the AS API (which uses the # 'user' key not 'username'). Since this is a new addition, we'll # fallback to 'username' if they gave one. diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 13bc9482c5..0eafbae457 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -229,7 +229,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): msg_handler = self.message_handler data = await msg_handler.get_room_data( - user_id=requester.user.to_string(), + requester=requester, room_id=room_id, event_type=event_type, state_key=state_key, @@ -574,7 +574,7 @@ class RoomMemberListRestServlet(RestServlet): events = await handler.get_state_events( room_id=room_id, - user_id=requester.user.to_string(), + requester=requester, at_token=at_token, state_filter=StateFilter.from_types([(EventTypes.Member, None)]), ) @@ -696,8 +696,7 @@ class RoomStateRestServlet(RestServlet): # Get all the current state for this room events = await self.message_handler.get_state_events( room_id=room_id, - user_id=requester.user.to_string(), - is_guest=requester.is_guest, + requester=requester, ) return 200, events @@ -755,7 +754,7 @@ class RoomEventServlet(RestServlet): == "true" ) if include_unredacted_content and not await self.auth.is_server_admin( - requester.user + requester ): power_level_event = ( await self._storage_controllers.state.get_current_state_event( @@ -1260,9 +1259,7 @@ class TimestampLookupRestServlet(RestServlet): self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: requester = await self._auth.get_user_by_req(request) - await self._auth.check_user_in_room_or_world_readable( - room_id, requester.user.to_string() - ) + await self._auth.check_user_in_room_or_world_readable(room_id, requester) timestamp = parse_integer(request, "ts", required=True) direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"]) diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index 8ecab86ec7..70d054a8f4 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -244,7 +244,7 @@ class ServerNoticesManager: assert self.server_notices_mxid is not None notice_user_data_in_room = await self._message_handler.get_room_data( - self.server_notices_mxid, + create_requester(self.server_notices_mxid), room_id, EventTypes.Member, self.server_notices_mxid, diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index cb63cd9b7d..7fb9c801da 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -69,9 +69,9 @@ class TokenLookupResult: """ user_id: str + token_id: int is_guest: bool = False shadow_banned: bool = False - token_id: Optional[int] = None device_id: Optional[str] = None valid_until_ms: Optional[int] = None token_owner: str = attr.ib() diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index dfcfaf79b6..e0f363555b 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -284,10 +284,13 @@ class AuthTestCase(unittest.HomeserverTestCase): TokenLookupResult( user_id="@baldrick:matrix.org", device_id="device", + token_id=5, token_owner="@admin:matrix.org", + token_used=True, ) ) self.store.insert_client_ip = simple_async_mock(None) + self.store.mark_access_token_as_used = simple_async_mock(None) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" request.args[b"access_token"] = [self.test_token] @@ -301,10 +304,13 @@ class AuthTestCase(unittest.HomeserverTestCase): TokenLookupResult( user_id="@baldrick:matrix.org", device_id="device", + token_id=5, token_owner="@admin:matrix.org", + token_used=True, ) ) self.store.insert_client_ip = simple_async_mock(None) + self.store.mark_access_token_as_used = simple_async_mock(None) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" request.args[b"access_token"] = [self.test_token] @@ -347,7 +353,7 @@ class AuthTestCase(unittest.HomeserverTestCase): serialized = macaroon.serialize() user_info = self.get_success(self.auth.get_user_by_access_token(serialized)) - self.assertEqual(user_id, user_info.user_id) + self.assertEqual(user_id, user_info.user.to_string()) self.assertTrue(user_info.is_guest) self.store.get_user_by_id.assert_called_with(user_id) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 7af1333126..8adba29d7f 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -25,7 +25,7 @@ from synapse.api.constants import EduTypes from synapse.api.errors import AuthError from synapse.federation.transport.server import TransportLayerServer from synapse.server import HomeServer -from synapse.types import JsonDict, UserID, create_requester +from synapse.types import JsonDict, Requester, UserID, create_requester from synapse.util import Clock from tests import unittest @@ -117,8 +117,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.room_members = [] - async def check_user_in_room(room_id: str, user_id: str) -> None: - if user_id not in [u.to_string() for u in self.room_members]: + async def check_user_in_room(room_id: str, requester: Requester) -> None: + if requester.user.to_string() not in [ + u.to_string() for u in self.room_members + ]: raise AuthError(401, "User is not in the room") return None diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index ac9c113354..9c8c1889d3 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -20,7 +20,7 @@ from synapse.api.constants import EventTypes from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer -from synapse.types import JsonDict +from synapse.types import JsonDict, create_requester from synapse.util import Clock from synapse.visibility import filter_events_for_client @@ -188,7 +188,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): message_handler = self.hs.get_message_handler() create_event = self.get_success( message_handler.get_room_data( - self.user_id, room_id, EventTypes.Create, state_key="" + create_requester(self.user_id), room_id, EventTypes.Create, state_key="" ) ) diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index d9bd8c4a28..c50f034b34 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -26,7 +26,7 @@ from synapse.rest.client import ( room_upgrade_rest_servlet, ) from synapse.server import HomeServer -from synapse.types import UserID +from synapse.types import UserID, create_requester from synapse.util import Clock from tests import unittest @@ -275,7 +275,7 @@ class ProfileTestCase(_ShadowBannedBase): message_handler = self.hs.get_message_handler() event = self.get_success( message_handler.get_room_data( - self.banned_user_id, + create_requester(self.banned_user_id), room_id, "m.room.member", self.banned_user_id, @@ -310,7 +310,7 @@ class ProfileTestCase(_ShadowBannedBase): message_handler = self.hs.get_message_handler() event = self.get_success( message_handler.get_room_data( - self.banned_user_id, + create_requester(self.banned_user_id), room_id, "m.room.member", self.banned_user_id, -- cgit 1.5.1 From d58615c82cec5bd866bedcb33e3e2a5d2a961c44 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Wed, 24 Aug 2022 14:13:12 -0500 Subject: Directly lookup local membership instead of getting all members in a room first (`get_users_in_room` mis-use) (#13608) See https://github.com/matrix-org/synapse/pull/13575#discussion_r953023755 --- changelog.d/13608.misc | 1 + synapse/handlers/events.py | 9 +++++--- synapse/handlers/message.py | 6 ++++-- synapse/handlers/room.py | 7 +++++-- synapse/handlers/room_member.py | 6 ++++-- synapse/server_notices/server_notices_manager.py | 10 +++++++-- synapse/storage/databases/main/roommember.py | 26 ++++++++++++++++++++++++ tests/rest/client/test_relations.py | 12 +++++------ 8 files changed, 60 insertions(+), 17 deletions(-) create mode 100644 changelog.d/13608.misc (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/13608.misc b/changelog.d/13608.misc new file mode 100644 index 0000000000..19bcc45e33 --- /dev/null +++ b/changelog.d/13608.misc @@ -0,0 +1 @@ +Refactor `get_users_in_room(room_id)` mis-use to lookup single local user with dedicated `check_local_user_in_room(...)` function. diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index ac13340d3a..949b69cb41 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -151,7 +151,7 @@ class EventHandler: """Retrieve a single specified event. Args: - user: The user requesting the event + user: The local user requesting the event room_id: The expected room id. We'll return None if the event's room does not match. event_id: The event ID to obtain. @@ -173,8 +173,11 @@ class EventHandler: if not event: return None - users = await self.store.get_users_in_room(event.room_id) - is_peeking = user.to_string() not in users + is_user_in_room = await self.store.check_local_user_in_room( + user_id=user.to_string(), room_id=event.room_id + ) + # The user is peeking if they aren't in the room already + is_peeking = not is_user_in_room filtered = await filter_events_for_client( self._storage_controllers, user.to_string(), [event], is_peeking=is_peeking diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index acd3de06f6..72157d5a36 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -761,8 +761,10 @@ class EventCreationHandler: async def _is_server_notices_room(self, room_id: str) -> bool: if self.config.servernotices.server_notices_mxid is None: return False - user_ids = await self.store.get_users_in_room(room_id) - return self.config.servernotices.server_notices_mxid in user_ids + is_server_notices_room = await self.store.check_local_user_in_room( + user_id=self.config.servernotices.server_notices_mxid, room_id=room_id + ) + return is_server_notices_room async def assert_accepted_privacy_policy(self, requester: Requester) -> None: """Check if a user has accepted the privacy policy diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 2bf0ebd025..2fc8264858 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1284,8 +1284,11 @@ class RoomContextHandler: before_limit = math.floor(limit / 2.0) after_limit = limit - before_limit - users = await self.store.get_users_in_room(room_id) - is_peeking = user.to_string() not in users + is_user_in_room = await self.store.check_local_user_in_room( + user_id=user.to_string(), room_id=room_id + ) + # The user is peeking if they aren't in the room already + is_peeking = not is_user_in_room async def filter_evts(events: List[EventBase]) -> List[EventBase]: if use_admin_priviledge: diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 65b9a655d4..709682622f 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -1620,8 +1620,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): async def _is_server_notice_room(self, room_id: str) -> bool: if self._server_notices_mxid is None: return False - user_ids = await self.store.get_users_in_room(room_id) - return self._server_notices_mxid in user_ids + is_server_notices_room = await self.store.check_local_user_in_room( + user_id=self._server_notices_mxid, room_id=room_id + ) + return is_server_notices_room class RoomMemberMasterHandler(RoomMemberHandler): diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index 70d054a8f4..564e3705c2 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -102,6 +102,10 @@ class ServerNoticesManager: Returns: The room's ID, or None if no room could be found. """ + # If there is no server notices MXID, then there is no server notices room + if self.server_notices_mxid is None: + return None + rooms = await self._store.get_rooms_for_local_user_where_membership_is( user_id, [Membership.INVITE, Membership.JOIN] ) @@ -111,8 +115,10 @@ class ServerNoticesManager: # be joined. This is kinda deliberate, in that if somebody somehow # manages to invite the system user to a room, that doesn't make it # the server notices room. - user_ids = await self._store.get_users_in_room(room.room_id) - if len(user_ids) <= 2 and self.server_notices_mxid in user_ids: + is_server_notices_room = await self._store.check_local_user_in_room( + user_id=self.server_notices_mxid, room_id=room.room_id + ) + if is_server_notices_room: # we found a room which our user shares with the system notice # user return room.room_id diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 046ad3a11c..9e5034b401 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -534,6 +534,32 @@ class RoomMemberWorkerStore(EventsWorkerStore): desc="get_local_users_in_room", ) + async def check_local_user_in_room(self, user_id: str, room_id: str) -> bool: + """ + Check whether a given local user is currently joined to the given room. + + Returns: + A boolean indicating whether the user is currently joined to the room + + Raises: + Exeption when called with a non-local user to this homeserver + """ + if not self.hs.is_mine_id(user_id): + raise Exception( + "Cannot call 'check_local_user_in_room' on " + "non-local user %s" % (user_id,), + ) + + ( + membership, + member_event_id, + ) = await self.get_local_current_membership_for_user_in_room( + user_id=user_id, + room_id=room_id, + ) + + return membership == Membership.JOIN + async def get_local_current_membership_for_user_in_room( self, user_id: str, room_id: str ) -> Tuple[Optional[str], Optional[str]]: diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index d589f07314..651f4f415d 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -999,7 +999,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations, ) - self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 6) + self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7) def test_annotation_to_annotation(self) -> None: """Any relation to an annotation should be ignored.""" @@ -1035,7 +1035,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations, ) - self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 6) + self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7) def test_thread(self) -> None: """ @@ -1080,21 +1080,21 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): # The "user" sent the root event and is making queries for the bundled # aggregations: they have participated. - self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8) + self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 9) # The "user2" sent replies in the thread and is making queries for the # bundled aggregations: they have participated. # # Note that this re-uses some cached values, so the total number of # queries is much smaller. self._test_bundled_aggregations( - RelationTypes.THREAD, _gen_assert(True), 2, access_token=self.user2_token + RelationTypes.THREAD, _gen_assert(True), 3, access_token=self.user2_token ) # A user with no interactions with the thread: they have not participated. user3_id, user3_token = self._create_user("charlie") self.helper.join(self.room, user=user3_id, tok=user3_token) self._test_bundled_aggregations( - RelationTypes.THREAD, _gen_assert(False), 2, access_token=user3_token + RelationTypes.THREAD, _gen_assert(False), 3, access_token=user3_token ) def test_thread_with_bundled_aggregations_for_latest(self) -> None: @@ -1142,7 +1142,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations["latest_event"].get("unsigned"), ) - self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8) + self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9) def test_nested_thread(self) -> None: """ -- cgit 1.5.1 From 6302753012927b63feddc71dd287e2d3554707d4 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Wed, 14 Sep 2022 15:53:18 +0000 Subject: Deduplicate `is_server_notices_room`. (#13780) --- changelog.d/13780.misc | 1 + synapse/handlers/message.py | 10 +--------- synapse/handlers/room_member.py | 10 +--------- synapse/storage/databases/main/roommember.py | 17 +++++++++++++++++ 4 files changed, 20 insertions(+), 18 deletions(-) create mode 100644 changelog.d/13780.misc (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/13780.misc b/changelog.d/13780.misc new file mode 100644 index 0000000000..1bcac51cad --- /dev/null +++ b/changelog.d/13780.misc @@ -0,0 +1 @@ +Deduplicate `is_server_notices_room`. \ No newline at end of file diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 72157d5a36..e07cda133a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -752,20 +752,12 @@ class EventCreationHandler: if builder.type == EventTypes.Member: membership = builder.content.get("membership", None) if membership == Membership.JOIN: - return await self._is_server_notices_room(builder.room_id) + return await self.store.is_server_notice_room(builder.room_id) elif membership == Membership.LEAVE: # the user is always allowed to leave (but not kick people) return builder.state_key == requester.user.to_string() return False - async def _is_server_notices_room(self, room_id: str) -> bool: - if self.config.servernotices.server_notices_mxid is None: - return False - is_server_notices_room = await self.store.check_local_user_in_room( - user_id=self.config.servernotices.server_notices_mxid, room_id=room_id - ) - return is_server_notices_room - async def assert_accepted_privacy_policy(self, requester: Requester) -> None: """Check if a user has accepted the privacy policy diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 5d4adf5bfd..8d01f4bf2b 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -837,7 +837,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): old_membership == Membership.INVITE and effective_membership_state == Membership.LEAVE ): - is_blocked = await self._is_server_notice_room(room_id) + is_blocked = await self.store.is_server_notice_room(room_id) if is_blocked: raise SynapseError( HTTPStatus.FORBIDDEN, @@ -1617,14 +1617,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): return False - async def _is_server_notice_room(self, room_id: str) -> bool: - if self._server_notices_mxid is None: - return False - is_server_notices_room = await self.store.check_local_user_in_room( - user_id=self._server_notices_mxid, room_id=room_id - ) - return is_server_notices_room - class RoomMemberMasterHandler(RoomMemberHandler): def __init__(self, hs: "HomeServer"): diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index fdb4684e12..a8d224602a 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -88,6 +88,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): # at a time. Keyed by room_id. self._joined_host_linearizer = Linearizer("_JoinedHostsCache") + self._server_notices_mxid = hs.config.servernotices.server_notices_mxid + if ( self.hs.config.worker.run_background_tasks and self.hs.config.metrics.metrics_flags.known_servers @@ -504,6 +506,21 @@ class RoomMemberWorkerStore(EventsWorkerStore): return membership == Membership.JOIN + async def is_server_notice_room(self, room_id: str) -> bool: + """ + Determines whether the given room is a 'Server Notices' room, used for + sending server notices to a user. + + This is determined by seeing whether the server notices user is present + in the room. + """ + if self._server_notices_mxid is None: + return False + is_server_notices_room = await self.check_local_user_in_room( + user_id=self._server_notices_mxid, room_id=room_id + ) + return is_server_notices_room + async def get_local_current_membership_for_user_in_room( self, user_id: str, room_id: str ) -> Tuple[Optional[str], Optional[str]]: -- cgit 1.5.1 From a2cf66a94d5dfd9d6496ac3e48ec9a22f17be69a Mon Sep 17 00:00:00 2001 From: Shay Date: Wed, 28 Sep 2022 02:39:03 -0700 Subject: Prepatory work for batching events to send (#13487) This PR begins work on batching up events during the creation of a room. The PR splits out the creation and sending/persisting of the events. The first three events in the creation of the room-creating the room, joining the creator to the room, and the power levels event are sent sequentially, while the subsequent events are created and collected to be sent at the end of the function. This is currently done by appending them to a list and then iterating over the list to send, the next step (after this PR) would be to send and persist the collected events as a batch. --- changelog.d/13487.misc | 1 + synapse/handlers/message.py | 175 ++++++++++++++++++++++++++-------------- synapse/handlers/room.py | 155 ++++++++++++++++++++++++----------- synapse/state/__init__.py | 63 +++++++++++++++ tests/rest/client/test_rooms.py | 4 +- 5 files changed, 290 insertions(+), 108 deletions(-) create mode 100644 changelog.d/13487.misc (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/13487.misc b/changelog.d/13487.misc new file mode 100644 index 0000000000..761adc8b05 --- /dev/null +++ b/changelog.d/13487.misc @@ -0,0 +1 @@ +Speed up creation of DM rooms. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e07cda133a..062f93bc67 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -63,6 +63,7 @@ from synapse.types import ( MutableStateMap, Requester, RoomAlias, + StateMap, StreamToken, UserID, create_requester, @@ -567,9 +568,17 @@ class EventCreationHandler: outlier: bool = False, historical: bool = False, depth: Optional[int] = None, + state_map: Optional[StateMap[str]] = None, + for_batch: bool = False, + current_state_group: Optional[int] = None, ) -> Tuple[EventBase, EventContext]: """ - Given a dict from a client, create a new event. + Given a dict from a client, create a new event. If bool for_batch is true, will + create an event using the prev_event_ids, and will create an event context for + the event using the parameters state_map and current_state_group, thus these parameters + must be provided in this case if for_batch is True. The subsequently created event + and context are suitable for being batched up and bulk persisted to the database + with other similarly created events. Creates an FrozenEvent object, filling out auth_events, prev_events, etc. @@ -612,16 +621,27 @@ class EventCreationHandler: outlier: Indicates whether the event is an `outlier`, i.e. if it's from an arbitrary point and floating in the DAG as opposed to being inline with the current DAG. + historical: Indicates whether the message is being inserted back in time around some existing events. This is used to skip a few checks and mark the event as backfilled. + depth: Override the depth used to order the event in the DAG. Should normally be set to None, which will cause the depth to be calculated based on the prev_events. + state_map: A state map of previously created events, used only when creating events + for batch persisting + + for_batch: whether the event is being created for batch persisting to the db + + current_state_group: the current state group, used only for creating events for + batch persisting + Raises: ResourceLimitError if server is blocked to some resource being exceeded + Returns: Tuple of created event, Context """ @@ -693,6 +713,9 @@ class EventCreationHandler: auth_event_ids=auth_event_ids, state_event_ids=state_event_ids, depth=depth, + state_map=state_map, + for_batch=for_batch, + current_state_group=current_state_group, ) # In an ideal world we wouldn't need the second part of this condition. However, @@ -707,10 +730,14 @@ class EventCreationHandler: # federation as well as those created locally. As of room v3, aliases events # can be created by users that are not in the room, therefore we have to # tolerate them in event_auth.check(). - prev_state_ids = await context.get_prev_state_ids( - StateFilter.from_types([(EventTypes.Member, None)]) - ) - prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender)) + if for_batch: + assert state_map is not None + prev_event_id = state_map.get((EventTypes.Member, event.sender)) + else: + prev_state_ids = await context.get_prev_state_ids( + StateFilter.from_types([(EventTypes.Member, None)]) + ) + prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender)) prev_event = ( await self.store.get_event(prev_event_id, allow_none=True) if prev_event_id @@ -1009,8 +1036,16 @@ class EventCreationHandler: auth_event_ids: Optional[List[str]] = None, state_event_ids: Optional[List[str]] = None, depth: Optional[int] = None, + state_map: Optional[StateMap[str]] = None, + for_batch: bool = False, + current_state_group: Optional[int] = None, ) -> Tuple[EventBase, EventContext]: - """Create a new event for a local client + """Create a new event for a local client. If bool for_batch is true, will + create an event using the prev_event_ids, and will create an event context for + the event using the parameters state_map and current_state_group, thus these parameters + must be provided in this case if for_batch is True. The subsequently created event + and context are suitable for being batched up and bulk persisted to the database + with other similarly created events. Args: builder: @@ -1043,6 +1078,14 @@ class EventCreationHandler: Should normally be set to None, which will cause the depth to be calculated based on the prev_events. + state_map: A state map of previously created events, used only when creating events + for batch persisting + + for_batch: whether the event is being created for batch persisting to the db + + current_state_group: the current state group, used only for creating events for + batch persisting + Returns: Tuple of created event, context """ @@ -1095,64 +1138,76 @@ class EventCreationHandler: builder.type == EventTypes.Create or prev_event_ids ), "Attempting to create a non-m.room.create event with no prev_events" - event = await builder.build( - prev_event_ids=prev_event_ids, - auth_event_ids=auth_event_ids, - depth=depth, - ) + if for_batch: + assert prev_event_ids is not None + assert state_map is not None + assert current_state_group is not None + auth_ids = self._event_auth_handler.compute_auth_events(builder, state_map) + event = await builder.build( + prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth + ) + context = await self.state.compute_event_context_for_batched( + event, state_map, current_state_group + ) + else: + event = await builder.build( + prev_event_ids=prev_event_ids, + auth_event_ids=auth_event_ids, + depth=depth, + ) - # Pass on the outlier property from the builder to the event - # after it is created - if builder.internal_metadata.outlier: - event.internal_metadata.outlier = True - context = EventContext.for_outlier(self._storage_controllers) - elif ( - event.type == EventTypes.MSC2716_INSERTION - and state_event_ids - and builder.internal_metadata.is_historical() - ): - # Add explicit state to the insertion event so it has state to derive - # from even though it's floating with no `prev_events`. The rest of - # the batch can derive from this state and state_group. - # - # TODO(faster_joins): figure out how this works, and make sure that the - # old state is complete. - # https://github.com/matrix-org/synapse/issues/13003 - metadata = await self.store.get_metadata_for_events(state_event_ids) - - state_map_for_event: MutableStateMap[str] = {} - for state_id in state_event_ids: - data = metadata.get(state_id) - if data is None: - # We're trying to persist a new historical batch of events - # with the given state, e.g. via - # `RoomBatchSendEventRestServlet`. The state can be inferred - # by Synapse or set directly by the client. - # - # Either way, we should have persisted all the state before - # getting here. - raise Exception( - f"State event {state_id} not found in DB," - " Synapse should have persisted it before using it." - ) + # Pass on the outlier property from the builder to the event + # after it is created + if builder.internal_metadata.outlier: + event.internal_metadata.outlier = True + context = EventContext.for_outlier(self._storage_controllers) + elif ( + event.type == EventTypes.MSC2716_INSERTION + and state_event_ids + and builder.internal_metadata.is_historical() + ): + # Add explicit state to the insertion event so it has state to derive + # from even though it's floating with no `prev_events`. The rest of + # the batch can derive from this state and state_group. + # + # TODO(faster_joins): figure out how this works, and make sure that the + # old state is complete. + # https://github.com/matrix-org/synapse/issues/13003 + metadata = await self.store.get_metadata_for_events(state_event_ids) + + state_map_for_event: MutableStateMap[str] = {} + for state_id in state_event_ids: + data = metadata.get(state_id) + if data is None: + # We're trying to persist a new historical batch of events + # with the given state, e.g. via + # `RoomBatchSendEventRestServlet`. The state can be inferred + # by Synapse or set directly by the client. + # + # Either way, we should have persisted all the state before + # getting here. + raise Exception( + f"State event {state_id} not found in DB," + " Synapse should have persisted it before using it." + ) - if data.state_key is None: - raise Exception( - f"Trying to set non-state event {state_id} as state" - ) + if data.state_key is None: + raise Exception( + f"Trying to set non-state event {state_id} as state" + ) - state_map_for_event[(data.event_type, data.state_key)] = state_id + state_map_for_event[(data.event_type, data.state_key)] = state_id - context = await self.state.compute_event_context( - event, - state_ids_before_event=state_map_for_event, - # TODO(faster_joins): check how MSC2716 works and whether we can have - # partial state here - # https://github.com/matrix-org/synapse/issues/13003 - partial_state=False, - ) - else: - context = await self.state.compute_event_context(event) + context = await self.state.compute_event_context( + event, + state_ids_before_event=state_map_for_event, + # TODO(faster_joins): check how MSC2716 works and whether we can have + # partial state here + # https://github.com/matrix-org/synapse/issues/13003 + partial_state=False, + ) + else: + context = await self.state.compute_event_context(event) if requester: context.app_service = requester.app_service diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 33e9a87002..09a1a82e6c 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -716,7 +716,7 @@ class RoomCreationHandler: if ( self._server_notices_mxid is not None - and requester.user.to_string() == self._server_notices_mxid + and user_id == self._server_notices_mxid ): # allow the server notices mxid to create rooms is_requester_admin = True @@ -1042,7 +1042,9 @@ class RoomCreationHandler: creator_join_profile: Optional[JsonDict] = None, ratelimit: bool = True, ) -> Tuple[int, str, int]: - """Sends the initial events into a new room. + """Sends the initial events into a new room. Sends the room creation, membership, + and power level events into the room sequentially, then creates and batches up the + rest of the events to persist as a batch to the DB. `power_level_content_override` doesn't apply when initial state has power level state event content. @@ -1053,13 +1055,21 @@ class RoomCreationHandler: """ creator_id = creator.user.to_string() - event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""} - depth = 1 + # the last event sent/persisted to the db last_sent_event_id: Optional[str] = None - - def create(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict: + # the most recently created event + prev_event: List[str] = [] + # a map of event types, state keys -> event_ids. We collect these mappings this as events are + # created (but not persisted to the db) to determine state for future created events + # (as this info can't be pulled from the db) + state_map: MutableStateMap[str] = {} + # current_state_group of last event created. Used for computing event context of + # events to be batched + current_state_group = None + + def create_event_dict(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict: e = {"type": etype, "content": content} e.update(event_keys) @@ -1067,32 +1077,52 @@ class RoomCreationHandler: return e - async def send(etype: str, content: JsonDict, **kwargs: Any) -> int: - nonlocal last_sent_event_id + async def create_event( + etype: str, + content: JsonDict, + for_batch: bool, + **kwargs: Any, + ) -> Tuple[EventBase, synapse.events.snapshot.EventContext]: nonlocal depth + nonlocal prev_event - event = create(etype, content, **kwargs) - logger.debug("Sending %s in new room", etype) - # Allow these events to be sent even if the user is shadow-banned to - # allow the room creation to complete. - ( - sent_event, - last_stream_id, - ) = await self.event_creation_handler.create_and_send_nonmember_event( + event_dict = create_event_dict(etype, content, **kwargs) + + new_event, new_context = await self.event_creation_handler.create_event( creator, - event, + event_dict, + prev_event_ids=prev_event, + depth=depth, + state_map=state_map, + for_batch=for_batch, + current_state_group=current_state_group, + ) + depth += 1 + prev_event = [new_event.event_id] + state_map[(new_event.type, new_event.state_key)] = new_event.event_id + + return new_event, new_context + + async def send( + event: EventBase, + context: synapse.events.snapshot.EventContext, + creator: Requester, + ) -> int: + nonlocal last_sent_event_id + + ev = await self.event_creation_handler.handle_new_client_event( + requester=creator, + event=event, + context=context, ratelimit=False, ignore_shadow_ban=True, - # Note: we don't pass state_event_ids here because this triggers - # an additional query per event to look them up from the events table. - prev_event_ids=[last_sent_event_id] if last_sent_event_id else [], - depth=depth, ) - last_sent_event_id = sent_event.event_id - depth += 1 + last_sent_event_id = ev.event_id - return last_stream_id + # we know it was persisted, so must have a stream ordering + assert ev.internal_metadata.stream_ordering + return ev.internal_metadata.stream_ordering try: config = self._presets_dict[preset_config] @@ -1102,9 +1132,13 @@ class RoomCreationHandler: ) creation_content.update({"creator": creator_id}) - await send(etype=EventTypes.Create, content=creation_content) + creation_event, creation_context = await create_event( + EventTypes.Create, creation_content, False + ) logger.debug("Sending %s in new room", EventTypes.Member) + await send(creation_event, creation_context, creator) + # Room create event must exist at this point assert last_sent_event_id is not None member_event_id, _ = await self.room_member_handler.update_membership( @@ -1119,14 +1153,22 @@ class RoomCreationHandler: depth=depth, ) last_sent_event_id = member_event_id + prev_event = [member_event_id] + + # update the depth and state map here as the membership event has been created + # through a different code path + depth += 1 + state_map[(EventTypes.Member, creator.user.to_string())] = member_event_id # We treat the power levels override specially as this needs to be one # of the first events that get sent into a room. pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) if pl_content is not None: - last_sent_stream_id = await send( - etype=EventTypes.PowerLevels, content=pl_content + power_event, power_context = await create_event( + EventTypes.PowerLevels, pl_content, False ) + current_state_group = power_context._state_group + last_sent_stream_id = await send(power_event, power_context, creator) else: power_level_content: JsonDict = { "users": {creator_id: 100}, @@ -1169,47 +1211,68 @@ class RoomCreationHandler: # apply those. if power_level_content_override: power_level_content.update(power_level_content_override) - - last_sent_stream_id = await send( - etype=EventTypes.PowerLevels, content=power_level_content + pl_event, pl_context = await create_event( + EventTypes.PowerLevels, + power_level_content, + False, ) + current_state_group = pl_context._state_group + last_sent_stream_id = await send(pl_event, pl_context, creator) + events_to_send = [] if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: - last_sent_stream_id = await send( - etype=EventTypes.CanonicalAlias, - content={"alias": room_alias.to_string()}, + room_alias_event, room_alias_context = await create_event( + EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True ) + current_state_group = room_alias_context._state_group + events_to_send.append((room_alias_event, room_alias_context)) if (EventTypes.JoinRules, "") not in initial_state: - last_sent_stream_id = await send( - etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]} + join_rules_event, join_rules_context = await create_event( + EventTypes.JoinRules, + {"join_rule": config["join_rules"]}, + True, ) + current_state_group = join_rules_context._state_group + events_to_send.append((join_rules_event, join_rules_context)) if (EventTypes.RoomHistoryVisibility, "") not in initial_state: - last_sent_stream_id = await send( - etype=EventTypes.RoomHistoryVisibility, - content={"history_visibility": config["history_visibility"]}, + visibility_event, visibility_context = await create_event( + EventTypes.RoomHistoryVisibility, + {"history_visibility": config["history_visibility"]}, + True, ) + current_state_group = visibility_context._state_group + events_to_send.append((visibility_event, visibility_context)) if config["guest_can_join"]: if (EventTypes.GuestAccess, "") not in initial_state: - last_sent_stream_id = await send( - etype=EventTypes.GuestAccess, - content={EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN}, + guest_access_event, guest_access_context = await create_event( + EventTypes.GuestAccess, + {EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN}, + True, ) + current_state_group = guest_access_context._state_group + events_to_send.append((guest_access_event, guest_access_context)) for (etype, state_key), content in initial_state.items(): - last_sent_stream_id = await send( - etype=etype, state_key=state_key, content=content + event, context = await create_event( + etype, content, True, state_key=state_key ) + current_state_group = context._state_group + events_to_send.append((event, context)) if config["encrypted"]: - last_sent_stream_id = await send( - etype=EventTypes.RoomEncryption, + encryption_event, encryption_context = await create_event( + EventTypes.RoomEncryption, + {"algorithm": RoomEncryptionAlgorithms.DEFAULT}, + True, state_key="", - content={"algorithm": RoomEncryptionAlgorithms.DEFAULT}, ) + events_to_send.append((encryption_event, encryption_context)) + for event, context in events_to_send: + last_sent_stream_id = await send(event, context, creator) return last_sent_stream_id, last_sent_event_id, depth def _generate_room_id(self) -> str: diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 3787d35b24..6f3dd0463e 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -420,6 +420,69 @@ class StateHandler: partial_state=partial_state, ) + async def compute_event_context_for_batched( + self, + event: EventBase, + state_ids_before_event: StateMap[str], + current_state_group: int, + ) -> EventContext: + """ + Generate an event context for an event that has not yet been persisted to the + database. Intended for use with events that are created to be persisted in a batch. + Args: + event: the event the context is being computed for + state_ids_before_event: a state map consisting of the state ids of the events + created prior to this event. + current_state_group: the current state group before the event. + """ + state_group_before_event_prev_group = None + deltas_to_state_group_before_event = None + + state_group_before_event = current_state_group + + # if the event is not state, we are set + if not event.is_state(): + return EventContext.with_state( + storage=self._storage_controllers, + state_group_before_event=state_group_before_event, + state_group=state_group_before_event, + state_delta_due_to_event={}, + prev_group=state_group_before_event_prev_group, + delta_ids=deltas_to_state_group_before_event, + partial_state=False, + ) + + # otherwise, we'll need to create a new state group for after the event + key = (event.type, event.state_key) + + if state_ids_before_event is not None: + replaces = state_ids_before_event.get(key) + + if replaces and replaces != event.event_id: + event.unsigned["replaces_state"] = replaces + + delta_ids = {key: event.event_id} + + state_group_after_event = ( + await self._state_storage_controller.store_state_group( + event.event_id, + event.room_id, + prev_group=state_group_before_event, + delta_ids=delta_ids, + current_state_ids=None, + ) + ) + + return EventContext.with_state( + storage=self._storage_controllers, + state_group=state_group_after_event, + state_group_before_event=state_group_before_event, + state_delta_due_to_event=delta_ids, + prev_group=state_group_before_event, + delta_ids=delta_ids, + partial_state=False, + ) + @measure_func() async def resolve_state_groups_for_events( self, room_id: str, event_ids: Collection[str], await_full_state: bool = True diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index c7eb88d33f..e281aef779 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -710,7 +710,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(44, channel.resource_usage.db_txn_count) + self.assertEqual(35, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -723,7 +723,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(50, channel.resource_usage.db_txn_count) + self.assertEqual(38, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id -- cgit 1.5.1 From 8ab16a92edd675453c78cfd9974081e374b0f998 Mon Sep 17 00:00:00 2001 From: Shay Date: Wed, 28 Sep 2022 03:11:48 -0700 Subject: Persist CreateRoom events to DB in a batch (#13800) --- changelog.d/13800.misc | 1 + synapse/handlers/message.py | 663 +++++++++++++++++--------------- synapse/handlers/room.py | 21 +- synapse/handlers/room_batch.py | 3 +- synapse/handlers/room_member.py | 11 +- synapse/replication/http/__init__.py | 2 + synapse/replication/http/send_event.py | 4 +- synapse/replication/http/send_events.py | 171 ++++++++ tests/handlers/test_message.py | 10 +- tests/handlers/test_register.py | 4 +- tests/storage/test_event_chain.py | 8 +- tests/unittest.py | 4 +- 12 files changed, 563 insertions(+), 339 deletions(-) create mode 100644 changelog.d/13800.misc create mode 100644 synapse/replication/http/send_events.py (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/13800.misc b/changelog.d/13800.misc new file mode 100644 index 0000000000..761adc8b05 --- /dev/null +++ b/changelog.d/13800.misc @@ -0,0 +1 @@ +Speed up creation of DM rooms. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 062f93bc67..00e7645ba5 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -56,11 +56,13 @@ from synapse.logging import opentracing from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet +from synapse.replication.http.send_events import ReplicationSendEventsRestServlet from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter from synapse.types import ( MutableStateMap, + PersistedEventPosition, Requester, RoomAlias, StateMap, @@ -493,6 +495,7 @@ class EventCreationHandler: self.membership_types_to_include_profile_data_in.add(Membership.INVITE) self.send_event = ReplicationSendEventRestServlet.make_client(hs) + self.send_events = ReplicationSendEventsRestServlet.make_client(hs) self.request_ratelimiter = hs.get_request_ratelimiter() @@ -1016,8 +1019,7 @@ class EventCreationHandler: ev = await self.handle_new_client_event( requester=requester, - event=event, - context=context, + events_and_context=[(event, context)], ratelimit=ratelimit, ignore_shadow_ban=ignore_shadow_ban, ) @@ -1293,13 +1295,13 @@ class EventCreationHandler: async def handle_new_client_event( self, requester: Requester, - event: EventBase, - context: EventContext, + events_and_context: List[Tuple[EventBase, EventContext]], ratelimit: bool = True, extra_users: Optional[List[UserID]] = None, ignore_shadow_ban: bool = False, ) -> EventBase: - """Processes a new event. + """Processes new events. Please note that if batch persisting events, an error in + handling any one of these events will result in all of the events being dropped. This includes deduplicating, checking auth, persisting, notifying users, sending to remote servers, etc. @@ -1309,8 +1311,7 @@ class EventCreationHandler: Args: requester - event - context + events_and_context: A list of one or more tuples of event, context to be persisted ratelimit extra_users: Any extra users to notify about event @@ -1328,62 +1329,63 @@ class EventCreationHandler: """ extra_users = extra_users or [] - # we don't apply shadow-banning to membership events here. Invites are blocked - # higher up the stack, and we allow shadow-banned users to send join and leave - # events as normal. - if ( - event.type != EventTypes.Member - and not ignore_shadow_ban - and requester.shadow_banned - ): - # We randomly sleep a bit just to annoy the requester. - await self.clock.sleep(random.randint(1, 10)) - raise ShadowBanError() + for event, context in events_and_context: + # we don't apply shadow-banning to membership events here. Invites are blocked + # higher up the stack, and we allow shadow-banned users to send join and leave + # events as normal. + if ( + event.type != EventTypes.Member + and not ignore_shadow_ban + and requester.shadow_banned + ): + # We randomly sleep a bit just to annoy the requester. + await self.clock.sleep(random.randint(1, 10)) + raise ShadowBanError() - if event.is_state(): - prev_event = await self.deduplicate_state_event(event, context) - if prev_event is not None: - logger.info( - "Not bothering to persist state event %s duplicated by %s", - event.event_id, - prev_event.event_id, - ) - return prev_event + if event.is_state(): + prev_event = await self.deduplicate_state_event(event, context) + if prev_event is not None: + logger.info( + "Not bothering to persist state event %s duplicated by %s", + event.event_id, + prev_event.event_id, + ) + return prev_event - if event.internal_metadata.is_out_of_band_membership(): - # the only sort of out-of-band-membership events we expect to see here are - # invite rejections and rescinded knocks that we have generated ourselves. - assert event.type == EventTypes.Member - assert event.content["membership"] == Membership.LEAVE - else: - try: - validate_event_for_room_version(event) - await self._event_auth_handler.check_auth_rules_from_context( - event, context - ) - except AuthError as err: - logger.warning("Denying new event %r because %s", event, err) - raise err + if event.internal_metadata.is_out_of_band_membership(): + # the only sort of out-of-band-membership events we expect to see here are + # invite rejections and rescinded knocks that we have generated ourselves. + assert event.type == EventTypes.Member + assert event.content["membership"] == Membership.LEAVE + else: + try: + validate_event_for_room_version(event) + await self._event_auth_handler.check_auth_rules_from_context( + event, context + ) + except AuthError as err: + logger.warning("Denying new event %r because %s", event, err) + raise err - # Ensure that we can round trip before trying to persist in db - try: - dump = json_encoder.encode(event.content) - json_decoder.decode(dump) - except Exception: - logger.exception("Failed to encode content: %r", event.content) - raise + # Ensure that we can round trip before trying to persist in db + try: + dump = json_encoder.encode(event.content) + json_decoder.decode(dump) + except Exception: + logger.exception("Failed to encode content: %r", event.content) + raise # We now persist the event (and update the cache in parallel, since we # don't want to block on it). + event, context = events_and_context[0] try: result, _ = await make_deferred_yieldable( gather_results( ( run_in_background( - self._persist_event, + self._persist_events, requester=requester, - event=event, - context=context, + events_and_context=events_and_context, ratelimit=ratelimit, extra_users=extra_users, ), @@ -1407,45 +1409,47 @@ class EventCreationHandler: return result - async def _persist_event( + async def _persist_events( self, requester: Requester, - event: EventBase, - context: EventContext, + events_and_context: List[Tuple[EventBase, EventContext]], ratelimit: bool = True, extra_users: Optional[List[UserID]] = None, ) -> EventBase: - """Actually persists the event. Should only be called by + """Actually persists new events. Should only be called by `handle_new_client_event`, and see its docstring for documentation of - the arguments. + the arguments. Please note that if batch persisting events, an error in + handling any one of these events will result in all of the events being dropped. PartialStateConflictError: if attempting to persist a partial state event in a room that has been un-partial stated. """ - # Skip push notification actions for historical messages - # because we don't want to notify people about old history back in time. - # The historical messages also do not have the proper `context.current_state_ids` - # and `state_groups` because they have `prev_events` that aren't persisted yet - # (historical messages persisted in reverse-chronological order). - if not event.internal_metadata.is_historical(): - with opentracing.start_active_span("calculate_push_actions"): - await self._bulk_push_rule_evaluator.action_for_event_by_user( - event, context - ) + for event, context in events_and_context: + # Skip push notification actions for historical messages + # because we don't want to notify people about old history back in time. + # The historical messages also do not have the proper `context.current_state_ids` + # and `state_groups` because they have `prev_events` that aren't persisted yet + # (historical messages persisted in reverse-chronological order). + if not event.internal_metadata.is_historical(): + with opentracing.start_active_span("calculate_push_actions"): + await self._bulk_push_rule_evaluator.action_for_event_by_user( + event, context + ) try: # If we're a worker we need to hit out to the master. - writer_instance = self._events_shard_config.get_instance(event.room_id) + first_event, _ = events_and_context[0] + writer_instance = self._events_shard_config.get_instance( + first_event.room_id + ) if writer_instance != self._instance_name: try: - result = await self.send_event( + result = await self.send_events( instance_name=writer_instance, - event_id=event.event_id, + events_and_context=events_and_context, store=self.store, requester=requester, - event=event, - context=context, ratelimit=ratelimit, extra_users=extra_users, ) @@ -1455,6 +1459,11 @@ class EventCreationHandler: raise stream_id = result["stream_id"] event_id = result["event_id"] + + # If we batch persisted events we return the last persisted event, otherwise + # we return the one event that was persisted + event, _ = events_and_context[-1] + if event_id != event.event_id: # If we get a different event back then it means that its # been de-duplicated, so we replace the given event with the @@ -1467,15 +1476,19 @@ class EventCreationHandler: event.internal_metadata.stream_ordering = stream_id return event - event = await self.persist_and_notify_client_event( - requester, event, context, ratelimit=ratelimit, extra_users=extra_users + event = await self.persist_and_notify_client_events( + requester, + events_and_context, + ratelimit=ratelimit, + extra_users=extra_users, ) return event except Exception: - # Ensure that we actually remove the entries in the push actions - # staging area, if we calculated them. - await self.store.remove_push_actions_from_staging(event.event_id) + for event, _ in events_and_context: + # Ensure that we actually remove the entries in the push actions + # staging area, if we calculated them. + await self.store.remove_push_actions_from_staging(event.event_id) raise async def cache_joined_hosts_for_event( @@ -1569,23 +1582,26 @@ class EventCreationHandler: Codes.BAD_ALIAS, ) - async def persist_and_notify_client_event( + async def persist_and_notify_client_events( self, requester: Requester, - event: EventBase, - context: EventContext, + events_and_context: List[Tuple[EventBase, EventContext]], ratelimit: bool = True, extra_users: Optional[List[UserID]] = None, ) -> EventBase: - """Called when we have fully built the event, have already - calculated the push actions for the event, and checked auth. + """Called when we have fully built the events, have already + calculated the push actions for the events, and checked auth. This should only be run on the instance in charge of persisting events. + Please note that if batch persisting events, an error in + handling any one of these events will result in all of the events being dropped. + Returns: - The persisted event. This may be different than the given event if - it was de-duplicated (e.g. because we had already persisted an - event with the same transaction ID.) + The persisted event, if one event is passed in, or the last event in the + list in the case of batch persisting. If only one event was persisted, the + returned event may be different than the given event if it was de-duplicated + (e.g. because we had already persisted an event with the same transaction ID.) Raises: PartialStateConflictError: if attempting to persist a partial state event in @@ -1593,277 +1609,297 @@ class EventCreationHandler: """ extra_users = extra_users or [] - assert self._storage_controllers.persistence is not None - assert self._events_shard_config.should_handle( - self._instance_name, event.room_id - ) + for event, context in events_and_context: + assert self._events_shard_config.should_handle( + self._instance_name, event.room_id + ) - if ratelimit: - # We check if this is a room admin redacting an event so that we - # can apply different ratelimiting. We do this by simply checking - # it's not a self-redaction (to avoid having to look up whether the - # user is actually admin or not). - is_admin_redaction = False - if event.type == EventTypes.Redaction: - assert event.redacts is not None + if ratelimit: + # We check if this is a room admin redacting an event so that we + # can apply different ratelimiting. We do this by simply checking + # it's not a self-redaction (to avoid having to look up whether the + # user is actually admin or not). + is_admin_redaction = False + if event.type == EventTypes.Redaction: + assert event.redacts is not None + + original_event = await self.store.get_event( + event.redacts, + redact_behaviour=EventRedactBehaviour.as_is, + get_prev_content=False, + allow_rejected=False, + allow_none=True, + ) - original_event = await self.store.get_event( - event.redacts, - redact_behaviour=EventRedactBehaviour.as_is, - get_prev_content=False, - allow_rejected=False, - allow_none=True, + is_admin_redaction = bool( + original_event and event.sender != original_event.sender + ) + + await self.request_ratelimiter.ratelimit( + requester, is_admin_redaction=is_admin_redaction ) - is_admin_redaction = bool( - original_event and event.sender != original_event.sender + # run checks/actions on event based on type + if event.type == EventTypes.Member and event.membership == Membership.JOIN: + ( + current_membership, + _, + ) = await self.store.get_local_current_membership_for_user_in_room( + event.state_key, event.room_id ) + if current_membership != Membership.JOIN: + self._notifier.notify_user_joined_room( + event.event_id, event.room_id + ) - await self.request_ratelimiter.ratelimit( - requester, is_admin_redaction=is_admin_redaction - ) + await self._maybe_kick_guest_users(event, context) - if event.type == EventTypes.Member and event.membership == Membership.JOIN: - ( - current_membership, - _, - ) = await self.store.get_local_current_membership_for_user_in_room( - event.state_key, event.room_id - ) - if current_membership != Membership.JOIN: - self._notifier.notify_user_joined_room(event.event_id, event.room_id) + if event.type == EventTypes.CanonicalAlias: + # Validate a newly added alias or newly added alt_aliases. - await self._maybe_kick_guest_users(event, context) + original_alias = None + original_alt_aliases: object = [] - if event.type == EventTypes.CanonicalAlias: - # Validate a newly added alias or newly added alt_aliases. + original_event_id = event.unsigned.get("replaces_state") + if original_event_id: + original_alias_event = await self.store.get_event(original_event_id) - original_alias = None - original_alt_aliases: object = [] + if original_alias_event: + original_alias = original_alias_event.content.get("alias", None) + original_alt_aliases = original_alias_event.content.get( + "alt_aliases", [] + ) - original_event_id = event.unsigned.get("replaces_state") - if original_event_id: - original_event = await self.store.get_event(original_event_id) + # Check the alias is currently valid (if it has changed). + room_alias_str = event.content.get("alias", None) + directory_handler = self.hs.get_directory_handler() + if room_alias_str and room_alias_str != original_alias: + await self._validate_canonical_alias( + directory_handler, room_alias_str, event.room_id + ) - if original_event: - original_alias = original_event.content.get("alias", None) - original_alt_aliases = original_event.content.get("alt_aliases", []) - - # Check the alias is currently valid (if it has changed). - room_alias_str = event.content.get("alias", None) - directory_handler = self.hs.get_directory_handler() - if room_alias_str and room_alias_str != original_alias: - await self._validate_canonical_alias( - directory_handler, room_alias_str, event.room_id - ) + # Check that alt_aliases is the proper form. + alt_aliases = event.content.get("alt_aliases", []) + if not isinstance(alt_aliases, (list, tuple)): + raise SynapseError( + 400, + "The alt_aliases property must be a list.", + Codes.INVALID_PARAM, + ) - # Check that alt_aliases is the proper form. - alt_aliases = event.content.get("alt_aliases", []) - if not isinstance(alt_aliases, (list, tuple)): - raise SynapseError( - 400, "The alt_aliases property must be a list.", Codes.INVALID_PARAM - ) + # If the old version of alt_aliases is of an unknown form, + # completely replace it. + if not isinstance(original_alt_aliases, (list, tuple)): + # TODO: check that the original_alt_aliases' entries are all strings + original_alt_aliases = [] + + # Check that each alias is currently valid. + new_alt_aliases = set(alt_aliases) - set(original_alt_aliases) + if new_alt_aliases: + for alias_str in new_alt_aliases: + await self._validate_canonical_alias( + directory_handler, alias_str, event.room_id + ) - # If the old version of alt_aliases is of an unknown form, - # completely replace it. - if not isinstance(original_alt_aliases, (list, tuple)): - # TODO: check that the original_alt_aliases' entries are all strings - original_alt_aliases = [] + federation_handler = self.hs.get_federation_handler() - # Check that each alias is currently valid. - new_alt_aliases = set(alt_aliases) - set(original_alt_aliases) - if new_alt_aliases: - for alias_str in new_alt_aliases: - await self._validate_canonical_alias( - directory_handler, alias_str, event.room_id + if event.type == EventTypes.Member: + if event.content["membership"] == Membership.INVITE: + event.unsigned[ + "invite_room_state" + ] = await self.store.get_stripped_room_state_from_event_context( + context, + self.room_prejoin_state_types, + membership_user_id=event.sender, ) - federation_handler = self.hs.get_federation_handler() + invitee = UserID.from_string(event.state_key) + if not self.hs.is_mine(invitee): + # TODO: Can we add signature from remote server in a nicer + # way? If we have been invited by a remote server, we need + # to get them to sign the event. - if event.type == EventTypes.Member: - if event.content["membership"] == Membership.INVITE: - event.unsigned[ - "invite_room_state" - ] = await self.store.get_stripped_room_state_from_event_context( - context, - self.room_prejoin_state_types, - membership_user_id=event.sender, - ) + returned_invite = await federation_handler.send_invite( + invitee.domain, event + ) + event.unsigned.pop("room_state", None) - invitee = UserID.from_string(event.state_key) - if not self.hs.is_mine(invitee): - # TODO: Can we add signature from remote server in a nicer - # way? If we have been invited by a remote server, we need - # to get them to sign the event. + # TODO: Make sure the signatures actually are correct. + event.signatures.update(returned_invite.signatures) - returned_invite = await federation_handler.send_invite( - invitee.domain, event + if event.content["membership"] == Membership.KNOCK: + event.unsigned[ + "knock_room_state" + ] = await self.store.get_stripped_room_state_from_event_context( + context, + self.room_prejoin_state_types, ) - event.unsigned.pop("room_state", None) - # TODO: Make sure the signatures actually are correct. - event.signatures.update(returned_invite.signatures) + if event.type == EventTypes.Redaction: + assert event.redacts is not None - if event.content["membership"] == Membership.KNOCK: - event.unsigned[ - "knock_room_state" - ] = await self.store.get_stripped_room_state_from_event_context( - context, - self.room_prejoin_state_types, + original_event = await self.store.get_event( + event.redacts, + redact_behaviour=EventRedactBehaviour.as_is, + get_prev_content=False, + allow_rejected=False, + allow_none=True, ) - if event.type == EventTypes.Redaction: - assert event.redacts is not None + room_version = await self.store.get_room_version_id(event.room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - original_event = await self.store.get_event( - event.redacts, - redact_behaviour=EventRedactBehaviour.as_is, - get_prev_content=False, - allow_rejected=False, - allow_none=True, - ) + # we can make some additional checks now if we have the original event. + if original_event: + if original_event.type == EventTypes.Create: + raise AuthError(403, "Redacting create events is not permitted") - room_version = await self.store.get_room_version_id(event.room_id) - room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - - # we can make some additional checks now if we have the original event. - if original_event: - if original_event.type == EventTypes.Create: - raise AuthError(403, "Redacting create events is not permitted") - - if original_event.room_id != event.room_id: - raise SynapseError(400, "Cannot redact event from a different room") - - if original_event.type == EventTypes.ServerACL: - raise AuthError(403, "Redacting server ACL events is not permitted") - - # Add a little safety stop-gap to prevent people from trying to - # redact MSC2716 related events when they're in a room version - # which does not support it yet. We allow people to use MSC2716 - # events in existing room versions but only from the room - # creator since it does not require any changes to the auth - # rules and in effect, the redaction algorithm . In the - # supported room version, we add the `historical` power level to - # auth the MSC2716 related events and adjust the redaction - # algorthim to keep the `historical` field around (redacting an - # event should only strip fields which don't affect the - # structural protocol level). - is_msc2716_event = ( - original_event.type == EventTypes.MSC2716_INSERTION - or original_event.type == EventTypes.MSC2716_BATCH - or original_event.type == EventTypes.MSC2716_MARKER - ) - if not room_version_obj.msc2716_historical and is_msc2716_event: - raise AuthError( - 403, - "Redacting MSC2716 events is not supported in this room version", - ) + if original_event.room_id != event.room_id: + raise SynapseError( + 400, "Cannot redact event from a different room" + ) - event_types = event_auth.auth_types_for_event(event.room_version, event) - prev_state_ids = await context.get_prev_state_ids( - StateFilter.from_types(event_types) - ) + if original_event.type == EventTypes.ServerACL: + raise AuthError( + 403, "Redacting server ACL events is not permitted" + ) - auth_events_ids = self._event_auth_handler.compute_auth_events( - event, prev_state_ids, for_verification=True - ) - auth_events_map = await self.store.get_events(auth_events_ids) - auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()} + # Add a little safety stop-gap to prevent people from trying to + # redact MSC2716 related events when they're in a room version + # which does not support it yet. We allow people to use MSC2716 + # events in existing room versions but only from the room + # creator since it does not require any changes to the auth + # rules and in effect, the redaction algorithm . In the + # supported room version, we add the `historical` power level to + # auth the MSC2716 related events and adjust the redaction + # algorthim to keep the `historical` field around (redacting an + # event should only strip fields which don't affect the + # structural protocol level). + is_msc2716_event = ( + original_event.type == EventTypes.MSC2716_INSERTION + or original_event.type == EventTypes.MSC2716_BATCH + or original_event.type == EventTypes.MSC2716_MARKER + ) + if not room_version_obj.msc2716_historical and is_msc2716_event: + raise AuthError( + 403, + "Redacting MSC2716 events is not supported in this room version", + ) - if event_auth.check_redaction( - room_version_obj, event, auth_events=auth_events - ): - # this user doesn't have 'redact' rights, so we need to do some more - # checks on the original event. Let's start by checking the original - # event exists. - if not original_event: - raise NotFoundError("Could not find event %s" % (event.redacts,)) - - if event.user_id != original_event.user_id: - raise AuthError(403, "You don't have permission to redact events") - - # all the checks are done. - event.internal_metadata.recheck_redaction = False - - if event.type == EventTypes.Create: - prev_state_ids = await context.get_prev_state_ids() - if prev_state_ids: - raise AuthError(403, "Changing the room create event is forbidden") - - if event.type == EventTypes.MSC2716_INSERTION: - room_version = await self.store.get_room_version_id(event.room_id) - room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - - create_event = await self.store.get_create_event_for_room(event.room_id) - room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR) - - # Only check an insertion event if the room version - # supports it or the event is from the room creator. - if room_version_obj.msc2716_historical or ( - self.config.experimental.msc2716_enabled - and event.sender == room_creator - ): - next_batch_id = event.content.get( - EventContentFields.MSC2716_NEXT_BATCH_ID + event_types = event_auth.auth_types_for_event(event.room_version, event) + prev_state_ids = await context.get_prev_state_ids( + StateFilter.from_types(event_types) ) - conflicting_insertion_event_id = None - if next_batch_id: - conflicting_insertion_event_id = ( - await self.store.get_insertion_event_id_by_batch_id( - event.room_id, next_batch_id + + auth_events_ids = self._event_auth_handler.compute_auth_events( + event, prev_state_ids, for_verification=True + ) + auth_events_map = await self.store.get_events(auth_events_ids) + auth_events = { + (e.type, e.state_key): e for e in auth_events_map.values() + } + + if event_auth.check_redaction( + room_version_obj, event, auth_events=auth_events + ): + # this user doesn't have 'redact' rights, so we need to do some more + # checks on the original event. Let's start by checking the original + # event exists. + if not original_event: + raise NotFoundError( + "Could not find event %s" % (event.redacts,) ) + + if event.user_id != original_event.user_id: + raise AuthError( + 403, "You don't have permission to redact events" + ) + + # all the checks are done. + event.internal_metadata.recheck_redaction = False + + if event.type == EventTypes.Create: + prev_state_ids = await context.get_prev_state_ids() + if prev_state_ids: + raise AuthError(403, "Changing the room create event is forbidden") + + if event.type == EventTypes.MSC2716_INSERTION: + room_version = await self.store.get_room_version_id(event.room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + + create_event = await self.store.get_create_event_for_room(event.room_id) + room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR) + + # Only check an insertion event if the room version + # supports it or the event is from the room creator. + if room_version_obj.msc2716_historical or ( + self.config.experimental.msc2716_enabled + and event.sender == room_creator + ): + next_batch_id = event.content.get( + EventContentFields.MSC2716_NEXT_BATCH_ID ) - if conflicting_insertion_event_id is not None: - # The current insertion event that we're processing is invalid - # because an insertion event already exists in the room with the - # same next_batch_id. We can't allow multiple because the batch - # pointing will get weird, e.g. we can't determine which insertion - # event the batch event is pointing to. - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Another insertion event already exists with the same next_batch_id", - errcode=Codes.INVALID_PARAM, - ) + conflicting_insertion_event_id = None + if next_batch_id: + conflicting_insertion_event_id = ( + await self.store.get_insertion_event_id_by_batch_id( + event.room_id, next_batch_id + ) + ) + if conflicting_insertion_event_id is not None: + # The current insertion event that we're processing is invalid + # because an insertion event already exists in the room with the + # same next_batch_id. We can't allow multiple because the batch + # pointing will get weird, e.g. we can't determine which insertion + # event the batch event is pointing to. + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Another insertion event already exists with the same next_batch_id", + errcode=Codes.INVALID_PARAM, + ) - # Mark any `m.historical` messages as backfilled so they don't appear - # in `/sync` and have the proper decrementing `stream_ordering` as we import - backfilled = False - if event.internal_metadata.is_historical(): - backfilled = True + # Mark any `m.historical` messages as backfilled so they don't appear + # in `/sync` and have the proper decrementing `stream_ordering` as we import + backfilled = False + if event.internal_metadata.is_historical(): + backfilled = True - # Note that this returns the event that was persisted, which may not be - # the same as we passed in if it was deduplicated due transaction IDs. + assert self._storage_controllers.persistence is not None ( - event, - event_pos, + persisted_events, max_stream_token, - ) = await self._storage_controllers.persistence.persist_event( - event, context=context, backfilled=backfilled + ) = await self._storage_controllers.persistence.persist_events( + events_and_context, backfilled=backfilled ) - if self._ephemeral_events_enabled: - # If there's an expiry timestamp on the event, schedule its expiry. - self._message_handler.maybe_schedule_expiry(event) + for event in persisted_events: + if self._ephemeral_events_enabled: + # If there's an expiry timestamp on the event, schedule its expiry. + self._message_handler.maybe_schedule_expiry(event) - async def _notify() -> None: - try: - await self.notifier.on_new_room_event( - event, event_pos, max_stream_token, extra_users=extra_users - ) - except Exception: - logger.exception( - "Error notifying about new room event %s", - event.event_id, - ) + stream_ordering = event.internal_metadata.stream_ordering + assert stream_ordering is not None + pos = PersistedEventPosition(self._instance_name, stream_ordering) + + async def _notify() -> None: + try: + await self.notifier.on_new_room_event( + event, pos, max_stream_token, extra_users=extra_users + ) + except Exception: + logger.exception( + "Error notifying about new room event %s", + event.event_id, + ) - run_in_background(_notify) + run_in_background(_notify) - if event.type == EventTypes.Message: - # We don't want to block sending messages on any presence code. This - # matters as sometimes presence code can take a while. - run_in_background(self._bump_active_time, requester.user) + if event.type == EventTypes.Message: + # We don't want to block sending messages on any presence code. This + # matters as sometimes presence code can take a while. + run_in_background(self._bump_active_time, requester.user) - return event + return persisted_events[-1] async def _maybe_kick_guest_users( self, event: EventBase, context: EventContext @@ -1952,8 +1988,7 @@ class EventCreationHandler: # shadow-banned user. await self.handle_new_client_event( requester, - event, - context, + events_and_context=[(event, context)], ratelimit=False, ignore_shadow_ban=True, ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 09a1a82e6c..b220238e55 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -301,8 +301,7 @@ class RoomCreationHandler: # now send the tombstone await self.event_creation_handler.handle_new_client_event( requester=requester, - event=tombstone_event, - context=tombstone_context, + events_and_context=[(tombstone_event, tombstone_context)], ) state_filter = StateFilter.from_types( @@ -1057,8 +1056,10 @@ class RoomCreationHandler: creator_id = creator.user.to_string() event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""} depth = 1 + # the last event sent/persisted to the db last_sent_event_id: Optional[str] = None + # the most recently created event prev_event: List[str] = [] # a map of event types, state keys -> event_ids. We collect these mappings this as events are @@ -1112,8 +1113,7 @@ class RoomCreationHandler: ev = await self.event_creation_handler.handle_new_client_event( requester=creator, - event=event, - context=context, + events_and_context=[(event, context)], ratelimit=False, ignore_shadow_ban=True, ) @@ -1152,7 +1152,6 @@ class RoomCreationHandler: prev_event_ids=[last_sent_event_id], depth=depth, ) - last_sent_event_id = member_event_id prev_event = [member_event_id] # update the depth and state map here as the membership event has been created @@ -1168,7 +1167,7 @@ class RoomCreationHandler: EventTypes.PowerLevels, pl_content, False ) current_state_group = power_context._state_group - last_sent_stream_id = await send(power_event, power_context, creator) + await send(power_event, power_context, creator) else: power_level_content: JsonDict = { "users": {creator_id: 100}, @@ -1217,7 +1216,7 @@ class RoomCreationHandler: False, ) current_state_group = pl_context._state_group - last_sent_stream_id = await send(pl_event, pl_context, creator) + await send(pl_event, pl_context, creator) events_to_send = [] if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: @@ -1271,9 +1270,11 @@ class RoomCreationHandler: ) events_to_send.append((encryption_event, encryption_context)) - for event, context in events_to_send: - last_sent_stream_id = await send(event, context, creator) - return last_sent_stream_id, last_sent_event_id, depth + last_event = await self.event_creation_handler.handle_new_client_event( + creator, events_to_send, ignore_shadow_ban=True + ) + assert last_event.internal_metadata.stream_ordering is not None + return last_event.internal_metadata.stream_ordering, last_event.event_id, depth def _generate_room_id(self) -> str: """Generates a random room ID. diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index 1414e575d6..411a6fb22f 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -379,8 +379,7 @@ class RoomBatchHandler: await self.create_requester_for_user_id_from_app_service( event.sender, app_service_requester.app_service ), - event=event, - context=context, + events_and_context=[(event, context)], ) return event_ids diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 8d01f4bf2b..88158822e0 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -432,8 +432,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): with opentracing.start_active_span("handle_new_client_event"): result_event = await self.event_creation_handler.handle_new_client_event( requester, - event, - context, + events_and_context=[(event, context)], extra_users=[target], ratelimit=ratelimit, ) @@ -1252,7 +1251,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): raise SynapseError(403, "This room has been blocked on this server") event = await self.event_creation_handler.handle_new_client_event( - requester, event, context, extra_users=[target_user], ratelimit=ratelimit + requester, + events_and_context=[(event, context)], + extra_users=[target_user], + ratelimit=ratelimit, ) prev_member_event_id = prev_state_ids.get( @@ -1860,8 +1862,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): result_event = await self.event_creation_handler.handle_new_client_event( requester, - event, - context, + events_and_context=[(event, context)], extra_users=[UserID.from_string(target_user)], ) # we know it was persisted, so must have a stream ordering diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index 53aa7fa4c6..ac9a92240a 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -25,6 +25,7 @@ from synapse.replication.http import ( push, register, send_event, + send_events, state, streams, ) @@ -43,6 +44,7 @@ class ReplicationRestResource(JsonResource): def register_servlets(self, hs: "HomeServer") -> None: send_event.register_servlets(hs, self) + send_events.register_servlets(hs, self) federation.register_servlets(hs, self) presence.register_servlets(hs, self) membership.register_servlets(hs, self) diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index 486f04723c..4215a1c1bc 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -141,8 +141,8 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): "Got event to send with ID: %s into room: %s", event.event_id, event.room_id ) - event = await self.event_creation_handler.persist_and_notify_client_event( - requester, event, context, ratelimit=ratelimit, extra_users=extra_users + event = await self.event_creation_handler.persist_and_notify_client_events( + requester, [(event, context)], ratelimit=ratelimit, extra_users=extra_users ) return ( diff --git a/synapse/replication/http/send_events.py b/synapse/replication/http/send_events.py new file mode 100644 index 0000000000..8889bbb644 --- /dev/null +++ b/synapse/replication/http/send_events.py @@ -0,0 +1,171 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING, List, Tuple + +from twisted.web.server import Request + +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.events import EventBase, make_event_from_dict +from synapse.events.snapshot import EventContext +from synapse.http.server import HttpServer +from synapse.http.servlet import parse_json_object_from_request +from synapse.replication.http._base import ReplicationEndpoint +from synapse.types import JsonDict, Requester, UserID +from synapse.util.metrics import Measure + +if TYPE_CHECKING: + from synapse.server import HomeServer + from synapse.storage.databases.main import DataStore + +logger = logging.getLogger(__name__) + + +class ReplicationSendEventsRestServlet(ReplicationEndpoint): + """Handles batches of newly created events on workers, including persisting and + notifying. + + The API looks like: + + POST /_synapse/replication/send_events/:txn_id + + { + "events": [{ + "event": { .. serialized event .. }, + "room_version": .., // "1", "2", "3", etc: the version of the room + // containing the event + "event_format_version": .., // 1,2,3 etc: the event format version + "internal_metadata": { .. serialized internal_metadata .. }, + "outlier": true|false, + "rejected_reason": .., // The event.rejected_reason field + "context": { .. serialized event context .. }, + "requester": { .. serialized requester .. }, + "ratelimit": true, + }] + } + + 200 OK + + { "stream_id": 12345, "event_id": "$abcdef..." } + + Responds with a 409 when a `PartialStateConflictError` is raised due to an event + context that needs to be recomputed due to the un-partial stating of a room. + + """ + + NAME = "send_events" + PATH_ARGS = () + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self.event_creation_handler = hs.get_event_creation_handler() + self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload( # type: ignore[override] + events_and_context: List[Tuple[EventBase, EventContext]], + store: "DataStore", + requester: Requester, + ratelimit: bool, + extra_users: List[UserID], + ) -> JsonDict: + """ + Args: + store + requester + events_and_ctx + ratelimit + """ + serialized_events = [] + + for event, context in events_and_context: + serialized_context = await context.serialize(event, store) + serialized_event = { + "event": event.get_pdu_json(), + "room_version": event.room_version.identifier, + "event_format_version": event.format_version, + "internal_metadata": event.internal_metadata.get_dict(), + "outlier": event.internal_metadata.is_outlier(), + "rejected_reason": event.rejected_reason, + "context": serialized_context, + "requester": requester.serialize(), + "ratelimit": ratelimit, + "extra_users": [u.to_string() for u in extra_users], + } + serialized_events.append(serialized_event) + + payload = {"events": serialized_events} + + return payload + + async def _handle_request( # type: ignore[override] + self, request: Request + ) -> Tuple[int, JsonDict]: + with Measure(self.clock, "repl_send_events_parse"): + payload = parse_json_object_from_request(request) + events_and_context = [] + events = payload["events"] + + for event_payload in events: + event_dict = event_payload["event"] + room_ver = KNOWN_ROOM_VERSIONS[event_payload["room_version"]] + internal_metadata = event_payload["internal_metadata"] + rejected_reason = event_payload["rejected_reason"] + + event = make_event_from_dict( + event_dict, room_ver, internal_metadata, rejected_reason + ) + event.internal_metadata.outlier = event_payload["outlier"] + + requester = Requester.deserialize( + self.store, event_payload["requester"] + ) + context = EventContext.deserialize( + self._storage_controllers, event_payload["context"] + ) + + ratelimit = event_payload["ratelimit"] + events_and_context.append((event, context)) + + extra_users = [ + UserID.from_string(u) for u in event_payload["extra_users"] + ] + + logger.info( + "Got batch of events to send, last ID of batch is: %s, sending into room: %s", + event.event_id, + event.room_id, + ) + + last_event = ( + await self.event_creation_handler.persist_and_notify_client_events( + requester, events_and_context, ratelimit, extra_users + ) + ) + + return ( + 200, + { + "stream_id": last_event.internal_metadata.stream_ordering, + "event_id": last_event.event_id, + }, + ) + + +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + ReplicationSendEventsRestServlet(hs).register(http_server) diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index 986b50ce0c..99384837d0 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -105,7 +105,10 @@ class EventCreationTestCase(unittest.HomeserverTestCase): event1, context = self._create_duplicate_event(txn_id) ret_event1 = self.get_success( - self.handler.handle_new_client_event(self.requester, event1, context) + self.handler.handle_new_client_event( + self.requester, + events_and_context=[(event1, context)], + ) ) stream_id1 = ret_event1.internal_metadata.stream_ordering @@ -118,7 +121,10 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.assertNotEqual(event1.event_id, event2.event_id) ret_event2 = self.get_success( - self.handler.handle_new_client_event(self.requester, event2, context) + self.handler.handle_new_client_event( + self.requester, + events_and_context=[(event2, context)], + ) ) stream_id2 = ret_event2.internal_metadata.stream_ordering diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 86b3d51975..765df75d91 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -497,7 +497,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - event_creation_handler.handle_new_client_event(requester, event, context) + event_creation_handler.handle_new_client_event( + requester, events_and_context=[(event, context)] + ) ) # Register a second user, which won't be be in the room (or even have an invite) diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index a0ce077a99..de9f4af2de 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -531,7 +531,9 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): ) ) self.get_success( - event_handler.handle_new_client_event(self.requester, event, context) + event_handler.handle_new_client_event( + self.requester, events_and_context=[(event, context)] + ) ) state1 = set(self.get_success(context.get_current_state_ids()).values()) @@ -549,7 +551,9 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): ) ) self.get_success( - event_handler.handle_new_client_event(self.requester, event, context) + event_handler.handle_new_client_event( + self.requester, events_and_context=[(event, context)] + ) ) state2 = set(self.get_success(context.get_current_state_ids()).values()) diff --git a/tests/unittest.py b/tests/unittest.py index 00cb023198..5116be338e 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -734,7 +734,9 @@ class HomeserverTestCase(TestCase): event.internal_metadata.soft_failed = True self.get_success( - event_creator.handle_new_client_event(requester, event, context) + event_creator.handle_new_client_event( + requester, events_and_context=[(event, context)] + ) ) return event.event_id -- cgit 1.5.1 From 7b7478e8b65cceb9e7362c6c1cb932b569a6f383 Mon Sep 17 00:00:00 2001 From: Shay Date: Wed, 5 Oct 2022 10:12:48 -0700 Subject: Batch up notifications after event persistence (#14033) --- changelog.d/14033.misc | 1 + synapse/handlers/federation_event.py | 4 +- synapse/handlers/message.py | 25 ++++++------ synapse/notifier.py | 75 ++++++++++++++++++++---------------- synapse/replication/tcp/client.py | 19 ++++----- 5 files changed, 66 insertions(+), 58 deletions(-) create mode 100644 changelog.d/14033.misc (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/14033.misc b/changelog.d/14033.misc new file mode 100644 index 0000000000..fe42852aa5 --- /dev/null +++ b/changelog.d/14033.misc @@ -0,0 +1 @@ +Don't repeatedly wake up the same users for batched events. \ No newline at end of file diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 778d8869b3..da319943cc 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -2240,8 +2240,8 @@ class FederationEventHandler: event_pos = PersistedEventPosition( self._instance_name, event.internal_metadata.stream_ordering ) - await self._notifier.on_new_room_event( - event, event_pos, max_stream_token, extra_users=extra_users + await self._notifier.on_new_room_events( + [(event, event_pos)], max_stream_token, extra_users=extra_users ) if event.type == EventTypes.Member and event.membership == Membership.JOIN: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 00e7645ba5..da1acea275 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1872,6 +1872,7 @@ class EventCreationHandler: events_and_context, backfilled=backfilled ) + events_and_pos = [] for event in persisted_events: if self._ephemeral_events_enabled: # If there's an expiry timestamp on the event, schedule its expiry. @@ -1880,25 +1881,23 @@ class EventCreationHandler: stream_ordering = event.internal_metadata.stream_ordering assert stream_ordering is not None pos = PersistedEventPosition(self._instance_name, stream_ordering) - - async def _notify() -> None: - try: - await self.notifier.on_new_room_event( - event, pos, max_stream_token, extra_users=extra_users - ) - except Exception: - logger.exception( - "Error notifying about new room event %s", - event.event_id, - ) - - run_in_background(_notify) + events_and_pos.append((event, pos)) if event.type == EventTypes.Message: # We don't want to block sending messages on any presence code. This # matters as sometimes presence code can take a while. run_in_background(self._bump_active_time, requester.user) + async def _notify() -> None: + try: + await self.notifier.on_new_room_events( + events_and_pos, max_stream_token, extra_users=extra_users + ) + except Exception: + logger.exception("Error notifying about new room events") + + run_in_background(_notify) + return persisted_events[-1] async def _maybe_kick_guest_users( diff --git a/synapse/notifier.py b/synapse/notifier.py index c42bb8266a..26b97cf766 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -294,35 +294,31 @@ class Notifier: """ self._new_join_in_room_callbacks.append(cb) - async def on_new_room_event( + async def on_new_room_events( self, - event: EventBase, - event_pos: PersistedEventPosition, + events_and_pos: List[Tuple[EventBase, PersistedEventPosition]], max_room_stream_token: RoomStreamToken, extra_users: Optional[Collection[UserID]] = None, ) -> None: - """Unwraps event and calls `on_new_room_event_args`.""" - await self.on_new_room_event_args( - event_pos=event_pos, - room_id=event.room_id, - event_id=event.event_id, - event_type=event.type, - state_key=event.get("state_key"), - membership=event.content.get("membership"), - max_room_stream_token=max_room_stream_token, - extra_users=extra_users or [], - ) + """Creates a _PendingRoomEventEntry for each of the listed events and calls + notify_new_room_events with the results.""" + event_entries = [] + for event, pos in events_and_pos: + entry = self.create_pending_room_event_entry( + pos, + extra_users, + event.room_id, + event.type, + event.get("state_key"), + event.content.get("membership"), + ) + event_entries.append((entry, event.event_id)) + await self.notify_new_room_events(event_entries, max_room_stream_token) - async def on_new_room_event_args( + async def notify_new_room_events( self, - room_id: str, - event_id: str, - event_type: str, - state_key: Optional[str], - membership: Optional[str], - event_pos: PersistedEventPosition, + event_entries: List[Tuple[_PendingRoomEventEntry, str]], max_room_stream_token: RoomStreamToken, - extra_users: Optional[Collection[UserID]] = None, ) -> None: """Used by handlers to inform the notifier something has happened in the room, room event wise. @@ -338,22 +334,33 @@ class Notifier: until all previous events have been persisted before notifying the client streams. """ - self.pending_new_room_events.append( - _PendingRoomEventEntry( - event_pos=event_pos, - extra_users=extra_users or [], - room_id=room_id, - type=event_type, - state_key=state_key, - membership=membership, - ) - ) - self._notify_pending_new_room_events(max_room_stream_token) + for event_entry, event_id in event_entries: + self.pending_new_room_events.append(event_entry) + await self._third_party_rules.on_new_event(event_id) - await self._third_party_rules.on_new_event(event_id) + self._notify_pending_new_room_events(max_room_stream_token) self.notify_replication() + def create_pending_room_event_entry( + self, + event_pos: PersistedEventPosition, + extra_users: Optional[Collection[UserID]], + room_id: str, + event_type: str, + state_key: Optional[str], + membership: Optional[str], + ) -> _PendingRoomEventEntry: + """Creates and returns a _PendingRoomEventEntry""" + return _PendingRoomEventEntry( + event_pos=event_pos, + extra_users=extra_users or [], + room_id=room_id, + type=event_type, + state_key=state_key, + membership=membership, + ) + def _notify_pending_new_room_events( self, max_room_stream_token: RoomStreamToken ) -> None: diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index b2522f98ca..18252a2958 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -210,15 +210,16 @@ class ReplicationDataHandler: max_token = self.store.get_room_max_token() event_pos = PersistedEventPosition(instance_name, token) - await self.notifier.on_new_room_event_args( - event_pos=event_pos, - max_room_stream_token=max_token, - extra_users=extra_users, - room_id=row.data.room_id, - event_id=row.data.event_id, - event_type=row.data.type, - state_key=row.data.state_key, - membership=row.data.membership, + event_entry = self.notifier.create_pending_room_event_entry( + event_pos, + extra_users, + row.data.room_id, + row.data.type, + row.data.state_key, + row.data.membership, + ) + await self.notifier.notify_new_room_events( + [(event_entry, row.data.event_id)], max_token ) # If this event is a join, make a note of it so we have an accurate -- cgit 1.5.1 From b6baa46db078c3ef9e6c5751bccb8d2e1c5c5402 Mon Sep 17 00:00:00 2001 From: Shay Date: Wed, 12 Oct 2022 11:01:00 -0700 Subject: Fix a bug where the joined hosts for a given event were not being properly cached (#14125) --- changelog.d/14125.bugfix | 1 + synapse/handlers/federation_event.py | 4 +- synapse/handlers/message.py | 91 +++++++++++++++++++----------------- 3 files changed, 51 insertions(+), 45 deletions(-) create mode 100644 changelog.d/14125.bugfix (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/14125.bugfix b/changelog.d/14125.bugfix new file mode 100644 index 0000000000..852f00ebb9 --- /dev/null +++ b/changelog.d/14125.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.69.0rc1 where the joined hosts for a given event were not being properly cached. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index da319943cc..f382961099 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -414,7 +414,9 @@ class FederationEventHandler: # First, precalculate the joined hosts so that the federation sender doesn't # need to. - await self._event_creation_handler.cache_joined_hosts_for_event(event, context) + await self._event_creation_handler.cache_joined_hosts_for_events( + [(event, context)] + ) await self._check_for_soft_fail(event, context=context, origin=origin) await self._run_push_actions_and_persist_event(event, context) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index da1acea275..4e55ebba0b 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1390,7 +1390,7 @@ class EventCreationHandler: extra_users=extra_users, ), run_in_background( - self.cache_joined_hosts_for_event, event, context + self.cache_joined_hosts_for_events, events_and_context ).addErrback( log_failure, "cache_joined_hosts_for_event failed" ), @@ -1491,62 +1491,65 @@ class EventCreationHandler: await self.store.remove_push_actions_from_staging(event.event_id) raise - async def cache_joined_hosts_for_event( - self, event: EventBase, context: EventContext + async def cache_joined_hosts_for_events( + self, events_and_context: List[Tuple[EventBase, EventContext]] ) -> None: - """Precalculate the joined hosts at the event, when using Redis, so that + """Precalculate the joined hosts at each of the given events, when using Redis, so that external federation senders don't have to recalculate it themselves. """ - if not self._external_cache.is_enabled(): - return - - # If external cache is enabled we should always have this. - assert self._external_cache_joined_hosts_updates is not None + for event, _ in events_and_context: + if not self._external_cache.is_enabled(): + return - # We actually store two mappings, event ID -> prev state group, - # state group -> joined hosts, which is much more space efficient - # than event ID -> joined hosts. - # - # Note: We have to cache event ID -> prev state group, as we don't - # store that in the DB. - # - # Note: We set the state group -> joined hosts cache if it hasn't been - # set for a while, so that the expiry time is reset. + # If external cache is enabled we should always have this. + assert self._external_cache_joined_hosts_updates is not None - state_entry = await self.state.resolve_state_groups_for_events( - event.room_id, event_ids=event.prev_event_ids() - ) + # We actually store two mappings, event ID -> prev state group, + # state group -> joined hosts, which is much more space efficient + # than event ID -> joined hosts. + # + # Note: We have to cache event ID -> prev state group, as we don't + # store that in the DB. + # + # Note: We set the state group -> joined hosts cache if it hasn't been + # set for a while, so that the expiry time is reset. - if state_entry.state_group: - await self._external_cache.set( - "event_to_prev_state_group", - event.event_id, - state_entry.state_group, - expiry_ms=60 * 60 * 1000, + state_entry = await self.state.resolve_state_groups_for_events( + event.room_id, event_ids=event.prev_event_ids() ) - if state_entry.state_group in self._external_cache_joined_hosts_updates: - return + if state_entry.state_group: + await self._external_cache.set( + "event_to_prev_state_group", + event.event_id, + state_entry.state_group, + expiry_ms=60 * 60 * 1000, + ) - state = await state_entry.get_state( - self._storage_controllers.state, StateFilter.all() - ) - with opentracing.start_active_span("get_joined_hosts"): - joined_hosts = await self.store.get_joined_hosts( - event.room_id, state, state_entry + if state_entry.state_group in self._external_cache_joined_hosts_updates: + return + + state = await state_entry.get_state( + self._storage_controllers.state, StateFilter.all() ) + with opentracing.start_active_span("get_joined_hosts"): + joined_hosts = await self.store.get_joined_hosts( + event.room_id, state, state_entry + ) - # Note that the expiry times must be larger than the expiry time in - # _external_cache_joined_hosts_updates. - await self._external_cache.set( - "get_joined_hosts", - str(state_entry.state_group), - list(joined_hosts), - expiry_ms=60 * 60 * 1000, - ) + # Note that the expiry times must be larger than the expiry time in + # _external_cache_joined_hosts_updates. + await self._external_cache.set( + "get_joined_hosts", + str(state_entry.state_group), + list(joined_hosts), + expiry_ms=60 * 60 * 1000, + ) - self._external_cache_joined_hosts_updates[state_entry.state_group] = None + self._external_cache_joined_hosts_updates[ + state_entry.state_group + ] = None async def _validate_canonical_alias( self, -- cgit 1.5.1 From 847e2393f3198b88809c9b99de5c681efbf1c92e Mon Sep 17 00:00:00 2001 From: Shay Date: Tue, 18 Oct 2022 09:58:47 -0700 Subject: Prepatory work for adding power level event to batched events (#14214) --- changelog.d/14214.misc | 1 + synapse/event_auth.py | 19 ++++++++++++++++++- synapse/handlers/event_auth.py | 18 +++++++++++++----- synapse/handlers/federation.py | 12 +++++------- synapse/handlers/message.py | 10 +++++++++- synapse/handlers/room.py | 4 +--- 6 files changed, 47 insertions(+), 17 deletions(-) create mode 100644 changelog.d/14214.misc (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/14214.misc b/changelog.d/14214.misc new file mode 100644 index 0000000000..102928b575 --- /dev/null +++ b/changelog.d/14214.misc @@ -0,0 +1 @@ +When authenticating batched events, check for auth events in batch as well as DB. diff --git a/synapse/event_auth.py b/synapse/event_auth.py index c7d5ef92fc..bab31e33c5 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -15,7 +15,18 @@ import logging import typing -from typing import Any, Collection, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import ( + Any, + Collection, + Dict, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, + Union, +) from canonicaljson import encode_canonical_json from signedjson.key import decode_verify_key_bytes @@ -134,6 +145,7 @@ def validate_event_for_room_version(event: "EventBase") -> None: async def check_state_independent_auth_rules( store: _EventSourceStore, event: "EventBase", + batched_auth_events: Optional[Mapping[str, "EventBase"]] = None, ) -> None: """Check that an event complies with auth rules that are independent of room state @@ -143,6 +155,8 @@ async def check_state_independent_auth_rules( Args: store: the datastore; used to fetch the auth events for validation event: the event being checked. + batched_auth_events: if the event being authed is part of a batch, any events + from the same batch that may be necessary to auth the current event Raises: AuthError if the checks fail @@ -162,6 +176,9 @@ async def check_state_independent_auth_rules( redact_behaviour=EventRedactBehaviour.as_is, allow_rejected=True, ) + if batched_auth_events: + auth_events.update(batched_auth_events) + room_id = event.room_id auth_dict: MutableStateMap[str] = {} expected_auth_types = auth_types_for_event(event.room_version, event) diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index 8249ca1ed2..3bbad0271b 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Collection, List, Optional, Union +from typing import TYPE_CHECKING, Collection, List, Mapping, Optional, Union from synapse import event_auth from synapse.api.constants import ( @@ -29,7 +29,6 @@ from synapse.event_auth import ( ) from synapse.events import EventBase from synapse.events.builder import EventBuilder -from synapse.events.snapshot import EventContext from synapse.types import StateMap, get_domain_from_id if TYPE_CHECKING: @@ -51,12 +50,21 @@ class EventAuthHandler: async def check_auth_rules_from_context( self, event: EventBase, - context: EventContext, + batched_auth_events: Optional[Mapping[str, EventBase]] = None, ) -> None: - """Check an event passes the auth rules at its own auth events""" - await check_state_independent_auth_rules(self._store, event) + """Check an event passes the auth rules at its own auth events + Args: + event: event to be authed + batched_auth_events: if the event being authed is part of a batch, any events + from the same batch that may be necessary to auth the current event + """ + await check_state_independent_auth_rules( + self._store, event, batched_auth_events + ) auth_event_ids = event.auth_event_ids() auth_events_by_id = await self._store.get_events(auth_event_ids) + if batched_auth_events: + auth_events_by_id.update(batched_auth_events) check_state_dependent_auth_rules(event, auth_events_by_id.values()) def compute_auth_events( diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index ccc045d36f..275a37a575 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -942,7 +942,7 @@ class FederationHandler: # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_join_request` - await self._event_auth_handler.check_auth_rules_from_context(event, context) + await self._event_auth_handler.check_auth_rules_from_context(event) return event async def on_invite_request( @@ -1123,7 +1123,7 @@ class FederationHandler: try: # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_leave_request` - await self._event_auth_handler.check_auth_rules_from_context(event, context) + await self._event_auth_handler.check_auth_rules_from_context(event) except AuthError as e: logger.warning("Failed to create new leave %r because %s", event, e) raise e @@ -1182,7 +1182,7 @@ class FederationHandler: try: # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_knock_request` - await self._event_auth_handler.check_auth_rules_from_context(event, context) + await self._event_auth_handler.check_auth_rules_from_context(event) except AuthError as e: logger.warning("Failed to create new knock %r because %s", event, e) raise e @@ -1348,9 +1348,7 @@ class FederationHandler: try: validate_event_for_room_version(event) - await self._event_auth_handler.check_auth_rules_from_context( - event, context - ) + await self._event_auth_handler.check_auth_rules_from_context(event) except AuthError as e: logger.warning("Denying new third party invite %r because %s", event, e) raise e @@ -1400,7 +1398,7 @@ class FederationHandler: try: validate_event_for_room_version(event) - await self._event_auth_handler.check_auth_rules_from_context(event, context) + await self._event_auth_handler.check_auth_rules_from_context(event) except AuthError as e: logger.warning("Denying third party invite %r because %s", event, e) raise e diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 4e55ebba0b..15b828dd74 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1360,8 +1360,16 @@ class EventCreationHandler: else: try: validate_event_for_room_version(event) + # If we are persisting a batch of events the event(s) needed to auth the + # current event may be part of the batch and will not be in the DB yet + event_id_to_event = {e.event_id: e for e, _ in events_and_context} + batched_auth_events = {} + for event_id in event.auth_event_ids(): + auth_event = event_id_to_event.get(event_id) + if auth_event: + batched_auth_events[event_id] = auth_event await self._event_auth_handler.check_auth_rules_from_context( - event, context + event, batched_auth_events ) except AuthError as err: logger.warning("Denying new event %r because %s", event, err) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 4e1aacb408..638f54051a 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -229,9 +229,7 @@ class RoomCreationHandler: }, ) validate_event_for_room_version(tombstone_event) - await self._event_auth_handler.check_auth_rules_from_context( - tombstone_event, tombstone_context - ) + await self._event_auth_handler.check_auth_rules_from_context(tombstone_event) # Upgrade the room # -- cgit 1.5.1 From b7a7ff6ee39da4981dcfdce61bf8ac4735e3d047 Mon Sep 17 00:00:00 2001 From: Shay Date: Fri, 21 Oct 2022 10:46:22 -0700 Subject: Add initial power level event to batch of bulk persisted events when creating a new room. (#14228) --- changelog.d/14228.misc | 1 + synapse/handlers/federation.py | 4 +- synapse/handlers/federation_event.py | 4 +- synapse/handlers/message.py | 14 ++---- synapse/handlers/room.py | 39 ++++----------- synapse/push/bulk_push_rule_evaluator.py | 74 ++++++++++++++++++++++++----- tests/push/test_bulk_push_rule_evaluator.py | 2 +- tests/replication/_base.py | 2 +- 8 files changed, 82 insertions(+), 58 deletions(-) create mode 100644 changelog.d/14228.misc (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/14228.misc b/changelog.d/14228.misc new file mode 100644 index 0000000000..14fe31a8bc --- /dev/null +++ b/changelog.d/14228.misc @@ -0,0 +1 @@ +Add initial power level event to batch of bulk persisted events when creating a new room. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 275a37a575..4fbc79a6cb 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1017,7 +1017,9 @@ class FederationHandler: context = EventContext.for_outlier(self._storage_controllers) - await self._bulk_push_rule_evaluator.action_for_event_by_user(event, context) + await self._bulk_push_rule_evaluator.action_for_events_by_user( + [(event, context)] + ) try: await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 06e41b5cc0..7da6316a82 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -2171,8 +2171,8 @@ class FederationEventHandler: min_depth, ) else: - await self._bulk_push_rule_evaluator.action_for_event_by_user( - event, context + await self._bulk_push_rule_evaluator.action_for_events_by_user( + [(event, context)] ) try: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 15b828dd74..468900a07f 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1433,17 +1433,9 @@ class EventCreationHandler: a room that has been un-partial stated. """ - for event, context in events_and_context: - # Skip push notification actions for historical messages - # because we don't want to notify people about old history back in time. - # The historical messages also do not have the proper `context.current_state_ids` - # and `state_groups` because they have `prev_events` that aren't persisted yet - # (historical messages persisted in reverse-chronological order). - if not event.internal_metadata.is_historical(): - with opentracing.start_active_span("calculate_push_actions"): - await self._bulk_push_rule_evaluator.action_for_event_by_user( - event, context - ) + await self._bulk_push_rule_evaluator.action_for_events_by_user( + events_and_context + ) try: # If we're a worker we need to hit out to the master. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 638f54051a..cc1e5c8f97 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1055,9 +1055,6 @@ class RoomCreationHandler: event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""} depth = 1 - # the last event sent/persisted to the db - last_sent_event_id: Optional[str] = None - # the most recently created event prev_event: List[str] = [] # a map of event types, state keys -> event_ids. We collect these mappings this as events are @@ -1102,26 +1099,6 @@ class RoomCreationHandler: return new_event, new_context - async def send( - event: EventBase, - context: synapse.events.snapshot.EventContext, - creator: Requester, - ) -> int: - nonlocal last_sent_event_id - - ev = await self.event_creation_handler.handle_new_client_event( - requester=creator, - events_and_context=[(event, context)], - ratelimit=False, - ignore_shadow_ban=True, - ) - - last_sent_event_id = ev.event_id - - # we know it was persisted, so must have a stream ordering - assert ev.internal_metadata.stream_ordering - return ev.internal_metadata.stream_ordering - try: config = self._presets_dict[preset_config] except KeyError: @@ -1135,10 +1112,14 @@ class RoomCreationHandler: ) logger.debug("Sending %s in new room", EventTypes.Member) - await send(creation_event, creation_context, creator) + ev = await self.event_creation_handler.handle_new_client_event( + requester=creator, + events_and_context=[(creation_event, creation_context)], + ratelimit=False, + ignore_shadow_ban=True, + ) + last_sent_event_id = ev.event_id - # Room create event must exist at this point - assert last_sent_event_id is not None member_event_id, _ = await self.room_member_handler.update_membership( creator, creator.user, @@ -1157,6 +1138,7 @@ class RoomCreationHandler: depth += 1 state_map[(EventTypes.Member, creator.user.to_string())] = member_event_id + events_to_send = [] # We treat the power levels override specially as this needs to be one # of the first events that get sent into a room. pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) @@ -1165,7 +1147,7 @@ class RoomCreationHandler: EventTypes.PowerLevels, pl_content, False ) current_state_group = power_context._state_group - await send(power_event, power_context, creator) + events_to_send.append((power_event, power_context)) else: power_level_content: JsonDict = { "users": {creator_id: 100}, @@ -1214,9 +1196,8 @@ class RoomCreationHandler: False, ) current_state_group = pl_context._state_group - await send(pl_event, pl_context, creator) + events_to_send.append((pl_event, pl_context)) - events_to_send = [] if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: room_alias_event, room_alias_context = await create_event( EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index a75386f6a0..d7795a9080 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -165,8 +165,21 @@ class BulkPushRuleEvaluator: return rules_by_user async def _get_power_levels_and_sender_level( - self, event: EventBase, context: EventContext + self, + event: EventBase, + context: EventContext, + event_id_to_event: Mapping[str, EventBase], ) -> Tuple[dict, Optional[int]]: + """ + Given an event and an event context, get the power level event relevant to the event + and the power level of the sender of the event. + Args: + event: event to check + context: context of event to check + event_id_to_event: a mapping of event_id to event for a set of events being + batch persisted. This is needed as the sought-after power level event may + be in this batch rather than the DB + """ # There are no power levels and sender levels possible to get from outlier if event.internal_metadata.is_outlier(): return {}, None @@ -177,15 +190,26 @@ class BulkPushRuleEvaluator: ) pl_event_id = prev_state_ids.get(POWER_KEY) + # fastpath: if there's a power level event, that's all we need, and + # not having a power level event is an extreme edge case if pl_event_id: - # fastpath: if there's a power level event, that's all we need, and - # not having a power level event is an extreme edge case - auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)} + # Get the power level event from the batch, or fall back to the database. + pl_event = event_id_to_event.get(pl_event_id) + if pl_event: + auth_events = {POWER_KEY: pl_event} + else: + auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)} else: auth_events_ids = self._event_auth_handler.compute_auth_events( event, prev_state_ids, for_verification=False ) auth_events_dict = await self.store.get_events(auth_events_ids) + # Some needed auth events might be in the batch, combine them with those + # fetched from the database. + for auth_event_id in auth_events_ids: + auth_event = event_id_to_event.get(auth_event_id) + if auth_event: + auth_events_dict[auth_event_id] = auth_event auth_events = {(e.type, e.state_key): e for e in auth_events_dict.values()} sender_level = get_user_power_level(event.sender, auth_events) @@ -194,16 +218,38 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level - @measure_func("action_for_event_by_user") - async def action_for_event_by_user( - self, event: EventBase, context: EventContext + async def action_for_events_by_user( + self, events_and_context: List[Tuple[EventBase, EventContext]] ) -> None: - """Given an event and context, evaluate the push rules, check if the message - should increment the unread count, and insert the results into the - event_push_actions_staging table. + """Given a list of events and their associated contexts, evaluate the push rules + for each event, check if the message should increment the unread count, and + insert the results into the event_push_actions_staging table. """ - if not event.internal_metadata.is_notifiable(): - # Push rules for events that aren't notifiable can't be processed by this + # For batched events the power level events may not have been persisted yet, + # so we pass in the batched events. Thus if the event cannot be found in the + # database we can check in the batch. + event_id_to_event = {e.event_id: e for e, _ in events_and_context} + for event, context in events_and_context: + await self._action_for_event_by_user(event, context, event_id_to_event) + + @measure_func("action_for_event_by_user") + async def _action_for_event_by_user( + self, + event: EventBase, + context: EventContext, + event_id_to_event: Mapping[str, EventBase], + ) -> None: + + if ( + not event.internal_metadata.is_notifiable() + or event.internal_metadata.is_historical() + ): + # Push rules for events that aren't notifiable can't be processed by this and + # we want to skip push notification actions for historical messages + # because we don't want to notify people about old history back in time. + # The historical messages also do not have the proper `context.current_state_ids` + # and `state_groups` because they have `prev_events` that aren't persisted yet + # (historical messages persisted in reverse-chronological order). return # Disable counting as unread unless the experimental configuration is @@ -223,7 +269,9 @@ class BulkPushRuleEvaluator: ( power_levels, sender_power_level, - ) = await self._get_power_levels_and_sender_level(event, context) + ) = await self._get_power_levels_and_sender_level( + event, context, event_id_to_event + ) # Find the event's thread ID. relation = relation_from_event(event) diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index 675d7df2ac..594e7937a8 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -71,4 +71,4 @@ class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): bulk_evaluator = BulkPushRuleEvaluator(self.hs) # should not raise - self.get_success(bulk_evaluator.action_for_event_by_user(event, context)) + self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) diff --git a/tests/replication/_base.py b/tests/replication/_base.py index ce53f808db..121f3d8d65 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -371,7 +371,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): config=worker_hs.config.server.listeners[0], resource=resource, server_version_string="1", - max_request_body_size=4096, + max_request_body_size=8192, reactor=self.reactor, ) -- cgit 1.5.1 From 86c5a710d8b4212f8a8a668d7d4a79c0bb371508 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 3 Nov 2022 16:21:31 +0000 Subject: Implement MSC3912: Relation-based redactions (#14260) Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> --- changelog.d/14260.feature | 1 + synapse/api/constants.py | 2 + synapse/config/experimental.py | 3 + synapse/handlers/message.py | 47 ++++- synapse/handlers/relations.py | 56 +++++- synapse/rest/client/room.py | 57 ++++-- synapse/rest/client/versions.py | 2 + synapse/storage/databases/main/relations.py | 36 ++++ tests/rest/client/test_redactions.py | 273 +++++++++++++++++++++++++++- tests/rest/client/utils.py | 37 ++++ 10 files changed, 486 insertions(+), 28 deletions(-) create mode 100644 changelog.d/14260.feature (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/14260.feature b/changelog.d/14260.feature new file mode 100644 index 0000000000..102dc7b3e0 --- /dev/null +++ b/changelog.d/14260.feature @@ -0,0 +1 @@ +Add experimental support for [MSC3912](https://github.com/matrix-org/matrix-spec-proposals/pull/3912): Relation-based redactions. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 44c5ffc6a5..bc04a0755b 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -125,6 +125,8 @@ class EventTypes: MSC2716_BATCH: Final = "org.matrix.msc2716.batch" MSC2716_MARKER: Final = "org.matrix.msc2716.marker" + Reaction: Final = "m.reaction" + class ToDeviceEventTypes: RoomKeyRequest: Final = "m.room_key_request" diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index d9bdd66d55..d4b71d1673 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -128,3 +128,6 @@ class ExperimentalConfig(Config): self.msc3886_endpoint: Optional[str] = experimental.get( "msc3886_endpoint", None ) + + # MSC3912: Relation-based redactions. + self.msc3912_enabled: bool = experimental.get("msc3912_enabled", False) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 468900a07f..4cf593cfdc 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -877,6 +877,36 @@ class EventCreationHandler: return prev_event return None + async def get_event_from_transaction( + self, + requester: Requester, + txn_id: str, + room_id: str, + ) -> Optional[EventBase]: + """For the given transaction ID and room ID, check if there is a matching event. + If so, fetch it and return it. + + Args: + requester: The requester making the request in the context of which we want + to fetch the event. + txn_id: The transaction ID. + room_id: The room ID. + + Returns: + An event if one could be found, None otherwise. + """ + if requester.access_token_id: + existing_event_id = await self.store.get_event_id_from_transaction_id( + room_id, + requester.user.to_string(), + requester.access_token_id, + txn_id, + ) + if existing_event_id: + return await self.store.get_event(existing_event_id) + + return None + async def create_and_send_nonmember_event( self, requester: Requester, @@ -956,18 +986,17 @@ class EventCreationHandler: # extremities to pile up, which in turn leads to state resolution # taking longer. async with self.limiter.queue(event_dict["room_id"]): - if txn_id and requester.access_token_id: - existing_event_id = await self.store.get_event_id_from_transaction_id( - event_dict["room_id"], - requester.user.to_string(), - requester.access_token_id, - txn_id, + if txn_id: + event = await self.get_event_from_transaction( + requester, txn_id, event_dict["room_id"] ) - if existing_event_id: - event = await self.store.get_event(existing_event_id) + if event: # we know it was persisted, so must have a stream ordering assert event.internal_metadata.stream_ordering - return event, event.internal_metadata.stream_ordering + return ( + event, + event.internal_metadata.stream_ordering, + ) event, context = await self.create_event( requester, diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 0a0c6d938e..8e71dda970 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tup import attr -from synapse.api.constants import RelationTypes +from synapse.api.constants import EventTypes, RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase, relation_from_event from synapse.logging.opentracing import trace @@ -75,6 +75,7 @@ class RelationsHandler: self._clock = hs.get_clock() self._event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() + self._event_creation_handler = hs.get_event_creation_handler() async def get_relations( self, @@ -205,6 +206,59 @@ class RelationsHandler: return related_events, next_token + async def redact_events_related_to( + self, + requester: Requester, + event_id: str, + initial_redaction_event: EventBase, + relation_types: List[str], + ) -> None: + """Redacts all events related to the given event ID with one of the given + relation types. + + This method is expected to be called when redacting the event referred to by + the given event ID. + + If an event cannot be redacted (e.g. because of insufficient permissions), log + the error and try to redact the next one. + + Args: + requester: The requester to redact events on behalf of. + event_id: The event IDs to look and redact relations of. + initial_redaction_event: The redaction for the event referred to by + event_id. + relation_types: The types of relations to look for. + + Raises: + ShadowBanError if the requester is shadow-banned + """ + related_event_ids = ( + await self._main_store.get_all_relations_for_event_with_types( + event_id, relation_types + ) + ) + + for related_event_id in related_event_ids: + try: + await self._event_creation_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.Redaction, + "content": initial_redaction_event.content, + "room_id": initial_redaction_event.room_id, + "sender": requester.user.to_string(), + "redacts": related_event_id, + }, + ratelimit=False, + ) + except SynapseError as e: + logger.warning( + "Failed to redact event %s (related to event %s): %s", + related_event_id, + event_id, + e.msg, + ) + async def get_annotations_for_event( self, event_id: str, diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 01e5079963..91cb791139 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -52,6 +52,7 @@ from synapse.http.servlet import ( from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import set_tag +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.client._base import client_patterns from synapse.rest.client.transactions import HttpTransactionCache from synapse.storage.state import StateFilter @@ -1029,6 +1030,8 @@ class RoomRedactEventRestServlet(TransactionRestServlet): super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() + self._relation_handler = hs.get_relations_handler() + self._msc3912_enabled = hs.config.experimental.msc3912_enabled def register(self, http_server: HttpServer) -> None: PATTERNS = "/rooms/(?P[^/]*)/redact/(?P[^/]*)" @@ -1045,20 +1048,46 @@ class RoomRedactEventRestServlet(TransactionRestServlet): content = parse_json_object_from_request(request) try: - ( - event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, - { - "type": EventTypes.Redaction, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - "redacts": event_id, - }, - txn_id=txn_id, - ) + with_relations = None + if self._msc3912_enabled and "org.matrix.msc3912.with_relations" in content: + with_relations = content["org.matrix.msc3912.with_relations"] + del content["org.matrix.msc3912.with_relations"] + + # Check if there's an existing event for this transaction now (even though + # create_and_send_nonmember_event also does it) because, if there's one, + # then we want to skip the call to redact_events_related_to. + event = None + if txn_id: + event = await self.event_creation_handler.get_event_from_transaction( + requester, txn_id, room_id + ) + + if event is None: + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.Redaction, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + "redacts": event_id, + }, + txn_id=txn_id, + ) + + if with_relations: + run_as_background_process( + "redact_related_events", + self._relation_handler.redact_events_related_to, + requester=requester, + event_id=event_id, + initial_redaction_event=event, + relation_types=with_relations, + ) + event_id = event.event_id except ShadowBanError: event_id = "$" + random_string(43) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 9b1b72c68a..180a11ef88 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -119,6 +119,8 @@ class VersionsRestServlet(RestServlet): # Adds support for simple HTTP rendezvous as per MSC3886 "org.matrix.msc3886": self.config.experimental.msc3886_endpoint is not None, + # Adds support for relation-based redactions as per MSC3912. + "org.matrix.msc3912": self.config.experimental.msc3912_enabled, }, }, ) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index c022510e76..ca431002c8 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -295,6 +295,42 @@ class RelationsWorkerStore(SQLBaseStore): "get_recent_references_for_event", _get_recent_references_for_event_txn ) + async def get_all_relations_for_event_with_types( + self, + event_id: str, + relation_types: List[str], + ) -> List[str]: + """Get the event IDs of all events that have a relation to the given event with + one of the given relation types. + + Args: + event_id: The event for which to look for related events. + relation_types: The types of relations to look for. + + Returns: + A list of the IDs of the events that relate to the given event with one of + the given relation types. + """ + + def get_all_relation_ids_for_event_with_types_txn( + txn: LoggingTransaction, + ) -> List[str]: + rows = self.db_pool.simple_select_many_txn( + txn=txn, + table="event_relations", + column="relation_type", + iterable=relation_types, + keyvalues={"relates_to_id": event_id}, + retcols=["event_id"], + ) + + return [row["event_id"] for row in rows] + + return await self.db_pool.runInteraction( + desc="get_all_relation_ids_for_event_with_types", + func=get_all_relation_ids_for_event_with_types_txn, + ) + async def event_includes_relation(self, event_id: str) -> bool: """Check if the given event relates to another event. diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index be4c67d68e..5dfe44defb 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -11,17 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Optional from twisted.test.proto_helpers import MemoryReactor +from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin from synapse.rest.client import login, room, sync from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock -from tests.unittest import HomeserverTestCase +from tests.unittest import HomeserverTestCase, override_config class RedactionsTestCase(HomeserverTestCase): @@ -67,7 +68,12 @@ class RedactionsTestCase(HomeserverTestCase): ) def _redact_event( - self, access_token: str, room_id: str, event_id: str, expect_code: int = 200 + self, + access_token: str, + room_id: str, + event_id: str, + expect_code: int = 200, + with_relations: Optional[List[str]] = None, ) -> JsonDict: """Helper function to send a redaction event. @@ -75,7 +81,13 @@ class RedactionsTestCase(HomeserverTestCase): """ path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id) - channel = self.make_request("POST", path, content={}, access_token=access_token) + request_content = {} + if with_relations: + request_content["org.matrix.msc3912.with_relations"] = with_relations + + channel = self.make_request( + "POST", path, request_content, access_token=access_token + ) self.assertEqual(channel.code, expect_code) return channel.json_body @@ -201,3 +213,256 @@ class RedactionsTestCase(HomeserverTestCase): # These should all succeed, even though this would be denied by # the standard message ratelimiter self._redact_event(self.mod_access_token, self.room_id, msg_id) + + @override_config({"experimental_features": {"msc3912_enabled": True}}) + def test_redact_relations(self) -> None: + """Tests that we can redact the relations of an event at the same time as the + event itself. + """ + # Send a root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "hello"}, + tok=self.mod_access_token, + ) + root_event_id = res["event_id"] + + # Send an edit to this root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "body": " * hello world", + "m.new_content": { + "body": "hello world", + "msgtype": "m.text", + }, + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.REPLACE, + }, + "msgtype": "m.text", + }, + tok=self.mod_access_token, + ) + edit_event_id = res["event_id"] + + # Also send a threaded message whose root is the same as the edit's. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message 1", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.mod_access_token, + ) + threaded_event_id = res["event_id"] + + # Also send a reaction, again with the same root. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Reaction, + content={ + "m.relates_to": { + "rel_type": RelationTypes.ANNOTATION, + "event_id": root_event_id, + "key": "👍", + } + }, + tok=self.mod_access_token, + ) + reaction_event_id = res["event_id"] + + # Redact the root event, specifying that we also want to delete events that + # relate to it with m.replace. + self._redact_event( + self.mod_access_token, + self.room_id, + root_event_id, + with_relations=[ + RelationTypes.REPLACE, + RelationTypes.THREAD, + ], + ) + + # Check that the root event got redacted. + event_dict = self.helper.get_event( + self.room_id, root_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the edit got redacted. + event_dict = self.helper.get_event( + self.room_id, edit_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the threaded message got redacted. + event_dict = self.helper.get_event( + self.room_id, threaded_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the reaction did not get redacted. + event_dict = self.helper.get_event( + self.room_id, reaction_event_id, self.mod_access_token + ) + self.assertNotIn("redacted_because", event_dict, event_dict) + + @override_config({"experimental_features": {"msc3912_enabled": True}}) + def test_redact_relations_no_perms(self) -> None: + """Tests that, when redacting a message along with its relations, if not all + the related messages can be redacted because of insufficient permissions, the + server still redacts all the ones that can be. + """ + # Send a root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "root", + }, + tok=self.other_access_token, + ) + root_event_id = res["event_id"] + + # Send a first threaded message, this one from the moderator. We do this for the + # first message with the m.thread relation (and not the last one) to ensure + # that, when the server fails to redact it, it doesn't stop there, and it + # instead goes on to redact the other one. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message 1", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.mod_access_token, + ) + first_threaded_event_id = res["event_id"] + + # Send a second threaded message, this time from the user who'll perform the + # redaction. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message 2", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.other_access_token, + ) + second_threaded_event_id = res["event_id"] + + # Redact the thread's root, and request that all threaded messages are also + # redacted. Send that request from the non-mod user, so that the first threaded + # event cannot be redacted. + self._redact_event( + self.other_access_token, + self.room_id, + root_event_id, + with_relations=[RelationTypes.THREAD], + ) + + # Check that the thread root got redacted. + event_dict = self.helper.get_event( + self.room_id, root_event_id, self.other_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the last message in the thread got redacted, despite failing to + # redact the one before it. + event_dict = self.helper.get_event( + self.room_id, second_threaded_event_id, self.other_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the message that was sent into the tread by the mod user is not + # redacted. + event_dict = self.helper.get_event( + self.room_id, first_threaded_event_id, self.other_access_token + ) + self.assertIn("body", event_dict["content"], event_dict) + self.assertEqual("message 1", event_dict["content"]["body"]) + + @override_config({"experimental_features": {"msc3912_enabled": True}}) + def test_redact_relations_txn_id_reuse(self) -> None: + """Tests that redacting a message using a transaction ID, then reusing the same + transaction ID but providing an additional list of relations to redact, is + effectively a no-op. + """ + # Send a root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "root", + }, + tok=self.mod_access_token, + ) + root_event_id = res["event_id"] + + # Send a first threaded message. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "I'm in a thread!", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.mod_access_token, + ) + threaded_event_id = res["event_id"] + + # Send a first redaction request which redacts only the root event. + channel = self.make_request( + method="PUT", + path=f"/rooms/{self.room_id}/redact/{root_event_id}/foo", + content={}, + access_token=self.mod_access_token, + ) + self.assertEqual(channel.code, 200) + + # Send a second redaction request which redacts the root event as well as + # threaded messages. + channel = self.make_request( + method="PUT", + path=f"/rooms/{self.room_id}/redact/{root_event_id}/foo", + content={"org.matrix.msc3912.with_relations": [RelationTypes.THREAD]}, + access_token=self.mod_access_token, + ) + self.assertEqual(channel.code, 200) + + # Check that the root event got redacted. + event_dict = self.helper.get_event( + self.room_id, root_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict) + + # Check that the threaded message didn't get redacted (since that wasn't part of + # the original redaction). + event_dict = self.helper.get_event( + self.room_id, threaded_event_id, self.mod_access_token + ) + self.assertIn("body", event_dict["content"], event_dict) + self.assertEqual("I'm in a thread!", event_dict["content"]["body"]) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 706399fae5..8d6f2b6ff9 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -410,6 +410,43 @@ class RestHelper: return channel.json_body + def get_event( + self, + room_id: str, + event_id: str, + tok: Optional[str] = None, + expect_code: int = HTTPStatus.OK, + ) -> JsonDict: + """Request a specific event from the server. + + Args: + room_id: the room in which the event was sent. + event_id: the event's ID. + tok: the token to request the event with. + expect_code: the expected HTTP status for the response. + + Returns: + The event as a dict. + """ + path = f"/_matrix/client/v3/rooms/{room_id}/event/{event_id}" + if tok: + path = path + f"?access_token={tok}" + + channel = make_request( + self.hs.get_reactor(), + self.site, + "GET", + path, + ) + + assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( + expect_code, + channel.code, + channel.result["body"], + ) + + return channel.json_body + def _read_write_state( self, room_id: str, -- cgit 1.5.1