summary refs log tree commit diff
diff options
context:
space:
mode:
authorEric Eastwood <eric.eastwood@beta.gouv.fr>2024-05-23 12:06:16 -0500
committerGitHub <noreply@github.com>2024-05-23 12:06:16 -0500
commitc97251d5ba53b905036b3181afaa9c792777d1ff (patch)
treeb129b1a1517b1857f3ab4ad2463fb3b80d4a5708
parentLog exceptions when failing to auto-join new user according to the `auto_join... (diff)
downloadsynapse-c97251d5ba53b905036b3181afaa9c792777d1ff.tar.xz
Add Sliding Sync `/sync/e2ee` endpoint for To-Device messages (#17167)
This is being introduced as part of Sliding Sync but doesn't have any sliding window component. It's just a way to get E2EE events without having to sit through a big initial sync  (`/sync` v2). And we can avoid encryption events being backed up by the main sync response or vice-versa.

Part of some Sliding Sync simplification/experimentation. See [this discussion](https://github.com/element-hq/synapse/pull/17167#discussion_r1610495866) for why it may not be as useful as we thought.

Based on:

 - https://github.com/matrix-org/matrix-spec-proposals/pull/3575
 - https://github.com/matrix-org/matrix-spec-proposals/pull/3885
 - https://github.com/matrix-org/matrix-spec-proposals/pull/3884
-rw-r--r--changelog.d/17167.feature1
-rw-r--r--synapse/config/experimental.py3
-rw-r--r--synapse/handlers/sync.py247
-rw-r--r--synapse/rest/client/sync.py171
-rw-r--r--tests/rest/client/test_devices.py144
-rw-r--r--tests/rest/client/test_sendtodevice.py71
-rw-r--r--tests/rest/client/test_sync.py399
7 files changed, 861 insertions, 175 deletions
diff --git a/changelog.d/17167.feature b/changelog.d/17167.feature
new file mode 100644
index 0000000000..5ad31db974
--- /dev/null
+++ b/changelog.d/17167.feature
@@ -0,0 +1 @@
+Add experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync/e2ee` endpoint for To-Device messages and device encryption info.
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 749452ce93..cda7afc5c4 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -332,6 +332,9 @@ class ExperimentalConfig(Config):
         # MSC3391: Removing account data.
         self.msc3391_enabled = experimental.get("msc3391_enabled", False)
 
+        # MSC3575 (Sliding Sync API endpoints)
+        self.msc3575_enabled: bool = experimental.get("msc3575_enabled", False)
+
         # MSC3773: Thread notifications
         self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False)
 
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index b7917a99d6..ac5bddd52f 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -28,11 +28,14 @@ from typing import (
     Dict,
     FrozenSet,
     List,
+    Literal,
     Mapping,
     Optional,
     Sequence,
     Set,
     Tuple,
+    Union,
+    overload,
 )
 
 import attr
@@ -128,6 +131,8 @@ class SyncVersion(Enum):
 
     # Traditional `/sync` endpoint
     SYNC_V2 = "sync_v2"
+    # Part of MSC3575 Sliding Sync
+    E2EE_SYNC = "e2ee_sync"
 
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -280,6 +285,26 @@ class SyncResult:
         )
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class E2eeSyncResult:
+    """
+    Attributes:
+        next_batch: Token for the next sync
+        to_device: List of direct messages for the device.
+        device_lists: List of user_ids whose devices have changed
+        device_one_time_keys_count: Dict of algorithm to count for one time keys
+            for this device
+        device_unused_fallback_key_types: List of key types that have an unused fallback
+            key
+    """
+
+    next_batch: StreamToken
+    to_device: List[JsonDict]
+    device_lists: DeviceListUpdates
+    device_one_time_keys_count: JsonMapping
+    device_unused_fallback_key_types: List[str]
+
+
 class SyncHandler:
     def __init__(self, hs: "HomeServer"):
         self.hs_config = hs.config
@@ -322,6 +347,31 @@ class SyncHandler:
 
         self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync
 
+    @overload
+    async def wait_for_sync_for_user(
+        self,
+        requester: Requester,
+        sync_config: SyncConfig,
+        sync_version: Literal[SyncVersion.SYNC_V2],
+        request_key: SyncRequestKey,
+        since_token: Optional[StreamToken] = None,
+        timeout: int = 0,
+        full_state: bool = False,
+    ) -> SyncResult: ...
+
+    @overload
+    async def wait_for_sync_for_user(
+        self,
+        requester: Requester,
+        sync_config: SyncConfig,
+        sync_version: Literal[SyncVersion.E2EE_SYNC],
+        request_key: SyncRequestKey,
+        since_token: Optional[StreamToken] = None,
+        timeout: int = 0,
+        full_state: bool = False,
+    ) -> E2eeSyncResult: ...
+
+    @overload
     async def wait_for_sync_for_user(
         self,
         requester: Requester,
@@ -331,7 +381,18 @@ class SyncHandler:
         since_token: Optional[StreamToken] = None,
         timeout: int = 0,
         full_state: bool = False,
-    ) -> SyncResult:
+    ) -> Union[SyncResult, E2eeSyncResult]: ...
+
+    async def wait_for_sync_for_user(
+        self,
+        requester: Requester,
+        sync_config: SyncConfig,
+        sync_version: SyncVersion,
+        request_key: SyncRequestKey,
+        since_token: Optional[StreamToken] = None,
+        timeout: int = 0,
+        full_state: bool = False,
+    ) -> Union[SyncResult, E2eeSyncResult]:
         """Get the sync for a client if we have new data for it now. Otherwise
         wait for new data to arrive on the server. If the timeout expires, then
         return an empty sync result.
@@ -344,8 +405,10 @@ class SyncHandler:
             since_token: The point in the stream to sync from.
             timeout: How long to wait for new data to arrive before giving up.
             full_state: Whether to return the full state for each room.
+
         Returns:
             When `SyncVersion.SYNC_V2`, returns a full `SyncResult`.
+            When `SyncVersion.E2EE_SYNC`, returns a `E2eeSyncResult`.
         """
         # If the user is not part of the mau group, then check that limits have
         # not been exceeded (if not part of the group by this point, almost certain
