summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/events.py8
-rw-r--r--synapse/storage/databases/main/relations.py36
2 files changed, 32 insertions, 12 deletions
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 4e528612ea..f1f4ce5e07 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1780,10 +1780,14 @@ class PersistEventsStore:
         )
 
         if rel_type == RelationTypes.REPLACE:
-            txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
+            txn.call_after(
+                self.store.get_applicable_edit.invalidate, (parent_id, event.room_id)
+            )
 
         if rel_type == RelationTypes.THREAD:
-            txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
+            txn.call_after(
+                self.store.get_thread_summary.invalidate, (parent_id, event.room_id)
+            )
 
     def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
         """Handles keeping track of insertion events and edges/connections.
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 0a43acda07..3368a8b084 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -37,6 +37,7 @@ class RelationsWorkerStore(SQLBaseStore):
     async def get_relations_for_event(
         self,
         event_id: str,
+        room_id: str,
         relation_type: Optional[str] = None,
         event_type: Optional[str] = None,
         aggregation_key: Optional[str] = None,
@@ -49,6 +50,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
         Args:
             event_id: Fetch events that relate to this event ID.
+            room_id: The room the event belongs to.
             relation_type: Only fetch events with this relation type, if given.
             event_type: Only fetch events with this event type, if given.
             aggregation_key: Only fetch events with this aggregation key, if given.
@@ -63,8 +65,8 @@ class RelationsWorkerStore(SQLBaseStore):
             the form `{"event_id": "..."}`.
         """
 
-        where_clause = ["relates_to_id = ?"]
-        where_args: List[Union[str, int]] = [event_id]
+        where_clause = ["relates_to_id = ?", "room_id = ?"]
+        where_args: List[Union[str, int]] = [event_id, room_id]
 
         if relation_type is not None:
             where_clause.append("relation_type = ?")
@@ -199,6 +201,7 @@ class RelationsWorkerStore(SQLBaseStore):
     async def get_aggregation_groups_for_event(
         self,
         event_id: str,
+        room_id: str,
         event_type: Optional[str] = None,
         limit: int = 5,
         direction: str = "b",
@@ -213,6 +216,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
         Args:
             event_id: Fetch events that relate to this event ID.
+            room_id: The room the event belongs to.
             event_type: Only fetch events with this event type, if given.
             limit: Only fetch the `limit` groups.
             direction: Whether to fetch the highest count first (`"b"`) or
@@ -225,8 +229,12 @@ class RelationsWorkerStore(SQLBaseStore):
             `type`, `key` and `count` fields.
         """
 
-        where_clause = ["relates_to_id = ?", "relation_type = ?"]
-        where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION]
+        where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"]
+        where_args: List[Union[str, int]] = [
+            event_id,
+            room_id,
+            RelationTypes.ANNOTATION,
+        ]
 
         if event_type:
             where_clause.append("type = ?")
@@ -288,7 +296,9 @@ class RelationsWorkerStore(SQLBaseStore):
         )
 
     @cached()
-    async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
+    async def get_applicable_edit(
+        self, event_id: str, room_id: str
+    ) -> Optional[EventBase]:
         """Get the most recent edit (if any) that has happened for the given
         event.
 
@@ -296,6 +306,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
         Args:
             event_id: The original event ID
+            room_id: The original event's room ID
 
         Returns:
             The most recent edit, if any.
@@ -317,13 +328,14 @@ class RelationsWorkerStore(SQLBaseStore):
             WHERE
                 relates_to_id = ?
                 AND relation_type = ?
+                AND edit.room_id = ?
                 AND edit.type = 'm.room.message'
             ORDER by edit.origin_server_ts DESC, edit.event_id DESC
             LIMIT 1
         """
 
         def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
-            txn.execute(sql, (event_id, RelationTypes.REPLACE))
+            txn.execute(sql, (event_id, RelationTypes.REPLACE, room_id))
             row = txn.fetchone()
             if row:
                 return row[0]
@@ -340,13 +352,14 @@ class RelationsWorkerStore(SQLBaseStore):
 
     @cached()
     async def get_thread_summary(
-        self, event_id: str
+        self, event_id: str, room_id: str
     ) -> Tuple[int, Optional[EventBase]]:
         """Get the number of threaded replies, the senders of those replies, and
         the latest reply (if any) for the given event.
 
         Args:
-            event_id: The original event ID
+            event_id: Summarize the thread related to this event ID.
+            room_id: The room the event belongs to.
 
         Returns:
             The number of items in the thread and the most recent response, if any.
@@ -363,12 +376,13 @@ class RelationsWorkerStore(SQLBaseStore):
                 INNER JOIN events USING (event_id)
                 WHERE
                     relates_to_id = ?
+                    AND room_id = ?
                     AND relation_type = ?
                 ORDER BY topological_ordering DESC, stream_ordering DESC
                 LIMIT 1
             """
 
-            txn.execute(sql, (event_id, RelationTypes.THREAD))
+            txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
             row = txn.fetchone()
             if row is None:
                 return 0, None
@@ -378,11 +392,13 @@ class RelationsWorkerStore(SQLBaseStore):
             sql = """
                 SELECT COALESCE(COUNT(event_id), 0)
                 FROM event_relations
+                INNER JOIN events USING (event_id)
                 WHERE
                     relates_to_id = ?
+                    AND room_id = ?
                     AND relation_type = ?
             """
-            txn.execute(sql, (event_id, RelationTypes.THREAD))
+            txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
             count = txn.fetchone()[0]  # type: ignore[index]
 
             return count, latest_event_id