summary refs log tree commit diff
diff options
context:
space:
mode:
authorEric Eastwood <eric.eastwood@beta.gouv.fr>2024-07-22 15:40:06 -0500
committerGitHub <noreply@github.com>2024-07-22 15:40:06 -0500
commitde05a642460fa04f6e279fa166f032e9ff94b4b0 (patch)
treee8dfd8ef089fc635b5e292613d4f63552bb6aa38
parentSS: Implement `$ME` support (#17469) (diff)
downloadsynapse-de05a642460fa04f6e279fa166f032e9ff94b4b0.tar.xz
Sliding Sync: Add E2EE extension (MSC3884) (#17454)
Spec: [MSC3884](https://github.com/matrix-org/matrix-spec-proposals/pull/3884)

Based on [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575): Sliding Sync
-rw-r--r--changelog.d/17454.feature1
-rw-r--r--synapse/handlers/device.py17
-rw-r--r--synapse/handlers/sliding_sync.py107
-rw-r--r--synapse/rest/client/keys.py10
-rw-r--r--synapse/rest/client/sync.py32
-rw-r--r--synapse/types/__init__.py7
-rw-r--r--synapse/types/handlers/__init__.py48
-rw-r--r--synapse/types/rest/client/__init__.py10
-rw-r--r--tests/rest/client/test_sync.py825
9 files changed, 1023 insertions, 34 deletions
diff --git a/changelog.d/17454.feature b/changelog.d/17454.feature
new file mode 100644
index 0000000000..bb088371bf
--- /dev/null
+++ b/changelog.d/17454.feature
@@ -0,0 +1 @@
+Add E2EE extension support to experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint.
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 0432d97109..4fc6fcd7ae 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -39,6 +39,7 @@ from synapse.metrics.background_process_metrics import (
 )
 from synapse.storage.databases.main.client_ips import DeviceLastConnectionInfo
 from synapse.types import (
+    DeviceListUpdates,
     JsonDict,
     JsonMapping,
     ScheduledTask,
@@ -214,7 +215,7 @@ class DeviceWorkerHandler:
     @cancellable
     async def get_user_ids_changed(
         self, user_id: str, from_token: StreamToken
-    ) -> JsonDict:
+    ) -> DeviceListUpdates:
         """Get list of users that have had the devices updated, or have newly
         joined a room, that `user_id` may be interested in.
         """
@@ -341,11 +342,19 @@ class DeviceWorkerHandler:
             possibly_joined = set()
             possibly_left = set()
 
-        result = {"changed": list(possibly_joined), "left": list(possibly_left)}
+        device_list_updates = DeviceListUpdates(
+            changed=possibly_joined,
+            left=possibly_left,
+        )
 
-        log_kv(result)
+        log_kv(
+            {
+                "changed": device_list_updates.changed,
+                "left": device_list_updates.left,
+            }
+        )
 
-        return result
+        return device_list_updates
 
     async def on_federation_query_user_devices(self, user_id: str) -> JsonDict:
         if not self.hs.is_mine(UserID.from_string(user_id)):
diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
index c362afa6e2..886d7c7159 100644
--- a/synapse/handlers/sliding_sync.py
+++ b/synapse/handlers/sliding_sync.py
@@ -19,7 +19,18 @@
 #
 import logging
 from itertools import chain
-from typing import TYPE_CHECKING, Any, Dict, Final, List, Mapping, Optional, Set, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Dict,
+    Final,
+    List,
+    Mapping,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+)
 
 import attr
 from immutabledict import immutabledict
@@ -33,6 +44,7 @@ from synapse.storage.databases.main.roommember import extract_heroes_from_room_s
 from synapse.storage.databases.main.stream import CurrentStateDeltaMembership
 from synapse.storage.roommember import MemberSummary
 from synapse.types import (
+    DeviceListUpdates,
     JsonDict,
     PersistedEventPosition,
     Requester,
@@ -343,6 +355,7 @@ class SlidingSyncHandler:
         self.notifier = hs.get_notifier()
         self.event_sources = hs.get_event_sources()
         self.relations_handler = hs.get_relations_handler()
+        self.device_handler = hs.get_device_handler()
         self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync
 
     async def wait_for_sync_for_user(
@@ -371,10 +384,6 @@ class SlidingSyncHandler:
         # auth_blocking will occur)
         await self.auth_blocking.check_auth_blocking(requester=requester)
 
-        # TODO: If the To-Device extension is enabled and we have a `from_token`, delete
-        # any to-device messages before that token (since we now know that the device
-        # has received them). (see sync v2 for how to do this)
-
         # If we're working with a user-provided token, we need to make sure to wait for
         # this worker to catch up with the token so we don't skip past any incoming
         # events or future events if the user is nefariously, manually modifying the
@@ -617,7 +626,9 @@ class SlidingSyncHandler:
             await concurrently_execute(handle_room, relevant_room_map, 10)
 
         extensions = await self.get_extensions_response(
-            sync_config=sync_config, to_token=to_token
+            sync_config=sync_config,
+            from_token=from_token,
+            to_token=to_token,
         )
 
         return SlidingSyncResult(
@@ -1776,33 +1787,47 @@ class SlidingSyncHandler:
         self,
         sync_config: SlidingSyncConfig,
         to_token: StreamToken,
+        from_token: Optional[StreamToken],
     ) -> SlidingSyncResult.Extensions:
         """Handle extension requests.
 
         Args:
             sync_config: Sync configuration
             to_token: The point in the stream to sync up to.
+            from_token: The point in the stream to sync from.
         """
 
         if sync_config.extensions is None:
             return SlidingSyncResult.Extensions()
 
         to_device_response = None
-        if sync_config.extensions.to_device:
-            to_device_response = await self.get_to_device_extensions_response(
+        if sync_config.extensions.to_device is not None:
+            to_device_response = await self.get_to_device_extension_response(
                 sync_config=sync_config,
                 to_device_request=sync_config.extensions.to_device,
                 to_token=to_token,
             )
 
-        return SlidingSyncResult.Extensions(to_device=to_device_response)
+        e2ee_response = None
+        if sync_config.extensions.e2ee is not None:
+            e2ee_response = await self.get_e2ee_extension_response(
+                sync_config=sync_config,
+                e2ee_request=sync_config.extensions.e2ee,
+                to_token=to_token,
+                from_token=from_token,
+            )
 
-    async def get_to_device_extensions_response(
+        return SlidingSyncResult.Extensions(
+            to_device=to_device_response,
+            e2ee=e2ee_response,
+        )
+
+    async def get_to_device_extension_response(
         self,
         sync_config: SlidingSyncConfig,
         to_device_request: SlidingSyncConfig.Extensions.ToDeviceExtension,
         to_token: StreamToken,
-    ) -> SlidingSyncResult.Extensions.ToDeviceExtension:
+    ) -> Optional[SlidingSyncResult.Extensions.ToDeviceExtension]:
         """Handle to-device extension (MSC3885)
 
         Args:
@@ -1810,14 +1835,16 @@ class SlidingSyncHandler:
             to_device_request: The to-device extension from the request
             to_token: The point in the stream to sync up to.
         """
-
         user_id = sync_config.user.to_string()
         device_id = sync_config.device_id
 
+        # Skip if the extension is not enabled
+        if not to_device_request.enabled:
+            return None
+
         # Check that this request has a valid device ID (not all requests have
-        # to belong to a device, and so device_id is None), and that the
-        # extension is enabled.
-        if device_id is None or not to_device_request.enabled:
+        # to belong to a device, and so device_id is None)
+        if device_id is None:
             return SlidingSyncResult.Extensions.ToDeviceExtension(
                 next_batch=f"{to_token.to_device_key}",
                 events=[],
@@ -1868,3 +1895,53 @@ class SlidingSyncHandler:
             next_batch=f"{stream_id}",
             events=messages,
         )
+
+    async def get_e2ee_extension_response(
+        self,
+        sync_config: SlidingSyncConfig,
+        e2ee_request: SlidingSyncConfig.Extensions.E2eeExtension,
+        to_token: StreamToken,
+        from_token: Optional[StreamToken],
+    ) -> Optional[SlidingSyncResult.Extensions.E2eeExtension]:
+        """Handle E2EE device extension (MSC3884)
+
+        Args:
+            sync_config: Sync configuration
+            e2ee_request: The e2ee extension from the request
+            to_token: The point in the stream to sync up to.
+            from_token: The point in the stream to sync from.
+        """
+        user_id = sync_config.user.to_string()
+        device_id = sync_config.device_id
+
+        # Skip if the extension is not enabled
+        if not e2ee_request.enabled:
+            return None
+
+        device_list_updates: Optional[DeviceListUpdates] = None
+        if from_token is not None:
+            # TODO: This should take into account the `from_token` and `to_token`
+            device_list_updates = await self.device_handler.get_user_ids_changed(
+                user_id=user_id,
+                from_token=from_token,
+            )
+
+        device_one_time_keys_count: Mapping[str, int] = {}
+        device_unused_fallback_key_types: Sequence[str] = []
+        if device_id:
+            # TODO: We should have a way to let clients differentiate between the states of:
+            #   * no change in OTK count since the provided since token
+            #   * the server has zero OTKs left for this device
+            #  Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
+            device_one_time_keys_count = await self.store.count_e2e_one_time_keys(
+                user_id, device_id
+            )
+            device_unused_fallback_key_types = (
+                await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
+            )
+
+        return SlidingSyncResult.Extensions.E2eeExtension(
+            device_list_updates=device_list_updates,
+            device_one_time_keys_count=device_one_time_keys_count,
+            device_unused_fallback_key_types=device_unused_fallback_key_types,
+        )
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index 67de634eab..eddad7d5b8 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -256,9 +256,15 @@ class KeyChangesServlet(RestServlet):
 
         user_id = requester.user.to_string()
 
-        results = await self.device_handler.get_user_ids_changed(user_id, from_token)
+        device_list_updates = await self.device_handler.get_user_ids_changed(
+            user_id, from_token
+        )
+
+        response: JsonDict = {}
+        response["changed"] = list(device_list_updates.changed)
+        response["left"] = list(device_list_updates.left)
 
-        return 200, results
+        return 200, response
 
 
 class OneTimeKeyServlet(RestServlet):
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 1d8cbfdf00..93fe1d439e 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -1081,15 +1081,41 @@ class SlidingSyncRestServlet(RestServlet):
     async def encode_extensions(
         self, requester: Requester, extensions: SlidingSyncResult.Extensions
     ) -> JsonDict:
-        result = {}
+        serialized_extensions: JsonDict = {}
 
         if extensions.to_device is not None:
-            result["to_device"] = {
+            serialized_extensions["to_device"] = {
                 "next_batch": extensions.to_device.next_batch,
                 "events": extensions.to_device.events,
             }
 
-        return result
+        if extensions.e2ee is not None:
+            serialized_extensions["e2ee"] = {
+                # We always include this because
+                # https://github.com/vector-im/element-android/issues/3725. The spec
+                # isn't terribly clear on when this can be omitted and how a client
+                # would tell the difference between "no keys present" and "nothing
+                # changed" in terms of whole field absent / individual key type entry
+                # absent Corresponding synapse issue:
+                # https://github.com/matrix-org/synapse/issues/10456
+                "device_one_time_keys_count": extensions.e2ee.device_one_time_keys_count,
+                # https://github.com/matrix-org/matrix-doc/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md
+                # states that this field should always be included, as long as the
+                # server supports the feature.
+                "device_unused_fallback_key_types": extensions.e2ee.device_unused_fallback_key_types,
+            }
+
+            if extensions.e2ee.device_list_updates is not None:
+                serialized_extensions["e2ee"]["device_lists"] = {}
+
+                serialized_extensions["e2ee"]["device_lists"]["changed"] = list(
+                    extensions.e2ee.device_list_updates.changed
+                )
+                serialized_extensions["e2ee"]["device_lists"]["left"] = list(
+                    extensions.e2ee.device_list_updates.left
+                )
+
+        return serialized_extensions
 
 
 def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index 3962ecc996..046cdc29cd 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -1219,11 +1219,12 @@ class ReadReceipt:
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class DeviceListUpdates:
     """
-    An object containing a diff of information regarding other users' device lists, intended for
-    a recipient to carry out device list tracking.
+    An object containing a diff of information regarding other users' device lists,
+    intended for a recipient to carry out device list tracking.
 
     Attributes:
-        changed: A set of users whose device lists have changed recently.
+        changed: A set of users who have updated their device identity or
+            cross-signing keys, or who now share an encrypted room with.
         left: A set of users who the recipient no longer needs to track the device lists of.
             Typically when those users no longer share any end-to-end encryption enabled rooms.
     """
diff --git a/synapse/types/handlers/__init__.py b/synapse/types/handlers/__init__.py
index 409120470a..4c6c42db04 100644
--- a/synapse/types/handlers/__init__.py
+++ b/synapse/types/handlers/__init__.py
@@ -18,7 +18,7 @@
 #
 #
 from enum import Enum
-from typing import TYPE_CHECKING, Dict, Final, List, Optional, Sequence, Tuple
+from typing import TYPE_CHECKING, Dict, Final, List, Mapping, Optional, Sequence, Tuple
 
 import attr
 from typing_extensions import TypedDict
@@ -31,7 +31,7 @@ else:
     from pydantic import Extra
 
 from synapse.events import EventBase
-from synapse.types import JsonDict, JsonMapping, StreamToken, UserID
+from synapse.types import DeviceListUpdates, JsonDict, JsonMapping, StreamToken, UserID
 from synapse.types.rest.client import SlidingSyncBody
 
 if TYPE_CHECKING:
@@ -264,6 +264,7 @@ class SlidingSyncResult:
 
         Attributes:
             to_device: The to-device extension (MSC3885)
+            e2ee: The E2EE device extension (MSC3884)
         """
 
         @attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -282,10 +283,51 @@ class SlidingSyncResult:
             def __bool__(self) -> bool:
                 return bool(self.events)
 
+        @attr.s(slots=True, frozen=True, auto_attribs=True)
+        class E2eeExtension:
+            """The E2EE device extension (MSC3884)
+
+            Attributes:
+                device_list_updates: List of user_ids whose devices have changed or left (only
+                    present on incremental syncs).
+                device_one_time_keys_count: Map from key algorithm to the number of
+                    unclaimed one-time keys currently held on the server for this device. If
+                    an algorithm is unlisted, the count for that algorithm is assumed to be
+                    zero. If this entire parameter is missing, the count for all algorithms
+                    is assumed to be zero.
+                device_unused_fallback_key_types: List of unused fallback key algorithms
+                    for this device.
+            """
+
+            # Only present on incremental syncs
+            device_list_updates: Optional[DeviceListUpdates]
+            device_one_time_keys_count: Mapping[str, int]
+            device_unused_fallback_key_types: Sequence[str]
+
+            def __bool__(self) -> bool:
+                # Note that "signed_curve25519" is always returned in key count responses
+                # regardless of whether we uploaded any keys for it. This is necessary until
+                # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
+                #
+                # Also related:
+                # https://github.com/element-hq/element-android/issues/3725 and
+                # https://github.com/matrix-org/synapse/issues/10456
+                default_otk = self.device_one_time_keys_count.get("signed_curve25519")
+                more_than_default_otk = len(self.device_one_time_keys_count) > 1 or (
+                    default_otk is not None and default_otk > 0
+                )
+
+                return bool(
+                    more_than_default_otk
+                    or self.device_list_updates
+                    or self.device_unused_fallback_key_types
+                )
+
         to_device: Optional[ToDeviceExtension] = None
+        e2ee: Optional[E2eeExtension] = None
 
         def __bool__(self) -> bool:
-            return bool(self.to_device)
+            return bool(self.to_device or self.e2ee)
 
     next_pos: StreamToken
     lists: Dict[str, SlidingWindowList]
diff --git a/synapse/types/rest/client/__init__.py b/synapse/types/rest/client/__init__.py
index dbe37bc712..f3c45a0d6a 100644
--- a/synapse/types/rest/client/__init__.py
+++ b/synapse/types/rest/client/__init__.py
@@ -313,7 +313,17 @@ class SlidingSyncBody(RequestBodyModel):
 
                 return value
 
+        class E2eeExtension(RequestBodyModel):
+            """The E2EE device extension (MSC3884)
+
+            Attributes:
+                enabled
+            """
+
+            enabled: Optional[StrictBool] = False
+
         to_device: Optional[ToDeviceExtension] = None
+        e2ee: Optional[E2eeExtension] = None
 
     # mypy workaround via https://github.com/pydantic/pydantic/issues/156#issuecomment-1130883884
     if TYPE_CHECKING:
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index a88bdb5c14..2628869de6 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -59,6 +59,7 @@ from tests.federation.transport.test_knocking import (
 )
 from tests.server import FakeChannel, TimedOutException
 from tests.test_utils.event_injection import mark_event_as_partial_state
+from tests.unittest import skip_unless
 
 logger = logging.getLogger(__name__)
 
@@ -1113,12 +1114,11 @@ class DeviceUnusedFallbackKeySyncTestCase(unittest.HomeserverTestCase):
         self.assertEqual(res, [])
 
         # Upload a fallback key for the user/device
-        fallback_key = {"alg1:k1": "fallback_key1"}
         self.get_success(
             self.e2e_keys_handler.upload_keys_for_user(
                 alice_user_id,
                 test_device_id,
-                {"fallback_keys": fallback_key},
+                {"fallback_keys": {"alg1:k1": "fallback_key1"}},
             )
         )
         # We should now have an unused alg1 key
@@ -1252,6 +1252,8 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
         self.store = hs.get_datastores().main
         self.event_sources = hs.get_event_sources()
         self.storage_controllers = hs.get_storage_controllers()
+        self.account_data_handler = hs.get_account_data_handler()
+        self.notifier = hs.get_notifier()
 
     def _assertRequiredStateIncludes(
         self,
@@ -1377,6 +1379,52 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
 
         return room_id
 
+    def _bump_notifier_wait_for_events(self, user_id: str) -> None:
+        """
+        Wake-up a `notifier.wait_for_events(user_id)` call without affecting the Sliding
+        Sync results.
+        """
+        # We're expecting some new activity from this point onwards
+        from_token = self.event_sources.get_current_token()
+
+        triggered_notifier_wait_for_events = False
+
+        async def _on_new_acivity(
+            before_token: StreamToken, after_token: StreamToken
+        ) -> bool:
+            nonlocal triggered_notifier_wait_for_events
+            triggered_notifier_wait_for_events = True
+            return True
+
+        # Listen for some new activity for the user. We're just trying to confirm that
+        # our bump below actually does what we think it does (triggers new activity for
+        # the user).
+        result_awaitable = self.notifier.wait_for_events(
+            user_id,
+            1000,
+            _on_new_acivity,
+            from_token=from_token,
+        )
+
+        # Update the account data so that `notifier.wait_for_events(...)` wakes up.
+        # We're bumping account data because it won't show up in the Sliding Sync
+        # response so it won't affect whether we have results.
+        self.get_success(
+            self.account_data_handler.add_account_data_for_user(
+                user_id,
+                "org.matrix.foobarbaz",
+                {"foo": "bar"},
+            )
+        )
+
+        # Wait for our notifier result
+        self.get_success(result_awaitable)
+
+        if not triggered_notifier_wait_for_events:
+            raise AssertionError(
+                "Expected `notifier.wait_for_events(...)` to be triggered"
+            )
+
     def test_sync_list(self) -> None:
         """
         Test that room IDs show up in the Sliding Sync `lists`
@@ -1482,6 +1530,124 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
         # with because we weren't able to find anything new yet.
         self.assertEqual(channel.json_body["pos"], future_position_token_serialized)
 
+    def test_wait_for_new_data(self) -> None:
+        """
+        Test to make sure that the Sliding Sync request waits for new data to arrive.
+
+        (Only applies to incremental syncs with a `timeout` specified)
+        """
+        user1_id = self.register_user("user1", "pass")
+        user1_tok = self.login(user1_id, "pass")
+        user2_id = self.register_user("user2", "pass")
+        user2_tok = self.login(user2_id, "pass")
+
+        room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+        self.helper.join(room_id, user1_id, tok=user1_tok)
+
+        from_token = self.event_sources.get_current_token()
+
+        # Make the Sliding Sync request
+        channel = self.make_request(
+            "POST",
+            self.sync_endpoint
+            + "?timeout=10000"
+            + f"&pos={self.get_success(from_token.to_string(self.store))}",
+            {
+                "lists": {
+                    "foo-list": {
+                        "ranges": [[0, 0]],
+                        "required_state": [],
+                        "timeline_limit": 1,
+                    }
+                }
+            },
+            access_token=user1_tok,
+            await_result=False,
+        )
+        # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)`
+        with self.assertRaises(TimedOutException):
+            channel.await_result(timeout_ms=5000)
+        # Bump the room with new events to trigger new results
+        event_response1 = self.helper.send(
+            room_id, "new activity in room", tok=user1_tok
+        )
+        # Should respond before the 10 second timeout
+        channel.await_result(timeout_ms=3000)
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check to make sure the new event is returned
+        self.assertEqual(
+            [
+                event["event_id"]
+                for event in channel.json_body["rooms"][room_id]["timeline"]
+            ],
+            [
+                event_response1["event_id"],
+            ],
+            channel.json_body["rooms"][room_id]["timeline"],
+        )
+
+    # TODO: Once we remove `ops`, we should be able to add a `RoomResult.__bool__` to
+    # check if there are any updates since the `from_token`.
+    @skip_unless(
+        False,
+        "Once we remove ops from the Sliding Sync response, this test should pass",
+    )
+    def test_wait_for_new_data_timeout(self) -> None:
+        """
+        Test to make sure that the Sliding Sync request waits for new data to arrive but
+        no data ever arrives so we timeout. We're also making sure that the default data
+        doesn't trigger a false-positive for new data.
+        """
+        user1_id = self.register_user("user1", "pass")
+        user1_tok = self.login(user1_id, "pass")
+        user2_id = self.register_user("user2", "pass")
+        user2_tok = self.login(user2_id, "pass")
+
+        room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+        self.helper.join(room_id, user1_id, tok=user1_tok)
+
+        from_token = self.event_sources.get_current_token()
+
+        # Make the Sliding Sync request
+        channel = self.make_request(
+            "POST",
+            self.sync_endpoint
+            + "?timeout=10000"
+            + f"&pos={self.get_success(from_token.to_string(self.store))}",
+            {
+                "lists": {
+                    "foo-list": {
+                        "ranges": [[0, 0]],
+                        "required_state": [],
+                        "timeline_limit": 1,
+                    }
+                }
+            },
+            access_token=user1_tok,
+            await_result=False,
+        )
+        # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)`
+        with self.assertRaises(TimedOutException):
+            channel.await_result(timeout_ms=5000)
+        # Wake-up `notifier.wait_for_events(...)` that will cause us test
+        # `SlidingSyncResult.__bool__` for new results.
+        self._bump_notifier_wait_for_events(user1_id)
+        # Block for a little bit more to ensure we don't see any new results.
+        with self.assertRaises(TimedOutException):
+            channel.await_result(timeout_ms=4000)
+        # Wait for the sync to complete (wait for the rest of the 10 second timeout,
+        # 5000 + 4000 + 1200 > 10000)
+        channel.await_result(timeout_ms=1200)
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # We still see rooms because that's how Sliding Sync lists work but we reached
+        # the timeout before seeing them
+        self.assertEqual(
+            [event["event_id"] for event in channel.json_body["rooms"].keys()],
+            [room_id],
+        )
+
     def test_filter_list(self) -> None:
         """
         Test that filters apply to `lists`
@@ -1508,11 +1674,11 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
         )
 
         # Create a normal room
-        room_id = self.helper.create_room_as(user1_id, tok=user2_tok)
+        room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
         self.helper.join(room_id, user1_id, tok=user1_tok)
 
         # Create a room that user1 is invited to
-        invite_room_id = self.helper.create_room_as(user1_id, tok=user2_tok)
+        invite_room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
         self.helper.invite(invite_room_id, src=user2_id, targ=user1_id, tok=user2_tok)
 
         # Make the Sliding Sync request
@@ -4320,10 +4486,59 @@ class SlidingSyncToDeviceExtensionTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
+        self.event_sources = hs.get_event_sources()
+        self.account_data_handler = hs.get_account_data_handler()
+        self.notifier = hs.get_notifier()
         self.sync_endpoint = (
             "/_matrix/client/unstable/org.matrix.simplified_msc3575/sync"
         )
 
+    def _bump_notifier_wait_for_events(self, user_id: str) -> None:
+        """
+        Wake-up a `notifier.wait_for_events(user_id)` call without affecting the Sliding
+        Sync results.
+        """
+        # We're expecting some new activity from this point onwards
+        from_token = self.event_sources.get_current_token()
+
+        triggered_notifier_wait_for_events = False
+
+        async def _on_new_acivity(
+            before_token: StreamToken, after_token: StreamToken
+        ) -> bool:
+            nonlocal triggered_notifier_wait_for_events
+            triggered_notifier_wait_for_events = True
+            return True
+
+        # Listen for some new activity for the user. We're just trying to confirm that
+        # our bump below actually does what we think it does (triggers new activity for
+        # the user).
+        result_awaitable = self.notifier.wait_for_events(
+            user_id,
+            1000,
+            _on_new_acivity,
+            from_token=from_token,
+        )
+
+        # Update the account data so that `notifier.wait_for_events(...)` wakes up.
+        # We're bumping account data because it won't show up in the Sliding Sync
+        # response so it won't affect whether we have results.
+        self.get_success(
+            self.account_data_handler.add_account_data_for_user(
+                user_id,
+                "org.matrix.foobarbaz",
+                {"foo": "bar"},
+            )
+        )
+
+        # Wait for our notifier result
+        self.get_success(result_awaitable)
+
+        if not triggered_notifier_wait_for_events:
+            raise AssertionError(
+                "Expected `notifier.wait_for_events(...)` to be triggered"
+            )
+
     def _assert_to_device_response(
         self, channel: FakeChannel, expected_messages: List[JsonDict]
     ) -> str:
@@ -4487,3 +4702,605 @@ class SlidingSyncToDeviceExtensionTestCase(unittest.HomeserverTestCase):
             access_token=user1_tok,
         )
         self._assert_to_device_response(channel, [])
+
+    def test_wait_for_new_data(self) -> None:
+        """
+        Test to make sure that the Sliding Sync request waits for new data to arrive.
+
+        (Only applies to incremental syncs with a `timeout` specified)
+        """
+        user1_id = self.register_user("user1", "pass")
+        user1_tok = self.login(user1_id, "pass", "d1")
+        user2_id = self.register_user("u2", "pass")
+        user2_tok = self.login(user2_id, "pass", "d2")
+
+        from_token = self.event_sources.get_current_token()
+
+        # Make the Sliding Sync request
+        channel = self.make_request(
+            "POST",
+            self.sync_endpoint
+            + "?timeout=10000"
+            + f"&pos={self.get_success(from_token.to_string(self.store))}",
+            {
+                "lists": {},
+                "extensions": {
+                    "to_device": {
+                        "enabled": True,
+                    }
+                },
+            },
+            access_token=user1_tok,
+            await_result=False,
+        )
+        # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)`
+        with self.assertRaises(TimedOutException):
+            channel.await_result(timeout_ms=5000)
+        # Bump the to-device messages to trigger new results
+        test_msg = {"foo": "bar"}
+        send_to_device_channel = self.make_request(
+            "PUT",
+            "/_matrix/client/r0/sendToDevice/m.test/1234",
+            content={"messages": {user1_id: {"d1": test_msg}}},
+            access_token=user2_tok,
+        )
+        self.assertEqual(
+            send_to_device_channel.code, 200, send_to_device_channel.result
+        )
+        # Should respond before the 10 second timeout
+        channel.await_result(timeout_ms=3000)
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        self._assert_to_device_response(
+            channel,
+            [{"content": test_msg, "sender": user2_id, "type": "m.test"}],
+        )
+
+    def test_wait_for_new_data_timeout(self) -> None:
+        """
+        Test to make sure that the Sliding Sync request waits for new data to arrive but
+        no data ever arrives so we timeout. We're also making sure that the default data
+        from the To-Device extension doesn't trigger a false-positive for new data.
+        """
+        user1_id = self.register_user("user1", "pass")
+        user1_tok = self.login(user1_id, "pass")
+
+        from_token = self.event_sources.get_current_token()
+
+        # Make the Sliding Sync request
+        channel = self.make_request(
+            "POST",
+            self.sync_endpoint
+            + "?timeout=10000"
+            + f"&pos={self.get_success(from_token.to_string(self.store))}",
+            {
+                "lists": {},
+                "extensions": {
+                    "to_device": {
+                        "enabled": True,
+                    }
+                },
+            },
+            access_token=user1_tok,
+            await_result=False,
+        )
+        # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)`
+        with self.assertRaises(TimedOutException):
+            channel.await_result(timeout_ms=5000)
+        # Wake-up `notifier.wait_for_events(...)` that will cause us test
+        # `SlidingSyncResult.__bool__` for new results.
+        self._bump_notifier_wait_for_events(user1_id)
+        # Block for a little bit more to ensure we don't see any new results.
+        with self.assertRaises(TimedOutException):
+            channel.await_result(timeout_ms=4000)
+        # Wait for the sync to complete (wait for the rest of the 10 second timeout,
+        # 5000 + 4000 + 1200 > 10000)
+        channel.await_result(timeout_ms=1200)
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        self._assert_to_device_response(channel, [])
+
+
+class SlidingSyncE2eeExtensionTestCase(unittest.HomeserverTestCase):
+    """Tests for the e2ee sliding sync extension"""
+
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+        sync.register_servlets,
+        devices.register_servlets,
+    ]
+
+    def default_config(self) -> JsonDict:
+        config = super().default_config()
+        # Enable sliding sync
+        config["experimental_features"] = {"msc3575_enabled": True}
+        return config
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = hs.get_datastores().main
+        self.event_sources = hs.get_event_sources()
+        self.e2e_keys_handler = hs.get_e2e_keys_handler()
+        self.account_data_handler = hs.get_account_data_handler()
+        self.notifier = hs.get_notifier()
+        self.sync_endpoint = (
+            "/_matrix/client/unstable/org.matrix.simplified_msc3575/sync"
+        )
+
+    def _bump_notifier_wait_for_events(self, user_id: str) -> None:
+        """
+        Wake-up a `notifier.wait_for_events(user_id)` call without affecting the Sliding
+        Sync results.
+        """
+        # We're expecting some new activity from this point onwards
+        from_token = self.event_sources.get_current_token()
+
+        triggered_notifier_wait_for_events = False
+
+        async def _on_new_acivity(
+            before_token: StreamToken, after_token: StreamToken
+        ) -> bool:
+            nonlocal triggered_notifier_wait_for_events
+            triggered_notifier_wait_for_events = True
+            return True
+
+        # Listen for some new activity for the user. We're just trying to confirm that
+        # our bump below actually does what we think it does (triggers new activity for
+        # the user).
+        result_awaitable = self.notifier.wait_for_events(
+            user_id,
+            1000,
+            _on_new_acivity,
+            from_token=from_token,
+        )
+
+        # Update the account data so that `notifier.wait_for_events(...)` wakes up.
+        # We're bumping account data because it won't show up in the Sliding Sync
+        # response so it won't affect whether we have results.
+        self.get_success(
+            self.account_data_handler.add_account_data_for_user(
+                user_id,
+                "org.matrix.foobarbaz",
+                {"foo": "bar"},
+            )
+        )
+
+        # Wait for our notifier result
+        self.get_success(result_awaitable)
+
+        if not triggered_notifier_wait_for_events:
+            raise AssertionError(
+                "Expected `notifier.wait_for_events(...)` to be triggered"
+            )
+
+    def test_no_data_initial_sync(self) -> None:
+        """
+        Test that enabling e2ee extension works during an intitial sync, even if there
+        is no-data
+        """
+        user1_id = self.register_user("user1", "pass")
+        user1_tok = self.login(user1_id, "pass")
+
+        # Make an initial Sliding Sync request with the e2ee extension enabled
+        channel = self.make_request(
+            "POST",
+            self.sync_endpoint,
+            {
+                "lists": {},
+                "extensions": {
+                    "e2ee": {
+                        "enabled": True,
+                    }
+                },
+            },
+            access_token=user1_tok,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Device list updates are only present for incremental syncs
+        self.assertIsNone(channel.json_body["extensions"]["e2ee"].get("device_lists"))
+
+        # Both of these should be present even when empty
+        self.assertEqual(
+            channel.json_body["extensions"]["e2ee"]["device_one_time_keys_count"],
+            {
+                # This is always present because of
+                # https://github.com/element-hq/element-android/issues/3725 and
+                # https://github.com/matrix-org/synapse/issues/10456
+                "signed_curve25519": 0
+            },
+        )
+        self.assertEqual(
+            channel.json_body["extensions"]["e2ee"]["device_unused_fallback_key_types"],
+            [],
+        )
+
+    def test_no_data_incremental_sync(self) -> None:
+        """
+        Test that enabling e2ee extension works during an incremental sync, even if
+        there is no-data
+        """
+        user1_id = self.register_user("user1", "pass")
+        user1_tok = self.login(user1_id, "pass")
+
+        from_token = self.event_sources.get_current_token()
+
+        # Make an incremental Sliding Sync request with the e2ee extension enabled
+        channel = self.make_request(
+            "POST",
+            self.sync_endpoint
+            + f"?pos={self.get_success(from_token.to_string(self.store))}",
+            {
+                "lists": {},
+                "extensions": {
+                    "e2ee": {
+                        "enabled": True,
+                    }
+                },
+            },
+            access_token=user1_tok,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Device list shows up for incremental syncs
+        self.assertEqual(
+            channel.json_body["extensions"]["e2ee"]
+            .get("device_lists", {})
+            .get("changed"),
+            [],
+        )
+        self.assertEqual(
+            channel.json_body["extensions"]["e2ee"].get("device_lists", {}).get("left"),
+            [],
+        )
+
+        # Both of these should be present even when empty
+        self.assertEqual(
+            channel.json_body["extensions"]["e2ee"]["device_one_time_keys_count"],
+            {
+                # Note that "signed_curve25519" is always returned in key count responses
+                # regardless of whether we uploaded any keys for it. This is necessary until
+                # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
+                #
+                # Also related:
+                # https://github.com/element-hq/element-android/issues/3725 and
+                # https://github.com/matrix-org/synapse/issues/10456
+                "signed_curve25519": 0
+            },
+        )
+        self.assertEqual(
+            channel.json_body["extensions"]["e2ee"]["device_unused_fallback_key_types"],
+            [],
+        )
+
+    def test_wait_for_new_data(self) -> None:
+        """
+        Test to make sure that the Sliding Sync request waits for new data to arrive.
+
+        (Only applies to incremental syncs with a `timeout` specified)
+        """
+        user1_id = self.register_user("user1", "pass")
+        user1_tok = self.login(user1_id, "pass")
+        user2_id = self.register_user("user2", "pass")
+        user2_tok = self.login(user2_id, "pass")
+        test_device_id = "TESTDEVICE"
+        user3_id = self.register_user("user3", "pass")
+        user3_tok = self.login(user3_id, "pass", device_id=test_device_id)
+
+        room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+        self.helper.join(room_id, user1_id, tok=user1_tok)
+        self.helper.join(room_id, user3_id, tok=user3_tok)
+
+        from_token = self.event_sources.get_current_token()
+
+        # Make the Sliding Sync request
+        channel = self.make_request(
+            "POST",
+            self.sync_endpoint
+            + "?timeout=10000"
+            + f"&pos={self.get_success(from_token.to_string(self.store))}",
+            {
+                "lists": {},
+                "extensions": {
+                    "e2ee": {
+                        "enabled": True,
+                    }
+                },
+            },
+            access_token=user1_tok,
+            await_result=False,
+        )
+        # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)`
+        with self.assertRaises(TimedOutException):
+            channel.await_result(timeout_ms=5000)
+        # Bump the device lists to trigger new results
+        # Have user3 update their device list
+        device_update_channel = self.make_request(
+            "PUT",
+            f"/devices/{test_device_id}",
+            {
+                "display_name": "New Device Name",
+            },
+            access_token=user3_tok,
+        )
+        self.assertEqual(
+            device_update_channel.code, 200, device_update_channel.json_body
+        )
+        # Should respond before the 10 second timeout
+        channel.await_result(timeout_ms=3000)
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # We should see the device list update
+        self.assertEqual(
+            channel.json_body["extensions"]["e2ee"]
+            .get("device_lists", {})
+            .get("changed"),
+            [user3_id],
+        )
+        self.assertEqual(
+            channel.json_body["extensions"]["e2ee"].get("device_lists", {}).get("left"),
+            [],
+        )
+
+    def test_wait_for_new_data_timeout(self) -> None:
+        """
+        Test to make sure that the Sliding Sync request waits for new data to arrive but
+        no data ever arrives so we timeout. We're also making sure that the default data
+        from the E2EE extension doesn't trigger a false-positive for new data (see
+        `device_one_time_keys_count.signed_curve25519`).
+        """
+        user1_id = self.register_user("user1", "pass")
+        user1_tok = self.login(user1_id, "pass")
+
+        from_token = self.event_sources.get_current_token()
+
+        # Make the Sliding Sync request
+        channel = self.make_request(
+            "POST",
+            self.sync_endpoint
+            + "?timeout=10000"
+            + f"&pos={self.get_success(from_token.to_string(self.store))}",
+            {
+                "lists": {},
+                "extensions": {
+                    "e2ee": {
+                        "enabled": True,
+                    }
+                },
+            },
+            access_token=user1_tok,
+            await_result=False,
+        )
+        # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)`
+        with self.assertRaises(TimedOutException):
+            channel.await_result(timeout_ms=5000)
+        # Wake-up `notifier.wait_for_events(...)` that will cause us test
+        # `SlidingSyncResult.__bool__` for new results.
+        self._bump_notifier_wait_for_events(user1_id)
+        # Block for a little bit more to ensure we don't see any new results.
+        with self.assertRaises(TimedOutException):
+            channel.await_result(timeout_ms=4000)
+        # Wait for the sync to complete (wait for the rest of the 10 second timeout,
+        # 5000 + 4000 + 1200 > 10000)
+        channel.await_result(timeout_ms=1200)
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Device lists are present for incremental syncs but empty because no device changes
+        self.assertEqual(
+            channel.json_body["extensions"]["e2ee"]
+            .get("device_lists", {})
+            .get("changed"),
+            [],
+        )
+        self.assertEqual(
+            channel.json_body["extensions"]["e2ee"].get("device_lists", {}).get("left"),
+            [],
+        )
+
+        # Both of these should be present even when empty
+        self.assertEqual(
+            channel.json_body["extensions"]["e2ee"]["device_one_time_keys_count"],
+            {
+                # Note that "signed_curve25519" is always returned in key count responses
+                # regardless of whether we uploaded any keys for it. This is necessary until
+                # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
+                #
+                # Also related:
+                # https://github.com/element-hq/element-android/issues/3725 and
+                # https://github.com/matrix-org/synapse/issues/10456
+                "signed_curve25519": 0
+            },
+        )
+        self.assertEqual(
+            channel.json_body["extensions"]["e2ee"]["device_unused_fallback_key_types"],
+            [],
+        )
+
+    def test_device_lists(self) -> None:
+        """
+        Test that device list updates are included in the response
+        """
+        user1_id = self.register_user("user1", "pass")
+        user1_tok = self.login(user1_id, "pass")
+        user2_id = self.register_user("user2", "pass")
+        user2_tok = self.login(user2_id, "pass")
+
+        test_device_id = "TESTDEVICE"
+        user3_id = self.register_user("user3", "pass")
+        user3_tok = self.login(user3_id, "pass", device_id=test_device_id)
+
+        user4_id = self.register_user("user4", "pass")
+        user4_tok = self.login(user4_id, "pass")
+
+        room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+        self.helper.join(room_id, user1_id, tok=user1_tok)
+        self.helper.join(room_id, user3_id, tok=user3_tok)
+        self.helper.join(room_id, user4_id, tok=user4_tok)
+
+        from_token = self.event_sources.get_current_token()
+
+        # Have user3 update their device list
+        channel = self.make_request(
+            "PUT",
+            f"/devices/{test_device_id}",
+            {
+                "display_name": "New Device Name",
+            },
+            access_token=user3_tok,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # User4 leaves the room
+        self.helper.leave(room_id, user4_id, tok=user4_tok)
+
+        # Make an incremental Sliding Sync request with the e2ee extension enabled
+        channel = self.make_request(
+            "POST",
+            self.sync_endpoint
+            + f"?pos={self.get_success(from_token.to_string(self.store))}",
+            {
+                "lists": {},
+                "extensions": {
+                    "e2ee": {
+                        "enabled": True,
+                    }
+                },
+            },
+            access_token=user1_tok,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Device list updates show up
+        self.assertEqual(
+            channel.json_body["extensions"]["e2ee"]
+            .get("device_lists", {})
+            .get("changed"),
+            [user3_id],
+        )
+        self.assertEqual(
+            channel.json_body["extensions"]["e2ee"].get("device_lists", {}).get("left"),
+            [user4_id],
+        )
+
+    def test_device_one_time_keys_count(self) -> None:
+        """
+        Test that `device_one_time_keys_count` are included in the response
+        """
+        test_device_id = "TESTDEVICE"
+        user1_id = self.register_user("user1", "pass")
+        user1_tok = self.login(user1_id, "pass", device_id=test_device_id)
+
+        # Upload one time keys for the user/device
+        keys: JsonDict = {
+            "alg1:k1": "key1",
+            "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
+            "alg2:k3": {"key": "key3"},
+        }
+        upload_keys_response = self.get_success(
+            self.e2e_keys_handler.upload_keys_for_user(
+                user1_id, test_device_id, {"one_time_keys": keys}
+            )
+        )
+        self.assertDictEqual(
+            upload_keys_response,
+            {
+                "one_time_key_counts": {
+                    "alg1": 1,
+                    "alg2": 2,
+                    # Note that "signed_curve25519" is always returned in key count responses
+                    # regardless of whether we uploaded any keys for it. This is necessary until
+                    # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
+                    #
+                    # Also related:
+                    # https://github.com/element-hq/element-android/issues/3725 and
+                    # https://github.com/matrix-org/synapse/issues/10456
+                    "signed_curve25519": 0,
+                }
+            },
+        )
+
+        # Make a Sliding Sync request with the e2ee extension enabled
+        channel = self.make_request(
+            "POST",
+            self.sync_endpoint,
+            {
+                "lists": {},
+                "extensions": {
+                    "e2ee": {
+                        "enabled": True,
+                    }
+                },
+            },
+            access_token=user1_tok,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check for those one time key counts
+        self.assertEqual(
+            channel.json_body["extensions"]["e2ee"].get("device_one_time_keys_count"),
+            {
+                "alg1": 1,
+                "alg2": 2,
+                # Note that "signed_curve25519" is always returned in key count responses
+                # regardless of whether we uploaded any keys for it. This is necessary until
+                # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
+                #
+                # Also related:
+                # https://github.com/element-hq/element-android/issues/3725 and
+                # https://github.com/matrix-org/synapse/issues/10456
+                "signed_curve25519": 0,
+            },
+        )
+
+    def test_device_unused_fallback_key_types(self) -> None:
+        """
+        Test that `device_unused_fallback_key_types` are included in the response
+        """
+        test_device_id = "TESTDEVICE"
+        user1_id = self.register_user("user1", "pass")
+        user1_tok = self.login(user1_id, "pass", device_id=test_device_id)
+
+        # We shouldn't have any unused fallback keys yet
+        res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(user1_id, test_device_id)
+        )
+        self.assertEqual(res, [])
+
+        # Upload a fallback key for the user/device
+        self.get_success(
+            self.e2e_keys_handler.upload_keys_for_user(
+                user1_id,
+                test_device_id,
+                {"fallback_keys": {"alg1:k1": "fallback_key1"}},
+            )
+        )
+        # We should now have an unused alg1 key
+        fallback_res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(user1_id, test_device_id)
+        )
+        self.assertEqual(fallback_res, ["alg1"], fallback_res)
+
+        # Make a Sliding Sync request with the e2ee extension enabled
+        channel = self.make_request(
+            "POST",
+            self.sync_endpoint,
+            {
+                "lists": {},
+                "extensions": {
+                    "e2ee": {
+                        "enabled": True,
+                    }
+                },
+            },
+            access_token=user1_tok,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check for the unused fallback key types
+        self.assertListEqual(
+            channel.json_body["extensions"]["e2ee"].get(
+                "device_unused_fallback_key_types"
+            ),
+            ["alg1"],
+        )