summary refs log tree commit diff
path: root/synapse/storage/databases/main/stream.py
diff options
context:
space:
mode:
authorEric Eastwood <eric.eastwood@beta.gouv.fr>2024-07-02 11:07:05 -0500
committerGitHub <noreply@github.com>2024-07-02 11:07:05 -0500
commitfa916558056013678e88d9dc2a2f64b161d9c77f (patch)
tree2a726ca48f2a131047d31199a481c9b41f539f5c /synapse/storage/databases/main/stream.py
parentMerge branch 'release-v1.110' into develop (diff)
downloadsynapse-fa916558056013678e88d9dc2a2f64b161d9c77f.tar.xz
Return some room data in Sliding Sync `/sync` (#17320)
 - Timeline events
 - Stripped `invite_state`

Based on [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575): Sliding Sync
Diffstat (limited to 'synapse/storage/databases/main/stream.py')
-rw-r--r--synapse/storage/databases/main/stream.py282
1 files changed, 278 insertions, 4 deletions
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index b7eb3116ae..d34376b8df 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -44,6 +44,7 @@ what sort order was used:
 import logging
 from typing import (
     TYPE_CHECKING,
+    AbstractSet,
     Any,
     Collection,
     Dict,
@@ -62,7 +63,7 @@ from typing_extensions import Literal
 
 from twisted.internet import defer
 
-from synapse.api.constants import Direction
+from synapse.api.constants import Direction, EventTypes, Membership
 from synapse.api.filtering import Filter
 from synapse.events import EventBase
 from synapse.logging.context import make_deferred_yieldable, run_in_background
@@ -111,6 +112,32 @@ class _EventsAround:
     end: RoomStreamToken
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class CurrentStateDeltaMembership:
+    """
+    Attributes:
+        event_id: The "current" membership event ID in this room.
+        event_pos: The position of the "current" membership event in the event stream.
+        prev_event_id: The previous membership event in this room that was replaced by
+            the "current" one. May be `None` if there was no previous membership event.
+        room_id: The room ID of the membership event.
+        membership: The membership state of the user in the room
+        sender: The person who sent the membership event
+    """
+
+    room_id: str
+    # Event
+    event_id: Optional[str]
+    event_pos: PersistedEventPosition
+    membership: str
+    sender: Optional[str]
+    # Prev event
+    prev_event_id: Optional[str]
+    prev_event_pos: Optional[PersistedEventPosition]
+    prev_membership: Optional[str]
+    prev_sender: Optional[str]
+
+
 def generate_pagination_where_clause(
     direction: Direction,
     column_names: Tuple[str, str],
@@ -390,6 +417,43 @@ def _filter_results(
     return True
 
 
+def _filter_results_by_stream(
+    lower_token: Optional[RoomStreamToken],
+    upper_token: Optional[RoomStreamToken],
+    instance_name: str,
+    stream_ordering: int,
+) -> bool:
+    """
+    This function only works with "live" tokens with `stream_ordering` only. See
+    `_filter_results(...)` if you want to work with all tokens.
+
+    Returns True if the event persisted by the given instance at the given
+    stream_ordering falls between the two tokens (taking a None
+    token to mean unbounded).
+
+    Used to filter results from fetching events in the DB against the given
+    tokens. This is necessary to handle the case where the tokens include
+    position maps, which we handle by fetching more than necessary from the DB
+    and then filtering (rather than attempting to construct a complicated SQL
+    query).
+    """
+    if lower_token:
+        assert lower_token.topological is None
+
+        # If these are live tokens we compare the stream ordering against the
+        # writers stream position.
+        if stream_ordering <= lower_token.get_stream_pos_for_instance(instance_name):
+            return False
+
+    if upper_token:
+        assert upper_token.topological is None
+
+        if upper_token.get_stream_pos_for_instance(instance_name) < stream_ordering:
+            return False
+
+    return True
+
+
 def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
     # NB: This may create SQL clauses that don't optimise well (and we don't
     # have indices on all possible clauses). E.g. it may create
@@ -734,6 +798,191 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return ret, key
 
+    async def get_current_state_delta_membership_changes_for_user(
+        self,
+        user_id: str,
+        from_key: RoomStreamToken,
+        to_key: RoomStreamToken,
+        excluded_room_ids: Optional[List[str]] = None,
+    ) -> List[CurrentStateDeltaMembership]:
+        """
+        Fetch membership events (and the previous event that was replaced by that one)
+        for a given user.
+
+        Note: This function only works with "live" tokens with `stream_ordering` only.
+
+        We're looking for membership changes in the token range (> `from_key` and <=
+        `to_key`).
+
+        Please be mindful to only use this with `from_key` and `to_key` tokens that are
+        recent enough to be after when the first local user joined the room. Otherwise,
+        the results may be incomplete or too greedy. For example, if you use a token
+        range before the first local user joined the room, you will see 0 events since
+        `current_state_delta_stream` tracks what the server thinks is the current state
+        of the room as time goes. It does not track how state progresses from the
+        beginning of the room. So for example, when you remotely join a room, the first
+        rows will just be the state when you joined and progress from there.
+
+        You can probably reasonably use this with `/sync` because the `to_key` passed in
+        will be the "current" now token and the range will cover when the user joined
+        the room.
+
+        Args:
+            user_id: The user ID to fetch membership events for.
+            from_key: The point in the stream to sync from (fetching events > this point).
+            to_key: The token to fetch rooms up to (fetching events <= this point).
+            excluded_room_ids: Optional list of room IDs to exclude from the results.
+
+        Returns:
+            All membership changes to the current state in the token range. Events are
+            sorted by `stream_ordering` ascending.
+        """
+        # Start by ruling out cases where a DB query is not necessary.
+        if from_key == to_key:
+            return []
+
+        if from_key:
+            has_changed = self._membership_stream_cache.has_entity_changed(
+                user_id, int(from_key.stream)
+            )
+            if not has_changed:
+                return []
+
+        def f(txn: LoggingTransaction) -> List[CurrentStateDeltaMembership]:
+            # To handle tokens with a non-empty instance_map we fetch more
+            # results than necessary and then filter down
+            min_from_id = from_key.stream
+            max_to_id = to_key.get_max_stream_pos()
+
+            args: List[Any] = [min_from_id, max_to_id, EventTypes.Member, user_id]
+
+            # TODO: It would be good to assert that the `from_token`/`to_token` is >=
+            # the first row in `current_state_delta_stream` for the rooms we're
+            # interested in. Otherwise, we will end up with empty results and not know
+            # it.
+
+            # We could `COALESCE(e.stream_ordering, s.stream_id)` to get more accurate
+            # stream positioning when available but given our usages, we can avoid the
+            # complexity. Between two (valid) stream tokens, we will still get all of
+            # the state changes. Since those events are persisted in a batch, valid
+            # tokens will either be before or after the batch of events.
+            #
+            # `stream_ordering` from the `events` table is more accurate when available
+            # since the `current_state_delta_stream` table only tracks that the current
+            # state is at this stream position (not what stream position the state event
+            # was added) and uses the *minimum* stream position for batches of events.
+            sql = """
+                SELECT
+                    s.room_id,
+                    e.event_id,
+                    s.instance_name,
+                    s.stream_id,
+                    m.membership,
+                    e.sender,
+                    s.prev_event_id,
+                    e_prev.instance_name AS prev_instance_name,
+                    e_prev.stream_ordering AS prev_stream_ordering,
+                    m_prev.membership AS prev_membership,
+                    e_prev.sender AS prev_sender
+                FROM current_state_delta_stream AS s
+                    LEFT JOIN events AS e ON e.event_id = s.event_id
+                    LEFT JOIN room_memberships AS m ON m.event_id = s.event_id
+                    LEFT JOIN events AS e_prev ON e_prev.event_id = s.prev_event_id
+                    LEFT JOIN room_memberships AS m_prev ON m_prev.event_id = s.prev_event_id
+                WHERE s.stream_id > ? AND s.stream_id <= ?
+                    AND s.type = ?
+                    AND s.state_key = ?
+                ORDER BY s.stream_id ASC
+            """
+
+            txn.execute(sql, args)
+
+            membership_changes: List[CurrentStateDeltaMembership] = []
+            for (
+                room_id,
+                event_id,
+                instance_name,
+                stream_ordering,
+                membership,
+                sender,
+                prev_event_id,
+                prev_instance_name,
+                prev_stream_ordering,
+                prev_membership,
+                prev_sender,
+            ) in txn:
+                assert room_id is not None
+                assert instance_name is not None
+                assert stream_ordering is not None
+
+                if _filter_results_by_stream(
+                    from_key,
+                    to_key,
+                    instance_name,
+                    stream_ordering,
+                ):
+                    # When the server leaves a room, it will insert new rows into the
+                    # `current_state_delta_stream` table with `event_id = null` for all
+                    # current state. This means we might already have a row for the
+                    # leave event and then another for the same leave where the
+                    # `event_id=null` but the `prev_event_id` is pointing back at the
+                    # earlier leave event. We don't want to report the leave, if we
+                    # already have a leave event.
+                    if event_id is None and prev_membership == Membership.LEAVE:
+                        continue
+
+                    membership_change = CurrentStateDeltaMembership(
+                        room_id=room_id,
+                        # Event
+                        event_id=event_id,
+                        event_pos=PersistedEventPosition(
+                            instance_name=instance_name,
+                            stream=stream_ordering,
+                        ),
+                        # When `s.event_id = null`, we won't be able to get respective
+                        # `room_membership` but can assume the user has left the room
+                        # because this only happens when the server leaves a room
+                        # (meaning everyone locally left) or a state reset which removed
+                        # the person from the room.
+                        membership=(
+                            membership if membership is not None else Membership.LEAVE
+                        ),
+                        sender=sender,
+                        # Prev event
+                        prev_event_id=prev_event_id,
+                        prev_event_pos=(
+                            PersistedEventPosition(
+                                instance_name=prev_instance_name,
+                                stream=prev_stream_ordering,
+                            )
+                            if (
+                                prev_instance_name is not None
+                                and prev_stream_ordering is not None
+                            )
+                            else None
+                        ),
+                        prev_membership=prev_membership,
+                        prev_sender=prev_sender,
+                    )
+
+                    membership_changes.append(membership_change)
+
+            return membership_changes
+
+        membership_changes = await self.db_pool.runInteraction(
+            "get_current_state_delta_membership_changes_for_user", f
+        )
+
+        room_ids_to_exclude: AbstractSet[str] = set()
+        if excluded_room_ids is not None:
+            room_ids_to_exclude = set(excluded_room_ids)
+
+        return [
+            membership_change
+            for membership_change in membership_changes
+            if membership_change.room_id not in room_ids_to_exclude
+        ]
+
     @cancellable
     async def get_membership_changes_for_user(
         self,
@@ -769,10 +1018,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
             ignore_room_clause = ""
             if excluded_rooms is not None and len(excluded_rooms) > 0:
-                ignore_room_clause = "AND e.room_id NOT IN (%s)" % ",".join(
-                    "?" for _ in excluded_rooms
+                ignore_room_clause, ignore_room_args = make_in_list_sql_clause(
+                    txn.database_engine, "e.room_id", excluded_rooms, negative=True
                 )
-                args = args + excluded_rooms
+                ignore_room_clause = f"AND {ignore_room_clause}"
+                args += ignore_room_args
 
             sql = """
                 SELECT m.event_id, instance_name, topological_ordering, stream_ordering
@@ -1554,6 +1804,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
     ) -> Tuple[List[EventBase], RoomStreamToken]:
         """Returns list of events before or after a given token.
 
+        When Direction.FORWARDS: from_key < x <= to_key
+        When Direction.BACKWARDS: from_key >= x > to_key
+
         Args:
             room_id
             from_key: The token used to stream from
@@ -1570,6 +1823,27 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             and `to_key`).
         """
 
+        # We can bail early if we're looking forwards, and our `to_key` is already
+        # before our `from_key`.
+        if (
+            direction == Direction.FORWARDS
+            and to_key is not None
+            and to_key.is_before_or_eq(from_key)
+        ):
+            # Token selection matches what we do in `_paginate_room_events_txn` if there
+            # are no rows
+            return [], to_key if to_key else from_key
+        # Or vice-versa, if we're looking backwards and our `from_key` is already before
+        # our `to_key`.
+        elif (
+            direction == Direction.BACKWARDS
+            and to_key is not None
+            and from_key.is_before_or_eq(to_key)
+        ):
+            # Token selection matches what we do in `_paginate_room_events_txn` if there
+            # are no rows
+            return [], to_key if to_key else from_key
+
         rows, token = await self.db_pool.runInteraction(
             "paginate_room_events",
             self._paginate_room_events_txn,