summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/events.py7
-rw-r--r--synapse/storage/databases/main/relations.py66
2 files changed, 62 insertions, 11 deletions
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 2be36a741a..7278002322 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1793,6 +1793,13 @@ class PersistEventsStore:
             txn.call_after(
                 self.store.get_thread_summary.invalidate, (parent_id, event.room_id)
             )
+            # It should be safe to only invalidate the cache if the user has not
+            # previously participated in the thread, but that's difficult (and
+            # potentially error-prone) so it is always invalidated.
+            txn.call_after(
+                self.store.get_thread_participated.invalidate,
+                (parent_id, event.room_id, event.sender),
+            )
 
     def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
         """Handles keeping track of insertion events and edges/connections.
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index c6c4bd18da..2cb5d06c13 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -384,8 +384,7 @@ class RelationsWorkerStore(SQLBaseStore):
     async def get_thread_summary(
         self, event_id: str, room_id: str
     ) -> Tuple[int, Optional[EventBase]]:
-        """Get the number of threaded replies, the senders of those replies, and
-        the latest reply (if any) for the given event.
+        """Get the number of threaded replies and the latest reply (if any) for the given event.
 
         Args:
             event_id: Summarize the thread related to this event ID.
@@ -398,7 +397,7 @@ class RelationsWorkerStore(SQLBaseStore):
         def _get_thread_summary_txn(
             txn: LoggingTransaction,
         ) -> Tuple[int, Optional[str]]:
-            # Fetch the count of threaded events and the latest event ID.
+            # Fetch the latest event ID in the thread.
             # TODO Should this only allow m.room.message events.
             sql = """
                 SELECT event_id
@@ -419,6 +418,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
             latest_event_id = row[0]
 
+            # Fetch the number of threaded replies.
             sql = """
                 SELECT COUNT(event_id)
                 FROM event_relations
@@ -443,6 +443,44 @@ class RelationsWorkerStore(SQLBaseStore):
 
         return count, latest_event
 
+    @cached()
+    async def get_thread_participated(
+        self, event_id: str, room_id: str, user_id: str
+    ) -> bool:
+        """Get whether the requesting user participated in a thread.
+
+        This is separate from get_thread_summary since that can be cached across
+        all users while this value is specific to the requeser.
+
+        Args:
+            event_id: The thread related to this event ID.
+            room_id: The room the event belongs to.
+            user_id: The user requesting the summary.
+
+        Returns:
+            True if the requesting user participated in the thread, otherwise false.
+        """
+
+        def _get_thread_summary_txn(txn: LoggingTransaction) -> bool:
+            # Fetch whether the requester has participated or not.
+            sql = """
+                SELECT 1
+                FROM event_relations
+                INNER JOIN events USING (event_id)
+                WHERE
+                    relates_to_id = ?
+                    AND room_id = ?
+                    AND relation_type = ?
+                    AND sender = ?
+            """
+
+            txn.execute(sql, (event_id, room_id, RelationTypes.THREAD, user_id))
+            return bool(txn.fetchone())
+
+        return await self.db_pool.runInteraction(
+            "get_thread_summary", _get_thread_summary_txn
+        )
+
     async def events_have_relations(
         self,
         parent_ids: List[str],
@@ -546,7 +584,7 @@ class RelationsWorkerStore(SQLBaseStore):
         )
 
     async def _get_bundled_aggregation_for_event(
-        self, event: EventBase
+        self, event: EventBase, user_id: str
     ) -> Optional[Dict[str, Any]]:
         """Generate bundled aggregations for an event.
 
@@ -554,6 +592,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
         Args:
             event: The event to calculate bundled aggregations for.
+            user_id: The user requesting the bundled aggregations.
 
         Returns:
             The bundled aggregations for an event, if bundled aggregations are
@@ -598,27 +637,32 @@ class RelationsWorkerStore(SQLBaseStore):
 
         # If this event is the start of a thread, include a summary of the replies.
         if self._msc3440_enabled:
-            (
-                thread_count,
-                latest_thread_event,
-            ) = await self.get_thread_summary(event_id, room_id)
+            thread_count, latest_thread_event = await self.get_thread_summary(
+                event_id, room_id
+            )
+            participated = await self.get_thread_participated(
+                event_id, room_id, user_id
+            )
             if latest_thread_event:
                 aggregations[RelationTypes.THREAD] = {
-                    # Don't bundle aggregations as this could recurse forever.
                     "latest_event": latest_thread_event,
                     "count": thread_count,
+                    "current_user_participated": participated,
                 }
 
         # Store the bundled aggregations in the event metadata for later use.
         return aggregations
 
     async def get_bundled_aggregations(
-        self, events: Iterable[EventBase]
+        self,
+        events: Iterable[EventBase],
+        user_id: str,
     ) -> Dict[str, Dict[str, Any]]:
         """Generate bundled aggregations for events.
 
         Args:
             events: The iterable of events to calculate bundled aggregations for.
+            user_id: The user requesting the bundled aggregations.
 
         Returns:
             A map of event ID to the bundled aggregation for the event. Not all
@@ -631,7 +675,7 @@ class RelationsWorkerStore(SQLBaseStore):
         # TODO Parallelize.
         results = {}
         for event in events:
-            event_result = await self._get_bundled_aggregation_for_event(event)
+            event_result = await self._get_bundled_aggregation_for_event(event, user_id)
             if event_result is not None:
                 results[event.event_id] = event_result