summary refs log tree commit diff
path: root/synapse/storage/databases/main/relations.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/relations.py')
-rw-r--r--synapse/storage/databases/main/relations.py218
1 files changed, 50 insertions, 168 deletions
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 36aa1092f6..b2295fd51f 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -27,7 +27,6 @@ from typing import (
 )
 
 import attr
-from frozendict import frozendict
 
 from synapse.api.constants import RelationTypes
 from synapse.events import EventBase
@@ -41,45 +40,15 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.stream import generate_pagination_where_clause
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
-from synapse.types import JsonDict, RoomStreamToken, StreamToken
+from synapse.types import RoomStreamToken, StreamToken
 from synapse.util.caches.descriptors import cached, cachedList
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
-    from synapse.storage.databases.main import DataStore
 
 logger = logging.getLogger(__name__)
 
 
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class _ThreadAggregation:
-    # The latest event in the thread.
-    latest_event: EventBase
-    # The latest edit to the latest event in the thread.
-    latest_edit: Optional[EventBase]
-    # The total number of events in the thread.
-    count: int
-    # True if the current user has sent an event to the thread.
-    current_user_participated: bool
-
-
-@attr.s(slots=True, auto_attribs=True)
-class BundledAggregations:
-    """
-    The bundled aggregations for an event.
-
-    Some values require additional processing during serialization.
-    """
-
-    annotations: Optional[JsonDict] = None
-    references: Optional[JsonDict] = None
-    replace: Optional[EventBase] = None
-    thread: Optional[_ThreadAggregation] = None
-
-    def __bool__(self) -> bool:
-        return bool(self.annotations or self.references or self.replace or self.thread)
-
-
 class RelationsWorkerStore(SQLBaseStore):
     def __init__(
         self,
@@ -91,10 +60,11 @@ class RelationsWorkerStore(SQLBaseStore):
 
         self._msc3440_enabled = hs.config.experimental.msc3440_enabled
 
-    @cached(tree=True)
+    @cached(uncached_args=("event",), tree=True)
     async def get_relations_for_event(
         self,
         event_id: str,
+        event: EventBase,
         room_id: str,
         relation_type: Optional[str] = None,
         event_type: Optional[str] = None,
@@ -108,6 +78,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
         Args:
             event_id: Fetch events that relate to this event ID.
+            event: The matching EventBase to event_id.
             room_id: The room the event belongs to.
             relation_type: Only fetch events with this relation type, if given.
             event_type: Only fetch events with this event type, if given.
@@ -122,9 +93,13 @@ class RelationsWorkerStore(SQLBaseStore):
             List of event IDs that match relations requested. The rows are of
             the form `{"event_id": "..."}`.
         """
+        # We don't use `event_id`, it's there so that we can cache based on
+        # it. The `event_id` must match the `event.event_id`.
+        assert event.event_id == event_id
 
         where_clause = ["relates_to_id = ?", "room_id = ?"]
-        where_args: List[Union[str, int]] = [event_id, room_id]
+        where_args: List[Union[str, int]] = [event.event_id, room_id]
+        is_redacted = event.internal_metadata.is_redacted()
 
         if relation_type is not None:
             where_clause.append("relation_type = ?")
@@ -157,7 +132,7 @@ class RelationsWorkerStore(SQLBaseStore):
             order = "ASC"
 
         sql = """
-            SELECT event_id, topological_ordering, stream_ordering
+            SELECT event_id, relation_type, topological_ordering, stream_ordering
             FROM event_relations
             INNER JOIN events USING (event_id)
             WHERE %s
@@ -178,9 +153,12 @@ class RelationsWorkerStore(SQLBaseStore):
             last_stream_id = None
             events = []
             for row in txn:
-                events.append({"event_id": row[0]})
-                last_topo_id = row[1]
-                last_stream_id = row[2]
+                # Do not include edits for redacted events as they leak event
+                # content.
+                if not is_redacted or row[1] != RelationTypes.REPLACE:
+                    events.append({"event_id": row[0]})
+                last_topo_id = row[2]
+                last_stream_id = row[3]
 
             # If there are more events, generate the next pagination key.
             next_token = None
@@ -375,7 +353,7 @@ class RelationsWorkerStore(SQLBaseStore):
         raise NotImplementedError()
 
     @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids")
-    async def _get_applicable_edits(
+    async def get_applicable_edits(
         self, event_ids: Collection[str]
     ) -> Dict[str, Optional[EventBase]]:
         """Get the most recent edit (if any) that has happened for the given
@@ -464,7 +442,7 @@ class RelationsWorkerStore(SQLBaseStore):
         raise NotImplementedError()
 
     @cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
-    async def _get_thread_summaries(
+    async def get_thread_summaries(
         self, event_ids: Collection[str]
     ) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]:
         """Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event.
@@ -499,7 +477,7 @@ class RelationsWorkerStore(SQLBaseStore):
                         AND parent.room_id = child.room_id
                     WHERE
                         %s
-                        AND relation_type = ?
+                        AND %s
                     ORDER BY parent.event_id, child.topological_ordering DESC, child.stream_ordering DESC
                 """
             else:
@@ -514,16 +492,22 @@ class RelationsWorkerStore(SQLBaseStore):
                         AND parent.room_id = child.room_id
                     WHERE
                         %s
-                        AND relation_type = ?
+                        AND %s
                     ORDER BY child.topological_ordering DESC, child.stream_ordering DESC
                 """
 
             clause, args = make_in_list_sql_clause(
                 txn.database_engine, "relates_to_id", event_ids
             )
-            args.append(RelationTypes.THREAD)
 
-            txn.execute(sql % (clause,), args)
+            if self._msc3440_enabled:
+                relations_clause = "(relation_type = ? OR relation_type = ?)"
+                args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD))
+            else:
+                relations_clause = "relation_type = ?"
+                args.append(RelationTypes.THREAD)
+
+            txn.execute(sql % (clause, relations_clause), args)
             latest_event_ids = {}
             for parent_event_id, child_event_id in txn:
                 # Only consider the latest threaded reply (by topological ordering).
@@ -543,7 +527,7 @@ class RelationsWorkerStore(SQLBaseStore):
                     AND parent.room_id = child.room_id
                 WHERE
                     %s
-                    AND relation_type = ?
+                    AND %s
                 GROUP BY parent.event_id
             """
 
@@ -552,9 +536,15 @@ class RelationsWorkerStore(SQLBaseStore):
             clause, args = make_in_list_sql_clause(
                 txn.database_engine, "relates_to_id", latest_event_ids.keys()
             )
-            args.append(RelationTypes.THREAD)
 
-            txn.execute(sql % (clause,), args)
+            if self._msc3440_enabled:
+                relations_clause = "(relation_type = ? OR relation_type = ?)"
+                args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD))
+            else:
+                relations_clause = "relation_type = ?"
+                args.append(RelationTypes.THREAD)
+
+            txn.execute(sql % (clause, relations_clause), args)
             counts = dict(cast(List[Tuple[str, int]], txn.fetchall()))
 
             return counts, latest_event_ids
@@ -566,7 +556,7 @@ class RelationsWorkerStore(SQLBaseStore):
         latest_events = await self.get_events(latest_event_ids.values())  # type: ignore[attr-defined]
 
         # Check to see if any of those events are edited.
-        latest_edits = await self._get_applicable_edits(latest_event_ids.values())
+        latest_edits = await self.get_applicable_edits(latest_event_ids.values())
 
         # Map to the event IDs to the thread summary.
         #
@@ -589,7 +579,7 @@ class RelationsWorkerStore(SQLBaseStore):
         raise NotImplementedError()
 
     @cachedList(cached_method_name="get_thread_participated", list_name="event_ids")
-    async def _get_threads_participated(
+    async def get_threads_participated(
         self, event_ids: Collection[str], user_id: str
     ) -> Dict[str, bool]:
         """Get whether the requesting user participated in the given threads.
@@ -617,16 +607,24 @@ class RelationsWorkerStore(SQLBaseStore):
                     AND parent.room_id = child.room_id
                 WHERE
                     %s
-                    AND relation_type = ?
+                    AND %s
                     AND child.sender = ?
             """
 
             clause, args = make_in_list_sql_clause(
                 txn.database_engine, "relates_to_id", event_ids
             )
-            args.extend((RelationTypes.THREAD, user_id))
 
-            txn.execute(sql % (clause,), args)
+            if self._msc3440_enabled:
+                relations_clause = "(relation_type = ? OR relation_type = ?)"
+                args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD))
+            else:
+                relations_clause = "relation_type = ?"
+                args.append(RelationTypes.THREAD)
+
+            args.append(user_id)
+
+            txn.execute(sql % (clause, relations_clause), args)
             return {row[0] for row in txn.fetchall()}
 
         participated_threads = await self.db_pool.runInteraction(
@@ -737,122 +735,6 @@ class RelationsWorkerStore(SQLBaseStore):
             "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
         )
 
-    async def _get_bundled_aggregation_for_event(
-        self, event: EventBase, user_id: str
-    ) -> Optional[BundledAggregations]:
-        """Generate bundled aggregations for an event.
-
-        Note that this does not use a cache, but depends on cached methods.
-
-        Args:
-            event: The event to calculate bundled aggregations for.
-            user_id: The user requesting the bundled aggregations.
-
-        Returns:
-            The bundled aggregations for an event, if bundled aggregations are
-            enabled and the event can have bundled aggregations.
-        """
-
-        # Do not bundle aggregations for an event which represents an edit or an
-        # annotation. It does not make sense for them to have related events.
-        relates_to = event.content.get("m.relates_to")
-        if isinstance(relates_to, (dict, frozendict)):
-            relation_type = relates_to.get("rel_type")
-            if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
-                return None
-
-        event_id = event.event_id
-        room_id = event.room_id
-
-        # The bundled aggregations to include, a mapping of relation type to a
-        # type-specific value. Some types include the direct return type here
-        # while others need more processing during serialization.
-        aggregations = BundledAggregations()
-
-        annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
-        if annotations.chunk:
-            aggregations.annotations = await annotations.to_dict(
-                cast("DataStore", self)
-            )
-
-        references = await self.get_relations_for_event(
-            event_id, room_id, RelationTypes.REFERENCE, direction="f"
-        )
-        if references.chunk:
-            aggregations.references = await references.to_dict(cast("DataStore", self))
-
-        # Store the bundled aggregations in the event metadata for later use.
-        return aggregations
-
-    async def get_bundled_aggregations(
-        self, events: Iterable[EventBase], user_id: str
-    ) -> Dict[str, BundledAggregations]:
-        """Generate bundled aggregations for events.
-
-        Args:
-            events: The iterable of events to calculate bundled aggregations for.
-            user_id: The user requesting the bundled aggregations.
-
-        Returns:
-            A map of event ID to the bundled aggregation for the event. Not all
-            events may have bundled aggregations in the results.
-        """
-        # The already processed event IDs. Tracked separately from the result
-        # since the result omits events which do not have bundled aggregations.
-        seen_event_ids = set()
-
-        # State events and redacted events do not get bundled aggregations.
-        events = [
-            event
-            for event in events
-            if not event.is_state() and not event.internal_metadata.is_redacted()
-        ]
-
-        # event ID -> bundled aggregation in non-serialized form.
-        results: Dict[str, BundledAggregations] = {}
-
-        # Fetch other relations per event.
-        for event in events:
-            # De-duplicate events by ID to handle the same event requested multiple
-            # times. The caches that _get_bundled_aggregation_for_event use should
-            # capture this, but best to reduce work.
-            if event.event_id in seen_event_ids:
-                continue
-            seen_event_ids.add(event.event_id)
-
-            event_result = await self._get_bundled_aggregation_for_event(event, user_id)
-            if event_result:
-                results[event.event_id] = event_result
-
-        # Fetch any edits.
-        edits = await self._get_applicable_edits(seen_event_ids)
-        for event_id, edit in edits.items():
-            results.setdefault(event_id, BundledAggregations()).replace = edit
-
-        # Fetch thread summaries.
-        if self._msc3440_enabled:
-            summaries = await self._get_thread_summaries(seen_event_ids)
-            # Only fetch participated for a limited selection based on what had
-            # summaries.
-            participated = await self._get_threads_participated(
-                summaries.keys(), user_id
-            )
-            for event_id, summary in summaries.items():
-                if summary:
-                    thread_count, latest_thread_event, edit = summary
-                    results.setdefault(
-                        event_id, BundledAggregations()
-                    ).thread = _ThreadAggregation(
-                        latest_event=latest_thread_event,
-                        latest_edit=edit,
-                        count=thread_count,
-                        # If there's a thread summary it must also exist in the
-                        # participated dictionary.
-                        current_user_participated=participated[event_id],
-                    )
-
-        return results
-
 
 class RelationsStore(RelationsWorkerStore):
     pass