diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 53576ad52f..907af10995 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -20,7 +20,7 @@ import attr
from synapse.api.constants import RelationTypes
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import LoggingTransaction
+from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.relations import (
AggregationPaginationToken,
@@ -334,6 +334,62 @@ class RelationsWorkerStore(SQLBaseStore):
return count, latest_event
+ async def events_have_relations(
+ self,
+ parent_ids: List[str],
+ relation_senders: Optional[List[str]],
+ relation_types: Optional[List[str]],
+ ) -> List[str]:
+ """Check which events have a relationship from the given senders of the
+ given types.
+
+ Args:
+ parent_ids: The events being annotated
+ relation_senders: The relation senders to check.
+ relation_types: The relation types to check.
+
+ Returns:
+ True if the event has at least one relationship from one of the given senders of the given type.
+ """
+ # If no restrictions are given then the event has the required relations.
+ if not relation_senders and not relation_types:
+ return parent_ids
+
+ sql = """
+ SELECT relates_to_id FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE
+ %s;
+ """
+
+ def _get_if_event_has_relations(txn) -> List[str]:
+ clauses: List[str] = []
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "relates_to_id", parent_ids
+ )
+ clauses.append(clause)
+
+ if relation_senders:
+ clause, temp_args = make_in_list_sql_clause(
+ txn.database_engine, "sender", relation_senders
+ )
+ clauses.append(clause)
+ args.extend(temp_args)
+ if relation_types:
+ clause, temp_args = make_in_list_sql_clause(
+ txn.database_engine, "relation_type", relation_types
+ )
+ clauses.append(clause)
+ args.extend(temp_args)
+
+ txn.execute(sql % " AND ".join(clauses), args)
+
+ return [row[0] for row in txn]
+
+ return await self.db_pool.runInteraction(
+ "get_if_event_has_relations", _get_if_event_has_relations
+ )
+
async def has_user_annotated_event(
self, parent_id: str, event_type: str, aggregation_key: str, sender: str
) -> bool:
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index dc7884b1c0..42dc807d17 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -272,31 +272,37 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
args = []
if event_filter.types:
- clauses.append("(%s)" % " OR ".join("type = ?" for _ in event_filter.types))
+ clauses.append(
+ "(%s)" % " OR ".join("event.type = ?" for _ in event_filter.types)
+ )
args.extend(event_filter.types)
for typ in event_filter.not_types:
- clauses.append("type != ?")
+ clauses.append("event.type != ?")
args.append(typ)
if event_filter.senders:
- clauses.append("(%s)" % " OR ".join("sender = ?" for _ in event_filter.senders))
+ clauses.append(
+ "(%s)" % " OR ".join("event.sender = ?" for _ in event_filter.senders)
+ )
args.extend(event_filter.senders)
for sender in event_filter.not_senders:
- clauses.append("sender != ?")
+ clauses.append("event.sender != ?")
args.append(sender)
if event_filter.rooms:
- clauses.append("(%s)" % " OR ".join("room_id = ?" for _ in event_filter.rooms))
+ clauses.append(
+ "(%s)" % " OR ".join("event.room_id = ?" for _ in event_filter.rooms)
+ )
args.extend(event_filter.rooms)
for room_id in event_filter.not_rooms:
- clauses.append("room_id != ?")
+ clauses.append("event.room_id != ?")
args.append(room_id)
if event_filter.contains_url:
- clauses.append("contains_url = ?")
+ clauses.append("event.contains_url = ?")
args.append(event_filter.contains_url)
# We're only applying the "labels" filter on the database query, because applying the
@@ -307,6 +313,23 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
clauses.append("(%s)" % " OR ".join("label = ?" for _ in event_filter.labels))
args.extend(event_filter.labels)
+ # Filter on relation_senders / relation types from the joined tables.
+ if event_filter.relation_senders:
+ clauses.append(
+ "(%s)"
+ % " OR ".join(
+ "related_event.sender = ?" for _ in event_filter.relation_senders
+ )
+ )
+ args.extend(event_filter.relation_senders)
+
+ if event_filter.relation_types:
+ clauses.append(
+ "(%s)"
+ % " OR ".join("relation_type = ?" for _ in event_filter.relation_types)
+ )
+ args.extend(event_filter.relation_types)
+
return " AND ".join(clauses), args
@@ -1116,7 +1139,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
bounds = generate_pagination_where_clause(
direction=direction,
- column_names=("topological_ordering", "stream_ordering"),
+ column_names=("event.topological_ordering", "event.stream_ordering"),
from_token=from_bound,
to_token=to_bound,
engine=self.database_engine,
@@ -1133,32 +1156,51 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
select_keywords = "SELECT"
join_clause = ""
+ # Using DISTINCT in this SELECT query is quite expensive, because it
+ # requires the engine to sort on the entire (not limited) result set,
+ # i.e. the entire events table. Only use it in scenarios that could result
+ # in the same event ID occurring multiple times in the results.
+ needs_distinct = False
if event_filter and event_filter.labels:
# If we're not filtering on a label, then joining on event_labels will
# return as many row for a single event as the number of labels it has. To
# avoid this, only join if we're filtering on at least one label.
- join_clause = """
+ join_clause += """
LEFT JOIN event_labels
USING (event_id, room_id, topological_ordering)
"""
if len(event_filter.labels) > 1:
- # Using DISTINCT in this SELECT query is quite expensive, because it
- # requires the engine to sort on the entire (not limited) result set,
- # i.e. the entire events table. We only need to use it when we're
- # filtering on more than two labels, because that's the only scenario
- # in which we can possibly to get multiple times the same event ID in
- # the results.
- select_keywords += "DISTINCT"
+ # Multiple labels could cause the same event to appear multiple times.
+ needs_distinct = True
+
+ # If there is a filter on relation_senders and relation_types join to the
+ # relations table.
+ if event_filter and (
+ event_filter.relation_senders or event_filter.relation_types
+ ):
+ # Filtering by relations could cause the same event to appear multiple
+ # times (since there's no limit on the number of relations to an event).
+ needs_distinct = True
+ join_clause += """
+ LEFT JOIN event_relations AS relation ON (event.event_id = relation.relates_to_id)
+ """
+ if event_filter.relation_senders:
+ join_clause += """
+ LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id)
+ """
+
+ if needs_distinct:
+ select_keywords += " DISTINCT"
sql = """
%(select_keywords)s
- event_id, instance_name,
- topological_ordering, stream_ordering
- FROM events
+ event.event_id, event.instance_name,
+ event.topological_ordering, event.stream_ordering
+ FROM events AS event
%(join_clause)s
- WHERE outlier = ? AND room_id = ? AND %(bounds)s
- ORDER BY topological_ordering %(order)s,
- stream_ordering %(order)s LIMIT ?
+ WHERE event.outlier = ? AND event.room_id = ? AND %(bounds)s
+ ORDER BY event.topological_ordering %(order)s,
+ event.stream_ordering %(order)s LIMIT ?
""" % {
"select_keywords": select_keywords,
"join_clause": join_clause,
|