summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/__init__.py1
-rw-r--r--synapse/storage/databases/main/devices.py217
2 files changed, 193 insertions, 25 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 1ea0b2aa6f..cdbe3872fa 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -146,6 +146,7 @@ class DataStore(
             extra_tables=[
                 ("user_signature_stream", "stream_id"),
                 ("device_lists_outbound_pokes", "stream_id"),
+                ("device_lists_changes_in_room", "stream_id"),
             ],
         )
 
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index f08f7834d3..07eea4b3d2 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -810,6 +810,7 @@ class DeviceWorkerStore(SQLBaseStore):
                     SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
                 ) AS e
                 WHERE ? < stream_id AND stream_id <= ?
+                ORDER BY stream_id ASC
                 LIMIT ?
             """
 
@@ -1528,7 +1529,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         )
 
     async def add_device_change_to_streams(
-        self, user_id: str, device_ids: Collection[str], hosts: Collection[str]
+        self,
+        user_id: str,
+        device_ids: Collection[str],
+        hosts: Optional[Collection[str]],
+        room_ids: Collection[str],
     ) -> Optional[int]:
         """Persist that a user's devices have been updated, and which hosts
         (if any) should be poked.
@@ -1537,7 +1542,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             user_id: The ID of the user whose device changed.
             device_ids: The IDs of any changed devices. If empty, this function will
                 return None.
-            hosts: The remote destinations that should be notified of the change.
+            hosts: The remote destinations that should be notified of the change. If
+                None then the set of hosts have *not* been calculated, and will be
+                calculated later by a background task.
+            room_ids: The rooms that the user is in
 
         Returns:
             The maximum stream ID of device list updates that were added to the database, or
@@ -1546,34 +1554,62 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         if not device_ids:
             return None
 
