summary refs log tree commit diff
path: root/synapse/handlers/sliding_sync.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2024-07-23 18:33:19 +0100
committerErik Johnston <erik@matrix.org>2024-07-23 18:33:19 +0100
commite2f3d48882ba3a874c79097af4d3b865e8b22029 (patch)
treee5f2759043a63d4abf14d9fb9f61aa8448cf56bd /synapse/handlers/sliding_sync.py
parentAdd test for cache being cleared (diff)
parentChange token names again (diff)
downloadsynapse-e2f3d48882ba3a874c79097af4d3b865e8b22029.tar.xz
Merge branch 'erikj/ss_tokens' into erikj/ss_room_store
Diffstat (limited to 'synapse/handlers/sliding_sync.py')
-rw-r--r--synapse/handlers/sliding_sync.py178
1 files changed, 130 insertions, 48 deletions
diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
index 543a1f8836..69be113117 100644
--- a/synapse/handlers/sliding_sync.py
+++ b/synapse/handlers/sliding_sync.py
@@ -20,7 +20,18 @@
 import logging
 from enum import Enum
 from itertools import chain
-from typing import TYPE_CHECKING, Any, Dict, Final, List, Mapping, Optional, Set, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Dict,
+    Final,
+    List,
+    Mapping,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+)
 
 import attr
 from immutabledict import immutabledict
@@ -35,6 +46,7 @@ from synapse.storage.databases.main.roommember import extract_heroes_from_room_s
 from synapse.storage.databases.main.stream import CurrentStateDeltaMembership
 from synapse.storage.roommember import MemberSummary
 from synapse.types import (
+    DeviceListUpdates,
     JsonDict,
     PersistedEventPosition,
     Requester,
@@ -333,6 +345,9 @@ class StateValues:
     # `sender` in the timeline). We only give special meaning to this value when it's a
     # `state_key`.
     LAZY: Final = "$LAZY"
+    # Subsitute with the requester's user ID. Typically used by clients to get
+    # the user's membership.
+    ME: Final = "$ME"
 
 
 class SlidingSyncHandler:
@@ -344,6 +359,7 @@ class SlidingSyncHandler:
         self.notifier = hs.get_notifier()
         self.event_sources = hs.get_event_sources()
         self.relations_handler = hs.get_relations_handler()
+        self.device_handler = hs.get_device_handler()
         self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync
 
         self.connection_store = SlidingSyncConnectionStore()
@@ -374,10 +390,6 @@ class SlidingSyncHandler:
         # auth_blocking will occur)
         await self.auth_blocking.check_auth_blocking(requester=requester)
 
-        # TODO: If the To-Device extension is enabled and we have a `from_token`, delete
-        # any to-device messages before that token (since we now know that the device
-        # has received them). (see sync v2 for how to do this)
-
         # If we're working with a user-provided token, we need to make sure to wait for
         # this worker to catch up with the token so we don't skip past any incoming
         # events or future events if the user is nefariously, manually modifying the
@@ -387,7 +399,7 @@ class SlidingSyncHandler:
             # this returns false, it means we timed out waiting, and we should
             # just return an empty response.
             before_wait_ts = self.clock.time_msec()
-            if not await self.notifier.wait_for_stream_token(from_token.stream):
+            if not await self.notifier.wait_for_stream_token(from_token.stream_token):
                 logger.warning(
                     "Timed out waiting for worker to catch up. Returning empty response"
                 )
@@ -425,7 +437,7 @@ class SlidingSyncHandler:
                 sync_config.user.to_string(),
                 timeout_ms,
                 current_sync_callback,
-                from_token=from_token.stream,
+                from_token=from_token.stream_token,
             )
 
         return result
@@ -472,7 +484,7 @@ class SlidingSyncHandler:
                 await self.get_room_membership_for_user_at_to_token(
                     user=sync_config.user,
                     to_token=to_token,
-                    from_token=from_token.stream if from_token else None,
+                    from_token=from_token.stream_token if from_token else None,
                 )
             )
 
@@ -515,7 +527,6 @@ class SlidingSyncHandler:
                 # Also see `StateFilter.must_await_full_state(...)` for comparison
                 lazy_loading = (
                     membership_state_keys is not None
-                    and len(membership_state_keys) == 1
                     and StateValues.LAZY in membership_state_keys
                 )
 
@@ -626,7 +637,9 @@ class SlidingSyncHandler:
             await concurrently_execute(handle_room, relevant_room_map, 10)
 
         extensions = await self.get_extensions_response(
-            sync_config=sync_config, to_token=to_token
+            sync_config=sync_config,
+            from_token=from_token,
+            to_token=to_token,
         )
 
         if has_lists or has_room_subscriptions:
@@ -638,7 +651,7 @@ class SlidingSyncHandler:
                 unsent_room_ids=[],
             )
         elif from_token:
