summary refs log tree commit diff
path: root/synapse/storage/databases/main/relations.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/relations.py')
-rw-r--r--synapse/storage/databases/main/relations.py159
1 files changed, 147 insertions, 12 deletions
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 53576ad52f..729ff17e2e 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -20,7 +20,7 @@ import attr
 from synapse.api.constants import RelationTypes
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import LoggingTransaction
+from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
 from synapse.storage.databases.main.stream import generate_pagination_where_clause
 from synapse.storage.relations import (
     AggregationPaginationToken,
@@ -37,6 +37,7 @@ class RelationsWorkerStore(SQLBaseStore):
     async def get_relations_for_event(
         self,
         event_id: str,
+        room_id: str,
         relation_type: Optional[str] = None,
         event_type: Optional[str] = None,
         aggregation_key: Optional[str] = None,
@@ -49,6 +50,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
         Args:
             event_id: Fetch events that relate to this event ID.
+            room_id: The room the event belongs to.
             relation_type: Only fetch events with this relation type, if given.
             event_type: Only fetch events with this event type, if given.
             aggregation_key: Only fetch events with this aggregation key, if given.
@@ -63,8 +65,8 @@ class RelationsWorkerStore(SQLBaseStore):
             the form `{"event_id": "..."}`.
         """
 
-        where_clause = ["relates_to_id = ?"]
-        where_args: List[Union[str, int]] = [event_id]
+        where_clause = ["relates_to_id = ?", "room_id = ?"]
+        where_args: List[Union[str, int]] = [event_id, room_id]
 
         if relation_type is not None:
             where_clause.append("relation_type = ?")
@@ -132,10 +134,74 @@ class RelationsWorkerStore(SQLBaseStore):
             "get_recent_references_for_event", _get_recent_references_for_event_txn
         )
 
+    async def event_includes_relation(self, event_id: str) -> bool:
+        """Check if the given event relates to another event.
+
+        An event has a relation if it has a valid m.relates_to with a rel_type
+        and event_id in the content:
+
+        {
+            "content": {
+                "m.relates_to": {
+                    "rel_type": "m.replace",
+                    "event_id": "$other_event_id"
+                }
+            }
+        }
+
+        Args:
+            event_id: The event to check.
+
+        Returns:
+            True if the event includes a valid relation.
+        """
+
+        result = await self.db_pool.simple_select_one_onecol(
+            table="event_relations",
+            keyvalues={"event_id": event_id},
+            retcol="event_id",
+            allow_none=True,
+            desc="event_includes_relation",
+        )
+        return result is not None
+
+    async def event_is_target_of_relation(self, parent_id: str) -> bool:
+        """Check if the given event is the target of another event's relation.
+
+        An event is the target of an event relation if it has a valid
+        m.relates_to with a rel_type and event_id pointing to parent_id in the
+        content:
+
+        {
+            "content": {
+                "m.relates_to": {
+                    "rel_type": "m.replace",
+                    "event_id": "$parent_id"
+                }
+            }
+        }
+
+        Args:
+            parent_id: The event to check.
+
+        Returns:
+            True if the event is the target of another event's relation.
+        """
+
+        result = await self.db_pool.simple_select_one_onecol(
+            table="event_relations",
+            keyvalues={"relates_to_id": parent_id},
+            retcol="event_id",
+            allow_none=True,
+            desc="event_is_target_of_relation",
+        )
+        return result is not None
+
     @cached(tree=True)
     async def get_aggregation_groups_for_event(
         self,
         event_id: str,
+        room_id: str,
         event_type: Optional[str] = None,
         limit: int = 5,
         direction: str = "b",
@@ -150,6 +216,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
         Args:
             event_id: Fetch events that relate to this event ID.
+            room_id: The room the event belongs to.
             event_type: Only fetch events with this event type, if given.
             limit: Only fetch the `limit` groups.
             direction: Whether to fetch the highest count first (`"b"`) or
@@ -162,8 +229,12 @@ class RelationsWorkerStore(SQLBaseStore):
             `type`, `key` and `count` fields.
         """
 
-        where_clause = ["relates_to_id = ?", "relation_type = ?"]
-        where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION]
+        where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"]
+        where_args: List[Union[str, int]] = [
+            event_id,
+            room_id,
+            RelationTypes.ANNOTATION,
+        ]
 
         if event_type:
             where_clause.append("type = ?")
@@ -225,7 +296,9 @@ class RelationsWorkerStore(SQLBaseStore):
         )
 
     @cached()
