diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
index c362afa6e2..886d7c7159 100644
--- a/synapse/handlers/sliding_sync.py
+++ b/synapse/handlers/sliding_sync.py
@@ -19,7 +19,18 @@
#
import logging
from itertools import chain
-from typing import TYPE_CHECKING, Any, Dict, Final, List, Mapping, Optional, Set, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Final,
+ List,
+ Mapping,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+)
import attr
from immutabledict import immutabledict
@@ -33,6 +44,7 @@ from synapse.storage.databases.main.roommember import extract_heroes_from_room_s
from synapse.storage.databases.main.stream import CurrentStateDeltaMembership
from synapse.storage.roommember import MemberSummary
from synapse.types import (
+ DeviceListUpdates,
JsonDict,
PersistedEventPosition,
Requester,
@@ -343,6 +355,7 @@ class SlidingSyncHandler:
self.notifier = hs.get_notifier()
self.event_sources = hs.get_event_sources()
self.relations_handler = hs.get_relations_handler()
+ self.device_handler = hs.get_device_handler()
self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync
async def wait_for_sync_for_user(
@@ -371,10 +384,6 @@ class SlidingSyncHandler:
# auth_blocking will occur)
await self.auth_blocking.check_auth_blocking(requester=requester)
- # TODO: If the To-Device extension is enabled and we have a `from_token`, delete
- # any to-device messages before that token (since we now know that the device
- # has received them). (see sync v2 for how to do this)
-
# If we're working with a user-provided token, we need to make sure to wait for
# this worker to catch up with the token so we don't skip past any incoming
# events or future events if the user is nefariously, manually modifying the
@@ -617,7 +626,9 @@ class SlidingSyncHandler:
await concurrently_execute(handle_room, relevant_room_map, 10)
extensions = await self.get_extensions_response(
- sync_config=sync_config, to_token=to_token
+ sync_config=sync_config,
+ from_token=from_token,
+ to_token=to_token,
)
return SlidingSyncResult(
@@ -1776,33 +1787,47 @@ class SlidingSyncHandler:
self,
sync_config: SlidingSyncConfig,
to_token: StreamToken,
+ from_token: Optional[StreamToken],
) -> SlidingSyncResult.Extensions:
"""Handle extension requests.
Args:
sync_config: Sync configuration
to_token: The point in the stream to sync up to.
+ from_token: The point in the stream to sync from.
"""
if sync_config.extensions is None:
return SlidingSyncResult.Extensions()
to_device_response = None
- if sync_config.extensions.to_device:
- to_device_response = await self.get_to_device_extensions_response(
+ if sync_config.extensions.to_device is not None:
+ to_device_response = await self.get_to_device_extension_response(
sync_config=sync_config,
to_device_request=sync_config.extensions.to_device,
to_token=to_token,
)
- return SlidingSyncResult.Extensions(to_device=to_device_response)
+ e2ee_response = None
+ if sync_config.extensions.e2ee is not None:
+ e2ee_response = await self.get_e2ee_extension_response(
+ sync_config=sync_config,
+ e2ee_request=sync_config.extensions.e2ee,
+ to_token=to_token,
+ from_token=from_token,
+ )
- async def get_to_device_extensions_response(
+ return SlidingSyncResult.Extensions(
+ to_device=to_device_response,
+ e2ee=e2ee_response,
+ )
+
+ async def get_to_device_extension_response(
self,
sync_config: SlidingSyncConfig,
to_device_request: SlidingSyncConfig.Extensions.ToDeviceExtension,
to_token: StreamToken,
- ) -> SlidingSyncResult.Extensions.ToDeviceExtension:
+ ) -> Optional[SlidingSyncResult.Extensions.ToDeviceExtension]:
"""Handle to-device extension (MSC3885)
Args:
@@ -1810,14 +1835,16 @@ class SlidingSyncHandler:
to_device_request: The to-device extension from the request
to_token: The point in the stream to sync up to.
"""
-
user_id = sync_config.user.to_string()
device_id = sync_config.device_id
+ # Skip if the extension is not enabled
+ if not to_device_request.enabled:
+ return None
+
# Check that this request has a valid device ID (not all requests have
- # to belong to a device, and so device_id is None), and that the
- # extension is enabled.
- if device_id is None or not to_device_request.enabled:
+ # to belong to a device, and so device_id is None)
+ if device_id is None:
return SlidingSyncResult.Extensions.ToDeviceExtension(
next_batch=f"{to_token.to_device_key}",
events=[],
@@ -1868,3 +1895,53 @@ class SlidingSyncHandler:
next_batch=f"{stream_id}",
events=messages,
)
+
+ async def get_e2ee_extension_response(
+ self,
+ sync_config: SlidingSyncConfig,
+ e2ee_request: SlidingSyncConfig.Extensions.E2eeExtension,
+ to_token: StreamToken,
+ from_token: Optional[StreamToken],
+ ) -> Optional[SlidingSyncResult.Extensions.E2eeExtension]:
+ """Handle E2EE device extension (MSC3884)
+
+ Args:
+ sync_config: Sync configuration
+ e2ee_request: The e2ee extension from the request
+ to_token: The point in the stream to sync up to.
+ from_token: The point in the stream to sync from.
+ """
+ user_id = sync_config.user.to_string()
+ device_id = sync_config.device_id
+
+ # Skip if the extension is not enabled
+ if not e2ee_request.enabled:
+ return None
+
+ device_list_updates: Optional[DeviceListUpdates] = None
+ if from_token is not None:
+ # TODO: This should take into account the `from_token` and `to_token`
+ device_list_updates = await self.device_handler.get_user_ids_changed(
+ user_id=user_id,
+ from_token=from_token,
+ )
+
+ device_one_time_keys_count: Mapping[str, int] = {}
+ device_unused_fallback_key_types: Sequence[str] = []
+ if device_id:
+ # TODO: We should have a way to let clients differentiate between the states of:
+ # * no change in OTK count since the provided since token
+ # * the server has zero OTKs left for this device
+ # Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
+ device_one_time_keys_count = await self.store.count_e2e_one_time_keys(
+ user_id, device_id
+ )
+ device_unused_fallback_key_types = (
+ await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
+ )
+
+ return SlidingSyncResult.Extensions.E2eeExtension(
+ device_list_updates=device_list_updates,
+ device_one_time_keys_count=device_one_time_keys_count,
+ device_unused_fallback_key_types=device_unused_fallback_key_types,
+ )
|