summary refs log tree commit diff
path: root/synapse/storage/databases/main/devices.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2022-04-04 15:25:20 +0100
committerGitHub <noreply@github.com>2022-04-04 15:25:20 +0100
commit5c9e39e6192e952ba8a5bb8e5485bc6067f91699 (patch)
tree43ad8fbd061254a87c8a93c6f3d32fa029a7cb72 /synapse/storage/databases/main/devices.py
parentRemove more dead/broken dev scripts (#12355) (diff)
downloadsynapse-5c9e39e6192e952ba8a5bb8e5485bc6067f91699.tar.xz
Track device list updates per room. (#12321)
This is a first step in dealing with #7721.

The idea is basically that rather than calculating the full set of users a device list update needs to be sent to up front, we instead simply record the rooms the user was in at the time of the change. This will allow a few things:

1. we can defer calculating the set of remote servers that need to be poked about the change; and
2. during `/sync` and `/keys/changes` we can avoid also avoid calculating users who share rooms with other users, and instead just look at the rooms that have changed.

However, care needs to be taken to correctly handle server downgrades. As such this PR writes to both `device_lists_changes_in_room` and the `device_lists_outbound_pokes` table synchronously. In a future release we can then bump the database schema compat version to `69` and then we can assume that the new `device_lists_changes_in_room` exists and is handled.

There is a temporary option to disable writing to `device_lists_outbound_pokes` synchronously, allowing us to test the new code path does work (and by implication upgrading to a future release and downgrading to this one will work correctly).

Note: Ideally we'd do the calculation of room to servers on a worker (e.g. the background worker), but currently only master can write to the `device_list_outbound_pokes` table.
Diffstat (limited to 'synapse/storage/databases/main/devices.py')
-rw-r--r--synapse/storage/databases/main/devices.py217
1 files changed, 192 insertions, 25 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index f08f7834d3..07eea4b3d2 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -810,6 +810,7 @@ class DeviceWorkerStore(SQLBaseStore):
                     SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
                 ) AS e
                 WHERE ? < stream_id AND stream_id <= ?
+                ORDER BY stream_id ASC
                 LIMIT ?
             """
 
@@ -1528,7 +1529,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         )
 
     async def add_device_change_to_streams(
-        self, user_id: str, device_ids: Collection[str], hosts: Collection[str]
+        self,
+        user_id: str,
+        device_ids: Collection[str],
+        hosts: Optional[Collection[str]],
+        room_ids: Collection[str],
     ) -> Optional[int]:
         """Persist that a user's devices have been updated, and which hosts
         (if any) should be poked.
@@ -1537,7 +1542,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             user_id: The ID of the user whose device changed.
             device_ids: The IDs of any changed devices. If empty, this function will
                 return None.
-            hosts: The remote destinations that should be notified of the change.
+            hosts: The remote destinations that should be notified of the change. If
+                None then the set of hosts have *not* been calculated, and will be
+                calculated later by a background task.
+            room_ids: The rooms that the user is in
 
         Returns:
             The maximum stream ID of device list updates that were added to the database, or
@@ -1546,34 +1554,62 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         if not device_ids:
             return None
 
-        async with self._device_list_id_gen.get_next_mult(
-            len(device_ids)
-        ) as stream_ids:
-            await self.db_pool.runInteraction(
-                "add_device_change_to_stream",
-                self._add_device_change_to_stream_txn,
+        context = get_active_span_text_map()
+
+        def add_device_changes_txn(
+            txn, stream_ids_for_device_change, stream_ids_for_outbound_pokes
+        ):
+            self._add_device_change_to_stream_txn(
+                txn,
                 user_id,
                 device_ids,
-                stream_ids,
+                stream_ids_for_device_change,
             )
 
-        if not hosts:
-            return stream_ids[-1]
+            self._add_device_outbound_room_poke_txn(
+                txn,
+                user_id,
+                device_ids,
+                room_ids,
+                stream_ids_for_device_change,
+                context,
+                hosts_have_been_calculated=hosts is not None,
+            )
 
-        context = get_active_span_text_map()
-        async with self._device_list_id_gen.get_next_mult(
-            len(hosts) * len(device_ids)
-        ) as stream_ids:
-            await self.db_pool.runInteraction(
-                "add_device_outbound_poke_to_stream",
-                self._add_device_outbound_poke_to_stream_txn,
+            # If the set of hosts to send to has not been calculated yet (and so
+            # `hosts` is None) or there are no `hosts` to send to, then skip
+            # trying to persist them to the DB.
+            if not hosts:
+                return
+
+            self._add_device_outbound_poke_to_stream_txn(
+                txn,
                 user_id,
                 device_ids,
                 hosts,
-                stream_ids,
+                stream_ids_for_outbound_pokes,
                 context,
             )
 
+        # `device_lists_stream` wants a stream ID per device update.
+        num_stream_ids = len(device_ids)
+
+        if hosts:
+            # `device_lists_outbound_pokes` wants a different stream ID for
+            # each row, which is a row per host per device update.
+            num_stream_ids += len(hosts) * len(device_ids)
+
+        async with self._device_list_id_gen.get_next_mult(num_stream_ids) as stream_ids:
+            stream_ids_for_device_change = stream_ids[: len(device_ids)]
+            stream_ids_for_outbound_pokes = stream_ids[len(device_ids) :]
+
+            await self.db_pool.runInteraction(
+                "add_device_change_to_stream",
+                add_device_changes_txn,
+                stream_ids_for_device_change,
+                stream_ids_for_outbound_pokes,
+            )
+
         return stream_ids[-1]
 
     def _add_device_change_to_stream_txn(
@@ -1617,7 +1653,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         user_id: str,
         device_ids: Iterable[str],
         hosts: Collection[str],
-        stream_ids: List[str],
+        stream_ids: List[int],
         context: Dict[str, str],
     ) -> None:
         for host in hosts:
@@ -1628,8 +1664,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             )
 
         now = self._clock.time_msec()
-        next_stream_id = iter(stream_ids)
+        stream_id_iterator = iter(stream_ids)
 
+        encoded_context = json_encoder.encode(context)
         self.db_pool.simple_insert_many_txn(
             txn,
             table="device_lists_outbound_pokes",
@@ -1645,16 +1682,146 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             values=[
                 (
                     destination,
-                    next(next_stream_id),
+                    next(stream_id_iterator),
                     user_id,
                     device_id,
                     False,
                     now,
-                    json_encoder.encode(context)
-                    if whitelisted_homeserver(destination)
-                    else "{}",
+                    encoded_context if whitelisted_homeserver(destination) else "{}",
                 )
                 for destination in hosts
                 for device_id in device_ids
             ],
         )
+
+    def _add_device_outbound_room_poke_txn(
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        device_ids: Iterable[str],
+        room_ids: Collection[str],
+        stream_ids: List[str],
+        context: Dict[str, str],
+        hosts_have_been_calculated: bool,
+    ) -> None:
+        """Record the user in the room has updated their device.
+
+        Args:
+            hosts_have_been_calculated: True if `device_lists_outbound_pokes`
+                has been updated already with the updates.
+        """
+
+        # We only need to convert to outbound pokes if they are our user.
+        converted_to_destinations = (
+            hosts_have_been_calculated or not self.hs.is_mine_id(user_id)
+        )
+
+        encoded_context = json_encoder.encode(context)
+
+        # The `device_lists_changes_in_room.stream_id` column matches the
+        # corresponding `stream_id` of the update in the `device_lists_stream`
+        # table, i.e. all rows persisted for the same device update will have
+        # the same `stream_id` (but different room IDs).
+        self.db_pool.simple_insert_many_txn(
+            txn,
+            table="device_lists_changes_in_room",
+            keys=(
+                "user_id",
+                "device_id",
+                "room_id",
+                "stream_id",
+                "converted_to_destinations",
+                "opentracing_context",
+            ),
+            values=[
+                (
+                    user_id,
+                    device_id,
+                    room_id,
+                    stream_id,
+                    converted_to_destinations,
+                    encoded_context,
+                )
+                for room_id in room_ids
+                for device_id, stream_id in zip(device_ids, stream_ids)
+            ],
+        )
+
+    async def get_uncoverted_outbound_room_pokes(
+        self, limit: int = 10
+    ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
+        """Get device list changes by room that have not yet been handled and
+        written to `device_lists_outbound_pokes`.
+
+        Returns:
+            A list of user ID, device ID, room ID, stream ID and optional opentracing context.
+        """
+
+        sql = """
+            SELECT user_id, device_id, room_id, stream_id, opentracing_context
+            FROM device_lists_changes_in_room
+            WHERE NOT converted_to_destinations
+            ORDER BY stream_id
+            LIMIT ?
+        """
+
+        def get_uncoverted_outbound_room_pokes_txn(txn):
+            txn.execute(sql, (limit,))
+            return txn.fetchall()
+
+        return await self.db_pool.runInteraction(
+            "get_uncoverted_outbound_room_pokes", get_uncoverted_outbound_room_pokes_txn
+        )
+
+    async def add_device_list_outbound_pokes(
+        self,
+        user_id: str,
+        device_id: str,
+        room_id: str,
+        stream_id: int,
+        hosts: Collection[str],
+        context: Optional[Dict[str, str]],
+    ) -> None:
+        """Queue the device update to be sent to the given set of hosts,
+        calculated from the room ID.
+
+        Marks the associated row in `device_lists_changes_in_room` as handled.
+        """
+
+        def add_device_list_outbound_pokes_txn(txn, stream_ids: List[int]):
+            if hosts:
+                self._add_device_outbound_poke_to_stream_txn(
+                    txn,
+                    user_id=user_id,
+                    device_ids=[device_id],
+                    hosts=hosts,
+                    stream_ids=stream_ids,
+                    context=context,
+                )
+
+            self.db_pool.simple_update_txn(
+                txn,
+                table="device_lists_changes_in_room",
+                keyvalues={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                    "stream_id": stream_id,
+                    "room_id": room_id,
+                },
+                updatevalues={"converted_to_destinations": True},
+            )
+
+        if not hosts:
+            # If there are no hosts then we don't try and generate stream IDs.
+            return await self.db_pool.runInteraction(
+                "add_device_list_outbound_pokes",
+                add_device_list_outbound_pokes_txn,
+                [],
+            )
+
+        async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:
+            return await self.db_pool.runInteraction(
+                "add_device_list_outbound_pokes",
+                add_device_list_outbound_pokes_txn,
+                stream_ids,
+            )