summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12321.misc1
-rwxr-xr-xsynapse/_scripts/synapse_port_db.py1
-rw-r--r--synapse/config/server.py8
-rw-r--r--synapse/handlers/device.py132
-rw-r--r--synapse/replication/slave/storage/devices.py1
-rw-r--r--synapse/storage/databases/main/__init__.py1
-rw-r--r--synapse/storage/databases/main/devices.py217
-rw-r--r--synapse/storage/schema/__init__.py1
-rw-r--r--synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql38
-rw-r--r--tests/federation/test_federation_sender.py23
-rw-r--r--tests/storage/test_devices.py14
11 files changed, 390 insertions, 47 deletions
diff --git a/changelog.d/12321.misc b/changelog.d/12321.misc
new file mode 100644
index 0000000000..200e7c44fe
--- /dev/null
+++ b/changelog.d/12321.misc
@@ -0,0 +1 @@
+Add ground work for speeding up device list updates for users in large numbers of rooms.
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index c38666da18..6324df883b 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -97,6 +97,7 @@ BOOLEAN_COLUMNS = {
     "users": ["shadow_banned"],
     "e2e_fallback_keys_json": ["used"],
     "access_tokens": ["used"],
+    "device_lists_changes_in_room": ["converted_to_destinations"],
 }
 
 
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 0f90302c95..b3a9e50752 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -680,6 +680,14 @@ class ServerConfig(Config):
             config.get("use_account_validity_in_account_status") or False
         )
 
