summary refs log tree commit diff
path: root/synapse/storage/databases/main/relations.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/relations.py')
-rw-r--r--synapse/storage/databases/main/relations.py58
1 files changed, 57 insertions, 1 deletions
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 53576ad52f..907af10995 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -20,7 +20,7 @@ import attr
 from synapse.api.constants import RelationTypes
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import LoggingTransaction
+from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
 from synapse.storage.databases.main.stream import generate_pagination_where_clause
 from synapse.storage.relations import (
     AggregationPaginationToken,
@@ -334,6 +334,62 @@ class RelationsWorkerStore(SQLBaseStore):
 
         return count, latest_event
 
+    async def events_have_relations(
+        self,
+        parent_ids: List[str],
+        relation_senders: Optional[List[str]],
+        relation_types: Optional[List[str]],
+    ) -> List[str]:
+        """Check which events have a relationship from the given senders of the
+        given types.
+
+        Args:
+            parent_ids: The events being annotated
+            relation_senders: The relation senders to check.
+            relation_types: The relation types to check.
+
+        Returns:
+            True if the event has at least one relationship from one of the given senders of the given type.
+        """
+        # If no restrictions are given then the event has the required relations.
+        if not relation_senders and not relation_types:
+            return parent_ids
+
+        sql = """
+            SELECT relates_to_id FROM event_relations
+            INNER JOIN events USING (event_id)
+            WHERE
+                %s;
+        """
+
+        def _get_if_event_has_relations(txn) -> List[str]:
+            clauses: List[str] = []
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "relates_to_id", parent_ids
+            )
+            clauses.append(clause)
+
+            if relation_senders:
+                clause, temp_args = make_in_list_sql_clause(
+                    txn.database_engine, "sender", relation_senders
+                )
+                clauses.append(clause)
+                args.extend(temp_args)
+            if relation_types:
+                clause, temp_args = make_in_list_sql_clause(
+                    txn.database_engine, "relation_type", relation_types
+                )
+                clauses.append(clause)
+                args.extend(temp_args)
+
+            txn.execute(sql % " AND ".join(clauses), args)
+
+            return [row[0] for row in txn]
+
+        return await self.db_pool.runInteraction(
+            "get_if_event_has_relations", _get_if_event_has_relations
+        )
+
     async def has_user_annotated_event(
         self, parent_id: str, event_type: str, aggregation_key: str, sender: str
     ) -> bool: