From c4456114e1a5471bb61cb45605e782263dc8233c Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Sun, 1 Jan 2023 03:40:46 +0000 Subject: Add experimental support for MSC3391: deleting account data (#14714) --- synapse/storage/databases/main/account_data.py | 219 +++++++++++++++++++++++-- 1 file changed, 206 insertions(+), 13 deletions(-) (limited to 'synapse/storage/databases') diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 07908c41d9..e59776f434 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -123,7 +123,11 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) async def get_account_data_for_user( self, user_id: str ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: - """Get all the client account_data for a user. + """ + Get all the client account_data for a user. + + If experimental MSC3391 support is enabled, any entries with an empty + content body are excluded; as this means they have been deleted. Args: user_id: The user to get the account_data for. @@ -135,27 +139,48 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) def get_account_data_for_user_txn( txn: LoggingTransaction, ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: - rows = self.db_pool.simple_select_list_txn( - txn, - "account_data", - {"user_id": user_id}, - ["account_data_type", "content"], - ) + # The 'content != '{}' condition below prevents us from using + # `simple_select_list_txn` here, as it doesn't support conditions + # other than 'equals'. + sql = """ + SELECT account_data_type, content FROM account_data + WHERE user_id = ? + """ + + # If experimental MSC3391 support is enabled, then account data entries + # with an empty content are considered "deleted". So skip adding them to + # the results. + if self.hs.config.experimental.msc3391_enabled: + sql += " AND content != '{}'" + + txn.execute(sql, (user_id,)) + rows = self.db_pool.cursor_to_dict(txn) global_account_data = { row["account_data_type"]: db_to_json(row["content"]) for row in rows } - rows = self.db_pool.simple_select_list_txn( - txn, - "room_account_data", - {"user_id": user_id}, - ["room_id", "account_data_type", "content"], - ) + # The 'content != '{}' condition below prevents us from using + # `simple_select_list_txn` here, as it doesn't support conditions + # other than 'equals'. + sql = """ + SELECT room_id, account_data_type, content FROM room_account_data + WHERE user_id = ? + """ + + # If experimental MSC3391 support is enabled, then account data entries + # with an empty content are considered "deleted". So skip adding them to + # the results. + if self.hs.config.experimental.msc3391_enabled: + sql += " AND content != '{}'" + + txn.execute(sql, (user_id,)) + rows = self.db_pool.cursor_to_dict(txn) by_room: Dict[str, Dict[str, JsonDict]] = {} for row in rows: room_data = by_room.setdefault(row["room_id"], {}) + room_data[row["account_data_type"]] = db_to_json(row["content"]) return global_account_data, by_room @@ -469,6 +494,72 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) return self._account_data_id_gen.get_current_token() + async def remove_account_data_for_room( + self, user_id: str, room_id: str, account_data_type: str + ) -> Optional[int]: + """Delete the room account data for the user of a given type. + + Args: + user_id: The user to remove account_data for. + room_id: The room ID to scope the request to. + account_data_type: The account data type to delete. + + Returns: + The maximum stream position, or None if there was no matching room account + data to delete. + """ + assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) + + def _remove_account_data_for_room_txn( + txn: LoggingTransaction, next_id: int + ) -> bool: + """ + Args: + txn: The transaction object. + next_id: The stream_id to update any existing rows to. + + Returns: + True if an entry in room_account_data had its content set to '{}', + otherwise False. This informs callers of whether there actually was an + existing room account data entry to delete, or if the call was a no-op. + """ + # We can't use `simple_update` as it doesn't have the ability to specify + # where clauses other than '=', which we need for `content != '{}'` below. + sql = """ + UPDATE room_account_data + SET stream_id = ?, content = '{}' + WHERE user_id = ? + AND room_id = ? + AND account_data_type = ? + AND content != '{}' + """ + txn.execute( + sql, + (next_id, user_id, room_id, account_data_type), + ) + # Return true if any rows were updated. + return txn.rowcount != 0 + + async with self._account_data_id_gen.get_next() as next_id: + row_updated = await self.db_pool.runInteraction( + "remove_account_data_for_room", + _remove_account_data_for_room_txn, + next_id, + ) + + if not row_updated: + return None + + self._account_data_stream_cache.entity_has_changed(user_id, next_id) + self.get_account_data_for_user.invalidate((user_id,)) + self.get_account_data_for_room.invalidate((user_id, room_id)) + self.get_account_data_for_room_and_type.prefill( + (user_id, room_id, account_data_type), {} + ) + + return self._account_data_id_gen.get_current_token() + async def add_account_data_for_user( self, user_id: str, account_data_type: str, content: JsonDict ) -> int: @@ -569,6 +660,108 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,)) self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,)) + async def remove_account_data_for_user( + self, + user_id: str, + account_data_type: str, + ) -> Optional[int]: + """ + Delete a single piece of user account data by type. + + A "delete" is performed by updating a potentially existing row in the + "account_data" database table for (user_id, account_data_type) and + setting its content to "{}". + + Args: + user_id: The user ID to modify the account data of. + account_data_type: The type to remove. + + Returns: + The maximum stream position, or None if there was no matching account data + to delete. + """ + assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) + + def _remove_account_data_for_user_txn( + txn: LoggingTransaction, next_id: int + ) -> bool: + """ + Args: + txn: The transaction object. + next_id: The stream_id to update any existing rows to. + + Returns: + True if an entry in account_data had its content set to '{}', otherwise + False. This informs callers of whether there actually was an existing + account data entry to delete, or if the call was a no-op. + """ + # We can't use `simple_update` as it doesn't have the ability to specify + # where clauses other than '=', which we need for `content != '{}'` below. + sql = """ + UPDATE account_data + SET stream_id = ?, content = '{}' + WHERE user_id = ? + AND account_data_type = ? + AND content != '{}' + """ + txn.execute(sql, (next_id, user_id, account_data_type)) + if txn.rowcount == 0: + # We didn't update any rows. This means that there was no matching room + # account data entry to delete in the first place. + return False + + # Ignored users get denormalized into a separate table as an optimisation. + if account_data_type == AccountDataTypes.IGNORED_USER_LIST: + # If this method was called with the ignored users account data type, we + # simply delete all ignored users. + + # First pull all the users that this user ignores. + previously_ignored_users = set( + self.db_pool.simple_select_onecol_txn( + txn, + table="ignored_users", + keyvalues={"ignorer_user_id": user_id}, + retcol="ignored_user_id", + ) + ) + + # Then delete them from the database. + self.db_pool.simple_delete_txn( + txn, + table="ignored_users", + keyvalues={"ignorer_user_id": user_id}, + ) + + # Invalidate the cache for ignored users which were removed. + for ignored_user_id in previously_ignored_users: + self._invalidate_cache_and_stream( + txn, self.ignored_by, (ignored_user_id,) + ) + + # Invalidate for this user the cache tracking ignored users. + self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,)) + + return True + + async with self._account_data_id_gen.get_next() as next_id: + row_updated = await self.db_pool.runInteraction( + "remove_account_data_for_user", + _remove_account_data_for_user_txn, + next_id, + ) + + if not row_updated: + return None + + self._account_data_stream_cache.entity_has_changed(user_id, next_id) + self.get_account_data_for_user.invalidate((user_id,)) + self.get_global_account_data_by_type_for_user.prefill( + (user_id, account_data_type), {} + ) + + return self._account_data_id_gen.get_current_token() + async def purge_account_data_for_user(self, user_id: str) -> None: """ Removes ALL the account data for a user. -- cgit 1.5.1 From db1cfe9c80a707995fcad8f3faa839acb247068a Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Wed, 4 Jan 2023 11:49:26 +0000 Subject: Update all stream IDs after processing replication rows (#14723) This creates a new store method, `process_replication_position` that is called after `process_replication_rows`. By moving stream ID advances here this guarantees any relevant cache invalidations will have been applied before the stream is advanced. This avoids race conditions where Python switches between threads mid way through processing the `process_replication_rows` method where stream IDs may be advanced before caches are invalidated due to class resolution ordering. See this comment/issue for further discussion: https://github.com/matrix-org/synapse/issues/14158#issuecomment-1344048703 --- changelog.d/14723.bugfix | 1 + synapse/replication/tcp/client.py | 3 +++ synapse/storage/_base.py | 17 ++++++++++++++++- synapse/storage/databases/main/account_data.py | 14 ++++++++++---- synapse/storage/databases/main/cache.py | 11 ++++++++--- synapse/storage/databases/main/deviceinbox.py | 7 +++++++ synapse/storage/databases/main/devices.py | 11 +++++++++-- synapse/storage/databases/main/events_worker.py | 15 ++++++++++----- synapse/storage/databases/main/presence.py | 8 +++++++- synapse/storage/databases/main/push_rule.py | 7 +++++++ synapse/storage/databases/main/pusher.py | 6 +++--- synapse/storage/databases/main/receipts.py | 7 +++++++ synapse/storage/databases/main/tags.py | 8 +++++++- 13 files changed, 95 insertions(+), 20 deletions(-) create mode 100644 changelog.d/14723.bugfix (limited to 'synapse/storage/databases') diff --git a/changelog.d/14723.bugfix b/changelog.d/14723.bugfix new file mode 100644 index 0000000000..e1f89cee35 --- /dev/null +++ b/changelog.d/14723.bugfix @@ -0,0 +1 @@ +Ensure stream IDs are always updated after caches get invalidated with workers. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 658d89210d..b5e40da533 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -152,6 +152,9 @@ class ReplicationDataHandler: rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ self.store.process_replication_rows(stream_name, instance_name, token, rows) + # NOTE: this must be called after process_replication_rows to ensure any + # cache invalidations are first handled before any stream ID advances. + self.store.process_replication_position(stream_name, instance_name, token) if self.send_handler: await self.send_handler.process_replication_rows(stream_name, token, rows) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 69abf6fa87..41d9111019 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -57,7 +57,22 @@ class SQLBaseStore(metaclass=ABCMeta): token: int, rows: Iterable[Any], ) -> None: - pass + """ + Used by storage classes to invalidate caches based on incoming replication data. These + must not update any ID generators, use `process_replication_position`. + """ + + def process_replication_position( # noqa: B027 (no-op by design) + self, + stream_name: str, + instance_name: str, + token: int, + ) -> None: + """ + Used by storage classes to advance ID generators based on incoming replication data. This + is called after process_replication_rows such that caches are invalidated before any token + positions advance. + """ def _invalidate_state_caches( self, room_id: str, members_changed: Collection[str] diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index e59776f434..86032897f5 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -436,10 +436,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) token: int, rows: Iterable[Any], ) -> None: - if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) - elif stream_name == AccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) + if stream_name == AccountDataStream.NAME: for row in rows: if not row.room_id: self.get_global_account_data_by_type_for_user.invalidate( @@ -454,6 +451,15 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == TagAccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + elif stream_name == AccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + async def add_account_data_to_room( self, user_id: str, room_id: str, account_data_type: str, content: JsonDict ) -> int: diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index a58668a380..2179a8bf59 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -164,9 +164,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore): backfilled=True, ) elif stream_name == CachesStream.NAME: - if self._cache_id_gen: - self._cache_id_gen.advance(instance_name, token) - for row in rows: if row.cache_func == CURRENT_STATE_CACHE_NAME: if row.keys is None: @@ -182,6 +179,14 @@ class CacheInvalidationWorkerStore(SQLBaseStore): super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == CachesStream.NAME: + if self._cache_id_gen: + self._cache_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None: data = row.data diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 48a54d9cb8..713be91c5d 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -157,6 +157,13 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) return super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == ToDeviceStream.NAME: + self._device_inbox_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + def get_to_device_stream_token(self) -> int: return self._device_inbox_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index a5bb4d404e..db877e3f13 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -162,14 +162,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] ) -> None: if stream_name == DeviceListsStream.NAME: - self._device_list_id_gen.advance(instance_name, token) self._invalidate_caches_for_devices(token, rows) elif stream_name == UserSignatureStream.NAME: - self._device_list_id_gen.advance(instance_name, token) for row in rows: self._user_signature_stream_cache.entity_has_changed(row.user_id, token) return super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == DeviceListsStream.NAME: + self._device_list_id_gen.advance(instance_name, token) + elif stream_name == UserSignatureStream.NAME: + self._device_list_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + def _invalidate_caches_for_devices( self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow] ) -> None: diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 761b15a815..d150fa8a94 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -388,11 +388,7 @@ class EventsWorkerStore(SQLBaseStore): token: int, rows: Iterable[Any], ) -> None: - if stream_name == EventsStream.NAME: - self._stream_id_gen.advance(instance_name, token) - elif stream_name == BackfillStream.NAME: - self._backfill_id_gen.advance(instance_name, -token) - elif stream_name == UnPartialStatedEventStream.NAME: + if stream_name == UnPartialStatedEventStream.NAME: for row in rows: assert isinstance(row, UnPartialStatedEventStreamRow) @@ -405,6 +401,15 @@ class EventsWorkerStore(SQLBaseStore): super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == EventsStream.NAME: + self._stream_id_gen.advance(instance_name, token) + elif stream_name == BackfillStream.NAME: + self._backfill_id_gen.advance(instance_name, -token) + super().process_replication_position(stream_name, instance_name, token) + async def have_censored_event(self, event_id: str) -> bool: """Check if an event has been censored, i.e. if the content of the event has been erased from the database due to a redaction. diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 9769a18a9d..7b60815043 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -439,8 +439,14 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore) rows: Iterable[Any], ) -> None: if stream_name == PresenceStream.NAME: - self._presence_id_gen.advance(instance_name, token) for row in rows: self.presence_stream_cache.entity_has_changed(row.user_id, token) self._get_presence_for_user.invalidate((row.user_id,)) return super().process_replication_rows(stream_name, instance_name, token, rows) + + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == PresenceStream.NAME: + self._presence_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index d4c64c46ad..d4e4b777da 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -154,6 +154,13 @@ class PushRulesWorkerStore( self.push_rules_stream_cache.entity_has_changed(row.user_id, token) return super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == PushRulesStream.NAME: + self._push_rules_stream_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + @cached(max_entries=5000) async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules: rows = await self.db_pool.simple_select_list( diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 40fd781a6a..7f24a3b6ec 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -111,12 +111,12 @@ class PusherWorkerStore(SQLBaseStore): def get_pushers_stream_token(self) -> int: return self._pushers_id_gen.get_current_token() - def process_replication_rows( - self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + def process_replication_position( + self, stream_name: str, instance_name: str, token: int ) -> None: if stream_name == PushersStream.NAME: self._pushers_id_gen.advance(instance_name, token) - return super().process_replication_rows(stream_name, instance_name, token, rows) + super().process_replication_position(stream_name, instance_name, token) async def get_pushers_by_app_id_and_pushkey( self, app_id: str, pushkey: str diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index e06725f69c..86f5bce5f0 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -588,6 +588,13 @@ class ReceiptsWorkerStore(SQLBaseStore): return super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == ReceiptsStream.NAME: + self._receipts_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + def _insert_linearized_receipt_txn( self, txn: LoggingTransaction, diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index b0f5de67a3..e23c927e02 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -300,13 +300,19 @@ class TagsWorkerStore(AccountDataWorkerStore): rows: Iterable[Any], ) -> None: if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) for row in rows: self.get_tags_for_user.invalidate((row.user_id,)) self._account_data_stream_cache.entity_has_changed(row.user_id, token) super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == TagAccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + class TagsStore(TagsWorkerStore): pass -- cgit 1.5.1 From ba4ea7d13ffae53644b206222af95a5171faa27c Mon Sep 17 00:00:00 2001 From: reivilibre Date: Tue, 10 Jan 2023 11:17:59 +0000 Subject: Batch up replication requests to request the resyncing of remote users's devices. (#14716) --- changelog.d/14716.misc | 1 + synapse/handlers/device.py | 124 +++++++++++++++++++++++------- synapse/handlers/devicemessage.py | 2 +- synapse/handlers/e2e_keys.py | 93 +++++++++++++--------- synapse/handlers/federation_event.py | 2 +- synapse/replication/http/devices.py | 74 +++++++++++++++++- synapse/storage/databases/main/devices.py | 30 ++++++-- synapse/types/__init__.py | 4 + synapse/util/async_helpers.py | 55 ++++++++++++- 9 files changed, 306 insertions(+), 79 deletions(-) create mode 100644 changelog.d/14716.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/14716.misc b/changelog.d/14716.misc new file mode 100644 index 0000000000..ef9522e01d --- /dev/null +++ b/changelog.d/14716.misc @@ -0,0 +1 @@ +Batch up replication requests to request the resyncing of remote users's devices. \ No newline at end of file diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index d4750a32e6..89864e1119 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from http import HTTPStatus from typing import ( TYPE_CHECKING, Any, @@ -33,6 +34,7 @@ from synapse.api.errors import ( Codes, FederationDeniedError, HttpResponseException, + InvalidAPICallError, RequestSendFailed, SynapseError, ) @@ -45,6 +47,7 @@ from synapse.types import ( JsonDict, StreamKeyType, StreamToken, + UserID, get_domain_from_id, get_verify_key_from_cross_signing_key, ) @@ -893,12 +896,47 @@ class DeviceListWorkerUpdater: def __init__(self, hs: "HomeServer"): from synapse.replication.http.devices import ( + ReplicationMultiUserDevicesResyncRestServlet, ReplicationUserDevicesResyncRestServlet, ) self._user_device_resync_client = ( ReplicationUserDevicesResyncRestServlet.make_client(hs) ) + self._multi_user_device_resync_client = ( + ReplicationMultiUserDevicesResyncRestServlet.make_client(hs) + ) + + async def multi_user_device_resync( + self, user_ids: List[str], mark_failed_as_stale: bool = True + ) -> Dict[str, Optional[JsonDict]]: + """ + Like `user_device_resync` but operates on multiple users **from the same origin** + at once. + + Returns: + Dict from User ID to the same Dict as `user_device_resync`. + """ + # mark_failed_as_stale is not sent. Ensure this doesn't break expectations. + assert mark_failed_as_stale + + if not user_ids: + # Shortcut empty requests + return {} + + try: + return await self._multi_user_device_resync_client(user_ids=user_ids) + except SynapseError as err: + if not ( + err.code == HTTPStatus.NOT_FOUND and err.errcode == Codes.UNRECOGNIZED + ): + raise + + # Fall back to single requests + result: Dict[str, Optional[JsonDict]] = {} + for user_id in user_ids: + result[user_id] = await self._user_device_resync_client(user_id=user_id) + return result async def user_device_resync( self, user_id: str, mark_failed_as_stale: bool = True @@ -913,8 +951,10 @@ class DeviceListWorkerUpdater: A dict with device info as under the "devices" in the result of this request: https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid + None when we weren't able to fetch the device info for some reason, + e.g. due to a connection problem. """ - return await self._user_device_resync_client(user_id=user_id) + return (await self.multi_user_device_resync([user_id]))[user_id] class DeviceListUpdater(DeviceListWorkerUpdater): @@ -1160,19 +1200,66 @@ class DeviceListUpdater(DeviceListWorkerUpdater): # Allow future calls to retry resyncinc out of sync device lists. self._resync_retry_in_progress = False + async def multi_user_device_resync( + self, user_ids: List[str], mark_failed_as_stale: bool = True + ) -> Dict[str, Optional[JsonDict]]: + """ + Like `user_device_resync` but operates on multiple users **from the same origin** + at once. + + Returns: + Dict from User ID to the same Dict as `user_device_resync`. + """ + if not user_ids: + return {} + + origins = {UserID.from_string(user_id).domain for user_id in user_ids} + + if len(origins) != 1: + raise InvalidAPICallError(f"Only one origin permitted, got {origins!r}") + + result = {} + failed = set() + # TODO(Perf): Actually batch these up + for user_id in user_ids: + user_result, user_failed = await self._user_device_resync_returning_failed( + user_id + ) + result[user_id] = user_result + if user_failed: + failed.add(user_id) + + if mark_failed_as_stale: + await self.store.mark_remote_users_device_caches_as_stale(failed) + + return result + async def user_device_resync( self, user_id: str, mark_failed_as_stale: bool = True ) -> Optional[JsonDict]: + result, failed = await self._user_device_resync_returning_failed(user_id) + + if failed and mark_failed_as_stale: + # Mark the remote user's device list as stale so we know we need to retry + # it later. + await self.store.mark_remote_users_device_caches_as_stale((user_id,)) + + return result + + async def _user_device_resync_returning_failed( + self, user_id: str + ) -> Tuple[Optional[JsonDict], bool]: """Fetches all devices for a user and updates the device cache with them. Args: user_id: The user's id whose device_list will be updated. - mark_failed_as_stale: Whether to mark the user's device list as stale - if the attempt to resync failed. Returns: - A dict with device info as under the "devices" in the result of this - request: - https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid + - A dict with device info as under the "devices" in the result of this + request: + https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid + None when we weren't able to fetch the device info for some reason, + e.g. due to a connection problem. + - True iff the resync failed and the device list should be marked as stale. """ logger.debug("Attempting to resync the device list for %s", user_id) log_kv({"message": "Doing resync to update device list."}) @@ -1181,12 +1268,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater): try: result = await self.federation.query_user_devices(origin, user_id) except NotRetryingDestination: - if mark_failed_as_stale: - # Mark the remote user's device list as stale so we know we need to retry - # it later. - await self.store.mark_remote_user_device_cache_as_stale(user_id) - - return None + return None, True except (RequestSendFailed, HttpResponseException) as e: logger.warning( "Failed to handle device list update for %s: %s", @@ -1194,23 +1276,18 @@ class DeviceListUpdater(DeviceListWorkerUpdater): e, ) - if mark_failed_as_stale: - # Mark the remote user's device list as stale so we know we need to retry - # it later. - await self.store.mark_remote_user_device_cache_as_stale(user_id) - # We abort on exceptions rather than accepting the update # as otherwise synapse will 'forget' that its device list # is out of date. If we bail then we will retry the resync # next time we get a device list update for this user_id. # This makes it more likely that the device lists will # eventually become consistent. - return None + return None, True except FederationDeniedError as e: set_tag("error", True) log_kv({"reason": "FederationDeniedError"}) logger.info(e) - return None + return None, False except Exception as e: set_tag("error", True) log_kv( @@ -1218,12 +1295,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater): ) logger.exception("Failed to handle device list update for %s", user_id) - if mark_failed_as_stale: - # Mark the remote user's device list as stale so we know we need to retry - # it later. - await self.store.mark_remote_user_device_cache_as_stale(user_id) - - return None + return None, True log_kv({"result": result}) stream_id = result["stream_id"] devices = result["devices"] @@ -1305,7 +1377,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater): # point. self._seen_updates[user_id] = {stream_id} - return result + return result, False async def process_cross_signing_key_update( self, diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 75e89850f5..00c403db49 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -195,7 +195,7 @@ class DeviceMessageHandler: sender_user_id, unknown_devices, ) - await self.store.mark_remote_user_device_cache_as_stale(sender_user_id) + await self.store.mark_remote_users_device_caches_as_stale((sender_user_id,)) # Immediately attempt a resync in the background run_in_background(self._user_device_resync, user_id=sender_user_id) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 5fe102e2f2..d2188ca08f 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -36,8 +36,8 @@ from synapse.types import ( get_domain_from_id, get_verify_key_from_cross_signing_key, ) -from synapse.util import json_decoder, unwrapFirstError -from synapse.util.async_helpers import Linearizer, delay_cancellation +from synapse.util import json_decoder +from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.cancellation import cancellable from synapse.util.retryutils import NotRetryingDestination @@ -238,24 +238,28 @@ class E2eKeysHandler: # Now fetch any devices that we don't have in our cache # TODO It might make sense to propagate cancellations into the # deferreds which are querying remote homeservers. - await make_deferred_yieldable( - delay_cancellation( - defer.gatherResults( - [ - run_in_background( - self._query_devices_for_destination, - results, - cross_signing_keys, - failures, - destination, - queries, - timeout, - ) - for destination, queries in remote_queries_not_in_cache.items() - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) + logger.debug( + "%d destinations to query devices for", len(remote_queries_not_in_cache) + ) + + async def _query( + destination_queries: Tuple[str, Dict[str, Iterable[str]]] + ) -> None: + destination, queries = destination_queries + return await self._query_devices_for_destination( + results, + cross_signing_keys, + failures, + destination, + queries, + timeout, ) + + await concurrently_execute( + _query, + remote_queries_not_in_cache.items(), + 10, + delay_cancellation=True, ) ret = {"device_keys": results, "failures": failures} @@ -300,28 +304,41 @@ class E2eKeysHandler: # queries. We use the more efficient batched query_client_keys for all # remaining users user_ids_updated = [] - for (user_id, device_list) in destination_query.items(): - if user_id in user_ids_updated: - continue - if device_list: - continue + # Perform a user device resync for each user only once and only as long as: + # - they have an empty device_list + # - they are in some rooms that this server can see + users_to_resync_devices = { + user_id + for (user_id, device_list) in destination_query.items() + if (not device_list) and (await self.store.get_rooms_for_user(user_id)) + } - room_ids = await self.store.get_rooms_for_user(user_id) - if not room_ids: - continue + logger.debug( + "%d users to resync devices for from destination %s", + len(users_to_resync_devices), + destination, + ) - # We've decided we're sharing a room with this user and should - # probably be tracking their device lists. However, we haven't - # done an initial sync on the device list so we do it now. - try: - resync_results = ( - await self.device_handler.device_list_updater.user_device_resync( - user_id - ) + try: + user_resync_results = ( + await self.device_handler.device_list_updater.multi_user_device_resync( + list(users_to_resync_devices) ) + ) + for user_id in users_to_resync_devices: + resync_results = user_resync_results[user_id] + if resync_results is None: - raise ValueError("Device resync failed") + # TODO: It's weird that we'll store a failure against a + # destination, yet continue processing users from that + # destination. + # We might want to consider changing this, but for now + # I'm leaving it as I found it. + failures[destination] = _exception_to_failure( + ValueError(f"Device resync failed for {user_id!r}") + ) + continue # Add the device keys to the results. user_devices = resync_results["devices"] @@ -339,8 +356,8 @@ class E2eKeysHandler: if self_signing_key: cross_signing_keys["self_signing_keys"][user_id] = self_signing_key - except Exception as e: - failures[destination] = _exception_to_failure(e) + except Exception as e: + failures[destination] = _exception_to_failure(e) if len(destination_query) == len(user_ids_updated): # We've updated all the users in the query and we do not need to diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 31df7f55cc..6df000faaf 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -1423,7 +1423,7 @@ class FederationEventHandler: """ try: - await self._store.mark_remote_user_device_cache_as_stale(sender) + await self._store.mark_remote_users_device_caches_as_stale((sender,)) # Immediately attempt a resync in the background if self._config.worker.worker_app: diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index 7c4941c3d3..ea5c08e6cf 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -13,12 +13,13 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from twisted.web.server import Request from synapse.http.server import HttpServer from synapse.http.servlet import parse_json_object_from_request +from synapse.logging.opentracing import active_span from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict @@ -84,6 +85,76 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): return 200, user_devices +class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint): + """Ask master to resync the device list for multiple users from the same + remote server by contacting their server. + + This must happen on master so that the results can be correctly cached in + the database and streamed to workers. + + Request format: + + POST /_synapse/replication/multi_user_device_resync + + { + "user_ids": ["@alice:example.org", "@bob:example.org", ...] + } + + Response is roughly equivalent to ` /_matrix/federation/v1/user/devices/:user_id` + response, but there is a map from user ID to response, e.g.: + + { + "@alice:example.org": { + "devices": [ + { + "device_id": "JLAFKJWSCS", + "keys": { ... }, + "device_display_name": "Alice's Mobile Phone" + } + ] + }, + ... + } + """ + + NAME = "multi_user_device_resync" + PATH_ARGS = () + CACHE = False + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + from synapse.handlers.device import DeviceHandler + + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_list_updater = handler.device_list_updater + + self.store = hs.get_datastores().main + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload(user_ids: List[str]) -> JsonDict: # type: ignore[override] + return {"user_ids": user_ids} + + async def _handle_request( # type: ignore[override] + self, request: Request + ) -> Tuple[int, Dict[str, Optional[JsonDict]]]: + content = parse_json_object_from_request(request) + user_ids: List[str] = content["user_ids"] + + logger.info("Resync for %r", user_ids) + span = active_span() + if span: + span.set_tag("user_ids", f"{user_ids!r}") + + multi_user_devices = await self.device_list_updater.multi_user_device_resync( + user_ids + ) + + return 200, multi_user_devices + + class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint): """Ask master to upload keys for the user and send them out over federation to update other servers. @@ -151,4 +222,5 @@ class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReplicationUserDevicesResyncRestServlet(hs).register(http_server) + ReplicationMultiUserDevicesResyncRestServlet(hs).register(http_server) ReplicationUploadKeysForUserRestServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index db877e3f13..b067664473 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -54,7 +54,7 @@ from synapse.storage.util.id_generators import ( AbstractStreamIdTracker, StreamIdGenerator, ) -from synapse.types import JsonDict, get_verify_key_from_cross_signing_key +from synapse.types import JsonDict, StrCollection, get_verify_key_from_cross_signing_key from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import LruCache @@ -1069,16 +1069,30 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): return {row["user_id"] for row in rows} - async def mark_remote_user_device_cache_as_stale(self, user_id: str) -> None: + async def mark_remote_users_device_caches_as_stale( + self, user_ids: StrCollection + ) -> None: """Records that the server has reason to believe the cache of the devices for the remote users is out of date. """ - await self.db_pool.simple_upsert( - table="device_lists_remote_resync", - keyvalues={"user_id": user_id}, - values={}, - insertion_values={"added_ts": self._clock.time_msec()}, - desc="mark_remote_user_device_cache_as_stale", + + def _mark_remote_users_device_caches_as_stale_txn( + txn: LoggingTransaction, + ) -> None: + # TODO add insertion_values support to simple_upsert_many and use + # that! + for user_id in user_ids: + self.db_pool.simple_upsert_txn( + txn, + table="device_lists_remote_resync", + keyvalues={"user_id": user_id}, + values={}, + insertion_values={"added_ts": self._clock.time_msec()}, + ) + + await self.db_pool.runInteraction( + "mark_remote_users_device_caches_as_stale", + _mark_remote_users_device_caches_as_stale_txn, ) async def mark_remote_user_device_cache_as_valid(self, user_id: str) -> None: diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index f2d436ddc3..0c725eb967 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -77,6 +77,10 @@ JsonMapping = Mapping[str, Any] # A JSON-serialisable object. JsonSerializable = object +# Collection[str] that does not include str itself; str being a Sequence[str] +# is very misleading and results in bugs. +StrCollection = Union[Tuple[str, ...], List[str], Set[str]] + # Note that this seems to require inheriting *directly* from Interface in order # for mypy-zope to realize it is an interface. diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index d24c4f68c4..01e3cd46f6 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -205,7 +205,10 @@ T = TypeVar("T") async def concurrently_execute( - func: Callable[[T], Any], args: Iterable[T], limit: int + func: Callable[[T], Any], + args: Iterable[T], + limit: int, + delay_cancellation: bool = False, ) -> None: """Executes the function with each argument concurrently while limiting the number of concurrent executions. @@ -215,6 +218,8 @@ async def concurrently_execute( args: List of arguments to pass to func, each invocation of func gets a single argument. limit: Maximum number of conccurent executions. + delay_cancellation: Whether to delay cancellation until after the invocations + have finished. Returns: None, when all function invocations have finished. The return values @@ -233,9 +238,16 @@ async def concurrently_execute( # We use `itertools.islice` to handle the case where the number of args is # less than the limit, avoiding needlessly spawning unnecessary background # tasks. - await yieldable_gather_results( - _concurrently_execute_inner, (value for value in itertools.islice(it, limit)) - ) + if delay_cancellation: + await yieldable_gather_results_delaying_cancellation( + _concurrently_execute_inner, + (value for value in itertools.islice(it, limit)), + ) + else: + await yieldable_gather_results( + _concurrently_execute_inner, + (value for value in itertools.islice(it, limit)), + ) P = ParamSpec("P") @@ -292,6 +304,41 @@ async def yieldable_gather_results( raise dfe.subFailure.value from None +async def yieldable_gather_results_delaying_cancellation( + func: Callable[Concatenate[T, P], Awaitable[R]], + iter: Iterable[T], + *args: P.args, + **kwargs: P.kwargs, +) -> List[R]: + """Executes the function with each argument concurrently. + Cancellation is delayed until after all the results have been gathered. + + See `yieldable_gather_results`. + + Args: + func: Function to execute that returns a Deferred + iter: An iterable that yields items that get passed as the first + argument to the function + *args: Arguments to be passed to each call to func + **kwargs: Keyword arguments to be passed to each call to func + + Returns + A list containing the results of the function + """ + try: + return await make_deferred_yieldable( + delay_cancellation( + defer.gatherResults( + [run_in_background(func, item, *args, **kwargs) for item in iter], # type: ignore[arg-type] + consumeErrors=True, + ) + ) + ) + except defer.FirstError as dfe: + assert isinstance(dfe.subFailure.value, BaseException) + raise dfe.subFailure.value from None + + T1 = TypeVar("T1") T2 = TypeVar("T2") T3 = TypeVar("T3") -- cgit 1.5.1