+        # This is a temporary option that enables fully using the new
+        # `device_lists_changes_in_room` without the backwards compat code. This
+        # is primarily for testing. If enabled the server should *not* be
+        # downgraded, as it may lead to missing device list updates.
+        self.use_new_device_lists_changes_in_room = (
+            config.get("use_new_device_lists_changes_in_room") or False
+        )
+
         self.rooms_to_exclude_from_sync: List[str] = (
             config.get("exclude_rooms_from_sync") or []
         )
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index d5ccaa0c37..c710c02cf9 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -37,7 +37,10 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.logging.opentracing import log_kv, set_tag, trace
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import (
+    run_as_background_process,
+    wrap_as_background_process,
+)
 from synapse.types import (
     JsonDict,
     StreamToken,
@@ -278,6 +281,22 @@ class DeviceHandler(DeviceWorkerHandler):
 
         hs.get_distributor().observe("user_left_room", self.user_left_room)
 
+        # Whether `_handle_new_device_update_async` is currently processing.
+        self._handle_new_device_update_is_processing = False
+
+        # If a new device update may have happened while the loop was
+        # processing.
+        self._handle_new_device_update_new_data = False
+
+        # On start up check if there are any updates pending.
+        hs.get_reactor().callWhenRunning(self._handle_new_device_update_async)
+
+        # Used to decide if we calculate outbound pokes up front or not. By
+        # default we do to allow safely downgrading Synapse.
+        self.use_new_device_lists_changes_in_room = (
+            hs.config.server.use_new_device_lists_changes_in_room
+        )
+
     def _check_device_name_length(self, name: Optional[str]) -> None:
         """
         Checks whether a device name is longer than the maximum allowed length.
@@ -469,19 +488,26 @@ class DeviceHandler(DeviceWorkerHandler):
             # No changes to notify about, so this is a no-op.
             return
 
-        users_who_share_room = await self.store.get_users_who_share_room_with_user(
-            user_id
-        )
+        room_ids = await self.store.get_rooms_for_user(user_id)
+
+        hosts: Optional[Set[str]] = None
+        if not self.use_new_device_lists_changes_in_room:
+            hosts = set()
 
-        hosts: Set[str] = set()
-        if self.hs.is_mine_id(user_id):
-            hosts.update(get_domain_from_id(u) for u in users_who_share_room)
-            hosts.discard(self.server_name)
+            if self.hs.is_mine_id(user_id):
+                for room_id in room_ids:
+                    joined_users = await self.store.get_users_in_room(room_id)
+                    hosts.update(get_domain_from_id(u) for u in joined_users)
 
-        set_tag("target_hosts", hosts)
+                set_tag("target_hosts", hosts)
+
+                hosts.discard(self.server_name)
 
         position = await self.store.add_device_change_to_streams(
-            user_id, device_ids, list(hosts)
+            user_id,
+            device_ids,
+            hosts=hosts,
+            room_ids=room_ids,
         )
 
         if not position:
@@ -495,9 +521,12 @@ class DeviceHandler(DeviceWorkerHandler):
 
         # specify the user ID too since the user should always get their own device list
         # updates, even if they aren't in any rooms.
-        users_to_notify = users_who_share_room.union({user_id})
+        self.notifier.on_new_event(
+            "device_list_key", position, users={user_id}, rooms=room_ids
+        )
 
-        self.notifier.on_new_event("device_list_key", position, users=users_to_notify)
+        # We may need to do some processing asynchronously.
+        self._handle_new_device_update_async()
 
         if hosts:
             logger.info(
@@ -614,6 +643,85 @@ class DeviceHandler(DeviceWorkerHandler):
 
         return {"success": True}
 
+    @wrap_as_background_process("_handle_new_device_update_async")
+    async def _handle_new_device_update_async(self) -> None:
+        """Called when we have a new local device list update that we need to
+        send out over federation.
+
+        This happens in the background so as not to block the original request
+        that generated the device update.
+        """
+        if self._handle_new_device_update_is_processing:
+            self._handle_new_device_update_new_data = True
+            return
+
+        self._handle_new_device_update_is_processing = True
+
+        # The stream ID we processed previous iteration (if any), and the set of
+        # hosts we've already poked about for this update. This is so that we
+        # don't poke the same remote server about the same update repeatedly.
+        current_stream_id = None
+        hosts_already_sent_to: Set[str] = set()
+
+        try:
+            while True:
+                self._handle_new_device_update_new_data = False
+                rows = await self.store.get_uncoverted_outbound_room_pokes()
+                if not rows:
+                    # If the DB returned nothing then there is nothing left to
+                    # do, *unless* a new device list update happened during the
+                    # DB query.
+                    if self._handle_new_device_update_new_data:
+                        continue
+                    else:
+                        return
+
+                for user_id, device_id, room_id, stream_id, opentracing_context in rows:
+                    joined_user_ids = await self.store.get_users_in_room(room_id)
+                    hosts = {get_domain_from_id(u) for u in joined_user_ids}
+                    hosts.discard(self.server_name)
+
+                    # Check if we've already sent this update to some hosts
+                    if current_stream_id == stream_id:
+                        hosts -= hosts_already_sent_to
+
+                    await self.store.add_device_list_outbound_pokes(
+                        user_id=user_id,
+                        device_id=device_id,
+                        room_id=room_id,
+                        stream_id=stream_id,
+                        hosts=hosts,
+                        context=opentracing_context,
+                    )
+
+                    # Notify replication that we've updated the device list stream.
+                    self.notifier.notify_replication()
+
+                    if hosts:
+                        logger.info(
+                            "Sending device list update notif for %r to: %r",
+                            user_id,
+                            hosts,
+                        )
+                        for host in hosts:
+                            self.federation_sender.send_device_messages(
+                                host, immediate=False
+                            )
+                            log_kv(
+                                {"message": "sent device update to host", "host": host}
+                            )
+
+                    if current_stream_id != stream_id:
+                        # Clear the set of hosts we've already sent to as we're
+                        # processing a new update.
+                        hosts_already_sent_to.clear()
+
+                    hosts_already_sent_to.update(hosts)
+                    current_stream_id = stream_id
+
+        finally:
+            self._handle_new_device_update_is_processing = False
+
 
 def _update_device_from_client_ips(
     device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 0ffd34f1da..f040e33bfb 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -44,6 +44,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
             extra_tables=[
                 ("user_signature_stream", "stream_id"),
                 ("device_lists_outbound_pokes", "stream_id"),
+                ("device_lists_changes_in_room", "stream_id"),
             ],
         )
         device_list_max = self._device_list_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 1ea0b2aa6f..cdbe3872fa 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -146,6 +146,7 @@ class DataStore(
             extra_tables=[
                 ("user_signature_stream", "stream_id"),
                 ("device_lists_outbound_pokes", "stream_id"),
+                ("device_lists_changes_in_room", "stream_id"),
             ],
         )
 
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index f08f7834d3..07eea4b3d2 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -810,6 +810,7 @@ class DeviceWorkerStore(SQLBaseStore):
                     SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
                 ) AS e
                 WHERE ? < stream_id AND stream_id <= ?
+                ORDER BY stream_id ASC
                 LIMIT ?
             """
 
@@ -1528,7 +1529,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         )
 
     async def add_device_change_to_streams(
-        self, user_id: str, device_ids: Collection[str], hosts: Collection[str]
+        self,
+        user_id: str,
+        device_ids: Collection[str],
+        hosts: Optional[Collection[str]],
+        room_ids: Collection[str],
     ) -> Optional[int]:
         """Persist that a user's devices have been updated, and which hosts
         (if any) should be poked.
@@ -1537,7 +1542,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             user_id: The ID of the user whose device changed.
             device_ids: The IDs of any changed devices. If empty, this function will
                 return None.
-            hosts: The remote destinations that should be notified of the change.
+            hosts: The remote destinations that should be notified of the change. If
+                None then the set of hosts have *not* been calculated, and will be
+                calculated later by a background task.
+            room_ids: The rooms that the user is in
 
         Returns:
             The maximum stream ID of device list updates that were added to the database, or
@@ -1546,34 +1554,62 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         if not device_ids:
             return None
 
-        async with self._device_list_id_gen.get_next_mult(
-            len(device_ids)
-        ) as stream_ids:
-            await self.db_pool.runInteraction(
-                "add_device_change_to_stream",
-                self._add_device_change_to_stream_txn,
+        context = get_active_span_text_map()
+
+        def add_device_changes_txn(
+            txn, stream_ids_for_device_change, stream_ids_for_outbound_pokes
+        ):
+            self._add_device_change_to_stream_txn(
+                txn,
                 user_id,
                 device_ids,
-                stream_ids,
+                stream_ids_for_device_change,
             )
 
-        if not hosts:
-            return stream_ids[-1]
+            self._add_device_outbound_room_poke_txn(
+                txn,
+                user_id,
+                device_ids,
+                room_ids,
+                stream_ids_for_device_change,
+                context,
+                hosts_have_been_calculated=hosts is not None,
+            )
 
-        context = get_active_span_text_map()
-        async with self._device_list_id_gen.get_next_mult(
-            len(hosts) * len(device_ids)
-        ) as stream_ids:
-            await self.db_pool.runInteraction(
-                "add_device_outbound_poke_to_stream",
-                self._add_device_outbound_poke_to_stream_txn,
+            # If the set of hosts to send to has not been calculated yet (and so
+            # `hosts` is None) or there are no `hosts` to send to, then skip
+            # trying to persist them to the DB.
+            if not hosts:
+                return
+
+            self._add_device_outbound_poke_to_stream_txn(
+                txn,
                 user_id,
                 device_ids,
                 hosts,
-                stream_ids,
+                stream_ids_for_outbound_pokes,
                 context,
             )
 
+        # `device_lists_stream` wants a stream ID per device update.
+        num_stream_ids = len(device_ids)
+
+        if hosts:
+            # `device_lists_outbound_pokes` wants a different stream ID for
+            # each row, which is a row per host per device update.
+            num_stream_ids += len(hosts) * len(device_ids)
+
+        async with self._device_list_id_gen.get_next_mult(num_stream_ids) as stream_ids:
+            stream_ids_for_device_change = stream_ids[: len(device_ids)]
+            stream_ids_for_outbound_pokes = stream_ids[len(device_ids) :]
+
+            await self.db_pool.runInteraction(
+                "add_device_change_to_stream",
+                add_device_changes_txn,
+                stream_ids_for_device_change,
+                stream_ids_for_outbound_pokes,
+            )
+
         return stream_ids[-1]
 
     def _add_device_change_to_stream_txn(
@@ -1617,7 +1653,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         user_id: str,
         device_ids: Iterable[str],
         hosts: Collection[str],
-        stream_ids: List[str],
+        stream_ids: List[int],
         context: Dict[str, str],
     ) -> None:
         for host in hosts:
@@ -1628,8 +1664,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             )
 
         now = self._clock.time_msec()
-        next_stream_id = iter(stream_ids)
+        stream_id_iterator = iter(stream_ids)
 
+        encoded_context = json_encoder.encode(context)
         self.db_pool.simple_insert_many_txn(
             txn,
             table="device_lists_outbound_pokes",
@@ -1645,16 +1682,146 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             values=[
                 (
                     destination,
-                    next(next_stream_id),
+                    next(stream_id_iterator),
                     user_id,
                     device_id,
                     False,
                     now,
-                    json_encoder.encode(context)
-                    if whitelisted_homeserver(destination)
-                    else "{}",
+                    encoded_context if whitelisted_homeserver(destination) else "{}",
                 )
                 for destination in hosts
                 for device_id in device_ids
             ],
         )
+
+    def _add_device_outbound_room_poke_txn(
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        device_ids: Iterable[str],
+        room_ids: Collection[str],
+        stream_ids: List[str],
+        context: Dict[str, str],
+        hosts_have_been_calculated: bool,
+    ) -> None:
+        """Record the user in the room has updated their device.
+
+        Args:
+            hosts_have_been_calculated: True if `device_lists_outbound_pokes`
+                has been updated already with the updates.
+        """
+
+        # We only need to convert to outbound pokes if they are our user.
+        converted_to_destinations = (
+            hosts_have_been_calculated or not self.hs.is_mine_id(user_id)
+        )
+
+        encoded_context = json_encoder.encode(context)
+
+        # The `device_lists_changes_in_room.stream_id` column matches the
+        # corresponding `stream_id` of the update in the `device_lists_stream`
+        # table, i.e. all rows persisted for the same device update will have
+        # the same `stream_id` (but different room IDs).
+        self.db_pool.simple_insert_many_txn(
+            txn,
+            table="device_lists_changes_in_room",
+            keys=(
+                "user_id",
+                "device_id",
+                "room_id",
+                "stream_id",
+                "converted_to_destinations",
+                "opentracing_context",
+            ),
+            values=[
+                (
+                    user_id,
+                    device_id,
+                    room_id,
+                    stream_id,
+                    converted_to_destinations,
+                    encoded_context,
+                )
+                for room_id in room_ids
+                for device_id, stream_id in zip(device_ids, stream_ids)
+            ],
+        )
+
+    async def get_uncoverted_outbound_room_pokes(
+        self, limit: int = 10
+    ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
+        """Get device list changes by room that have not yet been handled and
+        written to `device_lists_outbound_pokes`.
+
+        Returns:
+            A list of user ID, device ID, room ID, stream ID and optional opentracing context.
+        """
+
+        sql = """
+            SELECT user_id, device_id, room_id, stream_id, opentracing_context
+            FROM device_lists_changes_in_room
+            WHERE NOT converted_to_destinations
+            ORDER BY stream_id
+            LIMIT ?
+        """
+
+        def get_uncoverted_outbound_room_pokes_txn(txn):
+            txn.execute(sql, (limit,))
+            return txn.fetchall()
+
+        return await self.db_pool.runInteraction(
+            "get_uncoverted_outbound_room_pokes", get_uncoverted_outbound_room_pokes_txn
+        )
+
+    async def add_device_list_outbound_pokes(
+        self,
+        user_id: str,
+        device_id: str,
+        room_id: str,
+        stream_id: int,
+        hosts: Collection[str],
+        context: Optional[Dict[str, str]],
+    ) -> None:
+        """Queue the device update to be sent to the given set of hosts,
+        calculated from the room ID.
+
+        Marks the associated row in `device_lists_changes_in_room` as handled.
+        """
+
+        def add_device_list_outbound_pokes_txn(txn, stream_ids: List[int]):
+            if hosts:
+                self._add_device_outbound_poke_to_stream_txn(
+                    txn,
+                    user_id=user_id,
+                    device_ids=[device_id],
+                    hosts=hosts,
+                    stream_ids=stream_ids,
+                    context=context,
+                )
+
+            self.db_pool.simple_update_txn(
+                txn,
+                table="device_lists_changes_in_room",
+                keyvalues={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                    "stream_id": stream_id,
+                    "room_id": room_id,
+                },
+                updatevalues={"converted_to_destinations": True},
+            )
+
+        if not hosts:
+            # If there are no hosts then we don't try and generate stream IDs.
+            return await self.db_pool.runInteraction(
+                "add_device_list_outbound_pokes",
+                add_device_list_outbound_pokes_txn,
+                [],
+            )
+
+        async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:
+            return await self.db_pool.runInteraction(
+                "add_device_list_outbound_pokes",
+                add_device_list_outbound_pokes_txn,
+                stream_ids,
+            )
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index ea900e0f3d..151f2aa9bb 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -60,6 +60,7 @@ Changes in SCHEMA_VERSION = 68:
       new events.
 
 Changes in SCHEMA_VERSION = 69:
+    - We now write to `device_lists_changes_in_room` table.
     - Use sequence to generate future `application_services_txns.txn_id`s
 """
 
