From 55e4d27b36fd69a3cf3eceecbd42706579ef2dc7 Mon Sep 17 00:00:00 2001 From: Shay Date: Wed, 8 Feb 2023 11:25:11 -0800 Subject: Limit concurrent event creation for a room to avoid state resolution when sending bursts of events to a local room (#14977) --- synapse/handlers/message.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'synapse/handlers/message.py') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e688e00575..5f6da2943f 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -499,9 +499,9 @@ class EventCreationHandler: self.request_ratelimiter = hs.get_request_ratelimiter() - # We arbitrarily limit concurrent event creation for a room to 5. - # This is to stop us from diverging history *too* much. - self.limiter = Linearizer(max_count=5, name="room_event_creation_limit") + # We limit concurrent event creation for a room to 1. This prevents state resolution + # from occurring when sending bursts of events to a local room + self.limiter = Linearizer(max_count=1, name="room_event_creation_limit") self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator() -- cgit 1.5.1 From 03bccd542bcffe3ea12cd35108740a7d62dd38ab Mon Sep 17 00:00:00 2001 From: Shay Date: Thu, 9 Feb 2023 13:05:02 -0800 Subject: Add a class UnpersistedEventContext to allow for the batching up of storing state groups (#14675) * add class UnpersistedEventContext * modify create new client event to create unpersistedeventcontexts * persist event contexts after creation * fix tests to persist unpersisted event contexts * cleanup * misc lints + cleanup * changelog + fix comments * lints * fix batch insertion? * reduce redundant calculation * add unpersisted event classes * rework compute_event_context, split into function that returns unpersisted event context and then persists it * use calculate_context_info to create unpersisted event contexts * update typing * $%#^&* * black * fix comments and consolidate classes, use attr.s for class * requested changes * lint * requested changes * requested changes * refactor to be stupidly explicit * clearer renaming and flow * make partial state non-optional * update docstrings --------- Co-authored-by: Erik Johnston --- changelog.d/14675.misc | 1 + synapse/events/snapshot.py | 174 ++++++++++++++++++++++++++++++++- synapse/events/third_party_rules.py | 6 +- synapse/handlers/federation.py | 59 ++++++++---- synapse/handlers/federation_event.py | 6 +- synapse/handlers/message.py | 42 +++++--- synapse/state/__init__.py | 176 ++++++++++++++-------------------- tests/handlers/test_user_directory.py | 4 +- tests/rest/admin/test_user.py | 4 +- tests/storage/test_redaction.py | 24 +++-- tests/storage/test_state.py | 4 +- tests/test_utils/event_injection.py | 7 +- tests/test_visibility.py | 9 +- tests/utils.py | 5 +- 14 files changed, 359 insertions(+), 162 deletions(-) create mode 100644 changelog.d/14675.misc (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/14675.misc b/changelog.d/14675.misc new file mode 100644 index 0000000000..bc1ac1c82a --- /dev/null +++ b/changelog.d/14675.misc @@ -0,0 +1 @@ +Add a class UnpersistedEventContext to allow for the batching up of storing state groups. diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 6eaef8b57a..e0d82ad81c 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from abc import ABC, abstractmethod from typing import TYPE_CHECKING, List, Optional, Tuple import attr @@ -26,8 +27,51 @@ if TYPE_CHECKING: from synapse.types.state import StateFilter +class UnpersistedEventContextBase(ABC): + """ + This is a base class for EventContext and UnpersistedEventContext, objects which + hold information relevant to storing an associated event. Note that an + UnpersistedEventContexts must be converted into an EventContext before it is + suitable to send to the db with its associated event. + + Attributes: + _storage: storage controllers for interfacing with the database + app_service: If the associated event is being sent by a (local) application service, that + app service. + """ + + def __init__(self, storage_controller: "StorageControllers"): + self._storage: "StorageControllers" = storage_controller + self.app_service: Optional[ApplicationService] = None + + @abstractmethod + async def persist( + self, + event: EventBase, + ) -> "EventContext": + """ + A method to convert an UnpersistedEventContext to an EventContext, suitable for + sending to the database with the associated event. + """ + pass + + @abstractmethod + async def get_prev_state_ids( + self, state_filter: Optional["StateFilter"] = None + ) -> StateMap[str]: + """ + Gets the room state at the event (ie not including the event if the event is a + state event). + + Args: + state_filter: specifies the type of state event to fetch from DB, example: + EventTypes.JoinRules + """ + pass + + @attr.s(slots=True, auto_attribs=True) -class EventContext: +class EventContext(UnpersistedEventContextBase): """ Holds information relevant to persisting an event @@ -77,9 +121,6 @@ class EventContext: delta_ids: If ``prev_group`` is not None, the state delta between ``prev_group`` and ``state_group``. - app_service: If this event is being sent by a (local) application service, that - app service. - partial_state: if True, we may be storing this event with a temporary, incomplete state. """ @@ -122,6 +163,9 @@ class EventContext: """Return an EventContext instance suitable for persisting an outlier event""" return EventContext(storage=storage) + async def persist(self, event: EventBase) -> "EventContext": + return self + async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict: """Converts self to a type that can be serialized as JSON, and then deserialized by `deserialize` @@ -254,6 +298,128 @@ class EventContext: ) +@attr.s(slots=True, auto_attribs=True) +class UnpersistedEventContext(UnpersistedEventContextBase): + """ + The event context holds information about the state groups for an event. It is important + to remember that an event technically has two state groups: the state group before the + event, and the state group after the event. If the event is not a state event, the state + group will not change (ie the state group before the event will be the same as the state + group after the event), but if it is a state event the state group before the event + will differ from the state group after the event. + This is a version of an EventContext before the new state group (if any) has been + computed and stored. It contains information about the state before the event (which + also may be the information after the event, if the event is not a state event). The + UnpersistedEventContext must be converted into an EventContext by calling the method + 'persist' on it before it is suitable to be sent to the DB for processing. + + state_group_after_event: + The state group after the event. This will always be None until it is persisted. + If the event is not a state event, this will be the same as + state_group_before_event. + + state_group_before_event: + The ID of the state group representing the state of the room before this event. + + state_delta_due_to_event: + If the event is a state event, then this is the delta of the state between + `state_group` and `state_group_before_event` + + prev_group_for_state_group_before_event: + If it is known, ``state_group_before_event``'s previous state group. + + delta_ids_to_state_group_before_event: + If ``prev_group_for_state_group_before_event`` is not None, the state delta + between ``prev_group_for_state_group_before_event`` and ``state_group_before_event``. + + partial_state: + Whether the event has partial state. + + state_map_before_event: + A map of the state before the event, i.e. the state at `state_group_before_event` + """ + + _storage: "StorageControllers" + state_group_before_event: Optional[int] + state_group_after_event: Optional[int] + state_delta_due_to_event: Optional[dict] + prev_group_for_state_group_before_event: Optional[int] + delta_ids_to_state_group_before_event: Optional[StateMap[str]] + partial_state: bool + state_map_before_event: Optional[StateMap[str]] = None + + async def get_prev_state_ids( + self, state_filter: Optional["StateFilter"] = None + ) -> StateMap[str]: + """ + Gets the room state map, excluding this event. + + Args: + state_filter: specifies the type of state event to fetch from DB + + Returns: + Maps a (type, state_key) to the event ID of the state event matching + this tuple. + """ + if self.state_map_before_event: + return self.state_map_before_event + + assert self.state_group_before_event is not None + return await self._storage.state.get_state_ids_for_group( + self.state_group_before_event, state_filter + ) + + async def persist(self, event: EventBase) -> EventContext: + """ + Creates a full `EventContext` for the event, persisting any referenced state that + has not yet been persisted. + + Args: + event: event that the EventContext is associated with. + + Returns: An EventContext suitable for sending to the database with the event + for persisting + """ + assert self.partial_state is not None + + # If we have a full set of state for before the event but don't have a state + # group for that state, we need to get one + if self.state_group_before_event is None: + assert self.state_map_before_event + state_group_before_event = await self._storage.state.store_state_group( + event.event_id, + event.room_id, + prev_group=self.prev_group_for_state_group_before_event, + delta_ids=self.delta_ids_to_state_group_before_event, + current_state_ids=self.state_map_before_event, + ) + self.state_group_before_event = state_group_before_event + + # if the event isn't a state event the state group doesn't change + if not self.state_delta_due_to_event: + state_group_after_event = self.state_group_before_event + + # otherwise if it is a state event we need to get a state group for it + else: + state_group_after_event = await self._storage.state.store_state_group( + event.event_id, + event.room_id, + prev_group=self.state_group_before_event, + delta_ids=self.state_delta_due_to_event, + current_state_ids=None, + ) + + return EventContext.with_state( + storage=self._storage, + state_group=state_group_after_event, + state_group_before_event=self.state_group_before_event, + state_delta_due_to_event=self.state_delta_due_to_event, + partial_state=self.partial_state, + prev_group=self.state_group_before_event, + delta_ids=self.state_delta_due_to_event, + ) + + def _encode_state_dict( state_dict: Optional[StateMap[str]], ) -> Optional[List[Tuple[str, str, str]]]: diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 72ab696898..97c61cc258 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -18,7 +18,7 @@ from twisted.internet.defer import CancelledError from synapse.api.errors import ModuleFailedException, SynapseError from synapse.events import EventBase -from synapse.events.snapshot import EventContext +from synapse.events.snapshot import UnpersistedEventContextBase from synapse.storage.roommember import ProfileInfo from synapse.types import Requester, StateMap from synapse.util.async_helpers import delay_cancellation, maybe_awaitable @@ -231,7 +231,9 @@ class ThirdPartyEventRules: self._on_threepid_bind_callbacks.append(on_threepid_bind) async def check_event_allowed( - self, event: EventBase, context: EventContext + self, + event: EventBase, + context: UnpersistedEventContextBase, ) -> Tuple[bool, Optional[dict]]: """Check if a provided event should be allowed in the given context. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 7f64130e0a..43ed4a3dd1 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -56,7 +56,7 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.crypto.event_signing import compute_event_signature from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase -from synapse.events.snapshot import EventContext +from synapse.events.snapshot import EventContext, UnpersistedEventContextBase from synapse.events.validator import EventValidator from synapse.federation.federation_client import InvalidResponseError from synapse.http.servlet import assert_params_in_dict @@ -990,7 +990,10 @@ class FederationHandler: ) try: - event, context = await self.event_creation_handler.create_new_client_event( + ( + event, + unpersisted_context, + ) = await self.event_creation_handler.create_new_client_event( builder=builder ) except SynapseError as e: @@ -998,7 +1001,9 @@ class FederationHandler: raise # Ensure the user can even join the room. - await self._federation_event_handler.check_join_restrictions(context, event) + await self._federation_event_handler.check_join_restrictions( + unpersisted_context, event + ) # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_join_request` @@ -1178,7 +1183,7 @@ class FederationHandler: }, ) - event, context = await self.event_creation_handler.create_new_client_event( + event, _ = await self.event_creation_handler.create_new_client_event( builder=builder ) @@ -1228,12 +1233,13 @@ class FederationHandler: }, ) - event, context = await self.event_creation_handler.create_new_client_event( - builder=builder - ) + ( + event, + unpersisted_context, + ) = await self.event_creation_handler.create_new_client_event(builder=builder) event_allowed, _ = await self.third_party_event_rules.check_event_allowed( - event, context + event, unpersisted_context ) if not event_allowed: logger.warning("Creation of knock %s forbidden by third-party rules", event) @@ -1406,15 +1412,20 @@ class FederationHandler: try: ( event, - context, + unpersisted_context, ) = await self.event_creation_handler.create_new_client_event( builder=builder ) - event, context = await self.add_display_name_to_third_party_invite( - room_version_obj, event_dict, event, context + ( + event, + unpersisted_context, + ) = await self.add_display_name_to_third_party_invite( + room_version_obj, event_dict, event, unpersisted_context ) + context = await unpersisted_context.persist(event) + EventValidator().validate_new(event, self.config) # We need to tell the transaction queue to send this out, even @@ -1483,14 +1494,19 @@ class FederationHandler: try: ( event, - context, + unpersisted_context, ) = await self.event_creation_handler.create_new_client_event( builder=builder ) - event, context = await self.add_display_name_to_third_party_invite( - room_version_obj, event_dict, event, context + ( + event, + unpersisted_context, + ) = await self.add_display_name_to_third_party_invite( + room_version_obj, event_dict, event, unpersisted_context ) + context = await unpersisted_context.persist(event) + try: validate_event_for_room_version(event) await self._event_auth_handler.check_auth_rules_from_context(event) @@ -1522,8 +1538,8 @@ class FederationHandler: room_version_obj: RoomVersion, event_dict: JsonDict, event: EventBase, - context: EventContext, - ) -> Tuple[EventBase, EventContext]: + context: UnpersistedEventContextBase, + ) -> Tuple[EventBase, UnpersistedEventContextBase]: key = ( EventTypes.ThirdPartyInvite, event.content["third_party_invite"]["signed"]["token"], @@ -1557,11 +1573,14 @@ class FederationHandler: room_version_obj, event_dict ) EventValidator().validate_builder(builder) - event, context = await self.event_creation_handler.create_new_client_event( - builder=builder - ) + + ( + event, + unpersisted_context, + ) = await self.event_creation_handler.create_new_client_event(builder=builder) + EventValidator().validate_new(event, self.config) - return event, context + return event, unpersisted_context async def _check_signature(self, event: EventBase, context: EventContext) -> None: """ diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index e037acbca2..3561f2f1de 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -58,7 +58,7 @@ from synapse.event_auth import ( validate_event_for_room_version, ) from synapse.events import EventBase -from synapse.events.snapshot import EventContext +from synapse.events.snapshot import EventContext, UnpersistedEventContextBase from synapse.federation.federation_client import InvalidResponseError, PulledPduInfo from synapse.logging.context import nested_logging_context from synapse.logging.opentracing import ( @@ -426,7 +426,9 @@ class FederationEventHandler: return event, context async def check_join_restrictions( - self, context: EventContext, event: EventBase + self, + context: UnpersistedEventContextBase, + event: EventBase, ) -> None: """Check that restrictions in restricted join rules are matched diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 5f6da2943f..3e30f52e4d 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -48,7 +48,7 @@ from synapse.api.urls import ConsentURIBuilder from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase, relation_from_event from synapse.events.builder import EventBuilder -from synapse.events.snapshot import EventContext +from synapse.events.snapshot import EventContext, UnpersistedEventContextBase from synapse.events.utils import maybe_upsert_event_field from synapse.events.validator import EventValidator from synapse.handlers.directory import DirectoryHandler @@ -708,7 +708,7 @@ class EventCreationHandler: builder.internal_metadata.historical = historical - event, context = await self.create_new_client_event( + event, unpersisted_context = await self.create_new_client_event( builder=builder, requester=requester, allow_no_prev_events=allow_no_prev_events, @@ -721,6 +721,8 @@ class EventCreationHandler: current_state_group=current_state_group, ) + context = await unpersisted_context.persist(event) + # In an ideal world we wouldn't need the second part of this condition. However, # this behaviour isn't spec'd yet, meaning we should be able to deactivate this # behaviour. Another reason is that this code is also evaluated each time a new @@ -1083,13 +1085,14 @@ class EventCreationHandler: state_map: Optional[StateMap[str]] = None, for_batch: bool = False, current_state_group: Optional[int] = None, - ) -> Tuple[EventBase, EventContext]: + ) -> Tuple[EventBase, UnpersistedEventContextBase]: """Create a new event for a local client. If bool for_batch is true, will create an event using the prev_event_ids, and will create an event context for the event using the parameters state_map and current_state_group, thus these parameters must be provided in this case if for_batch is True. The subsequently created event and context are suitable for being batched up and bulk persisted to the database - with other similarly created events. + with other similarly created events. Note that this returns an UnpersistedEventContext, + which must be converted to an EventContext before it can be sent to the DB. Args: builder: @@ -1131,7 +1134,7 @@ class EventCreationHandler: batch persisting Returns: - Tuple of created event, context + Tuple of created event, UnpersistedEventContext """ # Strip down the state_event_ids to only what we need to auth the event. # For example, we don't need extra m.room.member that don't match event.sender @@ -1192,9 +1195,16 @@ class EventCreationHandler: event = await builder.build( prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth ) - context = await self.state.compute_event_context_for_batched( - event, state_map, current_state_group + + context: UnpersistedEventContextBase = ( + await self.state.calculate_context_info( + event, + state_ids_before_event=state_map, + partial_state=False, + state_group_before_event=current_state_group, + ) ) + else: event = await builder.build( prev_event_ids=prev_event_ids, @@ -1244,16 +1254,17 @@ class EventCreationHandler: state_map_for_event[(data.event_type, data.state_key)] = state_id - context = await self.state.compute_event_context( + # TODO(faster_joins): check how MSC2716 works and whether we can have + # partial state here + # https://github.com/matrix-org/synapse/issues/13003 + context = await self.state.calculate_context_info( event, state_ids_before_event=state_map_for_event, - # TODO(faster_joins): check how MSC2716 works and whether we can have - # partial state here - # https://github.com/matrix-org/synapse/issues/13003 partial_state=False, ) + else: - context = await self.state.compute_event_context(event) + context = await self.state.calculate_context_info(event) if requester: context.app_service = requester.app_service @@ -2082,9 +2093,9 @@ class EventCreationHandler: async def _rebuild_event_after_third_party_rules( self, third_party_result: dict, original_event: EventBase - ) -> Tuple[EventBase, EventContext]: + ) -> Tuple[EventBase, UnpersistedEventContextBase]: # the third_party_event_rules want to replace the event. - # we do some basic checks, and then return the replacement event and context. + # we do some basic checks, and then return the replacement event. # Construct a new EventBuilder and validate it, which helps with the # rest of these checks. @@ -2138,5 +2149,6 @@ class EventCreationHandler: # we rebuild the event context, to be on the safe side. If nothing else, # delta_ids might need an update. - context = await self.state.compute_event_context(event) + context = await self.state.calculate_context_info(event) + return event, context diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index fdfb46ab82..e877e6f1a1 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -39,7 +39,11 @@ from prometheus_client import Counter, Histogram from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions from synapse.events import EventBase -from synapse.events.snapshot import EventContext +from synapse.events.snapshot import ( + EventContext, + UnpersistedEventContext, + UnpersistedEventContextBase, +) from synapse.logging.context import ContextResourceUsage from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet from synapse.state import v1, v2 @@ -262,31 +266,31 @@ class StateHandler: state = await entry.get_state(self._state_storage_controller, StateFilter.all()) return await self.store.get_joined_hosts(room_id, state, entry) - async def compute_event_context( + async def calculate_context_info( self, event: EventBase, state_ids_before_event: Optional[StateMap[str]] = None, partial_state: Optional[bool] = None, - ) -> EventContext: - """Build an EventContext structure for a non-outlier event. - - (for an outlier, call EventContext.for_outlier directly) - - This works out what the current state should be for the event, and - generates a new state group if necessary. - - Args: - event: - state_ids_before_event: The event ids of the state before the event if - it can't be calculated from existing events. This is normally - only specified when receiving an event from federation where we - don't have the prev events, e.g. when backfilling. - partial_state: - `True` if `state_ids_before_event` is partial and omits non-critical - membership events. - `False` if `state_ids_before_event` is the full state. - `None` when `state_ids_before_event` is not provided. In this case, the - flag will be calculated based on `event`'s prev events. + state_group_before_event: Optional[int] = None, + ) -> UnpersistedEventContextBase: + """ + Calulates the contents of an unpersisted event context, other than the current + state group (which is either provided or calculated when the event context is persisted) + + state_ids_before_event: + The event ids of the full state before the event if + it can't be calculated from existing events. This is normally + only specified when receiving an event from federation where we + don't have the prev events, e.g. when backfilling or when the event + is being created for batch persisting. + partial_state: + `True` if `state_ids_before_event` is partial and omits non-critical + membership events. + `False` if `state_ids_before_event` is the full state. + `None` when `state_ids_before_event` is not provided. In this case, the + flag will be calculated based on `event`'s prev events. + state_group_before_event: + the current state group at the time of event, if known Returns: The event context. @@ -294,7 +298,6 @@ class StateHandler: RuntimeError if `state_ids_before_event` is not provided and one or more prev events are missing or outliers. """ - assert not event.internal_metadata.is_outlier() # @@ -306,17 +309,6 @@ class StateHandler: state_group_before_event_prev_group = None deltas_to_state_group_before_event = None - # .. though we need to get a state group for it. - state_group_before_event = ( - await self._state_storage_controller.store_state_group( - event.event_id, - event.room_id, - prev_group=None, - delta_ids=None, - current_state_ids=state_ids_before_event, - ) - ) - # the partial_state flag must be provided assert partial_state is not None else: @@ -345,6 +337,7 @@ class StateHandler: logger.debug("calling resolve_state_groups from compute_event_context") # we've already taken into account partial state, so no need to wait for # complete state here. + entry = await self.resolve_state_groups_for_events( event.room_id, event.prev_event_ids(), @@ -383,18 +376,19 @@ class StateHandler: # if not event.is_state(): - return EventContext.with_state( + return UnpersistedEventContext( storage=self._storage_controllers, state_group_before_event=state_group_before_event, - state_group=state_group_before_event, + state_group_after_event=state_group_before_event, state_delta_due_to_event={}, - prev_group=state_group_before_event_prev_group, - delta_ids=deltas_to_state_group_before_event, + prev_group_for_state_group_before_event=state_group_before_event_prev_group, + delta_ids_to_state_group_before_event=deltas_to_state_group_before_event, partial_state=partial_state, + state_map_before_event=state_ids_before_event, ) # - # otherwise, we'll need to create a new state group for after the event + # otherwise, we'll need to set up creating a new state group for after the event # key = (event.type, event.state_key) @@ -412,88 +406,60 @@ class StateHandler: delta_ids = {key: event.event_id} - state_group_after_event = ( - await self._state_storage_controller.store_state_group( - event.event_id, - event.room_id, - prev_group=state_group_before_event, - delta_ids=delta_ids, - current_state_ids=None, - ) - ) - - return EventContext.with_state( + return UnpersistedEventContext( storage=self._storage_controllers, - state_group=state_group_after_event, state_group_before_event=state_group_before_event, + state_group_after_event=None, state_delta_due_to_event=delta_ids, - prev_group=state_group_before_event, - delta_ids=delta_ids, + prev_group_for_state_group_before_event=state_group_before_event_prev_group, + delta_ids_to_state_group_before_event=deltas_to_state_group_before_event, partial_state=partial_state, + state_map_before_event=state_ids_before_event, ) - async def compute_event_context_for_batched( + async def compute_event_context( self, event: EventBase, - state_ids_before_event: StateMap[str], - current_state_group: int, + state_ids_before_event: Optional[StateMap[str]] = None, + partial_state: Optional[bool] = None, ) -> EventContext: - """ - Generate an event context for an event that has not yet been persisted to the - database. Intended for use with events that are created to be persisted in a batch. - Args: - event: the event the context is being computed for - state_ids_before_event: a state map consisting of the state ids of the events - created prior to this event. - current_state_group: the current state group before the event. - """ - state_group_before_event_prev_group = None - deltas_to_state_group_before_event = None - - state_group_before_event = current_state_group - - # if the event is not state, we are set - if not event.is_state(): - return EventContext.with_state( - storage=self._storage_controllers, - state_group_before_event=state_group_before_event, - state_group=state_group_before_event, - state_delta_due_to_event={}, - prev_group=state_group_before_event_prev_group, - delta_ids=deltas_to_state_group_before_event, - partial_state=False, - ) + """Build an EventContext structure for a non-outlier event. - # otherwise, we'll need to create a new state group for after the event - key = (event.type, event.state_key) + (for an outlier, call EventContext.for_outlier directly) - if state_ids_before_event is not None: - replaces = state_ids_before_event.get(key) + This works out what the current state should be for the event, and + generates a new state group if necessary. - if replaces and replaces != event.event_id: - event.unsigned["replaces_state"] = replaces + Args: + event: + state_ids_before_event: The event ids of the state before the event if + it can't be calculated from existing events. This is normally + only specified when receiving an event from federation where we + don't have the prev events, e.g. when backfilling. + partial_state: + `True` if `state_ids_before_event` is partial and omits non-critical + membership events. + `False` if `state_ids_before_event` is the full state. + `None` when `state_ids_before_event` is not provided. In this case, the + flag will be calculated based on `event`'s prev events. + entry: + A state cache entry for the resolved state across the prev events. We may + have already calculated this, so if it's available pass it in + Returns: + The event context. - delta_ids = {key: event.event_id} + Raises: + RuntimeError if `state_ids_before_event` is not provided and one or more + prev events are missing or outliers. + """ - state_group_after_event = ( - await self._state_storage_controller.store_state_group( - event.event_id, - event.room_id, - prev_group=state_group_before_event, - delta_ids=delta_ids, - current_state_ids=None, - ) + unpersisted_context = await self.calculate_context_info( + event=event, + state_ids_before_event=state_ids_before_event, + partial_state=partial_state, ) - return EventContext.with_state( - storage=self._storage_controllers, - state_group=state_group_after_event, - state_group_before_event=state_group_before_event, - state_delta_due_to_event=delta_ids, - prev_group=state_group_before_event, - delta_ids=delta_ids, - partial_state=False, - ) + return await unpersisted_context.persist(event) @measure_func() async def resolve_state_groups_for_events( diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 75fc5a17a4..e9be5fb504 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -949,10 +949,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): }, ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) + context = self.get_success(unpersisted_context.persist(event)) + self.get_success( self.hs.get_storage_controllers().persistence.persist_event(event, context) ) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 5c1ced355f..b50406e129 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -2934,10 +2934,12 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): }, ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( event_creation_handler.create_new_client_event(builder) ) + context = self.get_success(unpersisted_context.persist(event)) + self.get_success(storage_controllers.persistence.persist_event(event, context)) # Now get rooms diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index df4740f9d9..0100f7da14 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -74,10 +74,12 @@ class RedactionTestCase(unittest.HomeserverTestCase): }, ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) + context = self.get_success(unpersisted_context.persist(event)) + self.get_success(self._persistence.persist_event(event, context)) return event @@ -96,10 +98,12 @@ class RedactionTestCase(unittest.HomeserverTestCase): }, ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) + context = self.get_success(unpersisted_context.persist(event)) + self.get_success(self._persistence.persist_event(event, context)) return event @@ -119,10 +123,12 @@ class RedactionTestCase(unittest.HomeserverTestCase): }, ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) + context = self.get_success(unpersisted_context.persist(event)) + self.get_success(self._persistence.persist_event(event, context)) return event @@ -259,7 +265,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): def internal_metadata(self) -> _EventInternalMetadata: return self._base_builder.internal_metadata - event_1, context_1 = self.get_success( + event_1, unpersisted_context_1 = self.get_success( self.event_creation_handler.create_new_client_event( cast( EventBuilder, @@ -280,9 +286,11 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) ) + context_1 = self.get_success(unpersisted_context_1.persist(event_1)) + self.get_success(self._persistence.persist_event(event_1, context_1)) - event_2, context_2 = self.get_success( + event_2, unpersisted_context_2 = self.get_success( self.event_creation_handler.create_new_client_event( cast( EventBuilder, @@ -302,6 +310,8 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) ) ) + + context_2 = self.get_success(unpersisted_context_2.persist(event_2)) self.get_success(self._persistence.persist_event(event_2, context_2)) # fetch one of the redactions @@ -421,10 +431,12 @@ class RedactionTestCase(unittest.HomeserverTestCase): }, ) - redaction_event, context = self.get_success( + redaction_event, unpersisted_context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) + context = self.get_success(unpersisted_context.persist(redaction_event)) + self.get_success(self._persistence.persist_event(redaction_event, context)) # Now lets jump to the future where we have censored the redaction event diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index bad7f0bc60..f730b888f7 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -67,10 +67,12 @@ class StateStoreTestCase(HomeserverTestCase): }, ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) + context = self.get_success(unpersisted_context.persist(event)) + assert self.storage.persistence is not None self.get_success(self.storage.persistence.persist_event(event, context)) diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index 1a50c2acf1..a6330ed840 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -92,8 +92,13 @@ async def create_event( builder = hs.get_event_builder_factory().for_room_version( KNOWN_ROOM_VERSIONS[room_version], kwargs ) - event, context = await hs.get_event_creation_handler().create_new_client_event( + ( + event, + unpersisted_context, + ) = await hs.get_event_creation_handler().create_new_client_event( builder, prev_event_ids=prev_event_ids ) + context = await unpersisted_context.persist(event) + return event, context diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 875e37988f..36d6b37aa4 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -175,9 +175,10 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): }, ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) + context = self.get_success(unpersisted_context.persist(event)) self.get_success( self._storage_controllers.persistence.persist_event(event, context) ) @@ -202,9 +203,10 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): }, ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) + context = self.get_success(unpersisted_context.persist(event)) self.get_success( self._storage_controllers.persistence.persist_event(event, context) @@ -226,9 +228,10 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): }, ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) + context = self.get_success(unpersisted_context.persist(event)) self.get_success( self._storage_controllers.persistence.persist_event(event, context) diff --git a/tests/utils.py b/tests/utils.py index d76bf9716a..15fabbc2d0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -335,6 +335,9 @@ async def create_room(hs: HomeServer, room_id: str, creator_id: str) -> None: }, ) - event, context = await event_creation_handler.create_new_client_event(builder) + event, unpersisted_context = await event_creation_handler.create_new_client_event( + builder + ) + context = await unpersisted_context.persist(event) await persistence_store.persist_event(event, context) -- cgit 1.5.1 From 6cddf24e361fe43f086307c833cd814dc03363b6 Mon Sep 17 00:00:00 2001 From: Mathieu Velten Date: Sat, 11 Feb 2023 00:31:05 +0100 Subject: Faster joins: don't stall when a user joins during a fast join (#14606) Fixes #12801. Complement tests are at https://github.com/matrix-org/complement/pull/567. Avoid blocking on full state when handling a subsequent join into a partial state room. Also always perform a remote join into partial state rooms, since we do not know whether the joining user has been banned and want to avoid leaking history to banned users. Signed-off-by: Mathieu Velten Co-authored-by: Sean Quah Co-authored-by: David Robertson --- changelog.d/14606.misc | 1 + synapse/api/errors.py | 22 ++++++ synapse/federation/federation_server.py | 2 +- synapse/handlers/event_auth.py | 16 ++--- synapse/handlers/federation.py | 2 +- synapse/handlers/federation_event.py | 59 ++++++++++++++-- synapse/handlers/message.py | 2 +- synapse/handlers/room.py | 2 +- synapse/handlers/room_member.py | 118 ++++++++++++++++++++++--------- synapse/handlers/room_member_worker.py | 5 +- synapse/storage/databases/main/events.py | 21 +----- tests/handlers/test_federation.py | 40 +++++------ 12 files changed, 196 insertions(+), 94 deletions(-) create mode 100644 changelog.d/14606.misc (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/14606.misc b/changelog.d/14606.misc new file mode 100644 index 0000000000..e2debc96d8 --- /dev/null +++ b/changelog.d/14606.misc @@ -0,0 +1 @@ +Faster joins: don't stall when another user joins during a fast join resync. diff --git a/synapse/api/errors.py b/synapse/api/errors.py index c2c177fd71..9235ce6536 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -751,3 +751,25 @@ class ModuleFailedException(Exception): Raised when a module API callback fails, for example because it raised an exception. """ + + +class PartialStateConflictError(SynapseError): + """An internal error raised when attempting to persist an event with partial state + after the room containing the event has been un-partial stated. + + This error should be handled by recomputing the event context and trying again. + + This error has an HTTP status code so that it can be transported over replication. + It should not be exposed to clients. + """ + + @staticmethod + def message() -> str: + return "Cannot persist partial state event in un-partial stated room" + + def __init__(self) -> None: + super().__init__( + HTTPStatus.CONFLICT, + msg=PartialStateConflictError.message(), + errcode=Codes.UNKNOWN, + ) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 6addc0bb65..6d99845de5 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -48,6 +48,7 @@ from synapse.api.errors import ( FederationError, IncompatibleRoomVersionError, NotFoundError, + PartialStateConflictError, SynapseError, UnsupportedRoomVersionError, ) @@ -81,7 +82,6 @@ from synapse.replication.http.federation import ( ReplicationFederationSendEduRestServlet, ReplicationGetQueryRestServlet, ) -from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.lock import Lock from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary from synapse.storage.roommember import MemberSummary diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index a23a8ce2a1..46dd63c3f0 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -202,7 +202,7 @@ class EventAuthHandler: state_ids: StateMap[str], room_version: RoomVersion, user_id: str, - prev_member_event: Optional[EventBase], + prev_membership: Optional[str], ) -> None: """ Check whether a user can join a room without an invite due to restricted join rules. @@ -214,15 +214,14 @@ class EventAuthHandler: state_ids: The state of the room as it currently is. room_version: The room version of the room being joined. user_id: The user joining the room. - prev_member_event: The current membership event for this user. + prev_membership: The current membership state for this user. `None` if the + user has never joined the room (equivalent to "leave"). Raises: AuthError if the user cannot join the room. """ # If the member is invited or currently joined, then nothing to do. - if prev_member_event and ( - prev_member_event.membership in (Membership.JOIN, Membership.INVITE) - ): + if prev_membership in (Membership.JOIN, Membership.INVITE): return # This is not a room with a restricted join rule, so we don't need to do the @@ -255,13 +254,14 @@ class EventAuthHandler: ) async def has_restricted_join_rules( - self, state_ids: StateMap[str], room_version: RoomVersion + self, partial_state_ids: StateMap[str], room_version: RoomVersion ) -> bool: """ Return if the room has the proper join rules set for access via rooms. Args: - state_ids: The state of the room as it currently is. + state_ids: The state of the room as it currently is. May be full or partial + state. room_version: The room version of the room to query. Returns: @@ -272,7 +272,7 @@ class EventAuthHandler: return False # If there's no join rule, then it defaults to invite (so this doesn't apply). - join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None) + join_rules_event_id = partial_state_ids.get((EventTypes.JoinRules, ""), None) if not join_rules_event_id: return False diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 43ed4a3dd1..08727e4857 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -49,6 +49,7 @@ from synapse.api.errors import ( FederationPullAttemptBackoffError, HttpResponseException, NotFoundError, + PartialStateConflictError, RequestSendFailed, SynapseError, ) @@ -68,7 +69,6 @@ from synapse.replication.http.federation import ( ReplicationCleanRoomRestServlet, ReplicationStoreRoomOnOutlierMembershipRestServlet, ) -from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import JsonDict, StrCollection, get_domain_from_id from synapse.types.state import StateFilter diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 3561f2f1de..b7136f8d1c 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -47,6 +47,7 @@ from synapse.api.errors import ( FederationError, FederationPullAttemptBackoffError, HttpResponseException, + PartialStateConflictError, RequestSendFailed, SynapseError, ) @@ -74,7 +75,6 @@ from synapse.replication.http.federation import ( ReplicationFederationSendEventsRestServlet, ) from synapse.state import StateResolutionStore -from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import ( PersistedEventPosition, @@ -441,16 +441,17 @@ class FederationEventHandler: # Check if the user is already in the room or invited to the room. user_id = event.state_key prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) - prev_member_event = None + prev_membership = None if prev_member_event_id: prev_member_event = await self._store.get_event(prev_member_event_id) + prev_membership = prev_member_event.membership # Check if the member should be allowed access via membership in a space. await self._event_auth_handler.check_restricted_join_rules( prev_state_ids, event.room_version, user_id, - prev_member_event, + prev_membership, ) @trace @@ -526,11 +527,57 @@ class FederationEventHandler: "Peristing join-via-remote %s (partial_state: %s)", event, partial_state ) with nested_logging_context(suffix=event.event_id): + if partial_state: + # When handling a second partial state join into a partial state room, + # the returned state will exclude the membership from the first join. To + # preserve prior memberships, we try to compute the partial state before + # the event ourselves if we know about any of the prev events. + # + # When we don't know about any of the prev events, it's fine to just use + # the returned state, since the new join will create a new forward + # extremity, and leave the forward extremity containing our prior + # memberships alone. + prev_event_ids = set(event.prev_event_ids()) + seen_event_ids = await self._store.have_events_in_timeline( + prev_event_ids + ) + missing_event_ids = prev_event_ids - seen_event_ids + + state_maps_to_resolve: List[StateMap[str]] = [] + + # Fetch the state after the prev events that we know about. + state_maps_to_resolve.extend( + ( + await self._state_storage_controller.get_state_groups_ids( + room_id, seen_event_ids, await_full_state=False + ) + ).values() + ) + + # When there are prev events we do not have the state for, we state + # resolve with the state returned by the remote homeserver. + if missing_event_ids or len(state_maps_to_resolve) == 0: + state_maps_to_resolve.append( + {(e.type, e.state_key): e.event_id for e in state} + ) + + state_ids_before_event = ( + await self._state_resolution_handler.resolve_events_with_store( + event.room_id, + room_version.identifier, + state_maps_to_resolve, + event_map=None, + state_res_store=StateResolutionStore(self._store), + ) + ) + else: + state_ids_before_event = { + (e.type, e.state_key): e.event_id for e in state + } + context = await self._state_handler.compute_event_context( event, - state_ids_before_event={ - (e.type, e.state_key): e.event_id for e in state - }, + state_ids_before_event=state_ids_before_event, partial_state=partial_state, ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 3e30f52e4d..8f5b658d9d 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -38,6 +38,7 @@ from synapse.api.errors import ( Codes, ConsentNotGivenError, NotFoundError, + PartialStateConflictError, ShadowBanError, SynapseError, UnstableSpecAuthError, @@ -57,7 +58,6 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.replication.http.send_events import ReplicationSendEventsRestServlet -from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import ( MutableStateMap, diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 060bbcb181..837dabb3b7 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -43,6 +43,7 @@ from synapse.api.errors import ( Codes, LimitExceededError, NotFoundError, + PartialStateConflictError, StoreError, SynapseError, ) @@ -54,7 +55,6 @@ from synapse.events.utils import copy_and_fixup_power_levels_contents from synapse.handlers.relations import BundledAggregations from synapse.module_api import NOT_SPAM from synapse.rest.admin._base import assert_user_is_admin -from synapse.storage.databases.main.events import PartialStateConflictError from synapse.streams import EventSource from synapse.types import ( JsonDict, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 6e7141d2ef..a965c7ec76 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -26,7 +26,13 @@ from synapse.api.constants import ( GuestAccess, Membership, ) -from synapse.api.errors import AuthError, Codes, ShadowBanError, SynapseError +from synapse.api.errors import ( + AuthError, + Codes, + PartialStateConflictError, + ShadowBanError, + SynapseError, +) from synapse.api.ratelimiting import Ratelimiter from synapse.event_auth import get_named_level, get_power_level_event from synapse.events import EventBase @@ -34,7 +40,6 @@ from synapse.events.snapshot import EventContext from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN from synapse.logging import opentracing from synapse.module_api import NOT_SPAM -from synapse.storage.databases.main.events import PartialStateConflictError from synapse.types import ( JsonDict, Requester, @@ -56,6 +61,13 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class NoKnownServersError(SynapseError): + """No server already resident to the room was provided to the join/knock operation.""" + + def __init__(self, msg: str = "No known servers"): + super().__init__(404, msg) + + class RoomMemberHandler(metaclass=abc.ABCMeta): # TODO(paul): This handler currently contains a messy conflation of # low-level API that works on UserID objects and so on, and REST-level @@ -185,6 +197,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): room_id: Room that we are trying to join user: User who is trying to join content: A dict that should be used as the content of the join event. + + Raises: + NoKnownServersError: if remote_room_hosts does not contain a server joined to + the room. """ raise NotImplementedError() @@ -823,14 +839,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): latest_event_ids = await self.store.get_prev_events_for_room(room_id) - state_before_join = await self.state_handler.compute_state_after_events( - room_id, latest_event_ids + is_partial_state_room = await self.store.is_partial_state_room(room_id) + partial_state_before_join = await self.state_handler.compute_state_after_events( + room_id, latest_event_ids, await_full_state=False ) + # `is_partial_state_room` also indicates whether `partial_state_before_join` is + # partial. # TODO: Refactor into dictionary of explicitly allowed transitions # between old and new state, with specific error messages for some # transitions and generic otherwise - old_state_id = state_before_join.get((EventTypes.Member, target.to_string())) + old_state_id = partial_state_before_join.get( + (EventTypes.Member, target.to_string()) + ) if old_state_id: old_state = await self.store.get_event(old_state_id, allow_none=True) old_membership = old_state.content.get("membership") if old_state else None @@ -881,11 +902,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if action == "kick": raise AuthError(403, "The target user is not in the room") - is_host_in_room = await self._is_host_in_room(state_before_join) + is_host_in_room = await self._is_host_in_room(partial_state_before_join) if effective_membership_state == Membership.JOIN: if requester.is_guest: - guest_can_join = await self._can_guest_join(state_before_join) + guest_can_join = await self._can_guest_join(partial_state_before_join) if not guest_can_join: # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. @@ -927,8 +948,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): room_id, remote_room_hosts, content, + is_partial_state_room, is_host_in_room, - state_before_join, + partial_state_before_join, ) if remote_join: if ratelimit: @@ -1073,8 +1095,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): room_id: str, remote_room_hosts: List[str], content: JsonDict, + is_partial_state_room: bool, is_host_in_room: bool, - state_before_join: StateMap[str], + partial_state_before_join: StateMap[str], ) -> Tuple[bool, List[str]]: """ Check whether the server should do a remote join (as opposed to a local @@ -1093,9 +1116,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): remote_room_hosts: A list of remote room hosts. content: The content to use as the event body of the join. This may be modified. - is_host_in_room: True if the host is in the room. - state_before_join: The state before the join event (i.e. the resolution of - the states after its parent events). + is_partial_state_room: `True` if the server currently doesn't hold the full + state of the room. + is_host_in_room: `True` if the host is in the room. + partial_state_before_join: The state before the join event (i.e. the + resolution of the states after its parent events). May be full or + partial state, depending on `is_partial_state_room`. Returns: A tuple of: @@ -1109,6 +1135,23 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if not is_host_in_room: return True, remote_room_hosts + prev_member_event_id = partial_state_before_join.get( + (EventTypes.Member, user_id), None + ) + previous_membership = None + if prev_member_event_id: + prev_member_event = await self.store.get_event(prev_member_event_id) + previous_membership = prev_member_event.membership + + # If we are not fully joined yet, and the target is not already in the room, + # let's do a remote join so another server with the full state can validate + # that the user has not been banned for example. + # We could just accept the join and wait for state res to resolve that later on + # but we would then leak room history to this person until then, which is pretty + # bad. + if is_partial_state_room and previous_membership != Membership.JOIN: + return True, remote_room_hosts + # If the host is in the room, but not one of the authorised hosts # for restricted join rules, a remote join must be used. room_version = await self.store.get_room_version(room_id) @@ -1116,21 +1159,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # If restricted join rules are not being used, a local join can always # be used. if not await self.event_auth_handler.has_restricted_join_rules( - state_before_join, room_version + partial_state_before_join, room_version ): return False, [] # If the user is invited to the room or already joined, the join # event can always be issued locally. - prev_member_event_id = state_before_join.get((EventTypes.Member, user_id), None) - prev_member_event = None - if prev_member_event_id: - prev_member_event = await self.store.get_event(prev_member_event_id) - if prev_member_event.membership in ( - Membership.JOIN, - Membership.INVITE, - ): - return False, [] + if previous_membership in (Membership.JOIN, Membership.INVITE): + return False, [] + + # All the partial state cases are covered above. We have been given the full + # state of the room. + assert not is_partial_state_room + state_before_join = partial_state_before_join # If the local host has a user who can issue invites, then a local # join can be done. @@ -1154,7 +1195,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # Ensure the member should be allowed access via membership in a room. await self.event_auth_handler.check_restricted_join_rules( - state_before_join, room_version, user_id, prev_member_event + state_before_join, room_version, user_id, previous_membership ) # If this is going to be a local join, additional information must @@ -1304,11 +1345,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if prev_member_event.membership == Membership.JOIN: await self._user_left_room(target_user, room_id) - async def _can_guest_join(self, current_state_ids: StateMap[str]) -> bool: + async def _can_guest_join(self, partial_current_state_ids: StateMap[str]) -> bool: """ Returns whether a guest can join a room based on its current state. + + Args: + partial_current_state_ids: The current state of the room. May be full or + partial state. """ - guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None) + guest_access_id = partial_current_state_ids.get( + (EventTypes.GuestAccess, ""), None + ) if not guest_access_id: return False @@ -1634,19 +1681,25 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ) return event, stream_id - async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool: + async def _is_host_in_room(self, partial_current_state_ids: StateMap[str]) -> bool: + """Returns whether the homeserver is in the room based on its current state. + + Args: + partial_current_state_ids: The current state of the room. May be full or + partial state. + """ # Have we just created the room, and is this about to be the very # first member event? - create_event_id = current_state_ids.get(("m.room.create", "")) - if len(current_state_ids) == 1 and create_event_id: + create_event_id = partial_current_state_ids.get(("m.room.create", "")) + if len(partial_current_state_ids) == 1 and create_event_id: # We can only get here if we're in the process of creating the room return True - for etype, state_key in current_state_ids: + for etype, state_key in partial_current_state_ids: if etype != EventTypes.Member or not self.hs.is_mine_id(state_key): continue - event_id = current_state_ids[(etype, state_key)] + event_id = partial_current_state_ids[(etype, state_key)] event = await self.store.get_event(event_id, allow_none=True) if not event: continue @@ -1715,8 +1768,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): ] if len(remote_room_hosts) == 0: - raise SynapseError( - 404, + raise NoKnownServersError( "Can't join remote room because no servers " "that are in the room have been provided.", ) @@ -1947,7 +1999,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): ] if len(remote_room_hosts) == 0: - raise SynapseError(404, "No known servers") + raise NoKnownServersError() return await self.federation_handler.do_knock( remote_room_hosts, room_id, user.to_string(), content=content diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index 221552a2a6..ba261702d4 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -15,8 +15,7 @@ import logging from typing import TYPE_CHECKING, List, Optional, Tuple -from synapse.api.errors import SynapseError -from synapse.handlers.room_member import RoomMemberHandler +from synapse.handlers.room_member import NoKnownServersError, RoomMemberHandler from synapse.replication.http.membership import ( ReplicationRemoteJoinRestServlet as ReplRemoteJoin, ReplicationRemoteKnockRestServlet as ReplRemoteKnock, @@ -52,7 +51,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): ) -> Tuple[str, int]: """Implements RoomMemberHandler._remote_join""" if len(remote_room_hosts) == 0: - raise SynapseError(404, "No known servers") + raise NoKnownServersError() ret = await self._remote_join_client( requester=requester, diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index cb66376fb4..ffe766fd56 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -16,7 +16,6 @@ import itertools import logging from collections import OrderedDict -from http import HTTPStatus from typing import ( TYPE_CHECKING, Any, @@ -36,7 +35,7 @@ from prometheus_client import Counter import synapse.metrics from synapse.api.constants import EventContentFields, EventTypes, RelationTypes -from synapse.api.errors import Codes, SynapseError +from synapse.api.errors import PartialStateConflictError from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, relation_from_event from synapse.events.snapshot import EventContext @@ -72,24 +71,6 @@ event_counter = Counter( ) -class PartialStateConflictError(SynapseError): - """An internal error raised when attempting to persist an event with partial state - after the room containing the event has been un-partial stated. - - This error should be handled by recomputing the event context and trying again. - - This error has an HTTP status code so that it can be transported over replication. - It should not be exposed to clients. - """ - - def __init__(self) -> None: - super().__init__( - HTTPStatus.CONFLICT, - msg="Cannot persist partial state event in un-partial stated room", - errcode=Codes.UNKNOWN, - ) - - @attr.s(slots=True, auto_attribs=True) class DeltaState: """Deltas to use to update the `current_state_events` table. diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 57675fa407..5868eb2da7 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -575,26 +575,6 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): fed_client = fed_handler.federation_client room_id = "!room:example.com" - membership_event = make_event_from_dict( - { - "room_id": room_id, - "type": "m.room.member", - "sender": "@alice:test", - "state_key": "@alice:test", - "content": {"membership": "join"}, - }, - RoomVersions.V10, - ) - - mock_make_membership_event = Mock( - return_value=make_awaitable( - ( - "example.com", - membership_event, - RoomVersions.V10, - ) - ) - ) EVENT_CREATE = make_event_from_dict( { @@ -640,6 +620,26 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): }, room_version=RoomVersions.V10, ) + membership_event = make_event_from_dict( + { + "room_id": room_id, + "type": "m.room.member", + "sender": "@alice:test", + "state_key": "@alice:test", + "content": {"membership": "join"}, + "prev_events": [EVENT_INVITATION_MEMBERSHIP.event_id], + }, + RoomVersions.V10, + ) + mock_make_membership_event = Mock( + return_value=make_awaitable( + ( + "example.com", + membership_event, + RoomVersions.V10, + ) + ) + ) mock_send_join = Mock( return_value=make_awaitable( SendJoinResult( -- cgit 1.5.1 From 5febf88b6c5194582f427142dc0850625547c0d9 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 15 Feb 2023 11:47:57 +0000 Subject: Update the error code for duplicate annotation (#15075) --- changelog.d/15075.feature | 2 ++ synapse/api/errors.py | 4 ++++ synapse/handlers/message.py | 6 +++++- 3 files changed, 11 insertions(+), 1 deletion(-) create mode 100644 changelog.d/15075.feature (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/15075.feature b/changelog.d/15075.feature new file mode 100644 index 0000000000..d25a7567a4 --- /dev/null +++ b/changelog.d/15075.feature @@ -0,0 +1,2 @@ +Update the error code returned when user sends a duplicate annotation. + diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 9235ce6536..e1737de59b 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -108,6 +108,10 @@ class Codes(str, Enum): USER_AWAITING_APPROVAL = "ORG.MATRIX.MSC3866_USER_AWAITING_APPROVAL" + # Attempt to send a second annotation with the same event type & annotation key + # MSC2677 + DUPLICATE_ANNOTATION = "M_DUPLICATE_ANNOTATION" + class CodeMessageException(RuntimeError): """An exception with integer code and message string attributes. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 8f5b658d9d..aa90d0000d 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1337,7 +1337,11 @@ class EventCreationHandler: relation.parent_id, event.type, aggregation_key, event.sender ) if already_exists: - raise SynapseError(400, "Can't send same reaction twice") + raise SynapseError( + 400, + "Can't send same reaction twice", + errcode=Codes.DUPLICATE_ANNOTATION, + ) # Don't attempt to start a thread if the parent event is a relation. elif relation.rel_type == RelationTypes.THREAD: -- cgit 1.5.1 From 1c95ddd09bbc46046a3412e7bb03a87aa3b6f65a Mon Sep 17 00:00:00 2001 From: Shay Date: Fri, 24 Feb 2023 13:15:29 -0800 Subject: Batch up storing state groups when creating new room (#14918) --- changelog.d/14918.misc | 1 + synapse/events/snapshot.py | 49 +++++++++++ synapse/handlers/message.py | 16 ++-- synapse/handlers/room.py | 37 ++++---- synapse/handlers/room_batch.py | 4 +- synapse/handlers/room_member.py | 13 ++- synapse/storage/databases/state/store.py | 119 ++++++++++++++++++++++++++ tests/handlers/test_message.py | 25 ++++-- tests/handlers/test_register.py | 3 +- tests/push/test_bulk_push_rule_evaluator.py | 13 +-- tests/rest/client/test_rooms.py | 4 +- tests/storage/test_event_chain.py | 6 +- tests/storage/test_state.py | 126 ++++++++++++++++++++++++++++ tests/unittest.py | 4 +- 14 files changed, 371 insertions(+), 49 deletions(-) create mode 100644 changelog.d/14918.misc (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/14918.misc b/changelog.d/14918.misc new file mode 100644 index 0000000000..828794354a --- /dev/null +++ b/changelog.d/14918.misc @@ -0,0 +1 @@ +Batch up storing state groups when creating a new room. \ No newline at end of file diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index e0d82ad81c..a91a5d1e3c 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -23,6 +23,7 @@ from synapse.types import JsonDict, StateMap if TYPE_CHECKING: from synapse.storage.controllers import StorageControllers + from synapse.storage.databases import StateGroupDataStore from synapse.storage.databases.main import DataStore from synapse.types.state import StateFilter @@ -348,6 +349,54 @@ class UnpersistedEventContext(UnpersistedEventContextBase): partial_state: bool state_map_before_event: Optional[StateMap[str]] = None + @classmethod + async def batch_persist_unpersisted_contexts( + cls, + events_and_context: List[Tuple[EventBase, "UnpersistedEventContextBase"]], + room_id: str, + last_known_state_group: int, + datastore: "StateGroupDataStore", + ) -> List[Tuple[EventBase, EventContext]]: + """ + Takes a list of events and their associated unpersisted contexts and persists + the unpersisted contexts, returning a list of events and persisted contexts. + Note that all the events must be in a linear chain (ie a <- b <- c). + + Args: + events_and_context: A list of events and their unpersisted contexts + room_id: the room_id for the events + last_known_state_group: the last persisted state group + datastore: a state datastore + """ + amended_events_and_context = await datastore.store_state_deltas_for_batched( + events_and_context, room_id, last_known_state_group + ) + + events_and_persisted_context = [] + for event, unpersisted_context in amended_events_and_context: + if event.is_state(): + context = EventContext( + storage=unpersisted_context._storage, + state_group=unpersisted_context.state_group_after_event, + state_group_before_event=unpersisted_context.state_group_before_event, + state_delta_due_to_event=unpersisted_context.state_delta_due_to_event, + partial_state=unpersisted_context.partial_state, + prev_group=unpersisted_context.state_group_before_event, + delta_ids=unpersisted_context.state_delta_due_to_event, + ) + else: + context = EventContext( + storage=unpersisted_context._storage, + state_group=unpersisted_context.state_group_after_event, + state_group_before_event=unpersisted_context.state_group_before_event, + state_delta_due_to_event=unpersisted_context.state_delta_due_to_event, + partial_state=unpersisted_context.partial_state, + prev_group=unpersisted_context.prev_group_for_state_group_before_event, + delta_ids=unpersisted_context.delta_ids_to_state_group_before_event, + ) + events_and_persisted_context.append((event, context)) + return events_and_persisted_context + async def get_prev_state_ids( self, state_filter: Optional["StateFilter"] = None ) -> StateMap[str]: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index aa90d0000d..e433d6b01f 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -574,7 +574,7 @@ class EventCreationHandler: state_map: Optional[StateMap[str]] = None, for_batch: bool = False, current_state_group: Optional[int] = None, - ) -> Tuple[EventBase, EventContext]: + ) -> Tuple[EventBase, UnpersistedEventContextBase]: """ Given a dict from a client, create a new event. If bool for_batch is true, will create an event using the prev_event_ids, and will create an event context for @@ -721,8 +721,6 @@ class EventCreationHandler: current_state_group=current_state_group, ) - context = await unpersisted_context.persist(event) - # In an ideal world we wouldn't need the second part of this condition. However, # this behaviour isn't spec'd yet, meaning we should be able to deactivate this # behaviour. Another reason is that this code is also evaluated each time a new @@ -739,7 +737,7 @@ class EventCreationHandler: assert state_map is not None prev_event_id = state_map.get((EventTypes.Member, event.sender)) else: - prev_state_ids = await context.get_prev_state_ids( + prev_state_ids = await unpersisted_context.get_prev_state_ids( StateFilter.from_types([(EventTypes.Member, None)]) ) prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender)) @@ -764,8 +762,7 @@ class EventCreationHandler: ) self.validator.validate_new(event, self.config) - - return event, context + return event, unpersisted_context async def _is_exempt_from_privacy_policy( self, builder: EventBuilder, requester: Requester @@ -1005,7 +1002,7 @@ class EventCreationHandler: max_retries = 5 for i in range(max_retries): try: - event, context = await self.create_event( + event, unpersisted_context = await self.create_event( requester, event_dict, txn_id=txn_id, @@ -1016,6 +1013,7 @@ class EventCreationHandler: historical=historical, depth=depth, ) + context = await unpersisted_context.persist(event) assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % ( event.sender, @@ -1190,7 +1188,6 @@ class EventCreationHandler: if for_batch: assert prev_event_ids is not None assert state_map is not None - assert current_state_group is not None auth_ids = self._event_auth_handler.compute_auth_events(builder, state_map) event = await builder.build( prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth @@ -2046,7 +2043,7 @@ class EventCreationHandler: max_retries = 5 for i in range(max_retries): try: - event, context = await self.create_event( + event, unpersisted_context = await self.create_event( requester, { "type": EventTypes.Dummy, @@ -2055,6 +2052,7 @@ class EventCreationHandler: "sender": user_id, }, ) + context = await unpersisted_context.persist(event) event.internal_metadata.proactively_send = False diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index a26ec02284..b1784638f4 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -51,6 +51,7 @@ from synapse.api.filtering import Filter from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase +from synapse.events.snapshot import UnpersistedEventContext from synapse.events.utils import copy_and_fixup_power_levels_contents from synapse.handlers.relations import BundledAggregations from synapse.module_api import NOT_SPAM @@ -211,7 +212,7 @@ class RoomCreationHandler: # the required power level to send the tombstone event. ( tombstone_event, - tombstone_context, + tombstone_unpersisted_context, ) = await self.event_creation_handler.create_event( requester, { @@ -225,6 +226,9 @@ class RoomCreationHandler: }, }, ) + tombstone_context = await tombstone_unpersisted_context.persist( + tombstone_event + ) validate_event_for_room_version(tombstone_event) await self._event_auth_handler.check_auth_rules_from_context( tombstone_event @@ -1092,7 +1096,7 @@ class RoomCreationHandler: content: JsonDict, for_batch: bool, **kwargs: Any, - ) -> Tuple[EventBase, synapse.events.snapshot.EventContext]: + ) -> Tuple[EventBase, synapse.events.snapshot.UnpersistedEventContextBase]: """ Creates an event and associated event context. Args: @@ -1111,20 +1115,23 @@ class RoomCreationHandler: event_dict = create_event_dict(etype, content, **kwargs) - new_event, new_context = await self.event_creation_handler.create_event( + ( + new_event, + new_unpersisted_context, + ) = await self.event_creation_handler.create_event( creator, event_dict, prev_event_ids=prev_event, depth=depth, state_map=state_map, for_batch=for_batch, - current_state_group=current_state_group, ) + depth += 1 prev_event = [new_event.event_id] state_map[(new_event.type, new_event.state_key)] = new_event.event_id - return new_event, new_context + return new_event, new_unpersisted_context try: config = self._presets_dict[preset_config] @@ -1134,10 +1141,10 @@ class RoomCreationHandler: ) creation_content.update({"creator": creator_id}) - creation_event, creation_context = await create_event( + creation_event, unpersisted_creation_context = await create_event( EventTypes.Create, creation_content, False ) - + creation_context = await unpersisted_creation_context.persist(creation_event) logger.debug("Sending %s in new room", EventTypes.Member) ev = await self.event_creation_handler.handle_new_client_event( requester=creator, @@ -1181,7 +1188,6 @@ class RoomCreationHandler: power_event, power_context = await create_event( EventTypes.PowerLevels, pl_content, True ) - current_state_group = power_context._state_group events_to_send.append((power_event, power_context)) else: power_level_content: JsonDict = { @@ -1230,14 +1236,12 @@ class RoomCreationHandler: power_level_content, True, ) - current_state_group = pl_context._state_group events_to_send.append((pl_event, pl_context)) if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: room_alias_event, room_alias_context = await create_event( EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True ) - current_state_group = room_alias_context._state_group events_to_send.append((room_alias_event, room_alias_context)) if (EventTypes.JoinRules, "") not in initial_state: @@ -1246,7 +1250,6 @@ class RoomCreationHandler: {"join_rule": config["join_rules"]}, True, ) - current_state_group = join_rules_context._state_group events_to_send.append((join_rules_event, join_rules_context)) if (EventTypes.RoomHistoryVisibility, "") not in initial_state: @@ -1255,7 +1258,6 @@ class RoomCreationHandler: {"history_visibility": config["history_visibility"]}, True, ) - current_state_group = visibility_context._state_group events_to_send.append((visibility_event, visibility_context)) if config["guest_can_join"]: @@ -1265,14 +1267,12 @@ class RoomCreationHandler: {EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN}, True, ) - current_state_group = guest_access_context._state_group events_to_send.append((guest_access_event, guest_access_context)) for (etype, state_key), content in initial_state.items(): event, context = await create_event( etype, content, True, state_key=state_key ) - current_state_group = context._state_group events_to_send.append((event, context)) if config["encrypted"]: @@ -1284,9 +1284,16 @@ class RoomCreationHandler: ) events_to_send.append((encryption_event, encryption_context)) + datastore = self.hs.get_datastores().state + events_and_context = ( + await UnpersistedEventContext.batch_persist_unpersisted_contexts( + events_to_send, room_id, current_state_group, datastore + ) + ) + last_event = await self.event_creation_handler.handle_new_client_event( creator, - events_to_send, + events_and_context, ignore_shadow_ban=True, ratelimit=False, ) diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index 5d4ca0e2d2..bf9df60218 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -327,7 +327,7 @@ class RoomBatchHandler: # Mark all events as historical event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True - event, context = await self.event_creation_handler.create_event( + event, unpersisted_context = await self.event_creation_handler.create_event( await self.create_requester_for_user_id_from_app_service( ev["sender"], app_service_requester.app_service ), @@ -345,7 +345,7 @@ class RoomBatchHandler: historical=True, depth=inherited_depth, ) - + context = await unpersisted_context.persist(event) assert context._state_group # Normally this is done when persisting the event but we have to diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index a965c7ec76..de7476f300 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -414,7 +414,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): max_retries = 5 for i in range(max_retries): try: - event, context = await self.event_creation_handler.create_event( + ( + event, + unpersisted_context, + ) = await self.event_creation_handler.create_event( requester, { "type": EventTypes.Member, @@ -435,7 +438,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): outlier=outlier, historical=historical, ) - + context = await unpersisted_context.persist(event) prev_state_ids = await context.get_prev_state_ids( StateFilter.from_types([(EventTypes.Member, None)]) ) @@ -1944,7 +1947,10 @@ class RoomMemberMasterHandler(RoomMemberHandler): max_retries = 5 for i in range(max_retries): try: - event, context = await self.event_creation_handler.create_event( + ( + event, + unpersisted_context, + ) = await self.event_creation_handler.create_event( requester, event_dict, txn_id=txn_id, @@ -1952,6 +1958,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): auth_event_ids=auth_event_ids, outlier=True, ) + context = await unpersisted_context.persist(event) event.internal_metadata.out_of_band_membership = True result_event = ( diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 89b1faa6c8..bf4cdfdf29 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -18,6 +18,8 @@ from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Se import attr from synapse.api.constants import EventTypes +from synapse.events import EventBase +from synapse.events.snapshot import UnpersistedEventContext, UnpersistedEventContextBase from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -401,6 +403,123 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): fetched_keys=non_member_types, ) + async def store_state_deltas_for_batched( + self, + events_and_context: List[Tuple[EventBase, UnpersistedEventContextBase]], + room_id: str, + prev_group: int, + ) -> List[Tuple[EventBase, UnpersistedEventContext]]: + """Generate and store state deltas for a group of events and contexts created to be + batch persisted. Note that all the events must be in a linear chain (ie a <- b <- c). + + Args: + events_and_context: the events to generate and store a state groups for + and their associated contexts + room_id: the id of the room the events were created for + prev_group: the state group of the last event persisted before the batched events + were created + """ + + def insert_deltas_group_txn( + txn: LoggingTransaction, + events_and_context: List[Tuple[EventBase, UnpersistedEventContext]], + prev_group: int, + ) -> List[Tuple[EventBase, UnpersistedEventContext]]: + """Generate and store state groups for the provided events and contexts. + + Requires that we have the state as a delta from the last persisted state group. + + Returns: + A list of state groups + """ + is_in_db = self.db_pool.simple_select_one_onecol_txn( + txn, + table="state_groups", + keyvalues={"id": prev_group}, + retcol="id", + allow_none=True, + ) + if not is_in_db: + raise Exception( + "Trying to persist state with unpersisted prev_group: %r" + % (prev_group,) + ) + + num_state_groups = sum( + 1 for event, _ in events_and_context if event.is_state() + ) + + state_groups = self._state_group_seq_gen.get_next_mult_txn( + txn, num_state_groups + ) + + sg_before = prev_group + state_group_iter = iter(state_groups) + for event, context in events_and_context: + if not event.is_state(): + context.state_group_after_event = sg_before + context.state_group_before_event = sg_before + continue + + sg_after = next(state_group_iter) + context.state_group_after_event = sg_after + context.state_group_before_event = sg_before + context.state_delta_due_to_event = { + (event.type, event.state_key): event.event_id + } + sg_before = sg_after + + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups", + keys=("id", "room_id", "event_id"), + values=[ + (context.state_group_after_event, room_id, event.event_id) + for event, context in events_and_context + if event.is_state() + ], + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="state_group_edges", + keys=("state_group", "prev_state_group"), + values=[ + ( + context.state_group_after_event, + context.state_group_before_event, + ) + for event, context in events_and_context + if event.is_state() + ], + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups_state", + keys=("state_group", "room_id", "type", "state_key", "event_id"), + values=[ + ( + context.state_group_after_event, + room_id, + key[0], + key[1], + state_id, + ) + for event, context in events_and_context + if context.state_delta_due_to_event is not None + for key, state_id in context.state_delta_due_to_event.items() + ], + ) + return events_and_context + + return await self.db_pool.runInteraction( + "store_state_deltas_for_batched.insert_deltas_group", + insert_deltas_group_txn, + events_and_context, + prev_group, + ) + async def store_state_group( self, event_id: str, diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index 69d384442f..9691d66b48 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -18,7 +18,7 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.events.snapshot import EventContext +from synapse.events.snapshot import EventContext, UnpersistedEventContextBase from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer @@ -79,7 +79,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase): return memberEvent, memberEventContext - def _create_duplicate_event(self, txn_id: str) -> Tuple[EventBase, EventContext]: + def _create_duplicate_event( + self, txn_id: str + ) -> Tuple[EventBase, UnpersistedEventContextBase]: """Create a new event with the given transaction ID. All events produced by this method will be considered duplicates. """ @@ -107,7 +109,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase): txn_id = "something_suitably_random" - event1, context = self._create_duplicate_event(txn_id) + event1, unpersisted_context = self._create_duplicate_event(txn_id) + context = self.get_success(unpersisted_context.persist(event1)) ret_event1 = self.get_success( self.handler.handle_new_client_event( @@ -119,7 +122,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.assertEqual(event1.event_id, ret_event1.event_id) - event2, context = self._create_duplicate_event(txn_id) + event2, unpersisted_context = self._create_duplicate_event(txn_id) + context = self.get_success(unpersisted_context.persist(event2)) # We want to test that the deduplication at the persit event end works, # so we want to make sure we test with different events. @@ -140,7 +144,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase): # Let's test that calling `persist_event` directly also does the right # thing. - event3, context = self._create_duplicate_event(txn_id) + event3, unpersisted_context = self._create_duplicate_event(txn_id) + context = self.get_success(unpersisted_context.persist(event3)) + self.assertNotEqual(event1.event_id, event3.event_id) ret_event3, event_pos3, _ = self.get_success( @@ -154,7 +160,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase): # Let's test that calling `persist_events` directly also does the right # thing. - event4, context = self._create_duplicate_event(txn_id) + event4, unpersisted_context = self._create_duplicate_event(txn_id) + context = self.get_success(unpersisted_context.persist(event4)) self.assertNotEqual(event1.event_id, event3.event_id) events, _ = self.get_success( @@ -174,8 +181,10 @@ class EventCreationTestCase(unittest.HomeserverTestCase): txn_id = "something_else_suitably_random" # Create two duplicate events to persist at the same time - event1, context1 = self._create_duplicate_event(txn_id) - event2, context2 = self._create_duplicate_event(txn_id) + event1, unpersisted_context1 = self._create_duplicate_event(txn_id) + context1 = self.get_success(unpersisted_context1.persist(event1)) + event2, unpersisted_context2 = self._create_duplicate_event(txn_id) + context2 = self.get_success(unpersisted_context2.persist(event2)) # Ensure their event IDs are different to start with self.assertNotEqual(event1.event_id, event2.event_id) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 1db99b3c00..aff1ec4758 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -507,7 +507,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): # Lower the permissions of the inviter. event_creation_handler = self.hs.get_event_creation_handler() requester = create_requester(inviter) - event, context = self.get_success( + event, unpersisted_context = self.get_success( event_creation_handler.create_event( requester, { @@ -519,6 +519,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): }, ) ) + context = self.get_success(unpersisted_context.persist(event)) self.get_success( event_creation_handler.handle_new_client_event( requester, events_and_context=[(event, context)] diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index dce6899e78..1458076a90 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -130,7 +130,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): # Create a new message event, and try to evaluate it under the dodgy # power level event. - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_event( self.requester, { @@ -145,6 +145,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): prev_event_ids=[pl_event_id], ) ) + context = self.get_success(unpersisted_context.persist(event)) bulk_evaluator = BulkPushRuleEvaluator(self.hs) # should not raise @@ -170,7 +171,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): """Ensure that push rules are not calculated when disabled in the config""" # Create a new message event which should cause a notification. - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_event( self.requester, { @@ -184,6 +185,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): }, ) ) + context = self.get_success(unpersisted_context.persist(event)) bulk_evaluator = BulkPushRuleEvaluator(self.hs) # Mock the method which calculates push rules -- we do this instead of @@ -200,7 +202,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): ) -> bool: """Returns true iff the `mentions` trigger an event push action.""" # Create a new message event which should cause a notification. - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_event( self.requester, { @@ -211,7 +213,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): }, ) ) - + context = self.get_success(unpersisted_context.persist(event)) # Execute the push rule machinery. self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) @@ -390,7 +392,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): bulk_evaluator = BulkPushRuleEvaluator(self.hs) # Create & persist an event to use as the parent of the relation. - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_event( self.requester, { @@ -404,6 +406,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): }, ) ) + context = self.get_success(unpersisted_context.persist(event)) self.get_success( self.event_creation_handler.handle_new_client_event( self.requester, events_and_context=[(event, context)] diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 4dd763096d..a4900703c4 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -713,7 +713,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(33, channel.resource_usage.db_txn_count) + self.assertEqual(30, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -726,7 +726,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(36, channel.resource_usage.db_txn_count) + self.assertEqual(32, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index 73d11e7786..e39b63edac 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -522,7 +522,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_prev_events_for_room(room_id) ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( event_handler.create_event( self.requester, { @@ -535,6 +535,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): prev_event_ids=latest_event_ids, ) ) + context = self.get_success(unpersisted_context.persist(event)) self.get_success( event_handler.handle_new_client_event( self.requester, events_and_context=[(event, context)] @@ -544,7 +545,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): assert state_ids1 is not None state1 = set(state_ids1.values()) - event, context = self.get_success( + event, unpersisted_context = self.get_success( event_handler.create_event( self.requester, { @@ -557,6 +558,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): prev_event_ids=latest_event_ids, ) ) + context = self.get_success(unpersisted_context.persist(event)) self.get_success( event_handler.handle_new_client_event( self.requester, events_and_context=[(event, context)] diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index e82c03f597..62aed6af0a 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -496,3 +496,129 @@ class StateStoreTestCase(HomeserverTestCase): self.assertEqual(is_all, True) self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) + + def test_batched_state_group_storing(self) -> None: + creation_event = self.inject_state_event( + self.room, self.u_alice, EventTypes.Create, "", {} + ) + state_to_event = self.get_success( + self.storage.state.get_state_groups( + self.room.to_string(), [creation_event.event_id] + ) + ) + current_state_group = list(state_to_event.keys())[0] + + # create some unpersisted events and event contexts to store against room + events_and_context = [] + builder = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.Name, + "sender": self.u_alice.to_string(), + "state_key": "", + "room_id": self.room.to_string(), + "content": {"name": "first rename of room"}, + }, + ) + + event1, unpersisted_context1 = self.get_success( + self.event_creation_handler.create_new_client_event(builder) + ) + events_and_context.append((event1, unpersisted_context1)) + + builder2 = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.JoinRules, + "sender": self.u_alice.to_string(), + "state_key": "", + "room_id": self.room.to_string(), + "content": {"join_rule": "private"}, + }, + ) + + event2, unpersisted_context2 = self.get_success( + self.event_creation_handler.create_new_client_event(builder2) + ) + events_and_context.append((event2, unpersisted_context2)) + + builder3 = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.Message, + "sender": self.u_alice.to_string(), + "room_id": self.room.to_string(), + "content": {"body": "hello from event 3", "msgtype": "m.text"}, + }, + ) + + event3, unpersisted_context3 = self.get_success( + self.event_creation_handler.create_new_client_event(builder3) + ) + events_and_context.append((event3, unpersisted_context3)) + + builder4 = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.JoinRules, + "sender": self.u_alice.to_string(), + "state_key": "", + "room_id": self.room.to_string(), + "content": {"join_rule": "public"}, + }, + ) + + event4, unpersisted_context4 = self.get_success( + self.event_creation_handler.create_new_client_event(builder4) + ) + events_and_context.append((event4, unpersisted_context4)) + + processed_events_and_context = self.get_success( + self.hs.get_datastores().state.store_state_deltas_for_batched( + events_and_context, self.room.to_string(), current_state_group + ) + ) + + # check that only state events are in state_groups, and all state events are in state_groups + res = self.get_success( + self.store.db_pool.simple_select_list( + table="state_groups", + keyvalues=None, + retcols=("event_id",), + ) + ) + + events = [] + for result in res: + self.assertNotIn(event3.event_id, result) + events.append(result.get("event_id")) + + for event, _ in processed_events_and_context: + if event.is_state(): + self.assertIn(event.event_id, events) + + # check that each unique state has state group in state_groups_state and that the + # type/state key is correct, and check that each state event's state group + # has an entry and prev event in state_group_edges + for event, context in processed_events_and_context: + if event.is_state(): + state = self.get_success( + self.store.db_pool.simple_select_list( + table="state_groups_state", + keyvalues={"state_group": context.state_group_after_event}, + retcols=("type", "state_key"), + ) + ) + self.assertEqual(event.type, state[0].get("type")) + self.assertEqual(event.state_key, state[0].get("state_key")) + + groups = self.get_success( + self.store.db_pool.simple_select_list( + table="state_group_edges", + keyvalues={"state_group": str(context.state_group_after_event)}, + retcols=("*",), + ) + ) + self.assertEqual( + context.state_group_before_event, groups[0].get("prev_state_group") + ) diff --git a/tests/unittest.py b/tests/unittest.py index b21e7f1221..f9160faa1d 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -723,7 +723,7 @@ class HomeserverTestCase(TestCase): event_creator = self.hs.get_event_creation_handler() requester = create_requester(user) - event, context = self.get_success( + event, unpersisted_context = self.get_success( event_creator.create_event( requester, { @@ -735,7 +735,7 @@ class HomeserverTestCase(TestCase): prev_event_ids=prev_event_ids, ) ) - + context = self.get_success(unpersisted_context.persist(event)) if soft_failed: event.internal_metadata.soft_failed = True -- cgit 1.5.1 From 41f127e06861230024f43aa4ce272116dc886700 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 6 Mar 2023 17:08:39 +0100 Subject: Pass the requester during event serialization. (#15174) This allows Synapse to properly include the transaction ID in the unsigned data of events. --- changelog.d/15174.bugfix | 1 + synapse/events/utils.py | 30 +++++++++++++++------ synapse/handlers/events.py | 20 +++++++------- synapse/handlers/initial_sync.py | 51 +++++++++++++++++++++++++----------- synapse/handlers/message.py | 9 ++++--- synapse/handlers/pagination.py | 4 ++- synapse/handlers/relations.py | 12 +++++++-- synapse/handlers/search.py | 43 +++++++++++++++++++----------- synapse/rest/client/events.py | 16 ++++++----- synapse/rest/client/notifications.py | 12 ++++++--- synapse/rest/client/room.py | 18 +++++++++---- synapse/rest/client/sync.py | 10 +++---- 12 files changed, 151 insertions(+), 75 deletions(-) create mode 100644 changelog.d/15174.bugfix (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/15174.bugfix b/changelog.d/15174.bugfix new file mode 100644 index 0000000000..a0c70cbe22 --- /dev/null +++ b/changelog.d/15174.bugfix @@ -0,0 +1 @@ +Add the `transaction_id` in the events included in many endpoints responses. diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 45f46949a1..b9c15ffcdb 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -38,7 +38,7 @@ from synapse.api.constants import ( ) from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersion -from synapse.types import JsonDict +from synapse.types import JsonDict, Requester from . import EventBase @@ -316,8 +316,9 @@ class SerializeEventConfig: as_client_event: bool = True # Function to convert from federation format to client format event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1 - # ID of the user's auth token - used for namespacing of transaction IDs - token_id: Optional[int] = None + # The entity that requested the event. This is used to determine whether to include + # the transaction_id in the unsigned section of the event. + requester: Optional[Requester] = None # List of event fields to include. If empty, all fields will be returned. only_event_fields: Optional[List[str]] = None # Some events can have stripped room state stored in the `unsigned` field. @@ -367,11 +368,24 @@ def serialize_event( e.unsigned["redacted_because"], time_now_ms, config=config ) - if config.token_id is not None: - if config.token_id == getattr(e.internal_metadata, "token_id", None): - txn_id = getattr(e.internal_metadata, "txn_id", None) - if txn_id is not None: - d["unsigned"]["transaction_id"] = txn_id + # If we have a txn_id saved in the internal_metadata, we should include it in the + # unsigned section of the event if it was sent by the same session as the one + # requesting the event. + # There is a special case for guests, because they only have one access token + # without associated access_token_id, so we always include the txn_id for events + # they sent. + txn_id = getattr(e.internal_metadata, "txn_id", None) + if txn_id is not None and config.requester is not None: + event_token_id = getattr(e.internal_metadata, "token_id", None) + if config.requester.user.to_string() == e.sender and ( + ( + event_token_id is not None + and config.requester.access_token_id is not None + and event_token_id == config.requester.access_token_id + ) + or config.requester.is_guest + ): + d["unsigned"]["transaction_id"] = txn_id # invite_room_state and knock_room_state are a list of stripped room state events # that are meant to provide metadata about a room to an invitee/knocker. They are diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 949b69cb41..68c07f0265 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -23,7 +23,7 @@ from synapse.events.utils import SerializeEventConfig from synapse.handlers.presence import format_user_presence_state from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict, UserID +from synapse.types import JsonDict, Requester, UserID from synapse.visibility import filter_events_for_client if TYPE_CHECKING: @@ -46,13 +46,12 @@ class EventStreamHandler: async def get_stream( self, - auth_user_id: str, + requester: Requester, pagin_config: PaginationConfig, timeout: int = 0, as_client_event: bool = True, affect_presence: bool = True, room_id: Optional[str] = None, - is_guest: bool = False, ) -> JsonDict: """Fetches the events stream for a given user.""" @@ -62,13 +61,12 @@ class EventStreamHandler: raise SynapseError(403, "This room has been blocked on this server") # send any outstanding server notices to the user. - await self._server_notices_sender.on_user_syncing(auth_user_id) + await self._server_notices_sender.on_user_syncing(requester.user.to_string()) - auth_user = UserID.from_string(auth_user_id) presence_handler = self.hs.get_presence_handler() context = await presence_handler.user_syncing( - auth_user_id, + requester.user.to_string(), affect_presence=affect_presence, presence_state=PresenceState.ONLINE, ) @@ -82,10 +80,10 @@ class EventStreamHandler: timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1)) stream_result = await self.notifier.get_events_for( - auth_user, + requester.user, pagin_config, timeout, - is_guest=is_guest, + is_guest=requester.is_guest, explicit_room_id=room_id, ) events = stream_result.events @@ -102,7 +100,7 @@ class EventStreamHandler: if event.membership != Membership.JOIN: continue # Send down presence. - if event.state_key == auth_user_id: + if event.state_key == requester.user.to_string(): # Send down presence for everyone in the room. users: Iterable[str] = await self.store.get_users_in_room( event.room_id @@ -124,7 +122,9 @@ class EventStreamHandler: chunks = self._event_serializer.serialize_events( events, time_now, - config=SerializeEventConfig(as_client_event=as_client_event), + config=SerializeEventConfig( + as_client_event=as_client_event, requester=requester + ), ) chunk = { diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index aead0b44b9..b3be7a86f0 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -318,11 +318,9 @@ class InitialSyncHandler: ) is_peeking = member_event_id is None - user_id = requester.user.to_string() - if membership == Membership.JOIN: result = await self._room_initial_sync_joined( - user_id, room_id, pagin_config, membership, is_peeking + requester, room_id, pagin_config, membership, is_peeking ) elif membership == Membership.LEAVE: # The member_event_id will always be available if membership is set @@ -330,10 +328,16 @@ class InitialSyncHandler: assert member_event_id result = await self._room_initial_sync_parted( - user_id, room_id, pagin_config, membership, member_event_id, is_peeking + requester, + room_id, + pagin_config, + membership, + member_event_id, + is_peeking, ) account_data_events = [] + user_id = requester.user.to_string() tags = await self.store.get_tags_for_room(user_id, room_id) if tags: account_data_events.append( @@ -350,7 +354,7 @@ class InitialSyncHandler: async def _room_initial_sync_parted( self, - user_id: str, + requester: Requester, room_id: str, pagin_config: PaginationConfig, membership: str, @@ -369,13 +373,17 @@ class InitialSyncHandler: ) messages = await filter_events_for_client( - self._storage_controllers, user_id, messages, is_peeking=is_peeking + self._storage_controllers, + requester.user.to_string(), + messages, + is_peeking=is_peeking, ) start_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, token) end_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, stream_token) time_now = self.clock.time_msec() + serialize_options = SerializeEventConfig(requester=requester) return { "membership": membership, @@ -383,14 +391,18 @@ class InitialSyncHandler: "messages": { "chunk": ( # Don't bundle aggregations as this is a deprecated API. - self._event_serializer.serialize_events(messages, time_now) + self._event_serializer.serialize_events( + messages, time_now, config=serialize_options + ) ), "start": await start_token.to_string(self.store), "end": await end_token.to_string(self.store), }, "state": ( # Don't bundle aggregations as this is a deprecated API. - self._event_serializer.serialize_events(room_state.values(), time_now) + self._event_serializer.serialize_events( + room_state.values(), time_now, config=serialize_options + ) ), "presence": [], "receipts": [], @@ -398,7 +410,7 @@ class InitialSyncHandler: async def _room_initial_sync_joined( self, - user_id: str, + requester: Requester, room_id: str, pagin_config: PaginationConfig, membership: str, @@ -410,9 +422,12 @@ class InitialSyncHandler: # TODO: These concurrently time_now = self.clock.time_msec() + serialize_options = SerializeEventConfig(requester=requester) # Don't bundle aggregations as this is a deprecated API. state = self._event_serializer.serialize_events( - current_state.values(), time_now + current_state.values(), + time_now, + config=serialize_options, ) now_token = self.hs.get_event_sources().get_current_token() @@ -450,7 +465,10 @@ class InitialSyncHandler: if not receipts: return [] - return ReceiptEventSource.filter_out_private_receipts(receipts, user_id) + return ReceiptEventSource.filter_out_private_receipts( + receipts, + requester.user.to_string(), + ) presence, receipts, (messages, token) = await make_deferred_yieldable( gather_results( @@ -469,20 +487,23 @@ class InitialSyncHandler: ) messages = await filter_events_for_client( - self._storage_controllers, user_id, messages, is_peeking=is_peeking + self._storage_controllers, + requester.user.to_string(), + messages, + is_peeking=is_peeking, ) start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token) end_token = now_token - time_now = self.clock.time_msec() - ret = { "room_id": room_id, "messages": { "chunk": ( # Don't bundle aggregations as this is a deprecated API. - self._event_serializer.serialize_events(messages, time_now) + self._event_serializer.serialize_events( + messages, time_now, config=serialize_options + ) ), "start": await start_token.to_string(self.store), "end": await end_token.to_string(self.store), diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e433d6b01f..da129ec16a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -50,7 +50,7 @@ from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase, relation_from_event from synapse.events.builder import EventBuilder from synapse.events.snapshot import EventContext, UnpersistedEventContextBase -from synapse.events.utils import maybe_upsert_event_field +from synapse.events.utils import SerializeEventConfig, maybe_upsert_event_field from synapse.events.validator import EventValidator from synapse.handlers.directory import DirectoryHandler from synapse.logging import opentracing @@ -245,8 +245,11 @@ class MessageHandler: ) room_state = room_state_events[membership_event_id] - now = self.clock.time_msec() - events = self._event_serializer.serialize_events(room_state.values(), now) + events = self._event_serializer.serialize_events( + room_state.values(), + self.clock.time_msec(), + config=SerializeEventConfig(requester=requester), + ) return events async def _user_can_see_state_at_event( diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index ceefa16b49..8c79c055ba 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -579,7 +579,9 @@ class PaginationHandler: time_now = self.clock.time_msec() - serialize_options = SerializeEventConfig(as_client_event=as_client_event) + serialize_options = SerializeEventConfig( + as_client_event=as_client_event, requester=requester + ) chunk = { "chunk": ( diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 553053b694..1d09fdf135 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -20,6 +20,7 @@ import attr from synapse.api.constants import Direction, EventTypes, RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase, relation_from_event +from synapse.events.utils import SerializeEventConfig from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import trace from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent @@ -151,16 +152,23 @@ class RelationsHandler: ) now = self._clock.time_msec() + serialize_options = SerializeEventConfig(requester=requester) return_value: JsonDict = { "chunk": self._event_serializer.serialize_events( - events, now, bundle_aggregations=aggregations + events, + now, + bundle_aggregations=aggregations, + config=serialize_options, ), } if include_original_event: # Do not bundle aggregations when retrieving the original event because # we want the content before relations are applied to it. return_value["original_event"] = self._event_serializer.serialize_event( - event, now, bundle_aggregations=None + event, + now, + bundle_aggregations=None, + config=serialize_options, ) if next_token: diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 9bbf83047d..aad4706f14 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -23,7 +23,8 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, SynapseError from synapse.api.filtering import Filter from synapse.events import EventBase -from synapse.types import JsonDict, StrCollection, StreamKeyType, UserID +from synapse.events.utils import SerializeEventConfig +from synapse.types import JsonDict, Requester, StrCollection, StreamKeyType, UserID from synapse.types.state import StateFilter from synapse.visibility import filter_events_for_client @@ -109,12 +110,12 @@ class SearchHandler: return historical_room_ids async def search( - self, user: UserID, content: JsonDict, batch: Optional[str] = None + self, requester: Requester, content: JsonDict, batch: Optional[str] = None ) -> JsonDict: """Performs a full text search for a user. Args: - user: The user performing the search. + requester: The user performing the search. content: Search parameters batch: The next_batch parameter. Used for pagination. @@ -199,7 +200,7 @@ class SearchHandler: ) return await self._search( - user, + requester, batch_group, batch_group_key, batch_token, @@ -217,7 +218,7 @@ class SearchHandler: async def _search( self, - user: UserID, + requester: Requester, batch_group: Optional[str], batch_group_key: Optional[str], batch_token: Optional[str], @@ -235,7 +236,7 @@ class SearchHandler: """Performs a full text search for a user. Args: - user: The user performing the search. + requester: The user performing the search. batch_group: Pagination information. batch_group_key: Pagination information. batch_token: Pagination information. @@ -269,7 +270,7 @@ class SearchHandler: # TODO: Search through left rooms too rooms = await self.store.get_rooms_for_local_user_where_membership_is( - user.to_string(), + requester.user.to_string(), membership_list=[Membership.JOIN], # membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban], ) @@ -303,13 +304,13 @@ class SearchHandler: if order_by == "rank": search_result, sender_group = await self._search_by_rank( - user, room_ids, search_term, keys, search_filter + requester.user, room_ids, search_term, keys, search_filter ) # Unused return values for rank search. global_next_batch = None elif order_by == "recent": search_result, global_next_batch = await self._search_by_recent( - user, + requester.user, room_ids, search_term, keys, @@ -334,7 +335,7 @@ class SearchHandler: assert after_limit is not None contexts = await self._calculate_event_contexts( - user, + requester.user, search_result.allowed_events, before_limit, after_limit, @@ -363,27 +364,37 @@ class SearchHandler: # The returned events. search_result.allowed_events, ), - user.to_string(), + requester.user.to_string(), ) # We're now about to serialize the events. We should not make any # blocking calls after this. Otherwise, the 'age' will be wrong. time_now = self.clock.time_msec() + serialize_options = SerializeEventConfig(requester=requester) for context in contexts.values(): context["events_before"] = self._event_serializer.serialize_events( - context["events_before"], time_now, bundle_aggregations=aggregations + context["events_before"], + time_now, + bundle_aggregations=aggregations, + config=serialize_options, ) context["events_after"] = self._event_serializer.serialize_events( - context["events_after"], time_now, bundle_aggregations=aggregations + context["events_after"], + time_now, + bundle_aggregations=aggregations, + config=serialize_options, ) results = [ { "rank": search_result.rank_map[e.event_id], "result": self._event_serializer.serialize_event( - e, time_now, bundle_aggregations=aggregations + e, + time_now, + bundle_aggregations=aggregations, + config=serialize_options, ), "context": contexts.get(e.event_id, {}), } @@ -398,7 +409,9 @@ class SearchHandler: if state_results: rooms_cat_res["state"] = { - room_id: self._event_serializer.serialize_events(state_events, time_now) + room_id: self._event_serializer.serialize_events( + state_events, time_now, config=serialize_options + ) for room_id, state_events in state_results.items() } diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py index 782e7d14e8..694d77d287 100644 --- a/synapse/rest/client/events.py +++ b/synapse/rest/client/events.py @@ -17,6 +17,7 @@ import logging from typing import TYPE_CHECKING, Dict, List, Tuple, Union from synapse.api.errors import SynapseError +from synapse.events.utils import SerializeEventConfig from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_string from synapse.http.site import SynapseRequest @@ -43,9 +44,8 @@ class EventStreamRestServlet(RestServlet): async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) - is_guest = requester.is_guest args: Dict[bytes, List[bytes]] = request.args # type: ignore - if is_guest: + if requester.is_guest: if b"room_id" not in args: raise SynapseError(400, "Guest users must specify room_id param") room_id = parse_string(request, "room_id") @@ -63,13 +63,12 @@ class EventStreamRestServlet(RestServlet): as_client_event = b"raw" not in args chunk = await self.event_stream_handler.get_stream( - requester.user.to_string(), + requester, pagin_config, timeout=timeout, as_client_event=as_client_event, - affect_presence=(not is_guest), + affect_presence=(not requester.is_guest), room_id=room_id, - is_guest=is_guest, ) return 200, chunk @@ -91,9 +90,12 @@ class EventRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request) event = await self.event_handler.get_event(requester.user, None, event_id) - time_now = self.clock.time_msec() if event: - result = self._event_serializer.serialize_event(event, time_now) + result = self._event_serializer.serialize_event( + event, + self.clock.time_msec(), + config=SerializeEventConfig(requester=requester), + ) return 200, result else: return 404, "Event not found." diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index 61268e3af1..ea10042569 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -72,6 +72,12 @@ class NotificationsServlet(RestServlet): next_token = None + serialize_options = SerializeEventConfig( + event_format=format_event_for_client_v2_without_room_id, + requester=requester, + ) + now = self.clock.time_msec() + for pa in push_actions: returned_pa = { "room_id": pa.room_id, @@ -81,10 +87,8 @@ class NotificationsServlet(RestServlet): "event": ( self._event_serializer.serialize_event( notif_events[pa.event_id], - self.clock.time_msec(), - config=SerializeEventConfig( - event_format=format_event_for_client_v2_without_room_id - ), + now, + config=serialize_options, ) ), } diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index c5af07816a..61e4cf0213 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -37,7 +37,7 @@ from synapse.api.errors import ( UnredactedContentDeletedError, ) from synapse.api.filtering import Filter -from synapse.events.utils import format_event_for_client_v2 +from synapse.events.utils import SerializeEventConfig, format_event_for_client_v2 from synapse.http.server import HttpServer from synapse.http.servlet import ( ResolveRoomIdMixin, @@ -814,11 +814,13 @@ class RoomEventServlet(RestServlet): [event], requester.user.to_string() ) - time_now = self.clock.time_msec() # per MSC2676, /rooms/{roomId}/event/{eventId}, should return the # *original* event, rather than the edited version event_dict = self._event_serializer.serialize_event( - event, time_now, bundle_aggregations=aggregations + event, + self.clock.time_msec(), + bundle_aggregations=aggregations, + config=SerializeEventConfig(requester=requester), ) return 200, event_dict @@ -863,24 +865,30 @@ class RoomEventContextServlet(RestServlet): raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) time_now = self.clock.time_msec() + serializer_options = SerializeEventConfig(requester=requester) results = { "events_before": self._event_serializer.serialize_events( event_context.events_before, time_now, bundle_aggregations=event_context.aggregations, + config=serializer_options, ), "event": self._event_serializer.serialize_event( event_context.event, time_now, bundle_aggregations=event_context.aggregations, + config=serializer_options, ), "events_after": self._event_serializer.serialize_events( event_context.events_after, time_now, bundle_aggregations=event_context.aggregations, + config=serializer_options, ), "state": self._event_serializer.serialize_events( - event_context.state, time_now + event_context.state, + time_now, + config=serializer_options, ), "start": event_context.start, "end": event_context.end, @@ -1192,7 +1200,7 @@ class SearchRestServlet(RestServlet): content = parse_json_object_from_request(request) batch = parse_string(request, "next_batch") - results = await self.search_handler.search(requester.user, content, batch) + results = await self.search_handler.search(requester, content, batch) return 200, results diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 8fcb8ac3d9..e578b26fa3 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -38,7 +38,7 @@ from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.logging.opentracing import trace_with_opname -from synapse.types import JsonDict, StreamToken +from synapse.types import JsonDict, Requester, StreamToken from synapse.util import json_decoder from ._base import client_patterns, set_timeline_upper_limit @@ -226,7 +226,7 @@ class SyncRestServlet(RestServlet): # We know that the the requester has an access token since appservices # cannot use sync. response_content = await self.encode_response( - time_now, sync_result, requester.access_token_id, filter_collection + time_now, sync_result, requester, filter_collection ) logger.debug("Event formatting complete") @@ -237,7 +237,7 @@ class SyncRestServlet(RestServlet): self, time_now: int, sync_result: SyncResult, - access_token_id: Optional[int], + requester: Requester, filter: FilterCollection, ) -> JsonDict: logger.debug("Formatting events in sync response") @@ -250,12 +250,12 @@ class SyncRestServlet(RestServlet): serialize_options = SerializeEventConfig( event_format=event_formatter, - token_id=access_token_id, + requester=requester, only_event_fields=filter.event_fields, ) stripped_serialize_options = SerializeEventConfig( event_format=event_formatter, - token_id=access_token_id, + requester=requester, include_stripped_room_state=True, ) -- cgit 1.5.1