-            connection_token = from_token.connection
+            connection_token = from_token.connection_position
         else:
             # Initial sync without a `from_token` starts at `0`
             connection_token = 0
@@ -1242,34 +1255,33 @@ class SlidingSyncHandler:
         # Assemble a map of room ID to the `stream_ordering` of the last activity that the
         # user should see in the room (<= `to_token`)
         last_activity_in_room_map: Dict[str, int] = {}
-        for room_id, room_for_user in sync_room_map.items():
-            # If they are fully-joined to the room, let's find the latest activity
-            # at/before the `to_token`.
-            if room_for_user.membership == Membership.JOIN:
-                last_event_result = (
-                    await self.store.get_last_event_pos_in_room_before_stream_ordering(
-                        room_id, to_token.room_key
-                    )
-                )
-
-                # If the room has no events at/before the `to_token`, this is probably a
-                # mistake in the code that generates the `sync_room_map` since that should
-                # only give us rooms that the user had membership in during the token range.
-                assert last_event_result is not None
-
-                _, event_pos = last_event_result
 
-                last_activity_in_room_map[room_id] = event_pos.stream
-            else:
-                # Otherwise, if the user has left/been invited/knocked/been banned from
-                # a room, they shouldn't see anything past that point.
+        for room_id, room_for_user in sync_room_map.items():
+            if room_for_user.membership != Membership.JOIN:
+                # If the user has left/been invited/knocked/been banned from a
+                # room, they shouldn't see anything past that point.
                 #
-                # FIXME: It's possible that people should see beyond this point in
-                # invited/knocked cases if for example the room has
+                # FIXME: It's possible that people should see beyond this point
+                # in invited/knocked cases if for example the room has
                 # `invite`/`world_readable` history visibility, see
                 # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1653045932
                 last_activity_in_room_map[room_id] = room_for_user.event_pos.stream
 
+        # For fully-joined rooms, we find the latest activity at/before the
+        # `to_token`.
+        joined_room_positions = (
+            await self.store.bulk_get_last_event_pos_in_room_before_stream_ordering(
+                [
+                    room_id
+                    for room_id, room_for_user in sync_room_map.items()
+                    if room_for_user.membership == Membership.JOIN
+                ],
+                to_token.room_key,
+            )
+        )
+
+        last_activity_in_room_map.update(joined_room_positions)
+
         return sorted(
             sync_room_map.values(),
             # Sort by the last activity (stream_ordering) in the room
@@ -1414,11 +1426,11 @@ class SlidingSyncHandler:
         if from_token and not room_membership_for_user_at_to_token.newly_joined:
             room_status = await self.connection_store.have_sent_room(
                 sync_config=sync_config,
-                connection_token=from_token.connection,
+                connection_token=from_token.connection_position,
                 room_id=room_id,
             )
             if room_status.status == HaveSentRoomFlag.LIVE:
-                from_bound = from_token.stream.room_key
+                from_bound = from_token.stream_token.room_key
                 initial = False
             elif room_status.status == HaveSentRoomFlag.PREVIOUSLY:
                 assert room_status.last_token is not None
@@ -1529,7 +1541,9 @@ class SlidingSyncHandler:
                         instance_name=timeline_event.internal_metadata.instance_name,
                         stream=timeline_event.internal_metadata.stream_ordering,
                     )
-                    if persisted_position.persisted_after(from_token.stream.room_key):
+                    if persisted_position.persisted_after(
+                        from_token.stream_token.room_key
+                    ):
                         num_live += 1
                     else:
                         # Since we're iterating over the timeline events in
@@ -1699,6 +1713,8 @@ class SlidingSyncHandler:
 
                             # FIXME: We probably also care about invite, ban, kick, targets, etc
                             # but the spec only mentions "senders".
+                        elif state_key == StateValues.ME:
+                            required_state_types.append((state_type, user.to_string()))
                         else:
                             required_state_types.append((state_type, state_key))
 
@@ -1819,33 +1835,47 @@ class SlidingSyncHandler:
         self,
         sync_config: SlidingSyncConfig,
         to_token: StreamToken,
+        from_token: Optional[SlidingSyncStreamToken],
     ) -> SlidingSyncResult.Extensions:
         """Handle extension requests.
 
         Args:
             sync_config: Sync configuration
             to_token: The point in the stream to sync up to.
+            from_token: The point in the stream to sync from.
         """
 
         if sync_config.extensions is None:
             return SlidingSyncResult.Extensions()
 
         to_device_response = None
