summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2022-04-04 15:25:20 +0100
committerGitHub <noreply@github.com>2022-04-04 15:25:20 +0100
commit5c9e39e6192e952ba8a5bb8e5485bc6067f91699 (patch)
tree43ad8fbd061254a87c8a93c6f3d32fa029a7cb72 /synapse
parentRemove more dead/broken dev scripts (#12355) (diff)
downloadsynapse-5c9e39e6192e952ba8a5bb8e5485bc6067f91699.tar.xz
Track device list updates per room. (#12321)
This is a first step in dealing with #7721.

The idea is basically that rather than calculating the full set of users a device list update needs to be sent to up front, we instead simply record the rooms the user was in at the time of the change. This will allow a few things:

1. we can defer calculating the set of remote servers that need to be poked about the change; and
2. during `/sync` and `/keys/changes` we can avoid also avoid calculating users who share rooms with other users, and instead just look at the rooms that have changed.

However, care needs to be taken to correctly handle server downgrades. As such this PR writes to both `device_lists_changes_in_room` and the `device_lists_outbound_pokes` table synchronously. In a future release we can then bump the database schema compat version to `69` and then we can assume that the new `device_lists_changes_in_room` exists and is handled.

There is a temporary option to disable writing to `device_lists_outbound_pokes` synchronously, allowing us to test the new code path does work (and by implication upgrading to a future release and downgrading to this one will work correctly).

Note: Ideally we'd do the calculation of room to servers on a worker (e.g. the background worker), but currently only master can write to the `device_list_outbound_pokes` table.
Diffstat (limited to 'synapse')
-rwxr-xr-xsynapse/_scripts/synapse_port_db.py1
-rw-r--r--synapse/config/server.py8
-rw-r--r--synapse/handlers/device.py132
-rw-r--r--synapse/replication/slave/storage/devices.py1
-rw-r--r--synapse/storage/databases/main/__init__.py1
-rw-r--r--synapse/storage/databases/main/devices.py217
-rw-r--r--synapse/storage/schema/__init__.py1
-rw-r--r--synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql38
8 files changed, 362 insertions, 37 deletions
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index c38666da18..6324df883b 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -97,6 +97,7 @@ BOOLEAN_COLUMNS = {
     "users": ["shadow_banned"],
     "e2e_fallback_keys_json": ["used"],
     "access_tokens": ["used"],
+    "device_lists_changes_in_room": ["converted_to_destinations"],
 }
 
 
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 0f90302c95..b3a9e50752 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -680,6 +680,14 @@ class ServerConfig(Config):
             config.get("use_account_validity_in_account_status") or False
         )
 
+        # This is a temporary option that enables fully using the new
+        # `device_lists_changes_in_room` without the backwards compat code. This
+        # is primarily for testing. If enabled the server should *not* be
+        # downgraded, as it may lead to missing device list updates.
+        self.use_new_device_lists_changes_in_room = (
+            config.get("use_new_device_lists_changes_in_room") or False
+        )
+
         self.rooms_to_exclude_from_sync: List[str] = (
             config.get("exclude_rooms_from_sync") or []
         )
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index d5ccaa0c37..c710c02cf9 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -37,7 +37,10 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.logging.opentracing import log_kv, set_tag, trace
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import (
+    run_as_background_process,
+    wrap_as_background_process,
+)
 from synapse.types import (
     JsonDict,
     StreamToken,
@@ -278,6 +281,22 @@ class DeviceHandler(DeviceWorkerHandler):
 
         hs.get_distributor().observe("user_left_room", self.user_left_room)
 
+        # Whether `_handle_new_device_update_async` is currently processing.
+        self._handle_new_device_update_is_processing = False
+
+        # If a new device update may have happened while the loop was
+        # processing.
+        self._handle_new_device_update_new_data = False
+
+        # On start up check if there are any updates pending.
+        hs.get_reactor().callWhenRunning(self._handle_new_device_update_async)
+
+        # Used to decide if we calculate outbound pokes up front or not. By
+        # default we do to allow safely downgrading Synapse.
+        self.use_new_device_lists_changes_in_room = (
+            hs.config.server.use_new_device_lists_changes_in_room
+        )
+
     def _check_device_name_length(self, name: Optional[str]) -> None:
         """
         Checks whether a device name is longer than the maximum allowed length.
@@ -469,19 +488,26 @@ class DeviceHandler(DeviceWorkerHandler):
             # No changes to notify about, so this is a no-op.
             return
 
-        users_who_share_room = await self.store.get_users_who_share_room_with_user(
-            user_id
-        )
+        room_ids = await self.store.get_rooms_for_user(user_id)
+
+        hosts: Optional[Set[str]] = None
+        if not self.use_new_device_lists_changes_in_room:
+            hosts = set()
 
-        hosts: Set[str] = set()
-        if self.hs.is_mine_id(user_id):
-            hosts.update(get_domain_from_id(u) for u in users_who_share_room)
-            hosts.discard(self.server_name)
+            if self.hs.is_mine_id(user_id):
+                for room_id in room_ids:
+                    joined_users = await self.store.get_users_in_room(room_id)
+                    hosts.update(get_domain_from_id(u) for u in joined_users)
 
-        set_tag("target_hosts", hosts)
+                set_tag("target_hosts", hosts)
+
+                hosts.discard(self.server_name)
 
         position = await self.store.add_device_change_to_streams(
-            user_id, device_ids, list(hosts)
+            user_id,
+            device_ids,
+            hosts=hosts,
+            room_ids=room_ids,
         )
 
         if not position:
@@ -495,9 +521,12 @@ class DeviceHandler(DeviceWorkerHandler):
 
         # specify the user ID too since the user should always get their own device list
         # updates, even if they aren't in any rooms.
-        users_to_notify = users_who_share_room.union({user_id})
+        self.notifier.on_new_event(
+            "device_list_key", position, users={user_id}, rooms=room_ids
+        )
 
-        self.notifier.on_new_event("device_list_key", position, users=users_to_notify)
+        # We may need to do some processing asynchronously.
+        self._handle_new_device_update_async()
 
         if hosts:
             logger.info(
@@ -614,6 +643,85 @@ class DeviceHandler(DeviceWorkerHandler):
 
         return {"success": True}
 
+    @wrap_as_background_process("_handle_new_device_update_async")
+    async def _handle_new_device_update_async(self) -> None:
+        """Called when we have a new local device list update that we need to
+        send out over federation.
+
+        This happens in the background so as not to block the original request
+        that generated the device update.
+        """
+        if self._handle_new_device_update_is_processing:
+            self._handle_new_device_update_new_data = True
+            return
+
+        self._handle_new_device_update_is_processing = True
+
+        # The stream ID we processed previous iteration (if any), and the set of
+        # hosts we've already poked about for this update. This is so that we
+        # don't poke the same remote server about the same update repeatedly.
+        current_stream_id = None
+        hosts_already_sent_to: Set[str] = set()
+
+        try:
+            while True:
+                self._handle_new_device_update_new_data = False
+                rows = await self.store.get_uncoverted_outbound_room_pokes()
+                if not rows:
+                    # If the DB returned nothing then there is nothing left to
+                    # do, *unless* a new device list update happened during the
+                    # DB query.
+                    if self._handle_new_device_update_new_data:
+                        continue
+                    else:
+                        return
+
+                for user_id, device_id, room_id, stream_id, opentracing_context in rows:
+                    joined_user_ids = await self.store.get_users_in_room(room_id)
+                    hosts = {get_domain_from_id(u) for u in joined_user_ids}
+                    hosts.discard(self.server_name)
+
+                    # Check if we've already sent this update to some hosts
+                    if current_stream_id == stream_id:
+                        hosts -= hosts_already_sent_to
+
+                    await self.store.add_device_list_outbound_pokes(
+                        user_id=user_id,
+                        device_id=device_id,
+                        room_id=room_id,
+                        stream_id=stream_id,
+                        hosts=hosts,
+                        context=opentracing_context,
+                    )
+
+                    # Notify replication that we've updated the device list stream.
+                    self.notifier.notify_replication()
+
+                    if hosts:
+                        logger.info(
+                            "Sending device list update notif for %r to: %r",
+                            user_id,
+                            hosts,
+                        )
+                        for host in hosts:
+                            self.federation_sender.send_device_messages(
+                                host, immediate=False
+                            )
+                            log_kv(
+                                {"message": "sent device update to host", "host": host}
+                            )
+
+                    if current_stream_id != stream_id:
+                        # Clear the set of hosts we've already sent to as we're
+                        # processing a new update.
+                        hosts_already_sent_to.clear()
+
+                    hosts_already_sent_to.update(hosts)
+                    current_stream_id = stream_id
+
+        finally:
+            self._handle_new_device_update_is_processing = False
+
 
 def _update_device_from_client_ips(
     device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 0ffd34f1da..f040e33bfb 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -44,6 +44,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
             extra_tables=[
                 ("user_signature_stream", "stream_id"),
                 ("device_lists_outbound_pokes", "stream_id"),
+                ("device_lists_changes_in_room", "stream_id"),
             ],
         )
         device_list_max = self._device_list_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 1ea0b2aa6f..cdbe3872fa 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -146,6 +146,7 @@ class DataStore(
             extra_tables=[
                 ("user_signature_stream", "stream_id"),
                 ("device_lists_outbound_pokes", "stream_id"),
+                ("device_lists_changes_in_room", "stream_id"),
             ],
         )
 
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index f08f7834d3..07eea4b3d2 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -810,6 +810,7 @@ class DeviceWorkerStore(SQLBaseStore):
                     SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
                 ) AS e
                 WHERE ? < stream_id AND stream_id <= ?
+                ORDER BY stream_id ASC
                 LIMIT ?
             """
 
@@ -1528,7 +1529,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         )
 
     async def add_device_change_to_streams(
-        self, user_id: str, device_ids: Collection[str], hosts: Collection[str]
+        self,
+        user_id: str,
+        device_ids: Collection[str],
+        hosts: Optional[Collection[str]],
+        room_ids: Collection[str],
     ) -> Optional[int]:
         """Persist that a user's devices have been updated, and which hosts
         (if any) should be poked.
@@ -1537,7 +1542,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             user_id: The ID of the user whose device changed.
             device_ids: The IDs of any changed devices. If empty, this function will
                 return None.
-            hosts: The remote destinations that should be notified of the change.
+            hosts: The remote destinations that should be notified of the change. If
+                None then the set of hosts have *not* been calculated, and will be
+                calculated later by a background task.
+            room_ids: The rooms that the user is in
 
         Returns:
             The maximum stream ID of device list updates that were added to the database, or
@@ -1546,34 +1554,62 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         if not device_ids:
             return None
 
-        async with self._device_list_id_gen.get_next_mult(
-            len(device_ids)
-        ) as stream_ids:
-            await self.db_pool.runInteraction(
-                "add_device_change_to_stream",
-                self._add_device_change_to_stream_txn,
+        context = get_active_span_text_map()
+
+        def add_device_changes_txn(
+            txn, stream_ids_for_device_change, stream_ids_for_outbound_pokes
+        ):
+            self._add_device_change_to_stream_txn(
+                txn,
                 user_id,
                 device_ids,
-                stream_ids,
+                stream_ids_for_device_change,
             )
 
-        if not hosts:
-            return stream_ids[-1]
+            self._add_device_outbound_room_poke_txn(
+                txn,
+                user_id,
+                device_ids,
+                room_ids,
+                stream_ids_for_device_change,
+                context,
+                hosts_have_been_calculated=hosts is not None,
+            )
 
-        context = get_active_span_text_map()
-        async with self._device_list_id_gen.get_next_mult(
-            len(hosts) * len(device_ids)
-        ) as stream_ids:
-            await self.db_pool.runInteraction(
-                "add_device_outbound_poke_to_stream",
-                self._add_device_outbound_poke_to_stream_txn,
+            # If the set of hosts to send to has not been calculated yet (and so
+            # `hosts` is None) or there are no `hosts` to send to, then skip
+            # trying to persist them to the DB.
+            if not hosts:
+                return
+
+            self._add_device_outbound_poke_to_stream_txn(
+                txn,
                 user_id,
                 device_ids,
                 hosts,
-                stream_ids,
+                stream_ids_for_outbound_pokes,
                 context,
             )
 
+        # `device_lists_stream` wants a stream ID per device update.
+        num_stream_ids = len(device_ids)
+
+        if hosts:
+            # `device_lists_outbound_pokes` wants a different stream ID for
+            # each row, which is a row per host per device update.
+            num_stream_ids += len(hosts) * len(device_ids)
+
+        async with self._device_list_id_gen.get_next_mult(num_stream_ids) as stream_ids:
+            stream_ids_for_device_change = stream_ids[: len(device_ids)]
+            stream_ids_for_outbound_pokes = stream_ids[len(device_ids) :]
+
+            await self.db_pool.runInteraction(
+                "add_device_change_to_stream",
+                add_device_changes_txn,
+                stream_ids_for_device_change,
+                stream_ids_for_outbound_pokes,
+            )
+
         return stream_ids[-1]
 
     def _add_device_change_to_stream_txn(
@@ -1617,7 +1653,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         user_id: str,
         device_ids: Iterable[str],
         hosts: Collection[str],
-        stream_ids: List[str],
+        stream_ids: List[int],
         context: Dict[str, str],
     ) -> None:
         for host in hosts:
@@ -1628,8 +1664,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             )
 
         now = self._clock.time_msec()
-        next_stream_id = iter(stream_ids)
+        stream_id_iterator = iter(stream_ids)
 
+        encoded_context = json_encoder.encode(context)
         self.db_pool.simple_insert_many_txn(
             txn,
             table="device_lists_outbound_pokes",
@@ -1645,16 +1682,146 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             values=[
                 (
                     destination,
-                    next(next_stream_id),
+                    next(stream_id_iterator),
                     user_id,
                     device_id,
                     False,
                     now,
-                    json_encoder.encode(context)
-                    if whitelisted_homeserver(destination)
-                    else "{}",
+                    encoded_context if whitelisted_homeserver(destination) else "{}",
                 )
                 for destination in hosts
                 for device_id in device_ids
             ],
         )
+
+    def _add_device_outbound_room_poke_txn(
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        device_ids: Iterable[str],
+        room_ids: Collection[str],
+        stream_ids: List[str],
+        context: Dict[str, str],
+        hosts_have_been_calculated: bool,
+    ) -> None:
+        """Record the user in the room has updated their device.
+
+        Args:
+            hosts_have_been_calculated: True if `device_lists_outbound_pokes`
+                has been updated already with the updates.
+        """
+
+        # We only need to convert to outbound pokes if they are our user.
+        converted_to_destinations = (
+            hosts_have_been_calculated or not self.hs.is_mine_id(user_id)
+        )
+
+        encoded_context = json_encoder.encode(context)
+
+        # The `device_lists_changes_in_room.stream_id` column matches the
+        # corresponding `stream_id` of the update in the `device_lists_stream`
+        # table, i.e. all rows persisted for the same device update will have
+        # the same `stream_id` (but different room IDs).
+        self.db_pool.simple_insert_many_txn(
+            txn,
+            table="device_lists_changes_in_room",
+            keys=(
+                "user_id",
+                "device_id",
+                "room_id",
+                "stream_id",
+                "converted_to_destinations",
+                "opentracing_context",
+            ),
+            values=[
+                (
+                    user_id,
+                    device_id,
+                    room_id,
+                    stream_id,
+                    converted_to_destinations,
+                    encoded_context,
+                )
+                for room_id in room_ids
+                for device_id, stream_id in zip(device_ids, stream_ids)
+            ],
+        )
+
+    async def get_uncoverted_outbound_room_pokes(
+        self, limit: int = 10
+    ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
+        """Get device list changes by room that have not yet been handled and
+        written to `device_lists_outbound_pokes`.
+
+        Returns:
+            A list of user ID, device ID, room ID, stream ID and optional opentracing context.
+        """
+
+        sql = """
+            SELECT user_id, device_id, room_id, stream_id, opentracing_context
+            FROM device_lists_changes_in_room
+            WHERE NOT converted_to_destinations
+            ORDER BY stream_id
+            LIMIT ?
+        """
+
+        def get_uncoverted_outbound_room_pokes_txn(txn):
+            txn.execute(sql, (limit,))
+            return txn.fetchall()
+
+        return await self.db_pool.runInteraction(
+            "get_uncoverted_outbound_room_pokes", get_uncoverted_outbound_room_pokes_txn
+        )
+
+    async def add_device_list_outbound_pokes(
+        self,
+        user_id: str,
+        device_id: str,
+        room_id: str,
+        stream_id: int,
+        hosts: Collection[str],
+        context: Optional[Dict[str, str]],
+    ) -> None:
+        """Queue the device update to be sent to the given set of hosts,
+        calculated from the room ID.
+
+        Marks the associated row in `device_lists_changes_in_room` as handled.
+        """
+
+        def add_device_list_outbound_pokes_txn(txn, stream_ids: List[int]):
+            if hosts:
+                self._add_device_outbound_poke_to_stream_txn(
+                    txn,
+                    user_id=user_id,
+                    device_ids=[device_id],
+                    hosts=hosts,
+                    stream_ids=stream_ids,
+                    context=context,
+                )
+
+            self.db_pool.simple_update_txn(
+                txn,
+                table="device_lists_changes_in_room",
+                keyvalues={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                    "stream_id": stream_id,
+                    "room_id": room_id,
+                },
+                updatevalues={"converted_to_destinations": True},
+            )
+
+        if not hosts:
+            # If there are no hosts then we don't try and generate stream IDs.
+            return await self.db_pool.runInteraction(
+                "add_device_list_outbound_pokes",
+                add_device_list_outbound_pokes_txn,
+                [],
+            )
+
+        async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:
+            return await self.db_pool.runInteraction(
+                "add_device_list_outbound_pokes",
+                add_device_list_outbound_pokes_txn,
+                stream_ids,
+            )
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index ea900e0f3d..151f2aa9bb 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -60,6 +60,7 @@ Changes in SCHEMA_VERSION = 68:
       new events.
 
 Changes in SCHEMA_VERSION = 69:
+    - We now write to `device_lists_changes_in_room` table.
     - Use sequence to generate future `application_services_txns.txn_id`s
 """
 
diff --git a/synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql b/synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql
new file mode 100644
index 0000000000..b5b1782b2a
--- /dev/null
+++ b/synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql
@@ -0,0 +1,38 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE device_lists_changes_in_room (
+    user_id TEXT NOT NULL,
+    device_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+
+    -- This initially matches `device_lists_stream.stream_id`. Note that we
+    -- delete older values from `device_lists_stream`, so we can't use a foreign
+    -- constraint here.
+    --
+    -- The table will contain rows with the same `stream_id` but different
+    -- `room_id`, as for each device update we store a row per room the user is
+    -- joined to. Therefore `(stream_id, room_id)` gives a unique index.
+    stream_id BIGINT NOT NULL,
+
+    -- We have a background process which goes through this table and converts
+    -- entries into rows in `device_lists_outbound_pokes`. Once we have processed
+    -- a row, we mark it as such by setting `converted_to_destinations=TRUE`.
+    converted_to_destinations BOOLEAN NOT NULL,
+    opentracing_context TEXT
+);
+
+CREATE UNIQUE INDEX device_lists_changes_in_stream_id ON device_lists_changes_in_room(stream_id, room_id);
+CREATE INDEX device_lists_changes_in_stream_id_unconverted ON device_lists_changes_in_room(stream_id) WHERE NOT converted_to_destinations;