summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/sliding_sync.py350
-rw-r--r--synapse/rest/client/sync.py6
-rw-r--r--synapse/server.py1
-rw-r--r--synapse/storage/databases/main/state_deltas.py37
-rw-r--r--synapse/types/handlers/__init__.py3
-rw-r--r--synapse/types/rest/client/__init__.py5
6 files changed, 362 insertions, 40 deletions
diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
index 3231574402..2b74f1c9c9 100644
--- a/synapse/handlers/sliding_sync.py
+++ b/synapse/handlers/sliding_sync.py
@@ -18,6 +18,7 @@
 #
 #
 import logging
+from enum import Enum
 from itertools import chain
 from typing import (
     TYPE_CHECKING,
@@ -34,6 +35,7 @@ from typing import (
 
 import attr
 from immutabledict import immutabledict
+from typing_extensions import assert_never
 
 from synapse.api.constants import AccountDataTypes, Direction, EventTypes, Membership
 from synapse.events import EventBase
@@ -52,6 +54,7 @@ from synapse.types import (
     RoomStreamToken,
     SlidingSyncStreamToken,
     StateMap,
+    StrCollection,
     StreamKeyType,
     StreamToken,
     UserID,
@@ -361,6 +364,8 @@ class SlidingSyncHandler:
         self.push_rules_handler = hs.get_push_rules_handler()
         self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync
 
+        self.connection_store = SlidingSyncConnectionStore()
+
     async def wait_for_sync_for_user(
         self,
         requester: Requester,
@@ -464,6 +469,11 @@ class SlidingSyncHandler:
             # See https://github.com/matrix-org/matrix-doc/issues/1144
             raise NotImplementedError()
 
+        await self.connection_store.mark_token_seen(
+            sync_config=sync_config,
+            from_token=from_token,
+        )
+
         # Get all of the room IDs that the user should be able to see in the sync
         # response
         has_lists = sync_config.lists is not None and len(sync_config.lists) > 0
@@ -613,7 +623,7 @@ class SlidingSyncHandler:
         @tag_args
         async def handle_room(room_id: str) -> None:
             room_sync_result = await self.get_room_sync_data(
-                user=sync_config.user,
+                sync_config=sync_config,
                 room_id=room_id,
                 room_sync_config=relevant_room_map[room_id],
                 room_membership_for_user_at_to_token=room_membership_for_user_map[
@@ -635,11 +645,22 @@ class SlidingSyncHandler:
             to_token=to_token,
         )
 
-        # TODO: Update this when we implement per-connection state
-        connection_token = 0
+        if has_lists or has_room_subscriptions:
+            connection_position = await self.connection_store.record_rooms(
+                sync_config=sync_config,
+                from_token=from_token,
+                sent_room_ids=relevant_room_map.keys(),
+                # TODO: We need to calculate which rooms have had updates since the `from_token` but were not included in the `sent_room_ids`
+                unsent_room_ids=[],
+            )
+        elif from_token:
+            connection_position = from_token.connection_position
+        else:
+            # Initial sync without a `from_token` starts at `0`
+            connection_position = 0
 
         return SlidingSyncResult(
-            next_pos=SlidingSyncStreamToken(to_token, connection_token),
+            next_pos=SlidingSyncStreamToken(to_token, connection_position),
             lists=lists,
             rooms=rooms,
             extensions=extensions,
@@ -1370,7 +1391,7 @@ class SlidingSyncHandler:
 
     async def get_room_sync_data(
         self,
-        user: UserID,
+        sync_config: SlidingSyncConfig,
         room_id: str,
         room_sync_config: RoomSyncConfig,
         room_membership_for_user_at_to_token: _RoomMembershipForUser,
@@ -1392,6 +1413,37 @@ class SlidingSyncHandler:
             from_token: The point in the stream to sync from.
             to_token: The point in the stream to sync up to.
         """
+        user = sync_config.user
+
+        # Determine whether we should limit the timeline to the token range.
+        #
+        # We should return historical messages (before token range) in the
+        # following cases because we want clients to be able to show a basic
+        # screen of information:
+        #  - Initial sync (because no `from_token` to limit us anyway)
+        #  - When users `newly_joined`
+        #  - For an incremental sync where we haven't sent it down this
+        #    connection before
+        from_bound = None
+        initial = True
+        if from_token and not room_membership_for_user_at_to_token.newly_joined:
+            room_status = await self.connection_store.have_sent_room(
+                sync_config=sync_config,
+                connection_token=from_token.connection_position,
+                room_id=room_id,
+            )
+            if room_status.status == HaveSentRoomFlag.LIVE:
+                from_bound = from_token.stream_token.room_key
+                initial = False
+            elif room_status.status == HaveSentRoomFlag.PREVIOUSLY:
+                assert room_status.last_token is not None
+                from_bound = room_status.last_token
+                initial = False
+            elif room_status.status == HaveSentRoomFlag.NEVER:
+                from_bound = None
+                initial = True
+            else:
+                assert_never(room_status.status)
 
         # Assemble the list of timeline events
         #
@@ -1418,36 +1470,23 @@ class SlidingSyncHandler:
             prev_batch_token = to_token
 
             # We're going to paginate backwards from the `to_token`
-            from_bound = to_token.room_key
+            to_bound = to_token.room_key
             # People shouldn't see past their leave/ban event
             if room_membership_for_user_at_to_token.membership in (
                 Membership.LEAVE,
                 Membership.BAN,
             ):
-                from_bound = (
+                to_bound = (
                     room_membership_for_user_at_to_token.event_pos.to_room_stream_token()
                 )
 
-            # Determine whether we should limit the timeline to the token range.
-            #
-            # We should return historical messages (before token range) in the
-            # following cases because we want clients to be able to show a basic
-            # screen of information:
-            #  - Initial sync (because no `from_token` to limit us anyway)
-            #  - When users `newly_joined`
-            #  - TODO: For an incremental sync where we haven't sent it down this
-            #    connection before
-            to_bound = (
-                from_token.stream_token.room_key
-                if from_token is not None
-                and not room_membership_for_user_at_to_token.newly_joined
-                else None
-            )
-
             timeline_events, new_room_key = await self.store.paginate_room_events(
                 room_id=room_id,
-                from_key=from_bound,
-                to_key=to_bound,
+                # The bounds are reversed so we can paginate backwards
+                # (from newer to older events) starting at to_bound.
+                # This ensures we fill the `limit` with the newest events first,
+                from_key=to_bound,
+                to_key=from_bound,
                 direction=Direction.BACKWARDS,
                 # We add one so we can determine if there are enough events to saturate
                 # the limit or not (see `limited`)
@@ -1564,12 +1603,6 @@ class SlidingSyncHandler:
         # indicate to the client that a state reset happened. Perhaps we should indicate
         # this by setting `initial: True` and empty `required_state`.
 
-        # TODO: Since we can't determine whether we've already sent a room down this
-        # Sliding Sync connection before (we plan to add this optimization in the
-        # future), we're always returning the requested room state instead of
-        # updates.
-        initial = True
-
         # Check whether the room has a name set
         name_state_ids = await self.get_current_state_ids_at(
             room_id=room_id,
@@ -1715,9 +1748,22 @@ class SlidingSyncHandler:
                 to_token=to_token,
             )
         else:
-            # TODO: Once we can figure out if we've sent a room down this connection before,
-            # we can return updates instead of the full required state.
-            raise NotImplementedError()
+            assert from_bound is not None
+
+            # TODO: Limit the number of state events we're about to send down
+            # the room, if its too many we should change this to an
+            # `initial=True`?
+            deltas = await self.store.get_current_state_deltas_for_room(
+                room_id=room_id,
+                from_token=from_bound,
+                to_token=to_token.room_key,
+            )
+            # TODO: Filter room state before fetching events
+            # TODO: Handle state resets where event_id is None
+            events = await self.store.get_events(
+                [d.event_id for d in deltas if d.event_id]
+            )
+            room_state = {(s.type, s.state_key): s for s in events.values()}
 
         required_room_state: StateMap[EventBase] = {}
         if required_state_filter != StateFilter.none():
@@ -1863,7 +1909,7 @@ class SlidingSyncHandler:
             to_token: The point in the stream to sync up to.
         """
         user_id = sync_config.user.to_string()
-        device_id = sync_config.device_id
+        device_id = sync_config.requester.device_id
 
         # Skip if the extension is not enabled
         if not to_device_request.enabled:
@@ -1939,7 +1985,7 @@ class SlidingSyncHandler:
             from_token: The point in the stream to sync from.
         """
         user_id = sync_config.user.to_string()
-        device_id = sync_config.device_id
+        device_id = sync_config.requester.device_id
 
         # Skip if the extension is not enabled
         if not e2ee_request.enabled:
@@ -2094,3 +2140,235 @@ class SlidingSyncHandler:
             global_account_data_map=global_account_data_map,
             account_data_by_room_map=account_data_by_room_map,
         )
+
+
+class HaveSentRoomFlag(Enum):
+    """Flag for whether we have sent the room down a sliding sync connection.
+
+    The valid state changes here are:
+        NEVER -> LIVE
+        LIVE -> PREVIOUSLY
+        PREVIOUSLY -> LIVE
+    """
+
+    # The room has never been sent down (or we have forgotten we have sent it
+    # down).
+    NEVER = 1
+
+    # We have previously sent the room down, but there are updates that we
+    # haven't sent down.
+    PREVIOUSLY = 2
+
+    # We have sent the room down and the client has received all updates.
+    LIVE = 3
+
+
+@attr.s(auto_attribs=True, slots=True, frozen=True)
+class HaveSentRoom:
+    """Whether we have sent the room down a sliding sync connection.
+
+    Attributes:
+        status: Flag of if we have or haven't sent down the room
+        last_token: If the flag is `PREVIOUSLY` then this is non-null and
+            contains the last stream token of the last updates we sent down
+            the room, i.e. we still need to send everything since then to the
+            client.
+    """
+
+    status: HaveSentRoomFlag
+    last_token: Optional[RoomStreamToken]
+
+    @staticmethod
+    def previously(last_token: RoomStreamToken) -> "HaveSentRoom":
+        """Constructor for `PREVIOUSLY` flag."""
+        return HaveSentRoom(HaveSentRoomFlag.PREVIOUSLY, last_token)
+
+
+HAVE_SENT_ROOM_NEVER = HaveSentRoom(HaveSentRoomFlag.NEVER, None)
+HAVE_SENT_ROOM_LIVE = HaveSentRoom(HaveSentRoomFlag.LIVE, None)
+
+
+@attr.s(auto_attribs=True)
+class SlidingSyncConnectionStore:
+    """In-memory store of per-connection state, including what rooms we have
+    previously sent down a sliding sync connection.
+
+    Note: This is NOT safe to run in a worker setup because connection positions will
+    point to different sets of rooms on different workers. e.g. for the same connection,
+    a connection position of 5 might have totally different states on worker A and
+    worker B.
+
+     One complication that we need to deal with here is needing to handle requests being
+    resent, i.e. if we sent down a room in a response that the client received, we must
+    consider the room *not* sent when we get the request again.
+
+    This is handled by using an integer "token", which is returned to the client
+    as part of the sync token. For each connection we store a mapping from
+    tokens to the room states, and create a new entry when we send down new
+    rooms.
+
+    Note that for any given sliding sync connection we will only store a maximum
+    of two different tokens: the previous token from the request and a new token
+    sent in the response. When we receive a request with a given token, we then
+    clear out all other entries with a different token.
+
+    Attributes:
+        _connections: Mapping from `(user_id, conn_id)` to mapping of `token`
+            to mapping of room ID to `HaveSentRoom`.
+    """
+
+    # `(user_id, conn_id)` -> `token` -> `room_id` -> `HaveSentRoom`
+    _connections: Dict[Tuple[str, str], Dict[int, Dict[str, HaveSentRoom]]] = (
+        attr.Factory(dict)
+    )
+
+    async def have_sent_room(
+        self, sync_config: SlidingSyncConfig, connection_token: int, room_id: str
+    ) -> HaveSentRoom:
+        """For the given user_id/conn_id/token, return whether we have
+        previously sent the room down
+        """
+
+        conn_key = self._get_connection_key(sync_config)
+        sync_statuses = self._connections.setdefault(conn_key, {})
+        room_status = sync_statuses.get(connection_token, {}).get(
+            room_id, HAVE_SENT_ROOM_NEVER
+        )
+
+        return room_status
+
+    async def record_rooms(
+        self,
+        sync_config: SlidingSyncConfig,
+        from_token: Optional[SlidingSyncStreamToken],
+        *,
+        sent_room_ids: StrCollection,
+        unsent_room_ids: StrCollection,
+    ) -> int:
+        """Record which rooms we have/haven't sent down in a new response
+
+        Attributes:
+            sync_config
+            from_token: The since token from the request, if any
+            sent_room_ids: The set of room IDs that we have sent down as
+                part of this request (only needs to be ones we didn't
+                previously sent down).
+            unsent_room_ids: The set of room IDs that have had updates
+                since the `from_token`, but which were not included in
+                this request
+        """
+        prev_connection_token = 0
+        if from_token is not None:
+            prev_connection_token = from_token.connection_position
+
+        # If there are no changes then this is a noop.
+        if not sent_room_ids and not unsent_room_ids:
+            return prev_connection_token
+
+        conn_key = self._get_connection_key(sync_config)
+        sync_statuses = self._connections.setdefault(conn_key, {})
+
+        # Generate a new token, removing any existing entries in that token
+        # (which can happen if requests get resent).
+        new_store_token = prev_connection_token + 1
+        sync_statuses.pop(new_store_token, None)
+
+        # Copy over and update the room mappings.
+        new_room_statuses = dict(sync_statuses.get(prev_connection_token, {}))
+
+        # Whether we have updated the `new_room_statuses`, if we don't by the
+        # end we can treat this as a noop.
+        have_updated = False
+        for room_id in sent_room_ids:
+            new_room_statuses[room_id] = HAVE_SENT_ROOM_LIVE
+            have_updated = True
+
+        # Whether we add/update the entries for unsent rooms depends on the
+        # existing entry:
+        #   - LIVE: We have previously sent down everything up to
+        #     `last_room_token, so we update the entry to be `PREVIOUSLY` with
+        #     `last_room_token`.
+        #   - PREVIOUSLY: We have previously sent down everything up to *a*
+        #     given token, so we don't need to update the entry.
+        #   - NEVER: We have never previously sent down the room, and we haven't
+        #     sent anything down this time either so we leave it as NEVER.
+
+        # Work out the new state for unsent rooms that were `LIVE`.
+        if from_token:
+            new_unsent_state = HaveSentRoom.previously(from_token.stream_token.room_key)
+        else:
+            new_unsent_state = HAVE_SENT_ROOM_NEVER
+
+        for room_id in unsent_room_ids:
+            prev_state = new_room_statuses.get(room_id)
+            if prev_state is not None and prev_state.status == HaveSentRoomFlag.LIVE:
+                new_room_statuses[room_id] = new_unsent_state
+                have_updated = True
+
+        if not have_updated:
+            return prev_connection_token
+
+        sync_statuses[new_store_token] = new_room_statuses
+
+        return new_store_token
+
+    async def mark_token_seen(
+        self,
+        sync_config: SlidingSyncConfig,
+        from_token: Optional[SlidingSyncStreamToken],
+    ) -> None:
+        """We have received a request with the given token, so we can clear out
+        any other tokens associated with the connection.
+
+        If there is no from token then we have started afresh, and so we delete
+        all tokens associated with the device.
+        """
+        # Clear out any tokens for the connection that doesn't match the one
+        # from the request.
+
+        conn_key = self._get_connection_key(sync_config)
+        sync_statuses = self._connections.pop(conn_key, {})
+        if from_token is None:
+            return
+
+        sync_statuses = {
+            connection_token: room_statuses
+            for connection_token, room_statuses in sync_statuses.items()
+            if connection_token == from_token.connection_position
+        }
+        if sync_statuses:
+            self._connections[conn_key] = sync_statuses
+
+    @staticmethod
+    def _get_connection_key(sync_config: SlidingSyncConfig) -> Tuple[str, str]:
+        """Return a unique identifier for this connection.
+
+        The first part is simply the user ID.
+
+        The second part is generally a combination of device ID and conn_id.
+        However, both these two are optional (e.g. puppet access tokens don't
+        have device IDs), so this handles those edge cases.
+
+        We use this over the raw `conn_id` to avoid clashes between different
+        clients that use the same `conn_id`. Imagine a user uses a web client
+        that uses `conn_id: main_sync_loop` and an Android client that also has
+        a `conn_id: main_sync_loop`.
+        """
+
+        user_id = sync_config.user.to_string()
+
+        # Only one sliding sync connection is allowed per given conn_id (empty
+        # or not).
+        conn_id = sync_config.conn_id or ""
+
+        if sync_config.requester.device_id:
+            return (user_id, f"D/{sync_config.requester.device_id}/{conn_id}")
+
+        if sync_config.requester.access_token_id:
+            # If we don't have a device, then the access token ID should be a
+            # stable ID.
+            return (user_id, f"A/{sync_config.requester.access_token_id}/{conn_id}")
+
+        # If we have neither then its likely an AS or some weird token. Either
+        # way we can just fail here.
+        raise Exception("Cannot use sliding sync with access token type")
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 7cf1f56435..bf3ac8d483 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -881,7 +881,6 @@ class SlidingSyncRestServlet(RestServlet):
         )
 
         user = requester.user
-        device_id = requester.device_id
 
         timeout = parse_integer(request, "timeout", default=0)
         # Position in the stream
@@ -902,11 +901,12 @@ class SlidingSyncRestServlet(RestServlet):
 
         sync_config = SlidingSyncConfig(
             user=user,
-            device_id=device_id,
+            requester=requester,
             # FIXME: Currently, we're just manually copying the fields from the
-            # `SlidingSyncBody` into the config. How can we gurantee into the future
+            # `SlidingSyncBody` into the config. How can we guarantee into the future
             # that we don't forget any? I would like something more structured like
             # `copy_attributes(from=body, to=config)`
+            conn_id=body.conn_id,
             lists=body.lists,
             room_subscriptions=body.room_subscriptions,
             extensions=body.extensions,
diff --git a/synapse/server.py b/synapse/server.py
index 4a3f9ff934..46b9d83a04 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -559,6 +559,7 @@ class HomeServer(metaclass=abc.ABCMeta):
     def get_sync_handler(self) -> SyncHandler:
         return SyncHandler(self)
 
+    @cache_in_self
     def get_sliding_sync_handler(self) -> SlidingSyncHandler:
         return SlidingSyncHandler(self)
 
diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index 036972ac25..da3ebe66b8 100644
--- a/synapse/storage/databases/main/state_deltas.py
+++ b/synapse/storage/databases/main/state_deltas.py
@@ -26,6 +26,8 @@ import attr
 
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import LoggingTransaction
+from synapse.storage.databases.main.stream import _filter_results_by_stream
+from synapse.types import RoomStreamToken
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 logger = logging.getLogger(__name__)
@@ -156,3 +158,38 @@ class StateDeltasStore(SQLBaseStore):
             "get_max_stream_id_in_current_state_deltas",
             self._get_max_stream_id_in_current_state_deltas_txn,
         )
+
+    async def get_current_state_deltas_for_room(
+        self, room_id: str, from_token: RoomStreamToken, to_token: RoomStreamToken
+    ) -> List[StateDelta]:
+        """Get the state deltas between two tokens."""
+
+        def get_current_state_deltas_for_room_txn(
+            txn: LoggingTransaction,
+        ) -> List[StateDelta]:
+            sql = """
+                SELECT instance_name, stream_id, type, state_key, event_id, prev_event_id
+                FROM current_state_delta_stream
+                WHERE room_id = ? AND ? < stream_id AND stream_id <= ?
+                ORDER BY stream_id ASC
+            """
+            txn.execute(
+                sql, (room_id, from_token.stream, to_token.get_max_stream_pos())
+            )
+
+            return [
+                StateDelta(
+                    stream_id=row[1],
+                    room_id=room_id,
+                    event_type=row[2],
+                    state_key=row[3],
+                    event_id=row[4],
+                    prev_event_id=row[5],
+                )
+                for row in txn
+                if _filter_results_by_stream(from_token, to_token, row[0], row[1])
+            ]
+
+        return await self.db_pool.runInteraction(
+            "get_current_state_deltas_for_room", get_current_state_deltas_for_room_txn
+        )
diff --git a/synapse/types/handlers/__init__.py b/synapse/types/handlers/__init__.py
index 479222a18d..f3141b05a0 100644
--- a/synapse/types/handlers/__init__.py
+++ b/synapse/types/handlers/__init__.py
@@ -35,6 +35,7 @@ from synapse.types import (
     DeviceListUpdates,
     JsonDict,
     JsonMapping,
+    Requester,
     SlidingSyncStreamToken,
     StreamToken,
     UserID,
@@ -109,7 +110,7 @@ class SlidingSyncConfig(SlidingSyncBody):
     """
 
     user: UserID
-    device_id: Optional[str]
+    requester: Requester
 
     # Pydantic config
     class Config:
diff --git a/synapse/types/rest/client/__init__.py b/synapse/types/rest/client/__init__.py
index 34e07ddac5..dfe3b1e0f7 100644
--- a/synapse/types/rest/client/__init__.py
+++ b/synapse/types/rest/client/__init__.py
@@ -120,6 +120,9 @@ class SlidingSyncBody(RequestBodyModel):
     Sliding Sync API request body.
 
     Attributes:
+        conn_id: An optional string to identify this connection to the server.
+            Only one sliding sync connection is allowed per given conn_id (empty
+            or not).
         lists: Sliding window API. A map of list key to list information
             (:class:`SlidingSyncList`). Max lists: 100. The list keys should be
             arbitrary strings which the client is using to refer to the list. Keep this
@@ -343,6 +346,8 @@ class SlidingSyncBody(RequestBodyModel):
         e2ee: Optional[E2eeExtension] = None
         account_data: Optional[AccountDataExtension] = None
 
+    conn_id: Optional[str]
+
     # mypy workaround via https://github.com/pydantic/pydantic/issues/156#issuecomment-1130883884
     if TYPE_CHECKING:
         lists: Optional[Dict[str, SlidingSyncList]] = None