-    async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
+    async def get_applicable_edit(
+        self, event_id: str, room_id: str
+    ) -> Optional[EventBase]:
         """Get the most recent edit (if any) that has happened for the given
         event.
 
@@ -233,6 +306,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
         Args:
             event_id: The original event ID
+            room_id: The original event's room ID
 
         Returns:
             The most recent edit, if any.
@@ -254,13 +328,14 @@ class RelationsWorkerStore(SQLBaseStore):
             WHERE
                 relates_to_id = ?
                 AND relation_type = ?
+                AND edit.room_id = ?
                 AND edit.type = 'm.room.message'
             ORDER by edit.origin_server_ts DESC, edit.event_id DESC
             LIMIT 1
         """
 
         def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
-            txn.execute(sql, (event_id, RelationTypes.REPLACE))
+            txn.execute(sql, (event_id, RelationTypes.REPLACE, room_id))
             row = txn.fetchone()
             if row:
                 return row[0]
@@ -277,13 +352,14 @@ class RelationsWorkerStore(SQLBaseStore):
 
     @cached()
     async def get_thread_summary(
-        self, event_id: str
+        self, event_id: str, room_id: str
     ) -> Tuple[int, Optional[EventBase]]:
         """Get the number of threaded replies, the senders of those replies, and
         the latest reply (if any) for the given event.
 
         Args:
-            event_id: The original event ID
+            event_id: Summarize the thread related to this event ID.
+            room_id: The room the event belongs to.
 
         Returns:
             The number of items in the thread and the most recent response, if any.
@@ -300,12 +376,13 @@ class RelationsWorkerStore(SQLBaseStore):
                 INNER JOIN events USING (event_id)
                 WHERE
                     relates_to_id = ?
+                    AND room_id = ?
                     AND relation_type = ?
                 ORDER BY topological_ordering DESC, stream_ordering DESC
                 LIMIT 1
             """
 
-            txn.execute(sql, (event_id, RelationTypes.THREAD))
+            txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
             row = txn.fetchone()
             if row is None:
                 return 0, None
@@ -313,13 +390,15 @@ class RelationsWorkerStore(SQLBaseStore):
             latest_event_id = row[0]
 
             sql = """
-                SELECT COALESCE(COUNT(event_id), 0)
+                SELECT COUNT(event_id)
                 FROM event_relations
+                INNER JOIN events USING (event_id)
                 WHERE
                     relates_to_id = ?
+                    AND room_id = ?
                     AND relation_type = ?
             """
-            txn.execute(sql, (event_id, RelationTypes.THREAD))
+            txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
             count = txn.fetchone()[0]  # type: ignore[index]
 
             return count, latest_event_id
@@ -334,6 +413,62 @@ class RelationsWorkerStore(SQLBaseStore):
 
         return count, latest_event
 
+    async def events_have_relations(
+        self,
+        parent_ids: List[str],
+        relation_senders: Optional[List[str]],
+        relation_types: Optional[List[str]],
+    ) -> List[str]:
+        """Check which events have a relationship from the given senders of the
+        given types.
+
+        Args:
+            parent_ids: The events being annotated
+            relation_senders: The relation senders to check.
+            relation_types: The relation types to check.
+
+        Returns:
+            True if the event has at least one relationship from one of the given senders of the given type.
+        """
+        # If no restrictions are given then the event has the required relations.
+        if not relation_senders and not relation_types:
+            return parent_ids
+
+        sql = """
+            SELECT relates_to_id FROM event_relations
+            INNER JOIN events USING (event_id)
+            WHERE
+                %s;
+        """
+
+        def _get_if_events_have_relations(txn) -> List[str]:
+            clauses: List[str] = []
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "relates_to_id", parent_ids
+            )
+            clauses.append(clause)
+
+            if relation_senders:
+                clause, temp_args = make_in_list_sql_clause(
+                    txn.database_engine, "sender", relation_senders
+                )
+                clauses.append(clause)
+                args.extend(temp_args)
+            if relation_types:
+                clause, temp_args = make_in_list_sql_clause(
+                    txn.database_engine, "relation_type", relation_types
+                )
+                clauses.append(clause)
+                args.extend(temp_args)
+
+            txn.execute(sql % " AND ".join(clauses), args)
+
+            return [row[0] for row in txn]
+
+        return await self.db_pool.runInteraction(
+            "get_if_events_have_relations", _get_if_events_have_relations
+        )
+
     async def has_user_annotated_event(
         self, parent_id: str, event_type: str, aggregation_key: str, sender: str
     ) -> bool: