diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 4ff6aed253..c6c4bd18da 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,14 +13,30 @@
# limitations under the License.
import logging
-from typing import List, Optional, Tuple, Union, cast
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Union,
+ cast,
+)
import attr
+from frozendict import frozendict
-from synapse.api.constants import RelationTypes
+from synapse.api.constants import EventTypes, RelationTypes
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+ make_in_list_sql_clause,
+)
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.relations import (
AggregationPaginationToken,
@@ -29,10 +45,24 @@ from synapse.storage.relations import (
)
from synapse.util.caches.descriptors import cached
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class RelationsWorkerStore(SQLBaseStore):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
+ super().__init__(database, db_conn, hs)
+
+ self._msc1849_enabled = hs.config.experimental.msc1849_enabled
+ self._msc3440_enabled = hs.config.experimental.msc3440_enabled
+
@cached(tree=True)
async def get_relations_for_event(
self,
@@ -515,6 +545,98 @@ class RelationsWorkerStore(SQLBaseStore):
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)
+ async def _get_bundled_aggregation_for_event(
+ self, event: EventBase
+ ) -> Optional[Dict[str, Any]]:
+ """Generate bundled aggregations for an event.
+
+ Note that this does not use a cache, but depends on cached methods.
+
+ Args:
+ event: The event to calculate bundled aggregations for.
+
+ Returns:
+ The bundled aggregations for an event, if bundled aggregations are
+ enabled and the event can have bundled aggregations.
+ """
+ # State events and redacted events do not get bundled aggregations.
+ if event.is_state() or event.internal_metadata.is_redacted():
+ return None
+
+ # Do not bundle aggregations for an event which represents an edit or an
+ # annotation. It does not make sense for them to have related events.
+ relates_to = event.content.get("m.relates_to")
+ if isinstance(relates_to, (dict, frozendict)):
+ relation_type = relates_to.get("rel_type")
+ if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
+ return None
+
+ event_id = event.event_id
+ room_id = event.room_id
+
+ # The bundled aggregations to include, a mapping of relation type to a
+ # type-specific value. Some types include the direct return type here
+ # while others need more processing during serialization.
+ aggregations: Dict[str, Any] = {}
+
+ annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
+ if annotations.chunk:
+ aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
+
+ references = await self.get_relations_for_event(
+ event_id, room_id, RelationTypes.REFERENCE, direction="f"
+ )
+ if references.chunk:
+ aggregations[RelationTypes.REFERENCE] = references.to_dict()
+
+ edit = None
+ if event.type == EventTypes.Message:
+ edit = await self.get_applicable_edit(event_id, room_id)
+
+ if edit:
+ aggregations[RelationTypes.REPLACE] = edit
+
+ # If this event is the start of a thread, include a summary of the replies.
+ if self._msc3440_enabled:
+ (
+ thread_count,
+ latest_thread_event,
+ ) = await self.get_thread_summary(event_id, room_id)
+ if latest_thread_event:
+ aggregations[RelationTypes.THREAD] = {
+ # Don't bundle aggregations as this could recurse forever.
+ "latest_event": latest_thread_event,
+ "count": thread_count,
+ }
+
+ # Store the bundled aggregations in the event metadata for later use.
+ return aggregations
+
+ async def get_bundled_aggregations(
+ self, events: Iterable[EventBase]
+ ) -> Dict[str, Dict[str, Any]]:
+ """Generate bundled aggregations for events.
+
+ Args:
+ events: The iterable of events to calculate bundled aggregations for.
+
+ Returns:
+ A map of event ID to the bundled aggregation for the event. Not all
+ events may have bundled aggregations in the results.
+ """
+ # If bundled aggregations are disabled, nothing to do.
+ if not self._msc1849_enabled:
+ return {}
+
+ # TODO Parallelize.
+ results = {}
+ for event in events:
+ event_result = await self._get_bundled_aggregation_for_event(event)
+ if event_result is not None:
+ results[event.event_id] = event_result
+
+ return results
+
class RelationsStore(RelationsWorkerStore):
pass
|