diff options
Diffstat (limited to 'synapse/storage/databases/main/devices.py')
-rw-r--r-- | synapse/storage/databases/main/devices.py | 217 |
1 files changed, 192 insertions, 25 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index f08f7834d3..07eea4b3d2 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -810,6 +810,7 @@ class DeviceWorkerStore(SQLBaseStore): SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes ) AS e WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC LIMIT ? """ @@ -1528,7 +1529,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) async def add_device_change_to_streams( - self, user_id: str, device_ids: Collection[str], hosts: Collection[str] + self, + user_id: str, + device_ids: Collection[str], + hosts: Optional[Collection[str]], + room_ids: Collection[str], ) -> Optional[int]: """Persist that a user's devices have been updated, and which hosts (if any) should be poked. @@ -1537,7 +1542,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): user_id: The ID of the user whose device changed. device_ids: The IDs of any changed devices. If empty, this function will return None. - hosts: The remote destinations that should be notified of the change. + hosts: The remote destinations that should be notified of the change. If + None then the set of hosts have *not* been calculated, and will be + calculated later by a background task. + room_ids: The rooms that the user is in Returns: The maximum stream ID of device list updates that were added to the database, or @@ -1546,34 +1554,62 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not device_ids: return None - async with self._device_list_id_gen.get_next_mult( - len(device_ids) - ) as stream_ids: - await self.db_pool.runInteraction( - "add_device_change_to_stream", - self._add_device_change_to_stream_txn, + context = get_active_span_text_map() + + def add_device_changes_txn( + txn, stream_ids_for_device_change, stream_ids_for_outbound_pokes + ): + self._add_device_change_to_stream_txn( + txn, user_id, device_ids, - stream_ids, + stream_ids_for_device_change, ) - if not hosts: - return stream_ids[-1] + self._add_device_outbound_room_poke_txn( + txn, + user_id, + device_ids, + room_ids, + stream_ids_for_device_change, + context, + hosts_have_been_calculated=hosts is not None, + ) - context = get_active_span_text_map() - async with self._device_list_id_gen.get_next_mult( - len(hosts) * len(device_ids) - ) as stream_ids: - await self.db_pool.runInteraction( - "add_device_outbound_poke_to_stream", - self._add_device_outbound_poke_to_stream_txn, + # If the set of hosts to send to has not been calculated yet (and so + # `hosts` is None) or there are no `hosts` to send to, then skip + # trying to persist them to the DB. + if not hosts: + return + + self._add_device_outbound_poke_to_stream_txn( + txn, user_id, device_ids, hosts, - stream_ids, + stream_ids_for_outbound_pokes, context, ) + # `device_lists_stream` wants a stream ID per device update. + num_stream_ids = len(device_ids) + + if hosts: + # `device_lists_outbound_pokes` wants a different stream ID for + # each row, which is a row per host per device update. + num_stream_ids += len(hosts) * len(device_ids) + + async with self._device_list_id_gen.get_next_mult(num_stream_ids) as stream_ids: + stream_ids_for_device_change = stream_ids[: len(device_ids)] + stream_ids_for_outbound_pokes = stream_ids[len(device_ids) :] + + await self.db_pool.runInteraction( + "add_device_change_to_stream", + add_device_changes_txn, + stream_ids_for_device_change, + stream_ids_for_outbound_pokes, + ) + return stream_ids[-1] def _add_device_change_to_stream_txn( @@ -1617,7 +1653,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): user_id: str, device_ids: Iterable[str], hosts: Collection[str], - stream_ids: List[str], + stream_ids: List[int], context: Dict[str, str], ) -> None: for host in hosts: @@ -1628,8 +1664,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) now = self._clock.time_msec() - next_stream_id = iter(stream_ids) + stream_id_iterator = iter(stream_ids) + encoded_context = json_encoder.encode(context) self.db_pool.simple_insert_many_txn( txn, table="device_lists_outbound_pokes", @@ -1645,16 +1682,146 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): values=[ ( destination, - next(next_stream_id), + next(stream_id_iterator), user_id, device_id, False, now, - json_encoder.encode(context) - if whitelisted_homeserver(destination) - else "{}", + encoded_context if whitelisted_homeserver(destination) else "{}", ) for destination in hosts for device_id in device_ids ], ) + + def _add_device_outbound_room_poke_txn( + self, + txn: LoggingTransaction, + user_id: str, + device_ids: Iterable[str], + room_ids: Collection[str], + stream_ids: List[str], + context: Dict[str, str], + hosts_have_been_calculated: bool, + ) -> None: + """Record the user in the room has updated their device. + + Args: + hosts_have_been_calculated: True if `device_lists_outbound_pokes` + has been updated already with the updates. + """ + + # We only need to convert to outbound pokes if they are our user. + converted_to_destinations = ( + hosts_have_been_calculated or not self.hs.is_mine_id(user_id) + ) + + encoded_context = json_encoder.encode(context) + + # The `device_lists_changes_in_room.stream_id` column matches the + # corresponding `stream_id` of the update in the `device_lists_stream` + # table, i.e. all rows persisted for the same device update will have + # the same `stream_id` (but different room IDs). + self.db_pool.simple_insert_many_txn( + txn, + table="device_lists_changes_in_room", + keys=( + "user_id", + "device_id", + "room_id", + "stream_id", + "converted_to_destinations", + "opentracing_context", + ), + values=[ + ( + user_id, + device_id, + room_id, + stream_id, + converted_to_destinations, + encoded_context, + ) + for room_id in room_ids + for device_id, stream_id in zip(device_ids, stream_ids) + ], + ) + + async def get_uncoverted_outbound_room_pokes( + self, limit: int = 10 + ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]: + """Get device list changes by room that have not yet been handled and + written to `device_lists_outbound_pokes`. + + Returns: + A list of user ID, device ID, room ID, stream ID and optional opentracing context. + """ + + sql = """ + SELECT user_id, device_id, room_id, stream_id, opentracing_context + FROM device_lists_changes_in_room + WHERE NOT converted_to_destinations + ORDER BY stream_id + LIMIT ? + """ + + def get_uncoverted_outbound_room_pokes_txn(txn): + txn.execute(sql, (limit,)) + return txn.fetchall() + + return await self.db_pool.runInteraction( + "get_uncoverted_outbound_room_pokes", get_uncoverted_outbound_room_pokes_txn + ) + + async def add_device_list_outbound_pokes( + self, + user_id: str, + device_id: str, + room_id: str, + stream_id: int, + hosts: Collection[str], + context: Optional[Dict[str, str]], + ) -> None: + """Queue the device update to be sent to the given set of hosts, + calculated from the room ID. + + Marks the associated row in `device_lists_changes_in_room` as handled. + """ + + def add_device_list_outbound_pokes_txn(txn, stream_ids: List[int]): + if hosts: + self._add_device_outbound_poke_to_stream_txn( + txn, + user_id=user_id, + device_ids=[device_id], + hosts=hosts, + stream_ids=stream_ids, + context=context, + ) + + self.db_pool.simple_update_txn( + txn, + table="device_lists_changes_in_room", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "stream_id": stream_id, + "room_id": room_id, + }, + updatevalues={"converted_to_destinations": True}, + ) + + if not hosts: + # If there are no hosts then we don't try and generate stream IDs. + return await self.db_pool.runInteraction( + "add_device_list_outbound_pokes", + add_device_list_outbound_pokes_txn, + [], + ) + + async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: + return await self.db_pool.runInteraction( + "add_device_list_outbound_pokes", + add_device_list_outbound_pokes_txn, + stream_ids, + ) |