diff --git a/synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql b/synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql
new file mode 100644
index 0000000000..b5b1782b2a
--- /dev/null
+++ b/synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql
@@ -0,0 +1,38 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE device_lists_changes_in_room (
+    user_id TEXT NOT NULL,
+    device_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+
+    -- This initially matches `device_lists_stream.stream_id`. Note that we
+    -- delete older values from `device_lists_stream`, so we can't use a foreign
+    -- constraint here.
+    --
+    -- The table will contain rows with the same `stream_id` but different
+    -- `room_id`, as for each device update we store a row per room the user is
+    -- joined to. Therefore `(stream_id, room_id)` gives a unique index.
+    stream_id BIGINT NOT NULL,
+
+    -- We have a background process which goes through this table and converts
+    -- entries into rows in `device_lists_outbound_pokes`. Once we have processed
+    -- a row, we mark it as such by setting `converted_to_destinations=TRUE`.
+    converted_to_destinations BOOLEAN NOT NULL,
+    opentracing_context TEXT
+);
+
+CREATE UNIQUE INDEX device_lists_changes_in_stream_id ON device_lists_changes_in_room(stream_id, room_id);
+CREATE INDEX device_lists_changes_in_stream_id_unconverted ON device_lists_changes_in_room(stream_id) WHERE NOT converted_to_destinations;
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index e90592855a..a6e91956af 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -14,6 +14,7 @@
 from typing import Optional
 from unittest.mock import Mock
 