@@ -366,6 +429,29 @@ class SyncHandler:
         logger.debug("Returning sync response for %s", user_id)
         return res
 
+    @overload
+    async def _wait_for_sync_for_user(
+        self,
+        sync_config: SyncConfig,
+        sync_version: Literal[SyncVersion.SYNC_V2],
+        since_token: Optional[StreamToken],
+        timeout: int,
+        full_state: bool,
+        cache_context: ResponseCacheContext[SyncRequestKey],
+    ) -> SyncResult: ...
+
+    @overload
+    async def _wait_for_sync_for_user(
+        self,
+        sync_config: SyncConfig,
+        sync_version: Literal[SyncVersion.E2EE_SYNC],
+        since_token: Optional[StreamToken],
+        timeout: int,
+        full_state: bool,
+        cache_context: ResponseCacheContext[SyncRequestKey],
+    ) -> E2eeSyncResult: ...
+
+    @overload
     async def _wait_for_sync_for_user(
         self,
         sync_config: SyncConfig,
@@ -374,7 +460,17 @@ class SyncHandler:
         timeout: int,
         full_state: bool,
         cache_context: ResponseCacheContext[SyncRequestKey],
-    ) -> SyncResult:
+    ) -> Union[SyncResult, E2eeSyncResult]: ...
+
+    async def _wait_for_sync_for_user(
+        self,
+        sync_config: SyncConfig,
+        sync_version: SyncVersion,
+        since_token: Optional[StreamToken],
+        timeout: int,
+        full_state: bool,
+        cache_context: ResponseCacheContext[SyncRequestKey],
+    ) -> Union[SyncResult, E2eeSyncResult]:
         """The start of the machinery that produces a /sync response.
 
         See https://spec.matrix.org/v1.1/client-server-api/#syncing for full details.
@@ -417,14 +513,16 @@ class SyncHandler:
         if timeout == 0 or since_token is None or full_state:
             # we are going to return immediately, so don't bother calling
             # notifier.wait_for_events.
-            result: SyncResult = await self.current_sync_for_user(
-                sync_config, sync_version, since_token, full_state=full_state
+            result: Union[SyncResult, E2eeSyncResult] = (
+                await self.current_sync_for_user(
+                    sync_config, sync_version, since_token, full_state=full_state
+                )
             )
         else:
             # Otherwise, we wait for something to happen and report it to the user.
             async def current_sync_callback(
                 before_token: StreamToken, after_token: StreamToken
-            ) -> SyncResult:
+            ) -> Union[SyncResult, E2eeSyncResult]:
                 return await self.current_sync_for_user(
                     sync_config, sync_version, since_token
                 )
@@ -456,14 +554,43 @@ class SyncHandler:
 
         return result
 
+    @overload
+    async def current_sync_for_user(
+        self,
+        sync_config: SyncConfig,
+        sync_version: Literal[SyncVersion.SYNC_V2],
+        since_token: Optional[StreamToken] = None,
+        full_state: bool = False,
+    ) -> SyncResult: ...
+
+    @overload
+    async def current_sync_for_user(
+        self,
+        sync_config: SyncConfig,
+        sync_version: Literal[SyncVersion.E2EE_SYNC],
+        since_token: Optional[StreamToken] = None,
+        full_state: bool = False,
+    ) -> E2eeSyncResult: ...
+
+    @overload
     async def current_sync_for_user(
         self,
         sync_config: SyncConfig,
         sync_version: SyncVersion,
         since_token: Optional[StreamToken] = None,
         full_state: bool = False,
-    ) -> SyncResult:
-        """Generates the response body of a sync result, represented as a SyncResult.
+    ) -> Union[SyncResult, E2eeSyncResult]: ...
+
+    async def current_sync_for_user(
+        self,
+        sync_config: SyncConfig,
+        sync_version: SyncVersion,
+        since_token: Optional[StreamToken] = None,
+        full_state: bool = False,
+    ) -> Union[SyncResult, E2eeSyncResult]:
+        """
+        Generates the response body of a sync result, represented as a
+        `SyncResult`/`E2eeSyncResult`.
 
         This is a wrapper around `generate_sync_result` which starts an open tracing
         span to track the sync. See `generate_sync_result` for the next part of your
@@ -474,15 +601,25 @@ class SyncHandler:
             sync_version: Determines what kind of sync response to generate.
             since_token: The point in the stream to sync from.p.
             full_state: Whether to return the full state for each room.
+
         Returns:
             When `SyncVersion.SYNC_V2`, returns a full `SyncResult`.
+            When `SyncVersion.E2EE_SYNC`, returns a `E2eeSyncResult`.
         """
         with start_active_span("sync.current_sync_for_user"):
             log_kv({"since_token": since_token})
+
             # Go through the `/sync` v2 path
             if sync_version == SyncVersion.SYNC_V2:
