From a7bdf98d01d2225a479753a85ba81adf02b16a32 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 5 Aug 2020 21:38:57 +0100 Subject: Rename database classes to make some sense (#8033) --- synapse/storage/databases/main/keys.py | 210 +++++++++++++++++++++++++++++++++ 1 file changed, 210 insertions(+) create mode 100644 synapse/storage/databases/main/keys.py (limited to 'synapse/storage/databases/main/keys.py') diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py new file mode 100644 index 0000000000..384e9c5eb0 --- /dev/null +++ b/synapse/storage/databases/main/keys.py @@ -0,0 +1,210 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import logging + +from signedjson.key import decode_verify_key_bytes + +from synapse.storage._base import SQLBaseStore +from synapse.storage.keys import FetchKeyResult +from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.iterutils import batch_iter + +logger = logging.getLogger(__name__) + + +db_binary_type = memoryview + + +class KeyStore(SQLBaseStore): + """Persistence for signature verification keys + """ + + @cached() + def _get_server_verify_key(self, server_name_and_key_id): + raise NotImplementedError() + + @cachedList( + cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids" + ) + def get_server_verify_keys(self, server_name_and_key_ids): + """ + Args: + server_name_and_key_ids (iterable[Tuple[str, str]]): + iterable of (server_name, key-id) tuples to fetch keys for + + Returns: + Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]: + map from (server_name, key_id) -> FetchKeyResult, or None if the key is + unknown + """ + keys = {} + + def _get_keys(txn, batch): + """Processes a batch of keys to fetch, and adds the result to `keys`.""" + + # batch_iter always returns tuples so it's safe to do len(batch) + sql = ( + "SELECT server_name, key_id, verify_key, ts_valid_until_ms " + "FROM server_signature_keys WHERE 1=0" + ) + " OR (server_name=? AND key_id=?)" * len(batch) + + txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) + + for row in txn: + server_name, key_id, key_bytes, ts_valid_until_ms = row + + if ts_valid_until_ms is None: + # Old keys may be stored with a ts_valid_until_ms of null, + # in which case we treat this as if it was set to `0`, i.e. + # it won't match key requests that define a minimum + # `ts_valid_until_ms`. + ts_valid_until_ms = 0 + + res = FetchKeyResult( + verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)), + valid_until_ts=ts_valid_until_ms, + ) + keys[(server_name, key_id)] = res + + def _txn(txn): + for batch in batch_iter(server_name_and_key_ids, 50): + _get_keys(txn, batch) + return keys + + return self.db_pool.runInteraction("get_server_verify_keys", _txn) + + def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys): + """Stores NACL verification keys for remote servers. + Args: + from_server (str): Where the verification keys were looked up + ts_added_ms (int): The time to record that the key was added + verify_keys (iterable[tuple[str, str, FetchKeyResult]]): + keys to be stored. Each entry is a triplet of + (server_name, key_id, key). + """ + key_values = [] + value_values = [] + invalidations = [] + for server_name, key_id, fetch_result in verify_keys: + key_values.append((server_name, key_id)) + value_values.append( + ( + from_server, + ts_added_ms, + fetch_result.valid_until_ts, + db_binary_type(fetch_result.verify_key.encode()), + ) + ) + # invalidate takes a tuple corresponding to the params of + # _get_server_verify_key. _get_server_verify_key only takes one + # param, which is itself the 2-tuple (server_name, key_id). + invalidations.append((server_name, key_id)) + + def _invalidate(res): + f = self._get_server_verify_key.invalidate + for i in invalidations: + f((i,)) + return res + + return self.db_pool.runInteraction( + "store_server_verify_keys", + self.db_pool.simple_upsert_many_txn, + table="server_signature_keys", + key_names=("server_name", "key_id"), + key_values=key_values, + value_names=( + "from_server", + "ts_added_ms", + "ts_valid_until_ms", + "verify_key", + ), + value_values=value_values, + ).addCallback(_invalidate) + + def store_server_keys_json( + self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes + ): + """Stores the JSON bytes for a set of keys from a server + The JSON should be signed by the originating server, the intermediate + server, and by this server. Updates the value for the + (server_name, key_id, from_server) triplet if one already existed. + Args: + server_name (str): The name of the server. + key_id (str): The identifer of the key this JSON is for. + from_server (str): The server this JSON was fetched from. + ts_now_ms (int): The time now in milliseconds. + ts_valid_until_ms (int): The time when this json stops being valid. + key_json (bytes): The encoded JSON. + """ + return self.db_pool.simple_upsert( + table="server_keys_json", + keyvalues={ + "server_name": server_name, + "key_id": key_id, + "from_server": from_server, + }, + values={ + "server_name": server_name, + "key_id": key_id, + "from_server": from_server, + "ts_added_ms": ts_now_ms, + "ts_valid_until_ms": ts_expires_ms, + "key_json": db_binary_type(key_json_bytes), + }, + desc="store_server_keys_json", + ) + + def get_server_keys_json(self, server_keys): + """Retrive the key json for a list of server_keys and key ids. + If no keys are found for a given server, key_id and source then + that server, key_id, and source triplet entry will be an empty list. + The JSON is returned as a byte array so that it can be efficiently + used in an HTTP response. + Args: + server_keys (list): List of (server_name, key_id, source) triplets. + Returns: + Deferred[dict[Tuple[str, str, str|None], list[dict]]]: + Dict mapping (server_name, key_id, source) triplets to lists of dicts + """ + + def _get_server_keys_json_txn(txn): + results = {} + for server_name, key_id, from_server in server_keys: + keyvalues = {"server_name": server_name} + if key_id is not None: + keyvalues["key_id"] = key_id + if from_server is not None: + keyvalues["from_server"] = from_server + rows = self.db_pool.simple_select_list_txn( + txn, + "server_keys_json", + keyvalues=keyvalues, + retcols=( + "key_id", + "from_server", + "ts_added_ms", + "ts_valid_until_ms", + "key_json", + ), + ) + results[(server_name, key_id, from_server)] = rows + return results + + return self.db_pool.runInteraction( + "get_server_keys_json", _get_server_keys_json_txn + ) -- cgit 1.5.1 From 76c43f086a3c5feeab1fd0f916086930d44fb1cf Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 20 Aug 2020 06:39:55 -0400 Subject: Do not assume calls to runInteraction return Deferreds. (#8133) --- changelog.d/8133.misc | 1 + synapse/crypto/keyring.py | 7 +++--- synapse/module_api/__init__.py | 10 +++++--- synapse/spam_checker_api/__init__.py | 6 +++-- synapse/storage/databases/main/group_server.py | 7 +++--- synapse/storage/databases/main/keys.py | 28 ++++++++++++---------- .../storage/databases/main/user_erasure_store.py | 13 +++++----- 7 files changed, 41 insertions(+), 31 deletions(-) create mode 100644 changelog.d/8133.misc (limited to 'synapse/storage/databases/main/keys.py') diff --git a/changelog.d/8133.misc b/changelog.d/8133.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8133.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 28ef7cfdb9..81c4b430b2 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -757,9 +757,8 @@ class ServerKeyFetcher(BaseV2KeyFetcher): except Exception: logger.exception("Error getting keys %s from %s", key_ids, server_name) - return await yieldable_gather_results( - get_key, keys_to_fetch.items() - ).addCallback(lambda _: results) + await yieldable_gather_results(get_key, keys_to_fetch.items()) + return results async def get_server_verify_key_v2_direct(self, server_name, key_ids): """ @@ -769,7 +768,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher): key_ids (iterable[str]): Returns: - Deferred[dict[str, FetchKeyResult]]: map from key ID to lookup result + dict[str, FetchKeyResult]: map from key ID to lookup result Raises: KeyLookupError if there was a problem making the lookup diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index c2fb757d9a..ae0e359a77 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -167,8 +167,10 @@ class ModuleApi(object): external_id: id on that system user_id: complete mxid that it is mapped to """ - return self._store.record_user_external_id( - auth_provider_id, remote_user_id, registered_user_id + return defer.ensureDeferred( + self._store.record_user_external_id( + auth_provider_id, remote_user_id, registered_user_id + ) ) def generate_short_term_login_token( @@ -223,7 +225,9 @@ class ModuleApi(object): Returns: Deferred[object]: result of func """ - return self._store.db_pool.runInteraction(desc, func, *args, **kwargs) + return defer.ensureDeferred( + self._store.db_pool.runInteraction(desc, func, *args, **kwargs) + ) def complete_sso_login( self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py index 4d9b13ac04..7f63f1bfa0 100644 --- a/synapse/spam_checker_api/__init__.py +++ b/synapse/spam_checker_api/__init__.py @@ -48,8 +48,10 @@ class SpamCheckerApi(object): twisted.internet.defer.Deferred[list(synapse.events.FrozenEvent)]: The filtered state events in the room. """ - state_ids = yield self._store.get_filtered_current_state_ids( - room_id=room_id, state_filter=StateFilter.from_types(types) + state_ids = yield defer.ensureDeferred( + self._store.get_filtered_current_state_ids( + room_id=room_id, state_filter=StateFilter.from_types(types) + ) ) state = yield defer.ensureDeferred(self._store.get_events(state_ids.values())) return state.values() diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 380db3a3f3..0e3b8739c6 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -341,14 +341,15 @@ class GroupServerWorkerStore(SQLBaseStore): "get_users_for_summary_by_role", _get_users_for_summary_txn ) - def is_user_in_group(self, user_id, group_id): - return self.db_pool.simple_select_one_onecol( + async def is_user_in_group(self, user_id: str, group_id: str) -> bool: + result = await self.db_pool.simple_select_one_onecol( table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="user_id", allow_none=True, desc="is_user_in_group", - ).addCallback(lambda r: bool(r)) + ) + return bool(result) def is_user_admin_in_group(self, group_id, user_id): return self.db_pool.simple_select_one_onecol( diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index 384e9c5eb0..fadcad51e7 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -16,6 +16,7 @@ import itertools import logging +from typing import Iterable, Tuple from signedjson.key import decode_verify_key_bytes @@ -88,12 +89,17 @@ class KeyStore(SQLBaseStore): return self.db_pool.runInteraction("get_server_verify_keys", _txn) - def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys): + async def store_server_verify_keys( + self, + from_server: str, + ts_added_ms: int, + verify_keys: Iterable[Tuple[str, str, FetchKeyResult]], + ) -> None: """Stores NACL verification keys for remote servers. Args: - from_server (str): Where the verification keys were looked up - ts_added_ms (int): The time to record that the key was added - verify_keys (iterable[tuple[str, str, FetchKeyResult]]): + from_server: Where the verification keys were looked up + ts_added_ms: The time to record that the key was added + verify_keys: keys to be stored. Each entry is a triplet of (server_name, key_id, key). """ @@ -115,13 +121,7 @@ class KeyStore(SQLBaseStore): # param, which is itself the 2-tuple (server_name, key_id). invalidations.append((server_name, key_id)) - def _invalidate(res): - f = self._get_server_verify_key.invalidate - for i in invalidations: - f((i,)) - return res - - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "store_server_verify_keys", self.db_pool.simple_upsert_many_txn, table="server_signature_keys", @@ -134,7 +134,11 @@ class KeyStore(SQLBaseStore): "verify_key", ), value_values=value_values, - ).addCallback(_invalidate) + ) + + invalidate = self._get_server_verify_key.invalidate + for i in invalidations: + invalidate((i,)) def store_server_keys_json( self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py index da23fe7355..e3547e53b3 100644 --- a/synapse/storage/databases/main/user_erasure_store.py +++ b/synapse/storage/databases/main/user_erasure_store.py @@ -13,30 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -import operator - from synapse.storage._base import SQLBaseStore from synapse.util.caches.descriptors import cached, cachedList class UserErasureWorkerStore(SQLBaseStore): @cached() - def is_user_erased(self, user_id): + async def is_user_erased(self, user_id: str) -> bool: """ Check if the given user id has requested erasure Args: - user_id (str): full user id to check + user_id: full user id to check Returns: - Deferred[bool]: True if the user has requested erasure + True if the user has requested erasure """ - return self.db_pool.simple_select_onecol( + result = await self.db_pool.simple_select_onecol( table="erased_users", keyvalues={"user_id": user_id}, retcol="1", desc="is_user_erased", - ).addCallback(operator.truth) + ) + return bool(result) @cachedList(cached_method_name="is_user_erased", list_name="user_ids") async def are_users_erased(self, user_ids): -- cgit 1.5.1 From 9b7ac03af3e7ceae7d1933db566ee407cfdef72d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 27 Aug 2020 13:38:41 -0400 Subject: Convert calls of async database methods to async (#8166) --- changelog.d/8166.misc | 1 + synapse/federation/persistence.py | 16 +++++++----- synapse/federation/units.py | 4 +-- synapse/storage/databases/main/appservice.py | 6 ++--- synapse/storage/databases/main/devices.py | 4 +-- synapse/storage/databases/main/group_server.py | 30 ++++++++++++++++------ synapse/storage/databases/main/keys.py | 26 +++++++++++-------- synapse/storage/databases/main/media_repository.py | 22 ++++++++-------- synapse/storage/databases/main/openid.py | 6 +++-- synapse/storage/databases/main/profile.py | 10 +++++--- synapse/storage/databases/main/registration.py | 29 ++++++++++----------- synapse/storage/databases/main/room.py | 16 ++++++++---- synapse/storage/databases/main/stats.py | 10 ++++---- synapse/storage/databases/main/transactions.py | 18 +++++++------ 14 files changed, 114 insertions(+), 84 deletions(-) create mode 100644 changelog.d/8166.misc (limited to 'synapse/storage/databases/main/keys.py') diff --git a/changelog.d/8166.misc b/changelog.d/8166.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8166.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index d68b4bd670..769cd5de28 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -21,7 +21,9 @@ These actions are mostly only used by the :py:mod:`.replication` module. import logging +from synapse.federation.units import Transaction from synapse.logging.utils import log_function +from synapse.types import JsonDict logger = logging.getLogger(__name__) @@ -49,15 +51,15 @@ class TransactionActions(object): return self.store.get_received_txn_response(transaction.transaction_id, origin) @log_function - def set_response(self, origin, transaction, code, response): + async def set_response( + self, origin: str, transaction: Transaction, code: int, response: JsonDict + ) -> None: """ Persist how we responded to a transaction. - - Returns: - Deferred """ - if not transaction.transaction_id: + transaction_id = transaction.transaction_id # type: ignore + if not transaction_id: raise RuntimeError("Cannot persist a transaction with no transaction_id") - return self.store.set_received_txn_response( - transaction.transaction_id, origin, code, response + await self.store.set_received_txn_response( + transaction_id, origin, code, response ) diff --git a/synapse/federation/units.py b/synapse/federation/units.py index 6b32e0dcbf..64d98fc8f6 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -107,9 +107,7 @@ class Transaction(JsonEncodedObject): if "edus" in kwargs and not kwargs["edus"]: del kwargs["edus"] - super(Transaction, self).__init__( - transaction_id=transaction_id, pdus=pdus, **kwargs - ) + super().__init__(transaction_id=transaction_id, pdus=pdus, **kwargs) @staticmethod def create_new(pdus, **kwargs): diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 77723f7d4d..92f56f1602 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -161,16 +161,14 @@ class ApplicationServiceTransactionWorkerStore( return result.get("state") return None - def set_appservice_state(self, service, state): + async def set_appservice_state(self, service, state) -> None: """Set the application service state. Args: service(ApplicationService): The service whose state to set. state(ApplicationServiceState): The connectivity state to apply. - Returns: - An Awaitable which resolves when the state was set successfully. """ - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( "application_services_state", {"as_id": service.id}, {"state": state} ) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index a811a39eb5..ecd3f3b310 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -716,11 +716,11 @@ class DeviceWorkerStore(SQLBaseStore): return {row["user_id"] for row in rows} - def mark_remote_user_device_cache_as_stale(self, user_id: str): + async def mark_remote_user_device_cache_as_stale(self, user_id: str) -> None: """Records that the server has reason to believe the cache of the devices for the remote users is out of date. """ - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="device_lists_remote_resync", keyvalues={"user_id": user_id}, values={}, diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index e3ead71853..8acf254bf3 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -742,7 +742,13 @@ class GroupServerStore(GroupServerWorkerStore): desc="remove_room_from_summary", ) - def upsert_group_category(self, group_id, category_id, profile, is_public): + async def upsert_group_category( + self, + group_id: str, + category_id: str, + profile: Optional[JsonDict], + is_public: Optional[bool], + ) -> None: """Add/update room category for group """ insertion_values = {} @@ -758,7 +764,7 @@ class GroupServerStore(GroupServerWorkerStore): else: update_values["is_public"] = is_public - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, values=update_values, @@ -773,7 +779,13 @@ class GroupServerStore(GroupServerWorkerStore): desc="remove_group_category", ) - def upsert_group_role(self, group_id, role_id, profile, is_public): + async def upsert_group_role( + self, + group_id: str, + role_id: str, + profile: Optional[JsonDict], + is_public: Optional[bool], + ) -> None: """Add/remove user role """ insertion_values = {} @@ -789,7 +801,7 @@ class GroupServerStore(GroupServerWorkerStore): else: update_values["is_public"] = is_public - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, values=update_values, @@ -938,10 +950,10 @@ class GroupServerStore(GroupServerWorkerStore): desc="remove_user_from_summary", ) - def add_group_invite(self, group_id, user_id): + async def add_group_invite(self, group_id: str, user_id: str) -> None: """Record that the group server has invited a user """ - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( table="group_invites", values={"group_id": group_id, "user_id": user_id}, desc="add_group_invite", @@ -1044,8 +1056,10 @@ class GroupServerStore(GroupServerWorkerStore): "remove_user_from_group", _remove_user_from_group_txn ) - def add_room_to_group(self, group_id, room_id, is_public): - return self.db_pool.simple_insert( + async def add_room_to_group( + self, group_id: str, room_id: str, is_public: bool + ) -> None: + await self.db_pool.simple_insert( table="group_rooms", values={"group_id": group_id, "room_id": room_id, "is_public": is_public}, desc="add_room_to_group", diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index fadcad51e7..1c0a049c55 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -140,22 +140,28 @@ class KeyStore(SQLBaseStore): for i in invalidations: invalidate((i,)) - def store_server_keys_json( - self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes - ): + async def store_server_keys_json( + self, + server_name: str, + key_id: str, + from_server: str, + ts_now_ms: int, + ts_expires_ms: int, + key_json_bytes: bytes, + ) -> None: """Stores the JSON bytes for a set of keys from a server The JSON should be signed by the originating server, the intermediate server, and by this server. Updates the value for the (server_name, key_id, from_server) triplet if one already existed. Args: - server_name (str): The name of the server. - key_id (str): The identifer of the key this JSON is for. - from_server (str): The server this JSON was fetched from. - ts_now_ms (int): The time now in milliseconds. - ts_valid_until_ms (int): The time when this json stops being valid. - key_json (bytes): The encoded JSON. + server_name: The name of the server. + key_id: The identifer of the key this JSON is for. + from_server: The server this JSON was fetched from. + ts_now_ms: The time now in milliseconds. + ts_valid_until_ms: The time when this json stops being valid. + key_json_bytes: The encoded JSON. """ - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="server_keys_json", keyvalues={ "server_name": server_name, diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 8361dd63d9..3919ecad69 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -60,7 +60,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="get_local_media", ) - def store_local_media( + async def store_local_media( self, media_id, media_type, @@ -69,8 +69,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): media_length, user_id, url_cache=None, - ): - return self.db_pool.simple_insert( + ) -> None: + await self.db_pool.simple_insert( "local_media_repository", { "media_id": media_id, @@ -141,10 +141,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn) - def store_url_cache( + async def store_url_cache( self, url, response_code, etag, expires_ts, og, media_id, download_ts ): - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( "local_media_repository_url_cache", { "url": url, @@ -172,7 +172,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="get_local_media_thumbnails", ) - def store_local_thumbnail( + async def store_local_thumbnail( self, media_id, thumbnail_width, @@ -181,7 +181,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): thumbnail_method, thumbnail_length, ): - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( "local_media_repository_thumbnails", { "media_id": media_id, @@ -212,7 +212,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="get_cached_remote_media", ) - def store_cached_remote_media( + async def store_cached_remote_media( self, origin, media_id, @@ -222,7 +222,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): upload_name, filesystem_id, ): - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( "remote_media_cache", { "media_origin": origin, @@ -288,7 +288,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="get_remote_media_thumbnails", ) - def store_remote_media_thumbnail( + async def store_remote_media_thumbnail( self, origin, media_id, @@ -299,7 +299,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): thumbnail_method, thumbnail_length, ): - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( "remote_media_cache_thumbnails", { "media_origin": origin, diff --git a/synapse/storage/databases/main/openid.py b/synapse/storage/databases/main/openid.py index dcd1ff911a..4db8949da7 100644 --- a/synapse/storage/databases/main/openid.py +++ b/synapse/storage/databases/main/openid.py @@ -2,8 +2,10 @@ from synapse.storage._base import SQLBaseStore class OpenIdStore(SQLBaseStore): - def insert_open_id_token(self, token, ts_valid_until_ms, user_id): - return self.db_pool.simple_insert( + async def insert_open_id_token( + self, token: str, ts_valid_until_ms: int, user_id: str + ) -> None: + await self.db_pool.simple_insert( table="open_id_tokens", values={ "token": token, diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 858fd92420..301875a672 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -66,8 +66,8 @@ class ProfileWorkerStore(SQLBaseStore): desc="get_from_remote_profile_cache", ) - def create_profile(self, user_localpart): - return self.db_pool.simple_insert( + async def create_profile(self, user_localpart: str) -> None: + await self.db_pool.simple_insert( table="profiles", values={"user_id": user_localpart}, desc="create_profile" ) @@ -93,13 +93,15 @@ class ProfileWorkerStore(SQLBaseStore): class ProfileStore(ProfileWorkerStore): - def add_remote_profile_cache(self, user_id, displayname, avatar_url): + async def add_remote_profile_cache( + self, user_id: str, displayname: str, avatar_url: str + ) -> None: """Ensure we are caching the remote user's profiles. This should only be called when `is_subscribed_remote_profile_for_user` would return true for the user. """ - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="remote_profile_cache", keyvalues={"user_id": user_id}, values={ diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 48bda66f3e..28f7ae0430 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -17,7 +17,7 @@ import logging import re -from typing import Any, Awaitable, Dict, List, Optional +from typing import Any, Dict, List, Optional from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError @@ -549,23 +549,22 @@ class RegistrationWorkerStore(SQLBaseStore): desc="user_delete_threepids", ) - def add_user_bound_threepid(self, user_id, medium, address, id_server): + async def add_user_bound_threepid( + self, user_id: str, medium: str, address: str, id_server: str + ): """The server proxied a bind request to the given identity server on behalf of the given user. We need to remember this in case the user asks us to unbind the threepid. Args: - user_id (str) - medium (str) - address (str) - id_server (str) - - Returns: - Awaitable + user_id + medium + address + id_server """ # We need to use an upsert, in case they user had already bound the # threepid - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="user_threepid_id_server", keyvalues={ "user_id": user_id, @@ -1083,9 +1082,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - def record_user_external_id( + async def record_user_external_id( self, auth_provider: str, external_id: str, user_id: str - ) -> Awaitable: + ) -> None: """Record a mapping from an external user id to a mxid Args: @@ -1093,7 +1092,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): external_id: id on that system user_id: complete mxid that it is mapped to """ - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( table="user_external_ids", values={ "auth_provider": auth_provider, @@ -1237,12 +1236,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): return res if res else False - def add_user_pending_deactivation(self, user_id): + async def add_user_pending_deactivation(self, user_id: str) -> None: """ Adds a user to the table of users who need to be parted from all the rooms they're in """ - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( "users_pending_deactivation", values={"user_id": user_id}, desc="add_user_pending_deactivation", diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 66d7135413..a92641c339 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -27,7 +27,7 @@ from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.search import SearchStore -from synapse.types import ThirdPartyInstanceID +from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -1296,11 +1296,17 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): return self.db_pool.runInteraction("get_rooms", f) - def add_event_report( - self, room_id, event_id, user_id, reason, content, received_ts - ): + async def add_event_report( + self, + room_id: str, + event_id: str, + user_id: str, + reason: str, + content: JsonDict, + received_ts: int, + ) -> None: next_id = self._event_reports_id_gen.get_next() - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( table="event_reports", values={ "id": next_id, diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 9fe97af56a..7af2608ca4 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -16,7 +16,7 @@ import logging from itertools import chain -from typing import Tuple +from typing import Any, Dict, Tuple from twisted.internet.defer import DeferredLock @@ -222,11 +222,11 @@ class StatsStore(StateDeltasStore): desc="stats_incremental_position", ) - def update_room_state(self, room_id, fields): + async def update_room_state(self, room_id: str, fields: Dict[str, Any]) -> None: """ Args: - room_id (str) - fields (dict[str:Any]) + room_id + fields """ # For whatever reason some of the fields may contain null bytes, which @@ -244,7 +244,7 @@ class StatsStore(StateDeltasStore): if field and "\0" in field: fields[col] = None - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="room_stats_state", keyvalues={"room_id": room_id}, values=fields, diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 52668dbdf9..2efcc0dc66 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -21,6 +21,7 @@ from canonicaljson import encode_canonical_json from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool +from synapse.types import JsonDict from synapse.util.caches.expiringcache import ExpiringCache db_binary_type = memoryview @@ -98,20 +99,21 @@ class TransactionStore(SQLBaseStore): else: return None - def set_received_txn_response(self, transaction_id, origin, code, response_dict): - """Persist the response we returened for an incoming transaction, and + async def set_received_txn_response( + self, transaction_id: str, origin: str, code: int, response_dict: JsonDict + ) -> None: + """Persist the response we returned for an incoming transaction, and should return for subsequent transactions with the same transaction_id and origin. Args: - txn - transaction_id (str) - origin (str) - code (int) - response_json (str) + transaction_id: The incoming transaction ID. + origin: The origin server. + code: The response code. + response_dict: The response, to be encoded into JSON. """ - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( table="received_transactions", values={ "transaction_id": transaction_id, -- cgit 1.5.1 From 5c03134d0f8dd157ea1800ce1a4bcddbdb73ddf1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 28 Aug 2020 07:54:27 -0400 Subject: Convert additional database code to async/await. (#8195) --- changelog.d/8195.misc | 1 + synapse/appservice/__init__.py | 19 ++- synapse/federation/persistence.py | 19 ++- synapse/handlers/federation.py | 4 +- synapse/storage/databases/main/appservice.py | 15 +- synapse/storage/databases/main/deviceinbox.py | 12 +- synapse/storage/databases/main/e2e_room_keys.py | 30 ++-- synapse/storage/databases/main/event_federation.py | 71 ++++---- synapse/storage/databases/main/group_server.py | 187 +++++++++++++-------- synapse/storage/databases/main/keys.py | 24 +-- synapse/storage/databases/main/transactions.py | 39 +++-- 11 files changed, 246 insertions(+), 175 deletions(-) create mode 100644 changelog.d/8195.misc (limited to 'synapse/storage/databases/main/keys.py') diff --git a/changelog.d/8195.misc b/changelog.d/8195.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8195.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 1ffdc1ed95..69a7182ef4 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -14,11 +14,16 @@ # limitations under the License. import logging import re +from typing import TYPE_CHECKING from synapse.api.constants import EventTypes +from synapse.appservice.api import ApplicationServiceApi from synapse.types import GroupID, get_domain_from_id from synapse.util.caches.descriptors import cached +if TYPE_CHECKING: + from synapse.storage.databases.main import DataStore + logger = logging.getLogger(__name__) @@ -35,19 +40,19 @@ class AppServiceTransaction(object): self.id = id self.events = events - def send(self, as_api): + async def send(self, as_api: ApplicationServiceApi) -> bool: """Sends this transaction using the provided AS API interface. Args: - as_api(ApplicationServiceApi): The API to use to send. + as_api: The API to use to send. Returns: - An Awaitable which resolves to True if the transaction was sent. + True if the transaction was sent. """ - return as_api.push_bulk( + return await as_api.push_bulk( service=self.service, events=self.events, txn_id=self.id ) - def complete(self, store): + async def complete(self, store: "DataStore") -> None: """Completes this transaction as successful. Marks this transaction ID on the application service and removes the @@ -55,10 +60,8 @@ class AppServiceTransaction(object): Args: store: The database store to operate on. - Returns: - A Deferred which resolves to True if the transaction was completed. """ - return store.complete_appservice_txn(service=self.service, txn_id=self.id) + await store.complete_appservice_txn(service=self.service, txn_id=self.id) class ApplicationService(object): diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index 769cd5de28..de1fe7da38 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -20,6 +20,7 @@ These actions are mostly only used by the :py:mod:`.replication` module. """ import logging +from typing import Optional, Tuple from synapse.federation.units import Transaction from synapse.logging.utils import log_function @@ -36,25 +37,27 @@ class TransactionActions(object): self.store = datastore @log_function - def have_responded(self, origin, transaction): - """ Have we already responded to a transaction with the same id and + async def have_responded( + self, origin: str, transaction: Transaction + ) -> Optional[Tuple[int, JsonDict]]: + """Have we already responded to a transaction with the same id and origin? Returns: - Deferred: Results in `None` if we have not previously responded to - this transaction or a 2-tuple of `(int, dict)` representing the - response code and response body. + `None` if we have not previously responded to this transaction or a + 2-tuple of `(int, dict)` representing the response code and response body. """ - if not transaction.transaction_id: + transaction_id = transaction.transaction_id # type: ignore + if not transaction_id: raise RuntimeError("Cannot persist a transaction with no transaction_id") - return self.store.get_received_txn_response(transaction.transaction_id, origin) + return await self.store.get_received_txn_response(transaction_id, origin) @log_function async def set_response( self, origin: str, transaction: Transaction, code: int, response: JsonDict ) -> None: - """ Persist how we responded to a transaction. + """Persist how we responded to a transaction. """ transaction_id = transaction.transaction_id # type: ignore if not transaction_id: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 155d087413..16389a0dca 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1879,8 +1879,8 @@ class FederationHandler(BaseHandler): else: return None - def get_min_depth_for_context(self, context): - return self.store.get_min_depth(context) + async def get_min_depth_for_context(self, context): + return await self.store.get_min_depth(context) async def _handle_new_event( self, origin, event, state=None, auth_events=None, backfilled=False diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 92f56f1602..454c0bc50c 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -172,7 +172,7 @@ class ApplicationServiceTransactionWorkerStore( "application_services_state", {"as_id": service.id}, {"state": state} ) - def create_appservice_txn(self, service, events): + async def create_appservice_txn(self, service, events): """Atomically creates a new transaction for this application service with the given list of events. @@ -209,20 +209,17 @@ class ApplicationServiceTransactionWorkerStore( ) return AppServiceTransaction(service=service, id=new_txn_id, events=events) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "create_appservice_txn", _create_appservice_txn ) - def complete_appservice_txn(self, txn_id, service): + async def complete_appservice_txn(self, txn_id, service) -> None: """Completes an application service transaction. Args: txn_id(str): The transaction ID being completed. service(ApplicationService): The application service which was sent this transaction. - Returns: - A Deferred which resolves if this transaction was stored - successfully. """ txn_id = int(txn_id) @@ -258,7 +255,7 @@ class ApplicationServiceTransactionWorkerStore( {"txn_id": txn_id, "as_id": service.id}, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "complete_appservice_txn", _complete_appservice_txn ) @@ -312,13 +309,13 @@ class ApplicationServiceTransactionWorkerStore( else: return int(last_txn_id[0]) # select 'last_txn' col - def set_appservice_last_pos(self, pos): + async def set_appservice_last_pos(self, pos) -> None: def set_appservice_last_pos_txn(txn): txn.execute( "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "set_appservice_last_pos", set_appservice_last_pos_txn ) diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index bb85637a95..0044433110 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -190,15 +190,15 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) @trace - def delete_device_msgs_for_remote(self, destination, up_to_stream_id): + async def delete_device_msgs_for_remote( + self, destination: str, up_to_stream_id: int + ) -> None: """Used to delete messages when the remote destination acknowledges their receipt. Args: - destination(str): The destination server_name - up_to_stream_id(int): Where to delete messages up to. - Returns: - A deferred that resolves when the messages have been deleted. + destination: The destination server_name + up_to_stream_id: Where to delete messages up to. """ def delete_messages_for_remote_destination_txn(txn): @@ -209,7 +209,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) txn.execute(sql, (destination, up_to_stream_id)) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn ) diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index 82f9d870fd..12cecceec2 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -151,7 +151,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): return sessions - def get_e2e_room_keys_multi(self, user_id, version, room_keys): + async def get_e2e_room_keys_multi(self, user_id, version, room_keys): """Get multiple room keys at a time. The difference between this function and get_e2e_room_keys is that this function can be used to retrieve multiple specific keys at a time, whereas get_e2e_room_keys is used for @@ -166,10 +166,10 @@ class EndToEndRoomKeyStore(SQLBaseStore): that we want to query Returns: - Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key + dict[str, dict[str, dict]]: a map of room IDs to session IDs to room key """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_e2e_room_keys_multi", self._get_e2e_room_keys_multi_txn, user_id, @@ -283,7 +283,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): raise StoreError(404, "No current backup version") return row[0] - def get_e2e_room_keys_version_info(self, user_id, version=None): + async def get_e2e_room_keys_version_info(self, user_id, version=None): """Get info metadata about a version of our room_keys backup. Args: @@ -293,7 +293,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): Raises: StoreError: with code 404 if there are no e2e_room_keys_versions present Returns: - A deferred dict giving the info metadata for this backup version, with + A dict giving the info metadata for this backup version, with fields including: version(str) algorithm(str) @@ -324,12 +324,12 @@ class EndToEndRoomKeyStore(SQLBaseStore): result["etag"] = 0 return result - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn ) @trace - def create_e2e_room_keys_version(self, user_id, info): + async def create_e2e_room_keys_version(self, user_id: str, info: dict) -> str: """Atomically creates a new version of this user's e2e_room_keys store with the given version info. @@ -338,7 +338,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): info(dict): the info about the backup version to be created Returns: - A deferred string for the newly created version ID + The newly created version ID """ def _create_e2e_room_keys_version_txn(txn): @@ -365,7 +365,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): return new_version - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn ) @@ -403,13 +403,15 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) @trace - def delete_e2e_room_keys_version(self, user_id, version=None): + async def delete_e2e_room_keys_version( + self, user_id: str, version: Optional[str] = None + ) -> None: """Delete a given backup version of the user's room keys. Doesn't delete their actual key data. Args: - user_id(str): the user whose backup version we're deleting - version(str): Optional. the version ID of the backup version we're deleting + user_id: the user whose backup version we're deleting + version: Optional. the version ID of the backup version we're deleting If missing, we delete the current backup version info. Raises: StoreError: with code 404 if there are no e2e_room_keys_versions present, @@ -430,13 +432,13 @@ class EndToEndRoomKeyStore(SQLBaseStore): keyvalues={"user_id": user_id, "version": this_version}, ) - return self.db_pool.simple_update_one_txn( + self.db_pool.simple_update_one_txn( txn, table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": this_version}, updatevalues={"deleted": 1}, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn ) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 6e5761c7b7..0b69aa6a94 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -59,7 +59,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas include_given: include the given events in result Returns: - list of event_ids + An awaitable which resolve to a list of event_ids """ return await self.db_pool.runInteraction( "get_auth_chain_ids", @@ -95,7 +95,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return list(results) - def get_auth_chain_difference(self, state_sets: List[Set[str]]): + async def get_auth_chain_difference(self, state_sets: List[Set[str]]) -> Set[str]: """Given sets of state events figure out the auth chain difference (as per state res v2 algorithm). @@ -104,10 +104,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas chain. Returns: - Deferred[Set[str]] + The set of the difference in auth chains. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_auth_chain_difference", self._get_auth_chain_difference_txn, state_sets, @@ -252,8 +252,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas # Return all events where not all sets can reach them. return {eid for eid, n in event_to_missing_sets.items() if n} - def get_oldest_events_with_depth_in_room(self, room_id): - return self.db_pool.runInteraction( + async def get_oldest_events_with_depth_in_room(self, room_id): + return await self.db_pool.runInteraction( "get_oldest_events_with_depth_in_room", self.get_oldest_events_with_depth_in_room_txn, room_id, @@ -293,7 +293,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas else: return max(row["depth"] for row in rows) - def get_prev_events_for_room(self, room_id: str): + async def get_prev_events_for_room(self, room_id: str) -> List[str]: """ Gets a subset of the current forward extremities in the given room. @@ -301,14 +301,14 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas events which refer to hundreds of prev_events. Args: - room_id (str): room_id + room_id: room_id Returns: - Deferred[List[str]]: the event ids of the forward extremites + The event ids of the forward extremities. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id ) @@ -328,17 +328,19 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return [row[0] for row in txn] - def get_rooms_with_many_extremities(self, min_count, limit, room_id_filter): + async def get_rooms_with_many_extremities( + self, min_count: int, limit: int, room_id_filter: Iterable[str] + ) -> List[str]: """Get the top rooms with at least N extremities. Args: - min_count (int): The minimum number of extremities - limit (int): The maximum number of rooms to return. - room_id_filter (iterable[str]): room_ids to exclude from the results + min_count: The minimum number of extremities + limit: The maximum number of rooms to return. + room_id_filter: room_ids to exclude from the results Returns: - Deferred[list]: At most `limit` room IDs that have at least - `min_count` extremities, sorted by extremity count. + At most `limit` room IDs that have at least `min_count` extremities, + sorted by extremity count. """ def _get_rooms_with_many_extremities_txn(txn): @@ -363,7 +365,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas txn.execute(sql, query_args) return [room_id for room_id, in txn] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn ) @@ -376,10 +378,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas desc="get_latest_event_ids_in_room", ) - def get_min_depth(self, room_id): - """ For hte given room, get the minimum depth we have seen for it. + async def get_min_depth(self, room_id: str) -> int: + """For the given room, get the minimum depth we have seen for it. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_min_depth", self._get_min_depth_interaction, room_id ) @@ -394,7 +396,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return int(min_depth) if min_depth is not None else None - def get_forward_extremeties_for_room(self, room_id, stream_ordering): + async def get_forward_extremeties_for_room( + self, room_id: str, stream_ordering: int + ) -> List[str]: """For a given room_id and stream_ordering, return the forward extremeties of the room at that point in "time". @@ -402,11 +406,11 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas stream_orderings from that point. Args: - room_id (str): - stream_ordering (int): + room_id: + stream_ordering: Returns: - deferred, which resolves to a list of event_ids + A list of event_ids """ # We want to make the cache more effective, so we clamp to the last # change before the given ordering. @@ -422,10 +426,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas if last_change > self.stream_ordering_month_ago: stream_ordering = min(last_change, stream_ordering) - return self._get_forward_extremeties_for_room(room_id, stream_ordering) + return await self._get_forward_extremeties_for_room(room_id, stream_ordering) @cached(max_entries=5000, num_args=2) - def _get_forward_extremeties_for_room(self, room_id, stream_ordering): + async def _get_forward_extremeties_for_room(self, room_id, stream_ordering): """For a given room_id and stream_ordering, return the forward extremeties of the room at that point in "time". @@ -450,19 +454,18 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas txn.execute(sql, (stream_ordering, room_id)) return [event_id for event_id, in txn] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn ) - async def get_backfill_events(self, room_id, event_list, limit): + async def get_backfill_events(self, room_id: str, event_list: list, limit: int): """Get a list of Events for a given topic that occurred before (and including) the events in event_list. Return a list of max size `limit` Args: - txn - room_id (str) - event_list (list) - limit (int) + room_id + event_list + limit """ event_ids = await self.db_pool.runInteraction( "get_backfill_events", @@ -631,8 +634,8 @@ class EventFederationStore(EventFederationWorkerStore): _delete_old_forward_extrem_cache_txn, ) - def clean_room_for_join(self, room_id): - return self.db_pool.runInteraction( + async def clean_room_for_join(self, room_id): + return await self.db_pool.runInteraction( "clean_room_for_join", self._clean_room_for_join_txn, room_id ) diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 6c60171888..ccfbb2135e 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json @@ -70,7 +70,9 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_invited_users_in_group", ) - def get_rooms_in_group(self, group_id: str, include_private: bool = False): + async def get_rooms_in_group( + self, group_id: str, include_private: bool = False + ) -> List[Dict[str, Union[str, bool]]]: """Retrieve the rooms that belong to a given group. Does not return rooms that lack members. @@ -79,8 +81,7 @@ class GroupServerWorkerStore(SQLBaseStore): include_private: Whether to return private rooms in results Returns: - Deferred[List[Dict[str, str|bool]]]: A list of dictionaries, each in the - form of: + A list of dictionaries, each in the form of: { "room_id": "!a_room_id:example.com", # The ID of the room @@ -117,13 +118,13 @@ class GroupServerWorkerStore(SQLBaseStore): for room_id, is_public in txn ] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_rooms_in_group", _get_rooms_in_group_txn ) - def get_rooms_for_summary_by_category( + async def get_rooms_for_summary_by_category( self, group_id: str, include_private: bool = False, - ): + ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: """Get the rooms and categories that should be included in a summary request Args: @@ -131,7 +132,7 @@ class GroupServerWorkerStore(SQLBaseStore): include_private: Whether to return private rooms in results Returns: - Deferred[Tuple[List, Dict]]: A tuple containing: + A tuple containing: * A list of dictionaries with the keys: * "room_id": str, the room ID @@ -207,7 +208,7 @@ class GroupServerWorkerStore(SQLBaseStore): return rooms, categories - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_rooms_for_summary", _get_rooms_for_summary_txn ) @@ -281,10 +282,11 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_local_groups_for_room", ) - def get_users_for_summary_by_role(self, group_id, include_private=False): + async def get_users_for_summary_by_role(self, group_id, include_private=False): """Get the users and roles that should be included in a summary request - Returns ([users], [roles]) + Returns: + ([users], [roles]) """ def _get_users_for_summary_txn(txn): @@ -338,7 +340,7 @@ class GroupServerWorkerStore(SQLBaseStore): return users, roles - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_users_for_summary_by_role", _get_users_for_summary_txn ) @@ -376,7 +378,7 @@ class GroupServerWorkerStore(SQLBaseStore): allow_none=True, ) - def get_users_membership_info_in_group(self, group_id, user_id): + async def get_users_membership_info_in_group(self, group_id, user_id): """Get a dict describing the membership of a user in a group. Example if joined: @@ -387,7 +389,8 @@ class GroupServerWorkerStore(SQLBaseStore): "is_privileged": False, } - Returns an empty dict if the user is not join/invite/etc + Returns: + An empty dict if the user is not join/invite/etc """ def _get_users_membership_in_group_txn(txn): @@ -419,7 +422,7 @@ class GroupServerWorkerStore(SQLBaseStore): return {} - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_users_membership_info_in_group", _get_users_membership_in_group_txn ) @@ -433,7 +436,7 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_publicised_groups_for_user", ) - def get_attestations_need_renewals(self, valid_until_ms): + async def get_attestations_need_renewals(self, valid_until_ms): """Get all attestations that need to be renewed until givent time """ @@ -445,7 +448,7 @@ class GroupServerWorkerStore(SQLBaseStore): txn.execute(sql, (valid_until_ms,)) return self.db_pool.cursor_to_dict(txn) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_attestations_need_renewals", _get_attestations_need_renewals_txn ) @@ -475,7 +478,7 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_joined_groups", ) - def get_all_groups_for_user(self, user_id, now_token): + async def get_all_groups_for_user(self, user_id, now_token): def _get_all_groups_for_user_txn(txn): sql = """ SELECT group_id, type, membership, u.content @@ -495,7 +498,7 @@ class GroupServerWorkerStore(SQLBaseStore): for row in txn ] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_all_groups_for_user", _get_all_groups_for_user_txn ) @@ -600,8 +603,27 @@ class GroupServerStore(GroupServerWorkerStore): desc="set_group_join_policy", ) - def add_room_to_summary(self, group_id, room_id, category_id, order, is_public): - return self.db_pool.runInteraction( + async def add_room_to_summary( + self, + group_id: str, + room_id: str, + category_id: str, + order: int, + is_public: Optional[bool], + ) -> None: + """Add (or update) room's entry in summary. + + Args: + group_id + room_id + category_id: If not None then adds the category to the end of + the summary if its not already there. + order: If not None inserts the room at that position, e.g. an order + of 1 will put the room first. Otherwise, the room gets added to + the end. + is_public + """ + await self.db_pool.runInteraction( "add_room_to_summary", self._add_room_to_summary_txn, group_id, @@ -612,18 +634,26 @@ class GroupServerStore(GroupServerWorkerStore): ) def _add_room_to_summary_txn( - self, txn, group_id, room_id, category_id, order, is_public - ): + self, + txn, + group_id: str, + room_id: str, + category_id: str, + order: int, + is_public: Optional[bool], + ) -> None: """Add (or update) room's entry in summary. Args: - group_id (str) - room_id (str) - category_id (str): If not None then adds the category to the end of - the summary if its not already there. [Optional] - order (int): If not None inserts the room at that position, e.g. - an order of 1 will put the room first. Otherwise, the room gets - added to the end. + txn + group_id + room_id + category_id: If not None then adds the category to the end of + the summary if its not already there. + order: If not None inserts the room at that position, e.g. an order + of 1 will put the room first. Otherwise, the room gets added to + the end. + is_public """ room_in_group = self.db_pool.simple_select_one_onecol_txn( txn, @@ -818,8 +848,27 @@ class GroupServerStore(GroupServerWorkerStore): desc="remove_group_role", ) - def add_user_to_summary(self, group_id, user_id, role_id, order, is_public): - return self.db_pool.runInteraction( + async def add_user_to_summary( + self, + group_id: str, + user_id: str, + role_id: str, + order: int, + is_public: Optional[bool], + ) -> None: + """Add (or update) user's entry in summary. + + Args: + group_id + user_id + role_id: If not None then adds the role to the end of the summary if + its not already there. + order: If not None inserts the user at that position, e.g. an order + of 1 will put the user first. Otherwise, the user gets added to + the end. + is_public + """ + await self.db_pool.runInteraction( "add_user_to_summary", self._add_user_to_summary_txn, group_id, @@ -830,18 +879,26 @@ class GroupServerStore(GroupServerWorkerStore): ) def _add_user_to_summary_txn( - self, txn, group_id, user_id, role_id, order, is_public + self, + txn, + group_id: str, + user_id: str, + role_id: str, + order: int, + is_public: Optional[bool], ): """Add (or update) user's entry in summary. Args: - group_id (str) - user_id (str) - role_id (str): If not None then adds the role to the end of - the summary if its not already there. [Optional] - order (int): If not None inserts the user at that position, e.g. - an order of 1 will put the user first. Otherwise, the user gets - added to the end. + txn + group_id + user_id + role_id: If not None then adds the role to the end of the summary if + its not already there. + order: If not None inserts the user at that position, e.g. an order + of 1 will put the user first. Otherwise, the user gets added to + the end. + is_public """ user_in_group = self.db_pool.simple_select_one_onecol_txn( txn, @@ -963,27 +1020,26 @@ class GroupServerStore(GroupServerWorkerStore): desc="add_group_invite", ) - def add_user_to_group( + async def add_user_to_group( self, - group_id, - user_id, - is_admin=False, - is_public=True, - local_attestation=None, - remote_attestation=None, - ): + group_id: str, + user_id: str, + is_admin: bool = False, + is_public: bool = True, + local_attestation: dict = None, + remote_attestation: dict = None, + ) -> None: """Add a user to the group server. Args: - group_id (str) - user_id (str) - is_admin (bool) - is_public (bool) - local_attestation (dict): The attestation the GS created to give - to the remote server. Optional if the user and group are on the - same server - remote_attestation (dict): The attestation given to GS by remote + group_id + user_id + is_admin + is_public + local_attestation: The attestation the GS created to give to the remote server. Optional if the user and group are on the same server + remote_attestation: The attestation given to GS by remote server. + Optional if the user and group are on the same server """ def _add_user_to_group_txn(txn): @@ -1026,9 +1082,9 @@ class GroupServerStore(GroupServerWorkerStore): }, ) - return self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn) + await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn) - def remove_user_from_group(self, group_id, user_id): + async def remove_user_from_group(self, group_id: str, user_id: str) -> None: def _remove_user_from_group_txn(txn): self.db_pool.simple_delete_txn( txn, @@ -1056,7 +1112,7 @@ class GroupServerStore(GroupServerWorkerStore): keyvalues={"group_id": group_id, "user_id": user_id}, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "remove_user_from_group", _remove_user_from_group_txn ) @@ -1079,7 +1135,7 @@ class GroupServerStore(GroupServerWorkerStore): desc="update_room_in_group_visibility", ) - def remove_room_from_group(self, group_id, room_id): + async def remove_room_from_group(self, group_id: str, room_id: str) -> None: def _remove_room_from_group_txn(txn): self.db_pool.simple_delete_txn( txn, @@ -1093,7 +1149,7 @@ class GroupServerStore(GroupServerWorkerStore): keyvalues={"group_id": group_id, "room_id": room_id}, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "remove_room_from_group", _remove_room_from_group_txn ) @@ -1286,14 +1342,11 @@ class GroupServerStore(GroupServerWorkerStore): def get_group_stream_token(self): return self._group_updates_id_gen.get_current_token() - def delete_group(self, group_id): + async def delete_group(self, group_id: str) -> None: """Deletes a group fully from the database. Args: - group_id (str) - - Returns: - Deferred + group_id: The group ID to delete. """ def _delete_group_txn(txn): @@ -1317,4 +1370,4 @@ class GroupServerStore(GroupServerWorkerStore): txn, table=table, keyvalues={"group_id": group_id} ) - return self.db_pool.runInteraction("delete_group", _delete_group_txn) + await self.db_pool.runInteraction("delete_group", _delete_group_txn) diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index 1c0a049c55..ad43bb05ab 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -16,7 +16,7 @@ import itertools import logging -from typing import Iterable, Tuple +from typing import Dict, Iterable, List, Optional, Tuple from signedjson.key import decode_verify_key_bytes @@ -42,16 +42,17 @@ class KeyStore(SQLBaseStore): @cachedList( cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids" ) - def get_server_verify_keys(self, server_name_and_key_ids): + async def get_server_verify_keys( + self, server_name_and_key_ids: Iterable[Tuple[str, str]] + ) -> Dict[Tuple[str, str], Optional[FetchKeyResult]]: """ Args: - server_name_and_key_ids (iterable[Tuple[str, str]]): + server_name_and_key_ids: iterable of (server_name, key-id) tuples to fetch keys for Returns: - Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]: - map from (server_name, key_id) -> FetchKeyResult, or None if the key is - unknown + A map from (server_name, key_id) -> FetchKeyResult, or None if the + key is unknown """ keys = {} @@ -87,7 +88,7 @@ class KeyStore(SQLBaseStore): _get_keys(txn, batch) return keys - return self.db_pool.runInteraction("get_server_verify_keys", _txn) + return await self.db_pool.runInteraction("get_server_verify_keys", _txn) async def store_server_verify_keys( self, @@ -179,7 +180,9 @@ class KeyStore(SQLBaseStore): desc="store_server_keys_json", ) - def get_server_keys_json(self, server_keys): + async def get_server_keys_json( + self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]] + ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[dict]]: """Retrive the key json for a list of server_keys and key ids. If no keys are found for a given server, key_id and source then that server, key_id, and source triplet entry will be an empty list. @@ -188,8 +191,7 @@ class KeyStore(SQLBaseStore): Args: server_keys (list): List of (server_name, key_id, source) triplets. Returns: - Deferred[dict[Tuple[str, str, str|None], list[dict]]]: - Dict mapping (server_name, key_id, source) triplets to lists of dicts + A mapping from (server_name, key_id, source) triplets to a list of dicts """ def _get_server_keys_json_txn(txn): @@ -215,6 +217,6 @@ class KeyStore(SQLBaseStore): results[(server_name, key_id, from_server)] = rows return results - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_server_keys_json", _get_server_keys_json_txn ) diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 2efcc0dc66..5b31aab700 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -15,6 +15,7 @@ import logging from collections import namedtuple +from typing import Optional, Tuple from canonicaljson import encode_canonical_json @@ -56,21 +57,23 @@ class TransactionStore(SQLBaseStore): expiry_ms=5 * 60 * 1000, ) - def get_received_txn_response(self, transaction_id, origin): + async def get_received_txn_response( + self, transaction_id: str, origin: str + ) -> Optional[Tuple[int, JsonDict]]: """For an incoming transaction from a given origin, check if we have already responded to it. If so, return the response code and response body (as a dict). Args: - transaction_id (str) - origin(str) + transaction_id + origin Returns: - tuple: None if we have not previously responded to - this transaction or a 2-tuple of (int, dict) + None if we have not previously responded to this transaction or a + 2-tuple of (int, dict) """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_received_txn_response", self._get_received_txn_response, transaction_id, @@ -166,21 +169,25 @@ class TransactionStore(SQLBaseStore): else: return None - def set_destination_retry_timings( - self, destination, failure_ts, retry_last_ts, retry_interval - ): + async def set_destination_retry_timings( + self, + destination: str, + failure_ts: Optional[int], + retry_last_ts: int, + retry_interval: int, + ) -> None: """Sets the current retry timings for a given destination. Both timings should be zero if retrying is no longer occuring. Args: - destination (str) - failure_ts (int|None) - when the server started failing (ms since epoch) - retry_last_ts (int) - time of last retry attempt in unix epoch ms - retry_interval (int) - how long until next retry in ms + destination + failure_ts: when the server started failing (ms since epoch) + retry_last_ts: time of last retry attempt in unix epoch ms + retry_interval: how long until next retry in ms """ self._destination_retry_cache.pop(destination, None) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "set_destination_retry_timings", self._set_destination_retry_timings, destination, @@ -256,13 +263,13 @@ class TransactionStore(SQLBaseStore): "cleanup_transactions", self._cleanup_transactions ) - def _cleanup_transactions(self): + async def _cleanup_transactions(self) -> None: now = self._clock.time_msec() month_ago = now - 30 * 24 * 60 * 60 * 1000 def _cleanup_transactions_txn(txn): txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_cleanup_transactions", _cleanup_transactions_txn ) -- cgit 1.5.1