diff --git a/changelog.d/17595.misc b/changelog.d/17595.misc
new file mode 100644
index 0000000000..c8e040d87c
--- /dev/null
+++ b/changelog.d/17595.misc
@@ -0,0 +1 @@
+Refactor sliding sync class into multiple files.
diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync/__init__.py
index af8d7ab96c..1fcf2d149b 100644
--- a/synapse/handlers/sliding_sync.py
+++ b/synapse/handlers/sliding_sync/__init__.py
@@ -1,7 +1,7 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
-# Copyright (C) 2024 New Vector, Ltd
+# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
@@ -11,36 +11,21 @@
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
-# Originally licensed under the Apache License, Version 2.0:
-# <http://www.apache.org/licenses/LICENSE-2.0>.
-#
-# [This file includes modifications made by New Vector Limited]
-#
-#
+
import enum
import logging
-import typing
-from collections import ChainMap
-from enum import Enum
from itertools import chain
from typing import (
TYPE_CHECKING,
Any,
- Callable,
Dict,
- Final,
- Generic,
List,
Literal,
Mapping,
- MutableMapping,
Optional,
- Sequence,
Set,
Tuple,
- TypeVar,
Union,
- cast,
)
import attr
@@ -55,11 +40,18 @@ from synapse.api.constants import (
EventTypes,
Membership,
)
-from synapse.api.errors import SlidingSyncUnknownPosition
from synapse.events import EventBase, StrippedStateEvent
from synapse.events.utils import parse_stripped_state_event, strip_event
-from synapse.handlers.receipts import ReceiptEventSource
from synapse.handlers.relations import BundledAggregations
+from synapse.handlers.sliding_sync.extensions import SlidingSyncExtensionHandler
+from synapse.handlers.sliding_sync.store import SlidingSyncConnectionStore
+from synapse.handlers.sliding_sync.types import (
+ HaveSentRoomFlag,
+ MutablePerConnectionState,
+ PerConnectionState,
+ RoomSyncConfig,
+ StateValues,
+)
from synapse.logging.opentracing import (
SynapseTags,
log_kv,
@@ -79,10 +71,7 @@ from synapse.storage.databases.main.stream import (
)
from synapse.storage.roommember import MemberSummary
from synapse.types import (
- DeviceListUpdates,
JsonDict,
- JsonMapping,
- MultiWriterStreamToken,
MutableStateMap,
PersistedEventPosition,
Requester,
@@ -205,267 +194,6 @@ def filter_membership_for_sync(
)
-# We can't freeze this class because we want to update it in place with the
-# de-duplicated data.
-@attr.s(slots=True, auto_attribs=True)
-class RoomSyncConfig:
- """
- Holds the config for what data we should fetch for a room in the sync response.
-
- Attributes:
- timeline_limit: The maximum number of events to return in the timeline.
-
- required_state_map: Map from state event type to state_keys requested for the
- room. The values are close to `StateKey` but actually use a syntax where you
- can provide `*` wildcard and `$LAZY` for lazy-loading room members.
- """
-
- timeline_limit: int
- required_state_map: Dict[str, Set[str]]
-
- @classmethod
- def from_room_config(
- cls,
- room_params: SlidingSyncConfig.CommonRoomParameters,
- ) -> "RoomSyncConfig":
- """
- Create a `RoomSyncConfig` from a `SlidingSyncList`/`RoomSubscription` config.
-
- Args:
- room_params: `SlidingSyncConfig.SlidingSyncList` or `SlidingSyncConfig.RoomSubscription`
- """
- required_state_map: Dict[str, Set[str]] = {}
- for (
- state_type,
- state_key,
- ) in room_params.required_state:
- # If we already have a wildcard for this specific `state_key`, we don't need
- # to add it since the wildcard already covers it.
- if state_key in required_state_map.get(StateValues.WILDCARD, set()):
- continue
-
- # If we already have a wildcard `state_key` for this `state_type`, we don't need
- # to add anything else
- if StateValues.WILDCARD in required_state_map.get(state_type, set()):
- continue
-
- # If we're getting wildcards for the `state_type` and `state_key`, that's
- # all that matters so get rid of any other entries
- if state_type == StateValues.WILDCARD and state_key == StateValues.WILDCARD:
- required_state_map = {StateValues.WILDCARD: {StateValues.WILDCARD}}
- # We can break, since we don't need to add anything else
- break
-
- # If we're getting a wildcard for the `state_type`, get rid of any other
- # entries with the same `state_key`, since the wildcard will cover it already.
- elif state_type == StateValues.WILDCARD:
- # Get rid of any entries that match the `state_key`
- #
- # Make a copy so we don't run into an error: `dictionary changed size
- # during iteration`, when we remove items
- for (
- existing_state_type,
- existing_state_key_set,
- ) in list(required_state_map.items()):
- # Make a copy so we don't run into an error: `Set changed size during
- # iteration`, when we filter out and remove items
- for existing_state_key in existing_state_key_set.copy():
- if existing_state_key == state_key:
- existing_state_key_set.remove(state_key)
-
- # If we've the left the `set()` empty, remove it from the map
- if existing_state_key_set == set():
- required_state_map.pop(existing_state_type, None)
-
- # If we're getting a wildcard `state_key`, get rid of any other state_keys
- # for this `state_type` since the wildcard will cover it already.
- if state_key == StateValues.WILDCARD:
- required_state_map[state_type] = {state_key}
- # Otherwise, just add it to the set
- else:
- if required_state_map.get(state_type) is None:
- required_state_map[state_type] = {state_key}
- else:
- required_state_map[state_type].add(state_key)
-
- return cls(
- timeline_limit=room_params.timeline_limit,
- required_state_map=required_state_map,
- )
-
- def deep_copy(self) -> "RoomSyncConfig":
- required_state_map: Dict[str, Set[str]] = {
- state_type: state_key_set.copy()
- for state_type, state_key_set in self.required_state_map.items()
- }
-
- return RoomSyncConfig(
- timeline_limit=self.timeline_limit,
- required_state_map=required_state_map,
- )
-
- def combine_room_sync_config(
- self, other_room_sync_config: "RoomSyncConfig"
- ) -> None:
- """
- Combine this `RoomSyncConfig` with another `RoomSyncConfig` and take the
- superset union of the two.
- """
- # Take the highest timeline limit
- if self.timeline_limit < other_room_sync_config.timeline_limit:
- self.timeline_limit = other_room_sync_config.timeline_limit
-
- # Union the required state
- for (
- state_type,
- state_key_set,
- ) in other_room_sync_config.required_state_map.items():
- # If we already have a wildcard for everything, we don't need to add
- # anything else
- if StateValues.WILDCARD in self.required_state_map.get(
- StateValues.WILDCARD, set()
- ):
- break
-
- # If we already have a wildcard `state_key` for this `state_type`, we don't need
- # to add anything else
- if StateValues.WILDCARD in self.required_state_map.get(state_type, set()):
- continue
-
- # If we're getting wildcards for the `state_type` and `state_key`, that's
- # all that matters so get rid of any other entries
- if (
- state_type == StateValues.WILDCARD
- and StateValues.WILDCARD in state_key_set
- ):
- self.required_state_map = {state_type: {StateValues.WILDCARD}}
- # We can break, since we don't need to add anything else
- break
-
- for state_key in state_key_set:
- # If we already have a wildcard for this specific `state_key`, we don't need
- # to add it since the wildcard already covers it.
- if state_key in self.required_state_map.get(
- StateValues.WILDCARD, set()
- ):
- continue
-
- # If we're getting a wildcard for the `state_type`, get rid of any other
- # entries with the same `state_key`, since the wildcard will cover it already.
- if state_type == StateValues.WILDCARD:
- # Get rid of any entries that match the `state_key`
- #
- # Make a copy so we don't run into an error: `dictionary changed size
- # during iteration`, when we remove items
- for existing_state_type, existing_state_key_set in list(
- self.required_state_map.items()
- ):
- # Make a copy so we don't run into an error: `Set changed size during
- # iteration`, when we filter out and remove items
- for existing_state_key in existing_state_key_set.copy():
- if existing_state_key == state_key:
- existing_state_key_set.remove(state_key)
-
- # If we've the left the `set()` empty, remove it from the map
- if existing_state_key_set == set():
- self.required_state_map.pop(existing_state_type, None)
-
- # If we're getting a wildcard `state_key`, get rid of any other state_keys
- # for this `state_type` since the wildcard will cover it already.
- if state_key == StateValues.WILDCARD:
- self.required_state_map[state_type] = {state_key}
- break
- # Otherwise, just add it to the set
- else:
- if self.required_state_map.get(state_type) is None:
- self.required_state_map[state_type] = {state_key}
- else:
- self.required_state_map[state_type].add(state_key)
-
- def must_await_full_state(
- self,
- is_mine_id: Callable[[str], bool],
- ) -> bool:
- """
- Check if we have a we're only requesting `required_state` which is completely
- satisfied even with partial state, then we don't need to `await_full_state` before
- we can return it.
-
- Also see `StateFilter.must_await_full_state(...)` for comparison
-
- Partially-stated rooms should have all state events except for remote membership
- events so if we require a remote membership event anywhere, then we need to
- return `True` (requires full state).
-
- Args:
- is_mine_id: a callable which confirms if a given state_key matches a mxid
- of a local user
- """
- wildcard_state_keys = self.required_state_map.get(StateValues.WILDCARD)
- # Requesting *all* state in the room so we have to wait
- if (
- wildcard_state_keys is not None
- and StateValues.WILDCARD in wildcard_state_keys
- ):
- return True
-
- # If the wildcards don't refer to remote user IDs, then we don't need to wait
- # for full state.
- if wildcard_state_keys is not None:
- for possible_user_id in wildcard_state_keys:
- if not possible_user_id[0].startswith(UserID.SIGIL):
- # Not a user ID
- continue
-
- localpart_hostname = possible_user_id.split(":", 1)
- if len(localpart_hostname) < 2:
- # Not a user ID
- continue
-
- if not is_mine_id(possible_user_id):
- return True
-
- membership_state_keys = self.required_state_map.get(EventTypes.Member)
- # We aren't requesting any membership events at all so the partial state will
- # cover us.
- if membership_state_keys is None:
- return False
-
- # If we're requesting entirely local users, the partial state will cover us.
- for user_id in membership_state_keys:
- if user_id == StateValues.ME:
- continue
- # We're lazy-loading membership so we can just return the state we have.
- # Lazy-loading means we include membership for any event `sender` in the
- # timeline but since we had to auth those timeline events, we will have the
- # membership state for them (including from remote senders).
- elif user_id == StateValues.LAZY:
- continue
- elif user_id == StateValues.WILDCARD:
- return False
- elif not is_mine_id(user_id):
- return True
-
- # Local users only so the partial state will cover us.
- return False
-
-
-class StateValues:
- """
- Understood values of the (type, state_key) tuple in `required_state`.
- """
-
- # Include all state events of the given type
- WILDCARD: Final = "*"
- # Lazy-load room membership events (include room membership events for any event
- # `sender` in the timeline). We only give special meaning to this value when it's a
- # `state_key`.
- LAZY: Final = "$LAZY"
- # Subsitute with the requester's user ID. Typically used by clients to get
- # the user's membership.
- ME: Final = "$ME"
-
-
class SlidingSyncHandler:
def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
@@ -475,12 +203,11 @@ class SlidingSyncHandler:
self.notifier = hs.get_notifier()
self.event_sources = hs.get_event_sources()
self.relations_handler = hs.get_relations_handler()
- self.device_handler = hs.get_device_handler()
- self.push_rules_handler = hs.get_push_rules_handler()
self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync
self.is_mine_id = hs.is_mine_id
self.connection_store = SlidingSyncConnectionStore()
+ self.extensions = SlidingSyncExtensionHandler(hs)
async def wait_for_sync_for_user(
self,
@@ -868,7 +595,7 @@ class SlidingSyncHandler:
with start_active_span("sliding_sync.generate_room_entries"):
await concurrently_execute(handle_room, relevant_rooms_to_send_map, 10)
- extensions = await self.get_extensions_response(
+ extensions = await self.extensions.get_extensions_response(
sync_config=sync_config,
actual_lists=lists,
previous_connection_state=previous_connection_state,
@@ -2597,984 +2324,3 @@ class SlidingSyncHandler:
notification_count=0,
highlight_count=0,
)
-
- @trace
- async def get_extensions_response(
- self,
- sync_config: SlidingSyncConfig,
- previous_connection_state: "PerConnectionState",
- new_connection_state: "MutablePerConnectionState",
- actual_lists: Dict[str, SlidingSyncResult.SlidingWindowList],
- actual_room_ids: Set[str],
- actual_room_response_map: Dict[str, SlidingSyncResult.RoomResult],
- to_token: StreamToken,
- from_token: Optional[SlidingSyncStreamToken],
- ) -> SlidingSyncResult.Extensions:
- """Handle extension requests.
-
- Args:
- sync_config: Sync configuration
- new_connection_state: Snapshot of the current per-connection state
- new_per_connection_state: A mutable copy of the per-connection
- state, used to record updates to the state during this request.
- actual_lists: Sliding window API. A map of list key to list results in the
- Sliding Sync response.
- actual_room_ids: The actual room IDs in the the Sliding Sync response.
- actual_room_response_map: A map of room ID to room results in the the
- Sliding Sync response.
- to_token: The point in the stream to sync up to.
- from_token: The point in the stream to sync from.
- """
-
- if sync_config.extensions is None:
- return SlidingSyncResult.Extensions()
-
- to_device_response = None
- if sync_config.extensions.to_device is not None:
- to_device_response = await self.get_to_device_extension_response(
- sync_config=sync_config,
- to_device_request=sync_config.extensions.to_device,
- to_token=to_token,
- )
-
- e2ee_response = None
- if sync_config.extensions.e2ee is not None:
- e2ee_response = await self.get_e2ee_extension_response(
- sync_config=sync_config,
- e2ee_request=sync_config.extensions.e2ee,
- to_token=to_token,
- from_token=from_token,
- )
-
- account_data_response = None
- if sync_config.extensions.account_data is not None:
- account_data_response = await self.get_account_data_extension_response(
- sync_config=sync_config,
- actual_lists=actual_lists,
- actual_room_ids=actual_room_ids,
- account_data_request=sync_config.extensions.account_data,
- to_token=to_token,
- from_token=from_token,
- )
-
- receipts_response = None
- if sync_config.extensions.receipts is not None:
- receipts_response = await self.get_receipts_extension_response(
- sync_config=sync_config,
- previous_connection_state=previous_connection_state,
- new_connection_state=new_connection_state,
- actual_lists=actual_lists,
- actual_room_ids=actual_room_ids,
- actual_room_response_map=actual_room_response_map,
- receipts_request=sync_config.extensions.receipts,
- to_token=to_token,
- from_token=from_token,
- )
-
- typing_response = None
- if sync_config.extensions.typing is not None:
- typing_response = await self.get_typing_extension_response(
- sync_config=sync_config,
- actual_lists=actual_lists,
- actual_room_ids=actual_room_ids,
- actual_room_response_map=actual_room_response_map,
- typing_request=sync_config.extensions.typing,
- to_token=to_token,
- from_token=from_token,
- )
-
- return SlidingSyncResult.Extensions(
- to_device=to_device_response,
- e2ee=e2ee_response,
- account_data=account_data_response,
- receipts=receipts_response,
- typing=typing_response,
- )
-
- def find_relevant_room_ids_for_extension(
- self,
- requested_lists: Optional[List[str]],
- requested_room_ids: Optional[List[str]],
- actual_lists: Dict[str, SlidingSyncResult.SlidingWindowList],
- actual_room_ids: Set[str],
- ) -> Set[str]:
- """
- Handle the reserved `lists`/`rooms` keys for extensions. Extensions should only
- return results for rooms in the Sliding Sync response. This matches up the
- requested rooms/lists with the actual lists/rooms in the Sliding Sync response.
-
- {"lists": []} // Do not process any lists.
- {"lists": ["rooms", "dms"]} // Process only a subset of lists.
- {"lists": ["*"]} // Process all lists defined in the Sliding Window API. (This is the default.)
-
- {"rooms": []} // Do not process any specific rooms.
- {"rooms": ["!a:b", "!c:d"]} // Process only a subset of room subscriptions.
- {"rooms": ["*"]} // Process all room subscriptions defined in the Room Subscription API. (This is the default.)
-
- Args:
- requested_lists: The `lists` from the extension request.
- requested_room_ids: The `rooms` from the extension request.
- actual_lists: The actual lists from the Sliding Sync response.
- actual_room_ids: The actual room subscriptions from the Sliding Sync request.
- """
-
- # We only want to include account data for rooms that are already in the sliding
- # sync response AND that were requested in the account data request.
- relevant_room_ids: Set[str] = set()
-
- # See what rooms from the room subscriptions we should get account data for
- if requested_room_ids is not None:
- for room_id in requested_room_ids:
- # A wildcard means we process all rooms from the room subscriptions
- if room_id == "*":
- relevant_room_ids.update(actual_room_ids)
- break
-
- if room_id in actual_room_ids:
- relevant_room_ids.add(room_id)
-
- # See what rooms from the sliding window lists we should get account data for
- if requested_lists is not None:
- for list_key in requested_lists:
- # Just some typing because we share the variable name in multiple places
- actual_list: Optional[SlidingSyncResult.SlidingWindowList] = None
-
- # A wildcard means we process rooms from all lists
- if list_key == "*":
- for actual_list in actual_lists.values():
- # We only expect a single SYNC operation for any list
- assert len(actual_list.ops) == 1
- sync_op = actual_list.ops[0]
- assert sync_op.op == OperationType.SYNC
-
- relevant_room_ids.update(sync_op.room_ids)
-
- break
-
- actual_list = actual_lists.get(list_key)
- if actual_list is not None:
- # We only expect a single SYNC operation for any list
- assert len(actual_list.ops) == 1
- sync_op = actual_list.ops[0]
- assert sync_op.op == OperationType.SYNC
-
- relevant_room_ids.update(sync_op.room_ids)
-
- return relevant_room_ids
-
- @trace
- async def get_to_device_extension_response(
- self,
- sync_config: SlidingSyncConfig,
- to_device_request: SlidingSyncConfig.Extensions.ToDeviceExtension,
- to_token: StreamToken,
- ) -> Optional[SlidingSyncResult.Extensions.ToDeviceExtension]:
- """Handle to-device extension (MSC3885)
-
- Args:
- sync_config: Sync configuration
- to_device_request: The to-device extension from the request
- to_token: The point in the stream to sync up to.
- """
- user_id = sync_config.user.to_string()
- device_id = sync_config.requester.device_id
-
- # Skip if the extension is not enabled
- if not to_device_request.enabled:
- return None
-
- # Check that this request has a valid device ID (not all requests have
- # to belong to a device, and so device_id is None)
- if device_id is None:
- return SlidingSyncResult.Extensions.ToDeviceExtension(
- next_batch=f"{to_token.to_device_key}",
- events=[],
- )
-
- since_stream_id = 0
- if to_device_request.since is not None:
- # We've already validated this is an int.
- since_stream_id = int(to_device_request.since)
-
- if to_token.to_device_key < since_stream_id:
- # The since token is ahead of our current token, so we return an
- # empty response.
- logger.warning(
- "Got to-device.since from the future. since token: %r is ahead of our current to_device stream position: %r",
- since_stream_id,
- to_token.to_device_key,
- )
- return SlidingSyncResult.Extensions.ToDeviceExtension(
- next_batch=to_device_request.since,
- events=[],
- )
-
- # Delete everything before the given since token, as we know the
- # device must have received them.
- deleted = await self.store.delete_messages_for_device(
- user_id=user_id,
- device_id=device_id,
- up_to_stream_id=since_stream_id,
- )
-
- logger.debug(
- "Deleted %d to-device messages up to %d for %s",
- deleted,
- since_stream_id,
- user_id,
- )
-
- messages, stream_id = await self.store.get_messages_for_device(
- user_id=user_id,
- device_id=device_id,
- from_stream_id=since_stream_id,
- to_stream_id=to_token.to_device_key,
- limit=min(to_device_request.limit, 100), # Limit to at most 100 events
- )
-
- return SlidingSyncResult.Extensions.ToDeviceExtension(
- next_batch=f"{stream_id}",
- events=messages,
- )
-
- @trace
- async def get_e2ee_extension_response(
- self,
- sync_config: SlidingSyncConfig,
- e2ee_request: SlidingSyncConfig.Extensions.E2eeExtension,
- to_token: StreamToken,
- from_token: Optional[SlidingSyncStreamToken],
- ) -> Optional[SlidingSyncResult.Extensions.E2eeExtension]:
- """Handle E2EE device extension (MSC3884)
-
- Args:
- sync_config: Sync configuration
- e2ee_request: The e2ee extension from the request
- to_token: The point in the stream to sync up to.
- from_token: The point in the stream to sync from.
- """
- user_id = sync_config.user.to_string()
- device_id = sync_config.requester.device_id
-
- # Skip if the extension is not enabled
- if not e2ee_request.enabled:
- return None
-
- device_list_updates: Optional[DeviceListUpdates] = None
- if from_token is not None:
- # TODO: This should take into account the `from_token` and `to_token`
- device_list_updates = await self.device_handler.get_user_ids_changed(
- user_id=user_id,
- from_token=from_token.stream_token,
- )
-
- device_one_time_keys_count: Mapping[str, int] = {}
- device_unused_fallback_key_types: Sequence[str] = []
- if device_id:
- # TODO: We should have a way to let clients differentiate between the states of:
- # * no change in OTK count since the provided since token
- # * the server has zero OTKs left for this device
- # Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
- device_one_time_keys_count = await self.store.count_e2e_one_time_keys(
- user_id, device_id
- )
- device_unused_fallback_key_types = (
- await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
- )
-
- return SlidingSyncResult.Extensions.E2eeExtension(
- device_list_updates=device_list_updates,
- device_one_time_keys_count=device_one_time_keys_count,
- device_unused_fallback_key_types=device_unused_fallback_key_types,
- )
-
- @trace
- async def get_account_data_extension_response(
- self,
- sync_config: SlidingSyncConfig,
- actual_lists: Dict[str, SlidingSyncResult.SlidingWindowList],
- actual_room_ids: Set[str],
- account_data_request: SlidingSyncConfig.Extensions.AccountDataExtension,
- to_token: StreamToken,
- from_token: Optional[SlidingSyncStreamToken],
- ) -> Optional[SlidingSyncResult.Extensions.AccountDataExtension]:
- """Handle Account Data extension (MSC3959)
-
- Args:
- sync_config: Sync configuration
- actual_lists: Sliding window API. A map of list key to list results in the
- Sliding Sync response.
- actual_room_ids: The actual room IDs in the the Sliding Sync response.
- account_data_request: The account_data extension from the request
- to_token: The point in the stream to sync up to.
- from_token: The point in the stream to sync from.
- """
- user_id = sync_config.user.to_string()
-
- # Skip if the extension is not enabled
- if not account_data_request.enabled:
- return None
-
- global_account_data_map: Mapping[str, JsonMapping] = {}
- if from_token is not None:
- # TODO: This should take into account the `from_token` and `to_token`
- global_account_data_map = (
- await self.store.get_updated_global_account_data_for_user(
- user_id, from_token.stream_token.account_data_key
- )
- )
-
- have_push_rules_changed = await self.store.have_push_rules_changed_for_user(
- user_id, from_token.stream_token.push_rules_key
- )
- if have_push_rules_changed:
- global_account_data_map = dict(global_account_data_map)
- # TODO: This should take into account the `from_token` and `to_token`
- global_account_data_map[AccountDataTypes.PUSH_RULES] = (
- await self.push_rules_handler.push_rules_for_user(sync_config.user)
- )
- else:
- # TODO: This should take into account the `to_token`
- all_global_account_data = await self.store.get_global_account_data_for_user(
- user_id
- )
-
- global_account_data_map = dict(all_global_account_data)
- # TODO: This should take into account the `to_token`
- global_account_data_map[AccountDataTypes.PUSH_RULES] = (
- await self.push_rules_handler.push_rules_for_user(sync_config.user)
- )
-
- # Fetch room account data
- account_data_by_room_map: Mapping[str, Mapping[str, JsonMapping]] = {}
- relevant_room_ids = self.find_relevant_room_ids_for_extension(
- requested_lists=account_data_request.lists,
- requested_room_ids=account_data_request.rooms,
- actual_lists=actual_lists,
- actual_room_ids=actual_room_ids,
- )
- if len(relevant_room_ids) > 0:
- if from_token is not None:
- # TODO: This should take into account the `from_token` and `to_token`
- account_data_by_room_map = (
- await self.store.get_updated_room_account_data_for_user(
- user_id, from_token.stream_token.account_data_key
- )
- )
- else:
- # TODO: This should take into account the `to_token`
- account_data_by_room_map = (
- await self.store.get_room_account_data_for_user(user_id)
- )
-
- # Filter down to the relevant rooms
- account_data_by_room_map = {
- room_id: account_data_map
- for room_id, account_data_map in account_data_by_room_map.items()
- if room_id in relevant_room_ids
- }
-
- return SlidingSyncResult.Extensions.AccountDataExtension(
- global_account_data_map=global_account_data_map,
- account_data_by_room_map=account_data_by_room_map,
- )
-
- @trace
- async def get_receipts_extension_response(
- self,
- sync_config: SlidingSyncConfig,
- previous_connection_state: "PerConnectionState",
- new_connection_state: "MutablePerConnectionState",
- actual_lists: Dict[str, SlidingSyncResult.SlidingWindowList],
- actual_room_ids: Set[str],
- actual_room_response_map: Dict[str, SlidingSyncResult.RoomResult],
- receipts_request: SlidingSyncConfig.Extensions.ReceiptsExtension,
- to_token: StreamToken,
- from_token: Optional[SlidingSyncStreamToken],
- ) -> Optional[SlidingSyncResult.Extensions.ReceiptsExtension]:
- """Handle Receipts extension (MSC3960)
-
- Args:
- sync_config: Sync configuration
- previous_connection_state: The current per-connection state
- new_connection_state: A mutable copy of the per-connection
- state, used to record updates to the state.
- actual_lists: Sliding window API. A map of list key to list results in the
- Sliding Sync response.
- actual_room_ids: The actual room IDs in the the Sliding Sync response.
- actual_room_response_map: A map of room ID to room results in the the
- Sliding Sync response.
- account_data_request: The account_data extension from the request
- to_token: The point in the stream to sync up to.
- from_token: The point in the stream to sync from.
- """
- # Skip if the extension is not enabled
- if not receipts_request.enabled:
- return None
-
- relevant_room_ids = self.find_relevant_room_ids_for_extension(
- requested_lists=receipts_request.lists,
- requested_room_ids=receipts_request.rooms,
- actual_lists=actual_lists,
- actual_room_ids=actual_room_ids,
- )
-
- room_id_to_receipt_map: Dict[str, JsonMapping] = {}
- if len(relevant_room_ids) > 0:
- # We need to handle the different cases depending on if we have sent
- # down receipts previously or not, so we split the relevant rooms
- # up into different collections based on status.
- live_rooms = set()
- previously_rooms: Dict[str, MultiWriterStreamToken] = {}
- initial_rooms = set()
-
- for room_id in relevant_room_ids:
- if not from_token:
- initial_rooms.add(room_id)
- continue
-
- # If we're sending down the room from scratch again for some reason, we
- # should always resend the receipts as well (regardless of if
- # we've sent them down before). This is to mimic the behaviour
- # of what happens on initial sync, where you get a chunk of
- # timeline with all of the corresponding receipts for the events in the timeline.
- room_result = actual_room_response_map.get(room_id)
- if room_result is not None and room_result.initial:
- initial_rooms.add(room_id)
- continue
-
- room_status = previous_connection_state.receipts.have_sent_room(room_id)
- if room_status.status == HaveSentRoomFlag.LIVE:
- live_rooms.add(room_id)
- elif room_status.status == HaveSentRoomFlag.PREVIOUSLY:
- assert room_status.last_token is not None
- previously_rooms[room_id] = room_status.last_token
- elif room_status.status == HaveSentRoomFlag.NEVER:
- initial_rooms.add(room_id)
- else:
- assert_never(room_status.status)
-
- # The set of receipts that we fetched. Private receipts need to be
- # filtered out before returning.
- fetched_receipts = []
-
- # For live rooms we just fetch all receipts in those rooms since the
- # `since` token.
- if live_rooms:
- assert from_token is not None
- receipts = await self.store.get_linearized_receipts_for_rooms(
- room_ids=live_rooms,
- from_key=from_token.stream_token.receipt_key,
- to_key=to_token.receipt_key,
- )
- fetched_receipts.extend(receipts)
-
- # For rooms we've previously sent down, but aren't up to date, we
- # need to use the from token from the room status.
- if previously_rooms:
- for room_id, receipt_token in previously_rooms.items():
- # TODO: Limit the number of receipts we're about to send down
- # for the room, if its too many we should TODO
- previously_receipts = (
- await self.store.get_linearized_receipts_for_room(
- room_id=room_id,
- from_key=receipt_token,
- to_key=to_token.receipt_key,
- )
- )
- fetched_receipts.extend(previously_receipts)
-
- # For rooms we haven't previously sent down, we could send all receipts
- # from that room but we only want to include receipts for events
- # in the timeline to avoid bloating and blowing up the sync response
- # as the number of users in the room increases. (this behavior is part of the spec)
- initial_rooms_and_event_ids = [
- (room_id, event.event_id)
- for room_id in initial_rooms
- if room_id in actual_room_response_map
- for event in actual_room_response_map[room_id].timeline_events
- ]
- if initial_rooms_and_event_ids:
- initial_receipts = await self.store.get_linearized_receipts_for_events(
- room_and_event_ids=initial_rooms_and_event_ids,
- )
- fetched_receipts.extend(initial_receipts)
-
- fetched_receipts = ReceiptEventSource.filter_out_private_receipts(
- fetched_receipts, sync_config.user.to_string()
- )
-
- for receipt in fetched_receipts:
- # These fields should exist for every receipt
- room_id = receipt["room_id"]
- type = receipt["type"]
- content = receipt["content"]
-
- room_id_to_receipt_map[room_id] = {"type": type, "content": content}
-
- # Now we update the per-connection state to track which receipts we have
- # and haven't sent down.
- new_connection_state.receipts.record_sent_rooms(relevant_room_ids)
-
- if from_token:
- # Now find the set of rooms that may have receipts that we're not sending
- # down. We only need to check rooms that we have previously returned
- # receipts for (in `previous_connection_state`) because we only care about
- # updating `LIVE` rooms to `PREVIOUSLY`. The `PREVIOUSLY` rooms will just
- # stay pointing at their previous position so we don't need to waste time
- # checking those and since we default to `NEVER`, rooms that were `NEVER`
- # sent before don't need to be recorded as we'll handle them correctly when
- # they come into range for the first time.
- rooms_no_receipts = [
- room_id
- for room_id, room_status in previous_connection_state.receipts._statuses.items()
- if room_status.status == HaveSentRoomFlag.LIVE
- and room_id not in relevant_room_ids
- ]
- changed_rooms = await self.store.get_rooms_with_receipts_between(
- rooms_no_receipts,
- from_key=from_token.stream_token.receipt_key,
- to_key=to_token.receipt_key,
- )
- new_connection_state.receipts.record_unsent_rooms(
- changed_rooms, from_token.stream_token.receipt_key
- )
-
- return SlidingSyncResult.Extensions.ReceiptsExtension(
- room_id_to_receipt_map=room_id_to_receipt_map,
- )
-
- async def get_typing_extension_response(
- self,
- sync_config: SlidingSyncConfig,
- actual_lists: Dict[str, SlidingSyncResult.SlidingWindowList],
- actual_room_ids: Set[str],
- actual_room_response_map: Dict[str, SlidingSyncResult.RoomResult],
- typing_request: SlidingSyncConfig.Extensions.TypingExtension,
- to_token: StreamToken,
- from_token: Optional[SlidingSyncStreamToken],
- ) -> Optional[SlidingSyncResult.Extensions.TypingExtension]:
- """Handle Typing Notification extension (MSC3961)
-
- Args:
- sync_config: Sync configuration
- actual_lists: Sliding window API. A map of list key to list results in the
- Sliding Sync response.
- actual_room_ids: The actual room IDs in the the Sliding Sync response.
- actual_room_response_map: A map of room ID to room results in the the
- Sliding Sync response.
- account_data_request: The account_data extension from the request
- to_token: The point in the stream to sync up to.
- from_token: The point in the stream to sync from.
- """
- # Skip if the extension is not enabled
- if not typing_request.enabled:
- return None
-
- relevant_room_ids = self.find_relevant_room_ids_for_extension(
- requested_lists=typing_request.lists,
- requested_room_ids=typing_request.rooms,
- actual_lists=actual_lists,
- actual_room_ids=actual_room_ids,
- )
-
- room_id_to_typing_map: Dict[str, JsonMapping] = {}
- if len(relevant_room_ids) > 0:
- # Note: We don't need to take connection tracking into account for typing
- # notifications because they'll get anything still relevant and hasn't timed
- # out when the room comes into range. We consider the gap where the room
- # fell out of range, as long enough for any typing notifications to have
- # timed out (it's not worth the 30 seconds of data we may have missed).
- typing_source = self.event_sources.sources.typing
- typing_notifications, _ = await typing_source.get_new_events(
- user=sync_config.user,
- from_key=(from_token.stream_token.typing_key if from_token else 0),
- to_key=to_token.typing_key,
- # This is a dummy value and isn't used in the function
- limit=0,
- room_ids=relevant_room_ids,
- is_guest=False,
- )
-
- for typing_notification in typing_notifications:
- # These fields should exist for every typing notification
- room_id = typing_notification["room_id"]
- type = typing_notification["type"]
- content = typing_notification["content"]
-
- room_id_to_typing_map[room_id] = {"type": type, "content": content}
-
- return SlidingSyncResult.Extensions.TypingExtension(
- room_id_to_typing_map=room_id_to_typing_map,
- )
-
-
-class HaveSentRoomFlag(Enum):
- """Flag for whether we have sent the room down a sliding sync connection.
-
- The valid state changes here are:
- NEVER -> LIVE
- LIVE -> PREVIOUSLY
- PREVIOUSLY -> LIVE
- """
-
- # The room has never been sent down (or we have forgotten we have sent it
- # down).
- NEVER = 1
-
- # We have previously sent the room down, but there are updates that we
- # haven't sent down.
- PREVIOUSLY = 2
-
- # We have sent the room down and the client has received all updates.
- LIVE = 3
-
-
-T = TypeVar("T")
-
-
-@attr.s(auto_attribs=True, slots=True, frozen=True)
-class HaveSentRoom(Generic[T]):
- """Whether we have sent the room data down a sliding sync connection.
-
- We are generic over the type of token used, e.g. `RoomStreamToken` or
- `MultiWriterStreamToken`.
-
- Attributes:
- status: Flag of if we have or haven't sent down the room
- last_token: If the flag is `PREVIOUSLY` then this is non-null and
- contains the last stream token of the last updates we sent down
- the room, i.e. we still need to send everything since then to the
- client.
- """
-
- status: HaveSentRoomFlag
- last_token: Optional[T]
-
- @staticmethod
- def live() -> "HaveSentRoom[T]":
- return HaveSentRoom(HaveSentRoomFlag.LIVE, None)
-
- @staticmethod
- def previously(last_token: T) -> "HaveSentRoom[T]":
- """Constructor for `PREVIOUSLY` flag."""
- return HaveSentRoom(HaveSentRoomFlag.PREVIOUSLY, last_token)
-
- @staticmethod
- def never() -> "HaveSentRoom[T]":
- return HaveSentRoom(HaveSentRoomFlag.NEVER, None)
-
-
-@attr.s(auto_attribs=True, slots=True, frozen=True)
-class RoomStatusMap(Generic[T]):
- """For a given stream, e.g. events, records what we have or have not sent
- down for that stream in a given room."""
-
- # `room_id` -> `HaveSentRoom`
- _statuses: Mapping[str, HaveSentRoom[T]] = attr.Factory(dict)
-
- def have_sent_room(self, room_id: str) -> HaveSentRoom[T]:
- """Return whether we have previously sent the room down"""
- return self._statuses.get(room_id, HaveSentRoom.never())
-
- def get_mutable(self) -> "MutableRoomStatusMap[T]":
- """Get a mutable copy of this state."""
- return MutableRoomStatusMap(
- statuses=self._statuses,
- )
-
- def copy(self) -> "RoomStatusMap[T]":
- """Make a copy of the class. Useful for converting from a mutable to
- immutable version."""
-
- return RoomStatusMap(statuses=dict(self._statuses))
-
-
-class MutableRoomStatusMap(RoomStatusMap[T]):
- """A mutable version of `RoomStatusMap`"""
-
- # We use a ChainMap here so that we can easily track what has been updated
- # and what hasn't. Note that when we persist the per connection state this
- # will get flattened to a normal dict (via calling `.copy()`)
- _statuses: typing.ChainMap[str, HaveSentRoom[T]]
-
- def __init__(
- self,
- statuses: Mapping[str, HaveSentRoom[T]],
- ) -> None:
- # ChainMap requires a mutable mapping, but we're not actually going to
- # mutate it.
- statuses = cast(MutableMapping, statuses)
-
- super().__init__(
- statuses=ChainMap({}, statuses),
- )
-
- def get_updates(self) -> Mapping[str, HaveSentRoom[T]]:
- """Return only the changes that were made"""
- return self._statuses.maps[0]
-
- def record_sent_rooms(self, room_ids: StrCollection) -> None:
- """Record that we have sent these rooms in the response"""
- for room_id in room_ids:
- current_status = self._statuses.get(room_id, HaveSentRoom.never())
- if current_status.status == HaveSentRoomFlag.LIVE:
- continue
-
- self._statuses[room_id] = HaveSentRoom.live()
-
- def record_unsent_rooms(self, room_ids: StrCollection, from_token: T) -> None:
- """Record that we have not sent these rooms in the response, but there
- have been updates.
- """
- # Whether we add/update the entries for unsent rooms depends on the
- # existing entry:
- # - LIVE: We have previously sent down everything up to
- # `last_room_token, so we update the entry to be `PREVIOUSLY` with
- # `last_room_token`.
- # - PREVIOUSLY: We have previously sent down everything up to *a*
- # given token, so we don't need to update the entry.
- # - NEVER: We have never previously sent down the room, and we haven't
- # sent anything down this time either so we leave it as NEVER.
-
- for room_id in room_ids:
- current_status = self._statuses.get(room_id, HaveSentRoom.never())
- if current_status.status != HaveSentRoomFlag.LIVE:
- continue
-
- self._statuses[room_id] = HaveSentRoom.previously(from_token)
-
-
-@attr.s(auto_attribs=True)
-class PerConnectionState:
- """The per-connection state. A snapshot of what we've sent down the
- connection before.
-
- Currently, we track whether we've sent down various aspects of a given room
- before.
-
- We use the `rooms` field to store the position in the events stream for each
- room that we've previously sent to the client before. On the next request
- that includes the room, we can then send only what's changed since that
- recorded position.
-
- Same goes for the `receipts` field so we only need to send the new receipts
- since the last time you made a sync request.
-
- Attributes:
- rooms: The status of each room for the events stream.
- receipts: The status of each room for the receipts stream.
- room_configs: Map from room_id to the `RoomSyncConfig` of all
- rooms that we have previously sent down.
- """
-
- rooms: RoomStatusMap[RoomStreamToken] = attr.Factory(RoomStatusMap)
- receipts: RoomStatusMap[MultiWriterStreamToken] = attr.Factory(RoomStatusMap)
-
- room_configs: Mapping[str, RoomSyncConfig] = attr.Factory(dict)
-
- def get_mutable(self) -> "MutablePerConnectionState":
- """Get a mutable copy of this state."""
- room_configs = cast(MutableMapping[str, RoomSyncConfig], self.room_configs)
-
- return MutablePerConnectionState(
- rooms=self.rooms.get_mutable(),
- receipts=self.receipts.get_mutable(),
- room_configs=ChainMap({}, room_configs),
- )
-
- def copy(self) -> "PerConnectionState":
- return PerConnectionState(
- rooms=self.rooms.copy(),
- receipts=self.receipts.copy(),
- room_configs=dict(self.room_configs),
- )
-
-
-@attr.s(auto_attribs=True)
-class MutablePerConnectionState(PerConnectionState):
- """A mutable version of `PerConnectionState`"""
-
- rooms: MutableRoomStatusMap[RoomStreamToken]
- receipts: MutableRoomStatusMap[MultiWriterStreamToken]
-
- room_configs: typing.ChainMap[str, RoomSyncConfig]
-
- def has_updates(self) -> bool:
- return (
- bool(self.rooms.get_updates())
- or bool(self.receipts.get_updates())
- or bool(self.get_room_config_updates())
- )
-
- def get_room_config_updates(self) -> Mapping[str, RoomSyncConfig]:
- """Get updates to the room sync config"""
- return self.room_configs.maps[0]
-
-
-@attr.s(auto_attribs=True)
-class SlidingSyncConnectionStore:
- """In-memory store of per-connection state, including what rooms we have
- previously sent down a sliding sync connection.
-
- Note: This is NOT safe to run in a worker setup because connection positions will
- point to different sets of rooms on different workers. e.g. for the same connection,
- a connection position of 5 might have totally different states on worker A and
- worker B.
-
- One complication that we need to deal with here is needing to handle requests being
- resent, i.e. if we sent down a room in a response that the client received, we must
- consider the room *not* sent when we get the request again.
-
- This is handled by using an integer "token", which is returned to the client
- as part of the sync token. For each connection we store a mapping from
- tokens to the room states, and create a new entry when we send down new
- rooms.
-
- Note that for any given sliding sync connection we will only store a maximum
- of two different tokens: the previous token from the request and a new token
- sent in the response. When we receive a request with a given token, we then
- clear out all other entries with a different token.
-
- Attributes:
- _connections: Mapping from `(user_id, conn_id)` to mapping of `token`
- to mapping of room ID to `HaveSentRoom`.
- """
-
- # `(user_id, conn_id)` -> `connection_position` -> `PerConnectionState`
- _connections: Dict[Tuple[str, str], Dict[int, PerConnectionState]] = attr.Factory(
- dict
- )
-
- async def is_valid_token(
- self, sync_config: SlidingSyncConfig, connection_token: int
- ) -> bool:
- """Return whether the connection token is valid/recognized"""
- if connection_token == 0:
- return True
-
- conn_key = self._get_connection_key(sync_config)
- return connection_token in self._connections.get(conn_key, {})
-
- async def get_per_connection_state(
- self,
- sync_config: SlidingSyncConfig,
- from_token: Optional[SlidingSyncStreamToken],
- ) -> PerConnectionState:
- """Fetch the per-connection state for the token.
-
- Raises:
- SlidingSyncUnknownPosition if the connection_token is unknown
- """
- if from_token is None:
- return PerConnectionState()
-
- connection_position = from_token.connection_position
- if connection_position == 0:
- # Initial sync (request without a `from_token`) starts at `0` so
- # there is no existing per-connection state
- return PerConnectionState()
-
- conn_key = self._get_connection_key(sync_config)
- sync_statuses = self._connections.get(conn_key, {})
- connection_state = sync_statuses.get(connection_position)
-
- if connection_state is None:
- raise SlidingSyncUnknownPosition()
-
- return connection_state
-
- @trace
- async def record_new_state(
- self,
- sync_config: SlidingSyncConfig,
- from_token: Optional[SlidingSyncStreamToken],
- new_connection_state: MutablePerConnectionState,
- ) -> int:
- """Record updated per-connection state, returning the connection
- position associated with the new state.
- If there are no changes to the state this may return the same token as
- the existing per-connection state.
- """
- prev_connection_token = 0
- if from_token is not None:
- prev_connection_token = from_token.connection_position
-
- if not new_connection_state.has_updates():
- return prev_connection_token
-
- conn_key = self._get_connection_key(sync_config)
- sync_statuses = self._connections.setdefault(conn_key, {})
-
- # Generate a new token, removing any existing entries in that token
- # (which can happen if requests get resent).
- new_store_token = prev_connection_token + 1
- sync_statuses.pop(new_store_token, None)
-
- # We copy the `MutablePerConnectionState` so that the inner `ChainMap`s
- # don't grow forever.
- sync_statuses[new_store_token] = new_connection_state.copy()
-
- return new_store_token
-
- @trace
- async def mark_token_seen(
- self,
- sync_config: SlidingSyncConfig,
- from_token: Optional[SlidingSyncStreamToken],
- ) -> None:
- """We have received a request with the given token, so we can clear out
- any other tokens associated with the connection.
-
- If there is no from token then we have started afresh, and so we delete
- all tokens associated with the device.
- """
- # Clear out any tokens for the connection that doesn't match the one
- # from the request.
-
- conn_key = self._get_connection_key(sync_config)
- sync_statuses = self._connections.pop(conn_key, {})
- if from_token is None:
- return
-
- sync_statuses = {
- connection_token: room_statuses
- for connection_token, room_statuses in sync_statuses.items()
- if connection_token == from_token.connection_position
- }
- if sync_statuses:
- self._connections[conn_key] = sync_statuses
-
- @staticmethod
- def _get_connection_key(sync_config: SlidingSyncConfig) -> Tuple[str, str]:
- """Return a unique identifier for this connection.
-
- The first part is simply the user ID.
-
- The second part is generally a combination of device ID and conn_id.
- However, both these two are optional (e.g. puppet access tokens don't
- have device IDs), so this handles those edge cases.
-
- We use this over the raw `conn_id` to avoid clashes between different
- clients that use the same `conn_id`. Imagine a user uses a web client
- that uses `conn_id: main_sync_loop` and an Android client that also has
- a `conn_id: main_sync_loop`.
- """
-
- user_id = sync_config.user.to_string()
-
- # Only one sliding sync connection is allowed per given conn_id (empty
- # or not).
- conn_id = sync_config.conn_id or ""
-
- if sync_config.requester.device_id:
- return (user_id, f"D/{sync_config.requester.device_id}/{conn_id}")
-
- if sync_config.requester.access_token_id:
- # If we don't have a device, then the access token ID should be a
- # stable ID.
- return (user_id, f"A/{sync_config.requester.access_token_id}/{conn_id}")
-
- # If we have neither then its likely an AS or some weird token. Either
- # way we can just fail here.
- raise Exception("Cannot use sliding sync with access token type")
diff --git a/synapse/handlers/sliding_sync/extensions.py b/synapse/handlers/sliding_sync/extensions.py
new file mode 100644
index 0000000000..599c74429e
--- /dev/null
+++ b/synapse/handlers/sliding_sync/extensions.py
@@ -0,0 +1,660 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2023 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+
+import logging
+from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Sequence, Set
+
+from typing_extensions import assert_never
+
+from synapse.api.constants import AccountDataTypes
+from synapse.handlers.receipts import ReceiptEventSource
+from synapse.handlers.sliding_sync.types import (
+ HaveSentRoomFlag,
+ MutablePerConnectionState,
+ PerConnectionState,
+)
+from synapse.logging.opentracing import trace
+from synapse.types import (
+ DeviceListUpdates,
+ JsonMapping,
+ MultiWriterStreamToken,
+ SlidingSyncStreamToken,
+ StreamToken,
+)
+from synapse.types.handlers import OperationType, SlidingSyncConfig, SlidingSyncResult
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class SlidingSyncExtensionHandler:
+ """Handles the extensions to sliding sync."""
+
+ def __init__(self, hs: "HomeServer"):
+ self.store = hs.get_datastores().main
+ self.event_sources = hs.get_event_sources()
+ self.device_handler = hs.get_device_handler()
+ self.push_rules_handler = hs.get_push_rules_handler()
+
+ @trace
+ async def get_extensions_response(
+ self,
+ sync_config: SlidingSyncConfig,
+ previous_connection_state: "PerConnectionState",
+ new_connection_state: "MutablePerConnectionState",
+ actual_lists: Dict[str, SlidingSyncResult.SlidingWindowList],
+ actual_room_ids: Set[str],
+ actual_room_response_map: Dict[str, SlidingSyncResult.RoomResult],
+ to_token: StreamToken,
+ from_token: Optional[SlidingSyncStreamToken],
+ ) -> SlidingSyncResult.Extensions:
+ """Handle extension requests.
+
+ Args:
+ sync_config: Sync configuration
+ new_connection_state: Snapshot of the current per-connection state
+ new_per_connection_state: A mutable copy of the per-connection
+ state, used to record updates to the state during this request.
+ actual_lists: Sliding window API. A map of list key to list results in the
+ Sliding Sync response.
+ actual_room_ids: The actual room IDs in the the Sliding Sync response.
+ actual_room_response_map: A map of room ID to room results in the the
+ Sliding Sync response.
+ to_token: The point in the stream to sync up to.
+ from_token: The point in the stream to sync from.
+ """
+
+ if sync_config.extensions is None:
+ return SlidingSyncResult.Extensions()
+
+ to_device_response = None
+ if sync_config.extensions.to_device is not None:
+ to_device_response = await self.get_to_device_extension_response(
+ sync_config=sync_config,
+ to_device_request=sync_config.extensions.to_device,
+ to_token=to_token,
+ )
+
+ e2ee_response = None
+ if sync_config.extensions.e2ee is not None:
+ e2ee_response = await self.get_e2ee_extension_response(
+ sync_config=sync_config,
+ e2ee_request=sync_config.extensions.e2ee,
+ to_token=to_token,
+ from_token=from_token,
+ )
+
+ account_data_response = None
+ if sync_config.extensions.account_data is not None:
+ account_data_response = await self.get_account_data_extension_response(
+ sync_config=sync_config,
+ actual_lists=actual_lists,
+ actual_room_ids=actual_room_ids,
+ account_data_request=sync_config.extensions.account_data,
+ to_token=to_token,
+ from_token=from_token,
+ )
+
+ receipts_response = None
+ if sync_config.extensions.receipts is not None:
+ receipts_response = await self.get_receipts_extension_response(
+ sync_config=sync_config,
+ previous_connection_state=previous_connection_state,
+ new_connection_state=new_connection_state,
+ actual_lists=actual_lists,
+ actual_room_ids=actual_room_ids,
+ actual_room_response_map=actual_room_response_map,
+ receipts_request=sync_config.extensions.receipts,
+ to_token=to_token,
+ from_token=from_token,
+ )
+
+ typing_response = None
+ if sync_config.extensions.typing is not None:
+ typing_response = await self.get_typing_extension_response(
+ sync_config=sync_config,
+ actual_lists=actual_lists,
+ actual_room_ids=actual_room_ids,
+ actual_room_response_map=actual_room_response_map,
+ typing_request=sync_config.extensions.typing,
+ to_token=to_token,
+ from_token=from_token,
+ )
+
+ return SlidingSyncResult.Extensions(
+ to_device=to_device_response,
+ e2ee=e2ee_response,
+ account_data=account_data_response,
+ receipts=receipts_response,
+ typing=typing_response,
+ )
+
+ def find_relevant_room_ids_for_extension(
+ self,
+ requested_lists: Optional[List[str]],
+ requested_room_ids: Optional[List[str]],
+ actual_lists: Dict[str, SlidingSyncResult.SlidingWindowList],
+ actual_room_ids: Set[str],
+ ) -> Set[str]:
+ """
+ Handle the reserved `lists`/`rooms` keys for extensions. Extensions should only
+ return results for rooms in the Sliding Sync response. This matches up the
+ requested rooms/lists with the actual lists/rooms in the Sliding Sync response.
+
+ {"lists": []} // Do not process any lists.
+ {"lists": ["rooms", "dms"]} // Process only a subset of lists.
+ {"lists": ["*"]} // Process all lists defined in the Sliding Window API. (This is the default.)
+
+ {"rooms": []} // Do not process any specific rooms.
+ {"rooms": ["!a:b", "!c:d"]} // Process only a subset of room subscriptions.
+ {"rooms": ["*"]} // Process all room subscriptions defined in the Room Subscription API. (This is the default.)
+
+ Args:
+ requested_lists: The `lists` from the extension request.
+ requested_room_ids: The `rooms` from the extension request.
+ actual_lists: The actual lists from the Sliding Sync response.
+ actual_room_ids: The actual room subscriptions from the Sliding Sync request.
+ """
+
+ # We only want to include account data for rooms that are already in the sliding
+ # sync response AND that were requested in the account data request.
+ relevant_room_ids: Set[str] = set()
+
+ # See what rooms from the room subscriptions we should get account data for
+ if requested_room_ids is not None:
+ for room_id in requested_room_ids:
+ # A wildcard means we process all rooms from the room subscriptions
+ if room_id == "*":
+ relevant_room_ids.update(actual_room_ids)
+ break
+
+ if room_id in actual_room_ids:
+ relevant_room_ids.add(room_id)
+
+ # See what rooms from the sliding window lists we should get account data for
+ if requested_lists is not None:
+ for list_key in requested_lists:
+ # Just some typing because we share the variable name in multiple places
+ actual_list: Optional[SlidingSyncResult.SlidingWindowList] = None
+
+ # A wildcard means we process rooms from all lists
+ if list_key == "*":
+ for actual_list in actual_lists.values():
+ # We only expect a single SYNC operation for any list
+ assert len(actual_list.ops) == 1
+ sync_op = actual_list.ops[0]
+ assert sync_op.op == OperationType.SYNC
+
+ relevant_room_ids.update(sync_op.room_ids)
+
+ break
+
+ actual_list = actual_lists.get(list_key)
+ if actual_list is not None:
+ # We only expect a single SYNC operation for any list
+ assert len(actual_list.ops) == 1
+ sync_op = actual_list.ops[0]
+ assert sync_op.op == OperationType.SYNC
+
+ relevant_room_ids.update(sync_op.room_ids)
+
+ return relevant_room_ids
+
+ @trace
+ async def get_to_device_extension_response(
+ self,
+ sync_config: SlidingSyncConfig,
+ to_device_request: SlidingSyncConfig.Extensions.ToDeviceExtension,
+ to_token: StreamToken,
+ ) -> Optional[SlidingSyncResult.Extensions.ToDeviceExtension]:
+ """Handle to-device extension (MSC3885)
+
+ Args:
+ sync_config: Sync configuration
+ to_device_request: The to-device extension from the request
+ to_token: The point in the stream to sync up to.
+ """
+ user_id = sync_config.user.to_string()
+ device_id = sync_config.requester.device_id
+
+ # Skip if the extension is not enabled
+ if not to_device_request.enabled:
+ return None
+
+ # Check that this request has a valid device ID (not all requests have
+ # to belong to a device, and so device_id is None)
+ if device_id is None:
+ return SlidingSyncResult.Extensions.ToDeviceExtension(
+ next_batch=f"{to_token.to_device_key}",
+ events=[],
+ )
+
+ since_stream_id = 0
+ if to_device_request.since is not None:
+ # We've already validated this is an int.
+ since_stream_id = int(to_device_request.since)
+
+ if to_token.to_device_key < since_stream_id:
+ # The since token is ahead of our current token, so we return an
+ # empty response.
+ logger.warning(
+ "Got to-device.since from the future. since token: %r is ahead of our current to_device stream position: %r",
+ since_stream_id,
+ to_token.to_device_key,
+ )
+ return SlidingSyncResult.Extensions.ToDeviceExtension(
+ next_batch=to_device_request.since,
+ events=[],
+ )
+
+ # Delete everything before the given since token, as we know the
+ # device must have received them.
+ deleted = await self.store.delete_messages_for_device(
+ user_id=user_id,
+ device_id=device_id,
+ up_to_stream_id=since_stream_id,
+ )
+
+ logger.debug(
+ "Deleted %d to-device messages up to %d for %s",
+ deleted,
+ since_stream_id,
+ user_id,
+ )
+
+ messages, stream_id = await self.store.get_messages_for_device(
+ user_id=user_id,
+ device_id=device_id,
+ from_stream_id=since_stream_id,
+ to_stream_id=to_token.to_device_key,
+ limit=min(to_device_request.limit, 100), # Limit to at most 100 events
+ )
+
+ return SlidingSyncResult.Extensions.ToDeviceExtension(
+ next_batch=f"{stream_id}",
+ events=messages,
+ )
+
+ @trace
+ async def get_e2ee_extension_response(
+ self,
+ sync_config: SlidingSyncConfig,
+ e2ee_request: SlidingSyncConfig.Extensions.E2eeExtension,
+ to_token: StreamToken,
+ from_token: Optional[SlidingSyncStreamToken],
+ ) -> Optional[SlidingSyncResult.Extensions.E2eeExtension]:
+ """Handle E2EE device extension (MSC3884)
+
+ Args:
+ sync_config: Sync configuration
+ e2ee_request: The e2ee extension from the request
+ to_token: The point in the stream to sync up to.
+ from_token: The point in the stream to sync from.
+ """
+ user_id = sync_config.user.to_string()
+ device_id = sync_config.requester.device_id
+
+ # Skip if the extension is not enabled
+ if not e2ee_request.enabled:
+ return None
+
+ device_list_updates: Optional[DeviceListUpdates] = None
+ if from_token is not None:
+ # TODO: This should take into account the `from_token` and `to_token`
+ device_list_updates = await self.device_handler.get_user_ids_changed(
+ user_id=user_id,
+ from_token=from_token.stream_token,
+ )
+
+ device_one_time_keys_count: Mapping[str, int] = {}
+ device_unused_fallback_key_types: Sequence[str] = []
+ if device_id:
+ # TODO: We should have a way to let clients differentiate between the states of:
+ # * no change in OTK count since the provided since token
+ # * the server has zero OTKs left for this device
+ # Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
+ device_one_time_keys_count = await self.store.count_e2e_one_time_keys(
+ user_id, device_id
+ )
+ device_unused_fallback_key_types = (
+ await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
+ )
+
+ return SlidingSyncResult.Extensions.E2eeExtension(
+ device_list_updates=device_list_updates,
+ device_one_time_keys_count=device_one_time_keys_count,
+ device_unused_fallback_key_types=device_unused_fallback_key_types,
+ )
+
+ @trace
+ async def get_account_data_extension_response(
+ self,
+ sync_config: SlidingSyncConfig,
+ actual_lists: Dict[str, SlidingSyncResult.SlidingWindowList],
+ actual_room_ids: Set[str],
+ account_data_request: SlidingSyncConfig.Extensions.AccountDataExtension,
+ to_token: StreamToken,
+ from_token: Optional[SlidingSyncStreamToken],
+ ) -> Optional[SlidingSyncResult.Extensions.AccountDataExtension]:
+ """Handle Account Data extension (MSC3959)
+
+ Args:
+ sync_config: Sync configuration
+ actual_lists: Sliding window API. A map of list key to list results in the
+ Sliding Sync response.
+ actual_room_ids: The actual room IDs in the the Sliding Sync response.
+ account_data_request: The account_data extension from the request
+ to_token: The point in the stream to sync up to.
+ from_token: The point in the stream to sync from.
+ """
+ user_id = sync_config.user.to_string()
+
+ # Skip if the extension is not enabled
+ if not account_data_request.enabled:
+ return None
+
+ global_account_data_map: Mapping[str, JsonMapping] = {}
+ if from_token is not None:
+ # TODO: This should take into account the `from_token` and `to_token`
+ global_account_data_map = (
+ await self.store.get_updated_global_account_data_for_user(
+ user_id, from_token.stream_token.account_data_key
+ )
+ )
+
+ have_push_rules_changed = await self.store.have_push_rules_changed_for_user(
+ user_id, from_token.stream_token.push_rules_key
+ )
+ if have_push_rules_changed:
+ global_account_data_map = dict(global_account_data_map)
+ # TODO: This should take into account the `from_token` and `to_token`
+ global_account_data_map[AccountDataTypes.PUSH_RULES] = (
+ await self.push_rules_handler.push_rules_for_user(sync_config.user)
+ )
+ else:
+ # TODO: This should take into account the `to_token`
+ all_global_account_data = await self.store.get_global_account_data_for_user(
+ user_id
+ )
+
+ global_account_data_map = dict(all_global_account_data)
+ # TODO: This should take into account the `to_token`
+ global_account_data_map[AccountDataTypes.PUSH_RULES] = (
+ await self.push_rules_handler.push_rules_for_user(sync_config.user)
+ )
+
+ # Fetch room account data
+ account_data_by_room_map: Mapping[str, Mapping[str, JsonMapping]] = {}
+ relevant_room_ids = self.find_relevant_room_ids_for_extension(
+ requested_lists=account_data_request.lists,
+ requested_room_ids=account_data_request.rooms,
+ actual_lists=actual_lists,
+ actual_room_ids=actual_room_ids,
+ )
+ if len(relevant_room_ids) > 0:
+ if from_token is not None:
+ # TODO: This should take into account the `from_token` and `to_token`
+ account_data_by_room_map = (
+ await self.store.get_updated_room_account_data_for_user(
+ user_id, from_token.stream_token.account_data_key
+ )
+ )
+ else:
+ # TODO: This should take into account the `to_token`
+ account_data_by_room_map = (
+ await self.store.get_room_account_data_for_user(user_id)
+ )
+
+ # Filter down to the relevant rooms
+ account_data_by_room_map = {
+ room_id: account_data_map
+ for room_id, account_data_map in account_data_by_room_map.items()
+ if room_id in relevant_room_ids
+ }
+
+ return SlidingSyncResult.Extensions.AccountDataExtension(
+ global_account_data_map=global_account_data_map,
+ account_data_by_room_map=account_data_by_room_map,
+ )
+
+ @trace
+ async def get_receipts_extension_response(
+ self,
+ sync_config: SlidingSyncConfig,
+ previous_connection_state: "PerConnectionState",
+ new_connection_state: "MutablePerConnectionState",
+ actual_lists: Dict[str, SlidingSyncResult.SlidingWindowList],
+ actual_room_ids: Set[str],
+ actual_room_response_map: Dict[str, SlidingSyncResult.RoomResult],
+ receipts_request: SlidingSyncConfig.Extensions.ReceiptsExtension,
+ to_token: StreamToken,
+ from_token: Optional[SlidingSyncStreamToken],
+ ) -> Optional[SlidingSyncResult.Extensions.ReceiptsExtension]:
+ """Handle Receipts extension (MSC3960)
+
+ Args:
+ sync_config: Sync configuration
+ previous_connection_state: The current per-connection state
+ new_connection_state: A mutable copy of the per-connection
+ state, used to record updates to the state.
+ actual_lists: Sliding window API. A map of list key to list results in the
+ Sliding Sync response.
+ actual_room_ids: The actual room IDs in the the Sliding Sync response.
+ actual_room_response_map: A map of room ID to room results in the the
+ Sliding Sync response.
+ account_data_request: The account_data extension from the request
+ to_token: The point in the stream to sync up to.
+ from_token: The point in the stream to sync from.
+ """
+ # Skip if the extension is not enabled
+ if not receipts_request.enabled:
+ return None
+
+ relevant_room_ids = self.find_relevant_room_ids_for_extension(
+ requested_lists=receipts_request.lists,
+ requested_room_ids=receipts_request.rooms,
+ actual_lists=actual_lists,
+ actual_room_ids=actual_room_ids,
+ )
+
+ room_id_to_receipt_map: Dict[str, JsonMapping] = {}
+ if len(relevant_room_ids) > 0:
+ # We need to handle the different cases depending on if we have sent
+ # down receipts previously or not, so we split the relevant rooms
+ # up into different collections based on status.
+ live_rooms = set()
+ previously_rooms: Dict[str, MultiWriterStreamToken] = {}
+ initial_rooms = set()
+
+ for room_id in relevant_room_ids:
+ if not from_token:
+ initial_rooms.add(room_id)
+ continue
+
+ # If we're sending down the room from scratch again for some reason, we
+ # should always resend the receipts as well (regardless of if
+ # we've sent them down before). This is to mimic the behaviour
+ # of what happens on initial sync, where you get a chunk of
+ # timeline with all of the corresponding receipts for the events in the timeline.
+ room_result = actual_room_response_map.get(room_id)
+ if room_result is not None and room_result.initial:
+ initial_rooms.add(room_id)
+ continue
+
+ room_status = previous_connection_state.receipts.have_sent_room(room_id)
+ if room_status.status == HaveSentRoomFlag.LIVE:
+ live_rooms.add(room_id)
+ elif room_status.status == HaveSentRoomFlag.PREVIOUSLY:
+ assert room_status.last_token is not None
+ previously_rooms[room_id] = room_status.last_token
+ elif room_status.status == HaveSentRoomFlag.NEVER:
+ initial_rooms.add(room_id)
+ else:
+ assert_never(room_status.status)
+
+ # The set of receipts that we fetched. Private receipts need to be
+ # filtered out before returning.
+ fetched_receipts = []
+
+ # For live rooms we just fetch all receipts in those rooms since the
+ # `since` token.
+ if live_rooms:
+ assert from_token is not None
+ receipts = await self.store.get_linearized_receipts_for_rooms(
+ room_ids=live_rooms,
+ from_key=from_token.stream_token.receipt_key,
+ to_key=to_token.receipt_key,
+ )
+ fetched_receipts.extend(receipts)
+
+ # For rooms we've previously sent down, but aren't up to date, we
+ # need to use the from token from the room status.
+ if previously_rooms:
+ for room_id, receipt_token in previously_rooms.items():
+ # TODO: Limit the number of receipts we're about to send down
+ # for the room, if its too many we should TODO
+ previously_receipts = (
+ await self.store.get_linearized_receipts_for_room(
+ room_id=room_id,
+ from_key=receipt_token,
+ to_key=to_token.receipt_key,
+ )
+ )
+ fetched_receipts.extend(previously_receipts)
+
+ # For rooms we haven't previously sent down, we could send all receipts
+ # from that room but we only want to include receipts for events
+ # in the timeline to avoid bloating and blowing up the sync response
+ # as the number of users in the room increases. (this behavior is part of the spec)
+ initial_rooms_and_event_ids = [
+ (room_id, event.event_id)
+ for room_id in initial_rooms
+ if room_id in actual_room_response_map
+ for event in actual_room_response_map[room_id].timeline_events
+ ]
+ if initial_rooms_and_event_ids:
+ initial_receipts = await self.store.get_linearized_receipts_for_events(
+ room_and_event_ids=initial_rooms_and_event_ids,
+ )
+ fetched_receipts.extend(initial_receipts)
+
+ fetched_receipts = ReceiptEventSource.filter_out_private_receipts(
+ fetched_receipts, sync_config.user.to_string()
+ )
+
+ for receipt in fetched_receipts:
+ # These fields should exist for every receipt
+ room_id = receipt["room_id"]
+ type = receipt["type"]
+ content = receipt["content"]
+
+ room_id_to_receipt_map[room_id] = {"type": type, "content": content}
+
+ # Now we update the per-connection state to track which receipts we have
+ # and haven't sent down.
+ new_connection_state.receipts.record_sent_rooms(relevant_room_ids)
+
+ if from_token:
+ # Now find the set of rooms that may have receipts that we're not sending
+ # down. We only need to check rooms that we have previously returned
+ # receipts for (in `previous_connection_state`) because we only care about
+ # updating `LIVE` rooms to `PREVIOUSLY`. The `PREVIOUSLY` rooms will just
+ # stay pointing at their previous position so we don't need to waste time
+ # checking those and since we default to `NEVER`, rooms that were `NEVER`
+ # sent before don't need to be recorded as we'll handle them correctly when
+ # they come into range for the first time.
+ rooms_no_receipts = [
+ room_id
+ for room_id, room_status in previous_connection_state.receipts._statuses.items()
+ if room_status.status == HaveSentRoomFlag.LIVE
+ and room_id not in relevant_room_ids
+ ]
+ changed_rooms = await self.store.get_rooms_with_receipts_between(
+ rooms_no_receipts,
+ from_key=from_token.stream_token.receipt_key,
+ to_key=to_token.receipt_key,
+ )
+ new_connection_state.receipts.record_unsent_rooms(
+ changed_rooms, from_token.stream_token.receipt_key
+ )
+
+ return SlidingSyncResult.Extensions.ReceiptsExtension(
+ room_id_to_receipt_map=room_id_to_receipt_map,
+ )
+
+ async def get_typing_extension_response(
+ self,
+ sync_config: SlidingSyncConfig,
+ actual_lists: Dict[str, SlidingSyncResult.SlidingWindowList],
+ actual_room_ids: Set[str],
+ actual_room_response_map: Dict[str, SlidingSyncResult.RoomResult],
+ typing_request: SlidingSyncConfig.Extensions.TypingExtension,
+ to_token: StreamToken,
+ from_token: Optional[SlidingSyncStreamToken],
+ ) -> Optional[SlidingSyncResult.Extensions.TypingExtension]:
+ """Handle Typing Notification extension (MSC3961)
+
+ Args:
+ sync_config: Sync configuration
+ actual_lists: Sliding window API. A map of list key to list results in the
+ Sliding Sync response.
+ actual_room_ids: The actual room IDs in the the Sliding Sync response.
+ actual_room_response_map: A map of room ID to room results in the the
+ Sliding Sync response.
+ account_data_request: The account_data extension from the request
+ to_token: The point in the stream to sync up to.
+ from_token: The point in the stream to sync from.
+ """
+ # Skip if the extension is not enabled
+ if not typing_request.enabled:
+ return None
+
+ relevant_room_ids = self.find_relevant_room_ids_for_extension(
+ requested_lists=typing_request.lists,
+ requested_room_ids=typing_request.rooms,
+ actual_lists=actual_lists,
+ actual_room_ids=actual_room_ids,
+ )
+
+ room_id_to_typing_map: Dict[str, JsonMapping] = {}
+ if len(relevant_room_ids) > 0:
+ # Note: We don't need to take connection tracking into account for typing
+ # notifications because they'll get anything still relevant and hasn't timed
+ # out when the room comes into range. We consider the gap where the room
+ # fell out of range, as long enough for any typing notifications to have
+ # timed out (it's not worth the 30 seconds of data we may have missed).
+ typing_source = self.event_sources.sources.typing
+ typing_notifications, _ = await typing_source.get_new_events(
+ user=sync_config.user,
+ from_key=(from_token.stream_token.typing_key if from_token else 0),
+ to_key=to_token.typing_key,
+ # This is a dummy value and isn't used in the function
+ limit=0,
+ room_ids=relevant_room_ids,
+ is_guest=False,
+ )
+
+ for typing_notification in typing_notifications:
+ # These fields should exist for every typing notification
+ room_id = typing_notification["room_id"]
+ type = typing_notification["type"]
+ content = typing_notification["content"]
+
+ room_id_to_typing_map[room_id] = {"type": type, "content": content}
+
+ return SlidingSyncResult.Extensions.TypingExtension(
+ room_id_to_typing_map=room_id_to_typing_map,
+ )
diff --git a/synapse/handlers/sliding_sync/store.py b/synapse/handlers/sliding_sync/store.py
new file mode 100644
index 0000000000..3b727432fb
--- /dev/null
+++ b/synapse/handlers/sliding_sync/store.py
@@ -0,0 +1,200 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2023 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+
+import logging
+from typing import TYPE_CHECKING, Dict, Optional, Tuple
+
+import attr
+
+from synapse.api.errors import SlidingSyncUnknownPosition
+from synapse.handlers.sliding_sync.types import (
+ MutablePerConnectionState,
+ PerConnectionState,
+)
+from synapse.logging.opentracing import trace
+from synapse.types import SlidingSyncStreamToken
+from synapse.types.handlers import SlidingSyncConfig
+
+if TYPE_CHECKING:
+ pass
+
+logger = logging.getLogger(__name__)
+
+
+@attr.s(auto_attribs=True)
+class SlidingSyncConnectionStore:
+ """In-memory store of per-connection state, including what rooms we have
+ previously sent down a sliding sync connection.
+
+ Note: This is NOT safe to run in a worker setup because connection positions will
+ point to different sets of rooms on different workers. e.g. for the same connection,
+ a connection position of 5 might have totally different states on worker A and
+ worker B.
+
+ One complication that we need to deal with here is needing to handle requests being
+ resent, i.e. if we sent down a room in a response that the client received, we must
+ consider the room *not* sent when we get the request again.
+
+ This is handled by using an integer "token", which is returned to the client
+ as part of the sync token. For each connection we store a mapping from
+ tokens to the room states, and create a new entry when we send down new
+ rooms.
+
+ Note that for any given sliding sync connection we will only store a maximum
+ of two different tokens: the previous token from the request and a new token
+ sent in the response. When we receive a request with a given token, we then
+ clear out all other entries with a different token.
+
+ Attributes:
+ _connections: Mapping from `(user_id, conn_id)` to mapping of `token`
+ to mapping of room ID to `HaveSentRoom`.
+ """
+
+ # `(user_id, conn_id)` -> `connection_position` -> `PerConnectionState`
+ _connections: Dict[Tuple[str, str], Dict[int, PerConnectionState]] = attr.Factory(
+ dict
+ )
+
+ async def is_valid_token(
+ self, sync_config: SlidingSyncConfig, connection_token: int
+ ) -> bool:
+ """Return whether the connection token is valid/recognized"""
+ if connection_token == 0:
+ return True
+
+ conn_key = self._get_connection_key(sync_config)
+ return connection_token in self._connections.get(conn_key, {})
+
+ async def get_per_connection_state(
+ self,
+ sync_config: SlidingSyncConfig,
+ from_token: Optional[SlidingSyncStreamToken],
+ ) -> PerConnectionState:
+ """Fetch the per-connection state for the token.
+
+ Raises:
+ SlidingSyncUnknownPosition if the connection_token is unknown
+ """
+ if from_token is None:
+ return PerConnectionState()
+
+ connection_position = from_token.connection_position
+ if connection_position == 0:
+ # Initial sync (request without a `from_token`) starts at `0` so
+ # there is no existing per-connection state
+ return PerConnectionState()
+
+ conn_key = self._get_connection_key(sync_config)
+ sync_statuses = self._connections.get(conn_key, {})
+ connection_state = sync_statuses.get(connection_position)
+
+ if connection_state is None:
+ raise SlidingSyncUnknownPosition()
+
+ return connection_state
+
+ @trace
+ async def record_new_state(
+ self,
+ sync_config: SlidingSyncConfig,
+ from_token: Optional[SlidingSyncStreamToken],
+ new_connection_state: MutablePerConnectionState,
+ ) -> int:
+ """Record updated per-connection state, returning the connection
+ position associated with the new state.
+ If there are no changes to the state this may return the same token as
+ the existing per-connection state.
+ """
+ prev_connection_token = 0
+ if from_token is not None:
+ prev_connection_token = from_token.connection_position
+
+ if not new_connection_state.has_updates():
+ return prev_connection_token
+
+ conn_key = self._get_connection_key(sync_config)
+ sync_statuses = self._connections.setdefault(conn_key, {})
+
+ # Generate a new token, removing any existing entries in that token
+ # (which can happen if requests get resent).
+ new_store_token = prev_connection_token + 1
+ sync_statuses.pop(new_store_token, None)
+
+ # We copy the `MutablePerConnectionState` so that the inner `ChainMap`s
+ # don't grow forever.
+ sync_statuses[new_store_token] = new_connection_state.copy()
+
+ return new_store_token
+
+ @trace
+ async def mark_token_seen(
+ self,
+ sync_config: SlidingSyncConfig,
+ from_token: Optional[SlidingSyncStreamToken],
+ ) -> None:
+ """We have received a request with the given token, so we can clear out
+ any other tokens associated with the connection.
+
+ If there is no from token then we have started afresh, and so we delete
+ all tokens associated with the device.
+ """
+ # Clear out any tokens for the connection that doesn't match the one
+ # from the request.
+
+ conn_key = self._get_connection_key(sync_config)
+ sync_statuses = self._connections.pop(conn_key, {})
+ if from_token is None:
+ return
+
+ sync_statuses = {
+ connection_token: room_statuses
+ for connection_token, room_statuses in sync_statuses.items()
+ if connection_token == from_token.connection_position
+ }
+ if sync_statuses:
+ self._connections[conn_key] = sync_statuses
+
+ @staticmethod
+ def _get_connection_key(sync_config: SlidingSyncConfig) -> Tuple[str, str]:
+ """Return a unique identifier for this connection.
+
+ The first part is simply the user ID.
+
+ The second part is generally a combination of device ID and conn_id.
+ However, both these two are optional (e.g. puppet access tokens don't
+ have device IDs), so this handles those edge cases.
+
+ We use this over the raw `conn_id` to avoid clashes between different
+ clients that use the same `conn_id`. Imagine a user uses a web client
+ that uses `conn_id: main_sync_loop` and an Android client that also has
+ a `conn_id: main_sync_loop`.
+ """
+
+ user_id = sync_config.user.to_string()
+
+ # Only one sliding sync connection is allowed per given conn_id (empty
+ # or not).
+ conn_id = sync_config.conn_id or ""
+
+ if sync_config.requester.device_id:
+ return (user_id, f"D/{sync_config.requester.device_id}/{conn_id}")
+
+ if sync_config.requester.access_token_id:
+ # If we don't have a device, then the access token ID should be a
+ # stable ID.
+ return (user_id, f"A/{sync_config.requester.access_token_id}/{conn_id}")
+
+ # If we have neither then its likely an AS or some weird token. Either
+ # way we can just fail here.
+ raise Exception("Cannot use sliding sync with access token type")
diff --git a/synapse/handlers/sliding_sync/types.py b/synapse/handlers/sliding_sync/types.py
new file mode 100644
index 0000000000..003419d40a
--- /dev/null
+++ b/synapse/handlers/sliding_sync/types.py
@@ -0,0 +1,506 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+
+import logging
+import typing
+from collections import ChainMap
+from enum import Enum
+from typing import (
+ TYPE_CHECKING,
+ Callable,
+ Dict,
+ Final,
+ Generic,
+ Mapping,
+ MutableMapping,
+ Optional,
+ Set,
+ TypeVar,
+ cast,
+)
+
+import attr
+
+from synapse.api.constants import EventTypes
+from synapse.types import MultiWriterStreamToken, RoomStreamToken, StrCollection, UserID
+from synapse.types.handlers import SlidingSyncConfig
+
+if TYPE_CHECKING:
+ pass
+
+logger = logging.getLogger(__name__)
+
+
+class StateValues:
+ """
+ Understood values of the (type, state_key) tuple in `required_state`.
+ """
+
+ # Include all state events of the given type
+ WILDCARD: Final = "*"
+ # Lazy-load room membership events (include room membership events for any event
+ # `sender` in the timeline). We only give special meaning to this value when it's a
+ # `state_key`.
+ LAZY: Final = "$LAZY"
+ # Subsitute with the requester's user ID. Typically used by clients to get
+ # the user's membership.
+ ME: Final = "$ME"
+
+
+# We can't freeze this class because we want to update it in place with the
+# de-duplicated data.
+@attr.s(slots=True, auto_attribs=True)
+class RoomSyncConfig:
+ """
+ Holds the config for what data we should fetch for a room in the sync response.
+
+ Attributes:
+ timeline_limit: The maximum number of events to return in the timeline.
+
+ required_state_map: Map from state event type to state_keys requested for the
+ room. The values are close to `StateKey` but actually use a syntax where you
+ can provide `*` wildcard and `$LAZY` for lazy-loading room members.
+ """
+
+ timeline_limit: int
+ required_state_map: Dict[str, Set[str]]
+
+ @classmethod
+ def from_room_config(
+ cls,
+ room_params: SlidingSyncConfig.CommonRoomParameters,
+ ) -> "RoomSyncConfig":
+ """
+ Create a `RoomSyncConfig` from a `SlidingSyncList`/`RoomSubscription` config.
+
+ Args:
+ room_params: `SlidingSyncConfig.SlidingSyncList` or `SlidingSyncConfig.RoomSubscription`
+ """
+ required_state_map: Dict[str, Set[str]] = {}
+ for (
+ state_type,
+ state_key,
+ ) in room_params.required_state:
+ # If we already have a wildcard for this specific `state_key`, we don't need
+ # to add it since the wildcard already covers it.
+ if state_key in required_state_map.get(StateValues.WILDCARD, set()):
+ continue
+
+ # If we already have a wildcard `state_key` for this `state_type`, we don't need
+ # to add anything else
+ if StateValues.WILDCARD in required_state_map.get(state_type, set()):
+ continue
+
+ # If we're getting wildcards for the `state_type` and `state_key`, that's
+ # all that matters so get rid of any other entries
+ if state_type == StateValues.WILDCARD and state_key == StateValues.WILDCARD:
+ required_state_map = {StateValues.WILDCARD: {StateValues.WILDCARD}}
+ # We can break, since we don't need to add anything else
+ break
+
+ # If we're getting a wildcard for the `state_type`, get rid of any other
+ # entries with the same `state_key`, since the wildcard will cover it already.
+ elif state_type == StateValues.WILDCARD:
+ # Get rid of any entries that match the `state_key`
+ #
+ # Make a copy so we don't run into an error: `dictionary changed size
+ # during iteration`, when we remove items
+ for (
+ existing_state_type,
+ existing_state_key_set,
+ ) in list(required_state_map.items()):
+ # Make a copy so we don't run into an error: `Set changed size during
+ # iteration`, when we filter out and remove items
+ for existing_state_key in existing_state_key_set.copy():
+ if existing_state_key == state_key:
+ existing_state_key_set.remove(state_key)
+
+ # If we've the left the `set()` empty, remove it from the map
+ if existing_state_key_set == set():
+ required_state_map.pop(existing_state_type, None)
+
+ # If we're getting a wildcard `state_key`, get rid of any other state_keys
+ # for this `state_type` since the wildcard will cover it already.
+ if state_key == StateValues.WILDCARD:
+ required_state_map[state_type] = {state_key}
+ # Otherwise, just add it to the set
+ else:
+ if required_state_map.get(state_type) is None:
+ required_state_map[state_type] = {state_key}
+ else:
+ required_state_map[state_type].add(state_key)
+
+ return cls(
+ timeline_limit=room_params.timeline_limit,
+ required_state_map=required_state_map,
+ )
+
+ def deep_copy(self) -> "RoomSyncConfig":
+ required_state_map: Dict[str, Set[str]] = {
+ state_type: state_key_set.copy()
+ for state_type, state_key_set in self.required_state_map.items()
+ }
+
+ return RoomSyncConfig(
+ timeline_limit=self.timeline_limit,
+ required_state_map=required_state_map,
+ )
+
+ def combine_room_sync_config(
+ self, other_room_sync_config: "RoomSyncConfig"
+ ) -> None:
+ """
+ Combine this `RoomSyncConfig` with another `RoomSyncConfig` and take the
+ superset union of the two.
+ """
+ # Take the highest timeline limit
+ if self.timeline_limit < other_room_sync_config.timeline_limit:
+ self.timeline_limit = other_room_sync_config.timeline_limit
+
+ # Union the required state
+ for (
+ state_type,
+ state_key_set,
+ ) in other_room_sync_config.required_state_map.items():
+ # If we already have a wildcard for everything, we don't need to add
+ # anything else
+ if StateValues.WILDCARD in self.required_state_map.get(
+ StateValues.WILDCARD, set()
+ ):
+ break
+
+ # If we already have a wildcard `state_key` for this `state_type`, we don't need
+ # to add anything else
+ if StateValues.WILDCARD in self.required_state_map.get(state_type, set()):
+ continue
+
+ # If we're getting wildcards for the `state_type` and `state_key`, that's
+ # all that matters so get rid of any other entries
+ if (
+ state_type == StateValues.WILDCARD
+ and StateValues.WILDCARD in state_key_set
+ ):
+ self.required_state_map = {state_type: {StateValues.WILDCARD}}
+ # We can break, since we don't need to add anything else
+ break
+
+ for state_key in state_key_set:
+ # If we already have a wildcard for this specific `state_key`, we don't need
+ # to add it since the wildcard already covers it.
+ if state_key in self.required_state_map.get(
+ StateValues.WILDCARD, set()
+ ):
+ continue
+
+ # If we're getting a wildcard for the `state_type`, get rid of any other
+ # entries with the same `state_key`, since the wildcard will cover it already.
+ if state_type == StateValues.WILDCARD:
+ # Get rid of any entries that match the `state_key`
+ #
+ # Make a copy so we don't run into an error: `dictionary changed size
+ # during iteration`, when we remove items
+ for existing_state_type, existing_state_key_set in list(
+ self.required_state_map.items()
+ ):
+ # Make a copy so we don't run into an error: `Set changed size during
+ # iteration`, when we filter out and remove items
+ for existing_state_key in existing_state_key_set.copy():
+ if existing_state_key == state_key:
+ existing_state_key_set.remove(state_key)
+
+ # If we've the left the `set()` empty, remove it from the map
+ if existing_state_key_set == set():
+ self.required_state_map.pop(existing_state_type, None)
+
+ # If we're getting a wildcard `state_key`, get rid of any other state_keys
+ # for this `state_type` since the wildcard will cover it already.
+ if state_key == StateValues.WILDCARD:
+ self.required_state_map[state_type] = {state_key}
+ break
+ # Otherwise, just add it to the set
+ else:
+ if self.required_state_map.get(state_type) is None:
+ self.required_state_map[state_type] = {state_key}
+ else:
+ self.required_state_map[state_type].add(state_key)
+
+ def must_await_full_state(
+ self,
+ is_mine_id: Callable[[str], bool],
+ ) -> bool:
+ """
+ Check if we have a we're only requesting `required_state` which is completely
+ satisfied even with partial state, then we don't need to `await_full_state` before
+ we can return it.
+
+ Also see `StateFilter.must_await_full_state(...)` for comparison
+
+ Partially-stated rooms should have all state events except for remote membership
+ events so if we require a remote membership event anywhere, then we need to
+ return `True` (requires full state).
+
+ Args:
+ is_mine_id: a callable which confirms if a given state_key matches a mxid
+ of a local user
+ """
+ wildcard_state_keys = self.required_state_map.get(StateValues.WILDCARD)
+ # Requesting *all* state in the room so we have to wait
+ if (
+ wildcard_state_keys is not None
+ and StateValues.WILDCARD in wildcard_state_keys
+ ):
+ return True
+
+ # If the wildcards don't refer to remote user IDs, then we don't need to wait
+ # for full state.
+ if wildcard_state_keys is not None:
+ for possible_user_id in wildcard_state_keys:
+ if not possible_user_id[0].startswith(UserID.SIGIL):
+ # Not a user ID
+ continue
+
+ localpart_hostname = possible_user_id.split(":", 1)
+ if len(localpart_hostname) < 2:
+ # Not a user ID
+ continue
+
+ if not is_mine_id(possible_user_id):
+ return True
+
+ membership_state_keys = self.required_state_map.get(EventTypes.Member)
+ # We aren't requesting any membership events at all so the partial state will
+ # cover us.
+ if membership_state_keys is None:
+ return False
+
+ # If we're requesting entirely local users, the partial state will cover us.
+ for user_id in membership_state_keys:
+ if user_id == StateValues.ME:
+ continue
+ # We're lazy-loading membership so we can just return the state we have.
+ # Lazy-loading means we include membership for any event `sender` in the
+ # timeline but since we had to auth those timeline events, we will have the
+ # membership state for them (including from remote senders).
+ elif user_id == StateValues.LAZY:
+ continue
+ elif user_id == StateValues.WILDCARD:
+ return False
+ elif not is_mine_id(user_id):
+ return True
+
+ # Local users only so the partial state will cover us.
+ return False
+
+
+class HaveSentRoomFlag(Enum):
+ """Flag for whether we have sent the room down a sliding sync connection.
+
+ The valid state changes here are:
+ NEVER -> LIVE
+ LIVE -> PREVIOUSLY
+ PREVIOUSLY -> LIVE
+ """
+
+ # The room has never been sent down (or we have forgotten we have sent it
+ # down).
+ NEVER = "never"
+
+ # We have previously sent the room down, but there are updates that we
+ # haven't sent down.
+ PREVIOUSLY = "previously"
+
+ # We have sent the room down and the client has received all updates.
+ LIVE = "live"
+
+
+T = TypeVar("T")
+
+
+@attr.s(auto_attribs=True, slots=True, frozen=True)
+class HaveSentRoom(Generic[T]):
+ """Whether we have sent the room data down a sliding sync connection.
+
+ We are generic over the type of token used, e.g. `RoomStreamToken` or
+ `MultiWriterStreamToken`.
+
+ Attributes:
+ status: Flag of if we have or haven't sent down the room
+ last_token: If the flag is `PREVIOUSLY` then this is non-null and
+ contains the last stream token of the last updates we sent down
+ the room, i.e. we still need to send everything since then to the
+ client.
+ """
+
+ status: HaveSentRoomFlag
+ last_token: Optional[T]
+
+ @staticmethod
+ def live() -> "HaveSentRoom[T]":
+ return HaveSentRoom(HaveSentRoomFlag.LIVE, None)
+
+ @staticmethod
+ def previously(last_token: T) -> "HaveSentRoom[T]":
+ """Constructor for `PREVIOUSLY` flag."""
+ return HaveSentRoom(HaveSentRoomFlag.PREVIOUSLY, last_token)
+
+ @staticmethod
+ def never() -> "HaveSentRoom[T]":
+ return HaveSentRoom(HaveSentRoomFlag.NEVER, None)
+
+
+@attr.s(auto_attribs=True, slots=True, frozen=True)
+class RoomStatusMap(Generic[T]):
+ """For a given stream, e.g. events, records what we have or have not sent
+ down for that stream in a given room."""
+
+ # `room_id` -> `HaveSentRoom`
+ _statuses: Mapping[str, HaveSentRoom[T]] = attr.Factory(dict)
+
+ def have_sent_room(self, room_id: str) -> HaveSentRoom[T]:
+ """Return whether we have previously sent the room down"""
+ return self._statuses.get(room_id, HaveSentRoom.never())
+
+ def get_mutable(self) -> "MutableRoomStatusMap[T]":
+ """Get a mutable copy of this state."""
+ return MutableRoomStatusMap(
+ statuses=self._statuses,
+ )
+
+ def copy(self) -> "RoomStatusMap[T]":
+ """Make a copy of the class. Useful for converting from a mutable to
+ immutable version."""
+
+ return RoomStatusMap(statuses=dict(self._statuses))
+
+
+class MutableRoomStatusMap(RoomStatusMap[T]):
+ """A mutable version of `RoomStatusMap`"""
+
+ # We use a ChainMap here so that we can easily track what has been updated
+ # and what hasn't. Note that when we persist the per connection state this
+ # will get flattened to a normal dict (via calling `.copy()`)
+ _statuses: typing.ChainMap[str, HaveSentRoom[T]]
+
+ def __init__(
+ self,
+ statuses: Mapping[str, HaveSentRoom[T]],
+ ) -> None:
+ # ChainMap requires a mutable mapping, but we're not actually going to
+ # mutate it.
+ statuses = cast(MutableMapping, statuses)
+
+ super().__init__(
+ statuses=ChainMap({}, statuses),
+ )
+
+ def get_updates(self) -> Mapping[str, HaveSentRoom[T]]:
+ """Return only the changes that were made"""
+ return self._statuses.maps[0]
+
+ def record_sent_rooms(self, room_ids: StrCollection) -> None:
+ """Record that we have sent these rooms in the response"""
+ for room_id in room_ids:
+ current_status = self._statuses.get(room_id, HaveSentRoom.never())
+ if current_status.status == HaveSentRoomFlag.LIVE:
+ continue
+
+ self._statuses[room_id] = HaveSentRoom.live()
+
+ def record_unsent_rooms(self, room_ids: StrCollection, from_token: T) -> None:
+ """Record that we have not sent these rooms in the response, but there
+ have been updates.
+ """
+ # Whether we add/update the entries for unsent rooms depends on the
+ # existing entry:
+ # - LIVE: We have previously sent down everything up to
+ # `last_room_token, so we update the entry to be `PREVIOUSLY` with
+ # `last_room_token`.
+ # - PREVIOUSLY: We have previously sent down everything up to *a*
+ # given token, so we don't need to update the entry.
+ # - NEVER: We have never previously sent down the room, and we haven't
+ # sent anything down this time either so we leave it as NEVER.
+
+ for room_id in room_ids:
+ current_status = self._statuses.get(room_id, HaveSentRoom.never())
+ if current_status.status != HaveSentRoomFlag.LIVE:
+ continue
+
+ self._statuses[room_id] = HaveSentRoom.previously(from_token)
+
+
+@attr.s(auto_attribs=True)
+class PerConnectionState:
+ """The per-connection state. A snapshot of what we've sent down the
+ connection before.
+
+ Currently, we track whether we've sent down various aspects of a given room
+ before.
+
+ We use the `rooms` field to store the position in the events stream for each
+ room that we've previously sent to the client before. On the next request
+ that includes the room, we can then send only what's changed since that
+ recorded position.
+
+ Same goes for the `receipts` field so we only need to send the new receipts
+ since the last time you made a sync request.
+
+ Attributes:
+ rooms: The status of each room for the events stream.
+ receipts: The status of each room for the receipts stream.
+ room_configs: Map from room_id to the `RoomSyncConfig` of all
+ rooms that we have previously sent down.
+ """
+
+ rooms: RoomStatusMap[RoomStreamToken] = attr.Factory(RoomStatusMap)
+ receipts: RoomStatusMap[MultiWriterStreamToken] = attr.Factory(RoomStatusMap)
+
+ room_configs: Mapping[str, RoomSyncConfig] = attr.Factory(dict)
+
+ def get_mutable(self) -> "MutablePerConnectionState":
+ """Get a mutable copy of this state."""
+ room_configs = cast(MutableMapping[str, RoomSyncConfig], self.room_configs)
+
+ return MutablePerConnectionState(
+ rooms=self.rooms.get_mutable(),
+ receipts=self.receipts.get_mutable(),
+ room_configs=ChainMap({}, room_configs),
+ )
+
+ def copy(self) -> "PerConnectionState":
+ return PerConnectionState(
+ rooms=self.rooms.copy(),
+ receipts=self.receipts.copy(),
+ room_configs=dict(self.room_configs),
+ )
+
+
+@attr.s(auto_attribs=True)
+class MutablePerConnectionState(PerConnectionState):
+ """A mutable version of `PerConnectionState`"""
+
+ rooms: MutableRoomStatusMap[RoomStreamToken]
+ receipts: MutableRoomStatusMap[MultiWriterStreamToken]
+
+ room_configs: typing.ChainMap[str, RoomSyncConfig]
+
+ def has_updates(self) -> bool:
+ return (
+ bool(self.rooms.get_updates())
+ or bool(self.receipts.get_updates())
+ or bool(self.get_room_config_updates())
+ )
+
+ def get_room_config_updates(self) -> Mapping[str, RoomSyncConfig]:
+ """Get updates to the room sync config"""
+ return self.room_configs.maps[0]
|