-                sync_result: SyncResult = await self.generate_sync_result(
-                    sync_config, since_token, full_state
+                sync_result: Union[SyncResult, E2eeSyncResult] = (
+                    await self.generate_sync_result(
+                        sync_config, since_token, full_state
+                    )
+                )
+            # Go through the MSC3575 Sliding Sync `/sync/e2ee` path
+            elif sync_version == SyncVersion.E2EE_SYNC:
+                sync_result = await self.generate_e2ee_sync_result(
+                    sync_config, since_token
                 )
             else:
                 raise Exception(
@@ -1691,6 +1828,96 @@ class SyncHandler:
             next_batch=sync_result_builder.now_token,
         )
 
+    async def generate_e2ee_sync_result(
+        self,
+        sync_config: SyncConfig,
+        since_token: Optional[StreamToken] = None,
+    ) -> E2eeSyncResult:
+        """
+        Generates the response body of a MSC3575 Sliding Sync `/sync/e2ee` result.
+
+        This is represented by a `E2eeSyncResult` struct, which is built from small
+        pieces using a `SyncResultBuilder`. The `sync_result_builder` is passed as a
+        mutable ("inout") parameter to various helper functions. These retrieve and
+        process the data which forms the sync body, often writing to the
+        `sync_result_builder` to store their output.
+
+        At the end, we transfer data from the `sync_result_builder` to a new `E2eeSyncResult`
+        instance to signify that the sync calculation is complete.
+        """
+        user_id = sync_config.user.to_string()
+        app_service = self.store.get_app_service_by_user_id(user_id)
+        if app_service:
+            # We no longer support AS users using /sync directly.
+            # See https://github.com/matrix-org/matrix-doc/issues/1144
+            raise NotImplementedError()
+
+        sync_result_builder = await self.get_sync_result_builder(
+            sync_config,
+            since_token,
+            full_state=False,
+        )
+
+        # 1. Calculate `to_device` events
+        await self._generate_sync_entry_for_to_device(sync_result_builder)
+
+        # 2. Calculate `device_lists`
+        # Device list updates are sent if a since token is provided.
+        device_lists = DeviceListUpdates()
+        include_device_list_updates = bool(since_token and since_token.device_list_key)
+        if include_device_list_updates:
+            # Note that _generate_sync_entry_for_rooms sets sync_result_builder.joined, which
+            # is used in calculate_user_changes below.
+            #
+            # TODO: Running `_generate_sync_entry_for_rooms()` is a lot of work just to
+            # figure out the membership changes/derived info needed for
+            # `_generate_sync_entry_for_device_list()`. In the future, we should try to
+            # refactor this away.
+            (
+                newly_joined_rooms,
+                newly_left_rooms,
+            ) = await self._generate_sync_entry_for_rooms(sync_result_builder)
+
+            # This uses the sync_result_builder.joined which is set in
+            # `_generate_sync_entry_for_rooms`, if that didn't find any joined
+            # rooms for some reason it is a no-op.
+            (
+                newly_joined_or_invited_or_knocked_users,
+                newly_left_users,
+            ) = sync_result_builder.calculate_user_changes()
+
+            device_lists = await self._generate_sync_entry_for_device_list(
+                sync_result_builder,
+                newly_joined_rooms=newly_joined_rooms,
+                newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users,
+                newly_left_rooms=newly_left_rooms,
+                newly_left_users=newly_left_users,
+            )
+
+        # 3. Calculate `device_one_time_keys_count` and `device_unused_fallback_key_types`
+        device_id = sync_config.device_id
+        one_time_keys_count: JsonMapping = {}
+        unused_fallback_key_types: List[str] = []
+        if device_id:
+            # TODO: We should have a way to let clients differentiate between the states of:
+            #   * no change in OTK count since the provided since token
+            #   * the server has zero OTKs left for this device
+            #  Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
+            one_time_keys_count = await self.store.count_e2e_one_time_keys(
+                user_id, device_id
+            )
+            unused_fallback_key_types = list(
+                await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
+            )
+
+        return E2eeSyncResult(
+            to_device=sync_result_builder.to_device,
+            device_lists=device_lists,
+            device_one_time_keys_count=one_time_keys_count,
+            device_unused_fallback_key_types=unused_fallback_key_types,
+            next_batch=sync_result_builder.now_token,
+        )
+
     async def get_sync_result_builder(
         self,
         sync_config: SyncConfig,
@@ -1889,7 +2116,7 @@ class SyncHandler:
         users_that_have_changed = (
             await self._device_handler.get_device_changes_in_shared_rooms(
                 user_id,
-                sync_result_builder.joined_room_ids,
+                joined_room_ids,
                 from_token=since_token,
                 now_token=sync_result_builder.now_token,
             )
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 4a57eaf930..27ea943e31 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -567,5 +567,176 @@ class SyncRestServlet(RestServlet):
         return result
 
 
+class SlidingSyncE2eeRestServlet(RestServlet):
+    """
+    API endpoint for MSC3575 Sliding Sync `/sync/e2ee`. This is being introduced as part
+    of Sliding Sync but doesn't have any sliding window component. It's just a way to
+    get E2EE events without having to sit through a big initial sync (`/sync` v2). And
+    we can avoid encryption events being backed up by the main sync response.
+
+    Having To-Device messages split out to this sync endpoint also helps when clients
+    need to have 2 or more sync streams open at a time, e.g a push notification process
+    and a main process. This can cause the two processes to race to fetch the To-Device
+    events, resulting in the need for complex synchronisation rules to ensure the token
+    is correctly and atomically exchanged between processes.
+
+    GET parameters::
+        timeout(int): How long to wait for new events in milliseconds.
+        since(batch_token): Batch token when asking for incremental deltas.
+
+    Response JSON::
+        {
+            "next_batch": // batch token for the next /sync
+            "to_device": {
+                // list of to-device events
+                "events": [
+                    {
+                        "content: { "algorithm": "m.olm.v1.curve25519-aes-sha2", "ciphertext": { ... }, "org.matrix.msgid": "abcd", "session_id": "abcd" },
+                        "type": "m.room.encrypted",
+                        "sender": "@alice:example.com",
+                    }
+                    // ...
+                ]
+            },
+            "device_lists": {
+                "changed": ["@alice:example.com"],
+                "left": ["@bob:example.com"]
+            },
+            "device_one_time_keys_count": {
+                "signed_curve25519": 50
+            },
+            "device_unused_fallback_key_types": [
+                "signed_curve25519"
+            ]
+        }
+    """
+
+    PATTERNS = client_patterns(
+        "/org.matrix.msc3575/sync/e2ee$", releases=[], v1=False, unstable=True
+    )
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastores().main
+        self.sync_handler = hs.get_sync_handler()
+
+        # Filtering only matters for the `device_lists` because it requires a bunch of
+        # derived information from rooms (see how `_generate_sync_entry_for_rooms()`
+        # prepares a bunch of data for `_generate_sync_entry_for_device_list()`).
+        self.only_member_events_filter_collection = FilterCollection(
+            self.hs,
+            {
+                "room": {
+                    # We only care about membership events for the `device_lists`.
+                    # Membership will tell us whether a user has joined/left a room and
+                    # if there are new devices to encrypt for.
+                    "timeline": {
+                        "types": ["m.room.member"],
+                    },
+                    "state": {
+                        "types": ["m.room.member"],
+                    },
+                    # We don't want any extra account_data generated because it's not
+                    # returned by this endpoint. This helps us avoid work in
+                    # `_generate_sync_entry_for_rooms()`
+                    "account_data": {
+                        "not_types": ["*"],
+                    },
+                    # We don't want any extra ephemeral data generated because it's not
+                    # returned by this endpoint. This helps us avoid work in
+                    # `_generate_sync_entry_for_rooms()`
+                    "ephemeral": {
+                        "not_types": ["*"],
+                    },
+                },
+                # We don't want any extra account_data generated because it's not
+                # returned by this endpoint. (This is just here for good measure)
+                "account_data": {
+                    "not_types": ["*"],
+                },
+                # We don't want any extra presence data generated because it's not
+                # returned by this endpoint. (This is just here for good measure)
+                "presence": {
+                    "not_types": ["*"],
+                },
+            },
+        )
+
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request, allow_guest=True)
+        user = requester.user
+        device_id = requester.device_id
+
+        timeout = parse_integer(request, "timeout", default=0)
+        since = parse_string(request, "since")
+
+        sync_config = SyncConfig(
+            user=user,
+            filter_collection=self.only_member_events_filter_collection,
+            is_guest=requester.is_guest,
+            device_id=device_id,
+        )
+
+        since_token = None
+        if since is not None:
+            since_token = await StreamToken.from_string(self.store, since)
+
+        # Request cache key
+        request_key = (
+            SyncVersion.E2EE_SYNC,
+            user,
+            timeout,
+            since,
+        )
+
+        # Gather data for the response
+        sync_result = await self.sync_handler.wait_for_sync_for_user(
+            requester,
+            sync_config,
+            SyncVersion.E2EE_SYNC,
+            request_key,
+            since_token=since_token,
+            timeout=timeout,
+            full_state=False,
+        )
+
+        # The client may have disconnected by now; don't bother to serialize the
+        # response if so.
+        if request._disconnected:
+            logger.info("Client has disconnected; not serializing response.")
+            return 200, {}
+
+        response: JsonDict = defaultdict(dict)
+        response["next_batch"] = await sync_result.next_batch.to_string(self.store)
+
+        if sync_result.to_device:
+            response["to_device"] = {"events": sync_result.to_device}
+
+        if sync_result.device_lists.changed:
+            response["device_lists"]["changed"] = list(sync_result.device_lists.changed)
+        if sync_result.device_lists.left:
+            response["device_lists"]["left"] = list(sync_result.device_lists.left)
+
+        # We always include this because https://github.com/vector-im/element-android/issues/3725
+        # The spec isn't terribly clear on when this can be omitted and how a client would tell
+        # the difference between "no keys present" and "nothing changed" in terms of whole field
+        # absent / individual key type entry absent
+        # Corresponding synapse issue: https://github.com/matrix-org/synapse/issues/10456
+        response["device_one_time_keys_count"] = sync_result.device_one_time_keys_count
+
+        # https://github.com/matrix-org/matrix-doc/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md
+        # states that this field should always be included, as long as the server supports the feature.
+        response["device_unused_fallback_key_types"] = (
+            sync_result.device_unused_fallback_key_types
+        )
+
+        return 200, response
+
+
 def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     SyncRestServlet(hs).register(http_server)
+
+    if hs.config.experimental.msc3575_enabled:
+        SlidingSyncE2eeRestServlet(hs).register(http_server)
diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py
index 2b360732ac..a3ed12a38f 100644
--- a/tests/rest/client/test_devices.py
+++ b/tests/rest/client/test_devices.py
@@ -24,8 +24,8 @@ from twisted.internet.defer import ensureDeferred
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.errors import NotFoundError
-from synapse.rest import admin, devices, room, sync
-from synapse.rest.client import account, keys, login, register
+from synapse.rest import admin, devices, sync
+from synapse.rest.client import keys, login, register
 from synapse.server import HomeServer
 from synapse.types import JsonDict, UserID, create_requester
 from synapse.util import Clock
@@ -33,146 +33,6 @@ from synapse.util import Clock
 from tests import unittest
 
 
-class DeviceListsTestCase(unittest.HomeserverTestCase):
-    """Tests regarding device list changes."""
-
-    servlets = [
-        admin.register_servlets_for_client_rest_resource,
-        login.register_servlets,
-        register.register_servlets,
-        account.register_servlets,
-        room.register_servlets,
-        sync.register_servlets,
-        devices.register_servlets,
-    ]
-
-    def test_receiving_local_device_list_changes(self) -> None:
-        """Tests that a local users that share a room receive each other's device list
-        changes.
-        """
-        # Register two users
-        test_device_id = "TESTDEVICE"
-        alice_user_id = self.register_user("alice", "correcthorse")
-        alice_access_token = self.login(
-            alice_user_id, "correcthorse", device_id=test_device_id
-        )
-
-        bob_user_id = self.register_user("bob", "ponyponypony")
-        bob_access_token = self.login(bob_user_id, "ponyponypony")
-
-        # Create a room for them to coexist peacefully in
-        new_room_id = self.helper.create_room_as(
-            alice_user_id, is_public=True, tok=alice_access_token
-        )
-        self.assertIsNotNone(new_room_id)
-
-        # Have Bob join the room
-        self.helper.invite(
-            new_room_id, alice_user_id, bob_user_id, tok=alice_access_token
-        )
-        self.helper.join(new_room_id, bob_user_id, tok=bob_access_token)
-
-        # Now have Bob initiate an initial sync (in order to get a since token)
-        channel = self.make_request(
-            "GET",
-            "/sync",
-            access_token=bob_access_token,
-        )
-        self.assertEqual(channel.code, 200, channel.json_body)
-        next_batch_token = channel.json_body["next_batch"]
-
-        # ...and then an incremental sync. This should block until the sync stream is woken up,
-        # which we hope will happen as a result of Alice updating their device list.
-        bob_sync_channel = self.make_request(
-            "GET",
-            f"/sync?since={next_batch_token}&timeout=30000",
-            access_token=bob_access_token,
-            # Start the request, then continue on.
-            await_result=False,
-        )
-
-        # Have alice update their device list
-        channel = self.make_request(
-            "PUT",
-            f"/devices/{test_device_id}",
-            {
-                "display_name": "New Device Name",
-            },
-            access_token=alice_access_token,
-        )
-        self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
-
-        # Check that bob's incremental sync contains the updated device list.
-        # If not, the client would only receive the device list update on the
-        # *next* sync.
-        bob_sync_channel.await_result()
-        self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
-
-        changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
-            "changed", []
-        )
-        self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body)
-
-    def test_not_receiving_local_device_list_changes(self) -> None:
-        """Tests a local users DO NOT receive device updates from each other if they do not
-        share a room.
-        """
-        # Register two users
-        test_device_id = "TESTDEVICE"
-        alice_user_id = self.register_user("alice", "correcthorse")
-        alice_access_token = self.login(
-            alice_user_id, "correcthorse", device_id=test_device_id
-        )
-
-        bob_user_id = self.register_user("bob", "ponyponypony")
-        bob_access_token = self.login(bob_user_id, "ponyponypony")
-
-        # These users do not share a room. They are lonely.
-
-        # Have Bob initiate an initial sync (in order to get a since token)
-        channel = self.make_request(
-            "GET",
-            "/sync",
-            access_token=bob_access_token,
-        )
-        self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
-        next_batch_token = channel.json_body["next_batch"]
-
-        # ...and then an incremental sync. This should block until the sync stream is woken up,
-        # which we hope will happen as a result of Alice updating their device list.
-        bob_sync_channel = self.make_request(
-            "GET",
-            f"/sync?since={next_batch_token}&timeout=1000",
-            access_token=bob_access_token,
-            # Start the request, then continue on.
-            await_result=False,
-        )
-
-        # Have alice update their device list
-        channel = self.make_request(
-            "PUT",
-            f"/devices/{test_device_id}",
-            {
-                "display_name": "New Device Name",
-            },
-            access_token=alice_access_token,
-        )
-        self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
-
-        # Check that bob's incremental sync does not contain the updated device list.
-        bob_sync_channel.await_result()
-        self.assertEqual(
-            bob_sync_channel.code, HTTPStatus.OK, bob_sync_channel.json_body
-        )
-
-        changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
-            "changed", []
-        )
-        self.assertNotIn(
-            alice_user_id, changed_device_lists, bob_sync_channel.json_body
-        )
-
-
 class DevicesTestCase(unittest.HomeserverTestCase):
     servlets = [
         admin.register_servlets,
diff --git a/tests/rest/client/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py
index 2f994ad553..5ef501c6d5 100644
--- a/tests/rest/client/test_sendtodevice.py
+++ b/tests/rest/client/test_sendtodevice.py
@@ -18,15 +18,39 @@
 # [This file includes modifications made by New Vector Limited]
 #
 #
+from parameterized import parameterized_class
 
 from synapse.api.constants import EduTypes
 from synapse.rest import admin
 from synapse.rest.client import login, sendtodevice, sync
+from synapse.types import JsonDict
 
 from tests.unittest import HomeserverTestCase, override_config
 
 
+@parameterized_class(
+    ("sync_endpoint", "experimental_features"),
+    [
+        ("/sync", {}),
+        (
+            "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
+            # Enable sliding sync
+            {"msc3575_enabled": True},
+        ),
+    ],
+)
 class SendToDeviceTestCase(HomeserverTestCase):
+    """
+    Test `/sendToDevice` will deliver messages across to people receiving them over `/sync`.
+
+    Attributes:
+        sync_endpoint: The endpoint under test to use for syncing.
+        experimental_features: The experimental features homeserver config to use.
+    """
+
+    sync_endpoint: str
+    experimental_features: JsonDict
+
     servlets = [
         admin.register_servlets,
         login.register_servlets,
@@ -34,6 +58,11 @@ class SendToDeviceTestCase(HomeserverTestCase):
         sync.register_servlets,
     ]
 
+    def default_config(self) -> JsonDict:
+        config = super().default_config()
+        config["experimental_features"] = self.experimental_features
+        return config
+
     def test_user_to_user(self) -> None:
         """A to-device message from one user to another should get delivered"""
 
@@ -54,7 +83,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
         self.assertEqual(chan.code, 200, chan.result)
 
         # check it appears
-        channel = self.make_request("GET", "/sync", access_token=user2_tok)
+        channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
         self.assertEqual(channel.code, 200, channel.result)
         expected_result = {
             "events": [
@@ -67,15 +96,19 @@ class SendToDeviceTestCase(HomeserverTestCase):
         }
         self.assertEqual(channel.json_body["to_device"], expected_result)
 
-        # it should re-appear if we do another sync
-        channel = self.make_request("GET", "/sync", access_token=user2_tok)
+        # it should re-appear if we do another sync because the to-device message is not
+        # deleted until we acknowledge it by sending a `?since=...` parameter in the
+        # next sync request corresponding to the `next_batch` value from the response.
+        channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
         self.assertEqual(channel.code, 200, channel.result)
         self.assertEqual(channel.json_body["to_device"], expected_result)
 
         # it should *not* appear if we do an incremental sync
         sync_token = channel.json_body["next_batch"]
         channel = self.make_request(
-            "GET", f"/sync?since={sync_token}", access_token=user2_tok
+            "GET",
+            f"{self.sync_endpoint}?since={sync_token}",
+            access_token=user2_tok,
         )
         self.assertEqual(channel.code, 200, channel.result)
         self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), [])