+from parameterized import parameterized_class
 from signedjson import key, sign
 from signedjson.types import BaseKey, SigningKey
 
@@ -154,6 +155,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
         )
 
 
+@parameterized_class(
+    [
+        {"enable_room_poke_code_path": False},
+        {"enable_room_poke_code_path": True},
+    ]
+)
 class FederationSenderDevicesTestCases(HomeserverTestCase):
     servlets = [
         admin.register_servlets,
@@ -168,17 +175,21 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
     def default_config(self):
         c = super().default_config()
         c["send_federation"] = True
+        c["use_new_device_lists_changes_in_room"] = self.enable_room_poke_code_path
         return c
 
     def prepare(self, reactor, clock, hs):
-        # stub out get_users_who_share_room_with_user so that it claims that
-        # `@user2:host2` is in the room
-        def get_users_who_share_room_with_user(user_id):
+        # stub out `get_rooms_for_user` and `get_users_in_room` so that the
+        # server thinks the user shares a room with `@user2:host2`
+        def get_rooms_for_user(user_id):
+            return defer.succeed({"!room:host1"})
+
+        hs.get_datastores().main.get_rooms_for_user = get_rooms_for_user
+
+        def get_users_in_room(room_id):
             return defer.succeed({"@user2:host2"})
 
-        hs.get_datastores().main.get_users_who_share_room_with_user = (
-            get_users_who_share_room_with_user
-        )
+        hs.get_datastores().main.get_users_in_room = get_users_in_room
 
         # whenever send_transaction is called, record the edu data
         self.edus = []
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 21ffc5a909..d1227dd4ac 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -96,7 +96,9 @@ class DeviceStoreTestCase(HomeserverTestCase):
 
         # Add two device updates with sequential `stream_id`s
         self.get_success(
-            self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
+            self.store.add_device_change_to_streams(
+                "user_id", device_ids, ["somehost"], ["!some:room"]
+            )
         )
 
         # Get all device updates ever meant for this remote
@@ -122,7 +124,9 @@ class DeviceStoreTestCase(HomeserverTestCase):
             "device_id5",
         ]
         self.get_success(
-            self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
+            self.store.add_device_change_to_streams(
+                "user_id", device_ids, ["somehost"], ["!some:room"]
+            )
         )
 
         # Get device updates meant for this remote
@@ -144,7 +148,9 @@ class DeviceStoreTestCase(HomeserverTestCase):
         # Add some more device updates to ensure it still resumes properly
         device_ids = ["device_id6", "device_id7"]
         self.get_success(
-            self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
+            self.store.add_device_change_to_streams(
+                "user_id", device_ids, ["somehost"], ["!some:room"]
+            )
         )
 
         # Get the next batch of device updates
@@ -220,7 +226,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
 
         self.get_success(
             self.store.add_device_change_to_streams(
-                "@user_id:test", device_ids, ["somehost"]
+                "@user_id:test", device_ids, ["somehost"], ["!some:room"]
             )
         )