summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2022-01-18 11:38:57 -0500
committerGitHub <noreply@github.com>2022-01-18 11:38:57 -0500
commit68acb0a29dcb03a0ecbcebdb95e09c5999598f42 (patch)
treea14144b467bf09b1a60a608a511f2065ac1d16a0 /synapse
parentRemove `log_function` and its uses (#11761) (diff)
downloadsynapse-68acb0a29dcb03a0ecbcebdb95e09c5999598f42.tar.xz
Include whether the requesting user has participated in a thread. (#11577)
Per updates to MSC3440.

This is implement as a separate method since it needs to be cached
on a per-user basis, instead of a per-thread basis.
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/pagination.py2
-rw-r--r--synapse/handlers/room.py12
-rw-r--r--synapse/handlers/sync.py4
-rw-r--r--synapse/rest/client/relations.py4
-rw-r--r--synapse/rest/client/room.py4
-rw-r--r--synapse/storage/databases/main/events.py7
-rw-r--r--synapse/storage/databases/main/relations.py66
7 files changed, 81 insertions, 18 deletions
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 472688f045..973f262964 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -537,7 +537,7 @@ class PaginationHandler:
                 state_dict = await self.store.get_events(list(state_ids.values()))
                 state = state_dict.values()
 
-        aggregations = await self.store.get_bundled_aggregations(events)
+        aggregations = await self.store.get_bundled_aggregations(events, user_id)
 
         time_now = self.clock.time_msec()
 
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 3d47163f25..f963078e59 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1182,12 +1182,18 @@ class RoomContextHandler:
         results["event"] = filtered[0]
 
         # Fetch the aggregations.
-        aggregations = await self.store.get_bundled_aggregations([results["event"]])
+        aggregations = await self.store.get_bundled_aggregations(
+            [results["event"]], user.to_string()
+        )
         aggregations.update(
-            await self.store.get_bundled_aggregations(results["events_before"])
+            await self.store.get_bundled_aggregations(
+                results["events_before"], user.to_string()
+            )
         )
         aggregations.update(
-            await self.store.get_bundled_aggregations(results["events_after"])
+            await self.store.get_bundled_aggregations(
+                results["events_after"], user.to_string()
+            )
         )
         results["aggregations"] = aggregations
 
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index e1df9b3106..ffc6b748e8 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -637,7 +637,9 @@ class SyncHandler:
         # as clients will have all the necessary information.
         bundled_aggregations = None
         if limited or newly_joined_room:
-            bundled_aggregations = await self.store.get_bundled_aggregations(recents)
+            bundled_aggregations = await self.store.get_bundled_aggregations(
+                recents, sync_config.user.to_string()
+            )
 
         return TimelineBatch(
             events=recents,
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index 37d949a71e..8cf5ebaa07 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -118,7 +118,9 @@ class RelationPaginationServlet(RestServlet):
         )
         # The relations returned for the requested event do include their
         # bundled aggregations.
-        aggregations = await self.store.get_bundled_aggregations(events)
+        aggregations = await self.store.get_bundled_aggregations(
+            events, requester.user.to_string()
+        )
         serialized_events = self._event_serializer.serialize_events(
             events, now, bundle_aggregations=aggregations
         )
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index da6014900a..31fd329a38 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -663,7 +663,9 @@ class RoomEventServlet(RestServlet):
 
         if event:
             # Ensure there are bundled aggregations available.
-            aggregations = await self._store.get_bundled_aggregations([event])
+            aggregations = await self._store.get_bundled_aggregations(
+                [event], requester.user.to_string()
+            )
 
             time_now = self.clock.time_msec()
             event_dict = self._event_serializer.serialize_event(
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 2be36a741a..7278002322 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1793,6 +1793,13 @@ class PersistEventsStore:
             txn.call_after(
                 self.store.get_thread_summary.invalidate, (parent_id, event.room_id)
             )
+            # It should be safe to only invalidate the cache if the user has not
+            # previously participated in the thread, but that's difficult (and
+            # potentially error-prone) so it is always invalidated.
+            txn.call_after(
+                self.store.get_thread_participated.invalidate,
+                (parent_id, event.room_id, event.sender),
+            )
 
     def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
         """Handles keeping track of insertion events and edges/connections.
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index c6c4bd18da..2cb5d06c13 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -384,8 +384,7 @@ class RelationsWorkerStore(SQLBaseStore):
     async def get_thread_summary(
         self, event_id: str, room_id: str
     ) -> Tuple[int, Optional[EventBase]]:
-        """Get the number of threaded replies, the senders of those replies, and
-        the latest reply (if any) for the given event.
+        """Get the number of threaded replies and the latest reply (if any) for the given event.
 
         Args:
             event_id: Summarize the thread related to this event ID.
@@ -398,7 +397,7 @@ class RelationsWorkerStore(SQLBaseStore):
         def _get_thread_summary_txn(
             txn: LoggingTransaction,
         ) -> Tuple[int, Optional[str]]:
-            # Fetch the count of threaded events and the latest event ID.
+            # Fetch the latest event ID in the thread.
             # TODO Should this only allow m.room.message events.
             sql = """
                 SELECT event_id
@@ -419,6 +418,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
             latest_event_id = row[0]
 
+            # Fetch the number of threaded replies.
             sql = """
                 SELECT COUNT(event_id)
                 FROM event_relations
@@ -443,6 +443,44 @@ class RelationsWorkerStore(SQLBaseStore):
 
         return count, latest_event
 
+    @cached()
+    async def get_thread_participated(
+        self, event_id: str, room_id: str, user_id: str
+    ) -> bool:
+        """Get whether the requesting user participated in a thread.
+
+        This is separate from get_thread_summary since that can be cached across
+        all users while this value is specific to the requeser.
+
+        Args:
+            event_id: The thread related to this event ID.
+            room_id: The room the event belongs to.
+            user_id: The user requesting the summary.
+
+        Returns:
+            True if the requesting user participated in the thread, otherwise false.
+        """
+
+        def _get_thread_summary_txn(txn: LoggingTransaction) -> bool:
+            # Fetch whether the requester has participated or not.
+            sql = """
+                SELECT 1
+                FROM event_relations
+                INNER JOIN events USING (event_id)
+                WHERE
+                    relates_to_id = ?
+                    AND room_id = ?
+                    AND relation_type = ?
+                    AND sender = ?
+            """
+
+            txn.execute(sql, (event_id, room_id, RelationTypes.THREAD, user_id))
+            return bool(txn.fetchone())
+
+        return await self.db_pool.runInteraction(
+            "get_thread_summary", _get_thread_summary_txn
+        )
+
     async def events_have_relations(
         self,
         parent_ids: List[str],
@@ -546,7 +584,7 @@ class RelationsWorkerStore(SQLBaseStore):
         )
 
     async def _get_bundled_aggregation_for_event(
-        self, event: EventBase
+        self, event: EventBase, user_id: str
     ) -> Optional[Dict[str, Any]]:
         """Generate bundled aggregations for an event.
 
@@ -554,6 +592,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
         Args:
             event: The event to calculate bundled aggregations for.
+            user_id: The user requesting the bundled aggregations.
 
         Returns:
             The bundled aggregations for an event, if bundled aggregations are
@@ -598,27 +637,32 @@ class RelationsWorkerStore(SQLBaseStore):
 
         # If this event is the start of a thread, include a summary of the replies.
         if self._msc3440_enabled:
-            (
-                thread_count,
-                latest_thread_event,
-            ) = await self.get_thread_summary(event_id, room_id)
+            thread_count, latest_thread_event = await self.get_thread_summary(
+                event_id, room_id
+            )
+            participated = await self.get_thread_participated(
+                event_id, room_id, user_id
+            )
             if latest_thread_event:
                 aggregations[RelationTypes.THREAD] = {
-                    # Don't bundle aggregations as this could recurse forever.
                     "latest_event": latest_thread_event,
                     "count": thread_count,
+                    "current_user_participated": participated,
                 }
 
         # Store the bundled aggregations in the event metadata for later use.
         return aggregations
 
     async def get_bundled_aggregations(
-        self, events: Iterable[EventBase]
+        self,
+        events: Iterable[EventBase],
+        user_id: str,
     ) -> Dict[str, Dict[str, Any]]:
         """Generate bundled aggregations for events.
 
         Args:
             events: The iterable of events to calculate bundled aggregations for.
+            user_id: The user requesting the bundled aggregations.
 
         Returns:
             A map of event ID to the bundled aggregation for the event. Not all
@@ -631,7 +675,7 @@ class RelationsWorkerStore(SQLBaseStore):
         # TODO Parallelize.
         results = {}
         for event in events:
-            event_result = await self._get_bundled_aggregation_for_event(event)
+            event_result = await self._get_bundled_aggregation_for_event(event, user_id)
             if event_result is not None:
                 results[event.event_id] = event_result