-        if sync_config.extensions.to_device:
-            to_device_response = await self.get_to_device_extensions_response(
+        if sync_config.extensions.to_device is not None:
+            to_device_response = await self.get_to_device_extension_response(
                 sync_config=sync_config,
                 to_device_request=sync_config.extensions.to_device,
                 to_token=to_token,
             )
 
-        return SlidingSyncResult.Extensions(to_device=to_device_response)
+        e2ee_response = None
+        if sync_config.extensions.e2ee is not None:
+            e2ee_response = await self.get_e2ee_extension_response(
+                sync_config=sync_config,
+                e2ee_request=sync_config.extensions.e2ee,
+                to_token=to_token,
+                from_token=from_token,
+            )
+
+        return SlidingSyncResult.Extensions(
+            to_device=to_device_response,
+            e2ee=e2ee_response,
+        )
 
-    async def get_to_device_extensions_response(
+    async def get_to_device_extension_response(
         self,
         sync_config: SlidingSyncConfig,
         to_device_request: SlidingSyncConfig.Extensions.ToDeviceExtension,
         to_token: StreamToken,
-    ) -> SlidingSyncResult.Extensions.ToDeviceExtension:
+    ) -> Optional[SlidingSyncResult.Extensions.ToDeviceExtension]:
         """Handle to-device extension (MSC3885)
 
         Args:
@@ -1853,14 +1883,16 @@ class SlidingSyncHandler:
             to_device_request: The to-device extension from the request
             to_token: The point in the stream to sync up to.
         """
-
         user_id = sync_config.user.to_string()
         device_id = sync_config.requester.device_id
 
+        # Skip if the extension is not enabled
+        if not to_device_request.enabled:
+            return None
+
         # Check that this request has a valid device ID (not all requests have
-        # to belong to a device, and so device_id is None), and that the
-        # extension is enabled.
-        if device_id is None or not to_device_request.enabled:
+        # to belong to a device, and so device_id is None)
+        if device_id is None:
             return SlidingSyncResult.Extensions.ToDeviceExtension(
                 next_batch=f"{to_token.to_device_key}",
                 events=[],
@@ -1912,6 +1944,56 @@ class SlidingSyncHandler:
             events=messages,
         )
 
+    async def get_e2ee_extension_response(
+        self,
+        sync_config: SlidingSyncConfig,
+        e2ee_request: SlidingSyncConfig.Extensions.E2eeExtension,
+        to_token: StreamToken,
+        from_token: Optional[SlidingSyncStreamToken],
+    ) -> Optional[SlidingSyncResult.Extensions.E2eeExtension]:
+        """Handle E2EE device extension (MSC3884)
+
+        Args:
+            sync_config: Sync configuration
+            e2ee_request: The e2ee extension from the request
+            to_token: The point in the stream to sync up to.
+            from_token: The point in the stream to sync from.
+        """
+        user_id = sync_config.user.to_string()
+        device_id = sync_config.requester.device_id
+
+        # Skip if the extension is not enabled
+        if not e2ee_request.enabled:
+            return None
+
+        device_list_updates: Optional[DeviceListUpdates] = None
+        if from_token is not None:
+            # TODO: This should take into account the `from_token` and `to_token`
+            device_list_updates = await self.device_handler.get_user_ids_changed(
+                user_id=user_id,
+                from_token=from_token.stream_token,
+            )
+
+        device_one_time_keys_count: Mapping[str, int] = {}
+        device_unused_fallback_key_types: Sequence[str] = []
+        if device_id:
+            # TODO: We should have a way to let clients differentiate between the states of:
+            #   * no change in OTK count since the provided since token
+            #   * the server has zero OTKs left for this device
+            #  Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
+            device_one_time_keys_count = await self.store.count_e2e_one_time_keys(
+                user_id, device_id
+            )
+            device_unused_fallback_key_types = (
+                await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
+            )
+
+        return SlidingSyncResult.Extensions.E2eeExtension(
+            device_list_updates=device_list_updates,
+            device_one_time_keys_count=device_one_time_keys_count,
+            device_unused_fallback_key_types=device_unused_fallback_key_types,
+        )
+
 
 class HaveSentRoomFlag(Enum):
     """Flag for whether we have sent the room down a sliding sync connection.
@@ -2027,7 +2109,7 @@ class SlidingSyncConnectionStore:
         """
         prev_connection_token = 0
         if from_token is not None:
-            prev_connection_token = from_token.connection
+            prev_connection_token = from_token.connection_position
 
         # If there are no changes then this is a noop.
         if not sent_room_ids and not unsent_room_ids:
@@ -2063,7 +2145,7 @@ class SlidingSyncConnectionStore:
 
         # Work out the new state for unsent rooms that were `LIVE`.
         if from_token:
-            new_unsent_state = HaveSentRoom.previously(from_token.stream.room_key)
+            new_unsent_state = HaveSentRoom.previously(from_token.stream_token.room_key)
         else:
             new_unsent_state = HAVE_SENT_ROOM_NEVER
 
@@ -2102,7 +2184,7 @@ class SlidingSyncConnectionStore:
         sync_statuses = {
             connection_token: room_statuses
             for connection_token, room_statuses in sync_statuses.items()
-            if connection_token == from_token.connection
+            if connection_token == from_token.connection_position
         }
         if sync_statuses:
             self._connections[conn_key] = sync_statuses