summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2022-05-21 13:58:52 +0100
committerErik Johnston <erik@matrix.org>2022-05-21 13:58:52 +0100
commit2ebb0c6f994732928bc6040976880d74477e5e79 (patch)
treea0886e3332577ce1e8f4f7c7f19295110fa0d911
parentSend `USER_IP` commands on a different Redis channel, in order to reduce traf... (diff)
downloadsynapse-2ebb0c6f994732928bc6040976880d74477e5e79.tar.xz
Pull out less state when handling gaps
-rw-r--r--synapse/handlers/federation_event.py140
-rw-r--r--synapse/handlers/message.py17
-rw-r--r--synapse/state/__init__.py14
-rw-r--r--synapse/storage/databases/main/state.py42
4 files changed, 127 insertions, 86 deletions
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py

index 05c122f224..8521a230cb 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py
@@ -463,7 +463,9 @@ class FederationEventHandler: with nested_logging_context(suffix=event.event_id): context = await self._state_handler.compute_event_context( event, - old_state=state, + state_ids_before_event={ + (e.type, e.state_key): e.event_id for e in state + }, partial_state=partial_state, ) @@ -501,7 +503,7 @@ class FederationEventHandler: # build a new state group for it if need be context = await self._state_handler.compute_event_context( event, - old_state=state, + state_ids_before_event=state, ) if context.partial_state: # this can happen if some or all of the event's prev_events still have @@ -765,7 +767,7 @@ class FederationEventHandler: async def _resolve_state_at_missing_prevs( self, dest: str, event: EventBase - ) -> Optional[Iterable[EventBase]]: + ) -> Optional[StateMap[str]]: """Calculate the state at an event with missing prev_events. This is used when we have pulled a batch of events from a remote server, and @@ -792,8 +794,8 @@ class FederationEventHandler: event: an event to check for missing prevs. Returns: - if we already had all the prev events, `None`. Otherwise, returns a list of - the events in the state at `event`. + if we already had all the prev events, `None`. Otherwise, returns + the state at `event`. """ room_id = event.room_id event_id = event.event_id @@ -837,13 +839,7 @@ class FederationEventHandler: dest, room_id, p ) - remote_state_map = { - (x.type, x.state_key): x.event_id for x in remote_state - } - state_maps.append(remote_state_map) - - for x in remote_state: - event_map[x.event_id] = x + state_maps.append(remote_state) room_version = await self._store.get_room_version_id(room_id) state_map = await self._state_resolution_handler.resolve_events_with_store( @@ -854,19 +850,6 @@ class FederationEventHandler: state_res_store=StateResolutionStore(self._store), ) - # We need to give _process_received_pdu the actual state events - # rather than event ids, so generate that now. - - # First though we need to fetch all the events that are in - # state_map, so we can build up the state below. - evs = await self._store.get_events( - list(state_map.values()), - get_prev_content=False, - redact_behaviour=EventRedactBehaviour.as_is, - ) - event_map.update(evs) - - state = [event_map[e] for e in state_map.values()] except Exception: logger.warning( "Error attempting to resolve state at missing prev_events", @@ -878,14 +861,14 @@ class FederationEventHandler: "We can't get valid state history.", affected=event_id, ) - return state + return state_map async def _get_state_after_missing_prev_event( self, destination: str, room_id: str, event_id: str, - ) -> List[EventBase]: + ) -> StateMap[str]: """Requests all of the room state at a given event from a remote homeserver. Args: @@ -894,7 +877,7 @@ class FederationEventHandler: event_id: The id of the event we want the state at. Returns: - A list of events in the state, including the event itself + The state *after* the given event. """ ( state_event_ids, @@ -913,15 +896,13 @@ class FederationEventHandler: desired_events = set(state_event_ids) desired_events.add(event_id) logger.debug("Fetching %i events from cache/store", len(desired_events)) - fetched_events = await self._store.get_events( - desired_events, allow_rejected=True - ) + have_events = await self._store.have_seen_events(room_id, desired_events) - missing_desired_events = desired_events - fetched_events.keys() + missing_desired_events = desired_events - have_events logger.debug( "We are missing %i events (got %i)", len(missing_desired_events), - len(fetched_events), + len(have_events), ) # We probably won't need most of the auth events, so let's just check which @@ -932,7 +913,7 @@ class FederationEventHandler: # already have a bunch of the state events. It would be nice if the # federation api gave us a way of finding out which we actually need. - missing_auth_events = set(auth_event_ids) - fetched_events.keys() + missing_auth_events = set(auth_event_ids) - have_events missing_auth_events.difference_update( await self._store.have_seen_events(room_id, missing_auth_events) ) @@ -958,47 +939,54 @@ class FederationEventHandler: destination=destination, room_id=room_id, event_ids=missing_events ) - # we need to make sure we re-load from the database to get the rejected - # state correct. - fetched_events.update( - await self._store.get_events(missing_desired_events, allow_rejected=True) - ) + event_metadata = await self._store.get_metadata_for_events(state_event_ids) # check for events which were in the wrong room. # # this can happen if a remote server claims that the state or # auth_events at an event in room A are actually events in room B - bad_events = [ - (event_id, event.room_id) - for event_id, event in fetched_events.items() - if event.room_id != room_id - ] + event_metadata = await self._store.get_metadata_for_events(state_event_ids) - for bad_event_id, bad_room_id in bad_events: - # This is a bogus situation, but since we may only discover it a long time - # after it happened, we try our best to carry on, by just omitting the - # bad events from the returned state set. - logger.warning( - "Remote server %s claims event %s in room %s is an auth/state " - "event in room %s", - destination, - bad_event_id, - bad_room_id, - room_id, - ) + state_map = {} + + for state_event_id, metadata in event_metadata.items(): + if metadata.room_id != room_id: + # This is a bogus situation, but since we may only discover it a long time + # after it happened, we try our best to carry on, by just omitting the + # bad events from the returned state set. + logger.warning( + "Remote server %s claims event %s in room %s is an auth/state " + "event in room %s", + destination, + state_event_id, + metadata.room_id, + room_id, + ) + continue + + if metadata.state_key is None: + logger.warning( + "Remote server gave us non-state event in state: %s", state_event_id + ) + continue - del fetched_events[bad_event_id] + state_map[(metadata.event_type, metadata.state_key)] = state_event_id # if we couldn't get the prev event in question, that's a problem. - remote_event = fetched_events.get(event_id) + remote_event = await self._store.get_event( + event_id, + allow_none=True, + allow_rejected=True, + redact_behaviour=EventRedactBehaviour.as_is, + ) if not remote_event: raise Exception("Unable to get missing prev_event %s" % (event_id,)) # missing state at that event is a warning, not a blocker # XXX: this doesn't sound right? it means that we'll end up with incomplete # state. - failed_to_fetch = desired_events - fetched_events.keys() + failed_to_fetch = desired_events - event_metadata.keys() if failed_to_fetch: logger.warning( "Failed to fetch missing state events for %s %s", @@ -1006,14 +994,12 @@ class FederationEventHandler: failed_to_fetch, ) - remote_state = [ - fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events - ] - if remote_event.is_state() and remote_event.rejected_reason is None: - remote_state.append(remote_event) + state_map[ + (remote_event.type, remote_event.state_key) + ] = remote_event.event_id - return remote_state + return state_map async def _get_state_and_persist( self, destination: str, room_id: str, event_id: str @@ -1040,7 +1026,7 @@ class FederationEventHandler: self, origin: str, event: EventBase, - state: Optional[Iterable[EventBase]], + state: Optional[StateMap[str]], backfilled: bool = False, ) -> None: """Called when we have a new non-outlier event. @@ -1074,7 +1060,7 @@ class FederationEventHandler: try: context = await self._state_handler.compute_event_context( - event, old_state=state + event, state_ids_before_event=state ) context = await self._check_event_auth( origin, @@ -1565,7 +1551,7 @@ class FederationEventHandler: async def _check_for_soft_fail( self, event: EventBase, - state: Optional[Iterable[EventBase]], + state: Optional[StateMap[str]], origin: str, ) -> None: """Checks if we should soft fail the event; if so, marks the event as @@ -1602,17 +1588,21 @@ class FederationEventHandler: # given state at the event. This should correctly handle cases # like bans, especially with state res v2. - state_sets_d = await self._state_store.get_state_groups( + state_sets_d = await self._state_store.get_state_groups_ids( event.room_id, extrem_ids ) - state_sets: List[Iterable[EventBase]] = list(state_sets_d.values()) + state_sets: List[StateMap[str]] = list(state_sets_d.values()) state_sets.append(state) - current_states = await self._state_handler.resolve_events( - room_version, state_sets, event + + current_state_ids = ( + await self._state_resolution_handler.resolve_events_with_store( + event.room_id, + room_version, + state_sets, + event_map={}, + state_res_store=StateResolutionStore(self._store), + ) ) - current_state_ids: StateMap[str] = { - k: e.event_id for k, e in current_states.items() - } else: current_state_ids = await self._state_handler.get_current_state_ids( event.room_id, latest_event_ids=extrem_ids diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index e566ff1f8e..79e4b1cce3 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py
@@ -1021,8 +1021,21 @@ class EventCreationHandler: # # TODO(faster_joins): figure out how this works, and make sure that the # old state is complete. - old_state = await self.store.get_events_as_list(state_event_ids) - context = await self.state.compute_event_context(event, old_state=old_state) + metadata = await self.store.get_metadata_for_events(state_event_ids) + + state_map = {} + for event_id, data in metadata.items(): + if data.state_key is None: + raise Exception( + "Trying to set non-state event as state: %s", event_id + ) + + state_map[(data.event_type, data.state_key)] = event_id + + context = await self.state.compute_event_context( + event, + state_ids_before_event=state_map, + ) else: context = await self.state.compute_event_context(event) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 4b4ed42cff..712facd3a8 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py
@@ -261,7 +261,7 @@ class StateHandler: async def compute_event_context( self, event: EventBase, - old_state: Optional[Iterable[EventBase]] = None, + state_ids_before_event: Optional[StateMap[str]] = None, partial_state: bool = False, ) -> EventContext: """Build an EventContext structure for a non-outlier event. @@ -273,12 +273,12 @@ class StateHandler: Args: event: - old_state: The state at the event if it can't be + state_ids_before_event: The state at the event if it can't be calculated from existing events. This is normally only specified when receiving an event from federation where we don't have the prev events for, e.g. when backfilling. - partial_state: True if `old_state` is partial and omits non-critical - membership events + partial_state: True if `state_ids_before_event` is partial and omits + non-critical membership events Returns: The event context. """ @@ -288,11 +288,7 @@ class StateHandler: # # first of all, figure out the state before the event # - if old_state: - # if we're given the state before the event, then we use that - state_ids_before_event: StateMap[str] = { - (s.type, s.state_key): s.event_id for s in old_state - } + if state_ids_before_event: state_group_before_event = None state_group_before_event_prev_group = None deltas_to_state_group_before_event = None diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 18ae8aee29..d939a0e0b6 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py
@@ -16,6 +16,8 @@ import collections.abc import logging from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple +import attr + from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion @@ -26,6 +28,7 @@ from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, LoggingTransaction, + make_in_list_sql_clause, ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore @@ -43,6 +46,15 @@ logger = logging.getLogger(__name__) MAX_STATE_DELTA_HOPS = 100 +@attr.s(slots=True, frozen=True, auto_attribs=True) +class EventMetadata: + """Returned by `get_metadata_for_events`""" + + room_id: str + event_type: str + state_key: Optional[str] + + def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion: v = KNOWN_ROOM_VERSIONS.get(room_version_id) if not v: @@ -133,6 +145,36 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return room_version + async def get_metadata_for_events( + self, event_ids: Collection[str] + ) -> Dict[str, EventMetadata]: + """Get some metadata (room_id, type, state_key) for the given events.""" + + clause, args = make_in_list_sql_clause( + self.database_engine, "e.event_id", event_ids + ) + + sql = f""" + SELECT e.event_id, e.room_id, e.type, e.state_key FROM events AS e + LEFT JOIN state_events USING (event_id) + WHERE {clause} + """ + + def get_metadata_for_events_txn( + txn: LoggingTransaction, + ) -> Dict[str, EventMetadata]: + txn.execute(sql, args) + return { + event_id: EventMetadata( + room_id=room_id, event_type=event_type, state_key=state_key + ) + for event_id, room_id, event_type, state_key in txn + } + + return await self.db_pool.runInteraction( + "get_metadata_for_events", get_metadata_for_events_txn + ) + async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]: """Get the predecessor of an upgraded room if it exists. Otherwise return None.