diff --git a/changelog.d/17167.feature b/changelog.d/17167.feature
new file mode 100644
index 0000000000..5ad31db974
--- /dev/null
+++ b/changelog.d/17167.feature
@@ -0,0 +1 @@
+Add experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync/e2ee` endpoint for To-Device messages and device encryption info.
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 749452ce93..cda7afc5c4 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -332,6 +332,9 @@ class ExperimentalConfig(Config):
# MSC3391: Removing account data.
self.msc3391_enabled = experimental.get("msc3391_enabled", False)
+ # MSC3575 (Sliding Sync API endpoints)
+ self.msc3575_enabled: bool = experimental.get("msc3575_enabled", False)
+
# MSC3773: Thread notifications
self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index b7917a99d6..ac5bddd52f 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -28,11 +28,14 @@ from typing import (
Dict,
FrozenSet,
List,
+ Literal,
Mapping,
Optional,
Sequence,
Set,
Tuple,
+ Union,
+ overload,
)
import attr
@@ -128,6 +131,8 @@ class SyncVersion(Enum):
# Traditional `/sync` endpoint
SYNC_V2 = "sync_v2"
+ # Part of MSC3575 Sliding Sync
+ E2EE_SYNC = "e2ee_sync"
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -280,6 +285,26 @@ class SyncResult:
)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class E2eeSyncResult:
+ """
+ Attributes:
+ next_batch: Token for the next sync
+ to_device: List of direct messages for the device.
+ device_lists: List of user_ids whose devices have changed
+ device_one_time_keys_count: Dict of algorithm to count for one time keys
+ for this device
+ device_unused_fallback_key_types: List of key types that have an unused fallback
+ key
+ """
+
+ next_batch: StreamToken
+ to_device: List[JsonDict]
+ device_lists: DeviceListUpdates
+ device_one_time_keys_count: JsonMapping
+ device_unused_fallback_key_types: List[str]
+
+
class SyncHandler:
def __init__(self, hs: "HomeServer"):
self.hs_config = hs.config
@@ -322,6 +347,31 @@ class SyncHandler:
self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync
+ @overload
+ async def wait_for_sync_for_user(
+ self,
+ requester: Requester,
+ sync_config: SyncConfig,
+ sync_version: Literal[SyncVersion.SYNC_V2],
+ request_key: SyncRequestKey,
+ since_token: Optional[StreamToken] = None,
+ timeout: int = 0,
+ full_state: bool = False,
+ ) -> SyncResult: ...
+
+ @overload
+ async def wait_for_sync_for_user(
+ self,
+ requester: Requester,
+ sync_config: SyncConfig,
+ sync_version: Literal[SyncVersion.E2EE_SYNC],
+ request_key: SyncRequestKey,
+ since_token: Optional[StreamToken] = None,
+ timeout: int = 0,
+ full_state: bool = False,
+ ) -> E2eeSyncResult: ...
+
+ @overload
async def wait_for_sync_for_user(
self,
requester: Requester,
@@ -331,7 +381,18 @@ class SyncHandler:
since_token: Optional[StreamToken] = None,
timeout: int = 0,
full_state: bool = False,
- ) -> SyncResult:
+ ) -> Union[SyncResult, E2eeSyncResult]: ...
+
+ async def wait_for_sync_for_user(
+ self,
+ requester: Requester,
+ sync_config: SyncConfig,
+ sync_version: SyncVersion,
+ request_key: SyncRequestKey,
+ since_token: Optional[StreamToken] = None,
+ timeout: int = 0,
+ full_state: bool = False,
+ ) -> Union[SyncResult, E2eeSyncResult]:
"""Get the sync for a client if we have new data for it now. Otherwise
wait for new data to arrive on the server. If the timeout expires, then
return an empty sync result.
@@ -344,8 +405,10 @@ class SyncHandler:
since_token: The point in the stream to sync from.
timeout: How long to wait for new data to arrive before giving up.
full_state: Whether to return the full state for each room.
+
Returns:
When `SyncVersion.SYNC_V2`, returns a full `SyncResult`.
+ When `SyncVersion.E2EE_SYNC`, returns a `E2eeSyncResult`.
"""
# If the user is not part of the mau group, then check that limits have
# not been exceeded (if not part of the group by this point, almost certain
@@ -366,6 +429,29 @@ class SyncHandler:
logger.debug("Returning sync response for %s", user_id)
return res
+ @overload
+ async def _wait_for_sync_for_user(
+ self,
+ sync_config: SyncConfig,
+ sync_version: Literal[SyncVersion.SYNC_V2],
+ since_token: Optional[StreamToken],
+ timeout: int,
+ full_state: bool,
+ cache_context: ResponseCacheContext[SyncRequestKey],
+ ) -> SyncResult: ...
+
+ @overload
+ async def _wait_for_sync_for_user(
+ self,
+ sync_config: SyncConfig,
+ sync_version: Literal[SyncVersion.E2EE_SYNC],
+ since_token: Optional[StreamToken],
+ timeout: int,
+ full_state: bool,
+ cache_context: ResponseCacheContext[SyncRequestKey],
+ ) -> E2eeSyncResult: ...
+
+ @overload
async def _wait_for_sync_for_user(
self,
sync_config: SyncConfig,
@@ -374,7 +460,17 @@ class SyncHandler:
timeout: int,
full_state: bool,
cache_context: ResponseCacheContext[SyncRequestKey],
- ) -> SyncResult:
+ ) -> Union[SyncResult, E2eeSyncResult]: ...
+
+ async def _wait_for_sync_for_user(
+ self,
+ sync_config: SyncConfig,
+ sync_version: SyncVersion,
+ since_token: Optional[StreamToken],
+ timeout: int,
+ full_state: bool,
+ cache_context: ResponseCacheContext[SyncRequestKey],
+ ) -> Union[SyncResult, E2eeSyncResult]:
"""The start of the machinery that produces a /sync response.
See https://spec.matrix.org/v1.1/client-server-api/#syncing for full details.
@@ -417,14 +513,16 @@ class SyncHandler:
if timeout == 0 or since_token is None or full_state:
# we are going to return immediately, so don't bother calling
# notifier.wait_for_events.
- result: SyncResult = await self.current_sync_for_user(
- sync_config, sync_version, since_token, full_state=full_state
+ result: Union[SyncResult, E2eeSyncResult] = (
+ await self.current_sync_for_user(
+ sync_config, sync_version, since_token, full_state=full_state
+ )
)
else:
# Otherwise, we wait for something to happen and report it to the user.
async def current_sync_callback(
before_token: StreamToken, after_token: StreamToken
- ) -> SyncResult:
+ ) -> Union[SyncResult, E2eeSyncResult]:
return await self.current_sync_for_user(
sync_config, sync_version, since_token
)
@@ -456,14 +554,43 @@ class SyncHandler:
return result
+ @overload
+ async def current_sync_for_user(
+ self,
+ sync_config: SyncConfig,
+ sync_version: Literal[SyncVersion.SYNC_V2],
+ since_token: Optional[StreamToken] = None,
+ full_state: bool = False,
+ ) -> SyncResult: ...
+
+ @overload
+ async def current_sync_for_user(
+ self,
+ sync_config: SyncConfig,
+ sync_version: Literal[SyncVersion.E2EE_SYNC],
+ since_token: Optional[StreamToken] = None,
+ full_state: bool = False,
+ ) -> E2eeSyncResult: ...
+
+ @overload
async def current_sync_for_user(
self,
sync_config: SyncConfig,
sync_version: SyncVersion,
since_token: Optional[StreamToken] = None,
full_state: bool = False,
- ) -> SyncResult:
- """Generates the response body of a sync result, represented as a SyncResult.
+ ) -> Union[SyncResult, E2eeSyncResult]: ...
+
+ async def current_sync_for_user(
+ self,
+ sync_config: SyncConfig,
+ sync_version: SyncVersion,
+ since_token: Optional[StreamToken] = None,
+ full_state: bool = False,
+ ) -> Union[SyncResult, E2eeSyncResult]:
+ """
+ Generates the response body of a sync result, represented as a
+ `SyncResult`/`E2eeSyncResult`.
This is a wrapper around `generate_sync_result` which starts an open tracing
span to track the sync. See `generate_sync_result` for the next part of your
@@ -474,15 +601,25 @@ class SyncHandler:
sync_version: Determines what kind of sync response to generate.
since_token: The point in the stream to sync from.p.
full_state: Whether to return the full state for each room.
+
Returns:
When `SyncVersion.SYNC_V2`, returns a full `SyncResult`.
+ When `SyncVersion.E2EE_SYNC`, returns a `E2eeSyncResult`.
"""
with start_active_span("sync.current_sync_for_user"):
log_kv({"since_token": since_token})
+
# Go through the `/sync` v2 path
if sync_version == SyncVersion.SYNC_V2:
- sync_result: SyncResult = await self.generate_sync_result(
- sync_config, since_token, full_state
+ sync_result: Union[SyncResult, E2eeSyncResult] = (
+ await self.generate_sync_result(
+ sync_config, since_token, full_state
+ )
+ )
+ # Go through the MSC3575 Sliding Sync `/sync/e2ee` path
+ elif sync_version == SyncVersion.E2EE_SYNC:
+ sync_result = await self.generate_e2ee_sync_result(
+ sync_config, since_token
)
else:
raise Exception(
@@ -1691,6 +1828,96 @@ class SyncHandler:
next_batch=sync_result_builder.now_token,
)
+ async def generate_e2ee_sync_result(
+ self,
+ sync_config: SyncConfig,
+ since_token: Optional[StreamToken] = None,
+ ) -> E2eeSyncResult:
+ """
+ Generates the response body of a MSC3575 Sliding Sync `/sync/e2ee` result.
+
+ This is represented by a `E2eeSyncResult` struct, which is built from small
+ pieces using a `SyncResultBuilder`. The `sync_result_builder` is passed as a
+ mutable ("inout") parameter to various helper functions. These retrieve and
+ process the data which forms the sync body, often writing to the
+ `sync_result_builder` to store their output.
+
+ At the end, we transfer data from the `sync_result_builder` to a new `E2eeSyncResult`
+ instance to signify that the sync calculation is complete.
+ """
+ user_id = sync_config.user.to_string()
+ app_service = self.store.get_app_service_by_user_id(user_id)
+ if app_service:
+ # We no longer support AS users using /sync directly.
+ # See https://github.com/matrix-org/matrix-doc/issues/1144
+ raise NotImplementedError()
+
+ sync_result_builder = await self.get_sync_result_builder(
+ sync_config,
+ since_token,
+ full_state=False,
+ )
+
+ # 1. Calculate `to_device` events
+ await self._generate_sync_entry_for_to_device(sync_result_builder)
+
+ # 2. Calculate `device_lists`
+ # Device list updates are sent if a since token is provided.
+ device_lists = DeviceListUpdates()
+ include_device_list_updates = bool(since_token and since_token.device_list_key)
+ if include_device_list_updates:
+ # Note that _generate_sync_entry_for_rooms sets sync_result_builder.joined, which
+ # is used in calculate_user_changes below.
+ #
+ # TODO: Running `_generate_sync_entry_for_rooms()` is a lot of work just to
+ # figure out the membership changes/derived info needed for
+ # `_generate_sync_entry_for_device_list()`. In the future, we should try to
+ # refactor this away.
+ (
+ newly_joined_rooms,
+ newly_left_rooms,
+ ) = await self._generate_sync_entry_for_rooms(sync_result_builder)
+
+ # This uses the sync_result_builder.joined which is set in
+ # `_generate_sync_entry_for_rooms`, if that didn't find any joined
+ # rooms for some reason it is a no-op.
+ (
+ newly_joined_or_invited_or_knocked_users,
+ newly_left_users,
+ ) = sync_result_builder.calculate_user_changes()
+
+ device_lists = await self._generate_sync_entry_for_device_list(
+ sync_result_builder,
+ newly_joined_rooms=newly_joined_rooms,
+ newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users,
+ newly_left_rooms=newly_left_rooms,
+ newly_left_users=newly_left_users,
+ )
+
+ # 3. Calculate `device_one_time_keys_count` and `device_unused_fallback_key_types`
+ device_id = sync_config.device_id
+ one_time_keys_count: JsonMapping = {}
+ unused_fallback_key_types: List[str] = []
+ if device_id:
+ # TODO: We should have a way to let clients differentiate between the states of:
+ # * no change in OTK count since the provided since token
+ # * the server has zero OTKs left for this device
+ # Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
+ one_time_keys_count = await self.store.count_e2e_one_time_keys(
+ user_id, device_id
+ )
+ unused_fallback_key_types = list(
+ await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
+ )
+
+ return E2eeSyncResult(
+ to_device=sync_result_builder.to_device,
+ device_lists=device_lists,
+ device_one_time_keys_count=one_time_keys_count,
+ device_unused_fallback_key_types=unused_fallback_key_types,
+ next_batch=sync_result_builder.now_token,
+ )
+
async def get_sync_result_builder(
self,
sync_config: SyncConfig,
@@ -1889,7 +2116,7 @@ class SyncHandler:
users_that_have_changed = (
await self._device_handler.get_device_changes_in_shared_rooms(
user_id,
- sync_result_builder.joined_room_ids,
+ joined_room_ids,
from_token=since_token,
now_token=sync_result_builder.now_token,
)
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 4a57eaf930..27ea943e31 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -567,5 +567,176 @@ class SyncRestServlet(RestServlet):
return result
+class SlidingSyncE2eeRestServlet(RestServlet):
+ """
+ API endpoint for MSC3575 Sliding Sync `/sync/e2ee`. This is being introduced as part
+ of Sliding Sync but doesn't have any sliding window component. It's just a way to
+ get E2EE events without having to sit through a big initial sync (`/sync` v2). And
+ we can avoid encryption events being backed up by the main sync response.
+
+ Having To-Device messages split out to this sync endpoint also helps when clients
+ need to have 2 or more sync streams open at a time, e.g a push notification process
+ and a main process. This can cause the two processes to race to fetch the To-Device
+ events, resulting in the need for complex synchronisation rules to ensure the token
+ is correctly and atomically exchanged between processes.
+
+ GET parameters::
+ timeout(int): How long to wait for new events in milliseconds.
+ since(batch_token): Batch token when asking for incremental deltas.
+
+ Response JSON::
+ {
+ "next_batch": // batch token for the next /sync
+ "to_device": {
+ // list of to-device events
+ "events": [
+ {
+ "content: { "algorithm": "m.olm.v1.curve25519-aes-sha2", "ciphertext": { ... }, "org.matrix.msgid": "abcd", "session_id": "abcd" },
+ "type": "m.room.encrypted",
+ "sender": "@alice:example.com",
+ }
+ // ...
+ ]
+ },
+ "device_lists": {
+ "changed": ["@alice:example.com"],
+ "left": ["@bob:example.com"]
+ },
+ "device_one_time_keys_count": {
+ "signed_curve25519": 50
+ },
+ "device_unused_fallback_key_types": [
+ "signed_curve25519"
+ ]
+ }
+ """
+
+ PATTERNS = client_patterns(
+ "/org.matrix.msc3575/sync/e2ee$", releases=[], v1=False, unstable=True
+ )
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastores().main
+ self.sync_handler = hs.get_sync_handler()
+
+ # Filtering only matters for the `device_lists` because it requires a bunch of
+ # derived information from rooms (see how `_generate_sync_entry_for_rooms()`
+ # prepares a bunch of data for `_generate_sync_entry_for_device_list()`).
+ self.only_member_events_filter_collection = FilterCollection(
+ self.hs,
+ {
+ "room": {
+ # We only care about membership events for the `device_lists`.
+ # Membership will tell us whether a user has joined/left a room and
+ # if there are new devices to encrypt for.
+ "timeline": {
+ "types": ["m.room.member"],
+ },
+ "state": {
+ "types": ["m.room.member"],
+ },
+ # We don't want any extra account_data generated because it's not
+ # returned by this endpoint. This helps us avoid work in
+ # `_generate_sync_entry_for_rooms()`
+ "account_data": {
+ "not_types": ["*"],
+ },
+ # We don't want any extra ephemeral data generated because it's not
+ # returned by this endpoint. This helps us avoid work in
+ # `_generate_sync_entry_for_rooms()`
+ "ephemeral": {
+ "not_types": ["*"],
+ },
+ },
+ # We don't want any extra account_data generated because it's not
+ # returned by this endpoint. (This is just here for good measure)
+ "account_data": {
+ "not_types": ["*"],
+ },
+ # We don't want any extra presence data generated because it's not
+ # returned by this endpoint. (This is just here for good measure)
+ "presence": {
+ "not_types": ["*"],
+ },
+ },
+ )
+
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
+ user = requester.user
+ device_id = requester.device_id
+
+ timeout = parse_integer(request, "timeout", default=0)
+ since = parse_string(request, "since")
+
+ sync_config = SyncConfig(
+ user=user,
+ filter_collection=self.only_member_events_filter_collection,
+ is_guest=requester.is_guest,
+ device_id=device_id,
+ )
+
+ since_token = None
+ if since is not None:
+ since_token = await StreamToken.from_string(self.store, since)
+
+ # Request cache key
+ request_key = (
+ SyncVersion.E2EE_SYNC,
+ user,
+ timeout,
+ since,
+ )
+
+ # Gather data for the response
+ sync_result = await self.sync_handler.wait_for_sync_for_user(
+ requester,
+ sync_config,
+ SyncVersion.E2EE_SYNC,
+ request_key,
+ since_token=since_token,
+ timeout=timeout,
+ full_state=False,
+ )
+
+ # The client may have disconnected by now; don't bother to serialize the
+ # response if so.
+ if request._disconnected:
+ logger.info("Client has disconnected; not serializing response.")
+ return 200, {}
+
+ response: JsonDict = defaultdict(dict)
+ response["next_batch"] = await sync_result.next_batch.to_string(self.store)
+
+ if sync_result.to_device:
+ response["to_device"] = {"events": sync_result.to_device}
+
+ if sync_result.device_lists.changed:
+ response["device_lists"]["changed"] = list(sync_result.device_lists.changed)
+ if sync_result.device_lists.left:
+ response["device_lists"]["left"] = list(sync_result.device_lists.left)
+
+ # We always include this because https://github.com/vector-im/element-android/issues/3725
+ # The spec isn't terribly clear on when this can be omitted and how a client would tell
+ # the difference between "no keys present" and "nothing changed" in terms of whole field
+ # absent / individual key type entry absent
+ # Corresponding synapse issue: https://github.com/matrix-org/synapse/issues/10456
+ response["device_one_time_keys_count"] = sync_result.device_one_time_keys_count
+
+ # https://github.com/matrix-org/matrix-doc/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md
+ # states that this field should always be included, as long as the server supports the feature.
+ response["device_unused_fallback_key_types"] = (
+ sync_result.device_unused_fallback_key_types
+ )
+
+ return 200, response
+
+
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SyncRestServlet(hs).register(http_server)
+
+ if hs.config.experimental.msc3575_enabled:
+ SlidingSyncE2eeRestServlet(hs).register(http_server)
diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py
index 2b360732ac..a3ed12a38f 100644
--- a/tests/rest/client/test_devices.py
+++ b/tests/rest/client/test_devices.py
@@ -24,8 +24,8 @@ from twisted.internet.defer import ensureDeferred
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import NotFoundError
-from synapse.rest import admin, devices, room, sync
-from synapse.rest.client import account, keys, login, register
+from synapse.rest import admin, devices, sync
+from synapse.rest.client import keys, login, register
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
@@ -33,146 +33,6 @@ from synapse.util import Clock
from tests import unittest
-class DeviceListsTestCase(unittest.HomeserverTestCase):
- """Tests regarding device list changes."""
-
- servlets = [
- admin.register_servlets_for_client_rest_resource,
- login.register_servlets,
- register.register_servlets,
- account.register_servlets,
- room.register_servlets,
- sync.register_servlets,
- devices.register_servlets,
- ]
-
- def test_receiving_local_device_list_changes(self) -> None:
- """Tests that a local users that share a room receive each other's device list
- changes.
- """
- # Register two users
- test_device_id = "TESTDEVICE"
- alice_user_id = self.register_user("alice", "correcthorse")
- alice_access_token = self.login(
- alice_user_id, "correcthorse", device_id=test_device_id
- )
-
- bob_user_id = self.register_user("bob", "ponyponypony")
- bob_access_token = self.login(bob_user_id, "ponyponypony")
-
- # Create a room for them to coexist peacefully in
- new_room_id = self.helper.create_room_as(
- alice_user_id, is_public=True, tok=alice_access_token
- )
- self.assertIsNotNone(new_room_id)
-
- # Have Bob join the room
- self.helper.invite(
- new_room_id, alice_user_id, bob_user_id, tok=alice_access_token
- )
- self.helper.join(new_room_id, bob_user_id, tok=bob_access_token)
-
- # Now have Bob initiate an initial sync (in order to get a since token)
- channel = self.make_request(
- "GET",
- "/sync",
- access_token=bob_access_token,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
- next_batch_token = channel.json_body["next_batch"]
-
- # ...and then an incremental sync. This should block until the sync stream is woken up,
- # which we hope will happen as a result of Alice updating their device list.
- bob_sync_channel = self.make_request(
- "GET",
- f"/sync?since={next_batch_token}&timeout=30000",
- access_token=bob_access_token,
- # Start the request, then continue on.
- await_result=False,
- )
-
- # Have alice update their device list
- channel = self.make_request(
- "PUT",
- f"/devices/{test_device_id}",
- {
- "display_name": "New Device Name",
- },
- access_token=alice_access_token,
- )
- self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
-
- # Check that bob's incremental sync contains the updated device list.
- # If not, the client would only receive the device list update on the
- # *next* sync.
- bob_sync_channel.await_result()
- self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
-
- changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
- "changed", []
- )
- self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body)
-
- def test_not_receiving_local_device_list_changes(self) -> None:
- """Tests a local users DO NOT receive device updates from each other if they do not
- share a room.
- """
- # Register two users
- test_device_id = "TESTDEVICE"
- alice_user_id = self.register_user("alice", "correcthorse")
- alice_access_token = self.login(
- alice_user_id, "correcthorse", device_id=test_device_id
- )
-
- bob_user_id = self.register_user("bob", "ponyponypony")
- bob_access_token = self.login(bob_user_id, "ponyponypony")
-
- # These users do not share a room. They are lonely.
-
- # Have Bob initiate an initial sync (in order to get a since token)
- channel = self.make_request(
- "GET",
- "/sync",
- access_token=bob_access_token,
- )
- self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
- next_batch_token = channel.json_body["next_batch"]
-
- # ...and then an incremental sync. This should block until the sync stream is woken up,
- # which we hope will happen as a result of Alice updating their device list.
- bob_sync_channel = self.make_request(
- "GET",
- f"/sync?since={next_batch_token}&timeout=1000",
- access_token=bob_access_token,
- # Start the request, then continue on.
- await_result=False,
- )
-
- # Have alice update their device list
- channel = self.make_request(
- "PUT",
- f"/devices/{test_device_id}",
- {
- "display_name": "New Device Name",
- },
- access_token=alice_access_token,
- )
- self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
-
- # Check that bob's incremental sync does not contain the updated device list.
- bob_sync_channel.await_result()
- self.assertEqual(
- bob_sync_channel.code, HTTPStatus.OK, bob_sync_channel.json_body
- )
-
- changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
- "changed", []
- )
- self.assertNotIn(
- alice_user_id, changed_device_lists, bob_sync_channel.json_body
- )
-
-
class DevicesTestCase(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets,
diff --git a/tests/rest/client/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py
index 2f994ad553..5ef501c6d5 100644
--- a/tests/rest/client/test_sendtodevice.py
+++ b/tests/rest/client/test_sendtodevice.py
@@ -18,15 +18,39 @@
# [This file includes modifications made by New Vector Limited]
#
#
+from parameterized import parameterized_class
from synapse.api.constants import EduTypes
from synapse.rest import admin
from synapse.rest.client import login, sendtodevice, sync
+from synapse.types import JsonDict
from tests.unittest import HomeserverTestCase, override_config
+@parameterized_class(
+ ("sync_endpoint", "experimental_features"),
+ [
+ ("/sync", {}),
+ (
+ "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
+ # Enable sliding sync
+ {"msc3575_enabled": True},
+ ),
+ ],
+)
class SendToDeviceTestCase(HomeserverTestCase):
+ """
+ Test `/sendToDevice` will deliver messages across to people receiving them over `/sync`.
+
+ Attributes:
+ sync_endpoint: The endpoint under test to use for syncing.
+ experimental_features: The experimental features homeserver config to use.
+ """
+
+ sync_endpoint: str
+ experimental_features: JsonDict
+
servlets = [
admin.register_servlets,
login.register_servlets,
@@ -34,6 +58,11 @@ class SendToDeviceTestCase(HomeserverTestCase):
sync.register_servlets,
]
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["experimental_features"] = self.experimental_features
+ return config
+
def test_user_to_user(self) -> None:
"""A to-device message from one user to another should get delivered"""
@@ -54,7 +83,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
self.assertEqual(chan.code, 200, chan.result)
# check it appears
- channel = self.make_request("GET", "/sync", access_token=user2_tok)
+ channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
expected_result = {
"events": [
@@ -67,15 +96,19 @@ class SendToDeviceTestCase(HomeserverTestCase):
}
self.assertEqual(channel.json_body["to_device"], expected_result)
- # it should re-appear if we do another sync
- channel = self.make_request("GET", "/sync", access_token=user2_tok)
+ # it should re-appear if we do another sync because the to-device message is not
+ # deleted until we acknowledge it by sending a `?since=...` parameter in the
+ # next sync request corresponding to the `next_batch` value from the response.
+ channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body["to_device"], expected_result)
# it should *not* appear if we do an incremental sync
sync_token = channel.json_body["next_batch"]
channel = self.make_request(
- "GET", f"/sync?since={sync_token}", access_token=user2_tok
+ "GET",
+ f"{self.sync_endpoint}?since={sync_token}",
+ access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), [])
@@ -99,15 +132,19 @@ class SendToDeviceTestCase(HomeserverTestCase):
)
self.assertEqual(chan.code, 200, chan.result)
- # now sync: we should get two of the three
- channel = self.make_request("GET", "/sync", access_token=user2_tok)
+ # now sync: we should get two of the three (because burst_count=2)
+ channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 2)
for i in range(2):
self.assertEqual(
msgs[i],
- {"sender": user1, "type": "m.room_key_request", "content": {"idx": i}},
+ {
+ "sender": user1,
+ "type": "m.room_key_request",
+ "content": {"idx": i},
+ },
)
sync_token = channel.json_body["next_batch"]
@@ -125,7 +162,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
# ... which should arrive
channel = self.make_request(
- "GET", f"/sync?since={sync_token}", access_token=user2_tok
+ "GET",
+ f"{self.sync_endpoint}?since={sync_token}",
+ access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
@@ -159,7 +198,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
)
# now sync: we should get two of the three
- channel = self.make_request("GET", "/sync", access_token=user2_tok)
+ channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 2)
@@ -193,7 +232,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
# ... which should arrive
channel = self.make_request(
- "GET", f"/sync?since={sync_token}", access_token=user2_tok
+ "GET",
+ f"{self.sync_endpoint}?since={sync_token}",
+ access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
@@ -217,7 +258,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
user2_tok = self.login("u2", "pass", "d2")
# Do an initial sync
- channel = self.make_request("GET", "/sync", access_token=user2_tok)
+ channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
sync_token = channel.json_body["next_batch"]
@@ -233,7 +274,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
self.assertEqual(chan.code, 200, chan.result)
channel = self.make_request(
- "GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok
+ "GET",
+ f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
+ access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
messages = channel.json_body.get("to_device", {}).get("events", [])
@@ -241,7 +284,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
sync_token = channel.json_body["next_batch"]
channel = self.make_request(
- "GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok
+ "GET",
+ f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
+ access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
messages = channel.json_body.get("to_device", {}).get("events", [])
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 417a87feb2..daeb1d3ddd 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -21,7 +21,7 @@
import json
from typing import List
-from parameterized import parameterized
+from parameterized import parameterized, parameterized_class
from twisted.test.proto_helpers import MemoryReactor
@@ -688,24 +688,180 @@ class SyncCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.json_body)
+@parameterized_class(
+ ("sync_endpoint", "experimental_features"),
+ [
+ ("/sync", {}),
+ (
+ "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
+ # Enable sliding sync
+ {"msc3575_enabled": True},
+ ),
+ ],
+)
class DeviceListSyncTestCase(unittest.HomeserverTestCase):
+ """
+ Tests regarding device list (`device_lists`) changes.
+
+ Attributes:
+ sync_endpoint: The endpoint under test to use for syncing.
+ experimental_features: The experimental features homeserver config to use.
+ """
+
+ sync_endpoint: str
+ experimental_features: JsonDict
+
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
+ room.register_servlets,
sync.register_servlets,
devices.register_servlets,
]
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["experimental_features"] = self.experimental_features
+ return config
+
+ def test_receiving_local_device_list_changes(self) -> None:
+ """Tests that a local users that share a room receive each other's device list
+ changes.
+ """
+ # Register two users
+ test_device_id = "TESTDEVICE"
+ alice_user_id = self.register_user("alice", "correcthorse")
+ alice_access_token = self.login(
+ alice_user_id, "correcthorse", device_id=test_device_id
+ )
+
+ bob_user_id = self.register_user("bob", "ponyponypony")
+ bob_access_token = self.login(bob_user_id, "ponyponypony")
+
+ # Create a room for them to coexist peacefully in
+ new_room_id = self.helper.create_room_as(
+ alice_user_id, is_public=True, tok=alice_access_token
+ )
+ self.assertIsNotNone(new_room_id)
+
+ # Have Bob join the room
+ self.helper.invite(
+ new_room_id, alice_user_id, bob_user_id, tok=alice_access_token
+ )
+ self.helper.join(new_room_id, bob_user_id, tok=bob_access_token)
+
+ # Now have Bob initiate an initial sync (in order to get a since token)
+ channel = self.make_request(
+ "GET",
+ self.sync_endpoint,
+ access_token=bob_access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ next_batch_token = channel.json_body["next_batch"]
+
+ # ...and then an incremental sync. This should block until the sync stream is woken up,
+ # which we hope will happen as a result of Alice updating their device list.
+ bob_sync_channel = self.make_request(
+ "GET",
+ f"{self.sync_endpoint}?since={next_batch_token}&timeout=30000",
+ access_token=bob_access_token,
+ # Start the request, then continue on.
+ await_result=False,
+ )
+
+ # Have alice update their device list
+ channel = self.make_request(
+ "PUT",
+ f"/devices/{test_device_id}",
+ {
+ "display_name": "New Device Name",
+ },
+ access_token=alice_access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check that bob's incremental sync contains the updated device list.
+ # If not, the client would only receive the device list update on the
+ # *next* sync.
+ bob_sync_channel.await_result()
+ self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
+
+ changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
+ "changed", []
+ )
+ self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body)
+
+ def test_not_receiving_local_device_list_changes(self) -> None:
+ """Tests a local users DO NOT receive device updates from each other if they do not
+ share a room.
+ """
+ # Register two users
+ test_device_id = "TESTDEVICE"
+ alice_user_id = self.register_user("alice", "correcthorse")
+ alice_access_token = self.login(
+ alice_user_id, "correcthorse", device_id=test_device_id
+ )
+
+ bob_user_id = self.register_user("bob", "ponyponypony")
+ bob_access_token = self.login(bob_user_id, "ponyponypony")
+
+ # These users do not share a room. They are lonely.
+
+ # Have Bob initiate an initial sync (in order to get a since token)
+ channel = self.make_request(
+ "GET",
+ self.sync_endpoint,
+ access_token=bob_access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+ next_batch_token = channel.json_body["next_batch"]
+
+ # ...and then an incremental sync. This should block until the sync stream is woken up,
+ # which we hope will happen as a result of Alice updating their device list.
+ bob_sync_channel = self.make_request(
+ "GET",
+ f"{self.sync_endpoint}?since={next_batch_token}&timeout=1000",
+ access_token=bob_access_token,
+ # Start the request, then continue on.
+ await_result=False,
+ )
+
+ # Have alice update their device list
+ channel = self.make_request(
+ "PUT",
+ f"/devices/{test_device_id}",
+ {
+ "display_name": "New Device Name",
+ },
+ access_token=alice_access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check that bob's incremental sync does not contain the updated device list.
+ bob_sync_channel.await_result()
+ self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
+
+ changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
+ "changed", []
+ )
+ self.assertNotIn(
+ alice_user_id, changed_device_lists, bob_sync_channel.json_body
+ )
+
def test_user_with_no_rooms_receives_self_device_list_updates(self) -> None:
"""Tests that a user with no rooms still receives their own device list updates"""
- device_id = "TESTDEVICE"
+ test_device_id = "TESTDEVICE"
# Register a user and login, creating a device
- self.user_id = self.register_user("kermit", "monkey")
- self.tok = self.login("kermit", "monkey", device_id=device_id)
+ alice_user_id = self.register_user("alice", "correcthorse")
+ alice_access_token = self.login(
+ alice_user_id, "correcthorse", device_id=test_device_id
+ )
# Request an initial sync
- channel = self.make_request("GET", "/sync", access_token=self.tok)
+ channel = self.make_request(
+ "GET", self.sync_endpoint, access_token=alice_access_token
+ )
self.assertEqual(channel.code, 200, channel.json_body)
next_batch = channel.json_body["next_batch"]
@@ -713,19 +869,19 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
# It won't return until something has happened
incremental_sync_channel = self.make_request(
"GET",
- f"/sync?since={next_batch}&timeout=30000",
- access_token=self.tok,
+ f"{self.sync_endpoint}?since={next_batch}&timeout=30000",
+ access_token=alice_access_token,
await_result=False,
)
# Change our device's display name
channel = self.make_request(
"PUT",
- f"devices/{device_id}",
+ f"devices/{test_device_id}",
{
"display_name": "freeze ray",
},
- access_token=self.tok,
+ access_token=alice_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
@@ -739,7 +895,230 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
).get("changed", [])
self.assertIn(
- self.user_id, device_list_changes, incremental_sync_channel.json_body
+ alice_user_id, device_list_changes, incremental_sync_channel.json_body
+ )
+
+
+@parameterized_class(
+ ("sync_endpoint", "experimental_features"),
+ [
+ ("/sync", {}),
+ (
+ "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
+ # Enable sliding sync
+ {"msc3575_enabled": True},
+ ),
+ ],
+)
+class DeviceOneTimeKeysSyncTestCase(unittest.HomeserverTestCase):
+ """
+ Tests regarding device one time keys (`device_one_time_keys_count`) changes.
+
+ Attributes:
+ sync_endpoint: The endpoint under test to use for syncing.
+ experimental_features: The experimental features homeserver config to use.
+ """
+
+ sync_endpoint: str
+ experimental_features: JsonDict
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ devices.register_servlets,
+ ]
+
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["experimental_features"] = self.experimental_features
+ return config
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.e2e_keys_handler = hs.get_e2e_keys_handler()
+
+ def test_no_device_one_time_keys(self) -> None:
+ """
+ Tests when no one time keys set, it still has the default `signed_curve25519` in
+ `device_one_time_keys_count`
+ """
+ test_device_id = "TESTDEVICE"
+
+ alice_user_id = self.register_user("alice", "correcthorse")
+ alice_access_token = self.login(
+ alice_user_id, "correcthorse", device_id=test_device_id
+ )
+
+ # Request an initial sync
+ channel = self.make_request(
+ "GET", self.sync_endpoint, access_token=alice_access_token
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check for those one time key counts
+ self.assertDictEqual(
+ channel.json_body["device_one_time_keys_count"],
+ # Note that "signed_curve25519" is always returned in key count responses
+ # regardless of whether we uploaded any keys for it. This is necessary until
+ # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
+ {"signed_curve25519": 0},
+ channel.json_body["device_one_time_keys_count"],
+ )
+
+ def test_returns_device_one_time_keys(self) -> None:
+ """
+ Tests that one time keys for the device/user are counted correctly in the `/sync`
+ response
+ """
+ test_device_id = "TESTDEVICE"
+
+ alice_user_id = self.register_user("alice", "correcthorse")
+ alice_access_token = self.login(
+ alice_user_id, "correcthorse", device_id=test_device_id
+ )
+
+ # Upload one time keys for the user/device
+ keys: JsonDict = {
+ "alg1:k1": "key1",
+ "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
+ "alg2:k3": {"key": "key3"},
+ }
+ res = self.get_success(
+ self.e2e_keys_handler.upload_keys_for_user(
+ alice_user_id, test_device_id, {"one_time_keys": keys}
+ )
+ )
+ # Note that "signed_curve25519" is always returned in key count responses
+ # regardless of whether we uploaded any keys for it. This is necessary until
+ # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
+ self.assertDictEqual(
+ res,
+ {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}},
+ )
+
+ # Request an initial sync
+ channel = self.make_request(
+ "GET", self.sync_endpoint, access_token=alice_access_token
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check for those one time key counts
+ self.assertDictEqual(
+ channel.json_body["device_one_time_keys_count"],
+ {"alg1": 1, "alg2": 2, "signed_curve25519": 0},
+ channel.json_body["device_one_time_keys_count"],
+ )
+
+
+@parameterized_class(
+ ("sync_endpoint", "experimental_features"),
+ [
+ ("/sync", {}),
+ (
+ "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
+ # Enable sliding sync
+ {"msc3575_enabled": True},
+ ),
+ ],
+)
+class DeviceUnusedFallbackKeySyncTestCase(unittest.HomeserverTestCase):
+ """
+ Tests regarding device one time keys (`device_unused_fallback_key_types`) changes.
+
+ Attributes:
+ sync_endpoint: The endpoint under test to use for syncing.
+ experimental_features: The experimental features homeserver config to use.
+ """
+
+ sync_endpoint: str
+ experimental_features: JsonDict
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ sync.register_servlets,
+ devices.register_servlets,
+ ]
+
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["experimental_features"] = self.experimental_features
+ return config
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = self.hs.get_datastores().main
+ self.e2e_keys_handler = hs.get_e2e_keys_handler()
+
+ def test_no_device_unused_fallback_key(self) -> None:
+ """
+ Test when no unused fallback key is set, it just returns an empty list. The MSC
+ says "The device_unused_fallback_key_types parameter must be present if the
+ server supports fallback keys.",
+ https://github.com/matrix-org/matrix-spec-proposals/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md
+ """
+ test_device_id = "TESTDEVICE"
+
+ alice_user_id = self.register_user("alice", "correcthorse")
+ alice_access_token = self.login(
+ alice_user_id, "correcthorse", device_id=test_device_id
+ )
+
+ # Request an initial sync
+ channel = self.make_request(
+ "GET", self.sync_endpoint, access_token=alice_access_token
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check for those one time key counts
+ self.assertListEqual(
+ channel.json_body["device_unused_fallback_key_types"],
+ [],
+ channel.json_body["device_unused_fallback_key_types"],
+ )
+
+ def test_returns_device_one_time_keys(self) -> None:
+ """
+ Tests that device unused fallback key type is returned correctly in the `/sync`
+ """
+ test_device_id = "TESTDEVICE"
+
+ alice_user_id = self.register_user("alice", "correcthorse")
+ alice_access_token = self.login(
+ alice_user_id, "correcthorse", device_id=test_device_id
+ )
+
+ # We shouldn't have any unused fallback keys yet
+ res = self.get_success(
+ self.store.get_e2e_unused_fallback_key_types(alice_user_id, test_device_id)
+ )
+ self.assertEqual(res, [])
+
+ # Upload a fallback key for the user/device
+ fallback_key = {"alg1:k1": "fallback_key1"}
+ self.get_success(
+ self.e2e_keys_handler.upload_keys_for_user(
+ alice_user_id,
+ test_device_id,
+ {"fallback_keys": fallback_key},
+ )
+ )
+ # We should now have an unused alg1 key
+ fallback_res = self.get_success(
+ self.store.get_e2e_unused_fallback_key_types(alice_user_id, test_device_id)
+ )
+ self.assertEqual(fallback_res, ["alg1"], fallback_res)
+
+ # Request an initial sync
+ channel = self.make_request(
+ "GET", self.sync_endpoint, access_token=alice_access_token
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Check for the unused fallback key types
+ self.assertListEqual(
+ channel.json_body["device_unused_fallback_key_types"],
+ ["alg1"],
+ channel.json_body["device_unused_fallback_key_types"],
)
|