-        async with self._device_list_id_gen.get_next_mult(
-            len(device_ids)
-        ) as stream_ids:
-            await self.db_pool.runInteraction(
-                "add_device_change_to_stream",
-                self._add_device_change_to_stream_txn,
+        context = get_active_span_text_map()
+
+        def add_device_changes_txn(
+            txn, stream_ids_for_device_change, stream_ids_for_outbound_pokes
+        ):
+            self._add_device_change_to_stream_txn(
+                txn,
                 user_id,
                 device_ids,
-                stream_ids,
+                stream_ids_for_device_change,
             )
 
-        if not hosts:
-            return stream_ids[-1]
+            self._add_device_outbound_room_poke_txn(
+                txn,
+                user_id,
+                device_ids,
+                room_ids,
+                stream_ids_for_device_change,
+                context,
+                hosts_have_been_calculated=hosts is not None,
+            )
 
-        context = get_active_span_text_map()
-        async with self._device_list_id_gen.get_next_mult(
-            len(hosts) * len(device_ids)
-        ) as stream_ids:
-            await self.db_pool.runInteraction(
-                "add_device_outbound_poke_to_stream",
-                self._add_device_outbound_poke_to_stream_txn,
+            # If the set of hosts to send to has not been calculated yet (and so
+            # `hosts` is None) or there are no `hosts` to send to, then skip
+            # trying to persist them to the DB.
+            if not hosts:
+                return
+
+            self._add_device_outbound_poke_to_stream_txn(
+                txn,
                 user_id,
                 device_ids,
                 hosts,
-                stream_ids,
+                stream_ids_for_outbound_pokes,
                 context,
             )
 
+        # `device_lists_stream` wants a stream ID per device update.
+        num_stream_ids = len(device_ids)
+
+        if hosts:
+            # `device_lists_outbound_pokes` wants a different stream ID for
+            # each row, which is a row per host per device update.
+            num_stream_ids += len(hosts) * len(device_ids)
+
+        async with self._device_list_id_gen.get_next_mult(num_stream_ids) as stream_ids:
+            stream_ids_for_device_change = stream_ids[: len(device_ids)]
+            stream_ids_for_outbound_pokes = stream_ids[len(device_ids) :]
+
+            await self.db_pool.runInteraction(
+                "add_device_change_to_stream",
+                add_device_changes_txn,
+                stream_ids_for_device_change,
+                stream_ids_for_outbound_pokes,
+            )
+
         return stream_ids[-1]
 
     def _add_device_change_to_stream_txn(
@@ -1617,7 +1653,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         user_id: str,
         device_ids: Iterable[str],
         hosts: Collection[str],
-        stream_ids: List[str],
+        stream_ids: List[int],
         context: Dict[str, str],
     ) -> None:
         for host in hosts:
@@ -1628,8 +1664,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             )
 
         now = self._clock.time_msec()
-        next_stream_id = iter(stream_ids)
+        stream_id_iterator = iter(stream_ids)
 
+        encoded_context = json_encoder.encode(context)
         self.db_pool.simple_insert_many_txn(
             txn,
             table="device_lists_outbound_pokes",
@@ -1645,16 +1682,146 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             values=[
                 (
                     destination,
-                    next(next_stream_id),
+                    next(stream_id_iterator),
                     user_id,
                     device_id,
                     False,
                     now,
-                    json_encoder.encode(context)
-                    if whitelisted_homeserver(destination)
-                    else "{}",
+                    encoded_context if whitelisted_homeserver(destination) else "{}",
                 )
                 for destination in hosts
                 for device_id in device_ids
             ],
         )
+
+    def _add_device_outbound_room_poke_txn(
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        device_ids: Iterable[str],
+        room_ids: Collection[str],
+        stream_ids: List[str],
+        context: Dict[str, str],
+        hosts_have_been_calculated: bool,
+    ) -> None:
+        """Record the user in the room has updated their device.
+
+        Args:
+            hosts_have_been_calculated: True if `device_lists_outbound_pokes`
+                has been updated already with the updates.
+        """
+
+        # We only need to convert to outbound pokes if they are our user.
+        converted_to_destinations = (
+            hosts_have_been_calculated or not self.hs.is_mine_id(user_id)
+        )
+
+        encoded_context = json_encoder.encode(context)
+
+        # The `device_lists_changes_in_room.stream_id` column matches the
+        # corresponding `stream_id` of the update in the `device_lists_stream`
+        # table, i.e. all rows persisted for the same device update will have
+        # the same `stream_id` (but different room IDs).
+        self.db_pool.simple_insert_many_txn(
+            txn,
+            table="device_lists_changes_in_room",
+            keys=(
+                "user_id",
+                "device_id",
+                "room_id",
+                "stream_id",
+                "converted_to_destinations",
+                "opentracing_context",
+            ),
+            values=[
+                (
+                    user_id,
+                    device_id,
+                    room_id,
+                    stream_id,
+                    converted_to_destinations,
+                    encoded_context,
+                )
+                for room_id in room_ids
+                for device_id, stream_id in zip(device_ids, stream_ids)
+            ],
+        )
+
+    async def get_uncoverted_outbound_room_pokes(
+        self, limit: int = 10
+    ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
+        """Get device list changes by room that have not yet been handled and
+        written to `device_lists_outbound_pokes`.
+
+        Returns:
+            A list of user ID, device ID, room ID, stream ID and optional opentracing context.
+        """
+
+        sql = """
+            SELECT user_id, device_id, room_id, stream_id, opentracing_context
+            FROM device_lists_changes_in_room
+            WHERE NOT converted_to_destinations
+            ORDER BY stream_id
+            LIMIT ?
+        """
+
+        def get_uncoverted_outbound_room_pokes_txn(txn):
+            txn.execute(sql, (limit,))
+            return txn.fetchall()
+
+        return await self.db_pool.runInteraction(
+            "get_uncoverted_outbound_room_pokes", get_uncoverted_outbound_room_pokes_txn
+        )
+
+    async def add_device_list_outbound_pokes(
+        self,
+        user_id: str,
+        device_id: str,
+        room_id: str,
+        stream_id: int,
+        hosts: Collection[str],
+        context: Optional[Dict[str, str]],
+    ) -> None:
+        """Queue the device update to be sent to the given set of hosts,
+        calculated from the room ID.
+
+        Marks the associated row in `device_lists_changes_in_room` as handled.
+        """
+
+        def add_device_list_outbound_pokes_txn(txn, stream_ids: List[int]):
+            if hosts:
+                self._add_device_outbound_poke_to_stream_txn(
+                    txn,
+                    user_id=user_id,
+                    device_ids=[device_id],
+                    hosts=hosts,
+                    stream_ids=stream_ids,
+                    context=context,
+                )
+
+            self.db_pool.simple_update_txn(
+                txn,
+                table="device_lists_changes_in_room",
+                keyvalues={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                    "stream_id": stream_id,
+                    "room_id": room_id,
+                },
+                updatevalues={"converted_to_destinations": True},
+            )
+
+        if not hosts:
+            # If there are no hosts then we don't try and generate stream IDs.
+            return await self.db_pool.runInteraction(
+                "add_device_list_outbound_pokes",
+                add_device_list_outbound_pokes_txn,
+                [],
+            )
+
+        async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:
+            return await self.db_pool.runInteraction(
+                "add_device_list_outbound_pokes",
+                add_device_list_outbound_pokes_txn,
+                stream_ids,
+            )