diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index b7eb3116ae..d34376b8df 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -44,6 +44,7 @@ what sort order was used:
import logging
from typing import (
TYPE_CHECKING,
+ AbstractSet,
Any,
Collection,
Dict,
@@ -62,7 +63,7 @@ from typing_extensions import Literal
from twisted.internet import defer
-from synapse.api.constants import Direction
+from synapse.api.constants import Direction, EventTypes, Membership
from synapse.api.filtering import Filter
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
@@ -111,6 +112,32 @@ class _EventsAround:
end: RoomStreamToken
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class CurrentStateDeltaMembership:
+ """
+ Attributes:
+ event_id: The "current" membership event ID in this room.
+ event_pos: The position of the "current" membership event in the event stream.
+ prev_event_id: The previous membership event in this room that was replaced by
+ the "current" one. May be `None` if there was no previous membership event.
+ room_id: The room ID of the membership event.
+ membership: The membership state of the user in the room
+ sender: The person who sent the membership event
+ """
+
+ room_id: str
+ # Event
+ event_id: Optional[str]
+ event_pos: PersistedEventPosition
+ membership: str
+ sender: Optional[str]
+ # Prev event
+ prev_event_id: Optional[str]
+ prev_event_pos: Optional[PersistedEventPosition]
+ prev_membership: Optional[str]
+ prev_sender: Optional[str]
+
+
def generate_pagination_where_clause(
direction: Direction,
column_names: Tuple[str, str],
@@ -390,6 +417,43 @@ def _filter_results(
return True
+def _filter_results_by_stream(
+ lower_token: Optional[RoomStreamToken],
+ upper_token: Optional[RoomStreamToken],
+ instance_name: str,
+ stream_ordering: int,
+) -> bool:
+ """
+ This function only works with "live" tokens with `stream_ordering` only. See
+ `_filter_results(...)` if you want to work with all tokens.
+
+ Returns True if the event persisted by the given instance at the given
+ stream_ordering falls between the two tokens (taking a None
+ token to mean unbounded).
+
+ Used to filter results from fetching events in the DB against the given
+ tokens. This is necessary to handle the case where the tokens include
+ position maps, which we handle by fetching more than necessary from the DB
+ and then filtering (rather than attempting to construct a complicated SQL
+ query).
+ """
+ if lower_token:
+ assert lower_token.topological is None
+
+ # If these are live tokens we compare the stream ordering against the
+ # writers stream position.
+ if stream_ordering <= lower_token.get_stream_pos_for_instance(instance_name):
+ return False
+
+ if upper_token:
+ assert upper_token.topological is None
+
+ if upper_token.get_stream_pos_for_instance(instance_name) < stream_ordering:
+ return False
+
+ return True
+
+
def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
# NB: This may create SQL clauses that don't optimise well (and we don't
# have indices on all possible clauses). E.g. it may create
@@ -734,6 +798,191 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret, key
+ async def get_current_state_delta_membership_changes_for_user(
+ self,
+ user_id: str,
+ from_key: RoomStreamToken,
+ to_key: RoomStreamToken,
+ excluded_room_ids: Optional[List[str]] = None,
+ ) -> List[CurrentStateDeltaMembership]:
+ """
+ Fetch membership events (and the previous event that was replaced by that one)
+ for a given user.
+
+ Note: This function only works with "live" tokens with `stream_ordering` only.
+
+ We're looking for membership changes in the token range (> `from_key` and <=
+ `to_key`).
+
+ Please be mindful to only use this with `from_key` and `to_key` tokens that are
+ recent enough to be after when the first local user joined the room. Otherwise,
+ the results may be incomplete or too greedy. For example, if you use a token
+ range before the first local user joined the room, you will see 0 events since
+ `current_state_delta_stream` tracks what the server thinks is the current state
+ of the room as time goes. It does not track how state progresses from the
+ beginning of the room. So for example, when you remotely join a room, the first
+ rows will just be the state when you joined and progress from there.
+
+ You can probably reasonably use this with `/sync` because the `to_key` passed in
+ will be the "current" now token and the range will cover when the user joined
+ the room.
+
+ Args:
+ user_id: The user ID to fetch membership events for.
+ from_key: The point in the stream to sync from (fetching events > this point).
+ to_key: The token to fetch rooms up to (fetching events <= this point).
+ excluded_room_ids: Optional list of room IDs to exclude from the results.
+
+ Returns:
+ All membership changes to the current state in the token range. Events are
+ sorted by `stream_ordering` ascending.
+ """
+ # Start by ruling out cases where a DB query is not necessary.
+ if from_key == to_key:
+ return []
+
+ if from_key:
+ has_changed = self._membership_stream_cache.has_entity_changed(
+ user_id, int(from_key.stream)
+ )
+ if not has_changed:
+ return []
+
+ def f(txn: LoggingTransaction) -> List[CurrentStateDeltaMembership]:
+ # To handle tokens with a non-empty instance_map we fetch more
+ # results than necessary and then filter down
+ min_from_id = from_key.stream
+ max_to_id = to_key.get_max_stream_pos()
+
+ args: List[Any] = [min_from_id, max_to_id, EventTypes.Member, user_id]
+
+ # TODO: It would be good to assert that the `from_token`/`to_token` is >=
+ # the first row in `current_state_delta_stream` for the rooms we're
+ # interested in. Otherwise, we will end up with empty results and not know
+ # it.
+
+ # We could `COALESCE(e.stream_ordering, s.stream_id)` to get more accurate
+ # stream positioning when available but given our usages, we can avoid the
+ # complexity. Between two (valid) stream tokens, we will still get all of
+ # the state changes. Since those events are persisted in a batch, valid
+ # tokens will either be before or after the batch of events.
+ #
+ # `stream_ordering` from the `events` table is more accurate when available
+ # since the `current_state_delta_stream` table only tracks that the current
+ # state is at this stream position (not what stream position the state event
+ # was added) and uses the *minimum* stream position for batches of events.
+ sql = """
+ SELECT
+ s.room_id,
+ e.event_id,
+ s.instance_name,
+ s.stream_id,
+ m.membership,
+ e.sender,
+ s.prev_event_id,
+ e_prev.instance_name AS prev_instance_name,
+ e_prev.stream_ordering AS prev_stream_ordering,
+ m_prev.membership AS prev_membership,
+ e_prev.sender AS prev_sender
+ FROM current_state_delta_stream AS s
+ LEFT JOIN events AS e ON e.event_id = s.event_id
+ LEFT JOIN room_memberships AS m ON m.event_id = s.event_id
+ LEFT JOIN events AS e_prev ON e_prev.event_id = s.prev_event_id
+ LEFT JOIN room_memberships AS m_prev ON m_prev.event_id = s.prev_event_id
+ WHERE s.stream_id > ? AND s.stream_id <= ?
+ AND s.type = ?
+ AND s.state_key = ?
+ ORDER BY s.stream_id ASC
+ """
+
+ txn.execute(sql, args)
+
+ membership_changes: List[CurrentStateDeltaMembership] = []
+ for (
+ room_id,
+ event_id,
+ instance_name,
+ stream_ordering,
+ membership,
+ sender,
+ prev_event_id,
+ prev_instance_name,
+ prev_stream_ordering,
+ prev_membership,
+ prev_sender,
+ ) in txn:
+ assert room_id is not None
+ assert instance_name is not None
+ assert stream_ordering is not None
+
+ if _filter_results_by_stream(
+ from_key,
+ to_key,
+ instance_name,
+ stream_ordering,
+ ):
+ # When the server leaves a room, it will insert new rows into the
+ # `current_state_delta_stream` table with `event_id = null` for all
+ # current state. This means we might already have a row for the
+ # leave event and then another for the same leave where the
+ # `event_id=null` but the `prev_event_id` is pointing back at the
+ # earlier leave event. We don't want to report the leave, if we
+ # already have a leave event.
+ if event_id is None and prev_membership == Membership.LEAVE:
+ continue
+
+ membership_change = CurrentStateDeltaMembership(
+ room_id=room_id,
+ # Event
+ event_id=event_id,
+ event_pos=PersistedEventPosition(
+ instance_name=instance_name,
+ stream=stream_ordering,
+ ),
+ # When `s.event_id = null`, we won't be able to get respective
+ # `room_membership` but can assume the user has left the room
+ # because this only happens when the server leaves a room
+ # (meaning everyone locally left) or a state reset which removed
+ # the person from the room.
+ membership=(
+ membership if membership is not None else Membership.LEAVE
+ ),
+ sender=sender,
+ # Prev event
+ prev_event_id=prev_event_id,
+ prev_event_pos=(
+ PersistedEventPosition(
+ instance_name=prev_instance_name,
+ stream=prev_stream_ordering,
+ )
+ if (
+ prev_instance_name is not None
+ and prev_stream_ordering is not None
+ )
+ else None
+ ),
+ prev_membership=prev_membership,
+ prev_sender=prev_sender,
+ )
+
+ membership_changes.append(membership_change)
+
+ return membership_changes
+
+ membership_changes = await self.db_pool.runInteraction(
+ "get_current_state_delta_membership_changes_for_user", f
+ )
+
+ room_ids_to_exclude: AbstractSet[str] = set()
+ if excluded_room_ids is not None:
+ room_ids_to_exclude = set(excluded_room_ids)
+
+ return [
+ membership_change
+ for membership_change in membership_changes
+ if membership_change.room_id not in room_ids_to_exclude
+ ]
+
@cancellable
async def get_membership_changes_for_user(
self,
@@ -769,10 +1018,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
ignore_room_clause = ""
if excluded_rooms is not None and len(excluded_rooms) > 0:
- ignore_room_clause = "AND e.room_id NOT IN (%s)" % ",".join(
- "?" for _ in excluded_rooms
+ ignore_room_clause, ignore_room_args = make_in_list_sql_clause(
+ txn.database_engine, "e.room_id", excluded_rooms, negative=True
)
- args = args + excluded_rooms
+ ignore_room_clause = f"AND {ignore_room_clause}"
+ args += ignore_room_args
sql = """
SELECT m.event_id, instance_name, topological_ordering, stream_ordering
@@ -1554,6 +1804,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
) -> Tuple[List[EventBase], RoomStreamToken]:
"""Returns list of events before or after a given token.
+ When Direction.FORWARDS: from_key < x <= to_key
+ When Direction.BACKWARDS: from_key >= x > to_key
+
Args:
room_id
from_key: The token used to stream from
@@ -1570,6 +1823,27 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
and `to_key`).
"""
+ # We can bail early if we're looking forwards, and our `to_key` is already
+ # before our `from_key`.
+ if (
+ direction == Direction.FORWARDS
+ and to_key is not None
+ and to_key.is_before_or_eq(from_key)
+ ):
+ # Token selection matches what we do in `_paginate_room_events_txn` if there
+ # are no rows
+ return [], to_key if to_key else from_key
+ # Or vice-versa, if we're looking backwards and our `from_key` is already before
+ # our `to_key`.
+ elif (
+ direction == Direction.BACKWARDS
+ and to_key is not None
+ and from_key.is_before_or_eq(to_key)
+ ):
+ # Token selection matches what we do in `_paginate_room_events_txn` if there
+ # are no rows
+ return [], to_key if to_key else from_key
+
rows, token = await self.db_pool.runInteraction(
"paginate_room_events",
self._paginate_room_events_txn,
|