@@ -99,15 +132,19 @@ class SendToDeviceTestCase(HomeserverTestCase):
             )
             self.assertEqual(chan.code, 200, chan.result)
 
-        # now sync: we should get two of the three
-        channel = self.make_request("GET", "/sync", access_token=user2_tok)
+        # now sync: we should get two of the three (because burst_count=2)
+        channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
         self.assertEqual(channel.code, 200, channel.result)
         msgs = channel.json_body["to_device"]["events"]
         self.assertEqual(len(msgs), 2)
         for i in range(2):
             self.assertEqual(
                 msgs[i],
-                {"sender": user1, "type": "m.room_key_request", "content": {"idx": i}},
+                {
+                    "sender": user1,
+                    "type": "m.room_key_request",
+                    "content": {"idx": i},
+                },
             )
         sync_token = channel.json_body["next_batch"]
 
@@ -125,7 +162,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
 
         # ... which should arrive
         channel = self.make_request(
-            "GET", f"/sync?since={sync_token}", access_token=user2_tok
+            "GET",
+            f"{self.sync_endpoint}?since={sync_token}",
+            access_token=user2_tok,
         )
         self.assertEqual(channel.code, 200, channel.result)
         msgs = channel.json_body["to_device"]["events"]
@@ -159,7 +198,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
             )
 
         # now sync: we should get two of the three
