summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/events.py4
-rw-r--r--synapse/storage/databases/main/relations.py59
2 files changed, 62 insertions, 1 deletions
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 37439f8562..8d9086ecf0 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1710,6 +1710,7 @@ class PersistEventsStore:
             RelationTypes.ANNOTATION,
             RelationTypes.REFERENCE,
             RelationTypes.REPLACE,
+            RelationTypes.THREAD,
         ):
             # Unknown relation type
             return
@@ -1740,6 +1741,9 @@ class PersistEventsStore:
         if rel_type == RelationTypes.REPLACE:
             txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
 
+        if rel_type == RelationTypes.THREAD:
+            txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
+
     def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
         """Handles keeping track of insertion events and edges/connections.
         Part of MSC2716.
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 2bbf6d6a95..40760fbd1b 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import Optional
+from typing import Optional, Tuple
 
 import attr
 
@@ -269,6 +269,63 @@ class RelationsWorkerStore(SQLBaseStore):
 
         return await self.get_event(edit_id, allow_none=True)
 
+    @cached()
+    async def get_thread_summary(
+        self, event_id: str
+    ) -> Tuple[int, Optional[EventBase]]:
+        """Get the number of threaded replies, the senders of those replies, and
+        the latest reply (if any) for the given event.
+
+        Args:
+            event_id: The original event ID
+
+        Returns:
+            The number of items in the thread and the most recent response, if any.
+        """
+
+        def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]:
+            # Fetch the count of threaded events and the latest event ID.
+            # TODO Should this only allow m.room.message events.
+            sql = """
+                SELECT event_id
+                FROM event_relations
+                INNER JOIN events USING (event_id)
+                WHERE
+                    relates_to_id = ?
+                    AND relation_type = ?
+                ORDER BY topological_ordering DESC, stream_ordering DESC
+                LIMIT 1
+            """
+
+            txn.execute(sql, (event_id, RelationTypes.THREAD))
+            row = txn.fetchone()
+            if row is None:
+                return 0, None
+
+            latest_event_id = row[0]
+
+            sql = """
+                SELECT COALESCE(COUNT(event_id), 0)
+                FROM event_relations
+                WHERE
+                    relates_to_id = ?
+                    AND relation_type = ?
+            """
+            txn.execute(sql, (event_id, RelationTypes.THREAD))
+            count = txn.fetchone()[0]
+
+            return count, latest_event_id
+
+        count, latest_event_id = await self.db_pool.runInteraction(
+            "get_thread_summary", _get_thread_summary_txn
+        )
+
+        latest_event = None
+        if latest_event_id:
+            latest_event = await self.get_event(latest_event_id, allow_none=True)
+
+        return count, latest_event
+
     async def has_user_annotated_event(
         self, parent_id: str, event_type: str, aggregation_key: str, sender: str
     ) -> bool: