summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/events/utils.py11
-rw-r--r--synapse/rest/client/relations.py7
-rw-r--r--synapse/storage/databases/main/events.py8
-rw-r--r--synapse/storage/databases/main/relations.py36
4 files changed, 45 insertions, 17 deletions
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 84ef69df67..3da432c1df 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -454,23 +454,26 @@ class EventClientSerializer:
                 return
 
         event_id = event.event_id
+        room_id = event.room_id
 
         # The bundled aggregations to include.
         aggregations = {}
 
-        annotations = await self.store.get_aggregation_groups_for_event(event_id)
+        annotations = await self.store.get_aggregation_groups_for_event(
+            event_id, room_id
+        )
         if annotations.chunk:
             aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
 
         references = await self.store.get_relations_for_event(
-            event_id, RelationTypes.REFERENCE, direction="f"
+            event_id, room_id, RelationTypes.REFERENCE, direction="f"
         )
         if references.chunk:
             aggregations[RelationTypes.REFERENCE] = references.to_dict()
 
         edit = None
         if event.type == EventTypes.Message:
-            edit = await self.store.get_applicable_edit(event_id)
+            edit = await self.store.get_applicable_edit(event_id, room_id)
 
         if edit:
             # If there is an edit replace the content, preserving existing
@@ -503,7 +506,7 @@ class EventClientSerializer:
             (
                 thread_count,
                 latest_thread_event,
-            ) = await self.store.get_thread_summary(event_id)
+            ) = await self.store.get_thread_summary(event_id, room_id)
             if latest_thread_event:
                 aggregations[RelationTypes.THREAD] = {
                     # Don't bundle aggregations as this could recurse forever.
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index fc4e6921c5..ffa37ef06c 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -212,6 +212,7 @@ class RelationPaginationServlet(RestServlet):
 
             pagination_chunk = await self.store.get_relations_for_event(
                 event_id=parent_id,
+                room_id=room_id,
                 relation_type=relation_type,
                 event_type=event_type,
                 limit=limit,
@@ -317,6 +318,7 @@ class RelationAggregationPaginationServlet(RestServlet):
 
             pagination_chunk = await self.store.get_aggregation_groups_for_event(
                 event_id=parent_id,
+                room_id=room_id,
                 event_type=event_type,
                 limit=limit,
                 from_token=from_token,
@@ -383,7 +385,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
 
         # This checks that a) the event exists and b) the user is allowed to
         # view it.
-        await self.event_handler.get_event(requester.user, room_id, parent_id)
+        event = await self.event_handler.get_event(requester.user, room_id, parent_id)
+        if event is None:
+            raise SynapseError(404, "Unknown parent event.")
 
         if relation_type != RelationTypes.ANNOTATION:
             raise SynapseError(400, "Relation type must be 'annotation'")
@@ -402,6 +406,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
 
         result = await self.store.get_relations_for_event(
             event_id=parent_id,
+            room_id=room_id,
             relation_type=relation_type,
             event_type=event_type,
             aggregation_key=key,
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 4e528612ea..f1f4ce5e07 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1780,10 +1780,14 @@ class PersistEventsStore:
         )
 
         if rel_type == RelationTypes.REPLACE:
-            txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
+            txn.call_after(
+                self.store.get_applicable_edit.invalidate, (parent_id, event.room_id)
+            )
 
         if rel_type == RelationTypes.THREAD:
-            txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
+            txn.call_after(
+                self.store.get_thread_summary.invalidate, (parent_id, event.room_id)
+            )
 
     def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
         """Handles keeping track of insertion events and edges/connections.
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 0a43acda07..3368a8b084 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -37,6 +37,7 @@ class RelationsWorkerStore(SQLBaseStore):
     async def get_relations_for_event(
         self,
         event_id: str,
+        room_id: str,
         relation_type: Optional[str] = None,
         event_type: Optional[str] = None,
         aggregation_key: Optional[str] = None,
@@ -49,6 +50,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
         Args:
             event_id: Fetch events that relate to this event ID.
+            room_id: The room the event belongs to.
             relation_type: Only fetch events with this relation type, if given.
             event_type: Only fetch events with this event type, if given.
             aggregation_key: Only fetch events with this aggregation key, if given.
@@ -63,8 +65,8 @@ class RelationsWorkerStore(SQLBaseStore):
             the form `{"event_id": "..."}`.
         """
 
-        where_clause = ["relates_to_id = ?"]
-        where_args: List[Union[str, int]] = [event_id]
+        where_clause = ["relates_to_id = ?", "room_id = ?"]
+        where_args: List[Union[str, int]] = [event_id, room_id]
 
         if relation_type is not None:
             where_clause.append("relation_type = ?")
@@ -199,6 +201,7 @@ class RelationsWorkerStore(SQLBaseStore):
     async def get_aggregation_groups_for_event(
         self,
         event_id: str,
+        room_id: str,
         event_type: Optional[str] = None,
         limit: int = 5,
         direction: str = "b",
@@ -213,6 +216,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
         Args:
             event_id: Fetch events that relate to this event ID.
+            room_id: The room the event belongs to.
             event_type: Only fetch events with this event type, if given.
             limit: Only fetch the `limit` groups.
             direction: Whether to fetch the highest count first (`"b"`) or
@@ -225,8 +229,12 @@ class RelationsWorkerStore(SQLBaseStore):
             `type`, `key` and `count` fields.
         """
 
-        where_clause = ["relates_to_id = ?", "relation_type = ?"]
-        where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION]
+        where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"]
+        where_args: List[Union[str, int]] = [
+            event_id,
+            room_id,
+            RelationTypes.ANNOTATION,
+        ]
 
         if event_type:
             where_clause.append("type = ?")
@@ -288,7 +296,9 @@ class RelationsWorkerStore(SQLBaseStore):
         )
 
     @cached()
-    async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
+    async def get_applicable_edit(
+        self, event_id: str, room_id: str
+    ) -> Optional[EventBase]:
         """Get the most recent edit (if any) that has happened for the given
         event.
 
@@ -296,6 +306,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
         Args:
             event_id: The original event ID
+            room_id: The original event's room ID
 
         Returns:
             The most recent edit, if any.
@@ -317,13 +328,14 @@ class RelationsWorkerStore(SQLBaseStore):
             WHERE
                 relates_to_id = ?
                 AND relation_type = ?
+                AND edit.room_id = ?
                 AND edit.type = 'm.room.message'
             ORDER by edit.origin_server_ts DESC, edit.event_id DESC
             LIMIT 1
         """
 
         def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
-            txn.execute(sql, (event_id, RelationTypes.REPLACE))
+            txn.execute(sql, (event_id, RelationTypes.REPLACE, room_id))
             row = txn.fetchone()
             if row:
                 return row[0]
@@ -340,13 +352,14 @@ class RelationsWorkerStore(SQLBaseStore):
 
     @cached()
     async def get_thread_summary(
-        self, event_id: str
+        self, event_id: str, room_id: str
     ) -> Tuple[int, Optional[EventBase]]:
         """Get the number of threaded replies, the senders of those replies, and
         the latest reply (if any) for the given event.
 
         Args:
-            event_id: The original event ID
+            event_id: Summarize the thread related to this event ID.
+            room_id: The room the event belongs to.
 
         Returns:
             The number of items in the thread and the most recent response, if any.
@@ -363,12 +376,13 @@ class RelationsWorkerStore(SQLBaseStore):
                 INNER JOIN events USING (event_id)
                 WHERE
                     relates_to_id = ?
+                    AND room_id = ?
                     AND relation_type = ?
                 ORDER BY topological_ordering DESC, stream_ordering DESC
                 LIMIT 1
             """
 
-            txn.execute(sql, (event_id, RelationTypes.THREAD))
+            txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
             row = txn.fetchone()
             if row is None:
                 return 0, None
@@ -378,11 +392,13 @@ class RelationsWorkerStore(SQLBaseStore):
             sql = """
                 SELECT COALESCE(COUNT(event_id), 0)
                 FROM event_relations
+                INNER JOIN events USING (event_id)
                 WHERE
                     relates_to_id = ?
+                    AND room_id = ?
                     AND relation_type = ?
             """
-            txn.execute(sql, (event_id, RelationTypes.THREAD))
+            txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
             count = txn.fetchone()[0]  # type: ignore[index]
 
             return count, latest_event_id