-        channel = self.make_request("GET", "/sync", access_token=user2_tok)
+        channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
         self.assertEqual(channel.code, 200, channel.result)
         msgs = channel.json_body["to_device"]["events"]
         self.assertEqual(len(msgs), 2)
@@ -193,7 +232,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
 
         # ... which should arrive
         channel = self.make_request(
-            "GET", f"/sync?since={sync_token}", access_token=user2_tok
+            "GET",
+            f"{self.sync_endpoint}?since={sync_token}",
+            access_token=user2_tok,
         )
         self.assertEqual(channel.code, 200, channel.result)
         msgs = channel.json_body["to_device"]["events"]
@@ -217,7 +258,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
         user2_tok = self.login("u2", "pass", "d2")
 
         # Do an initial sync
-        channel = self.make_request("GET", "/sync", access_token=user2_tok)
+        channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
         self.assertEqual(channel.code, 200, channel.result)
         sync_token = channel.json_body["next_batch"]
 
@@ -233,7 +274,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
             self.assertEqual(chan.code, 200, chan.result)
 
         channel = self.make_request(
-            "GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok
+            "GET",
+            f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
+            access_token=user2_tok,
         )
         self.assertEqual(channel.code, 200, channel.result)
         messages = channel.json_body.get("to_device", {}).get("events", [])
