diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 2da2659f41..baec35ee27 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -412,16 +412,16 @@ class ApplicationServiceTransactionWorkerStore(
)
async def set_type_stream_id_for_appservice(
- self, service: ApplicationService, type: str, pos: Optional[int]
+ self, service: ApplicationService, stream_type: str, pos: Optional[int]
) -> None:
- if type not in ("read_receipt", "presence"):
+ if stream_type not in ("read_receipt", "presence"):
raise ValueError(
"Expected type to be a valid application stream id type, got %s"
- % (type,)
+ % (stream_type,)
)
def set_type_stream_id_for_appservice_txn(txn):
- stream_id_type = "%s_stream_id" % type
+ stream_id_type = "%s_stream_id" % stream_type
txn.execute(
"UPDATE application_services_state SET %s = ? WHERE as_id=?"
% stream_id_type,
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 264e625bd7..ae3afdd5d2 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -134,7 +134,10 @@ class DeviceInboxWorkerStore(SQLBaseStore):
limit: The maximum number of messages to retrieve.
Returns:
- A list of messages for the device and where in the stream the messages got to.
+ A tuple containing:
+ * A list of messages for the device.
+ * The max stream token of these messages. There may be more to retrieve
+ if the given limit was reached.
"""
has_changed = self._device_inbox_stream_cache.has_entity_changed(
user_id, last_stream_id
@@ -153,12 +156,19 @@ class DeviceInboxWorkerStore(SQLBaseStore):
txn.execute(
sql, (user_id, device_id, last_stream_id, current_stream_id, limit)
)
+
messages = []
+ stream_pos = current_stream_id
+
for row in txn:
stream_pos = row[0]
messages.append(db_to_json(row[1]))
+
+ # If the limit was not reached we know that there's no more data for this
+ # user/device pair up to current_stream_id.
if len(messages) < limit:
stream_pos = current_stream_id
+
return messages, stream_pos
return await self.db_pool.runInteraction(
@@ -260,13 +270,20 @@ class DeviceInboxWorkerStore(SQLBaseStore):
" LIMIT ?"
)
txn.execute(sql, (destination, last_stream_id, current_stream_id, limit))
+
messages = []
+ stream_pos = current_stream_id
+
for row in txn:
stream_pos = row[0]
messages.append(db_to_json(row[1]))
+
+ # If the limit was not reached we know that there's no more data for this
+ # user/device pair up to current_stream_id.
if len(messages) < limit:
log_kv({"message": "Set stream position to current position"})
stream_pos = current_stream_id
+
return messages, stream_pos
return await self.db_pool.runInteraction(
@@ -372,8 +389,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
"""Used to send messages from this server.
Args:
- local_messages_by_user_and_device:
- Dictionary of user_id to device_id to message.
+ local_messages_by_user_then_device:
+ Dictionary of recipient user_id to recipient device_id to message.
remote_messages_by_destination:
Dictionary of destination server_name to the EDU JSON to send.
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 53576ad52f..907af10995 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -20,7 +20,7 @@ import attr
from synapse.api.constants import RelationTypes
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import LoggingTransaction
+from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.relations import (
AggregationPaginationToken,
@@ -334,6 +334,62 @@ class RelationsWorkerStore(SQLBaseStore):
return count, latest_event
+ async def events_have_relations(
+ self,
+ parent_ids: List[str],
+ relation_senders: Optional[List[str]],
+ relation_types: Optional[List[str]],
+ ) -> List[str]:
+ """Check which events have a relationship from the given senders of the
+ given types.
+
+ Args:
+ parent_ids: The events being annotated
+ relation_senders: The relation senders to check.
+ relation_types: The relation types to check.
+
+ Returns:
+ True if the event has at least one relationship from one of the given senders of the given type.
+ """
+ # If no restrictions are given then the event has the required relations.
+ if not relation_senders and not relation_types:
+ return parent_ids
+
+ sql = """
+ SELECT relates_to_id FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE
+ %s;
+ """
+
+ def _get_if_event_has_relations(txn) -> List[str]:
+ clauses: List[str] = []
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "relates_to_id", parent_ids
+ )
+ clauses.append(clause)
+
+ if relation_senders:
+ clause, temp_args = make_in_list_sql_clause(
+ txn.database_engine, "sender", relation_senders
+ )
+ clauses.append(clause)
+ args.extend(temp_args)
+ if relation_types:
+ clause, temp_args = make_in_list_sql_clause(
+ txn.database_engine, "relation_type", relation_types
+ )
+ clauses.append(clause)
+ args.extend(temp_args)
+
+ txn.execute(sql % " AND ".join(clauses), args)
+
+ return [row[0] for row in txn]
+
+ return await self.db_pool.runInteraction(
+ "get_if_event_has_relations", _get_if_event_has_relations
+ )
+
async def has_user_annotated_event(
self, parent_id: str, event_type: str, aggregation_key: str, sender: str
) -> bool:
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index cefc77fa0f..17b398bb69 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1751,7 +1751,12 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
)
async def block_room(self, room_id: str, user_id: str) -> None:
- """Marks the room as blocked. Can be called multiple times.
+ """Marks the room as blocked.
+
+ Can be called multiple times (though we'll only track the last user to
+ block this room).
+
+ Can be called on a room unknown to this homeserver.
Args:
room_id: Room to block
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index dc7884b1c0..42dc807d17 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -272,31 +272,37 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
args = []
if event_filter.types:
- clauses.append("(%s)" % " OR ".join("type = ?" for _ in event_filter.types))
+ clauses.append(
+ "(%s)" % " OR ".join("event.type = ?" for _ in event_filter.types)
+ )
args.extend(event_filter.types)
for typ in event_filter.not_types:
- clauses.append("type != ?")
+ clauses.append("event.type != ?")
args.append(typ)
if event_filter.senders:
- clauses.append("(%s)" % " OR ".join("sender = ?" for _ in event_filter.senders))
+ clauses.append(
+ "(%s)" % " OR ".join("event.sender = ?" for _ in event_filter.senders)
+ )
args.extend(event_filter.senders)
for sender in event_filter.not_senders:
- clauses.append("sender != ?")
+ clauses.append("event.sender != ?")
args.append(sender)
if event_filter.rooms:
- clauses.append("(%s)" % " OR ".join("room_id = ?" for _ in event_filter.rooms))
+ clauses.append(
+ "(%s)" % " OR ".join("event.room_id = ?" for _ in event_filter.rooms)
+ )
args.extend(event_filter.rooms)
for room_id in event_filter.not_rooms:
- clauses.append("room_id != ?")
+ clauses.append("event.room_id != ?")
args.append(room_id)
if event_filter.contains_url:
- clauses.append("contains_url = ?")
+ clauses.append("event.contains_url = ?")
args.append(event_filter.contains_url)
# We're only applying the "labels" filter on the database query, because applying the
@@ -307,6 +313,23 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
clauses.append("(%s)" % " OR ".join("label = ?" for _ in event_filter.labels))
args.extend(event_filter.labels)
+ # Filter on relation_senders / relation types from the joined tables.
+ if event_filter.relation_senders:
+ clauses.append(
+ "(%s)"
+ % " OR ".join(
+ "related_event.sender = ?" for _ in event_filter.relation_senders
+ )
+ )
+ args.extend(event_filter.relation_senders)
+
+ if event_filter.relation_types:
+ clauses.append(
+ "(%s)"
+ % " OR ".join("relation_type = ?" for _ in event_filter.relation_types)
+ )
+ args.extend(event_filter.relation_types)
+
return " AND ".join(clauses), args
@@ -1116,7 +1139,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
bounds = generate_pagination_where_clause(
direction=direction,
- column_names=("topological_ordering", "stream_ordering"),
+ column_names=("event.topological_ordering", "event.stream_ordering"),
from_token=from_bound,
to_token=to_bound,
engine=self.database_engine,
@@ -1133,32 +1156,51 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
select_keywords = "SELECT"
join_clause = ""
+ # Using DISTINCT in this SELECT query is quite expensive, because it
+ # requires the engine to sort on the entire (not limited) result set,
+ # i.e. the entire events table. Only use it in scenarios that could result
+ # in the same event ID occurring multiple times in the results.
+ needs_distinct = False
if event_filter and event_filter.labels:
# If we're not filtering on a label, then joining on event_labels will
# return as many row for a single event as the number of labels it has. To
# avoid this, only join if we're filtering on at least one label.
- join_clause = """
+ join_clause += """
LEFT JOIN event_labels
USING (event_id, room_id, topological_ordering)
"""
if len(event_filter.labels) > 1:
- # Using DISTINCT in this SELECT query is quite expensive, because it
- # requires the engine to sort on the entire (not limited) result set,
- # i.e. the entire events table. We only need to use it when we're
- # filtering on more than two labels, because that's the only scenario
- # in which we can possibly to get multiple times the same event ID in
- # the results.
- select_keywords += "DISTINCT"
+ # Multiple labels could cause the same event to appear multiple times.
+ needs_distinct = True
+
+ # If there is a filter on relation_senders and relation_types join to the
+ # relations table.
+ if event_filter and (
+ event_filter.relation_senders or event_filter.relation_types
+ ):
+ # Filtering by relations could cause the same event to appear multiple
+ # times (since there's no limit on the number of relations to an event).
+ needs_distinct = True
+ join_clause += """
+ LEFT JOIN event_relations AS relation ON (event.event_id = relation.relates_to_id)
+ """
+ if event_filter.relation_senders:
+ join_clause += """
+ LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id)
+ """
+
+ if needs_distinct:
+ select_keywords += " DISTINCT"
sql = """
%(select_keywords)s
- event_id, instance_name,
- topological_ordering, stream_ordering
- FROM events
+ event.event_id, event.instance_name,
+ event.topological_ordering, event.stream_ordering
+ FROM events AS event
%(join_clause)s
- WHERE outlier = ? AND room_id = ? AND %(bounds)s
- ORDER BY topological_ordering %(order)s,
- stream_ordering %(order)s LIMIT ?
+ WHERE event.outlier = ? AND event.room_id = ? AND %(bounds)s
+ ORDER BY event.topological_ordering %(order)s,
+ event.stream_ordering %(order)s LIMIT ?
""" % {
"select_keywords": select_keywords,
"join_clause": join_clause,
|