summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/relations.py58
-rw-r--r--synapse/storage/databases/main/stream.py86
2 files changed, 121 insertions, 23 deletions
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 53576ad52f..907af10995 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -20,7 +20,7 @@ import attr
 from synapse.api.constants import RelationTypes
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import LoggingTransaction
+from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
 from synapse.storage.databases.main.stream import generate_pagination_where_clause
 from synapse.storage.relations import (
     AggregationPaginationToken,
@@ -334,6 +334,62 @@ class RelationsWorkerStore(SQLBaseStore):
 
         return count, latest_event
 
+    async def events_have_relations(
+        self,
+        parent_ids: List[str],
+        relation_senders: Optional[List[str]],
+        relation_types: Optional[List[str]],
+    ) -> List[str]:
+        """Check which events have a relationship from the given senders of the
+        given types.
+
+        Args:
+            parent_ids: The events being annotated
+            relation_senders: The relation senders to check.
+            relation_types: The relation types to check.
+
+        Returns:
+            True if the event has at least one relationship from one of the given senders of the given type.
+        """
+        # If no restrictions are given then the event has the required relations.
+        if not relation_senders and not relation_types:
+            return parent_ids
+
+        sql = """
+            SELECT relates_to_id FROM event_relations
+            INNER JOIN events USING (event_id)
+            WHERE
+                %s;
+        """
+
+        def _get_if_event_has_relations(txn) -> List[str]:
+            clauses: List[str] = []
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "relates_to_id", parent_ids
+            )
+            clauses.append(clause)
+
+            if relation_senders:
+                clause, temp_args = make_in_list_sql_clause(
+                    txn.database_engine, "sender", relation_senders
+                )
+                clauses.append(clause)
+                args.extend(temp_args)
+            if relation_types:
+                clause, temp_args = make_in_list_sql_clause(
+                    txn.database_engine, "relation_type", relation_types
+                )
+                clauses.append(clause)
+                args.extend(temp_args)
+
+            txn.execute(sql % " AND ".join(clauses), args)
+
+            return [row[0] for row in txn]
+
+        return await self.db_pool.runInteraction(
+            "get_if_event_has_relations", _get_if_event_has_relations
+        )
+
     async def has_user_annotated_event(
         self, parent_id: str, event_type: str, aggregation_key: str, sender: str
     ) -> bool:
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index dc7884b1c0..42dc807d17 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -272,31 +272,37 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
     args = []
 
     if event_filter.types:
-        clauses.append("(%s)" % " OR ".join("type = ?" for _ in event_filter.types))
+        clauses.append(
+            "(%s)" % " OR ".join("event.type = ?" for _ in event_filter.types)
+        )
         args.extend(event_filter.types)
 
     for typ in event_filter.not_types:
-        clauses.append("type != ?")
+        clauses.append("event.type != ?")
         args.append(typ)
 
     if event_filter.senders:
-        clauses.append("(%s)" % " OR ".join("sender = ?" for _ in event_filter.senders))
+        clauses.append(
+            "(%s)" % " OR ".join("event.sender = ?" for _ in event_filter.senders)
+        )
         args.extend(event_filter.senders)
 
     for sender in event_filter.not_senders:
-        clauses.append("sender != ?")
+        clauses.append("event.sender != ?")
         args.append(sender)
 
     if event_filter.rooms:
-        clauses.append("(%s)" % " OR ".join("room_id = ?" for _ in event_filter.rooms))
+        clauses.append(
+            "(%s)" % " OR ".join("event.room_id = ?" for _ in event_filter.rooms)
+        )
         args.extend(event_filter.rooms)
 
     for room_id in event_filter.not_rooms:
-        clauses.append("room_id != ?")
+        clauses.append("event.room_id != ?")
         args.append(room_id)
 
     if event_filter.contains_url:
-        clauses.append("contains_url = ?")
+        clauses.append("event.contains_url = ?")
         args.append(event_filter.contains_url)
 
     # We're only applying the "labels" filter on the database query, because applying the
@@ -307,6 +313,23 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
         clauses.append("(%s)" % " OR ".join("label = ?" for _ in event_filter.labels))
         args.extend(event_filter.labels)
 
+    # Filter on relation_senders / relation types from the joined tables.
+    if event_filter.relation_senders:
+        clauses.append(
+            "(%s)"
+            % " OR ".join(
+                "related_event.sender = ?" for _ in event_filter.relation_senders
+            )
+        )
+        args.extend(event_filter.relation_senders)
+
+    if event_filter.relation_types:
+        clauses.append(
+            "(%s)"
+            % " OR ".join("relation_type = ?" for _ in event_filter.relation_types)
+        )
+        args.extend(event_filter.relation_types)
+
     return " AND ".join(clauses), args
 
 
@@ -1116,7 +1139,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
 
         bounds = generate_pagination_where_clause(
             direction=direction,
-            column_names=("topological_ordering", "stream_ordering"),
+            column_names=("event.topological_ordering", "event.stream_ordering"),
             from_token=from_bound,
             to_token=to_bound,
             engine=self.database_engine,
@@ -1133,32 +1156,51 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
 
         select_keywords = "SELECT"
         join_clause = ""
+        # Using DISTINCT in this SELECT query is quite expensive, because it
+        # requires the engine to sort on the entire (not limited) result set,
+        # i.e. the entire events table. Only use it in scenarios that could result
+        # in the same event ID occurring multiple times in the results.
+        needs_distinct = False
         if event_filter and event_filter.labels:
             # If we're not filtering on a label, then joining on event_labels will
             # return as many row for a single event as the number of labels it has. To
             # avoid this, only join if we're filtering on at least one label.
-            join_clause = """
+            join_clause += """
                 LEFT JOIN event_labels
                 USING (event_id, room_id, topological_ordering)
             """
             if len(event_filter.labels) > 1:
-                # Using DISTINCT in this SELECT query is quite expensive, because it
-                # requires the engine to sort on the entire (not limited) result set,
-                # i.e. the entire events table. We only need to use it when we're
-                # filtering on more than two labels, because that's the only scenario
-                # in which we can possibly to get multiple times the same event ID in
-                # the results.
-                select_keywords += "DISTINCT"
+                # Multiple labels could cause the same event to appear multiple times.
+                needs_distinct = True
+
+        # If there is a filter on relation_senders and relation_types join to the
+        # relations table.
+        if event_filter and (
+            event_filter.relation_senders or event_filter.relation_types
+        ):
+            # Filtering by relations could cause the same event to appear multiple
+            # times (since there's no limit on the number of relations to an event).
+            needs_distinct = True
+            join_clause += """
+                LEFT JOIN event_relations AS relation ON (event.event_id = relation.relates_to_id)
+            """
+            if event_filter.relation_senders:
+                join_clause += """
+                    LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id)
+                """
+
+        if needs_distinct:
+            select_keywords += " DISTINCT"
 
         sql = """
             %(select_keywords)s
-                event_id, instance_name,
-                topological_ordering, stream_ordering
-            FROM events
+                event.event_id, event.instance_name,
+                event.topological_ordering, event.stream_ordering
+            FROM events AS event
             %(join_clause)s
-            WHERE outlier = ? AND room_id = ? AND %(bounds)s
-            ORDER BY topological_ordering %(order)s,
-            stream_ordering %(order)s LIMIT ?
+            WHERE event.outlier = ? AND event.room_id = ? AND %(bounds)s
+            ORDER BY event.topological_ordering %(order)s,
+            event.stream_ordering %(order)s LIMIT ?
         """ % {
             "select_keywords": select_keywords,
             "join_clause": join_clause,