diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 4989c960a6..3fda49f31f 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -21,7 +21,7 @@
#
#
-""" This module is responsible for getting events from the DB for pagination
+"""This module is responsible for getting events from the DB for pagination
and event streaming.
The order it returns events in depend on whether we are streaming forwards or
@@ -50,6 +50,8 @@ from typing import (
Dict,
Iterable,
List,
+ Literal,
+ Mapping,
Optional,
Protocol,
Set,
@@ -60,7 +62,7 @@ from typing import (
import attr
from immutabledict import immutabledict
-from typing_extensions import Literal, assert_never
+from typing_extensions import assert_never
from twisted.internet import defer
@@ -78,9 +80,10 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
+from synapse.storage.roommember import RoomsForUserStateReset
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import PersistedEventPosition, RoomStreamToken, StrCollection
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
@@ -107,7 +110,7 @@ class PaginateFunction(Protocol):
to_key: Optional[RoomStreamToken] = None,
direction: Direction = Direction.BACKWARDS,
limit: int = 0,
- ) -> Tuple[List[EventBase], RoomStreamToken]: ...
+ ) -> Tuple[List[EventBase], RoomStreamToken, bool]: ...
# Used as return values for pagination APIs
@@ -451,6 +454,8 @@ def _filter_results_by_stream(
stream_ordering falls between the two tokens (taking a None
token to mean unbounded).
+ The token range is defined by > `lower_token` and <= `upper_token`.
+
Used to filter results from fetching events in the DB against the given
tokens. This is necessary to handle the case where the tokens include
position maps, which we handle by fetching more than necessary from the DB
@@ -678,7 +683,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
to_key: Optional[RoomStreamToken] = None,
direction: Direction = Direction.BACKWARDS,
limit: int = 0,
- ) -> Dict[str, Tuple[List[EventBase], RoomStreamToken]]:
+ ) -> Dict[str, Tuple[List[EventBase], RoomStreamToken, bool]]:
"""Get new room events in stream ordering since `from_key`.
Args:
@@ -694,6 +699,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
A map from room id to a tuple containing:
- list of recent events in the room
- stream ordering key for the start of the chunk of events returned.
+ - a boolean to indicate if there were more events but we hit the limit
When Direction.FORWARDS: from_key < x <= to_key, (ascending order)
When Direction.BACKWARDS: from_key >= x > to_key, (descending order)
@@ -749,6 +755,48 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if self._events_stream_cache.has_entity_changed(room_id, from_id)
}
+ async def get_rooms_that_have_updates_since_sliding_sync_table(
+ self,
+ room_ids: StrCollection,
+ from_key: RoomStreamToken,
+ ) -> StrCollection:
+ """Return the rooms that probably have had updates since the given
+ token (changes that are > `from_key`)."""
+ # If the stream change cache is valid for the stream token, we can just
+ # use the result of that.
+ if from_key.stream >= self._events_stream_cache.get_earliest_known_position():
+ return self._events_stream_cache.get_entities_changed(
+ room_ids, from_key.stream
+ )
+
+ def get_rooms_that_have_updates_since_sliding_sync_table_txn(
+ txn: LoggingTransaction,
+ ) -> StrCollection:
+ sql = """
+ SELECT room_id
+ FROM sliding_sync_joined_rooms
+ WHERE {clause}
+ AND event_stream_ordering > ?
+ """
+
+ results: Set[str] = set()
+ for batch in batch_iter(room_ids, 1000):
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "room_id", batch
+ )
+
+ args.append(from_key.stream)
+ txn.execute(sql.format(clause=clause), args)
+
+ results.update(row[0] for row in txn)
+
+ return results
+
+ return await self.db_pool.runInteraction(
+ "get_rooms_that_have_updates_since_sliding_sync_table",
+ get_rooms_that_have_updates_since_sliding_sync_table_txn,
+ )
+
async def paginate_room_events_by_stream_ordering(
self,
*,
@@ -757,7 +805,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
to_key: Optional[RoomStreamToken] = None,
direction: Direction = Direction.BACKWARDS,
limit: int = 0,
- ) -> Tuple[List[EventBase], RoomStreamToken]:
+ ) -> Tuple[List[EventBase], RoomStreamToken, bool]:
"""
Paginate events by `stream_ordering` in the room from the `from_key` in the
given `direction` to the `to_key` or `limit`.
@@ -772,8 +820,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
limit: Maximum number of events to return
Returns:
- The results as a list of events and a token that points to the end
- of the result set. If no events are returned then the end of the
+ The results as a list of events, a token that points to the end of
+ the result set, and a boolean to indicate if there were more events
+ but we hit the limit. If no events are returned then the end of the
stream has been reached (i.e. there are no events between `from_key`
and `to_key`).
@@ -797,7 +846,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
and to_key.is_before_or_eq(from_key)
):
# Token selection matches what we do below if there are no rows
- return [], to_key if to_key else from_key
+ return [], to_key if to_key else from_key, False
# Or vice-versa, if we're looking backwards and our `from_key` is already before
# our `to_key`.
elif (
@@ -806,7 +855,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
and from_key.is_before_or_eq(to_key)
):
# Token selection matches what we do below if there are no rows
- return [], to_key if to_key else from_key
+ return [], to_key if to_key else from_key, False
# We can do a quick sanity check to see if any events have been sent in the room
# since the earlier token.
@@ -825,7 +874,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if not has_changed:
# Token selection matches what we do below if there are no rows
- return [], to_key if to_key else from_key
+ return [], to_key if to_key else from_key, False
order, from_bound, to_bound = generate_pagination_bounds(
direction, from_key, to_key
@@ -841,7 +890,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
engine=self.database_engine,
)
- def f(txn: LoggingTransaction) -> List[_EventDictReturn]:
+ def f(txn: LoggingTransaction) -> Tuple[List[_EventDictReturn], bool]:
sql = f"""
SELECT event_id, instance_name, stream_ordering
FROM events
@@ -853,9 +902,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
txn.execute(sql, (room_id, 2 * limit))
+ # Get all the rows and check if we hit the limit.
+ fetched_rows = txn.fetchall()
+ limited = len(fetched_rows) >= 2 * limit
+
rows = [
_EventDictReturn(event_id, None, stream_ordering)
- for event_id, instance_name, stream_ordering in txn
+ for event_id, instance_name, stream_ordering in fetched_rows
if _filter_results_by_stream(
lower_token=(
to_key if direction == Direction.BACKWARDS else from_key
@@ -866,10 +919,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
instance_name=instance_name,
stream_ordering=stream_ordering,
)
- ][:limit]
- return rows
+ ]
+
+ if len(rows) > limit:
+ limited = True
- rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
+ rows = rows[:limit]
+ return rows, limited
+
+ rows, limited = await self.db_pool.runInteraction(
+ "get_room_events_stream_for_room", f
+ )
ret = await self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
@@ -886,7 +946,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# `_paginate_room_events_by_topological_ordering_txn(...)`)
next_key = to_key if to_key else from_key
- return ret, next_key
+ return ret, next_key, limited
@trace
async def get_current_state_delta_membership_changes_for_user(
@@ -927,7 +987,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
All membership changes to the current state in the token range. Events are
sorted by `stream_ordering` ascending.
+
+ `event_id`/`sender` can be `None` when the server leaves a room (meaning
+ everyone locally left) or a state reset which removed the person from the
+ room. We can't tell the difference between the two cases with what's
+ available in the `current_state_delta_stream` table. To actually check for a
+ state reset, you need to check if a membership still exists in the room.
"""
+
+ assert from_key.topological is None
+ assert to_key.topological is None
+
# Start by ruling out cases where a DB query is not necessary.
if from_key == to_key:
return []
@@ -1038,6 +1108,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
membership=(
membership if membership is not None else Membership.LEAVE
),
+ # This will also be null for the same reasons if `s.event_id = null`
sender=sender,
# Prev event
prev_event_id=prev_event_id,
@@ -1072,6 +1143,203 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if membership_change.room_id not in room_ids_to_exclude
]
+ @trace
+ async def get_sliding_sync_membership_changes(
+ self,
+ user_id: str,
+ from_key: RoomStreamToken,
+ to_key: RoomStreamToken,
+ excluded_room_ids: Optional[AbstractSet[str]] = None,
+ ) -> Dict[str, RoomsForUserStateReset]:
+ """
+ Fetch membership events that result in a meaningful membership change for a
+ given user.
+
+ A meaningful membership changes is one where the `membership` value actually
+ changes. This means memberships changes from `join` to `join` (like a display
+ name change) will be filtered out since they result in no meaningful change.
+
+ Note: This function only works with "live" tokens with `stream_ordering` only.
+
+ We're looking for membership changes in the token range (> `from_key` and <=
+ `to_key`).
+
+ Args:
+ user_id: The user ID to fetch membership events for.
+ from_key: The point in the stream to sync from (fetching events > this point).
+ to_key: The token to fetch rooms up to (fetching events <= this point).
+ excluded_room_ids: Optional list of room IDs to exclude from the results.
+
+ Returns:
+ All meaningful membership changes to the current state in the token range.
+ Events are sorted by `stream_ordering` ascending.
+
+ `event_id`/`sender` can be `None` when the server leaves a room (meaning
+ everyone locally left) or a state reset which removed the person from the
+ room. We can't tell the difference between the two cases with what's
+ available in the `current_state_delta_stream` table. To actually check for a
+ state reset, you need to check if a membership still exists in the room.
+ """
+
+ assert from_key.topological is None
+ assert to_key.topological is None
+
+ # Start by ruling out cases where a DB query is not necessary.
+ if from_key == to_key:
+ return {}
+
+ if from_key:
+ has_changed = self._membership_stream_cache.has_entity_changed(
+ user_id, int(from_key.stream)
+ )
+ if not has_changed:
+ return {}
+
+ room_ids_to_exclude: AbstractSet[str] = set()
+ if excluded_room_ids is not None:
+ room_ids_to_exclude = excluded_room_ids
+
+ def f(txn: LoggingTransaction) -> Dict[str, RoomsForUserStateReset]:
+ # To handle tokens with a non-empty instance_map we fetch more
+ # results than necessary and then filter down
+ min_from_id = from_key.stream
+ max_to_id = to_key.get_max_stream_pos()
+
+ # This query looks at membership changes in
+ # `sliding_sync_membership_snapshots` which will not include users
+ # that were state reset out of rooms; so we need to look for that
+ # case in `current_state_delta_stream`.
+ sql = """
+ SELECT
+ room_id,
+ membership_event_id,
+ event_instance_name,
+ event_stream_ordering,
+ membership,
+ sender,
+ prev_membership,
+ room_version
+ FROM
+ (
+ SELECT
+ s.room_id,
+ s.membership_event_id,
+ s.event_instance_name,
+ s.event_stream_ordering,
+ s.membership,
+ s.sender,
+ m_prev.membership AS prev_membership
+ FROM sliding_sync_membership_snapshots as s
+ LEFT JOIN event_edges AS e ON e.event_id = s.membership_event_id
+ LEFT JOIN room_memberships AS m_prev ON m_prev.event_id = e.prev_event_id
+ WHERE s.user_id = ?
+
+ UNION ALL
+
+ SELECT
+ s.room_id,
+ e.event_id,
+ s.instance_name,
+ s.stream_id,
+ m.membership,
+ e.sender,
+ m_prev.membership AS prev_membership
+ FROM current_state_delta_stream AS s
+ LEFT JOIN events AS e ON e.event_id = s.event_id
+ LEFT JOIN room_memberships AS m ON m.event_id = s.event_id
+ LEFT JOIN room_memberships AS m_prev ON m_prev.event_id = s.prev_event_id
+ WHERE
+ s.type = ?
+ AND s.state_key = ?
+ ) AS c
+ INNER JOIN rooms USING (room_id)
+ WHERE event_stream_ordering > ? AND event_stream_ordering <= ?
+ ORDER BY event_stream_ordering ASC
+ """
+
+ txn.execute(
+ sql,
+ (user_id, EventTypes.Member, user_id, min_from_id, max_to_id),
+ )
+
+ membership_changes: Dict[str, RoomsForUserStateReset] = {}
+ for (
+ room_id,
+ membership_event_id,
+ event_instance_name,
+ event_stream_ordering,
+ membership,
+ sender,
+ prev_membership,
+ room_version_id,
+ ) in txn:
+ assert room_id is not None
+ assert event_stream_ordering is not None
+
+ if room_id in room_ids_to_exclude:
+ continue
+
+ if _filter_results_by_stream(
+ from_key,
+ to_key,
+ event_instance_name,
+ event_stream_ordering,
+ ):
+ # When the server leaves a room, it will insert new rows into the
+ # `current_state_delta_stream` table with `event_id = null` for all
+ # current state. This means we might already have a row for the
+ # leave event and then another for the same leave where the
+ # `event_id=null` but the `prev_event_id` is pointing back at the
+ # earlier leave event. We don't want to report the leave, if we
+ # already have a leave event.
+ if (
+ membership_event_id is None
+ and prev_membership == Membership.LEAVE
+ ):
+ continue
+
+ if membership_event_id is None and room_id in membership_changes:
+ # SUSPICIOUS: if we join a room and get state reset out of it
+ # in the same queried window,
+ # won't this ignore the 'state reset out of it' part?
+ continue
+
+ # When `s.event_id = null`, we won't be able to get respective
+ # `room_membership` but can assume the user has left the room
+ # because this only happens when the server leaves a room
+ # (meaning everyone locally left) or a state reset which removed
+ # the person from the room.
+ membership = (
+ membership if membership is not None else Membership.LEAVE
+ )
+
+ if membership == prev_membership:
+ # If `membership` and `prev_membership` are the same then this
+ # is not a meaningful change so we can skip it.
+ # An example of this happening is when the user changes their display name.
+ continue
+
+ membership_change = RoomsForUserStateReset(
+ room_id=room_id,
+ sender=sender,
+ membership=membership,
+ event_id=membership_event_id,
+ event_pos=PersistedEventPosition(
+ event_instance_name, event_stream_ordering
+ ),
+ room_version_id=room_version_id,
+ )
+
+ membership_changes[room_id] = membership_change
+
+ return membership_changes
+
+ membership_changes = await self.db_pool.runInteraction(
+ "get_sliding_sync_membership_changes", f
+ )
+
+ return membership_changes
+
@cancellable
async def get_membership_changes_for_user(
self,
@@ -1121,9 +1389,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
AND e.stream_ordering > ? AND e.stream_ordering <= ?
%s
ORDER BY e.stream_ordering ASC
- """ % (
- ignore_room_clause,
- )
+ """ % (ignore_room_clause,)
txn.execute(sql, args)
@@ -1192,7 +1458,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if limit == 0:
return [], end_token
- rows, token = await self.db_pool.runInteraction(
+ rows, token, _ = await self.db_pool.runInteraction(
"get_recent_event_ids_for_room",
self._paginate_room_events_by_topological_ordering_txn,
room_id,
@@ -1263,12 +1529,76 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return None
+ async def get_last_event_pos_in_room(
+ self,
+ room_id: str,
+ event_types: Optional[StrCollection] = None,
+ ) -> Optional[Tuple[str, PersistedEventPosition]]:
+ """
+ Returns the ID and event position of the last event in a room.
+
+ Based on `get_last_event_pos_in_room_before_stream_ordering(...)`
+
+ Args:
+ room_id
+ event_types: Optional allowlist of event types to filter by
+
+ Returns:
+ The ID of the most recent event and it's position, or None if there are no
+ events in the room that match the given event types.
+ """
+
+ def _get_last_event_pos_in_room_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[Tuple[str, PersistedEventPosition]]:
+ event_type_clause = ""
+ event_type_args: List[str] = []
+ if event_types is not None and len(event_types) > 0:
+ event_type_clause, event_type_args = make_in_list_sql_clause(
+ txn.database_engine, "type", event_types
+ )
+ event_type_clause = f"AND {event_type_clause}"
+
+ sql = f"""
+ SELECT event_id, stream_ordering, instance_name
+ FROM events
+ LEFT JOIN rejections USING (event_id)
+ WHERE room_id = ?
+ {event_type_clause}
+ AND NOT outlier
+ AND rejections.event_id IS NULL
+ ORDER BY stream_ordering DESC
+ LIMIT 1
+ """
+
+ txn.execute(
+ sql,
+ [room_id] + event_type_args,
+ )
+
+ row = cast(Optional[Tuple[str, int, str]], txn.fetchone())
+ if row is not None:
+ event_id, stream_ordering, instance_name = row
+
+ return event_id, PersistedEventPosition(
+ # If instance_name is null we default to "master"
+ instance_name or "master",
+ stream_ordering,
+ )
+
+ return None
+
+ return await self.db_pool.runInteraction(
+ "get_last_event_pos_in_room",
+ _get_last_event_pos_in_room_txn,
+ )
+
@trace
async def get_last_event_pos_in_room_before_stream_ordering(
self,
room_id: str,
end_token: RoomStreamToken,
- event_types: Optional[Collection[str]] = None,
+ event_types: Optional[StrCollection] = None,
) -> Optional[Tuple[str, PersistedEventPosition]]:
"""
Returns the ID and event position of the last event in a room at or before a
@@ -1381,8 +1711,56 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rooms
"""
+ # First we just get the latest positions for the room, as the vast
+ # majority of them will be before the given end token anyway. By doing
+ # this we can cache most rooms.
+ uncapped_results = await self._bulk_get_max_event_pos(room_ids)
+
+ # Check that the stream position for the rooms are from before the
+ # minimum position of the token. If not then we need to fetch more
+ # rows.
+ results: Dict[str, int] = {}
+ recheck_rooms: Set[str] = set()
min_token = end_token.stream
- max_token = end_token.get_max_stream_pos()
+ for room_id, stream in uncapped_results.items():
+ if stream is None:
+ # Despite the function not directly setting None, the cache can!
+ # See: https://github.com/element-hq/synapse/issues/17726
+ continue
+ if stream <= min_token:
+ results[room_id] = stream
+ else:
+ recheck_rooms.add(room_id)
+
+ if not recheck_rooms:
+ return results
+
+ # There shouldn't be many rooms that we need to recheck, so we do them
+ # one-by-one.
+ for room_id in recheck_rooms:
+ result = await self.get_last_event_pos_in_room_before_stream_ordering(
+ room_id, end_token
+ )
+ if result is not None:
+ results[room_id] = result[1].stream
+
+ return results
+
+ @cached()
+ async def _get_max_event_pos(self, room_id: str) -> int:
+ raise NotImplementedError()
+
+ @cachedList(cached_method_name="_get_max_event_pos", list_name="room_ids")
+ async def _bulk_get_max_event_pos(
+ self, room_ids: StrCollection
+ ) -> Mapping[str, Optional[int]]:
+ """Fetch the max position of a persisted event in the room."""
+
+ # We need to be careful not to return positions ahead of the current
+ # positions, so we get the current token now and cap our queries to it.
+ now_token = self.get_room_max_token()
+ max_pos = now_token.get_max_stream_pos()
+
results: Dict[str, int] = {}
# First, we check for the rooms in the stream change cache to see if we
@@ -1390,31 +1768,32 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
missing_room_ids: Set[str] = set()
for room_id in room_ids:
stream_pos = self._events_stream_cache.get_max_pos_of_last_change(room_id)
- if stream_pos and stream_pos <= min_token:
+ if stream_pos is not None:
results[room_id] = stream_pos
else:
missing_room_ids.add(room_id)
+ if not missing_room_ids:
+ return results
+
# Next, we query the stream position from the DB. At first we fetch all
# positions less than the *max* stream pos in the token, then filter
# them down. We do this as a) this is a cheaper query, and b) the vast
# majority of rooms will have a latest token from before the min stream
# pos.
- def bulk_get_last_event_pos_txn(
- txn: LoggingTransaction, batch_room_ids: StrCollection
+ def bulk_get_max_event_pos_fallback_txn(
+ txn: LoggingTransaction, batched_room_ids: StrCollection
) -> Dict[str, int]:
- # This query fetches the latest stream position in the rooms before
- # the given max position.
clause, args = make_in_list_sql_clause(
- self.database_engine, "room_id", batch_room_ids
+ self.database_engine, "room_id", batched_room_ids
)
sql = f"""
SELECT room_id, (
SELECT stream_ordering FROM events AS e
LEFT JOIN rejections USING (event_id)
WHERE e.room_id = r.room_id
- AND stream_ordering <= ?
+ AND e.stream_ordering <= ?
AND NOT outlier
AND rejection_reason IS NULL
ORDER BY stream_ordering DESC
@@ -1423,72 +1802,55 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
FROM rooms AS r
WHERE {clause}
"""
- txn.execute(sql, [max_token] + args)
+ txn.execute(sql, [max_pos] + args)
return {row[0]: row[1] for row in txn}
- recheck_rooms: Set[str] = set()
- for batched in batch_iter(missing_room_ids, 1000):
- result = await self.db_pool.runInteraction(
- "bulk_get_last_event_pos_in_room_before_stream_ordering",
- bulk_get_last_event_pos_txn,
- batched,
- )
-
- # Check that the stream position for the rooms are from before the
- # minimum position of the token. If not then we need to fetch more
- # rows.
- for room_id, stream in result.items():
- if stream <= min_token:
- results[room_id] = stream
- else:
- recheck_rooms.add(room_id)
-
- if not recheck_rooms:
- return results
-
- # For the remaining rooms we need to fetch all rows between the min and
- # max stream positions in the end token, and filter out the rows that
- # are after the end token.
- #
- # This query should be fast as the range between the min and max should
- # be small.
-
- def bulk_get_last_event_pos_recheck_txn(
- txn: LoggingTransaction, batch_room_ids: StrCollection
+ # It's easier to look at the `sliding_sync_joined_rooms` table and avoid all of
+ # the joins and sub-queries.
+ def bulk_get_max_event_pos_from_sliding_sync_tables_txn(
+ txn: LoggingTransaction, batched_room_ids: StrCollection
) -> Dict[str, int]:
clause, args = make_in_list_sql_clause(
- self.database_engine, "room_id", batch_room_ids
+ self.database_engine, "room_id", batched_room_ids
)
sql = f"""
- SELECT room_id, instance_name, stream_ordering
- FROM events
- WHERE ? < stream_ordering AND stream_ordering <= ?
- AND NOT outlier
- AND rejection_reason IS NULL
- AND {clause}
- ORDER BY stream_ordering ASC
+ SELECT room_id, event_stream_ordering
+ FROM sliding_sync_joined_rooms
+ WHERE {clause}
+ ORDER BY event_stream_ordering DESC
"""
- txn.execute(sql, [min_token, max_token] + args)
-
- # We take the max stream ordering that is less than the token. Since
- # we ordered by stream ordering we just need to iterate through and
- # take the last matching stream ordering.
- txn_results: Dict[str, int] = {}
- for row in txn:
- room_id = row[0]
- event_pos = PersistedEventPosition(row[1], row[2])
- if not event_pos.persisted_after(end_token):
- txn_results[room_id] = event_pos.stream
-
- return txn_results
-
- for batched in batch_iter(recheck_rooms, 1000):
- recheck_result = await self.db_pool.runInteraction(
- "bulk_get_last_event_pos_in_room_before_stream_ordering_recheck",
- bulk_get_last_event_pos_recheck_txn,
- batched,
+ txn.execute(sql, args)
+ return {row[0]: row[1] for row in txn}
+
+ recheck_rooms: Set[str] = set()
+ for batched in batch_iter(room_ids, 1000):
+ if await self.have_finished_sliding_sync_background_jobs():
+ batch_results = await self.db_pool.runInteraction(
+ "bulk_get_max_event_pos_from_sliding_sync_tables_txn",
+ bulk_get_max_event_pos_from_sliding_sync_tables_txn,
+ batched,
+ )
+ else:
+ batch_results = await self.db_pool.runInteraction(
+ "bulk_get_max_event_pos_fallback_txn",
+ bulk_get_max_event_pos_fallback_txn,
+ batched,
+ )
+ for room_id, stream_ordering in batch_results.items():
+ if stream_ordering <= now_token.stream:
+ results[room_id] = stream_ordering
+ else:
+ recheck_rooms.add(room_id)
+
+ # We now need to handle rooms where the above query returned a stream
+ # position that was potentially too new. This should happen very rarely
+ # so we just query the rooms one-by-one
+ for room_id in recheck_rooms:
+ result = await self.get_last_event_pos_in_room_before_stream_ordering(
+ room_id, now_token
)
- results.update(recheck_result)
+ if result is not None:
+ results[room_id] = result[1].stream
return results
@@ -1680,15 +2042,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
dict
"""
- stream_ordering, topological_ordering = cast(
- Tuple[int, int],
- self.db_pool.simple_select_one_txn(
- txn,
- "events",
- keyvalues={"event_id": event_id, "room_id": room_id},
- retcols=["stream_ordering", "topological_ordering"],
- ),
+ row = self.db_pool.simple_select_one_txn(
+ txn,
+ "events",
+ keyvalues={"event_id": event_id, "room_id": room_id},
+ retcols=("stream_ordering", "topological_ordering"),
)
+ stream_ordering = int(row[0])
+ topological_ordering = int(row[1])
# Paginating backwards includes the event at the token, but paginating
# forward doesn't.
@@ -1700,7 +2061,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
topological=topological_ordering, stream=stream_ordering
)
- rows, start_token = self._paginate_room_events_by_topological_ordering_txn(
+ rows, start_token, _ = self._paginate_room_events_by_topological_ordering_txn(
txn,
room_id,
before_token,
@@ -1710,7 +2071,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
events_before = [r.event_id for r in rows]
- rows, end_token = self._paginate_room_events_by_topological_ordering_txn(
+ rows, end_token, _ = self._paginate_room_events_by_topological_ordering_txn(
txn,
room_id,
after_token,
@@ -1882,7 +2243,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
direction: Direction = Direction.BACKWARDS,
limit: int = 0,
event_filter: Optional[Filter] = None,
- ) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
+ ) -> Tuple[List[_EventDictReturn], RoomStreamToken, bool]:
"""Returns list of events before or after a given token.
Args:
@@ -1897,10 +2258,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
those that match the filter.
Returns:
- A list of _EventDictReturn and a token that points to the end of the
- result set. If no events are returned then the end of the stream has
- been reached (i.e. there are no events between `from_token` and
- `to_token`), or `limit` is zero.
+ A list of _EventDictReturn, a token that points to the end of the
+ result set, and a boolean to indicate if there were more events but
+ we hit the limit. If no events are returned then the end of the
+ stream has been reached (i.e. there are no events between
+ `from_token` and `to_token`), or `limit` is zero.
"""
# We can bail early if we're looking forwards, and our `to_key` is already
# before our `from_token`.
@@ -1910,7 +2272,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
and to_token.is_before_or_eq(from_token)
):
# Token selection matches what we do below if there are no rows
- return [], to_token if to_token else from_token
+ return [], to_token if to_token else from_token, False
# Or vice-versa, if we're looking backwards and our `from_token` is already before
# our `to_token`.
elif (
@@ -1919,7 +2281,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
and from_token.is_before_or_eq(to_token)
):
# Token selection matches what we do below if there are no rows
- return [], to_token if to_token else from_token
+ return [], to_token if to_token else from_token, False
args: List[Any] = [room_id]
@@ -1942,6 +2304,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
args.extend(filter_args)
# We fetch more events as we'll filter the result set
+ requested_limit = int(limit) * 2
args.append(int(limit) * 2)
select_keywords = "SELECT"
@@ -2006,10 +2369,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
}
txn.execute(sql, args)
+ # Get all the rows and check if we hit the limit.
+ fetched_rows = txn.fetchall()
+ limited = len(fetched_rows) >= requested_limit
+
# Filter the result set.
rows = [
_EventDictReturn(event_id, topological_ordering, stream_ordering)
- for event_id, instance_name, topological_ordering, stream_ordering in txn
+ for event_id, instance_name, topological_ordering, stream_ordering in fetched_rows
if _filter_results(
lower_token=(
to_token if direction == Direction.BACKWARDS else from_token
@@ -2021,7 +2388,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
topological_ordering=topological_ordering,
stream_ordering=stream_ordering,
)
- ][:limit]
+ ]
+
+ if len(rows) > limit:
+ limited = True
+
+ rows = rows[:limit]
if rows:
assert rows[-1].topological_ordering is not None
@@ -2032,7 +2404,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# TODO (erikj): We should work out what to do here instead.
next_token = to_token if to_token else from_token
- return rows, next_token
+ return rows, next_token, limited
@trace
@tag_args
@@ -2045,7 +2417,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
direction: Direction = Direction.BACKWARDS,
limit: int = 0,
event_filter: Optional[Filter] = None,
- ) -> Tuple[List[EventBase], RoomStreamToken]:
+ ) -> Tuple[List[EventBase], RoomStreamToken, bool]:
"""
Paginate events by `topological_ordering` (tie-break with `stream_ordering`) in
the room from the `from_key` in the given `direction` to the `to_key` or
@@ -2062,8 +2434,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
event_filter: If provided filters the events to those that match the filter.
Returns:
- The results as a list of events and a token that points to the end
- of the result set. If no events are returned then the end of the
+ The results as a list of events, a token that points to the end of
+ the result set, and a boolean to indicate if there were more events
+ but we hit the limit. If no events are returned then the end of the
stream has been reached (i.e. there are no events between `from_key`
and `to_key`).
@@ -2087,7 +2460,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
):
# Token selection matches what we do in `_paginate_room_events_txn` if there
# are no rows
- return [], to_key if to_key else from_key
+ return [], to_key if to_key else from_key, False
# Or vice-versa, if we're looking backwards and our `from_key` is already before
# our `to_key`.
elif (
@@ -2097,9 +2470,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
):
# Token selection matches what we do in `_paginate_room_events_txn` if there
# are no rows
- return [], to_key if to_key else from_key
+ return [], to_key if to_key else from_key, False
- rows, token = await self.db_pool.runInteraction(
+ rows, token, limited = await self.db_pool.runInteraction(
"paginate_room_events_by_topological_ordering",
self._paginate_room_events_by_topological_ordering_txn,
room_id,
@@ -2114,7 +2487,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
[r.event_id for r in rows], get_prev_content=True
)
- return events, token
+ return events, token, limited
@cached()
async def get_id_for_instance(self, instance_name: str) -> int:
|