@@ -241,7 +284,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
         sync_token = channel.json_body["next_batch"]
 
         channel = self.make_request(
-            "GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok
+            "GET",
+            f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
+            access_token=user2_tok,
         )
         self.assertEqual(channel.code, 200, channel.result)
         messages = channel.json_body.get("to_device", {}).get("events", [])
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 417a87feb2..daeb1d3ddd 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -21,7 +21,7 @@
 import json
 from typing import List
 
-from parameterized import parameterized
+from parameterized import parameterized, parameterized_class
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -688,24 +688,180 @@ class SyncCacheTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200, channel.json_body)
 
 
+@parameterized_class(
+    ("sync_endpoint", "experimental_features"),
+    [
+        ("/sync", {}),
+        (
+            "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
+            # Enable sliding sync
+            {"msc3575_enabled": True},
+        ),
+    ],
+)
 class DeviceListSyncTestCase(unittest.HomeserverTestCase):
+    """
+    Tests regarding device list (`device_lists`) changes.
+
+    Attributes:
+        sync_endpoint: The endpoint under test to use for syncing.
+        experimental_features: The experimental features homeserver config to use.
+    """
+
+    sync_endpoint: str
+    experimental_features: JsonDict
+
     servlets = [
         synapse.rest.admin.register_servlets,
         login.register_servlets,
+        room.register_servlets,
         sync.register_servlets,
         devices.register_servlets,
     ]
 
