From 9bb2eac71962970d02842bca441f4bcdbbf93a11 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 22 Feb 2023 15:29:09 -0500 Subject: Bump black from 22.12.0 to 23.1.0 (#15103) --- synapse/storage/databases/main/deviceinbox.py | 5 ++--- synapse/storage/databases/main/devices.py | 4 ++-- synapse/storage/databases/main/e2e_room_keys.py | 2 +- synapse/storage/databases/main/end_to_end_keys.py | 8 ++++---- synapse/storage/databases/main/event_federation.py | 1 - synapse/storage/databases/main/events.py | 1 - synapse/storage/databases/main/events_bg_updates.py | 4 ++-- synapse/storage/databases/main/events_worker.py | 2 +- synapse/storage/databases/main/media_repository.py | 1 - synapse/storage/databases/main/pusher.py | 3 --- synapse/storage/databases/main/receipts.py | 1 - synapse/storage/databases/main/room.py | 1 - synapse/storage/databases/main/search.py | 2 -- synapse/storage/databases/main/state.py | 1 - synapse/storage/databases/main/stats.py | 2 +- synapse/storage/databases/main/stream.py | 1 + synapse/storage/databases/main/transactions.py | 1 - synapse/storage/databases/main/user_directory.py | 1 - synapse/storage/databases/state/bg_updates.py | 1 - synapse/storage/databases/state/store.py | 7 ++----- 20 files changed, 16 insertions(+), 33 deletions(-) (limited to 'synapse/storage/databases') diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 8e61aba454..0d75d9739a 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -721,8 +721,8 @@ class DeviceInboxWorkerStore(SQLBaseStore): ], ) - for (user_id, messages_by_device) in edu["messages"].items(): - for (device_id, msg) in messages_by_device.items(): + for user_id, messages_by_device in edu["messages"].items(): + for device_id, msg in messages_by_device.items(): with start_active_span("store_outgoing_to_device_message"): set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["sender"]) set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["message_id"]) @@ -959,7 +959,6 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): def _remove_dead_devices_from_device_inbox_txn( txn: LoggingTransaction, ) -> Tuple[int, bool]: - if "max_stream_id" in progress: max_stream_id = progress["max_stream_id"] else: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 1ca66d57d4..0dd15f16ff 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -512,7 +512,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): results.append(("org.matrix.signing_key_update", result)) if issue_8631_logger.isEnabledFor(logging.DEBUG): - for (user_id, edu) in results: + for user_id, edu in results: issue_8631_logger.debug( "device update to %s for %s from %s to %s: %s", destination, @@ -1316,7 +1316,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ) """ count = 0 - for (destination, user_id, stream_id, device_id) in rows: + for destination, user_id, stream_id, device_id in rows: txn.execute( delete_sql, (destination, user_id, stream_id, stream_id, device_id) ) diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index 6240f9a75e..9f8d2e4bea 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -108,7 +108,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): raise StoreError(404, "No backup with that version exists") values = [] - for (room_id, session_id, room_key) in room_keys: + for room_id, session_id, room_key in room_keys: values.append( ( user_id, diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 2c2d145666..b9c39b1718 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -268,7 +268,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker ) # add each cross-signing signature to the correct device in the result dict. - for (user_id, key_id, device_id, signature) in cross_sigs_result: + for user_id, key_id, device_id, signature in cross_sigs_result: target_device_result = result[user_id][device_id] # We've only looked up cross-signatures for non-deleted devices with key # data. @@ -311,7 +311,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker # devices. user_list = [] user_device_list = [] - for (user_id, device_id) in query_list: + for user_id, device_id in query_list: if device_id is None: user_list.append(user_id) else: @@ -353,7 +353,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker txn.execute(sql, query_params) - for (user_id, device_id, display_name, key_json) in txn: + for user_id, device_id, display_name, key_json in txn: assert device_id is not None if include_deleted_devices: deleted_devices.remove((user_id, device_id)) @@ -382,7 +382,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker signature_query_clauses = [] signature_query_params = [] - for (user_id, device_id) in device_query: + for user_id, device_id in device_query: signature_query_clauses.append( "target_user_id = ? AND target_device_id = ? AND user_id = ?" ) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index ca780cca36..ff3edeb716 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1612,7 +1612,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas latest_events: List[str], limit: int, ) -> List[str]: - seen_events = set(earliest_events) front = set(latest_events) - seen_events event_results: List[str] = [] diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 7996cbb557..73b8aea16c 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -469,7 +469,6 @@ class PersistEventsStore: txn: LoggingTransaction, events: List[EventBase], ) -> None: - # We only care about state events, so this if there are no state events. if not any(e.is_state() for e in events): return diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 584536111d..0a275e6ce6 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -709,7 +709,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): nbrows = 0 last_row_event_id = "" - for (event_id, event_json_raw) in results: + for event_id, event_json_raw in results: try: event_json = db_to_json(event_json_raw) @@ -1167,7 +1167,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): results = list(txn) # (event_id, parent_id, rel_type) for each relation relations_to_insert: List[Tuple[str, str, str]] = [] - for (event_id, event_json_raw) in results: + for event_id, event_json_raw in results: try: event_json = db_to_json(event_json_raw) except Exception as e: diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 6d0ef10258..b7e7498125 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1493,7 +1493,7 @@ class EventsWorkerStore(SQLBaseStore): txn.execute(redactions_sql + clause, args) - for (redacter, redacted) in txn: + for redacter, redacted in txn: d = event_dict.get(redacted) if d: d.redactions.append(redacter) diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index b202c5eb87..fa8be214ce 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -196,7 +196,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def get_local_media_by_user_paginate_txn( txn: LoggingTransaction, ) -> Tuple[List[Dict[str, Any]], int]: - # Set ordering order_by_column = MediaSortOrder(order_by).value diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index df53e726e6..fddbc07afa 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -344,7 +344,6 @@ class PusherWorkerStore(SQLBaseStore): last_user = progress.get("last_user", "") def _delete_pushers(txn: LoggingTransaction) -> int: - sql = """ SELECT name FROM users WHERE deactivated = ? and name > ? @@ -392,7 +391,6 @@ class PusherWorkerStore(SQLBaseStore): last_pusher = progress.get("last_pusher", 0) def _delete_pushers(txn: LoggingTransaction) -> int: - sql = """ SELECT p.id, access_token FROM pushers AS p LEFT JOIN access_tokens AS a ON (p.access_token = a.id) @@ -449,7 +447,6 @@ class PusherWorkerStore(SQLBaseStore): last_pusher = progress.get("last_pusher", 0) def _delete_pushers(txn: LoggingTransaction) -> int: - sql = """ SELECT p.id, p.user_name, p.app_id, p.pushkey FROM pushers AS p diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index dddf49c2d5..92a82240ab 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -887,7 +887,6 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore): def _populate_receipt_event_stream_ordering_txn( txn: LoggingTransaction, ) -> bool: - if "max_stream_id" in progress: max_stream_id = progress["max_stream_id"] else: diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 644bbb8878..39f89291b2 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -2168,7 +2168,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): def _get_event_report_txn( txn: LoggingTransaction, report_id: int ) -> Optional[Dict[str, Any]]: - sql = """ SELECT er.id, diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 3fe433f66c..a7aae661d8 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -122,7 +122,6 @@ class SearchWorkerStore(SQLBaseStore): class SearchBackgroundUpdateStore(SearchWorkerStore): - EVENT_SEARCH_UPDATE_NAME = "event_search" EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" @@ -615,7 +614,6 @@ class SearchStore(SearchBackgroundUpdateStore): """ count_args = [search_query] + count_args elif isinstance(self.database_engine, Sqlite3Engine): - # We use CROSS JOIN here to ensure we use the right indexes. # https://sqlite.org/optoverview.html#crossjoin # diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index ba325d390b..ebb2ae964f 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -490,7 +490,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): - CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events" diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index d7b7d0c3c9..d3393d8e49 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -461,7 +461,7 @@ class StatsStore(StateDeltasStore): insert_cols = [] qargs = [] - for (key, val) in chain( + for key, val in chain( keyvalues.items(), absolutes.items(), additive_relatives.items() ): insert_cols.append(key) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 818c46182e..ac5fbf6b86 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -87,6 +87,7 @@ MAX_STREAM_SIZE = 1000 _STREAM_TOKEN = "stream" _TOPOLOGICAL_TOKEN = "topological" + # Used as return values for pagination APIs @attr.s(slots=True, frozen=True, auto_attribs=True) class _EventDictReturn: diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 6b33d809b6..6d72bd9f67 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -573,7 +573,6 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): def get_destination_rooms_paginate_txn( txn: LoggingTransaction, ) -> Tuple[List[JsonDict], int]: - if direction == Direction.BACKWARDS: order = "DESC" else: diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 30af4b3b6c..c3f2b61bd5 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -98,7 +98,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): async def _populate_user_directory_createtables( self, progress: JsonDict, batch_size: int ) -> int: - # Get all the rooms that we want to process. def _make_staging_area(txn: LoggingTransaction) -> None: sql = ( diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index d743282f13..097dea5182 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -251,7 +251,6 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): - STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx" diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 1a7232b276..89b1faa6c8 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -257,14 +257,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): member_filter, non_member_filter = state_filter.get_member_split() # Now we look them up in the member and non-member caches - ( - non_member_state, - incomplete_groups_nm, - ) = self._get_state_for_groups_using_cache( + non_member_state, incomplete_groups_nm = self._get_state_for_groups_using_cache( groups, self._state_group_cache, state_filter=non_member_filter ) - (member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache( + member_state, incomplete_groups_m = self._get_state_for_groups_using_cache( groups, self._state_group_members_cache, state_filter=member_filter ) -- cgit 1.5.1 From 335f52d595c2c32e4b512b97e2851bc98b819ca7 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 24 Feb 2023 13:39:45 +0000 Subject: Improve handling of non-ASCII characters in user directory search (#15143) * Fix a long-standing bug where non-ASCII characters in search terms, including accented letters, would not match characters in a different case. * Fix a long-standing bug where search terms using combining accents would not match display names using precomposed accents and vice versa. To fully take effect, the user directory must be rebuilt after this change. Fixes #14630. Signed-off-by: Sean Quah --- changelog.d/15143.misc | 1 + synapse/storage/databases/main/user_directory.py | 52 ++++++++- tests/storage/test_user_directory.py | 133 +++++++++++++++++++++++ 3 files changed, 184 insertions(+), 2 deletions(-) create mode 100644 changelog.d/15143.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/15143.misc b/changelog.d/15143.misc new file mode 100644 index 0000000000..cff4518811 --- /dev/null +++ b/changelog.d/15143.misc @@ -0,0 +1 @@ +Fix a long-standing bug where the user directory search was not case-insensitive for accented characters. diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index c3f2b61bd5..f16a509ac4 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -14,6 +14,7 @@ import logging import re +import unicodedata from typing import ( TYPE_CHECKING, Iterable, @@ -490,6 +491,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): values={"display_name": display_name, "avatar_url": avatar_url}, ) + # The display name that goes into the database index. + index_display_name = display_name + if index_display_name is not None: + index_display_name = _filter_text_for_index(index_display_name) + if isinstance(self.database_engine, PostgresEngine): # We weight the localpart most highly, then display name and finally # server name @@ -507,11 +513,15 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): user_id, get_localpart_from_id(user_id), get_domain_from_id(user_id), - display_name, + index_display_name, ), ) elif isinstance(self.database_engine, Sqlite3Engine): - value = "%s %s" % (user_id, display_name) if display_name else user_id + value = ( + "%s %s" % (user_id, index_display_name) + if index_display_name + else user_id + ) self.db_pool.simple_upsert_txn( txn, table="user_directory_search", @@ -896,6 +906,41 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): return {"limited": limited, "results": results[0:limit]} +def _filter_text_for_index(text: str) -> str: + """Transforms text before it is inserted into the user directory index, or searched + for in the user directory index. + + Note that the user directory search table needs to be rebuilt whenever this function + changes. + """ + # Lowercase the text, to make searches case-insensitive. + # This is necessary for both PostgreSQL and SQLite. PostgreSQL's + # `to_tsquery/to_tsvector` functions don't lowercase non-ASCII characters when using + # the "C" collation, while SQLite just doesn't lowercase non-ASCII characters at + # all. + text = text.lower() + + # Normalize the text. NFKC normalization has two effects: + # 1. It canonicalizes the text, ie. maps all visually identical strings to the same + # string. For example, ["e", "◌́"] is mapped to ["é"]. + # 2. It maps strings that are roughly equivalent to the same string. + # For example, ["dž"] is mapped to ["d", "ž"], ["①"] to ["1"] and ["i⁹"] to + # ["i", "9"]. + text = unicodedata.normalize("NFKC", text) + + # Note that nothing is done to make searches accent-insensitive. + # That could be achieved by converting to NFKD form instead (with combining accents + # split out) and filtering out combining accents using `unicodedata.combining(c)`. + # The downside of this may be noisier search results, since search terms with + # explicit accents will match characters with no accents, or completely different + # accents. + # + # text = unicodedata.normalize("NFKD", text) + # text = "".join([c for c in text if not unicodedata.combining(c)]) + + return text + + def _parse_query_sqlite(search_term: str) -> str: """Takes a plain unicode string from the user and converts it into a form that can be passed to database. @@ -905,6 +950,7 @@ def _parse_query_sqlite(search_term: str) -> str: We specifically add both a prefix and non prefix matching term so that exact matches get ranked higher. """ + search_term = _filter_text_for_index(search_term) # Pull out the individual words, discarding any non-word characters. results = _parse_words(search_term) @@ -917,6 +963,8 @@ def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]: We use this so that we can add prefix matching, which isn't something that is supported by default. """ + search_term = _filter_text_for_index(search_term) + escaped_words = [] for word in _parse_words(search_term): # Postgres tsvector and tsquery quoting rules: diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 2d169684cf..43b724c4dd 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -504,6 +504,139 @@ class UserDirectoryStoreTestCase(HomeserverTestCase): {"user_id": BELA, "display_name": "Bela", "avatar_url": None}, ) + @override_config({"user_directory": {"search_all_users": True}}) + def test_search_user_dir_ascii_case_insensitivity(self) -> None: + """Tests that a user can look up another user by searching for their name in a + different case. + """ + CHARLIE = "@someuser:example.org" + self.get_success( + self.store.update_profile_in_user_dir(CHARLIE, "Charlie", None) + ) + + r = self.get_success(self.store.search_user_dir(ALICE, "cHARLIE", 10)) + self.assertFalse(r["limited"]) + self.assertEqual(1, len(r["results"])) + self.assertDictEqual( + r["results"][0], + {"user_id": CHARLIE, "display_name": "Charlie", "avatar_url": None}, + ) + + @override_config({"user_directory": {"search_all_users": True}}) + def test_search_user_dir_unicode_case_insensitivity(self) -> None: + """Tests that a user can look up another user by searching for their name in a + different case. + """ + IVAN = "@someuser:example.org" + self.get_success(self.store.update_profile_in_user_dir(IVAN, "Иван", None)) + + r = self.get_success(self.store.search_user_dir(ALICE, "иВАН", 10)) + self.assertFalse(r["limited"]) + self.assertEqual(1, len(r["results"])) + self.assertDictEqual( + r["results"][0], + {"user_id": IVAN, "display_name": "Иван", "avatar_url": None}, + ) + + @override_config({"user_directory": {"search_all_users": True}}) + def test_search_user_dir_dotted_dotless_i_case_insensitivity(self) -> None: + """Tests that a user can look up another user by searching for their name in a + different case, when their name contains dotted or dotless "i"s. + + Some languages have dotted and dotless versions of "i", which are considered to + be different letters: i <-> İ, ı <-> I. To make things difficult, they reuse the + ASCII "i" and "I" code points, despite having different lowercase / uppercase + forms. + """ + USER = "@someuser:example.org" + + expected_matches = [ + # (search_term, display_name) + # A search for "i" should match "İ". + ("iiiii", "İİİİİ"), + # A search for "I" should match "ı". + ("IIIII", "ııııı"), + # A search for "ı" should match "I". + ("ııııı", "IIIII"), + # A search for "İ" should match "i". + ("İİİİİ", "iiiii"), + ] + + for search_term, display_name in expected_matches: + self.get_success( + self.store.update_profile_in_user_dir(USER, display_name, None) + ) + + r = self.get_success(self.store.search_user_dir(ALICE, search_term, 10)) + self.assertFalse(r["limited"]) + self.assertEqual( + 1, + len(r["results"]), + f"searching for {search_term!r} did not match {display_name!r}", + ) + self.assertDictEqual( + r["results"][0], + {"user_id": USER, "display_name": display_name, "avatar_url": None}, + ) + + # We don't test for negative matches, to allow implementations that consider all + # the i variants to be the same. + + test_search_user_dir_dotted_dotless_i_case_insensitivity.skip = "not supported" # type: ignore + + @override_config({"user_directory": {"search_all_users": True}}) + def test_search_user_dir_unicode_normalization(self) -> None: + """Tests that a user can look up another user by searching for their name with + either composed or decomposed accents. + """ + AMELIE = "@someuser:example.org" + + expected_matches = [ + # (search_term, display_name) + ("Ame\u0301lie", "Amélie"), + ("Amélie", "Ame\u0301lie"), + ] + + for search_term, display_name in expected_matches: + self.get_success( + self.store.update_profile_in_user_dir(AMELIE, display_name, None) + ) + + r = self.get_success(self.store.search_user_dir(ALICE, search_term, 10)) + self.assertFalse(r["limited"]) + self.assertEqual( + 1, + len(r["results"]), + f"searching for {search_term!r} did not match {display_name!r}", + ) + self.assertDictEqual( + r["results"][0], + {"user_id": AMELIE, "display_name": display_name, "avatar_url": None}, + ) + + @override_config({"user_directory": {"search_all_users": True}}) + def test_search_user_dir_accent_insensitivity(self) -> None: + """Tests that a user can look up another user by searching for their name + without any accents. + """ + AMELIE = "@someuser:example.org" + self.get_success(self.store.update_profile_in_user_dir(AMELIE, "Amélie", None)) + + r = self.get_success(self.store.search_user_dir(ALICE, "amelie", 10)) + self.assertFalse(r["limited"]) + self.assertEqual(1, len(r["results"])) + self.assertDictEqual( + r["results"][0], + {"user_id": AMELIE, "display_name": "Amélie", "avatar_url": None}, + ) + + # It may be desirable for "é"s in search terms to not match plain "e"s and we + # really don't want "é"s in search terms to match "e"s with different accents. + # But we don't test for this to allow implementations that consider all + # "e"-lookalikes to be the same. + + test_search_user_dir_accent_insensitivity.skip = "not supported yet" # type: ignore + class UserDirectoryStoreTestCaseWithIcu(UserDirectoryStoreTestCase): use_icu = True -- cgit 1.5.1 From 1c95ddd09bbc46046a3412e7bb03a87aa3b6f65a Mon Sep 17 00:00:00 2001 From: Shay Date: Fri, 24 Feb 2023 13:15:29 -0800 Subject: Batch up storing state groups when creating new room (#14918) --- changelog.d/14918.misc | 1 + synapse/events/snapshot.py | 49 +++++++++++ synapse/handlers/message.py | 16 ++-- synapse/handlers/room.py | 37 ++++---- synapse/handlers/room_batch.py | 4 +- synapse/handlers/room_member.py | 13 ++- synapse/storage/databases/state/store.py | 119 ++++++++++++++++++++++++++ tests/handlers/test_message.py | 25 ++++-- tests/handlers/test_register.py | 3 +- tests/push/test_bulk_push_rule_evaluator.py | 13 +-- tests/rest/client/test_rooms.py | 4 +- tests/storage/test_event_chain.py | 6 +- tests/storage/test_state.py | 126 ++++++++++++++++++++++++++++ tests/unittest.py | 4 +- 14 files changed, 371 insertions(+), 49 deletions(-) create mode 100644 changelog.d/14918.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/14918.misc b/changelog.d/14918.misc new file mode 100644 index 0000000000..828794354a --- /dev/null +++ b/changelog.d/14918.misc @@ -0,0 +1 @@ +Batch up storing state groups when creating a new room. \ No newline at end of file diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index e0d82ad81c..a91a5d1e3c 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -23,6 +23,7 @@ from synapse.types import JsonDict, StateMap if TYPE_CHECKING: from synapse.storage.controllers import StorageControllers + from synapse.storage.databases import StateGroupDataStore from synapse.storage.databases.main import DataStore from synapse.types.state import StateFilter @@ -348,6 +349,54 @@ class UnpersistedEventContext(UnpersistedEventContextBase): partial_state: bool state_map_before_event: Optional[StateMap[str]] = None + @classmethod + async def batch_persist_unpersisted_contexts( + cls, + events_and_context: List[Tuple[EventBase, "UnpersistedEventContextBase"]], + room_id: str, + last_known_state_group: int, + datastore: "StateGroupDataStore", + ) -> List[Tuple[EventBase, EventContext]]: + """ + Takes a list of events and their associated unpersisted contexts and persists + the unpersisted contexts, returning a list of events and persisted contexts. + Note that all the events must be in a linear chain (ie a <- b <- c). + + Args: + events_and_context: A list of events and their unpersisted contexts + room_id: the room_id for the events + last_known_state_group: the last persisted state group + datastore: a state datastore + """ + amended_events_and_context = await datastore.store_state_deltas_for_batched( + events_and_context, room_id, last_known_state_group + ) + + events_and_persisted_context = [] + for event, unpersisted_context in amended_events_and_context: + if event.is_state(): + context = EventContext( + storage=unpersisted_context._storage, + state_group=unpersisted_context.state_group_after_event, + state_group_before_event=unpersisted_context.state_group_before_event, + state_delta_due_to_event=unpersisted_context.state_delta_due_to_event, + partial_state=unpersisted_context.partial_state, + prev_group=unpersisted_context.state_group_before_event, + delta_ids=unpersisted_context.state_delta_due_to_event, + ) + else: + context = EventContext( + storage=unpersisted_context._storage, + state_group=unpersisted_context.state_group_after_event, + state_group_before_event=unpersisted_context.state_group_before_event, + state_delta_due_to_event=unpersisted_context.state_delta_due_to_event, + partial_state=unpersisted_context.partial_state, + prev_group=unpersisted_context.prev_group_for_state_group_before_event, + delta_ids=unpersisted_context.delta_ids_to_state_group_before_event, + ) + events_and_persisted_context.append((event, context)) + return events_and_persisted_context + async def get_prev_state_ids( self, state_filter: Optional["StateFilter"] = None ) -> StateMap[str]: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index aa90d0000d..e433d6b01f 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -574,7 +574,7 @@ class EventCreationHandler: state_map: Optional[StateMap[str]] = None, for_batch: bool = False, current_state_group: Optional[int] = None, - ) -> Tuple[EventBase, EventContext]: + ) -> Tuple[EventBase, UnpersistedEventContextBase]: """ Given a dict from a client, create a new event. If bool for_batch is true, will create an event using the prev_event_ids, and will create an event context for @@ -721,8 +721,6 @@ class EventCreationHandler: current_state_group=current_state_group, ) - context = await unpersisted_context.persist(event) - # In an ideal world we wouldn't need the second part of this condition. However, # this behaviour isn't spec'd yet, meaning we should be able to deactivate this # behaviour. Another reason is that this code is also evaluated each time a new @@ -739,7 +737,7 @@ class EventCreationHandler: assert state_map is not None prev_event_id = state_map.get((EventTypes.Member, event.sender)) else: - prev_state_ids = await context.get_prev_state_ids( + prev_state_ids = await unpersisted_context.get_prev_state_ids( StateFilter.from_types([(EventTypes.Member, None)]) ) prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender)) @@ -764,8 +762,7 @@ class EventCreationHandler: ) self.validator.validate_new(event, self.config) - - return event, context + return event, unpersisted_context async def _is_exempt_from_privacy_policy( self, builder: EventBuilder, requester: Requester @@ -1005,7 +1002,7 @@ class EventCreationHandler: max_retries = 5 for i in range(max_retries): try: - event, context = await self.create_event( + event, unpersisted_context = await self.create_event( requester, event_dict, txn_id=txn_id, @@ -1016,6 +1013,7 @@ class EventCreationHandler: historical=historical, depth=depth, ) + context = await unpersisted_context.persist(event) assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % ( event.sender, @@ -1190,7 +1188,6 @@ class EventCreationHandler: if for_batch: assert prev_event_ids is not None assert state_map is not None - assert current_state_group is not None auth_ids = self._event_auth_handler.compute_auth_events(builder, state_map) event = await builder.build( prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth @@ -2046,7 +2043,7 @@ class EventCreationHandler: max_retries = 5 for i in range(max_retries): try: - event, context = await self.create_event( + event, unpersisted_context = await self.create_event( requester, { "type": EventTypes.Dummy, @@ -2055,6 +2052,7 @@ class EventCreationHandler: "sender": user_id, }, ) + context = await unpersisted_context.persist(event) event.internal_metadata.proactively_send = False diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index a26ec02284..b1784638f4 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -51,6 +51,7 @@ from synapse.api.filtering import Filter from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase +from synapse.events.snapshot import UnpersistedEventContext from synapse.events.utils import copy_and_fixup_power_levels_contents from synapse.handlers.relations import BundledAggregations from synapse.module_api import NOT_SPAM @@ -211,7 +212,7 @@ class RoomCreationHandler: # the required power level to send the tombstone event. ( tombstone_event, - tombstone_context, + tombstone_unpersisted_context, ) = await self.event_creation_handler.create_event( requester, { @@ -225,6 +226,9 @@ class RoomCreationHandler: }, }, ) + tombstone_context = await tombstone_unpersisted_context.persist( + tombstone_event + ) validate_event_for_room_version(tombstone_event) await self._event_auth_handler.check_auth_rules_from_context( tombstone_event @@ -1092,7 +1096,7 @@ class RoomCreationHandler: content: JsonDict, for_batch: bool, **kwargs: Any, - ) -> Tuple[EventBase, synapse.events.snapshot.EventContext]: + ) -> Tuple[EventBase, synapse.events.snapshot.UnpersistedEventContextBase]: """ Creates an event and associated event context. Args: @@ -1111,20 +1115,23 @@ class RoomCreationHandler: event_dict = create_event_dict(etype, content, **kwargs) - new_event, new_context = await self.event_creation_handler.create_event( + ( + new_event, + new_unpersisted_context, + ) = await self.event_creation_handler.create_event( creator, event_dict, prev_event_ids=prev_event, depth=depth, state_map=state_map, for_batch=for_batch, - current_state_group=current_state_group, ) + depth += 1 prev_event = [new_event.event_id] state_map[(new_event.type, new_event.state_key)] = new_event.event_id - return new_event, new_context + return new_event, new_unpersisted_context try: config = self._presets_dict[preset_config] @@ -1134,10 +1141,10 @@ class RoomCreationHandler: ) creation_content.update({"creator": creator_id}) - creation_event, creation_context = await create_event( + creation_event, unpersisted_creation_context = await create_event( EventTypes.Create, creation_content, False ) - + creation_context = await unpersisted_creation_context.persist(creation_event) logger.debug("Sending %s in new room", EventTypes.Member) ev = await self.event_creation_handler.handle_new_client_event( requester=creator, @@ -1181,7 +1188,6 @@ class RoomCreationHandler: power_event, power_context = await create_event( EventTypes.PowerLevels, pl_content, True ) - current_state_group = power_context._state_group events_to_send.append((power_event, power_context)) else: power_level_content: JsonDict = { @@ -1230,14 +1236,12 @@ class RoomCreationHandler: power_level_content, True, ) - current_state_group = pl_context._state_group events_to_send.append((pl_event, pl_context)) if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: room_alias_event, room_alias_context = await create_event( EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True ) - current_state_group = room_alias_context._state_group events_to_send.append((room_alias_event, room_alias_context)) if (EventTypes.JoinRules, "") not in initial_state: @@ -1246,7 +1250,6 @@ class RoomCreationHandler: {"join_rule": config["join_rules"]}, True, ) - current_state_group = join_rules_context._state_group events_to_send.append((join_rules_event, join_rules_context)) if (EventTypes.RoomHistoryVisibility, "") not in initial_state: @@ -1255,7 +1258,6 @@ class RoomCreationHandler: {"history_visibility": config["history_visibility"]}, True, ) - current_state_group = visibility_context._state_group events_to_send.append((visibility_event, visibility_context)) if config["guest_can_join"]: @@ -1265,14 +1267,12 @@ class RoomCreationHandler: {EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN}, True, ) - current_state_group = guest_access_context._state_group events_to_send.append((guest_access_event, guest_access_context)) for (etype, state_key), content in initial_state.items(): event, context = await create_event( etype, content, True, state_key=state_key ) - current_state_group = context._state_group events_to_send.append((event, context)) if config["encrypted"]: @@ -1284,9 +1284,16 @@ class RoomCreationHandler: ) events_to_send.append((encryption_event, encryption_context)) + datastore = self.hs.get_datastores().state + events_and_context = ( + await UnpersistedEventContext.batch_persist_unpersisted_contexts( + events_to_send, room_id, current_state_group, datastore + ) + ) + last_event = await self.event_creation_handler.handle_new_client_event( creator, - events_to_send, + events_and_context, ignore_shadow_ban=True, ratelimit=False, ) diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index 5d4ca0e2d2..bf9df60218 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -327,7 +327,7 @@ class RoomBatchHandler: # Mark all events as historical event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True - event, context = await self.event_creation_handler.create_event( + event, unpersisted_context = await self.event_creation_handler.create_event( await self.create_requester_for_user_id_from_app_service( ev["sender"], app_service_requester.app_service ), @@ -345,7 +345,7 @@ class RoomBatchHandler: historical=True, depth=inherited_depth, ) - + context = await unpersisted_context.persist(event) assert context._state_group # Normally this is done when persisting the event but we have to diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index a965c7ec76..de7476f300 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -414,7 +414,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): max_retries = 5 for i in range(max_retries): try: - event, context = await self.event_creation_handler.create_event( + ( + event, + unpersisted_context, + ) = await self.event_creation_handler.create_event( requester, { "type": EventTypes.Member, @@ -435,7 +438,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): outlier=outlier, historical=historical, ) - + context = await unpersisted_context.persist(event) prev_state_ids = await context.get_prev_state_ids( StateFilter.from_types([(EventTypes.Member, None)]) ) @@ -1944,7 +1947,10 @@ class RoomMemberMasterHandler(RoomMemberHandler): max_retries = 5 for i in range(max_retries): try: - event, context = await self.event_creation_handler.create_event( + ( + event, + unpersisted_context, + ) = await self.event_creation_handler.create_event( requester, event_dict, txn_id=txn_id, @@ -1952,6 +1958,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): auth_event_ids=auth_event_ids, outlier=True, ) + context = await unpersisted_context.persist(event) event.internal_metadata.out_of_band_membership = True result_event = ( diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 89b1faa6c8..bf4cdfdf29 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -18,6 +18,8 @@ from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Se import attr from synapse.api.constants import EventTypes +from synapse.events import EventBase +from synapse.events.snapshot import UnpersistedEventContext, UnpersistedEventContextBase from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -401,6 +403,123 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): fetched_keys=non_member_types, ) + async def store_state_deltas_for_batched( + self, + events_and_context: List[Tuple[EventBase, UnpersistedEventContextBase]], + room_id: str, + prev_group: int, + ) -> List[Tuple[EventBase, UnpersistedEventContext]]: + """Generate and store state deltas for a group of events and contexts created to be + batch persisted. Note that all the events must be in a linear chain (ie a <- b <- c). + + Args: + events_and_context: the events to generate and store a state groups for + and their associated contexts + room_id: the id of the room the events were created for + prev_group: the state group of the last event persisted before the batched events + were created + """ + + def insert_deltas_group_txn( + txn: LoggingTransaction, + events_and_context: List[Tuple[EventBase, UnpersistedEventContext]], + prev_group: int, + ) -> List[Tuple[EventBase, UnpersistedEventContext]]: + """Generate and store state groups for the provided events and contexts. + + Requires that we have the state as a delta from the last persisted state group. + + Returns: + A list of state groups + """ + is_in_db = self.db_pool.simple_select_one_onecol_txn( + txn, + table="state_groups", + keyvalues={"id": prev_group}, + retcol="id", + allow_none=True, + ) + if not is_in_db: + raise Exception( + "Trying to persist state with unpersisted prev_group: %r" + % (prev_group,) + ) + + num_state_groups = sum( + 1 for event, _ in events_and_context if event.is_state() + ) + + state_groups = self._state_group_seq_gen.get_next_mult_txn( + txn, num_state_groups + ) + + sg_before = prev_group + state_group_iter = iter(state_groups) + for event, context in events_and_context: + if not event.is_state(): + context.state_group_after_event = sg_before + context.state_group_before_event = sg_before + continue + + sg_after = next(state_group_iter) + context.state_group_after_event = sg_after + context.state_group_before_event = sg_before + context.state_delta_due_to_event = { + (event.type, event.state_key): event.event_id + } + sg_before = sg_after + + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups", + keys=("id", "room_id", "event_id"), + values=[ + (context.state_group_after_event, room_id, event.event_id) + for event, context in events_and_context + if event.is_state() + ], + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="state_group_edges", + keys=("state_group", "prev_state_group"), + values=[ + ( + context.state_group_after_event, + context.state_group_before_event, + ) + for event, context in events_and_context + if event.is_state() + ], + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups_state", + keys=("state_group", "room_id", "type", "state_key", "event_id"), + values=[ + ( + context.state_group_after_event, + room_id, + key[0], + key[1], + state_id, + ) + for event, context in events_and_context + if context.state_delta_due_to_event is not None + for key, state_id in context.state_delta_due_to_event.items() + ], + ) + return events_and_context + + return await self.db_pool.runInteraction( + "store_state_deltas_for_batched.insert_deltas_group", + insert_deltas_group_txn, + events_and_context, + prev_group, + ) + async def store_state_group( self, event_id: str, diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index 69d384442f..9691d66b48 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -18,7 +18,7 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.events.snapshot import EventContext +from synapse.events.snapshot import EventContext, UnpersistedEventContextBase from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer @@ -79,7 +79,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase): return memberEvent, memberEventContext - def _create_duplicate_event(self, txn_id: str) -> Tuple[EventBase, EventContext]: + def _create_duplicate_event( + self, txn_id: str + ) -> Tuple[EventBase, UnpersistedEventContextBase]: """Create a new event with the given transaction ID. All events produced by this method will be considered duplicates. """ @@ -107,7 +109,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase): txn_id = "something_suitably_random" - event1, context = self._create_duplicate_event(txn_id) + event1, unpersisted_context = self._create_duplicate_event(txn_id) + context = self.get_success(unpersisted_context.persist(event1)) ret_event1 = self.get_success( self.handler.handle_new_client_event( @@ -119,7 +122,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.assertEqual(event1.event_id, ret_event1.event_id) - event2, context = self._create_duplicate_event(txn_id) + event2, unpersisted_context = self._create_duplicate_event(txn_id) + context = self.get_success(unpersisted_context.persist(event2)) # We want to test that the deduplication at the persit event end works, # so we want to make sure we test with different events. @@ -140,7 +144,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase): # Let's test that calling `persist_event` directly also does the right # thing. - event3, context = self._create_duplicate_event(txn_id) + event3, unpersisted_context = self._create_duplicate_event(txn_id) + context = self.get_success(unpersisted_context.persist(event3)) + self.assertNotEqual(event1.event_id, event3.event_id) ret_event3, event_pos3, _ = self.get_success( @@ -154,7 +160,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase): # Let's test that calling `persist_events` directly also does the right # thing. - event4, context = self._create_duplicate_event(txn_id) + event4, unpersisted_context = self._create_duplicate_event(txn_id) + context = self.get_success(unpersisted_context.persist(event4)) self.assertNotEqual(event1.event_id, event3.event_id) events, _ = self.get_success( @@ -174,8 +181,10 @@ class EventCreationTestCase(unittest.HomeserverTestCase): txn_id = "something_else_suitably_random" # Create two duplicate events to persist at the same time - event1, context1 = self._create_duplicate_event(txn_id) - event2, context2 = self._create_duplicate_event(txn_id) + event1, unpersisted_context1 = self._create_duplicate_event(txn_id) + context1 = self.get_success(unpersisted_context1.persist(event1)) + event2, unpersisted_context2 = self._create_duplicate_event(txn_id) + context2 = self.get_success(unpersisted_context2.persist(event2)) # Ensure their event IDs are different to start with self.assertNotEqual(event1.event_id, event2.event_id) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 1db99b3c00..aff1ec4758 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -507,7 +507,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): # Lower the permissions of the inviter. event_creation_handler = self.hs.get_event_creation_handler() requester = create_requester(inviter) - event, context = self.get_success( + event, unpersisted_context = self.get_success( event_creation_handler.create_event( requester, { @@ -519,6 +519,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): }, ) ) + context = self.get_success(unpersisted_context.persist(event)) self.get_success( event_creation_handler.handle_new_client_event( requester, events_and_context=[(event, context)] diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index dce6899e78..1458076a90 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -130,7 +130,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): # Create a new message event, and try to evaluate it under the dodgy # power level event. - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_event( self.requester, { @@ -145,6 +145,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): prev_event_ids=[pl_event_id], ) ) + context = self.get_success(unpersisted_context.persist(event)) bulk_evaluator = BulkPushRuleEvaluator(self.hs) # should not raise @@ -170,7 +171,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): """Ensure that push rules are not calculated when disabled in the config""" # Create a new message event which should cause a notification. - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_event( self.requester, { @@ -184,6 +185,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): }, ) ) + context = self.get_success(unpersisted_context.persist(event)) bulk_evaluator = BulkPushRuleEvaluator(self.hs) # Mock the method which calculates push rules -- we do this instead of @@ -200,7 +202,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): ) -> bool: """Returns true iff the `mentions` trigger an event push action.""" # Create a new message event which should cause a notification. - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_event( self.requester, { @@ -211,7 +213,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): }, ) ) - + context = self.get_success(unpersisted_context.persist(event)) # Execute the push rule machinery. self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) @@ -390,7 +392,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): bulk_evaluator = BulkPushRuleEvaluator(self.hs) # Create & persist an event to use as the parent of the relation. - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_event( self.requester, { @@ -404,6 +406,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): }, ) ) + context = self.get_success(unpersisted_context.persist(event)) self.get_success( self.event_creation_handler.handle_new_client_event( self.requester, events_and_context=[(event, context)] diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 4dd763096d..a4900703c4 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -713,7 +713,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(33, channel.resource_usage.db_txn_count) + self.assertEqual(30, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -726,7 +726,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(36, channel.resource_usage.db_txn_count) + self.assertEqual(32, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index 73d11e7786..e39b63edac 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -522,7 +522,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_prev_events_for_room(room_id) ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( event_handler.create_event( self.requester, { @@ -535,6 +535,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): prev_event_ids=latest_event_ids, ) ) + context = self.get_success(unpersisted_context.persist(event)) self.get_success( event_handler.handle_new_client_event( self.requester, events_and_context=[(event, context)] @@ -544,7 +545,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): assert state_ids1 is not None state1 = set(state_ids1.values()) - event, context = self.get_success( + event, unpersisted_context = self.get_success( event_handler.create_event( self.requester, { @@ -557,6 +558,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): prev_event_ids=latest_event_ids, ) ) + context = self.get_success(unpersisted_context.persist(event)) self.get_success( event_handler.handle_new_client_event( self.requester, events_and_context=[(event, context)] diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index e82c03f597..62aed6af0a 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -496,3 +496,129 @@ class StateStoreTestCase(HomeserverTestCase): self.assertEqual(is_all, True) self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) + + def test_batched_state_group_storing(self) -> None: + creation_event = self.inject_state_event( + self.room, self.u_alice, EventTypes.Create, "", {} + ) + state_to_event = self.get_success( + self.storage.state.get_state_groups( + self.room.to_string(), [creation_event.event_id] + ) + ) + current_state_group = list(state_to_event.keys())[0] + + # create some unpersisted events and event contexts to store against room + events_and_context = [] + builder = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.Name, + "sender": self.u_alice.to_string(), + "state_key": "", + "room_id": self.room.to_string(), + "content": {"name": "first rename of room"}, + }, + ) + + event1, unpersisted_context1 = self.get_success( + self.event_creation_handler.create_new_client_event(builder) + ) + events_and_context.append((event1, unpersisted_context1)) + + builder2 = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.JoinRules, + "sender": self.u_alice.to_string(), + "state_key": "", + "room_id": self.room.to_string(), + "content": {"join_rule": "private"}, + }, + ) + + event2, unpersisted_context2 = self.get_success( + self.event_creation_handler.create_new_client_event(builder2) + ) + events_and_context.append((event2, unpersisted_context2)) + + builder3 = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.Message, + "sender": self.u_alice.to_string(), + "room_id": self.room.to_string(), + "content": {"body": "hello from event 3", "msgtype": "m.text"}, + }, + ) + + event3, unpersisted_context3 = self.get_success( + self.event_creation_handler.create_new_client_event(builder3) + ) + events_and_context.append((event3, unpersisted_context3)) + + builder4 = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.JoinRules, + "sender": self.u_alice.to_string(), + "state_key": "", + "room_id": self.room.to_string(), + "content": {"join_rule": "public"}, + }, + ) + + event4, unpersisted_context4 = self.get_success( + self.event_creation_handler.create_new_client_event(builder4) + ) + events_and_context.append((event4, unpersisted_context4)) + + processed_events_and_context = self.get_success( + self.hs.get_datastores().state.store_state_deltas_for_batched( + events_and_context, self.room.to_string(), current_state_group + ) + ) + + # check that only state events are in state_groups, and all state events are in state_groups + res = self.get_success( + self.store.db_pool.simple_select_list( + table="state_groups", + keyvalues=None, + retcols=("event_id",), + ) + ) + + events = [] + for result in res: + self.assertNotIn(event3.event_id, result) + events.append(result.get("event_id")) + + for event, _ in processed_events_and_context: + if event.is_state(): + self.assertIn(event.event_id, events) + + # check that each unique state has state group in state_groups_state and that the + # type/state key is correct, and check that each state event's state group + # has an entry and prev event in state_group_edges + for event, context in processed_events_and_context: + if event.is_state(): + state = self.get_success( + self.store.db_pool.simple_select_list( + table="state_groups_state", + keyvalues={"state_group": context.state_group_after_event}, + retcols=("type", "state_key"), + ) + ) + self.assertEqual(event.type, state[0].get("type")) + self.assertEqual(event.state_key, state[0].get("state_key")) + + groups = self.get_success( + self.store.db_pool.simple_select_list( + table="state_group_edges", + keyvalues={"state_group": str(context.state_group_after_event)}, + retcols=("*",), + ) + ) + self.assertEqual( + context.state_group_before_event, groups[0].get("prev_state_group") + ) diff --git a/tests/unittest.py b/tests/unittest.py index b21e7f1221..f9160faa1d 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -723,7 +723,7 @@ class HomeserverTestCase(TestCase): event_creator = self.hs.get_event_creation_handler() requester = create_requester(user) - event, context = self.get_success( + event, unpersisted_context = self.get_success( event_creator.create_event( requester, { @@ -735,7 +735,7 @@ class HomeserverTestCase(TestCase): prev_event_ids=prev_event_ids, ) ) - + context = self.get_success(unpersisted_context.persist(event)) if soft_failed: event.internal_metadata.soft_failed = True -- cgit 1.5.1 From b40657314e03583f45ad49504711698a70735313 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 27 Feb 2023 14:19:19 +0000 Subject: Add module API callbacks for adding and deleting local 3PID associations (#15044 --- changelog.d/15044.feature | 1 + docs/modules/third_party_rules_callbacks.md | 45 ++++++++- docs/upgrade.md | 24 +++++ synapse/events/third_party_rules.py | 63 +++++++++++++ synapse/handlers/auth.py | 49 ++++++---- synapse/handlers/deactivate_account.py | 20 ++-- synapse/module_api/__init__.py | 10 ++ synapse/rest/admin/users.py | 11 ++- synapse/rest/client/account.py | 9 +- synapse/storage/databases/main/registration.py | 13 --- tests/push/test_email.py | 6 +- tests/rest/client/test_third_party_rules.py | 121 +++++++++++++++++++++++++ 12 files changed, 324 insertions(+), 48 deletions(-) create mode 100644 changelog.d/15044.feature (limited to 'synapse/storage/databases') diff --git a/changelog.d/15044.feature b/changelog.d/15044.feature new file mode 100644 index 0000000000..91e5cda8c3 --- /dev/null +++ b/changelog.d/15044.feature @@ -0,0 +1 @@ +Add two new Third Party Rules module API callbacks: [`on_add_user_third_party_identifier`](https://matrix-org.github.io/synapse/v1.79/modules/third_party_rules_callbacks.html#on_add_user_third_party_identifier) and [`on_remove_user_third_party_identifier`](https://matrix-org.github.io/synapse/v1.79/modules/third_party_rules_callbacks.html#on_remove_user_third_party_identifier). \ No newline at end of file diff --git a/docs/modules/third_party_rules_callbacks.md b/docs/modules/third_party_rules_callbacks.md index 888e43bd10..4a27d976fb 100644 --- a/docs/modules/third_party_rules_callbacks.md +++ b/docs/modules/third_party_rules_callbacks.md @@ -254,6 +254,11 @@ If multiple modules implement this callback, Synapse runs them all in order. _First introduced in Synapse v1.56.0_ +** +This callback is deprecated in favour of the `on_add_user_third_party_identifier` callback, which +features the same functionality. The only difference is in name. +** + ```python async def on_threepid_bind(user_id: str, medium: str, address: str) -> None: ``` @@ -268,6 +273,44 @@ server_. If multiple modules implement this callback, Synapse runs them all in order. +### `on_add_user_third_party_identifier` + +_First introduced in Synapse v1.79.0_ + +```python +async def on_add_user_third_party_identifier(user_id: str, medium: str, address: str) -> None: +``` + +Called after successfully creating an association between a user and a third-party identifier +(email address, phone number). The module is given the Matrix ID of the user the +association is for, as well as the medium (`email` or `msisdn`) and address of the +third-party identifier (i.e. an email address). + +Note that this callback is _not_ called if a user attempts to bind their third-party identifier +to an identity server (via a call to [`POST +/_matrix/client/v3/account/3pid/bind`](https://spec.matrix.org/v1.5/client-server-api/#post_matrixclientv3account3pidbind)). + +If multiple modules implement this callback, Synapse runs them all in order. + +### `on_remove_user_third_party_identifier` + +_First introduced in Synapse v1.79.0_ + +```python +async def on_remove_user_third_party_identifier(user_id: str, medium: str, address: str) -> None: +``` + +Called after successfully removing an association between a user and a third-party identifier +(email address, phone number). The module is given the Matrix ID of the user the +association is for, as well as the medium (`email` or `msisdn`) and address of the +third-party identifier (i.e. an email address). + +Note that this callback is _not_ called if a user attempts to unbind their third-party +identifier from an identity server (via a call to [`POST +/_matrix/client/v3/account/3pid/unbind`](https://spec.matrix.org/v1.5/client-server-api/#post_matrixclientv3account3pidunbind)). + +If multiple modules implement this callback, Synapse runs them all in order. + ## Example The example below is a module that implements the third-party rules callback @@ -300,4 +343,4 @@ class EventCensorer: ) event_dict["content"] = new_event_content return event_dict -``` +``` \ No newline at end of file diff --git a/docs/upgrade.md b/docs/upgrade.md index 15167b8c58..f06e874054 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -88,6 +88,30 @@ process, for example: dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb ``` +# Upgrading to v1.79.0 + +## The `on_threepid_bind` module callback method has been deprecated + +Synapse v1.79.0 deprecates the +[`on_threepid_bind`](modules/third_party_rules_callbacks.md#on_threepid_bind) +"third-party rules" Synapse module callback method in favour of a new module method, +[`on_add_user_third_party_identifier`](modules/third_party_rules_callbacks.md#on_add_user_third_party_identifier). +`on_threepid_bind` will be removed in a future version of Synapse. You should check whether any Synapse +modules in use in your deployment are making use of `on_threepid_bind`, and update them where possible. + +The arguments and functionality of the new method are the same. + +The justification behind the name change is that the old method's name, `on_threepid_bind`, was +misleading. A user is considered to "bind" their third-party ID to their Matrix ID only if they +do so via an [identity server](https://spec.matrix.org/latest/identity-service-api/) +(so that users on other homeservers may find them). But this method was not called in that case - +it was only called when a user added a third-party identifier on the local homeserver. + +Module developers may also be interested in the related +[`on_remove_user_third_party_identifier`](modules/third_party_rules_callbacks.md#on_remove_user_third_party_identifier) +module callback method that was also added in Synapse v1.79.0. This new method is called when a +user removes a third-party identifier from their account. + # Upgrading to v1.78.0 ## Deprecate the `/_synapse/admin/v1/media//delete` admin API diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 9a25ed419b..3e4d52c8d8 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -45,6 +45,8 @@ CHECK_CAN_DEACTIVATE_USER_CALLBACK = Callable[[str, bool], Awaitable[bool]] ON_PROFILE_UPDATE_CALLBACK = Callable[[str, ProfileInfo, bool, bool], Awaitable] ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK = Callable[[str, bool, bool], Awaitable] ON_THREEPID_BIND_CALLBACK = Callable[[str, str, str], Awaitable] +ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK = Callable[[str, str, str], Awaitable] +ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK = Callable[[str, str, str], Awaitable] def load_legacy_third_party_event_rules(hs: "HomeServer") -> None: @@ -172,6 +174,12 @@ class ThirdPartyEventRules: ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK ] = [] self._on_threepid_bind_callbacks: List[ON_THREEPID_BIND_CALLBACK] = [] + self._on_add_user_third_party_identifier_callbacks: List[ + ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK + ] = [] + self._on_remove_user_third_party_identifier_callbacks: List[ + ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK + ] = [] def register_third_party_rules_callbacks( self, @@ -191,6 +199,12 @@ class ThirdPartyEventRules: ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK ] = None, on_threepid_bind: Optional[ON_THREEPID_BIND_CALLBACK] = None, + on_add_user_third_party_identifier: Optional[ + ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK + ] = None, + on_remove_user_third_party_identifier: Optional[ + ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK + ] = None, ) -> None: """Register callbacks from modules for each hook.""" if check_event_allowed is not None: @@ -228,6 +242,11 @@ class ThirdPartyEventRules: if on_threepid_bind is not None: self._on_threepid_bind_callbacks.append(on_threepid_bind) + if on_add_user_third_party_identifier is not None: + self._on_add_user_third_party_identifier_callbacks.append( + on_add_user_third_party_identifier + ) + async def check_event_allowed( self, event: EventBase, @@ -511,6 +530,9 @@ class ThirdPartyEventRules: local homeserver, not when it's created on an identity server (and then kept track of so that it can be unbound on the same IS later on). + THIS MODULE CALLBACK METHOD HAS BEEN DEPRECATED. Please use the + `on_add_user_third_party_identifier` callback method instead. + Args: user_id: the user being associated with the threepid. medium: the threepid's medium. @@ -523,3 +545,44 @@ class ThirdPartyEventRules: logger.exception( "Failed to run module API callback %s: %s", callback, e ) + + async def on_add_user_third_party_identifier( + self, user_id: str, medium: str, address: str + ) -> None: + """Called when an association between a user's Matrix ID and a third-party ID + (email, phone number) has successfully been registered on the homeserver. + + Args: + user_id: The User ID included in the association. + medium: The medium of the third-party ID (email, msisdn). + address: The address of the third-party ID (i.e. an email address). + """ + for callback in self._on_add_user_third_party_identifier_callbacks: + try: + await callback(user_id, medium, address) + except Exception as e: + logger.exception( + "Failed to run module API callback %s: %s", callback, e + ) + + async def on_remove_user_third_party_identifier( + self, user_id: str, medium: str, address: str + ) -> None: + """Called when an association between a user's Matrix ID and a third-party ID + (email, phone number) has been successfully removed on the homeserver. + + This is called *after* any known bindings on identity servers for this + association have been removed. + + Args: + user_id: The User ID included in the removed association. + medium: The medium of the third-party ID (email, msisdn). + address: The address of the third-party ID (i.e. an email address). + """ + for callback in self._on_remove_user_third_party_identifier_callbacks: + try: + await callback(user_id, medium, address) + except Exception as e: + logger.exception( + "Failed to run module API callback %s: %s", callback, e + ) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index b12bc4c9a3..308e38edea 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1542,6 +1542,17 @@ class AuthHandler: async def add_threepid( self, user_id: str, medium: str, address: str, validated_at: int ) -> None: + """ + Adds an association between a user's Matrix ID and a third-party ID (email, + phone number). + + Args: + user_id: The ID of the user to associate. + medium: The medium of the third-party ID (email, msisdn). + address: The address of the third-party ID (i.e. an email address). + validated_at: The timestamp in ms of when the validation that the user owns + this third-party ID occurred. + """ # check if medium has a valid value if medium not in ["email", "msisdn"]: raise SynapseError( @@ -1566,42 +1577,44 @@ class AuthHandler: user_id, medium, address, validated_at, self.hs.get_clock().time_msec() ) + # Inform Synapse modules that a 3PID association has been created. + await self._third_party_rules.on_add_user_third_party_identifier( + user_id, medium, address + ) + + # Deprecated method for informing Synapse modules that a 3PID association + # has successfully been created. await self._third_party_rules.on_threepid_bind(user_id, medium, address) - async def delete_threepid( - self, user_id: str, medium: str, address: str, id_server: Optional[str] = None - ) -> bool: - """Attempts to unbind the 3pid on the identity servers and deletes it - from the local database. + async def delete_local_threepid( + self, user_id: str, medium: str, address: str + ) -> None: + """Deletes an association between a third-party ID and a user ID from the local + database. This method does not unbind the association from any identity servers. + + If `medium` is 'email' and a pusher is associated with this third-party ID, the + pusher will also be deleted. Args: user_id: ID of user to remove the 3pid from. medium: The medium of the 3pid being removed: "email" or "msisdn". address: The 3pid address to remove. - id_server: Use the given identity server when unbinding - any threepids. If None then will attempt to unbind using the - identity server specified when binding (if known). - - Returns: - Returns True if successfully unbound the 3pid on - the identity server, False if identity server doesn't support the - unbind API. """ - # 'Canonicalise' email addresses as per above if medium == "email": address = canonicalise_email(address) - result = await self.hs.get_identity_handler().try_unbind_threepid( - user_id, medium, address, id_server + await self.store.user_delete_threepid(user_id, medium, address) + + # Inform Synapse modules that a 3PID association has been deleted. + await self._third_party_rules.on_remove_user_third_party_identifier( + user_id, medium, address ) - await self.store.user_delete_threepid(user_id, medium, address) if medium == "email": await self.store.delete_pusher_by_app_id_pushkey_user_id( app_id="m.email", pushkey=address, user_id=user_id ) - return result async def hash(self, password: str) -> str: """Computes a secure hash of password. diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index d24f649382..d31263c717 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -100,26 +100,28 @@ class DeactivateAccountHandler: # unbinding identity_server_supports_unbinding = True - # Retrieve the 3PIDs this user has bound to an identity server - threepids = await self.store.user_get_bound_threepids(user_id) - - for threepid in threepids: + # Attempt to unbind any known bound threepids to this account from identity + # server(s). + bound_threepids = await self.store.user_get_bound_threepids(user_id) + for threepid in bound_threepids: try: result = await self._identity_handler.try_unbind_threepid( user_id, threepid["medium"], threepid["address"], id_server ) - identity_server_supports_unbinding &= result except Exception: # Do we want this to be a fatal error or should we carry on? logger.exception("Failed to remove threepid from ID server") raise SynapseError(400, "Failed to remove threepid from ID server") - await self.store.user_delete_threepid( + + identity_server_supports_unbinding &= result + + # Remove any local threepid associations for this account. + local_threepids = await self.store.user_get_threepids(user_id) + for threepid in local_threepids: + await self._auth_handler.delete_local_threepid( user_id, threepid["medium"], threepid["address"] ) - # Remove all 3PIDs this user has bound to the homeserver - await self.store.user_delete_threepids(user_id) - # delete any devices belonging to the user, which will also # delete corresponding access tokens. await self._device_handler.delete_all_devices_for_user(user_id) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 1964276a54..424239e3df 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -64,9 +64,11 @@ from synapse.events.third_party_rules import ( CHECK_EVENT_ALLOWED_CALLBACK, CHECK_THREEPID_CAN_BE_INVITED_CALLBACK, CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK, + ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK, ON_CREATE_ROOM_CALLBACK, ON_NEW_EVENT_CALLBACK, ON_PROFILE_UPDATE_CALLBACK, + ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK, ON_THREEPID_BIND_CALLBACK, ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK, ) @@ -357,6 +359,12 @@ class ModuleApi: ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK ] = None, on_threepid_bind: Optional[ON_THREEPID_BIND_CALLBACK] = None, + on_add_user_third_party_identifier: Optional[ + ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK + ] = None, + on_remove_user_third_party_identifier: Optional[ + ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK + ] = None, ) -> None: """Registers callbacks for third party event rules capabilities. @@ -373,6 +381,8 @@ class ModuleApi: on_profile_update=on_profile_update, on_user_deactivation_status_changed=on_user_deactivation_status_changed, on_threepid_bind=on_threepid_bind, + on_add_user_third_party_identifier=on_add_user_third_party_identifier, + on_remove_user_third_party_identifier=on_remove_user_third_party_identifier, ) def register_presence_router_callbacks( diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 7cc4db20d6..357e9a574d 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -304,13 +304,20 @@ class UserRestServletV2(RestServlet): # remove old threepids for medium, address in del_threepids: try: - await self.auth_handler.delete_threepid( - user_id, medium, address, None + # Attempt to remove any known bindings of this third-party ID + # and user ID from identity servers. + await self.hs.get_identity_handler().try_unbind_threepid( + user_id, medium, address, id_server=None ) except Exception: logger.exception("Failed to remove threepids") raise SynapseError(500, "Failed to remove threepids") + # Delete the local association of this user ID and third-party ID. + await self.auth_handler.delete_local_threepid( + user_id, medium, address + ) + # add new threepids current_time = self.hs.get_clock().time_msec() for medium, address in add_threepids: diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index 662f5bf762..484d7440a4 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -768,7 +768,9 @@ class ThreepidDeleteRestServlet(RestServlet): user_id = requester.user.to_string() try: - ret = await self.auth_handler.delete_threepid( + # Attempt to remove any known bindings of this third-party ID + # and user ID from identity servers. + ret = await self.hs.get_identity_handler().try_unbind_threepid( user_id, body.medium, body.address, body.id_server ) except Exception: @@ -783,6 +785,11 @@ class ThreepidDeleteRestServlet(RestServlet): else: id_server_unbind_result = "no-support" + # Delete the local association of this user ID and third-party ID. + await self.auth_handler.delete_local_threepid( + user_id, body.medium, body.address + ) + return 200, {"id_server_unbind_result": id_server_unbind_result} diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 9a55e17624..717237e024 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1002,19 +1002,6 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="user_delete_threepid", ) - async def user_delete_threepids(self, user_id: str) -> None: - """Delete all threepid this user has bound - - Args: - user_id: The user id to delete all threepids of - - """ - await self.db_pool.simple_delete( - "user_threepids", - keyvalues={"user_id": user_id}, - desc="user_delete_threepids", - ) - async def add_user_bound_threepid( self, user_id: str, medium: str, address: str, id_server: str ) -> None: diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 0a3aca5c50..4ea5472eb4 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -369,10 +369,8 @@ class EmailPusherTests(HomeserverTestCase): # disassociate the user's email address self.get_success( - self.auth_handler.delete_threepid( - user_id=self.user_id, - medium="email", - address="a@example.com", + self.auth_handler.delete_local_threepid( + user_id=self.user_id, medium="email", address="a@example.com" ) ) diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index c0f93f898a..3b99513707 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -934,3 +934,124 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): # Check that the mock was called with the right parameters self.assertEqual(args, (user_id, "email", "foo@example.com")) + + def test_on_add_and_remove_user_third_party_identifier(self) -> None: + """Tests that the on_add_user_third_party_identifier and + on_remove_user_third_party_identifier module callbacks are called + just before associating and removing a 3PID to/from an account. + """ + # Pretend to be a Synapse module and register both callbacks as mocks. + third_party_rules = self.hs.get_third_party_event_rules() + on_add_user_third_party_identifier_callback_mock = Mock( + return_value=make_awaitable(None) + ) + on_remove_user_third_party_identifier_callback_mock = Mock( + return_value=make_awaitable(None) + ) + third_party_rules._on_threepid_bind_callbacks.append( + on_add_user_third_party_identifier_callback_mock + ) + third_party_rules._on_threepid_bind_callbacks.append( + on_remove_user_third_party_identifier_callback_mock + ) + + # Register an admin user. + self.register_user("admin", "password", admin=True) + admin_tok = self.login("admin", "password") + + # Also register a normal user we can modify. + user_id = self.register_user("user", "password") + + # Add a 3PID to the user. + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % user_id, + { + "threepids": [ + { + "medium": "email", + "address": "foo@example.com", + }, + ], + }, + access_token=admin_tok, + ) + + # Check that the mocked add callback was called with the appropriate + # 3PID details. + self.assertEqual(channel.code, 200, channel.json_body) + on_add_user_third_party_identifier_callback_mock.assert_called_once() + args = on_add_user_third_party_identifier_callback_mock.call_args[0] + self.assertEqual(args, (user_id, "email", "foo@example.com")) + + # Now remove the 3PID from the user + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % user_id, + { + "threepids": [], + }, + access_token=admin_tok, + ) + + # Check that the mocked remove callback was called with the appropriate + # 3PID details. + self.assertEqual(channel.code, 200, channel.json_body) + on_remove_user_third_party_identifier_callback_mock.assert_called_once() + args = on_remove_user_third_party_identifier_callback_mock.call_args[0] + self.assertEqual(args, (user_id, "email", "foo@example.com")) + + def test_on_remove_user_third_party_identifier_is_called_on_deactivate( + self, + ) -> None: + """Tests that the on_remove_user_third_party_identifier module callback is called + when a user is deactivated and their third-party ID associations are deleted. + """ + # Pretend to be a Synapse module and register both callbacks as mocks. + third_party_rules = self.hs.get_third_party_event_rules() + on_remove_user_third_party_identifier_callback_mock = Mock( + return_value=make_awaitable(None) + ) + third_party_rules._on_threepid_bind_callbacks.append( + on_remove_user_third_party_identifier_callback_mock + ) + + # Register an admin user. + self.register_user("admin", "password", admin=True) + admin_tok = self.login("admin", "password") + + # Also register a normal user we can modify. + user_id = self.register_user("user", "password") + + # Add a 3PID to the user. + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % user_id, + { + "threepids": [ + { + "medium": "email", + "address": "foo@example.com", + }, + ], + }, + access_token=admin_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Now deactivate the user. + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % user_id, + { + "deactivated": True, + }, + access_token=admin_tok, + ) + + # Check that the mocked remove callback was called with the appropriate + # 3PID details. + self.assertEqual(channel.code, 200, channel.json_body) + on_remove_user_third_party_identifier_callback_mock.assert_called_once() + args = on_remove_user_third_party_identifier_callback_mock.call_args[0] + self.assertEqual(args, (user_id, "email", "foo@example.com")) -- cgit 1.5.1 From 93f7955eba50c827f96e1b2e8e44ef22a98cecc4 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 28 Feb 2023 13:09:10 +0100 Subject: Admin API endpoint to delete a reported event (#15116) * Admin api to delete event report * lint + tests * newsfile * Apply suggestions from code review Co-authored-by: David Robertson * revert changes - move to WorkerStore * update unit test * Note that timestamp is in millseconds --------- Co-authored-by: David Robertson --- changelog.d/15116.feature | 1 + docs/admin_api/event_reports.md | 14 ++++ synapse/rest/admin/event_reports.py | 41 ++++++++-- synapse/storage/databases/main/room.py | 36 ++++++++- tests/rest/admin/test_event_reports.py | 143 ++++++++++++++++++++++++++++++++- 5 files changed, 224 insertions(+), 11 deletions(-) create mode 100644 changelog.d/15116.feature (limited to 'synapse/storage/databases') diff --git a/changelog.d/15116.feature b/changelog.d/15116.feature new file mode 100644 index 0000000000..087d8dc7f1 --- /dev/null +++ b/changelog.d/15116.feature @@ -0,0 +1 @@ +Add an [admin API](https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/index.html) to delete a [specific event report](https://spec.matrix.org/v1.6/client-server-api/#reporting-content). \ No newline at end of file diff --git a/docs/admin_api/event_reports.md b/docs/admin_api/event_reports.md index beec8bb7ef..83f7dc37f4 100644 --- a/docs/admin_api/event_reports.md +++ b/docs/admin_api/event_reports.md @@ -169,3 +169,17 @@ The following fields are returned in the JSON response body: * `canonical_alias`: string - The canonical alias of the room. `null` if the room does not have a canonical alias set. * `event_json`: object - Details of the original event that was reported. + +# Delete a specific event report + +This API deletes a specific event report. If the request is successful, the response body +will be an empty JSON object. + +The api is: +``` +DELETE /_synapse/admin/v1/event_reports/ +``` + +**URL parameters:** + +* `report_id`: string - The ID of the event report. diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py index a3beb74e2c..c546ef7e23 100644 --- a/synapse/rest/admin/event_reports.py +++ b/synapse/rest/admin/event_reports.py @@ -53,11 +53,11 @@ class EventReportsRestServlet(RestServlet): PATTERNS = admin_patterns("/event_reports$") def __init__(self, hs: "HomeServer"): - self.auth = hs.get_auth() - self.store = hs.get_datastores().main + self._auth = hs.get_auth() + self._store = hs.get_datastores().main async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - await assert_requester_is_admin(self.auth, request) + await assert_requester_is_admin(self._auth, request) start = parse_integer(request, "from", default=0) limit = parse_integer(request, "limit", default=100) @@ -79,7 +79,7 @@ class EventReportsRestServlet(RestServlet): errcode=Codes.INVALID_PARAM, ) - event_reports, total = await self.store.get_event_reports_paginate( + event_reports, total = await self._store.get_event_reports_paginate( start, limit, direction, user_id, room_id ) ret = {"event_reports": event_reports, "total": total} @@ -108,13 +108,13 @@ class EventReportDetailRestServlet(RestServlet): PATTERNS = admin_patterns("/event_reports/(?P[^/]*)$") def __init__(self, hs: "HomeServer"): - self.auth = hs.get_auth() - self.store = hs.get_datastores().main + self._auth = hs.get_auth() + self._store = hs.get_datastores().main async def on_GET( self, request: SynapseRequest, report_id: str ) -> Tuple[int, JsonDict]: - await assert_requester_is_admin(self.auth, request) + await assert_requester_is_admin(self._auth, request) message = ( "The report_id parameter must be a string representing a positive integer." @@ -131,8 +131,33 @@ class EventReportDetailRestServlet(RestServlet): HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM ) - ret = await self.store.get_event_report(resolved_report_id) + ret = await self._store.get_event_report(resolved_report_id) if not ret: raise NotFoundError("Event report not found") return HTTPStatus.OK, ret + + async def on_DELETE( + self, request: SynapseRequest, report_id: str + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self._auth, request) + + message = ( + "The report_id parameter must be a string representing a positive integer." + ) + try: + resolved_report_id = int(report_id) + except ValueError: + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) + + if resolved_report_id < 0: + raise SynapseError( + HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM + ) + + if await self._store.delete_event_report(resolved_report_id): + return HTTPStatus.OK, {} + + raise NotFoundError("Event report not found") diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 39f89291b2..a2e9519cb6 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1417,6 +1417,27 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): get_un_partial_stated_rooms_from_stream_txn, ) + async def delete_event_report(self, report_id: int) -> bool: + """Remove an event report from database. + + Args: + report_id: Report to delete + + Returns: + Whether the report was successfully deleted or not. + """ + try: + await self.db_pool.simple_delete_one( + table="event_reports", + keyvalues={"id": report_id}, + desc="delete_event_report", + ) + except StoreError: + # Deletion failed because report does not exist + return False + + return True + class _BackgroundUpdates: REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" @@ -2139,7 +2160,19 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): reason: Optional[str], content: JsonDict, received_ts: int, - ) -> None: + ) -> int: + """Add an event report + + Args: + room_id: Room that contains the reported event. + event_id: The reported event. + user_id: User who reports the event. + reason: Description that the user specifies. + content: Report request body (score and reason). + received_ts: Time when the user submitted the report (milliseconds). + Returns: + Id of the event report. + """ next_id = self._event_reports_id_gen.get_next() await self.db_pool.simple_insert( table="event_reports", @@ -2154,6 +2187,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): }, desc="add_event_report", ) + return next_id async def get_event_report(self, report_id: int) -> Optional[Dict[str, Any]]: """Retrieve an event report diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py index 233eba3516..f189b07769 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py @@ -78,7 +78,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): """ Try to get an event report without authentication. """ - channel = self.make_request("GET", self.url, b"{}") + channel = self.make_request("GET", self.url, {}) self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -473,7 +473,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): """ Try to get event report without authentication. """ - channel = self.make_request("GET", self.url, b"{}") + channel = self.make_request("GET", self.url, {}) self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) @@ -599,3 +599,142 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase): self.assertIn("room_id", content["event_json"]) self.assertIn("sender", content["event_json"]) self.assertIn("content", content["event_json"]) + + +class DeleteEventReportTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self._store = hs.get_datastores().main + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + # create report + event_id = self.get_success( + self._store.add_event_report( + "room_id", + "event_id", + self.other_user, + "this makes me sad", + {}, + self.clock.time_msec(), + ) + ) + + self.url = f"/_synapse/admin/v1/event_reports/{event_id}" + + def test_no_auth(self) -> None: + """ + Try to delete event report without authentication. + """ + channel = self.make_request("DELETE", self.url) + + self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self) -> None: + """ + If the user is not a server admin, an error 403 is returned. + """ + + channel = self.make_request( + "DELETE", + self.url, + access_token=self.other_user_tok, + ) + + self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_delete_success(self) -> None: + """ + Testing delete a report. + """ + + channel = self.make_request( + "DELETE", + self.url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual({}, channel.json_body) + + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) + + # check that report was deleted + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_invalid_report_id(self) -> None: + """ + Testing that an invalid `report_id` returns a 400. + """ + + # `report_id` is negative + channel = self.make_request( + "DELETE", + "/_synapse/admin/v1/event_reports/-123", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + # `report_id` is a non-numerical string + channel = self.make_request( + "DELETE", + "/_synapse/admin/v1/event_reports/abcdef", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + # `report_id` is undefined + channel = self.make_request( + "DELETE", + "/_synapse/admin/v1/event_reports/", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + def test_report_id_not_found(self) -> None: + """ + Testing that a not existing `report_id` returns a 404. + """ + + channel = self.make_request( + "DELETE", + "/_synapse/admin/v1/event_reports/123", + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + self.assertEqual("Event report not found", channel.json_body["error"]) -- cgit 1.5.1 From 682d31c7023b6b7299e74bc631e4d2acc60f91ac Mon Sep 17 00:00:00 2001 From: reivilibre Date: Tue, 28 Feb 2023 16:37:19 +0000 Subject: Allow use of the `/filter` Client-Server APIs on workers. (#15134) --- changelog.d/15134.feature | 1 + docker/configure_workers_and_start.py | 1 + docs/workers.md | 1 + synapse/rest/__init__.py | 3 +-- synapse/storage/databases/main/__init__.py | 4 ++-- synapse/storage/databases/main/filtering.py | 25 +++++++++++++++++++++---- 6 files changed, 27 insertions(+), 8 deletions(-) create mode 100644 changelog.d/15134.feature (limited to 'synapse/storage/databases') diff --git a/changelog.d/15134.feature b/changelog.d/15134.feature new file mode 100644 index 0000000000..0dbb30bc8f --- /dev/null +++ b/changelog.d/15134.feature @@ -0,0 +1 @@ +Allow use of the `/filter` Client-Server APIs on workers. \ No newline at end of file diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py index 58c62f2231..7f615e5066 100755 --- a/docker/configure_workers_and_start.py +++ b/docker/configure_workers_and_start.py @@ -142,6 +142,7 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = { "^/_matrix/client/(api/v1|r0|v3|unstable/.*)/rooms/.*/aliases", "^/_matrix/client/v1/rooms/.*/timestamp_to_event$", "^/_matrix/client/(api/v1|r0|v3|unstable)/search", + "^/_matrix/client/(r0|v3|unstable)/user/.*/filter(/|$)", ], "shared_extra_conf": {}, "worker_extra_conf": "", diff --git a/docs/workers.md b/docs/workers.md index 2eb970ffa6..35a96f12a9 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -232,6 +232,7 @@ information. ^/_matrix/client/(api/v1|r0|v3|unstable)/joined_rooms$ ^/_matrix/client/v1/rooms/.*/timestamp_to_event$ ^/_matrix/client/(api/v1|r0|v3|unstable)/search$ + ^/_matrix/client/(r0|v3|unstable)/user/.*/filter(/|$) # Encryption requests ^/_matrix/client/(r0|v3|unstable)/keys/query$ diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 14c4e6ebbb..c327f15043 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -108,8 +108,7 @@ class ClientRestResource(JsonResource): if is_main_process: logout.register_servlets(hs, client_resource) sync.register_servlets(hs, client_resource) - if is_main_process: - filter.register_servlets(hs, client_resource) + filter.register_servlets(hs, client_resource) account.register_servlets(hs, client_resource) register.register_servlets(hs, client_resource) if is_main_process: diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 837dc7646e..dc3948c170 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -43,7 +43,7 @@ from .event_federation import EventFederationStore from .event_push_actions import EventPushActionsStore from .events_bg_updates import EventsBackgroundUpdatesStore from .events_forward_extremities import EventForwardExtremitiesStore -from .filtering import FilteringStore +from .filtering import FilteringWorkerStore from .keys import KeyStore from .lock import LockStore from .media_repository import MediaRepositoryStore @@ -99,7 +99,7 @@ class DataStore( EventFederationStore, MediaRepositoryStore, RejectionsStore, - FilteringStore, + FilteringWorkerStore, PusherStore, PushRuleStore, ApplicationServiceTransactionStore, diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py index 12f3b601f1..8e57c8e5a0 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py @@ -17,7 +17,7 @@ from typing import Optional, Tuple, Union, cast from canonicaljson import encode_canonical_json -from synapse.api.errors import Codes, SynapseError +from synapse.api.errors import Codes, StoreError, SynapseError from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import LoggingTransaction from synapse.types import JsonDict @@ -46,8 +46,6 @@ class FilteringWorkerStore(SQLBaseStore): return db_to_json(def_json) - -class FilteringStore(FilteringWorkerStore): async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int: def_json = encode_canonical_json(user_filter) @@ -79,4 +77,23 @@ class FilteringStore(FilteringWorkerStore): return filter_id - return await self.db_pool.runInteraction("add_user_filter", _do_txn) + attempts = 0 + while True: + # Try a few times. + # This is technically needed if a user tries to create two filters at once, + # leading to two concurrent transactions. + # The failure case would be: + # - SELECT filter_id ... filter_json = ? → both transactions return no rows + # - SELECT MAX(filter_id) ... → both transactions return e.g. 5 + # - INSERT INTO ... → both transactions insert filter_id = 6 + # One of the transactions will commit. The other will get a unique key + # constraint violation error (IntegrityError). This is not the same as a + # serialisability violation, which would be automatically retried by + # `runInteraction`. + try: + return await self.db_pool.runInteraction("add_user_filter", _do_txn) + except self.db_pool.engine.module.IntegrityError: + attempts += 1 + + if attempts >= 5: + raise StoreError(500, "Couldn't generate a filter ID.") -- cgit 1.5.1 From d62cd940cb38e706f7fadc279017b0be3f3f29a3 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Tue, 28 Feb 2023 17:11:26 +0000 Subject: Fix a long-standing bug where an initial sync would not respond to changes to the list of ignored users if there was an initial sync cached. (#15163) --- changelog.d/15163.bugfix | 1 + synapse/rest/client/sync.py | 25 +++++++++++++++++++-- synapse/storage/databases/main/account_data.py | 31 ++++++++++++++++++++++++++ tests/storage/test_account_data.py | 22 ++++++++++++++++++ 4 files changed, 77 insertions(+), 2 deletions(-) create mode 100644 changelog.d/15163.bugfix (limited to 'synapse/storage/databases') diff --git a/changelog.d/15163.bugfix b/changelog.d/15163.bugfix new file mode 100644 index 0000000000..7ff1cd4463 --- /dev/null +++ b/changelog.d/15163.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where an initial sync would not respond to changes to the list of ignored users if there was an initial sync cached. \ No newline at end of file diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index f2013faeb2..8fcb8ac3d9 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -16,7 +16,7 @@ import logging from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union -from synapse.api.constants import EduTypes, Membership, PresenceState +from synapse.api.constants import AccountDataTypes, EduTypes, Membership, PresenceState from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState @@ -139,7 +139,28 @@ class SyncRestServlet(RestServlet): device_id, ) - request_key = (user, timeout, since, filter_id, full_state, device_id) + # Stream position of the last ignored users account data event for this user, + # if we're initial syncing. + # We include this in the request key to invalidate an initial sync + # in the response cache once the set of ignored users has changed. + # (We filter out ignored users from timeline events, so our sync response + # is invalid once the set of ignored users changes.) + last_ignore_accdata_streampos: Optional[int] = None + if not since: + # No `since`, so this is an initial sync. + last_ignore_accdata_streampos = await self.store.get_latest_stream_id_for_global_account_data_by_type_for_user( + user.to_string(), AccountDataTypes.IGNORED_USER_LIST + ) + + request_key = ( + user, + timeout, + since, + filter_id, + full_state, + device_id, + last_ignore_accdata_streampos, + ) if filter_id is None: filter_collection = self.filtering.DEFAULT_FILTER_COLLECTION diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 95567826f2..308d19440f 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -237,6 +237,37 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) else: return None + async def get_latest_stream_id_for_global_account_data_by_type_for_user( + self, user_id: str, data_type: str + ) -> Optional[int]: + """ + Returns: + The stream ID of the account data, + or None if there is no such account data. + """ + + def get_latest_stream_id_for_global_account_data_by_type_for_user_txn( + txn: LoggingTransaction, + ) -> Optional[int]: + sql = """ + SELECT stream_id FROM account_data + WHERE user_id = ? AND account_data_type = ? + ORDER BY stream_id DESC + LIMIT 1 + """ + txn.execute(sql, (user_id, data_type)) + + row = txn.fetchone() + if row: + return row[0] + else: + return None + + return await self.db_pool.runInteraction( + "get_latest_stream_id_for_global_account_data_by_type_for_user", + get_latest_stream_id_for_global_account_data_by_type_for_user_txn, + ) + @cached(num_args=2, tree=True) async def get_account_data_for_room( self, user_id: str, room_id: str diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py index 1bfd11ceae..b12691a9d3 100644 --- a/tests/storage/test_account_data.py +++ b/tests/storage/test_account_data.py @@ -140,3 +140,25 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): # No one ignores the user now. self.assert_ignored(self.user, set()) self.assert_ignorers("@other:test", set()) + + def test_ignoring_users_with_latest_stream_ids(self) -> None: + """Test that ignoring users updates the latest stream ID for the ignored + user list account data.""" + + def get_latest_ignore_streampos(user_id: str) -> Optional[int]: + return self.get_success( + self.store.get_latest_stream_id_for_global_account_data_by_type_for_user( + user_id, AccountDataTypes.IGNORED_USER_LIST + ) + ) + + self.assertIsNone(get_latest_ignore_streampos("@user:test")) + + self._update_ignore_list("@other:test", "@another:remote") + + self.assertEqual(get_latest_ignore_streampos("@user:test"), 2) + + # Add one user, remove one user, and leave one user. + self._update_ignore_list("@foo:test", "@another:remote") + + self.assertEqual(get_latest_ignore_streampos("@user:test"), 3) -- cgit 1.5.1 From 2b78981736f9004f99b1760e3e77b234f92755a7 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 28 Feb 2023 18:49:28 +0000 Subject: Remove support for aggregating reactions (#15172) It turns out that no clients rely on server-side aggregation of `m.annotation` relationships: it's just not very useful as currently implemented. It's also non-trivial to calculate. I want to remove it from MSC2677, so to keep the implementation in line, let's remove it here. --- changelog.d/15172.feature | 1 + synapse/events/utils.py | 5 - synapse/handlers/relations.py | 76 +-------- synapse/storage/databases/main/cache.py | 3 - synapse/storage/databases/main/events.py | 4 - .../storage/databases/main/events_bg_updates.py | 3 - synapse/storage/databases/main/relations.py | 137 ---------------- tests/rest/client/test_relations.py | 178 ++++----------------- 8 files changed, 30 insertions(+), 377 deletions(-) create mode 100644 changelog.d/15172.feature (limited to 'synapse/storage/databases') diff --git a/changelog.d/15172.feature b/changelog.d/15172.feature new file mode 100644 index 0000000000..3f789edb7f --- /dev/null +++ b/changelog.d/15172.feature @@ -0,0 +1 @@ +Remove support for server-side aggregation of reactions. diff --git a/synapse/events/utils.py b/synapse/events/utils.py index ebf8c7ed83..eaa6cad4af 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -516,11 +516,6 @@ class EventClientSerializer: # being serialized. serialized_aggregations = {} - if event_aggregations.annotations: - serialized_aggregations[ - RelationTypes.ANNOTATION - ] = event_aggregations.annotations - if event_aggregations.references: serialized_aggregations[ RelationTypes.REFERENCE diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 0fb15391e0..553053b694 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -60,13 +60,12 @@ class BundledAggregations: Some values require additional processing during serialization. """ - annotations: Optional[JsonDict] = None references: Optional[JsonDict] = None replace: Optional[EventBase] = None thread: Optional[_ThreadAggregation] = None def __bool__(self) -> bool: - return bool(self.annotations or self.references or self.replace or self.thread) + return bool(self.references or self.replace or self.thread) class RelationsHandler: @@ -227,67 +226,6 @@ class RelationsHandler: e.msg, ) - async def get_annotations_for_events( - self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset() - ) -> Dict[str, List[JsonDict]]: - """Get a list of annotations to the given events, grouped by event type and - aggregation key, sorted by count. - - This is used e.g. to get the what and how many reactions have happened - on an event. - - Args: - event_ids: Fetch events that relate to these event IDs. - ignored_users: The users ignored by the requesting user. - - Returns: - A map of event IDs to a list of groups of annotations that match. - Each entry is a dict with `type`, `key` and `count` fields. - """ - # Get the base results for all users. - full_results = await self._main_store.get_aggregation_groups_for_events( - event_ids - ) - - # Avoid additional logic if there are no ignored users. - if not ignored_users: - return { - event_id: results - for event_id, results in full_results.items() - if results - } - - # Then subtract off the results for any ignored users. - ignored_results = await self._main_store.get_aggregation_groups_for_users( - [event_id for event_id, results in full_results.items() if results], - ignored_users, - ) - - filtered_results = {} - for event_id, results in full_results.items(): - # If no annotations, skip. - if not results: - continue - - # If there are not ignored results for this event, copy verbatim. - if event_id not in ignored_results: - filtered_results[event_id] = results - continue - - # Otherwise, subtract out the ignored results. - event_ignored_results = ignored_results[event_id] - for result in results: - key = (result["type"], result["key"]) - if key in event_ignored_results: - # Ensure to not modify the cache. - result = result.copy() - result["count"] -= event_ignored_results[key] - if result["count"] <= 0: - continue - filtered_results.setdefault(event_id, []).append(result) - - return filtered_results - async def get_references_for_events( self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset() ) -> Dict[str, List[_RelatedEvent]]: @@ -531,17 +469,6 @@ class RelationsHandler: # (as that is what makes it part of the thread). relations_by_id[latest_thread_event.event_id] = RelationTypes.THREAD - async def _fetch_annotations() -> None: - """Fetch any annotations (ie, reactions) to bundle with this event.""" - annotations_by_event_id = await self.get_annotations_for_events( - events_by_id.keys(), ignored_users=ignored_users - ) - for event_id, annotations in annotations_by_event_id.items(): - if annotations: - results.setdefault(event_id, BundledAggregations()).annotations = { - "chunk": annotations - } - async def _fetch_references() -> None: """Fetch any references to bundle with this event.""" references_by_event_id = await self.get_references_for_events( @@ -575,7 +502,6 @@ class RelationsHandler: await make_deferred_yieldable( gather_results( ( - run_in_background(_fetch_annotations), run_in_background(_fetch_references), run_in_background(_fetch_edits), ) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 5b66431691..096dec7f87 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -266,9 +266,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore): if relates_to: self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,)) self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,)) - self._attempt_to_invalidate_cache( - "get_aggregation_groups_for_event", (relates_to,) - ) self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,)) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 73b8aea16c..a8a4ed4436 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -2024,10 +2024,6 @@ class PersistEventsStore: self.store._invalidate_cache_and_stream( txn, self.store.get_relations_for_event, (redacted_relates_to,) ) - if rel_type == RelationTypes.ANNOTATION: - self.store._invalidate_cache_and_stream( - txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,) - ) if rel_type == RelationTypes.REFERENCE: self.store._invalidate_cache_and_stream( txn, self.store.get_references_for_event, (redacted_relates_to,) diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 0a275e6ce6..daef3685b0 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -1219,9 +1219,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): self._invalidate_cache_and_stream( # type: ignore[attr-defined] txn, self.get_relations_for_event, cache_tuple # type: ignore[attr-defined] ) - self._invalidate_cache_and_stream( # type: ignore[attr-defined] - txn, self.get_aggregation_groups_for_event, cache_tuple # type: ignore[attr-defined] - ) self._invalidate_cache_and_stream( # type: ignore[attr-defined] txn, self.get_thread_summary, cache_tuple # type: ignore[attr-defined] ) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index fa3266c081..bc3a83919c 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -397,143 +397,6 @@ class RelationsWorkerStore(SQLBaseStore): ) return result is not None - @cached() - async def get_aggregation_groups_for_event( - self, event_id: str - ) -> Sequence[JsonDict]: - raise NotImplementedError() - - @cachedList( - cached_method_name="get_aggregation_groups_for_event", list_name="event_ids" - ) - async def get_aggregation_groups_for_events( - self, event_ids: Collection[str] - ) -> Mapping[str, Optional[List[JsonDict]]]: - """Get a list of annotations on the given events, grouped by event type and - aggregation key, sorted by count. - - This is used e.g. to get the what and how many reactions have happend - on an event. - - Args: - event_ids: Fetch events that relate to these event IDs. - - Returns: - A map of event IDs to a list of groups of annotations that match. - Each entry is a dict with `type`, `key` and `count` fields. - """ - # The number of entries to return per event ID. - limit = 5 - - clause, args = make_in_list_sql_clause( - self.database_engine, "relates_to_id", event_ids - ) - args.append(RelationTypes.ANNOTATION) - - sql = f""" - SELECT - relates_to_id, - annotation.type, - aggregation_key, - COUNT(DISTINCT annotation.sender) - FROM events AS annotation - INNER JOIN event_relations USING (event_id) - INNER JOIN events AS parent ON - parent.event_id = relates_to_id - AND parent.room_id = annotation.room_id - WHERE - {clause} - AND relation_type = ? - GROUP BY relates_to_id, annotation.type, aggregation_key - ORDER BY relates_to_id, COUNT(*) DESC - """ - - def _get_aggregation_groups_for_events_txn( - txn: LoggingTransaction, - ) -> Mapping[str, List[JsonDict]]: - txn.execute(sql, args) - - result: Dict[str, List[JsonDict]] = {} - for event_id, type, key, count in cast( - List[Tuple[str, str, str, int]], txn - ): - event_results = result.setdefault(event_id, []) - - # Limit the number of results per event ID. - if len(event_results) == limit: - continue - - event_results.append({"type": type, "key": key, "count": count}) - - return result - - return await self.db_pool.runInteraction( - "get_aggregation_groups_for_events", _get_aggregation_groups_for_events_txn - ) - - async def get_aggregation_groups_for_users( - self, event_ids: Collection[str], users: FrozenSet[str] - ) -> Dict[str, Dict[Tuple[str, str], int]]: - """Fetch the partial aggregations for an event for specific users. - - This is used, in conjunction with get_aggregation_groups_for_event, to - remove information from the results for ignored users. - - Args: - event_ids: Fetch events that relate to these event IDs. - users: The users to fetch information for. - - Returns: - A map of event ID to a map of (event type, aggregation key) to a - count of users. - """ - - if not users: - return {} - - events_sql, args = make_in_list_sql_clause( - self.database_engine, "relates_to_id", event_ids - ) - - users_sql, users_args = make_in_list_sql_clause( - self.database_engine, "annotation.sender", users - ) - args.extend(users_args) - args.append(RelationTypes.ANNOTATION) - - sql = f""" - SELECT - relates_to_id, - annotation.type, - aggregation_key, - COUNT(DISTINCT annotation.sender) - FROM events AS annotation - INNER JOIN event_relations USING (event_id) - INNER JOIN events AS parent ON - parent.event_id = relates_to_id - AND parent.room_id = annotation.room_id - WHERE {events_sql} AND {users_sql} AND relation_type = ? - GROUP BY relates_to_id, annotation.type, aggregation_key - ORDER BY relates_to_id, COUNT(*) DESC - """ - - def _get_aggregation_groups_for_users_txn( - txn: LoggingTransaction, - ) -> Dict[str, Dict[Tuple[str, str], int]]: - txn.execute(sql, args) - - result: Dict[str, Dict[Tuple[str, str], int]] = {} - for event_id, type, key, count in cast( - List[Tuple[str, str, str, int]], txn - ): - result.setdefault(event_id, {})[(type, key)] = count - - return result - - return await self.db_pool.runInteraction( - "get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn - ) - @cached() async def get_references_for_event(self, event_id: str) -> List[JsonDict]: raise NotImplementedError() diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index c8a6911d5e..a8a0a16141 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1080,48 +1080,6 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): ] assert_bundle(self._find_event_in_chunk(chunk)) - def test_annotation(self) -> None: - """ - Test that annotations get correctly bundled. - """ - # Setup by sending a variety of relations. - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token - ) - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - - def assert_annotations(bundled_aggregations: JsonDict) -> None: - self.assertEqual( - { - "chunk": [ - {"type": "m.reaction", "key": "a", "count": 2}, - {"type": "m.reaction", "key": "b", "count": 1}, - ] - }, - bundled_aggregations, - ) - - self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7) - - def test_annotation_to_annotation(self) -> None: - """Any relation to an annotation should be ignored.""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - event_id = channel.json_body["event_id"] - self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "b", parent_id=event_id - ) - - # Fetch the initial annotation event to see if it has bundled aggregations. - channel = self.make_request( - "GET", - f"/_matrix/client/v3/rooms/{self.room}/event/{event_id}", - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - # The first annotationt should not have any bundled aggregations. - self.assertNotIn("m.relations", channel.json_body["unsigned"]) - def test_reference(self) -> None: """ Test that references get correctly bundled. @@ -1138,7 +1096,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations, ) - self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7) + self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 6) def test_thread(self) -> None: """ @@ -1183,7 +1141,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): # The "user" sent the root event and is making queries for the bundled # aggregations: they have participated. - self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 7) + self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 6) # The "user2" sent replies in the thread and is making queries for the # bundled aggregations: they have participated. # @@ -1208,9 +1166,10 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): channel = self._send_relation(RelationTypes.THREAD, "m.room.test") thread_2 = channel.json_body["event_id"] - self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_2 + channel = self._send_relation( + RelationTypes.REFERENCE, "org.matrix.test", parent_id=thread_2 ) + reference_event_id = channel.json_body["event_id"] def assert_thread(bundled_aggregations: JsonDict) -> None: self.assertEqual(2, bundled_aggregations.get("count")) @@ -1235,17 +1194,15 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): self.assert_dict( { "m.relations": { - RelationTypes.ANNOTATION: { - "chunk": [ - {"type": "m.reaction", "key": "a", "count": 1}, - ] + RelationTypes.REFERENCE: { + "chunk": [{"event_id": reference_event_id}] }, } }, bundled_aggregations["latest_event"].get("unsigned"), ) - self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 7) + self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 6) def test_nested_thread(self) -> None: """ @@ -1363,10 +1320,11 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): channel = self._send_relation(RelationTypes.THREAD, "m.room.test") thread_id = channel.json_body["event_id"] - # Annotate the thread. - self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id + # Make a reference to the thread. + channel = self._send_relation( + RelationTypes.REFERENCE, "org.matrix.test", parent_id=thread_id ) + reference_event_id = channel.json_body["event_id"] channel = self.make_request( "GET", @@ -1377,9 +1335,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): self.assertEqual( channel.json_body["unsigned"].get("m.relations"), { - RelationTypes.ANNOTATION: { - "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] - }, + RelationTypes.REFERENCE: {"chunk": [{"event_id": reference_event_id}]}, }, ) @@ -1396,9 +1352,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): self.assertEqual( thread_message["unsigned"].get("m.relations"), { - RelationTypes.ANNOTATION: { - "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] - }, + RelationTypes.REFERENCE: {"chunk": [{"event_id": reference_event_id}]}, }, ) @@ -1410,7 +1364,8 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): Note that the spec allows for a server to return additional fields beyond what is specified. """ - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + channel = self._send_relation(RelationTypes.REFERENCE, "org.matrix.test") + reference_event_id = channel.json_body["event_id"] # Note that the sync filter does not include "unsigned" as a field. filter = urllib.parse.quote_plus( @@ -1428,7 +1383,12 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): # Ensure there's bundled aggregations on it. self.assertIn("unsigned", parent_event) - self.assertIn("m.relations", parent_event["unsigned"]) + self.assertEqual( + parent_event["unsigned"].get("m.relations"), + { + RelationTypes.REFERENCE: {"chunk": [{"event_id": reference_event_id}]}, + }, + ) class RelationIgnoredUserTestCase(BaseRelationsTestCase): @@ -1475,53 +1435,8 @@ class RelationIgnoredUserTestCase(BaseRelationsTestCase): return before_aggregations[relation_type], after_aggregations[relation_type] - def test_annotation(self) -> None: - """Annotations should ignore""" - # Send 2 from us, 2 from the to be ignored user. - allowed_event_ids = [] - ignored_event_ids = [] - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - allowed_event_ids.append(channel.json_body["event_id"]) - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="b") - allowed_event_ids.append(channel.json_body["event_id"]) - channel = self._send_relation( - RelationTypes.ANNOTATION, - "m.reaction", - key="a", - access_token=self.user2_token, - ) - ignored_event_ids.append(channel.json_body["event_id"]) - channel = self._send_relation( - RelationTypes.ANNOTATION, - "m.reaction", - key="c", - access_token=self.user2_token, - ) - ignored_event_ids.append(channel.json_body["event_id"]) - - before_aggregations, after_aggregations = self._test_ignored_user( - RelationTypes.ANNOTATION, allowed_event_ids, ignored_event_ids - ) - - self.assertCountEqual( - before_aggregations["chunk"], - [ - {"type": "m.reaction", "key": "a", "count": 2}, - {"type": "m.reaction", "key": "b", "count": 1}, - {"type": "m.reaction", "key": "c", "count": 1}, - ], - ) - - self.assertCountEqual( - after_aggregations["chunk"], - [ - {"type": "m.reaction", "key": "a", "count": 1}, - {"type": "m.reaction", "key": "b", "count": 1}, - ], - ) - def test_reference(self) -> None: - """Annotations should ignore""" + """Aggregations should exclude reference relations from ignored users""" channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") allowed_event_ids = [channel.json_body["event_id"]] @@ -1544,7 +1459,7 @@ class RelationIgnoredUserTestCase(BaseRelationsTestCase): ) def test_thread(self) -> None: - """Annotations should ignore""" + """Aggregations should exclude thread releations from ignored users""" channel = self._send_relation(RelationTypes.THREAD, "m.room.test") allowed_event_ids = [channel.json_body["event_id"]] @@ -1618,43 +1533,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase): for t in threads ] - def test_redact_relation_annotation(self) -> None: - """ - Test that annotations of an event are properly handled after the - annotation is redacted. - - The redacted relation should not be included in bundled aggregations or - the response to relations. - """ - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - to_redact_event_id = channel.json_body["event_id"] - - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token - ) - unredacted_event_id = channel.json_body["event_id"] - - # Both relations should exist. - event_ids = self._get_related_events() - relations = self._get_bundled_aggregations() - self.assertCountEqual(event_ids, [to_redact_event_id, unredacted_event_id]) - self.assertEquals( - relations["m.annotation"], - {"chunk": [{"type": "m.reaction", "key": "a", "count": 2}]}, - ) - - # Redact one of the reactions. - self._redact(to_redact_event_id) - - # The unredacted relation should still exist. - event_ids = self._get_related_events() - relations = self._get_bundled_aggregations() - self.assertEquals(event_ids, [unredacted_event_id]) - self.assertEquals( - relations["m.annotation"], - {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, - ) - def test_redact_relation_thread(self) -> None: """ Test that thread replies are properly handled after the thread reply redacted. @@ -1775,14 +1653,14 @@ class RelationRedactionTestCase(BaseRelationsTestCase): is redacted. """ # Add a relation - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") + channel = self._send_relation(RelationTypes.REFERENCE, "org.matrix.test") related_event_id = channel.json_body["event_id"] # The relations should exist. event_ids = self._get_related_events() relations = self._get_bundled_aggregations() self.assertEqual(len(event_ids), 1) - self.assertIn(RelationTypes.ANNOTATION, relations) + self.assertIn(RelationTypes.REFERENCE, relations) # Redact the original event. self._redact(self.parent_id) @@ -1792,8 +1670,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase): relations = self._get_bundled_aggregations() self.assertEquals(event_ids, [related_event_id]) self.assertEquals( - relations["m.annotation"], - {"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]}, + relations[RelationTypes.REFERENCE], + {"chunk": [{"event_id": related_event_id}]}, ) def test_redact_parent_thread(self) -> None: -- cgit 1.5.1 From 65f10afb64127dc9412e24860c5e8a78f3dc9863 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 2 Mar 2023 11:38:46 +0100 Subject: Move event_reports to `RoomWorkerStore` (#15165) --- changelog.d/15165.misc | 1 + synapse/storage/databases/main/room.py | 354 ++++++++++++++++----------------- 2 files changed, 178 insertions(+), 177 deletions(-) create mode 100644 changelog.d/15165.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/15165.misc b/changelog.d/15165.misc new file mode 100644 index 0000000000..a75be84dac --- /dev/null +++ b/changelog.d/15165.misc @@ -0,0 +1 @@ +Move `get_event_report` and `get_event_reports_paginate` from `RoomStore` to `RoomWorkerStore`. \ No newline at end of file diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index a2e9519cb6..3825bd6079 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1417,6 +1417,183 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): get_un_partial_stated_rooms_from_stream_txn, ) + async def get_event_report(self, report_id: int) -> Optional[Dict[str, Any]]: + """Retrieve an event report + + Args: + report_id: ID of reported event in database + Returns: + JSON dict of information from an event report or None if the + report does not exist. + """ + + def _get_event_report_txn( + txn: LoggingTransaction, report_id: int + ) -> Optional[Dict[str, Any]]: + sql = """ + SELECT + er.id, + er.received_ts, + er.room_id, + er.event_id, + er.user_id, + er.content, + events.sender, + room_stats_state.canonical_alias, + room_stats_state.name, + event_json.json AS event_json + FROM event_reports AS er + LEFT JOIN events + ON events.event_id = er.event_id + JOIN event_json + ON event_json.event_id = er.event_id + JOIN room_stats_state + ON room_stats_state.room_id = er.room_id + WHERE er.id = ? + """ + + txn.execute(sql, [report_id]) + row = txn.fetchone() + + if not row: + return None + + event_report = { + "id": row[0], + "received_ts": row[1], + "room_id": row[2], + "event_id": row[3], + "user_id": row[4], + "score": db_to_json(row[5]).get("score"), + "reason": db_to_json(row[5]).get("reason"), + "sender": row[6], + "canonical_alias": row[7], + "name": row[8], + "event_json": db_to_json(row[9]), + } + + return event_report + + return await self.db_pool.runInteraction( + "get_event_report", _get_event_report_txn, report_id + ) + + async def get_event_reports_paginate( + self, + start: int, + limit: int, + direction: Direction = Direction.BACKWARDS, + user_id: Optional[str] = None, + room_id: Optional[str] = None, + ) -> Tuple[List[Dict[str, Any]], int]: + """Retrieve a paginated list of event reports + + Args: + start: event offset to begin the query from + limit: number of rows to retrieve + direction: Whether to fetch the most recent first (backwards) or the + oldest first (forwards) + user_id: search for user_id. Ignored if user_id is None + room_id: search for room_id. Ignored if room_id is None + Returns: + Tuple of: + json list of event reports + total number of event reports matching the filter criteria + """ + + def _get_event_reports_paginate_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Dict[str, Any]], int]: + filters = [] + args: List[object] = [] + + if user_id: + filters.append("er.user_id LIKE ?") + args.extend(["%" + user_id + "%"]) + if room_id: + filters.append("er.room_id LIKE ?") + args.extend(["%" + room_id + "%"]) + + if direction == Direction.BACKWARDS: + order = "DESC" + else: + order = "ASC" + + where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" + + # We join on room_stats_state despite not using any columns from it + # because the join can influence the number of rows returned; + # e.g. a room that doesn't have state, maybe because it was deleted. + # The query returning the total count should be consistent with + # the query returning the results. + sql = """ + SELECT COUNT(*) as total_event_reports + FROM event_reports AS er + JOIN room_stats_state ON room_stats_state.room_id = er.room_id + {} + """.format( + where_clause + ) + txn.execute(sql, args) + count = cast(Tuple[int], txn.fetchone())[0] + + sql = """ + SELECT + er.id, + er.received_ts, + er.room_id, + er.event_id, + er.user_id, + er.content, + events.sender, + room_stats_state.canonical_alias, + room_stats_state.name + FROM event_reports AS er + LEFT JOIN events + ON events.event_id = er.event_id + JOIN room_stats_state + ON room_stats_state.room_id = er.room_id + {where_clause} + ORDER BY er.received_ts {order} + LIMIT ? + OFFSET ? + """.format( + where_clause=where_clause, + order=order, + ) + + args += [limit, start] + txn.execute(sql, args) + + event_reports = [] + for row in txn: + try: + s = db_to_json(row[5]).get("score") + r = db_to_json(row[5]).get("reason") + except Exception: + logger.error("Unable to parse json from event_reports: %s", row[0]) + continue + event_reports.append( + { + "id": row[0], + "received_ts": row[1], + "room_id": row[2], + "event_id": row[3], + "user_id": row[4], + "score": s, + "reason": r, + "sender": row[6], + "canonical_alias": row[7], + "name": row[8], + } + ) + + return event_reports, count + + return await self.db_pool.runInteraction( + "get_event_reports_paginate", _get_event_reports_paginate_txn + ) + async def delete_event_report(self, report_id: int) -> bool: """Remove an event report from database. @@ -2189,183 +2366,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): ) return next_id - async def get_event_report(self, report_id: int) -> Optional[Dict[str, Any]]: - """Retrieve an event report - - Args: - report_id: ID of reported event in database - Returns: - JSON dict of information from an event report or None if the - report does not exist. - """ - - def _get_event_report_txn( - txn: LoggingTransaction, report_id: int - ) -> Optional[Dict[str, Any]]: - sql = """ - SELECT - er.id, - er.received_ts, - er.room_id, - er.event_id, - er.user_id, - er.content, - events.sender, - room_stats_state.canonical_alias, - room_stats_state.name, - event_json.json AS event_json - FROM event_reports AS er - LEFT JOIN events - ON events.event_id = er.event_id - JOIN event_json - ON event_json.event_id = er.event_id - JOIN room_stats_state - ON room_stats_state.room_id = er.room_id - WHERE er.id = ? - """ - - txn.execute(sql, [report_id]) - row = txn.fetchone() - - if not row: - return None - - event_report = { - "id": row[0], - "received_ts": row[1], - "room_id": row[2], - "event_id": row[3], - "user_id": row[4], - "score": db_to_json(row[5]).get("score"), - "reason": db_to_json(row[5]).get("reason"), - "sender": row[6], - "canonical_alias": row[7], - "name": row[8], - "event_json": db_to_json(row[9]), - } - - return event_report - - return await self.db_pool.runInteraction( - "get_event_report", _get_event_report_txn, report_id - ) - - async def get_event_reports_paginate( - self, - start: int, - limit: int, - direction: Direction = Direction.BACKWARDS, - user_id: Optional[str] = None, - room_id: Optional[str] = None, - ) -> Tuple[List[Dict[str, Any]], int]: - """Retrieve a paginated list of event reports - - Args: - start: event offset to begin the query from - limit: number of rows to retrieve - direction: Whether to fetch the most recent first (backwards) or the - oldest first (forwards) - user_id: search for user_id. Ignored if user_id is None - room_id: search for room_id. Ignored if room_id is None - Returns: - Tuple of: - json list of event reports - total number of event reports matching the filter criteria - """ - - def _get_event_reports_paginate_txn( - txn: LoggingTransaction, - ) -> Tuple[List[Dict[str, Any]], int]: - filters = [] - args: List[object] = [] - - if user_id: - filters.append("er.user_id LIKE ?") - args.extend(["%" + user_id + "%"]) - if room_id: - filters.append("er.room_id LIKE ?") - args.extend(["%" + room_id + "%"]) - - if direction == Direction.BACKWARDS: - order = "DESC" - else: - order = "ASC" - - where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" - - # We join on room_stats_state despite not using any columns from it - # because the join can influence the number of rows returned; - # e.g. a room that doesn't have state, maybe because it was deleted. - # The query returning the total count should be consistent with - # the query returning the results. - sql = """ - SELECT COUNT(*) as total_event_reports - FROM event_reports AS er - JOIN room_stats_state ON room_stats_state.room_id = er.room_id - {} - """.format( - where_clause - ) - txn.execute(sql, args) - count = cast(Tuple[int], txn.fetchone())[0] - - sql = """ - SELECT - er.id, - er.received_ts, - er.room_id, - er.event_id, - er.user_id, - er.content, - events.sender, - room_stats_state.canonical_alias, - room_stats_state.name - FROM event_reports AS er - LEFT JOIN events - ON events.event_id = er.event_id - JOIN room_stats_state - ON room_stats_state.room_id = er.room_id - {where_clause} - ORDER BY er.received_ts {order} - LIMIT ? - OFFSET ? - """.format( - where_clause=where_clause, - order=order, - ) - - args += [limit, start] - txn.execute(sql, args) - - event_reports = [] - for row in txn: - try: - s = db_to_json(row[5]).get("score") - r = db_to_json(row[5]).get("reason") - except Exception: - logger.error("Unable to parse json from event_reports: %s", row[0]) - continue - event_reports.append( - { - "id": row[0], - "received_ts": row[1], - "room_id": row[2], - "event_id": row[3], - "user_id": row[4], - "score": s, - "reason": r, - "sender": row[6], - "canonical_alias": row[7], - "name": row[8], - } - ) - - return event_reports, count - - return await self.db_pool.runInteraction( - "get_event_reports_paginate", _get_event_reports_paginate_txn - ) - async def block_room(self, room_id: str, user_id: str) -> None: """Marks the room as blocked. -- cgit 1.5.1 From 1eea662780a6325af0a61ceb447b4c91a2d3ac98 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Thu, 2 Mar 2023 18:27:00 +0000 Subject: Add a `get_next_txn` method to `StreamIdGenerator` to match `MultiWriterIdGenerator` (#15191 --- changelog.d/15191.misc | 1 + synapse/storage/databases/main/account_data.py | 11 ++----- synapse/storage/util/id_generators.py | 45 +++++++++++++++++++++++++- synapse/storage/util/sequence.py | 2 +- 4 files changed, 48 insertions(+), 11 deletions(-) create mode 100644 changelog.d/15191.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/15191.misc b/changelog.d/15191.misc new file mode 100644 index 0000000000..579f76d451 --- /dev/null +++ b/changelog.d/15191.misc @@ -0,0 +1 @@ +Add a `get_next_txn` method to `StreamIdGenerator` to match `MultiWriterIdGenerator`. \ No newline at end of file diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 308d19440f..2d2ba74347 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -40,7 +40,6 @@ from synapse.storage.databases.main.push_rule import PushRulesWorkerStore from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import ( AbstractStreamIdGenerator, - AbstractStreamIdTracker, MultiWriterIdGenerator, StreamIdGenerator, ) @@ -64,14 +63,12 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) ): super().__init__(database, db_conn, hs) - # `_can_write_to_account_data` indicates whether the current worker is allowed - # to write account data. A value of `True` implies that `_account_data_id_gen` - # is an `AbstractStreamIdGenerator` and not just a tracker. - self._account_data_id_gen: AbstractStreamIdTracker self._can_write_to_account_data = ( self._instance_name in hs.config.worker.writers.account_data ) + self._account_data_id_gen: AbstractStreamIdGenerator + if isinstance(database.engine, PostgresEngine): self._account_data_id_gen = MultiWriterIdGenerator( db_conn=db_conn, @@ -558,7 +555,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) The maximum stream ID. """ assert self._can_write_to_account_data - assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) content_json = json_encoder.encode(content) @@ -598,7 +594,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) data to delete. """ assert self._can_write_to_account_data - assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) def _remove_account_data_for_room_txn( txn: LoggingTransaction, next_id: int @@ -663,7 +658,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) The maximum stream ID. """ assert self._can_write_to_account_data - assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) async with self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction( @@ -770,7 +764,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) to delete. """ assert self._can_write_to_account_data - assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) def _remove_account_data_for_user_txn( txn: LoggingTransaction, next_id: int diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 9adff3f4f5..334d3d718b 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -158,6 +158,15 @@ class AbstractStreamIdGenerator(AbstractStreamIdTracker): """ raise NotImplementedError() + @abc.abstractmethod + def get_next_txn(self, txn: LoggingTransaction) -> int: + """ + Usage: + stream_id_gen.get_next_txn(txn) + # ... persist events ... + """ + raise NotImplementedError() + class StreamIdGenerator(AbstractStreamIdGenerator): """Generates and tracks stream IDs for a stream with a single writer. @@ -263,6 +272,40 @@ class StreamIdGenerator(AbstractStreamIdGenerator): return _AsyncCtxManagerWrapper(manager()) + def get_next_txn(self, txn: LoggingTransaction) -> int: + """ + Retrieve the next stream ID from within a database transaction. + + Clean-up functions will be called when the transaction finishes. + + Args: + txn: The database transaction object. + + Returns: + The next stream ID. + """ + if not self._is_writer: + raise Exception("Tried to allocate stream ID on non-writer") + + # Get the next stream ID. + with self._lock: + self._current += self._step + next_id = self._current + + self._unfinished_ids[next_id] = next_id + + def clear_unfinished_id(id_to_clear: int) -> None: + """A function to mark processing this ID as finished""" + with self._lock: + self._unfinished_ids.pop(id_to_clear) + + # Mark this ID as finished once the database transaction itself finishes. + txn.call_after(clear_unfinished_id, next_id) + txn.call_on_exception(clear_unfinished_id, next_id) + + # Return the new ID. + return next_id + def get_current_token(self) -> int: if not self._is_writer: return self._current @@ -568,7 +611,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): """ Usage: - stream_id = stream_id_gen.get_next(txn) + stream_id = stream_id_gen.get_next_txn(txn) # ... persist event ... """ diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index 75268cbe15..80915216de 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -205,7 +205,7 @@ class LocalSequenceGenerator(SequenceGenerator): """ Args: get_first_callback: a callback which is called on the first call to - get_next_id_txn; should return the curreent maximum id + get_next_id_txn; should return the current maximum id """ # the callback. this is cleared after it is called, so that it can be GCed. self._callback: Optional[GetFirstCallbackType] = get_first_callback -- cgit 1.5.1 From 15e975f68fc354843a0647e53f285696e86de89b Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 3 Mar 2023 10:51:57 +0000 Subject: Experimental MSC3890 Implementation: Fix deleting account data when using an account data writer worker (#14869) --- changelog.d/14869.bugfix | 1 + synapse/handlers/account_data.py | 7 ------ synapse/storage/databases/main/account_data.py | 34 ++++++++++++-------------- 3 files changed, 16 insertions(+), 26 deletions(-) create mode 100644 changelog.d/14869.bugfix (limited to 'synapse/storage/databases') diff --git a/changelog.d/14869.bugfix b/changelog.d/14869.bugfix new file mode 100644 index 0000000000..865b597741 --- /dev/null +++ b/changelog.d/14869.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.75.0rc1 that caused experimental support for deleting account data to raise an internal server error while using an account data writer worker. \ No newline at end of file diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 797de46dbc..7e01c18c6c 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -155,9 +155,6 @@ class AccountDataHandler: max_stream_id = await self._store.remove_account_data_for_room( user_id, room_id, account_data_type ) - if max_stream_id is None: - # The referenced account data did not exist, so no delete occurred. - return None self._notifier.on_new_event( StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id] @@ -230,9 +227,6 @@ class AccountDataHandler: max_stream_id = await self._store.remove_account_data_for_user( user_id, account_data_type ) - if max_stream_id is None: - # The referenced account data did not exist, so no delete occurred. - return None self._notifier.on_new_event( StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id] @@ -248,7 +242,6 @@ class AccountDataHandler: instance_name=random.choice(self._account_data_writers), user_id=user_id, account_data_type=account_data_type, - content={}, ) return response["max_stream_id"] diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 2d2ba74347..a9843f6e17 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -581,7 +581,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) async def remove_account_data_for_room( self, user_id: str, room_id: str, account_data_type: str - ) -> Optional[int]: + ) -> int: """Delete the room account data for the user of a given type. Args: @@ -632,15 +632,13 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) next_id, ) - if not row_updated: - return None - - self._account_data_stream_cache.entity_has_changed(user_id, next_id) - self.get_room_account_data_for_user.invalidate((user_id,)) - self.get_account_data_for_room.invalidate((user_id, room_id)) - self.get_account_data_for_room_and_type.prefill( - (user_id, room_id, account_data_type), {} - ) + if row_updated: + self._account_data_stream_cache.entity_has_changed(user_id, next_id) + self.get_room_account_data_for_user.invalidate((user_id,)) + self.get_account_data_for_room.invalidate((user_id, room_id)) + self.get_account_data_for_room_and_type.prefill( + (user_id, room_id, account_data_type), {} + ) return self._account_data_id_gen.get_current_token() @@ -747,7 +745,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) self, user_id: str, account_data_type: str, - ) -> Optional[int]: + ) -> int: """ Delete a single piece of user account data by type. @@ -833,14 +831,12 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) next_id, ) - if not row_updated: - return None - - self._account_data_stream_cache.entity_has_changed(user_id, next_id) - self.get_global_account_data_for_user.invalidate((user_id,)) - self.get_global_account_data_by_type_for_user.prefill( - (user_id, account_data_type), {} - ) + if row_updated: + self._account_data_stream_cache.entity_has_changed(user_id, next_id) + self.get_global_account_data_for_user.invalidate((user_id,)) + self.get_global_account_data_by_type_for_user.prefill( + (user_id, account_data_type), {} + ) return self._account_data_id_gen.get_current_token() -- cgit 1.5.1 From 02f74f3a997a4356b5bda957ebc51a829dad15f9 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 3 Mar 2023 08:13:37 -0500 Subject: Combine AbstractStreamIdTracker and AbstractStreamIdGenerator. (#15192) AbstractStreamIdTracker (now) has only a single sub-class: AbstractStreamIdGenerator, combine them to simplify some code and remove any direct references to AbstractStreamIdTracker. --- changelog.d/15192.misc | 1 + synapse/storage/databases/main/devices.py | 7 ++----- synapse/storage/databases/main/events_worker.py | 5 ++--- synapse/storage/databases/main/push_rule.py | 3 +-- synapse/storage/databases/main/pusher.py | 3 +-- synapse/storage/databases/main/receipts.py | 6 +++--- synapse/storage/util/id_generators.py | 17 +++++------------ 7 files changed, 15 insertions(+), 27 deletions(-) create mode 100644 changelog.d/15192.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/15192.misc b/changelog.d/15192.misc new file mode 100644 index 0000000000..1076686875 --- /dev/null +++ b/changelog.d/15192.misc @@ -0,0 +1 @@ +Combine `AbstractStreamIdTracker` and `AbstractStreamIdGenerator`. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 0dd15f16ff..5503621ad6 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -52,7 +52,6 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.types import Cursor from synapse.storage.util.id_generators import ( AbstractStreamIdGenerator, - AbstractStreamIdTracker, StreamIdGenerator, ) from synapse.types import JsonDict, StrCollection, get_verify_key_from_cross_signing_key @@ -91,7 +90,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): # In the worker store this is an ID tracker which we overwrite in the non-worker # class below that is used on the main process. - self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + self._device_list_id_gen = StreamIdGenerator( db_conn, hs.get_replication_notifier(), "device_lists_stream", @@ -712,9 +711,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): The new stream ID. """ - # TODO: this looks like it's _writing_. Should this be on DeviceStore rather - # than DeviceWorkerStore? - async with self._device_list_id_gen.get_next() as stream_id: # type: ignore[attr-defined] + async with self._device_list_id_gen.get_next() as stream_id: await self.db_pool.runInteraction( "add_user_sig_change_to_streams", self._add_user_signature_change_txn, diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index b7e7498125..20b7a68362 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -72,7 +72,6 @@ from synapse.storage.engines import PostgresEngine from synapse.storage.types import Cursor from synapse.storage.util.id_generators import ( AbstractStreamIdGenerator, - AbstractStreamIdTracker, MultiWriterIdGenerator, StreamIdGenerator, ) @@ -187,8 +186,8 @@ class EventsWorkerStore(SQLBaseStore): ): super().__init__(database, db_conn, hs) - self._stream_id_gen: AbstractStreamIdTracker - self._backfill_id_gen: AbstractStreamIdTracker + self._stream_id_gen: AbstractStreamIdGenerator + self._backfill_id_gen: AbstractStreamIdGenerator if isinstance(database.engine, PostgresEngine): # If we're using Postgres than we can use `MultiWriterIdGenerator` # regardless of whether this process writes to the streams or not. diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 9b2bbe060d..9f862f00c1 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -46,7 +46,6 @@ from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException from synapse.storage.util.id_generators import ( AbstractStreamIdGenerator, - AbstractStreamIdTracker, IdGenerator, StreamIdGenerator, ) @@ -118,7 +117,7 @@ class PushRulesWorkerStore( # In the worker store this is an ID tracker which we overwrite in the non-worker # class below that is used on the main process. - self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + self._push_rules_stream_id_gen = StreamIdGenerator( db_conn, hs.get_replication_notifier(), "push_rules_stream", diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index fddbc07afa..9a24f7a655 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -36,7 +36,6 @@ from synapse.storage.database import ( ) from synapse.storage.util.id_generators import ( AbstractStreamIdGenerator, - AbstractStreamIdTracker, StreamIdGenerator, ) from synapse.types import JsonDict @@ -60,7 +59,7 @@ class PusherWorkerStore(SQLBaseStore): # In the worker store this is an ID tracker which we overwrite in the non-worker # class below that is used on the main process. - self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator( + self._pushers_id_gen = StreamIdGenerator( db_conn, hs.get_replication_notifier(), "pushers", diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 92a82240ab..074942b167 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -39,7 +39,7 @@ from synapse.storage.database import ( from synapse.storage.engines import PostgresEngine from synapse.storage.engines._base import IsolationLevel from synapse.storage.util.id_generators import ( - AbstractStreamIdTracker, + AbstractStreamIdGenerator, MultiWriterIdGenerator, StreamIdGenerator, ) @@ -65,7 +65,7 @@ class ReceiptsWorkerStore(SQLBaseStore): # In the worker store this is an ID tracker which we overwrite in the non-worker # class below that is used on the main process. - self._receipts_id_gen: AbstractStreamIdTracker + self._receipts_id_gen: AbstractStreamIdGenerator if isinstance(database.engine, PostgresEngine): self._can_write_to_receipts = ( @@ -768,7 +768,7 @@ class ReceiptsWorkerStore(SQLBaseStore): "insert_receipt_conv", self._graph_to_linear, room_id, event_ids ) - async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined] + async with self._receipts_id_gen.get_next() as stream_id: event_ts = await self.db_pool.runInteraction( "insert_linearized_receipt", self._insert_linearized_receipt_txn, diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 334d3d718b..d2c874b9a8 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -93,8 +93,11 @@ def _load_current_id( return res -class AbstractStreamIdTracker(metaclass=abc.ABCMeta): - """Tracks the "current" stream ID of a stream that may have multiple writers. +class AbstractStreamIdGenerator(metaclass=abc.ABCMeta): + """Generates or tracks stream IDs for a stream that may have multiple writers. + + Each stream ID represents a write transaction, whose completion is tracked + so that the "current" stream ID of the stream can be determined. Stream IDs are monotonically increasing or decreasing integers representing write transactions. The "current" stream ID is the stream ID such that all transactions @@ -130,16 +133,6 @@ class AbstractStreamIdTracker(metaclass=abc.ABCMeta): """ raise NotImplementedError() - -class AbstractStreamIdGenerator(AbstractStreamIdTracker): - """Generates stream IDs for a stream that may have multiple writers. - - Each stream ID represents a write transaction, whose completion is tracked - so that the "current" stream ID of the stream can be determined. - - See `AbstractStreamIdTracker` for more details. - """ - @abc.abstractmethod def get_next(self) -> AsyncContextManager[int]: """ -- cgit 1.5.1 From c69aae94cda9b62b2a82584b2f5ee72a95feb435 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 7 Mar 2023 08:51:34 +0000 Subject: Split up txn for fetching device keys (#15215) We look up keys in batches, but we should do that outside of the transaction to avoid starving the database pool. --- changelog.d/15215.misc | 1 + synapse/storage/database.py | 10 +++++++++- synapse/storage/databases/main/end_to_end_keys.py | 24 +++++++++++++++-------- 3 files changed, 26 insertions(+), 9 deletions(-) create mode 100644 changelog.d/15215.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/15215.misc b/changelog.d/15215.misc new file mode 100644 index 0000000000..fe52a56a7e --- /dev/null +++ b/changelog.d/15215.misc @@ -0,0 +1 @@ +Refactor database transaction for query users' devices to reduce database pool contention. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index feaa6cdd07..5efe31aa19 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -672,7 +672,15 @@ class DatabasePool: f = cast(types.FunctionType, func) # type: ignore[redundant-cast] if f.__closure__: for i, cell in enumerate(f.__closure__): - if inspect.isgenerator(cell.cell_contents): + try: + contents = cell.cell_contents + except ValueError: + # cell.cell_contents can raise if the "cell" is empty, + # which indicates that the variable is currently + # unbound. + continue + + if inspect.isgenerator(contents): logger.error( "Programming error: function %s references generator %s " "via its closure", diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index b9c39b1718..a3b6c8ae8e 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -244,9 +244,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker set_tag("include_all_devices", include_all_devices) set_tag("include_deleted_devices", include_deleted_devices) - result = await self.db_pool.runInteraction( - "get_e2e_device_keys", - self._get_e2e_device_keys_txn, + result = await self._get_e2e_device_keys( query_list, include_all_devices, include_deleted_devices, @@ -285,9 +283,8 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker log_kv(result) return result - def _get_e2e_device_keys_txn( + async def _get_e2e_device_keys( self, - txn: LoggingTransaction, query_list: Collection[Tuple[str, Optional[str]]], include_all_devices: bool = False, include_deleted_devices: bool = False, @@ -319,7 +316,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker if user_list: user_id_in_list_clause, user_args = make_in_list_sql_clause( - txn.database_engine, "user_id", user_list + self.database_engine, "user_id", user_list ) query_clauses.append(user_id_in_list_clause) query_params_list.append(user_args) @@ -332,13 +329,16 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker user_device_id_in_list_clause, user_device_args, ) = make_tuple_in_list_sql_clause( - txn.database_engine, ("user_id", "device_id"), user_device_batch + self.database_engine, ("user_id", "device_id"), user_device_batch ) query_clauses.append(user_device_id_in_list_clause) query_params_list.append(user_device_args) result: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] = {} - for query_clause, query_params in zip(query_clauses, query_params_list): + + def get_e2e_device_keys_txn( + txn: LoggingTransaction, query_clause: str, query_params: list + ) -> None: sql = ( "SELECT user_id, device_id, " " d.display_name, " @@ -361,6 +361,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker display_name, db_to_json(key_json) if key_json else None ) + for query_clause, query_params in zip(query_clauses, query_params_list): + await self.db_pool.runInteraction( + "_get_e2e_device_keys", + get_e2e_device_keys_txn, + query_clause, + query_params, + ) + if include_deleted_devices: for user_id, device_id in deleted_devices: if device_id is None: -- cgit 1.5.1