From 6bf81a7a61d8d5248be5def955104c44fcb78dae Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 7 Jan 2022 09:10:46 -0500 Subject: Bundle aggregations outside of the serialization method. (#11612) This makes the serialization of events synchronous (and it no longer access the database), but we must manually calculate and provide the bundled aggregations. Overall this should cause no change in behavior, but is prep work for other improvements. --- synapse/storage/databases/main/relations.py | 128 +++++++++++++++++++++++++++- 1 file changed, 125 insertions(+), 3 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 4ff6aed253..c6c4bd18da 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -13,14 +13,30 @@ # limitations under the License. import logging -from typing import List, Optional, Tuple, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, + cast, +) import attr +from frozendict import frozendict -from synapse.api.constants import RelationTypes +from synapse.api.constants import EventTypes, RelationTypes from synapse.events import EventBase from synapse.storage._base import SQLBaseStore -from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + make_in_list_sql_clause, +) from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.relations import ( AggregationPaginationToken, @@ -29,10 +45,24 @@ from synapse.storage.relations import ( ) from synapse.util.caches.descriptors import cached +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class RelationsWorkerStore(SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self._msc1849_enabled = hs.config.experimental.msc1849_enabled + self._msc3440_enabled = hs.config.experimental.msc3440_enabled + @cached(tree=True) async def get_relations_for_event( self, @@ -515,6 +545,98 @@ class RelationsWorkerStore(SQLBaseStore): "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) + async def _get_bundled_aggregation_for_event( + self, event: EventBase + ) -> Optional[Dict[str, Any]]: + """Generate bundled aggregations for an event. + + Note that this does not use a cache, but depends on cached methods. + + Args: + event: The event to calculate bundled aggregations for. + + Returns: + The bundled aggregations for an event, if bundled aggregations are + enabled and the event can have bundled aggregations. + """ + # State events and redacted events do not get bundled aggregations. + if event.is_state() or event.internal_metadata.is_redacted(): + return None + + # Do not bundle aggregations for an event which represents an edit or an + # annotation. It does not make sense for them to have related events. + relates_to = event.content.get("m.relates_to") + if isinstance(relates_to, (dict, frozendict)): + relation_type = relates_to.get("rel_type") + if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE): + return None + + event_id = event.event_id + room_id = event.room_id + + # The bundled aggregations to include, a mapping of relation type to a + # type-specific value. Some types include the direct return type here + # while others need more processing during serialization. + aggregations: Dict[str, Any] = {} + + annotations = await self.get_aggregation_groups_for_event(event_id, room_id) + if annotations.chunk: + aggregations[RelationTypes.ANNOTATION] = annotations.to_dict() + + references = await self.get_relations_for_event( + event_id, room_id, RelationTypes.REFERENCE, direction="f" + ) + if references.chunk: + aggregations[RelationTypes.REFERENCE] = references.to_dict() + + edit = None + if event.type == EventTypes.Message: + edit = await self.get_applicable_edit(event_id, room_id) + + if edit: + aggregations[RelationTypes.REPLACE] = edit + + # If this event is the start of a thread, include a summary of the replies. + if self._msc3440_enabled: + ( + thread_count, + latest_thread_event, + ) = await self.get_thread_summary(event_id, room_id) + if latest_thread_event: + aggregations[RelationTypes.THREAD] = { + # Don't bundle aggregations as this could recurse forever. + "latest_event": latest_thread_event, + "count": thread_count, + } + + # Store the bundled aggregations in the event metadata for later use. + return aggregations + + async def get_bundled_aggregations( + self, events: Iterable[EventBase] + ) -> Dict[str, Dict[str, Any]]: + """Generate bundled aggregations for events. + + Args: + events: The iterable of events to calculate bundled aggregations for. + + Returns: + A map of event ID to the bundled aggregation for the event. Not all + events may have bundled aggregations in the results. + """ + # If bundled aggregations are disabled, nothing to do. + if not self._msc1849_enabled: + return {} + + # TODO Parallelize. + results = {} + for event in events: + event_result = await self._get_bundled_aggregation_for_event(event) + if event_result is not None: + results[event.event_id] = event_result + + return results + class RelationsStore(RelationsWorkerStore): pass -- cgit 1.4.1