+    def default_config(self) -> JsonDict:
+        config = super().default_config()
+        config["experimental_features"] = self.experimental_features
+        return config
+
+    def test_receiving_local_device_list_changes(self) -> None:
+        """Tests that a local users that share a room receive each other's device list
+        changes.
+        """
+        # Register two users
+        test_device_id = "TESTDEVICE"
+        alice_user_id = self.register_user("alice", "correcthorse")
+        alice_access_token = self.login(
+            alice_user_id, "correcthorse", device_id=test_device_id
+        )
+
+        bob_user_id = self.register_user("bob", "ponyponypony")
+        bob_access_token = self.login(bob_user_id, "ponyponypony")
+
+        # Create a room for them to coexist peacefully in
+        new_room_id = self.helper.create_room_as(
+            alice_user_id, is_public=True, tok=alice_access_token
+        )
+        self.assertIsNotNone(new_room_id)
+
+        # Have Bob join the room
+        self.helper.invite(
+            new_room_id, alice_user_id, bob_user_id, tok=alice_access_token
+        )
+        self.helper.join(new_room_id, bob_user_id, tok=bob_access_token)
+
+        # Now have Bob initiate an initial sync (in order to get a since token)
+        channel = self.make_request(
+            "GET",
+            self.sync_endpoint,
+            access_token=bob_access_token,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+        next_batch_token = channel.json_body["next_batch"]
+
+        # ...and then an incremental sync. This should block until the sync stream is woken up,
+        # which we hope will happen as a result of Alice updating their device list.
+        bob_sync_channel = self.make_request(
+            "GET",
+            f"{self.sync_endpoint}?since={next_batch_token}&timeout=30000",
+            access_token=bob_access_token,
+            # Start the request, then continue on.
+            await_result=False,
+        )
+
+        # Have alice update their device list
+        channel = self.make_request(
+            "PUT",
+            f"/devices/{test_device_id}",
+            {
+                "display_name": "New Device Name",
+            },
+            access_token=alice_access_token,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check that bob's incremental sync contains the updated device list.
+        # If not, the client would only receive the device list update on the
+        # *next* sync.
+        bob_sync_channel.await_result()
+        self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
+
+        changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
+            "changed", []
+        )
+        self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body)
+
+    def test_not_receiving_local_device_list_changes(self) -> None:
+        """Tests a local users DO NOT receive device updates from each other if they do not
+        share a room.
+        """
+        # Register two users
+        test_device_id = "TESTDEVICE"
+        alice_user_id = self.register_user("alice", "correcthorse")
+        alice_access_token = self.login(
+            alice_user_id, "correcthorse", device_id=test_device_id
+        )
+
+        bob_user_id = self.register_user("bob", "ponyponypony")
+        bob_access_token = self.login(bob_user_id, "ponyponypony")
+
+        # These users do not share a room. They are lonely.
+
+        # Have Bob initiate an initial sync (in order to get a since token)
+        channel = self.make_request(
+            "GET",
+            self.sync_endpoint,
+            access_token=bob_access_token,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+        next_batch_token = channel.json_body["next_batch"]
+
+        # ...and then an incremental sync. This should block until the sync stream is woken up,
+        # which we hope will happen as a result of Alice updating their device list.
+        bob_sync_channel = self.make_request(
+            "GET",
+            f"{self.sync_endpoint}?since={next_batch_token}&timeout=1000",
+            access_token=bob_access_token,
+            # Start the request, then continue on.
+            await_result=False,
+        )
+
+        # Have alice update their device list
+        channel = self.make_request(
+            "PUT",
+            f"/devices/{test_device_id}",
+            {
+                "display_name": "New Device Name",
+            },
+            access_token=alice_access_token,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check that bob's incremental sync does not contain the updated device list.
+        bob_sync_channel.await_result()
+        self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
+
+        changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
+            "changed", []
+        )
+        self.assertNotIn(
+            alice_user_id, changed_device_lists, bob_sync_channel.json_body
+        )
+
     def test_user_with_no_rooms_receives_self_device_list_updates(self) -> None:
         """Tests that a user with no rooms still receives their own device list updates"""
-        device_id = "TESTDEVICE"
+        test_device_id = "TESTDEVICE"
 
         # Register a user and login, creating a device
-        self.user_id = self.register_user("kermit", "monkey")
-        self.tok = self.login("kermit", "monkey", device_id=device_id)
+        alice_user_id = self.register_user("alice", "correcthorse")
+        alice_access_token = self.login(
+            alice_user_id, "correcthorse", device_id=test_device_id
+        )
 
         # Request an initial sync
-        channel = self.make_request("GET", "/sync", access_token=self.tok)
+        channel = self.make_request(
+            "GET", self.sync_endpoint, access_token=alice_access_token
+        )
         self.assertEqual(channel.code, 200, channel.json_body)
         next_batch = channel.json_body["next_batch"]
 
@@ -713,19 +869,19 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
         # It won't return until something has happened
         incremental_sync_channel = self.make_request(
             "GET",
-            f"/sync?since={next_batch}&timeout=30000",
-            access_token=self.tok,
+            f"{self.sync_endpoint}?since={next_batch}&timeout=30000",
+            access_token=alice_access_token,
             await_result=False,
         )
 
         # Change our device's display name
         channel = self.make_request(
             "PUT",
-            f"devices/{device_id}",
+            f"devices/{test_device_id}",
             {
                 "display_name": "freeze ray",
             },
-            access_token=self.tok,
+            access_token=alice_access_token,
         )
         self.assertEqual(channel.code, 200, channel.json_body)
 
@@ -739,7 +895,230 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
         ).get("changed", [])
 
         self.assertIn(
-            self.user_id, device_list_changes, incremental_sync_channel.json_body
+            alice_user_id, device_list_changes, incremental_sync_channel.json_body
+        )
+
+
+@parameterized_class(
+    ("sync_endpoint", "experimental_features"),
+    [
+        ("/sync", {}),
+        (
+            "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
+            # Enable sliding sync
+            {"msc3575_enabled": True},
+        ),
+    ],
+)
+class DeviceOneTimeKeysSyncTestCase(unittest.HomeserverTestCase):
+    """
+    Tests regarding device one time keys (`device_one_time_keys_count`) changes.
+
+    Attributes:
+        sync_endpoint: The endpoint under test to use for syncing.
+        experimental_features: The experimental features homeserver config to use.
+    """
+
+    sync_endpoint: str
+    experimental_features: JsonDict
+
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+        sync.register_servlets,
+        devices.register_servlets,
+    ]
+
+    def default_config(self) -> JsonDict:
+        config = super().default_config()
+        config["experimental_features"] = self.experimental_features
+        return config
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.e2e_keys_handler = hs.get_e2e_keys_handler()
+
+    def test_no_device_one_time_keys(self) -> None:
+        """
+        Tests when no one time keys set, it still has the default `signed_curve25519` in
+        `device_one_time_keys_count`
+        """
+        test_device_id = "TESTDEVICE"
+
+        alice_user_id = self.register_user("alice", "correcthorse")
+        alice_access_token = self.login(
+            alice_user_id, "correcthorse", device_id=test_device_id
+        )
+
+        # Request an initial sync
+        channel = self.make_request(
+            "GET", self.sync_endpoint, access_token=alice_access_token
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check for those one time key counts
+        self.assertDictEqual(
+            channel.json_body["device_one_time_keys_count"],
+            # Note that "signed_curve25519" is always returned in key count responses
+            # regardless of whether we uploaded any keys for it. This is necessary until
+            # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
+            {"signed_curve25519": 0},
+            channel.json_body["device_one_time_keys_count"],
+        )
+
+    def test_returns_device_one_time_keys(self) -> None:
+        """
+        Tests that one time keys for the device/user are counted correctly in the `/sync`
+        response
+        """
+        test_device_id = "TESTDEVICE"
+
+        alice_user_id = self.register_user("alice", "correcthorse")
+        alice_access_token = self.login(
+            alice_user_id, "correcthorse", device_id=test_device_id
+        )
+
+        # Upload one time keys for the user/device
+        keys: JsonDict = {
+            "alg1:k1": "key1",
+            "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
+            "alg2:k3": {"key": "key3"},
+        }
+        res = self.get_success(
+            self.e2e_keys_handler.upload_keys_for_user(
+                alice_user_id, test_device_id, {"one_time_keys": keys}
+            )
+        )
+        # Note that "signed_curve25519" is always returned in key count responses
+        # regardless of whether we uploaded any keys for it. This is necessary until
+        # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
+        self.assertDictEqual(
+            res,
+            {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}},
+        )
+
+        # Request an initial sync
+        channel = self.make_request(
+            "GET", self.sync_endpoint, access_token=alice_access_token
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check for those one time key counts
+        self.assertDictEqual(
+            channel.json_body["device_one_time_keys_count"],
+            {"alg1": 1, "alg2": 2, "signed_curve25519": 0},
+            channel.json_body["device_one_time_keys_count"],
+        )
+
+
+@parameterized_class(
+    ("sync_endpoint", "experimental_features"),
+    [
+        ("/sync", {}),
+        (
+            "/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
+            # Enable sliding sync
+            {"msc3575_enabled": True},
+        ),
+    ],
+)
+class DeviceUnusedFallbackKeySyncTestCase(unittest.HomeserverTestCase):
+    """
+    Tests regarding device one time keys (`device_unused_fallback_key_types`) changes.
+
+    Attributes:
+        sync_endpoint: The endpoint under test to use for syncing.
+        experimental_features: The experimental features homeserver config to use.
+    """
+
+    sync_endpoint: str
+    experimental_features: JsonDict
+
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+        sync.register_servlets,
+        devices.register_servlets,
+    ]
+
+    def default_config(self) -> JsonDict:
+        config = super().default_config()
+        config["experimental_features"] = self.experimental_features
+        return config
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = self.hs.get_datastores().main
+        self.e2e_keys_handler = hs.get_e2e_keys_handler()
+
+    def test_no_device_unused_fallback_key(self) -> None:
+        """
+        Test when no unused fallback key is set, it just returns an empty list. The MSC
+        says "The device_unused_fallback_key_types parameter must be present if the
+        server supports fallback keys.",
+        https://github.com/matrix-org/matrix-spec-proposals/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md
+        """
+        test_device_id = "TESTDEVICE"
+
+        alice_user_id = self.register_user("alice", "correcthorse")
+        alice_access_token = self.login(
+            alice_user_id, "correcthorse", device_id=test_device_id
+        )
+
+        # Request an initial sync
+        channel = self.make_request(
+            "GET", self.sync_endpoint, access_token=alice_access_token
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check for those one time key counts
+        self.assertListEqual(
+            channel.json_body["device_unused_fallback_key_types"],
+            [],
+            channel.json_body["device_unused_fallback_key_types"],
+        )
+
+    def test_returns_device_one_time_keys(self) -> None:
+        """
+        Tests that device unused fallback key type is returned correctly in the `/sync`
+        """
+        test_device_id = "TESTDEVICE"
+
+        alice_user_id = self.register_user("alice", "correcthorse")
+        alice_access_token = self.login(
+            alice_user_id, "correcthorse", device_id=test_device_id
+        )
+
+        # We shouldn't have any unused fallback keys yet
+        res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(alice_user_id, test_device_id)
+        )
+        self.assertEqual(res, [])
+
+        # Upload a fallback key for the user/device
+        fallback_key = {"alg1:k1": "fallback_key1"}
+        self.get_success(
+            self.e2e_keys_handler.upload_keys_for_user(
+                alice_user_id,
+                test_device_id,
+                {"fallback_keys": fallback_key},
+            )
+        )
+        # We should now have an unused alg1 key
+        fallback_res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(alice_user_id, test_device_id)
+        )
+        self.assertEqual(fallback_res, ["alg1"], fallback_res)
+
+        # Request an initial sync
+        channel = self.make_request(
+            "GET", self.sync_endpoint, access_token=alice_access_token
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check for the unused fallback key types
+        self.assertListEqual(
+            channel.json_body["device_unused_fallback_key_types"],
+            ["alg1"],
+            channel.json_body["device_unused_fallback_key_types"],
         )