From a7bdf98d01d2225a479753a85ba81adf02b16a32 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 5 Aug 2020 21:38:57 +0100 Subject: Rename database classes to make some sense (#8033) --- synapse/storage/databases/main/registration.py | 1588 ++++++++++++++++++++++++ 1 file changed, 1588 insertions(+) create mode 100644 synapse/storage/databases/main/registration.py (limited to 'synapse/storage/databases/main/registration.py') diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py new file mode 100644 index 0000000000..f618629e09 --- /dev/null +++ b/synapse/storage/databases/main/registration.py @@ -0,0 +1,1588 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re +from typing import Optional + +from twisted.internet import defer +from twisted.internet.defer import Deferred + +from synapse.api.constants import UserTypes +from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import DatabasePool +from synapse.storage.types import Cursor +from synapse.storage.util.sequence import build_sequence_generator +from synapse.types import UserID +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks + +THIRTY_MINUTES_IN_MS = 30 * 60 * 1000 + +logger = logging.getLogger(__name__) + + +class RegistrationWorkerStore(SQLBaseStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super(RegistrationWorkerStore, self).__init__(database, db_conn, hs) + + self.config = hs.config + self.clock = hs.get_clock() + + self._user_id_seq = build_sequence_generator( + database.engine, find_max_generated_user_id_localpart, "user_id_seq", + ) + + @cached() + def get_user_by_id(self, user_id): + return self.db_pool.simple_select_one( + table="users", + keyvalues={"name": user_id}, + retcols=[ + "name", + "password_hash", + "is_guest", + "admin", + "consent_version", + "consent_server_notice_sent", + "appservice_id", + "creation_ts", + "user_type", + "deactivated", + ], + allow_none=True, + desc="get_user_by_id", + ) + + @defer.inlineCallbacks + def is_trial_user(self, user_id): + """Checks if user is in the "trial" period, i.e. within the first + N days of registration defined by `mau_trial_days` config + + Args: + user_id (str) + + Returns: + Deferred[bool] + """ + + info = yield self.get_user_by_id(user_id) + if not info: + return False + + now = self.clock.time_msec() + trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000 + is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms + return is_trial + + @cached() + def get_user_by_access_token(self, token): + """Get a user from the given access token. + + Args: + token (str): The access token of a user. + Returns: + defer.Deferred: None, if the token did not match, otherwise dict + including the keys `name`, `is_guest`, `device_id`, `token_id`, + `valid_until_ms`. + """ + return self.db_pool.runInteraction( + "get_user_by_access_token", self._query_for_auth, token + ) + + @cachedInlineCallbacks() + def get_expiration_ts_for_user(self, user_id): + """Get the expiration timestamp for the account bearing a given user ID. + + Args: + user_id (str): The ID of the user. + Returns: + defer.Deferred: None, if the account has no expiration timestamp, + otherwise int representation of the timestamp (as a number of + milliseconds since epoch). + """ + res = yield self.db_pool.simple_select_one_onecol( + table="account_validity", + keyvalues={"user_id": user_id}, + retcol="expiration_ts_ms", + allow_none=True, + desc="get_expiration_ts_for_user", + ) + return res + + @defer.inlineCallbacks + def set_account_validity_for_user( + self, user_id, expiration_ts, email_sent, renewal_token=None + ): + """Updates the account validity properties of the given account, with the + given values. + + Args: + user_id (str): ID of the account to update properties for. + expiration_ts (int): New expiration date, as a timestamp in milliseconds + since epoch. + email_sent (bool): True means a renewal email has been sent for this + account and there's no need to send another one for the current validity + period. + renewal_token (str): Renewal token the user can use to extend the validity + of their account. Defaults to no token. + """ + + def set_account_validity_for_user_txn(txn): + self.db_pool.simple_update_txn( + txn=txn, + table="account_validity", + keyvalues={"user_id": user_id}, + updatevalues={ + "expiration_ts_ms": expiration_ts, + "email_sent": email_sent, + "renewal_token": renewal_token, + }, + ) + self._invalidate_cache_and_stream( + txn, self.get_expiration_ts_for_user, (user_id,) + ) + + yield self.db_pool.runInteraction( + "set_account_validity_for_user", set_account_validity_for_user_txn + ) + + @defer.inlineCallbacks + def set_renewal_token_for_user(self, user_id, renewal_token): + """Defines a renewal token for a given user. + + Args: + user_id (str): ID of the user to set the renewal token for. + renewal_token (str): Random unique string that will be used to renew the + user's account. + + Raises: + StoreError: The provided token is already set for another user. + """ + yield self.db_pool.simple_update_one( + table="account_validity", + keyvalues={"user_id": user_id}, + updatevalues={"renewal_token": renewal_token}, + desc="set_renewal_token_for_user", + ) + + @defer.inlineCallbacks + def get_user_from_renewal_token(self, renewal_token): + """Get a user ID from a renewal token. + + Args: + renewal_token (str): The renewal token to perform the lookup with. + + Returns: + defer.Deferred[str]: The ID of the user to which the token belongs. + """ + res = yield self.db_pool.simple_select_one_onecol( + table="account_validity", + keyvalues={"renewal_token": renewal_token}, + retcol="user_id", + desc="get_user_from_renewal_token", + ) + + return res + + @defer.inlineCallbacks + def get_renewal_token_for_user(self, user_id): + """Get the renewal token associated with a given user ID. + + Args: + user_id (str): The user ID to lookup a token for. + + Returns: + defer.Deferred[str]: The renewal token associated with this user ID. + """ + res = yield self.db_pool.simple_select_one_onecol( + table="account_validity", + keyvalues={"user_id": user_id}, + retcol="renewal_token", + desc="get_renewal_token_for_user", + ) + + return res + + @defer.inlineCallbacks + def get_users_expiring_soon(self): + """Selects users whose account will expire in the [now, now + renew_at] time + window (see configuration for account_validity for information on what renew_at + refers to). + + Returns: + Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]] + """ + + def select_users_txn(txn, now_ms, renew_at): + sql = ( + "SELECT user_id, expiration_ts_ms FROM account_validity" + " WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?" + ) + values = [False, now_ms, renew_at] + txn.execute(sql, values) + return self.db_pool.cursor_to_dict(txn) + + res = yield self.db_pool.runInteraction( + "get_users_expiring_soon", + select_users_txn, + self.clock.time_msec(), + self.config.account_validity.renew_at, + ) + + return res + + @defer.inlineCallbacks + def set_renewal_mail_status(self, user_id, email_sent): + """Sets or unsets the flag that indicates whether a renewal email has been sent + to the user (and the user hasn't renewed their account yet). + + Args: + user_id (str): ID of the user to set/unset the flag for. + email_sent (bool): Flag which indicates whether a renewal email has been sent + to this user. + """ + yield self.db_pool.simple_update_one( + table="account_validity", + keyvalues={"user_id": user_id}, + updatevalues={"email_sent": email_sent}, + desc="set_renewal_mail_status", + ) + + @defer.inlineCallbacks + def delete_account_validity_for_user(self, user_id): + """Deletes the entry for the given user in the account validity table, removing + their expiration date and renewal token. + + Args: + user_id (str): ID of the user to remove from the account validity table. + """ + yield self.db_pool.simple_delete_one( + table="account_validity", + keyvalues={"user_id": user_id}, + desc="delete_account_validity_for_user", + ) + + async def is_server_admin(self, user): + """Determines if a user is an admin of this homeserver. + + Args: + user (UserID): user ID of the user to test + + Returns (bool): + true iff the user is a server admin, false otherwise. + """ + res = await self.db_pool.simple_select_one_onecol( + table="users", + keyvalues={"name": user.to_string()}, + retcol="admin", + allow_none=True, + desc="is_server_admin", + ) + + return bool(res) if res else False + + def set_server_admin(self, user, admin): + """Sets whether a user is an admin of this homeserver. + + Args: + user (UserID): user ID of the user to test + admin (bool): true iff the user is to be a server admin, + false otherwise. + """ + + def set_server_admin_txn(txn): + self.db_pool.simple_update_one_txn( + txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0} + ) + self._invalidate_cache_and_stream( + txn, self.get_user_by_id, (user.to_string(),) + ) + + return self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) + + def _query_for_auth(self, txn, token): + sql = ( + "SELECT users.name, users.is_guest, access_tokens.id as token_id," + " access_tokens.device_id, access_tokens.valid_until_ms" + " FROM users" + " INNER JOIN access_tokens on users.name = access_tokens.user_id" + " WHERE token = ?" + ) + + txn.execute(sql, (token,)) + rows = self.db_pool.cursor_to_dict(txn) + if rows: + return rows[0] + + return None + + @cachedInlineCallbacks() + def is_real_user(self, user_id): + """Determines if the user is a real user, ie does not have a 'user_type'. + + Args: + user_id (str): user id to test + + Returns: + Deferred[bool]: True if user 'user_type' is null or empty string + """ + res = yield self.db_pool.runInteraction( + "is_real_user", self.is_real_user_txn, user_id + ) + return res + + @cached() + def is_support_user(self, user_id): + """Determines if the user is of type UserTypes.SUPPORT + + Args: + user_id (str): user id to test + + Returns: + Deferred[bool]: True if user is of type UserTypes.SUPPORT + """ + return self.db_pool.runInteraction( + "is_support_user", self.is_support_user_txn, user_id + ) + + def is_real_user_txn(self, txn, user_id): + res = self.db_pool.simple_select_one_onecol_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + retcol="user_type", + allow_none=True, + ) + return res is None + + def is_support_user_txn(self, txn, user_id): + res = self.db_pool.simple_select_one_onecol_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + retcol="user_type", + allow_none=True, + ) + return True if res == UserTypes.SUPPORT else False + + def get_users_by_id_case_insensitive(self, user_id): + """Gets users that match user_id case insensitively. + Returns a mapping of user_id -> password_hash. + """ + + def f(txn): + sql = "SELECT name, password_hash FROM users WHERE lower(name) = lower(?)" + txn.execute(sql, (user_id,)) + return dict(txn) + + return self.db_pool.runInteraction("get_users_by_id_case_insensitive", f) + + async def get_user_by_external_id( + self, auth_provider: str, external_id: str + ) -> str: + """Look up a user by their external auth id + + Args: + auth_provider: identifier for the remote auth provider + external_id: id on that system + + Returns: + str|None: the mxid of the user, or None if they are not known + """ + return await self.db_pool.simple_select_one_onecol( + table="user_external_ids", + keyvalues={"auth_provider": auth_provider, "external_id": external_id}, + retcol="user_id", + allow_none=True, + desc="get_user_by_external_id", + ) + + @defer.inlineCallbacks + def count_all_users(self): + """Counts all users registered on the homeserver.""" + + def _count_users(txn): + txn.execute("SELECT COUNT(*) AS users FROM users") + rows = self.db_pool.cursor_to_dict(txn) + if rows: + return rows[0]["users"] + return 0 + + ret = yield self.db_pool.runInteraction("count_users", _count_users) + return ret + + def count_daily_user_type(self): + """ + Counts 1) native non guest users + 2) native guests users + 3) bridged users + who registered on the homeserver in the past 24 hours + """ + + def _count_daily_user_type(txn): + yesterday = int(self._clock.time()) - (60 * 60 * 24) + + sql = """ + SELECT user_type, COALESCE(count(*), 0) AS count FROM ( + SELECT + CASE + WHEN is_guest=0 AND appservice_id IS NULL THEN 'native' + WHEN is_guest=1 AND appservice_id IS NULL THEN 'guest' + WHEN is_guest=0 AND appservice_id IS NOT NULL THEN 'bridged' + END AS user_type + FROM users + WHERE creation_ts > ? + ) AS t GROUP BY user_type + """ + results = {"native": 0, "guest": 0, "bridged": 0} + txn.execute(sql, (yesterday,)) + for row in txn: + results[row[0]] = row[1] + return results + + return self.db_pool.runInteraction( + "count_daily_user_type", _count_daily_user_type + ) + + @defer.inlineCallbacks + def count_nonbridged_users(self): + def _count_users(txn): + txn.execute( + """ + SELECT COALESCE(COUNT(*), 0) FROM users + WHERE appservice_id IS NULL + """ + ) + (count,) = txn.fetchone() + return count + + ret = yield self.db_pool.runInteraction("count_users", _count_users) + return ret + + @defer.inlineCallbacks + def count_real_users(self): + """Counts all users without a special user_type registered on the homeserver.""" + + def _count_users(txn): + txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null") + rows = self.db_pool.cursor_to_dict(txn) + if rows: + return rows[0]["users"] + return 0 + + ret = yield self.db_pool.runInteraction("count_real_users", _count_users) + return ret + + async def generate_user_id(self) -> str: + """Generate a suitable localpart for a guest user + + Returns: a (hopefully) free localpart + """ + next_id = await self.db_pool.runInteraction( + "generate_user_id", self._user_id_seq.get_next_id_txn + ) + + return str(next_id) + + async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[str]: + """Returns user id from threepid + + Args: + medium: threepid medium e.g. email + address: threepid address e.g. me@example.com + + Returns: + The user ID or None if no user id/threepid mapping exists + """ + user_id = await self.db_pool.runInteraction( + "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address + ) + return user_id + + def get_user_id_by_threepid_txn(self, txn, medium, address): + """Returns user id from threepid + + Args: + txn (cursor): + medium (str): threepid medium e.g. email + address (str): threepid address e.g. me@example.com + + Returns: + str|None: user id or None if no user id/threepid mapping exists + """ + ret = self.db_pool.simple_select_one_txn( + txn, + "user_threepids", + {"medium": medium, "address": address}, + ["user_id"], + True, + ) + if ret: + return ret["user_id"] + return None + + @defer.inlineCallbacks + def user_add_threepid(self, user_id, medium, address, validated_at, added_at): + yield self.db_pool.simple_upsert( + "user_threepids", + {"medium": medium, "address": address}, + {"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, + ) + + @defer.inlineCallbacks + def user_get_threepids(self, user_id): + ret = yield self.db_pool.simple_select_list( + "user_threepids", + {"user_id": user_id}, + ["medium", "address", "validated_at", "added_at"], + "user_get_threepids", + ) + return ret + + def user_delete_threepid(self, user_id, medium, address): + return self.db_pool.simple_delete( + "user_threepids", + keyvalues={"user_id": user_id, "medium": medium, "address": address}, + desc="user_delete_threepid", + ) + + def user_delete_threepids(self, user_id: str): + """Delete all threepid this user has bound + + Args: + user_id: The user id to delete all threepids of + + """ + return self.db_pool.simple_delete( + "user_threepids", + keyvalues={"user_id": user_id}, + desc="user_delete_threepids", + ) + + def add_user_bound_threepid(self, user_id, medium, address, id_server): + """The server proxied a bind request to the given identity server on + behalf of the given user. We need to remember this in case the user + asks us to unbind the threepid. + + Args: + user_id (str) + medium (str) + address (str) + id_server (str) + + Returns: + Deferred + """ + # We need to use an upsert, in case they user had already bound the + # threepid + return self.db_pool.simple_upsert( + table="user_threepid_id_server", + keyvalues={ + "user_id": user_id, + "medium": medium, + "address": address, + "id_server": id_server, + }, + values={}, + insertion_values={}, + desc="add_user_bound_threepid", + ) + + def user_get_bound_threepids(self, user_id): + """Get the threepids that a user has bound to an identity server through the homeserver + The homeserver remembers where binds to an identity server occurred. Using this + method can retrieve those threepids. + + Args: + user_id (str): The ID of the user to retrieve threepids for + + Returns: + Deferred[list[dict]]: List of dictionaries containing the following: + medium (str): The medium of the threepid (e.g "email") + address (str): The address of the threepid (e.g "bob@example.com") + """ + return self.db_pool.simple_select_list( + table="user_threepid_id_server", + keyvalues={"user_id": user_id}, + retcols=["medium", "address"], + desc="user_get_bound_threepids", + ) + + def remove_user_bound_threepid(self, user_id, medium, address, id_server): + """The server proxied an unbind request to the given identity server on + behalf of the given user, so we remove the mapping of threepid to + identity server. + + Args: + user_id (str) + medium (str) + address (str) + id_server (str) + + Returns: + Deferred + """ + return self.db_pool.simple_delete( + table="user_threepid_id_server", + keyvalues={ + "user_id": user_id, + "medium": medium, + "address": address, + "id_server": id_server, + }, + desc="remove_user_bound_threepid", + ) + + def get_id_servers_user_bound(self, user_id, medium, address): + """Get the list of identity servers that the server proxied bind + requests to for given user and threepid + + Args: + user_id (str) + medium (str) + address (str) + + Returns: + Deferred[list[str]]: Resolves to a list of identity servers + """ + return self.db_pool.simple_select_onecol( + table="user_threepid_id_server", + keyvalues={"user_id": user_id, "medium": medium, "address": address}, + retcol="id_server", + desc="get_id_servers_user_bound", + ) + + @cachedInlineCallbacks() + def get_user_deactivated_status(self, user_id): + """Retrieve the value for the `deactivated` property for the provided user. + + Args: + user_id (str): The ID of the user to retrieve the status for. + + Returns: + defer.Deferred(bool): The requested value. + """ + + res = yield self.db_pool.simple_select_one_onecol( + table="users", + keyvalues={"name": user_id}, + retcol="deactivated", + desc="get_user_deactivated_status", + ) + + # Convert the integer into a boolean. + return res == 1 + + def get_threepid_validation_session( + self, medium, client_secret, address=None, sid=None, validated=True + ): + """Gets a session_id and last_send_attempt (if available) for a + combination of validation metadata + + Args: + medium (str|None): The medium of the 3PID + address (str|None): The address of the 3PID + sid (str|None): The ID of the validation session + client_secret (str): A unique string provided by the client to help identify this + validation attempt + validated (bool|None): Whether sessions should be filtered by + whether they have been validated already or not. None to + perform no filtering + + Returns: + Deferred[dict|None]: A dict containing the following: + * address - address of the 3pid + * medium - medium of the 3pid + * client_secret - a secret provided by the client for this validation session + * session_id - ID of the validation session + * send_attempt - a number serving to dedupe send attempts for this session + * validated_at - timestamp of when this session was validated if so + + Otherwise None if a validation session is not found + """ + if not client_secret: + raise SynapseError( + 400, "Missing parameter: client_secret", errcode=Codes.MISSING_PARAM + ) + + keyvalues = {"client_secret": client_secret} + if medium: + keyvalues["medium"] = medium + if address: + keyvalues["address"] = address + if sid: + keyvalues["session_id"] = sid + + assert address or sid + + def get_threepid_validation_session_txn(txn): + sql = """ + SELECT address, session_id, medium, client_secret, + last_send_attempt, validated_at + FROM threepid_validation_session WHERE %s + """ % ( + " AND ".join("%s = ?" % k for k in keyvalues.keys()), + ) + + if validated is not None: + sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL") + + sql += " LIMIT 1" + + txn.execute(sql, list(keyvalues.values())) + rows = self.db_pool.cursor_to_dict(txn) + if not rows: + return None + + return rows[0] + + return self.db_pool.runInteraction( + "get_threepid_validation_session", get_threepid_validation_session_txn + ) + + def delete_threepid_session(self, session_id): + """Removes a threepid validation session from the database. This can + be done after validation has been performed and whatever action was + waiting on it has been carried out + + Args: + session_id (str): The ID of the session to delete + """ + + def delete_threepid_session_txn(txn): + self.db_pool.simple_delete_txn( + txn, + table="threepid_validation_token", + keyvalues={"session_id": session_id}, + ) + self.db_pool.simple_delete_txn( + txn, + table="threepid_validation_session", + keyvalues={"session_id": session_id}, + ) + + return self.db_pool.runInteraction( + "delete_threepid_session", delete_threepid_session_txn + ) + + +class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs) + + self.clock = hs.get_clock() + self.config = hs.config + + self.db_pool.updates.register_background_index_update( + "access_tokens_device_index", + index_name="access_tokens_device_id", + table="access_tokens", + columns=["user_id", "device_id"], + ) + + self.db_pool.updates.register_background_index_update( + "users_creation_ts", + index_name="users_creation_ts", + table="users", + columns=["creation_ts"], + ) + + # we no longer use refresh tokens, but it's possible that some people + # might have a background update queued to build this index. Just + # clear the background update. + self.db_pool.updates.register_noop_background_update( + "refresh_tokens_device_index" + ) + + self.db_pool.updates.register_background_update_handler( + "user_threepids_grandfather", self._bg_user_threepids_grandfather + ) + + self.db_pool.updates.register_background_update_handler( + "users_set_deactivated_flag", self._background_update_set_deactivated_flag + ) + + @defer.inlineCallbacks + def _background_update_set_deactivated_flag(self, progress, batch_size): + """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1 + for each of them. + """ + + last_user = progress.get("user_id", "") + + def _background_update_set_deactivated_flag_txn(txn): + txn.execute( + """ + SELECT + users.name, + COUNT(access_tokens.token) AS count_tokens, + COUNT(user_threepids.address) AS count_threepids + FROM users + LEFT JOIN access_tokens ON (access_tokens.user_id = users.name) + LEFT JOIN user_threepids ON (user_threepids.user_id = users.name) + WHERE (users.password_hash IS NULL OR users.password_hash = '') + AND (users.appservice_id IS NULL OR users.appservice_id = '') + AND users.is_guest = 0 + AND users.name > ? + GROUP BY users.name + ORDER BY users.name ASC + LIMIT ?; + """, + (last_user, batch_size), + ) + + rows = self.db_pool.cursor_to_dict(txn) + + if not rows: + return True, 0 + + rows_processed_nb = 0 + + for user in rows: + if not user["count_tokens"] and not user["count_threepids"]: + self.set_user_deactivated_status_txn(txn, user["name"], True) + rows_processed_nb += 1 + + logger.info("Marked %d rows as deactivated", rows_processed_nb) + + self.db_pool.updates._background_update_progress_txn( + txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]} + ) + + if batch_size > len(rows): + return True, len(rows) + else: + return False, len(rows) + + end, nb_processed = yield self.db_pool.runInteraction( + "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn + ) + + if end: + yield self.db_pool.updates._end_background_update( + "users_set_deactivated_flag" + ) + + return nb_processed + + @defer.inlineCallbacks + def _bg_user_threepids_grandfather(self, progress, batch_size): + """We now track which identity servers a user binds their 3PID to, so + we need to handle the case of existing bindings where we didn't track + this. + + We do this by grandfathering in existing user threepids assuming that + they used one of the server configured trusted identity servers. + """ + id_servers = set(self.config.trusted_third_party_id_servers) + + def _bg_user_threepids_grandfather_txn(txn): + sql = """ + INSERT INTO user_threepid_id_server + (user_id, medium, address, id_server) + SELECT user_id, medium, address, ? + FROM user_threepids + """ + + txn.executemany(sql, [(id_server,) for id_server in id_servers]) + + if id_servers: + yield self.db_pool.runInteraction( + "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn + ) + + yield self.db_pool.updates._end_background_update("user_threepids_grandfather") + + return 1 + + +class RegistrationStore(RegistrationBackgroundUpdateStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super(RegistrationStore, self).__init__(database, db_conn, hs) + + self._account_validity = hs.config.account_validity + + if self._account_validity.enabled: + self._clock.call_later( + 0.0, + run_as_background_process, + "account_validity_set_expiration_dates", + self._set_expiration_date_when_missing, + ) + + # Create a background job for culling expired 3PID validity tokens + def start_cull(): + # run as a background process to make sure that the database transactions + # have a logcontext to report to + return run_as_background_process( + "cull_expired_threepid_validation_tokens", + self.cull_expired_threepid_validation_tokens, + ) + + hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS) + + @defer.inlineCallbacks + def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms): + """Adds an access token for the given user. + + Args: + user_id (str): The user ID. + token (str): The new access token to add. + device_id (str): ID of the device to associate with the access + token + valid_until_ms (int|None): when the token is valid until. None for + no expiry. + Raises: + StoreError if there was a problem adding this. + """ + next_id = self._access_tokens_id_gen.get_next() + + yield self.db_pool.simple_insert( + "access_tokens", + { + "id": next_id, + "user_id": user_id, + "token": token, + "device_id": device_id, + "valid_until_ms": valid_until_ms, + }, + desc="add_access_token_to_user", + ) + + def register_user( + self, + user_id, + password_hash=None, + was_guest=False, + make_guest=False, + appservice_id=None, + create_profile_with_displayname=None, + admin=False, + user_type=None, + ): + """Attempts to register an account. + + Args: + user_id (str): The desired user ID to register. + password_hash (str|None): Optional. The password hash for this user. + was_guest (bool): Optional. Whether this is a guest account being + upgraded to a non-guest account. + make_guest (boolean): True if the the new user should be guest, + false to add a regular user account. + appservice_id (str): The ID of the appservice registering the user. + create_profile_with_displayname (unicode): Optionally create a profile for + the user, setting their displayname to the given value + admin (boolean): is an admin user? + user_type (str|None): type of user. One of the values from + api.constants.UserTypes, or None for a normal user. + + Raises: + StoreError if the user_id could not be registered. + + Returns: + Deferred + """ + return self.db_pool.runInteraction( + "register_user", + self._register_user, + user_id, + password_hash, + was_guest, + make_guest, + appservice_id, + create_profile_with_displayname, + admin, + user_type, + ) + + def _register_user( + self, + txn, + user_id, + password_hash, + was_guest, + make_guest, + appservice_id, + create_profile_with_displayname, + admin, + user_type, + ): + user_id_obj = UserID.from_string(user_id) + + now = int(self.clock.time()) + + try: + if was_guest: + # Ensure that the guest user actually exists + # ``allow_none=False`` makes this raise an exception + # if the row isn't in the database. + self.db_pool.simple_select_one_txn( + txn, + "users", + keyvalues={"name": user_id, "is_guest": 1}, + retcols=("name",), + allow_none=False, + ) + + self.db_pool.simple_update_one_txn( + txn, + "users", + keyvalues={"name": user_id, "is_guest": 1}, + updatevalues={ + "password_hash": password_hash, + "upgrade_ts": now, + "is_guest": 1 if make_guest else 0, + "appservice_id": appservice_id, + "admin": 1 if admin else 0, + "user_type": user_type, + }, + ) + else: + self.db_pool.simple_insert_txn( + txn, + "users", + values={ + "name": user_id, + "password_hash": password_hash, + "creation_ts": now, + "is_guest": 1 if make_guest else 0, + "appservice_id": appservice_id, + "admin": 1 if admin else 0, + "user_type": user_type, + }, + ) + + except self.database_engine.module.IntegrityError: + raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE) + + if self._account_validity.enabled: + self.set_expiration_date_for_user_txn(txn, user_id) + + if create_profile_with_displayname: + # set a default displayname serverside to avoid ugly race + # between auto-joins and clients trying to set displaynames + # + # *obviously* the 'profiles' table uses localpart for user_id + # while everything else uses the full mxid. + txn.execute( + "INSERT INTO profiles(user_id, displayname) VALUES (?,?)", + (user_id_obj.localpart, create_profile_with_displayname), + ) + + if self.hs.config.stats_enabled: + # we create a new completed user statistics row + + # we don't strictly need current_token since this user really can't + # have any state deltas before now (as it is a new user), but still, + # we include it for completeness. + current_token = self._get_max_stream_id_in_current_state_deltas_txn(txn) + self._update_stats_delta_txn( + txn, now, "user", user_id, {}, complete_with_stream_id=current_token + ) + + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + txn.call_after(self.is_guest.invalidate, (user_id,)) + + def record_user_external_id( + self, auth_provider: str, external_id: str, user_id: str + ) -> Deferred: + """Record a mapping from an external user id to a mxid + + Args: + auth_provider: identifier for the remote auth provider + external_id: id on that system + user_id: complete mxid that it is mapped to + """ + return self.db_pool.simple_insert( + table="user_external_ids", + values={ + "auth_provider": auth_provider, + "external_id": external_id, + "user_id": user_id, + }, + desc="record_user_external_id", + ) + + def user_set_password_hash(self, user_id, password_hash): + """ + NB. This does *not* evict any cache because the one use for this + removes most of the entries subsequently anyway so it would be + pointless. Use flush_user separately. + """ + + def user_set_password_hash_txn(txn): + self.db_pool.simple_update_one_txn( + txn, "users", {"name": user_id}, {"password_hash": password_hash} + ) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + + return self.db_pool.runInteraction( + "user_set_password_hash", user_set_password_hash_txn + ) + + def user_set_consent_version(self, user_id, consent_version): + """Updates the user table to record privacy policy consent + + Args: + user_id (str): full mxid of the user to update + consent_version (str): version of the policy the user has consented + to + + Raises: + StoreError(404) if user not found + """ + + def f(txn): + self.db_pool.simple_update_one_txn( + txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"consent_version": consent_version}, + ) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + + return self.db_pool.runInteraction("user_set_consent_version", f) + + def user_set_consent_server_notice_sent(self, user_id, consent_version): + """Updates the user table to record that we have sent the user a server + notice about privacy policy consent + + Args: + user_id (str): full mxid of the user to update + consent_version (str): version of the policy we have notified the + user about + + Raises: + StoreError(404) if user not found + """ + + def f(txn): + self.db_pool.simple_update_one_txn( + txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"consent_server_notice_sent": consent_version}, + ) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + + return self.db_pool.runInteraction("user_set_consent_server_notice_sent", f) + + def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None): + """ + Invalidate access tokens belonging to a user + + Args: + user_id (str): ID of user the tokens belong to + except_token_id (str): list of access_tokens IDs which should + *not* be deleted + device_id (str|None): ID of device the tokens are associated with. + If None, tokens associated with any device (or no device) will + be deleted + Returns: + defer.Deferred[list[str, int, str|None, int]]: a list of + (token, token id, device id) for each of the deleted tokens + """ + + def f(txn): + keyvalues = {"user_id": user_id} + if device_id is not None: + keyvalues["device_id"] = device_id + + items = keyvalues.items() + where_clause = " AND ".join(k + " = ?" for k, _ in items) + values = [v for _, v in items] + if except_token_id: + where_clause += " AND id != ?" + values.append(except_token_id) + + txn.execute( + "SELECT token, id, device_id FROM access_tokens WHERE %s" + % where_clause, + values, + ) + tokens_and_devices = [(r[0], r[1], r[2]) for r in txn] + + for token, _, _ in tokens_and_devices: + self._invalidate_cache_and_stream( + txn, self.get_user_by_access_token, (token,) + ) + + txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values) + + return tokens_and_devices + + return self.db_pool.runInteraction("user_delete_access_tokens", f) + + def delete_access_token(self, access_token): + def f(txn): + self.db_pool.simple_delete_one_txn( + txn, table="access_tokens", keyvalues={"token": access_token} + ) + + self._invalidate_cache_and_stream( + txn, self.get_user_by_access_token, (access_token,) + ) + + return self.db_pool.runInteraction("delete_access_token", f) + + @cachedInlineCallbacks() + def is_guest(self, user_id): + res = yield self.db_pool.simple_select_one_onecol( + table="users", + keyvalues={"name": user_id}, + retcol="is_guest", + allow_none=True, + desc="is_guest", + ) + + return res if res else False + + def add_user_pending_deactivation(self, user_id): + """ + Adds a user to the table of users who need to be parted from all the rooms they're + in + """ + return self.db_pool.simple_insert( + "users_pending_deactivation", + values={"user_id": user_id}, + desc="add_user_pending_deactivation", + ) + + def del_user_pending_deactivation(self, user_id): + """ + Removes the given user to the table of users who need to be parted from all the + rooms they're in, effectively marking that user as fully deactivated. + """ + # XXX: This should be simple_delete_one but we failed to put a unique index on + # the table, so somehow duplicate entries have ended up in it. + return self.db_pool.simple_delete( + "users_pending_deactivation", + keyvalues={"user_id": user_id}, + desc="del_user_pending_deactivation", + ) + + def get_user_pending_deactivation(self): + """ + Gets one user from the table of users waiting to be parted from all the rooms + they're in. + """ + return self.db_pool.simple_select_one_onecol( + "users_pending_deactivation", + keyvalues={}, + retcol="user_id", + allow_none=True, + desc="get_users_pending_deactivation", + ) + + def validate_threepid_session(self, session_id, client_secret, token, current_ts): + """Attempt to validate a threepid session using a token + + Args: + session_id (str): The id of a validation session + client_secret (str): A unique string provided by the client to + help identify this validation attempt + token (str): A validation token + current_ts (int): The current unix time in milliseconds. Used for + checking token expiry status + + Raises: + ThreepidValidationError: if a matching validation token was not found or has + expired + + Returns: + deferred str|None: A str representing a link to redirect the user + to if there is one. + """ + + # Insert everything into a transaction in order to run atomically + def validate_threepid_session_txn(txn): + row = self.db_pool.simple_select_one_txn( + txn, + table="threepid_validation_session", + keyvalues={"session_id": session_id}, + retcols=["client_secret", "validated_at"], + allow_none=True, + ) + + if not row: + raise ThreepidValidationError(400, "Unknown session_id") + retrieved_client_secret = row["client_secret"] + validated_at = row["validated_at"] + + if retrieved_client_secret != client_secret: + raise ThreepidValidationError( + 400, "This client_secret does not match the provided session_id" + ) + + row = self.db_pool.simple_select_one_txn( + txn, + table="threepid_validation_token", + keyvalues={"session_id": session_id, "token": token}, + retcols=["expires", "next_link"], + allow_none=True, + ) + + if not row: + raise ThreepidValidationError( + 400, "Validation token not found or has expired" + ) + expires = row["expires"] + next_link = row["next_link"] + + # If the session is already validated, no need to revalidate + if validated_at: + return next_link + + if expires <= current_ts: + raise ThreepidValidationError( + 400, "This token has expired. Please request a new one" + ) + + # Looks good. Validate the session + self.db_pool.simple_update_txn( + txn, + table="threepid_validation_session", + keyvalues={"session_id": session_id}, + updatevalues={"validated_at": self.clock.time_msec()}, + ) + + return next_link + + # Return next_link if it exists + return self.db_pool.runInteraction( + "validate_threepid_session_txn", validate_threepid_session_txn + ) + + def upsert_threepid_validation_session( + self, + medium, + address, + client_secret, + send_attempt, + session_id, + validated_at=None, + ): + """Upsert a threepid validation session + Args: + medium (str): The medium of the 3PID + address (str): The address of the 3PID + client_secret (str): A unique string provided by the client to + help identify this validation attempt + send_attempt (int): The latest send_attempt on this session + session_id (str): The id of this validation session + validated_at (int|None): The unix timestamp in milliseconds of + when the session was marked as valid + """ + insertion_values = { + "medium": medium, + "address": address, + "client_secret": client_secret, + } + + if validated_at: + insertion_values["validated_at"] = validated_at + + return self.db_pool.simple_upsert( + table="threepid_validation_session", + keyvalues={"session_id": session_id}, + values={"last_send_attempt": send_attempt}, + insertion_values=insertion_values, + desc="upsert_threepid_validation_session", + ) + + def start_or_continue_validation_session( + self, + medium, + address, + session_id, + client_secret, + send_attempt, + next_link, + token, + token_expires, + ): + """Creates a new threepid validation session if it does not already + exist and associates a new validation token with it + + Args: + medium (str): The medium of the 3PID + address (str): The address of the 3PID + session_id (str): The id of this validation session + client_secret (str): A unique string provided by the client to + help identify this validation attempt + send_attempt (int): The latest send_attempt on this session + next_link (str|None): The link to redirect the user to upon + successful validation + token (str): The validation token + token_expires (int): The timestamp for which after the token + will no longer be valid + """ + + def start_or_continue_validation_session_txn(txn): + # Create or update a validation session + self.db_pool.simple_upsert_txn( + txn, + table="threepid_validation_session", + keyvalues={"session_id": session_id}, + values={"last_send_attempt": send_attempt}, + insertion_values={ + "medium": medium, + "address": address, + "client_secret": client_secret, + }, + ) + + # Create a new validation token with this session ID + self.db_pool.simple_insert_txn( + txn, + table="threepid_validation_token", + values={ + "session_id": session_id, + "token": token, + "next_link": next_link, + "expires": token_expires, + }, + ) + + return self.db_pool.runInteraction( + "start_or_continue_validation_session", + start_or_continue_validation_session_txn, + ) + + def cull_expired_threepid_validation_tokens(self): + """Remove threepid validation tokens with expiry dates that have passed""" + + def cull_expired_threepid_validation_tokens_txn(txn, ts): + sql = """ + DELETE FROM threepid_validation_token WHERE + expires < ? + """ + return txn.execute(sql, (ts,)) + + return self.db_pool.runInteraction( + "cull_expired_threepid_validation_tokens", + cull_expired_threepid_validation_tokens_txn, + self.clock.time_msec(), + ) + + @defer.inlineCallbacks + def set_user_deactivated_status(self, user_id, deactivated): + """Set the `deactivated` property for the provided user to the provided value. + + Args: + user_id (str): The ID of the user to set the status for. + deactivated (bool): The value to set for `deactivated`. + """ + + yield self.db_pool.runInteraction( + "set_user_deactivated_status", + self.set_user_deactivated_status_txn, + user_id, + deactivated, + ) + + def set_user_deactivated_status_txn(self, txn, user_id, deactivated): + self.db_pool.simple_update_one_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"deactivated": 1 if deactivated else 0}, + ) + self._invalidate_cache_and_stream( + txn, self.get_user_deactivated_status, (user_id,) + ) + + @defer.inlineCallbacks + def _set_expiration_date_when_missing(self): + """ + Retrieves the list of registered users that don't have an expiration date, and + adds an expiration date for each of them. + """ + + def select_users_with_no_expiration_date_txn(txn): + """Retrieves the list of registered users with no expiration date from the + database, filtering out deactivated users. + """ + sql = ( + "SELECT users.name FROM users" + " LEFT JOIN account_validity ON (users.name = account_validity.user_id)" + " WHERE account_validity.user_id is NULL AND users.deactivated = 0;" + ) + txn.execute(sql, []) + + res = self.db_pool.cursor_to_dict(txn) + if res: + for user in res: + self.set_expiration_date_for_user_txn( + txn, user["name"], use_delta=True + ) + + yield self.db_pool.runInteraction( + "get_users_with_no_expiration_date", + select_users_with_no_expiration_date_txn, + ) + + def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False): + """Sets an expiration date to the account with the given user ID. + + Args: + user_id (str): User ID to set an expiration date for. + use_delta (bool): If set to False, the expiration date for the user will be + now + validity period. If set to True, this expiration date will be a + random value in the [now + period - d ; now + period] range, d being a + delta equal to 10% of the validity period. + """ + now_ms = self._clock.time_msec() + expiration_ts = now_ms + self._account_validity.period + + if use_delta: + expiration_ts = self.rand.randrange( + expiration_ts - self._account_validity.startup_job_max_delta, + expiration_ts, + ) + + self.db_pool.simple_upsert_txn( + txn, + "account_validity", + keyvalues={"user_id": user_id}, + values={"expiration_ts_ms": expiration_ts, "email_sent": False}, + ) + + +def find_max_generated_user_id_localpart(cur: Cursor) -> int: + """ + Gets the localpart of the max current generated user ID. + + Generated user IDs are integers, so we find the largest integer user ID + already taken and return that. + """ + + # We bound between '@0' and '@a' to avoid pulling the entire table + # out. + cur.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'") + + regex = re.compile(r"^@(\d+):") + + max_found = 0 + + for (user_id,) in cur: + match = regex.search(user_id) + if match: + max_found = max(int(match.group(1)), max_found) + return max_found -- cgit 1.5.1 From a0acdfa9e93ae63a3adee264d5420fdd1d38d76e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 11 Aug 2020 17:21:13 -0400 Subject: Converts event_federation and registration databases to async/await (#8061) --- changelog.d/8061.misc | 1 + synapse/storage/databases/main/event_federation.py | 38 ++-- synapse/storage/databases/main/registration.py | 233 ++++++++++----------- synapse/storage/databases/state/bg_updates.py | 18 +- tests/handlers/test_register.py | 11 +- tests/storage/test_monthly_active_users.py | 8 +- tests/storage/test_registration.py | 18 +- 7 files changed, 150 insertions(+), 177 deletions(-) create mode 100644 changelog.d/8061.misc (limited to 'synapse/storage/databases/main/registration.py') diff --git a/changelog.d/8061.misc b/changelog.d/8061.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8061.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index eddb32b4d3..484875f989 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -15,9 +15,7 @@ import itertools import logging from queue import Empty, PriorityQueue -from typing import Dict, List, Optional, Set, Tuple - -from twisted.internet import defer +from typing import Dict, Iterable, List, Optional, Set, Tuple from synapse.api.errors import StoreError from synapse.metrics.background_process_metrics import run_as_background_process @@ -286,17 +284,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return dict(txn) - @defer.inlineCallbacks - def get_max_depth_of(self, event_ids): + async def get_max_depth_of(self, event_ids: List[str]) -> int: """Returns the max depth of a set of event IDs Args: - event_ids (list[str]) - - Returns - Deferred[int] + event_ids: The event IDs to calculate the max depth of. """ - rows = yield self.db_pool.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="events", column="event_id", iterable=event_ids, @@ -550,9 +544,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return event_results - @defer.inlineCallbacks - def get_missing_events(self, room_id, earliest_events, latest_events, limit): - ids = yield self.db_pool.runInteraction( + async def get_missing_events(self, room_id, earliest_events, latest_events, limit): + ids = await self.db_pool.runInteraction( "get_missing_events", self._get_missing_events, room_id, @@ -560,7 +553,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas latest_events, limit, ) - events = yield self.get_events_as_list(ids) + events = await self.get_events_as_list(ids) return events def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit): @@ -595,17 +588,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas event_results.reverse() return event_results - @defer.inlineCallbacks - def get_successor_events(self, event_ids): + async def get_successor_events(self, event_ids: Iterable[str]) -> List[str]: """Fetch all events that have the given events as a prev event Args: - event_ids (iterable[str]) - - Returns: - Deferred[list[str]] + event_ids: The events to use as the previous events. """ - rows = yield self.db_pool.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="event_edges", column="prev_event_id", iterable=event_ids, @@ -674,8 +663,7 @@ class EventFederationStore(EventFederationWorkerStore): txn.execute(query, (room_id,)) txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) - @defer.inlineCallbacks - def _background_delete_non_state_event_auth(self, progress, batch_size): + async def _background_delete_non_state_event_auth(self, progress, batch_size): def delete_event_auth(txn): target_min_stream_id = progress.get("target_min_stream_id_inclusive") max_stream_id = progress.get("max_stream_id_exclusive") @@ -714,12 +702,12 @@ class EventFederationStore(EventFederationWorkerStore): return min_stream_id >= target_min_stream_id - result = yield self.db_pool.runInteraction( + result = await self.db_pool.runInteraction( self.EVENT_AUTH_STATE_ONLY, delete_event_auth ) if not result: - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.EVENT_AUTH_STATE_ONLY ) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index f618629e09..402ae25571 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -17,9 +17,8 @@ import logging import re -from typing import Optional +from typing import Dict, List, Optional -from twisted.internet import defer from twisted.internet.defer import Deferred from synapse.api.constants import UserTypes @@ -30,7 +29,7 @@ from synapse.storage.database import DatabasePool from synapse.storage.types import Cursor from synapse.storage.util.sequence import build_sequence_generator from synapse.types import UserID -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import cached THIRTY_MINUTES_IN_MS = 30 * 60 * 1000 @@ -69,19 +68,15 @@ class RegistrationWorkerStore(SQLBaseStore): desc="get_user_by_id", ) - @defer.inlineCallbacks - def is_trial_user(self, user_id): + async def is_trial_user(self, user_id: str) -> bool: """Checks if user is in the "trial" period, i.e. within the first N days of registration defined by `mau_trial_days` config Args: - user_id (str) - - Returns: - Deferred[bool] + user_id: The user to check for trial status. """ - info = yield self.get_user_by_id(user_id) + info = await self.get_user_by_id(user_id) if not info: return False @@ -105,41 +100,42 @@ class RegistrationWorkerStore(SQLBaseStore): "get_user_by_access_token", self._query_for_auth, token ) - @cachedInlineCallbacks() - def get_expiration_ts_for_user(self, user_id): + @cached() + async def get_expiration_ts_for_user(self, user_id: str) -> Optional[None]: """Get the expiration timestamp for the account bearing a given user ID. Args: - user_id (str): The ID of the user. + user_id: The ID of the user. Returns: - defer.Deferred: None, if the account has no expiration timestamp, - otherwise int representation of the timestamp (as a number of - milliseconds since epoch). + None, if the account has no expiration timestamp, otherwise int + representation of the timestamp (as a number of milliseconds since epoch). """ - res = yield self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="account_validity", keyvalues={"user_id": user_id}, retcol="expiration_ts_ms", allow_none=True, desc="get_expiration_ts_for_user", ) - return res - @defer.inlineCallbacks - def set_account_validity_for_user( - self, user_id, expiration_ts, email_sent, renewal_token=None - ): + async def set_account_validity_for_user( + self, + user_id: str, + expiration_ts: int, + email_sent: bool, + renewal_token: Optional[str] = None, + ) -> None: """Updates the account validity properties of the given account, with the given values. Args: - user_id (str): ID of the account to update properties for. - expiration_ts (int): New expiration date, as a timestamp in milliseconds + user_id: ID of the account to update properties for. + expiration_ts: New expiration date, as a timestamp in milliseconds since epoch. - email_sent (bool): True means a renewal email has been sent for this - account and there's no need to send another one for the current validity + email_sent: True means a renewal email has been sent for this account + and there's no need to send another one for the current validity period. - renewal_token (str): Renewal token the user can use to extend the validity + renewal_token: Renewal token the user can use to extend the validity of their account. Defaults to no token. """ @@ -158,75 +154,69 @@ class RegistrationWorkerStore(SQLBaseStore): txn, self.get_expiration_ts_for_user, (user_id,) ) - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "set_account_validity_for_user", set_account_validity_for_user_txn ) - @defer.inlineCallbacks - def set_renewal_token_for_user(self, user_id, renewal_token): + async def set_renewal_token_for_user( + self, user_id: str, renewal_token: str + ) -> None: """Defines a renewal token for a given user. Args: - user_id (str): ID of the user to set the renewal token for. - renewal_token (str): Random unique string that will be used to renew the + user_id: ID of the user to set the renewal token for. + renewal_token: Random unique string that will be used to renew the user's account. Raises: StoreError: The provided token is already set for another user. """ - yield self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="account_validity", keyvalues={"user_id": user_id}, updatevalues={"renewal_token": renewal_token}, desc="set_renewal_token_for_user", ) - @defer.inlineCallbacks - def get_user_from_renewal_token(self, renewal_token): + async def get_user_from_renewal_token(self, renewal_token: str) -> str: """Get a user ID from a renewal token. Args: - renewal_token (str): The renewal token to perform the lookup with. + renewal_token: The renewal token to perform the lookup with. Returns: - defer.Deferred[str]: The ID of the user to which the token belongs. + The ID of the user to which the token belongs. """ - res = yield self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="account_validity", keyvalues={"renewal_token": renewal_token}, retcol="user_id", desc="get_user_from_renewal_token", ) - return res - - @defer.inlineCallbacks - def get_renewal_token_for_user(self, user_id): + async def get_renewal_token_for_user(self, user_id: str) -> str: """Get the renewal token associated with a given user ID. Args: - user_id (str): The user ID to lookup a token for. + user_id: The user ID to lookup a token for. Returns: - defer.Deferred[str]: The renewal token associated with this user ID. + The renewal token associated with this user ID. """ - res = yield self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="account_validity", keyvalues={"user_id": user_id}, retcol="renewal_token", desc="get_renewal_token_for_user", ) - return res - - @defer.inlineCallbacks - def get_users_expiring_soon(self): + async def get_users_expiring_soon(self) -> List[Dict[str, int]]: """Selects users whose account will expire in the [now, now + renew_at] time window (see configuration for account_validity for information on what renew_at refers to). Returns: - Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]] + A list of dictionaries mapping user ID to expiration time (in milliseconds). """ def select_users_txn(txn, now_ms, renew_at): @@ -238,53 +228,49 @@ class RegistrationWorkerStore(SQLBaseStore): txn.execute(sql, values) return self.db_pool.cursor_to_dict(txn) - res = yield self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_users_expiring_soon", select_users_txn, self.clock.time_msec(), self.config.account_validity.renew_at, ) - return res - - @defer.inlineCallbacks - def set_renewal_mail_status(self, user_id, email_sent): + async def set_renewal_mail_status(self, user_id: str, email_sent: bool) -> None: """Sets or unsets the flag that indicates whether a renewal email has been sent to the user (and the user hasn't renewed their account yet). Args: - user_id (str): ID of the user to set/unset the flag for. - email_sent (bool): Flag which indicates whether a renewal email has been sent + user_id: ID of the user to set/unset the flag for. + email_sent: Flag which indicates whether a renewal email has been sent to this user. """ - yield self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="account_validity", keyvalues={"user_id": user_id}, updatevalues={"email_sent": email_sent}, desc="set_renewal_mail_status", ) - @defer.inlineCallbacks - def delete_account_validity_for_user(self, user_id): + async def delete_account_validity_for_user(self, user_id: str) -> None: """Deletes the entry for the given user in the account validity table, removing their expiration date and renewal token. Args: - user_id (str): ID of the user to remove from the account validity table. + user_id: ID of the user to remove from the account validity table. """ - yield self.db_pool.simple_delete_one( + await self.db_pool.simple_delete_one( table="account_validity", keyvalues={"user_id": user_id}, desc="delete_account_validity_for_user", ) - async def is_server_admin(self, user): + async def is_server_admin(self, user: UserID) -> bool: """Determines if a user is an admin of this homeserver. Args: - user (UserID): user ID of the user to test + user: user ID of the user to test - Returns (bool): + Returns: true iff the user is a server admin, false otherwise. """ res = await self.db_pool.simple_select_one_onecol( @@ -332,32 +318,31 @@ class RegistrationWorkerStore(SQLBaseStore): return None - @cachedInlineCallbacks() - def is_real_user(self, user_id): + @cached() + async def is_real_user(self, user_id: str) -> bool: """Determines if the user is a real user, ie does not have a 'user_type'. Args: - user_id (str): user id to test + user_id: user id to test Returns: - Deferred[bool]: True if user 'user_type' is null or empty string + True if user 'user_type' is null or empty string """ - res = yield self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "is_real_user", self.is_real_user_txn, user_id ) - return res @cached() - def is_support_user(self, user_id): + async def is_support_user(self, user_id: str) -> bool: """Determines if the user is of type UserTypes.SUPPORT Args: - user_id (str): user id to test + user_id: user id to test Returns: - Deferred[bool]: True if user is of type UserTypes.SUPPORT + True if user is of type UserTypes.SUPPORT """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "is_support_user", self.is_support_user_txn, user_id ) @@ -413,8 +398,7 @@ class RegistrationWorkerStore(SQLBaseStore): desc="get_user_by_external_id", ) - @defer.inlineCallbacks - def count_all_users(self): + async def count_all_users(self): """Counts all users registered on the homeserver.""" def _count_users(txn): @@ -424,8 +408,7 @@ class RegistrationWorkerStore(SQLBaseStore): return rows[0]["users"] return 0 - ret = yield self.db_pool.runInteraction("count_users", _count_users) - return ret + return await self.db_pool.runInteraction("count_users", _count_users) def count_daily_user_type(self): """ @@ -460,8 +443,7 @@ class RegistrationWorkerStore(SQLBaseStore): "count_daily_user_type", _count_daily_user_type ) - @defer.inlineCallbacks - def count_nonbridged_users(self): + async def count_nonbridged_users(self): def _count_users(txn): txn.execute( """ @@ -472,11 +454,9 @@ class RegistrationWorkerStore(SQLBaseStore): (count,) = txn.fetchone() return count - ret = yield self.db_pool.runInteraction("count_users", _count_users) - return ret + return await self.db_pool.runInteraction("count_users", _count_users) - @defer.inlineCallbacks - def count_real_users(self): + async def count_real_users(self): """Counts all users without a special user_type registered on the homeserver.""" def _count_users(txn): @@ -486,8 +466,7 @@ class RegistrationWorkerStore(SQLBaseStore): return rows[0]["users"] return 0 - ret = yield self.db_pool.runInteraction("count_real_users", _count_users) - return ret + return await self.db_pool.runInteraction("count_real_users", _count_users) async def generate_user_id(self) -> str: """Generate a suitable localpart for a guest user @@ -537,23 +516,20 @@ class RegistrationWorkerStore(SQLBaseStore): return ret["user_id"] return None - @defer.inlineCallbacks - def user_add_threepid(self, user_id, medium, address, validated_at, added_at): - yield self.db_pool.simple_upsert( + async def user_add_threepid(self, user_id, medium, address, validated_at, added_at): + await self.db_pool.simple_upsert( "user_threepids", {"medium": medium, "address": address}, {"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, ) - @defer.inlineCallbacks - def user_get_threepids(self, user_id): - ret = yield self.db_pool.simple_select_list( + async def user_get_threepids(self, user_id): + return await self.db_pool.simple_select_list( "user_threepids", {"user_id": user_id}, ["medium", "address", "validated_at", "added_at"], "user_get_threepids", ) - return ret def user_delete_threepid(self, user_id, medium, address): return self.db_pool.simple_delete( @@ -668,18 +644,18 @@ class RegistrationWorkerStore(SQLBaseStore): desc="get_id_servers_user_bound", ) - @cachedInlineCallbacks() - def get_user_deactivated_status(self, user_id): + @cached() + async def get_user_deactivated_status(self, user_id: str) -> bool: """Retrieve the value for the `deactivated` property for the provided user. Args: - user_id (str): The ID of the user to retrieve the status for. + user_id: The ID of the user to retrieve the status for. Returns: - defer.Deferred(bool): The requested value. + True if the user was deactivated, false if the user is still active. """ - res = yield self.db_pool.simple_select_one_onecol( + res = await self.db_pool.simple_select_one_onecol( table="users", keyvalues={"name": user_id}, retcol="deactivated", @@ -818,8 +794,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): "users_set_deactivated_flag", self._background_update_set_deactivated_flag ) - @defer.inlineCallbacks - def _background_update_set_deactivated_flag(self, progress, batch_size): + async def _background_update_set_deactivated_flag(self, progress, batch_size): """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1 for each of them. """ @@ -870,19 +845,18 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): else: return False, len(rows) - end, nb_processed = yield self.db_pool.runInteraction( + end, nb_processed = await self.db_pool.runInteraction( "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn ) if end: - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( "users_set_deactivated_flag" ) return nb_processed - @defer.inlineCallbacks - def _bg_user_threepids_grandfather(self, progress, batch_size): + async def _bg_user_threepids_grandfather(self, progress, batch_size): """We now track which identity servers a user binds their 3PID to, so we need to handle the case of existing bindings where we didn't track this. @@ -903,11 +877,11 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): txn.executemany(sql, [(id_server,) for id_server in id_servers]) if id_servers: - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn ) - yield self.db_pool.updates._end_background_update("user_threepids_grandfather") + await self.db_pool.updates._end_background_update("user_threepids_grandfather") return 1 @@ -937,23 +911,26 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS) - @defer.inlineCallbacks - def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms): + async def add_access_token_to_user( + self, + user_id: str, + token: str, + device_id: Optional[str], + valid_until_ms: Optional[int], + ) -> None: """Adds an access token for the given user. Args: - user_id (str): The user ID. - token (str): The new access token to add. - device_id (str): ID of the device to associate with the access - token - valid_until_ms (int|None): when the token is valid until. None for - no expiry. + user_id: The user ID. + token: The new access token to add. + device_id: ID of the device to associate with the access token + valid_until_ms: when the token is valid until. None for no expiry. Raises: StoreError if there was a problem adding this. """ next_id = self._access_tokens_id_gen.get_next() - yield self.db_pool.simple_insert( + await self.db_pool.simple_insert( "access_tokens", { "id": next_id, @@ -1097,7 +1074,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - txn.call_after(self.is_guest.invalidate, (user_id,)) def record_user_external_id( self, auth_provider: str, external_id: str, user_id: str @@ -1241,9 +1217,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): return self.db_pool.runInteraction("delete_access_token", f) - @cachedInlineCallbacks() - def is_guest(self, user_id): - res = yield self.db_pool.simple_select_one_onecol( + @cached() + async def is_guest(self, user_id: str) -> bool: + res = await self.db_pool.simple_select_one_onecol( table="users", keyvalues={"name": user_id}, retcol="is_guest", @@ -1481,16 +1457,17 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): self.clock.time_msec(), ) - @defer.inlineCallbacks - def set_user_deactivated_status(self, user_id, deactivated): + async def set_user_deactivated_status( + self, user_id: str, deactivated: bool + ) -> None: """Set the `deactivated` property for the provided user to the provided value. Args: - user_id (str): The ID of the user to set the status for. - deactivated (bool): The value to set for `deactivated`. + user_id: The ID of the user to set the status for. + deactivated: The value to set for `deactivated`. """ - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "set_user_deactivated_status", self.set_user_deactivated_status_txn, user_id, @@ -1507,9 +1484,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): self._invalidate_cache_and_stream( txn, self.get_user_deactivated_status, (user_id,) ) + txn.call_after(self.is_guest.invalidate, (user_id,)) - @defer.inlineCallbacks - def _set_expiration_date_when_missing(self): + async def _set_expiration_date_when_missing(self): """ Retrieves the list of registered users that don't have an expiration date, and adds an expiration date for each of them. @@ -1533,7 +1510,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): txn, user["name"], use_delta=True ) - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "get_users_with_no_expiration_date", select_users_with_no_expiration_date_txn, ) diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index 1e2d584098..139085b672 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine @@ -198,8 +196,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): columns=["room_id"], ) - @defer.inlineCallbacks - def _background_deduplicate_state(self, progress, batch_size): + async def _background_deduplicate_state(self, progress, batch_size): """This background update will slowly deduplicate state by reencoding them as deltas. """ @@ -212,7 +209,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) if max_group is None: - rows = yield self.db_pool.execute( + rows = await self.db_pool.execute( "_background_deduplicate_state", None, "SELECT coalesce(max(id), 0) FROM state_groups", @@ -330,19 +327,18 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): return False, batch_size - finished, result = yield self.db_pool.runInteraction( + finished, result = await self.db_pool.runInteraction( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn ) if finished: - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME ) return result * BATCH_SIZE_SCALE_FACTOR - @defer.inlineCallbacks - def _background_index_state(self, progress, batch_size): + async def _background_index_state(self, progress, batch_size): def reindex_txn(conn): conn.rollback() if isinstance(self.database_engine, PostgresEngine): @@ -365,9 +361,9 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): ) txn.execute("DROP INDEX IF EXISTS state_groups_state_id") - yield self.db_pool.runWithConnection(reindex_txn) + await self.db_pool.runWithConnection(reindex_txn) - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.STATE_GROUP_INDEX_UPDATE_NAME ) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 6d45c4b233..e364b1bd62 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -22,6 +22,7 @@ from synapse.api.errors import Codes, ResourceLimitError, SynapseError from synapse.handlers.register import RegistrationHandler from synapse.types import RoomAlias, UserID, create_requester +from tests.test_utils import make_awaitable from tests.unittest import override_config from .. import unittest @@ -187,7 +188,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - self.store.is_real_user = Mock(return_value=defer.succeed(False)) + self.store.is_real_user = Mock(return_value=make_awaitable(False)) user_id = self.get_success(self.handler.register_user(localpart="support")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) @@ -199,8 +200,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self): room_alias_str = "#room:test" - self.store.count_real_users = Mock(return_value=defer.succeed(1)) - self.store.is_real_user = Mock(return_value=defer.succeed(True)) + self.store.count_real_users = Mock(return_value=make_awaitable(1)) + self.store.is_real_user = Mock(return_value=make_awaitable(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) directory_handler = self.hs.get_handlers().directory_handler @@ -214,8 +215,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - self.store.count_real_users = Mock(return_value=defer.succeed(2)) - self.store.is_real_user = Mock(return_value=defer.succeed(True)) + self.store.count_real_users = Mock(return_value=make_awaitable(2)) + self.store.is_real_user = Mock(return_value=make_awaitable(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index e793781a26..9870c74883 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -300,8 +300,12 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.get_success(self.store.register_user(user_id=user2, password_hash=None)) now = int(self.hs.get_clock().time_msec()) - self.store.user_add_threepid(user1, "email", user1_email, now, now) - self.store.user_add_threepid(user2, "email", user2_email, now, now) + self.get_success( + self.store.user_add_threepid(user1, "email", user1_email, now, now) + ) + self.get_success( + self.store.user_add_threepid(user2, "email", user2_email, now, now) + ) users = self.get_success(self.store.get_registered_reserved_users()) self.assertEqual(len(users), len(threepids)) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 71a40a0a49..840db66072 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -58,8 +58,10 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_add_tokens(self): yield self.store.register_user(self.user_id, self.pwhash) - yield self.store.add_access_token_to_user( - self.user_id, self.tokens[1], self.device_id, valid_until_ms=None + yield defer.ensureDeferred( + self.store.add_access_token_to_user( + self.user_id, self.tokens[1], self.device_id, valid_until_ms=None + ) ) result = yield self.store.get_user_by_access_token(self.tokens[1]) @@ -74,11 +76,15 @@ class RegistrationStoreTestCase(unittest.TestCase): def test_user_delete_access_tokens(self): # add some tokens yield self.store.register_user(self.user_id, self.pwhash) - yield self.store.add_access_token_to_user( - self.user_id, self.tokens[0], device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.store.add_access_token_to_user( + self.user_id, self.tokens[0], device_id=None, valid_until_ms=None + ) ) - yield self.store.add_access_token_to_user( - self.user_id, self.tokens[1], self.device_id, valid_until_ms=None + yield defer.ensureDeferred( + self.store.add_access_token_to_user( + self.user_id, self.tokens[1], self.device_id, valid_until_ms=None + ) ) # now delete some -- cgit 1.5.1 From 6b7ce1d332766bc2da9e99b22a452e0813a2aae3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Aug 2020 09:25:40 -0400 Subject: Remove some unused database functions. (#8085) --- changelog.d/8085.misc | 1 + synapse/storage/databases/main/event_federation.py | 13 -- synapse/storage/databases/main/events_worker.py | 170 +-------------------- synapse/storage/databases/main/presence.py | 21 --- synapse/storage/databases/main/registration.py | 37 ----- synapse/storage/databases/main/room.py | 4 - .../delta/58/13remove_presence_allow_inbound.sql | 17 +++ 7 files changed, 19 insertions(+), 244 deletions(-) create mode 100644 changelog.d/8085.misc create mode 100644 synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql (limited to 'synapse/storage/databases/main/registration.py') diff --git a/changelog.d/8085.misc b/changelog.d/8085.misc new file mode 100644 index 0000000000..c3da1e297c --- /dev/null +++ b/changelog.d/8085.misc @@ -0,0 +1 @@ +Remove some unused database functions. diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 484875f989..431bd76693 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -257,11 +257,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas # Return all events where not all sets can reach them. return {eid for eid, n in event_to_missing_sets.items() if n} - def get_oldest_events_in_room(self, room_id): - return self.db_pool.runInteraction( - "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id - ) - def get_oldest_events_with_depth_in_room(self, room_id): return self.db_pool.runInteraction( "get_oldest_events_with_depth_in_room", @@ -303,14 +298,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas else: return max(row["depth"] for row in rows) - def _get_oldest_events_in_room_txn(self, txn, room_id): - return self.db_pool.simple_select_onecol_txn( - txn, - table="event_backward_extremities", - keyvalues={"room_id": room_id}, - retcol="event_id", - ) - def get_prev_events_for_room(self, room_id: str): """ Gets a subset of the current forward extremities in the given room. diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 755b7a2a85..5687448e3d 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -43,7 +43,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import get_domain_from_id -from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import Cache, cachedInlineCallbacks from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -137,42 +137,6 @@ class EventsWorkerStore(SQLBaseStore): desc="get_received_ts", ) - def get_received_ts_by_stream_pos(self, stream_ordering): - """Given a stream ordering get an approximate timestamp of when it - happened. - - This is done by simply taking the received ts of the first event that - has a stream ordering greater than or equal to the given stream pos. - If none exists returns the current time, on the assumption that it must - have happened recently. - - Args: - stream_ordering (int) - - Returns: - Deferred[int] - """ - - def _get_approximate_received_ts_txn(txn): - sql = """ - SELECT received_ts FROM events - WHERE stream_ordering >= ? - LIMIT 1 - """ - - txn.execute(sql, (stream_ordering,)) - row = txn.fetchone() - if row and row[0]: - ts = row[0] - else: - ts = self.clock.time_msec() - - return ts - - return self.db_pool.runInteraction( - "get_approximate_received_ts", _get_approximate_received_ts_txn - ) - @defer.inlineCallbacks def get_event( self, @@ -923,36 +887,6 @@ class EventsWorkerStore(SQLBaseStore): ) return results - def _get_total_state_event_counts_txn(self, txn, room_id): - """ - See get_total_state_event_counts. - """ - # We join against the events table as that has an index on room_id - sql = """ - SELECT COUNT(*) FROM state_events - INNER JOIN events USING (room_id, event_id) - WHERE room_id=? - """ - txn.execute(sql, (room_id,)) - row = txn.fetchone() - return row[0] if row else 0 - - def get_total_state_event_counts(self, room_id): - """ - Gets the total number of state events in a room. - - Args: - room_id (str) - - Returns: - Deferred[int] - """ - return self.db_pool.runInteraction( - "get_total_state_event_counts", - self._get_total_state_event_counts_txn, - room_id, - ) - def _get_current_state_event_counts_txn(self, txn, room_id): """ See get_current_state_event_counts. @@ -1222,97 +1156,6 @@ class EventsWorkerStore(SQLBaseStore): return rows, to_token, True - @cached(num_args=5, max_entries=10) - def get_all_new_events( - self, - last_backfill_id, - last_forward_id, - current_backfill_id, - current_forward_id, - limit, - ): - """Get all the new events that have arrived at the server either as - new events or as backfilled events""" - have_backfill_events = last_backfill_id != current_backfill_id - have_forward_events = last_forward_id != current_forward_id - - if not have_backfill_events and not have_forward_events: - return defer.succeed(AllNewEventsResult([], [], [], [], [])) - - def get_all_new_events_txn(txn): - sql = ( - "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " WHERE ? < stream_ordering AND stream_ordering <= ?" - " ORDER BY stream_ordering ASC" - " LIMIT ?" - ) - if have_forward_events: - txn.execute(sql, (last_forward_id, current_forward_id, limit)) - new_forward_events = txn.fetchall() - - if len(new_forward_events) == limit: - upper_bound = new_forward_events[-1][0] - else: - upper_bound = current_forward_id - - sql = ( - "SELECT event_stream_ordering, event_id, state_group" - " FROM ex_outlier_stream" - " WHERE ? > event_stream_ordering" - " AND event_stream_ordering >= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (last_forward_id, upper_bound)) - forward_ex_outliers = txn.fetchall() - else: - new_forward_events = [] - forward_ex_outliers = [] - - sql = ( - "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " WHERE ? > stream_ordering AND stream_ordering >= ?" - " ORDER BY stream_ordering DESC" - " LIMIT ?" - ) - if have_backfill_events: - txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit)) - new_backfill_events = txn.fetchall() - - if len(new_backfill_events) == limit: - upper_bound = new_backfill_events[-1][0] - else: - upper_bound = current_backfill_id - - sql = ( - "SELECT -event_stream_ordering, event_id, state_group" - " FROM ex_outlier_stream" - " WHERE ? > event_stream_ordering" - " AND event_stream_ordering >= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (-last_backfill_id, -upper_bound)) - backward_ex_outliers = txn.fetchall() - else: - new_backfill_events = [] - backward_ex_outliers = [] - - return AllNewEventsResult( - new_forward_events, - new_backfill_events, - forward_ex_outliers, - backward_ex_outliers, - ) - - return self.db_pool.runInteraction("get_all_new_events", get_all_new_events_txn) - async def is_event_after(self, event_id1, event_id2): """Returns True if event_id1 is after event_id2 in the stream """ @@ -1357,14 +1200,3 @@ class EventsWorkerStore(SQLBaseStore): return self.db_pool.runInteraction( desc="get_next_event_to_expire", func=get_next_event_to_expire_txn ) - - -AllNewEventsResult = namedtuple( - "AllNewEventsResult", - [ - "new_forward_events", - "new_backfill_events", - "forward_ex_outliers", - "backward_ex_outliers", - ], -) diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index fd213d2dfd..9f691e5792 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -157,24 +157,3 @@ class PresenceStore(SQLBaseStore): def get_current_presence_token(self): return self._presence_id_gen.get_current_token() - - def allow_presence_visible(self, observed_localpart, observer_userid): - return self.db_pool.simple_insert( - table="presence_allow_inbound", - values={ - "observed_user_id": observed_localpart, - "observer_user_id": observer_userid, - }, - desc="allow_presence_visible", - or_ignore=True, - ) - - def disallow_presence_visible(self, observed_localpart, observer_userid): - return self.db_pool.simple_delete_one( - table="presence_allow_inbound", - keyvalues={ - "observed_user_id": observed_localpart, - "observer_user_id": observer_userid, - }, - desc="disallow_presence_visible", - ) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 402ae25571..7965a52e30 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1345,43 +1345,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): "validate_threepid_session_txn", validate_threepid_session_txn ) - def upsert_threepid_validation_session( - self, - medium, - address, - client_secret, - send_attempt, - session_id, - validated_at=None, - ): - """Upsert a threepid validation session - Args: - medium (str): The medium of the 3PID - address (str): The address of the 3PID - client_secret (str): A unique string provided by the client to - help identify this validation attempt - send_attempt (int): The latest send_attempt on this session - session_id (str): The id of this validation session - validated_at (int|None): The unix timestamp in milliseconds of - when the session was marked as valid - """ - insertion_values = { - "medium": medium, - "address": address, - "client_secret": client_secret, - } - - if validated_at: - insertion_values["validated_at"] = validated_at - - return self.db_pool.simple_upsert( - table="threepid_validation_session", - keyvalues={"session_id": session_id}, - values={"last_send_attempt": send_attempt}, - insertion_values=insertion_values, - desc="upsert_threepid_validation_session", - ) - def start_or_continue_validation_session( self, medium, diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index f4008e6221..aef08c7e12 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -35,10 +35,6 @@ from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) -OpsLevel = collections.namedtuple( - "OpsLevel", ("ban_level", "kick_level", "redact_level") -) - RatelimitOverride = collections.namedtuple( "RatelimitOverride", ("messages_per_second", "burst_count") ) diff --git a/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql b/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql new file mode 100644 index 0000000000..15421b99ac --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql @@ -0,0 +1,17 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This table is no longer used. +DROP TABLE IF EXISTS presence_allow_inbound; -- cgit 1.5.1 From ac77cdb64e50c9fdfc00cccbc7b96f42057aa741 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Aug 2020 12:37:59 -0400 Subject: Add a shadow-banned flag to users. (#8092) --- changelog.d/8092.feature | 1 + synapse/api/auth.py | 12 ++++++++++- synapse/handlers/register.py | 8 +++++++ synapse/replication/http/register.py | 4 ++++ synapse/storage/databases/main/registration.py | 9 +++++++- .../main/schema/delta/58/09shadow_ban.sql | 18 ++++++++++++++++ synapse/types.py | 25 +++++++++++++++++++--- tests/storage/test_cleanup_extrems.py | 4 ++-- tests/storage/test_event_metrics.py | 2 +- tests/storage/test_roommember.py | 2 +- tests/test_federation.py | 2 +- tests/unittest.py | 8 +++++-- 12 files changed, 83 insertions(+), 12 deletions(-) create mode 100644 changelog.d/8092.feature create mode 100644 synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql (limited to 'synapse/storage/databases/main/registration.py') diff --git a/changelog.d/8092.feature b/changelog.d/8092.feature new file mode 100644 index 0000000000..813e6d0903 --- /dev/null +++ b/changelog.d/8092.feature @@ -0,0 +1 @@ +Add support for shadow-banning users (ignoring any message send requests). diff --git a/synapse/api/auth.py b/synapse/api/auth.py index d8190f92ab..7aab764360 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -213,6 +213,7 @@ class Auth(object): user = user_info["user"] token_id = user_info["token_id"] is_guest = user_info["is_guest"] + shadow_banned = user_info["shadow_banned"] # Deny the request if the user account has expired. if self._account_validity.enabled and not allow_expired: @@ -252,7 +253,12 @@ class Auth(object): opentracing.set_tag("device_id", device_id) return synapse.types.create_requester( - user, token_id, is_guest, device_id, app_service=app_service + user, + token_id, + is_guest, + shadow_banned, + device_id, + app_service=app_service, ) except KeyError: raise MissingClientTokenError() @@ -297,6 +303,7 @@ class Auth(object): dict that includes: `user` (UserID) `is_guest` (bool) + `shadow_banned` (bool) `token_id` (int|None): access token id. May be None if guest `device_id` (str|None): device corresponding to access token Raises: @@ -356,6 +363,7 @@ class Auth(object): ret = { "user": user, "is_guest": True, + "shadow_banned": False, "token_id": None, # all guests get the same device id "device_id": GUEST_DEVICE_ID, @@ -365,6 +373,7 @@ class Auth(object): ret = { "user": user, "is_guest": False, + "shadow_banned": False, "token_id": None, "device_id": None, } @@ -488,6 +497,7 @@ class Auth(object): "user": UserID.from_string(ret.get("name")), "token_id": ret.get("token_id", None), "is_guest": False, + "shadow_banned": ret.get("shadow_banned"), "device_id": ret.get("device_id"), "valid_until_ms": ret.get("valid_until_ms"), } diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index c94209ab3d..999bc6efb5 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -142,6 +142,7 @@ class RegistrationHandler(BaseHandler): address=None, bind_emails=[], by_admin=False, + shadow_banned=False, ): """Registers a new client on the server. @@ -159,6 +160,7 @@ class RegistrationHandler(BaseHandler): bind_emails (List[str]): list of emails to bind to this account. by_admin (bool): True if this registration is being made via the admin api, otherwise False. + shadow_banned (bool): Shadow-ban the created user. Returns: str: user_id Raises: @@ -194,6 +196,7 @@ class RegistrationHandler(BaseHandler): admin=admin, user_type=user_type, address=address, + shadow_banned=shadow_banned, ) if self.hs.config.user_directory_search_all_users: @@ -224,6 +227,7 @@ class RegistrationHandler(BaseHandler): make_guest=make_guest, create_profile_with_displayname=default_display_name, address=address, + shadow_banned=shadow_banned, ) # Successfully registered @@ -529,6 +533,7 @@ class RegistrationHandler(BaseHandler): admin=False, user_type=None, address=None, + shadow_banned=False, ): """Register user in the datastore. @@ -546,6 +551,7 @@ class RegistrationHandler(BaseHandler): user_type (str|None): type of user. One of the values from api.constants.UserTypes, or None for a normal user. address (str|None): the IP address used to perform the registration. + shadow_banned (bool): Whether to shadow-ban the user Returns: Awaitable @@ -561,6 +567,7 @@ class RegistrationHandler(BaseHandler): admin=admin, user_type=user_type, address=address, + shadow_banned=shadow_banned, ) else: return self.store.register_user( @@ -572,6 +579,7 @@ class RegistrationHandler(BaseHandler): create_profile_with_displayname=create_profile_with_displayname, admin=admin, user_type=user_type, + shadow_banned=shadow_banned, ) async def register_device( diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index ce9420aa69..a02b27474d 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -44,6 +44,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): admin, user_type, address, + shadow_banned, ): """ Args: @@ -60,6 +61,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): user_type (str|None): type of user. One of the values from api.constants.UserTypes, or None for a normal user. address (str|None): the IP address used to perform the regitration. + shadow_banned (bool): Whether to shadow-ban the user """ return { "password_hash": password_hash, @@ -70,6 +72,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): "admin": admin, "user_type": user_type, "address": address, + "shadow_banned": shadow_banned, } async def _handle_request(self, request, user_id): @@ -87,6 +90,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): admin=content["admin"], user_type=content["user_type"], address=content["address"], + shadow_banned=content["shadow_banned"], ) return 200, {} diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 7965a52e30..de50fa6e94 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -304,7 +304,7 @@ class RegistrationWorkerStore(SQLBaseStore): def _query_for_auth(self, txn, token): sql = ( - "SELECT users.name, users.is_guest, access_tokens.id as token_id," + "SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id," " access_tokens.device_id, access_tokens.valid_until_ms" " FROM users" " INNER JOIN access_tokens on users.name = access_tokens.user_id" @@ -952,6 +952,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): create_profile_with_displayname=None, admin=False, user_type=None, + shadow_banned=False, ): """Attempts to register an account. @@ -968,6 +969,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): admin (boolean): is an admin user? user_type (str|None): type of user. One of the values from api.constants.UserTypes, or None for a normal user. + shadow_banned (bool): Whether the user is shadow-banned, + i.e. they may be told their requests succeeded but we ignore them. Raises: StoreError if the user_id could not be registered. @@ -986,6 +989,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): create_profile_with_displayname, admin, user_type, + shadow_banned, ) def _register_user( @@ -999,6 +1003,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): create_profile_with_displayname, admin, user_type, + shadow_banned, ): user_id_obj = UserID.from_string(user_id) @@ -1028,6 +1033,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): "appservice_id": appservice_id, "admin": 1 if admin else 0, "user_type": user_type, + "shadow_banned": shadow_banned, }, ) else: @@ -1042,6 +1048,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): "appservice_id": appservice_id, "admin": 1 if admin else 0, "user_type": user_type, + "shadow_banned": shadow_banned, }, ) diff --git a/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql b/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql new file mode 100644 index 0000000000..260b009b48 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql @@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- A shadow-banned user may be told that their requests succeeded when they were +-- actually ignored. +ALTER TABLE users ADD COLUMN shadow_banned BOOLEAN; diff --git a/synapse/types.py b/synapse/types.py index 9e580f4295..bc36cdde30 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -51,7 +51,15 @@ JsonDict = Dict[str, Any] class Requester( namedtuple( - "Requester", ["user", "access_token_id", "is_guest", "device_id", "app_service"] + "Requester", + [ + "user", + "access_token_id", + "is_guest", + "shadow_banned", + "device_id", + "app_service", + ], ) ): """ @@ -62,6 +70,7 @@ class Requester( access_token_id (int|None): *ID* of the access token used for this request, or None if it came via the appservice API or similar is_guest (bool): True if the user making this request is a guest user + shadow_banned (bool): True if the user making this request has been shadow-banned. device_id (str|None): device_id which was set at authentication time app_service (ApplicationService|None): the AS requesting on behalf of the user """ @@ -77,6 +86,7 @@ class Requester( "user_id": self.user.to_string(), "access_token_id": self.access_token_id, "is_guest": self.is_guest, + "shadow_banned": self.shadow_banned, "device_id": self.device_id, "app_server_id": self.app_service.id if self.app_service else None, } @@ -101,13 +111,19 @@ class Requester( user=UserID.from_string(input["user_id"]), access_token_id=input["access_token_id"], is_guest=input["is_guest"], + shadow_banned=input["shadow_banned"], device_id=input["device_id"], app_service=appservice, ) def create_requester( - user_id, access_token_id=None, is_guest=False, device_id=None, app_service=None + user_id, + access_token_id=None, + is_guest=False, + shadow_banned=False, + device_id=None, + app_service=None, ): """ Create a new ``Requester`` object @@ -117,6 +133,7 @@ def create_requester( access_token_id (int|None): *ID* of the access token used for this request, or None if it came via the appservice API or similar is_guest (bool): True if the user making this request is a guest user + shadow_banned (bool): True if the user making this request is shadow-banned. device_id (str|None): device_id which was set at authentication time app_service (ApplicationService|None): the AS requesting on behalf of the user @@ -125,7 +142,9 @@ def create_requester( """ if not isinstance(user_id, UserID): user_id = UserID.from_string(user_id) - return Requester(user_id, access_token_id, is_guest, device_id, app_service) + return Requester( + user_id, access_token_id, is_guest, shadow_banned, device_id, app_service + ) def get_domain_from_id(string): diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 3fab5a5248..8e9a650f9f 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -38,7 +38,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): # Create a test user and room self.user = UserID("alice", "test") - self.requester = Requester(self.user, None, False, None, None) + self.requester = Requester(self.user, None, False, False, None, None) info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] @@ -260,7 +260,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): # Create a test user and room self.user = UserID.from_string(self.register_user("user1", "password")) self.token1 = self.login("user1", "password") - self.requester = Requester(self.user, None, False, None, None) + self.requester = Requester(self.user, None, False, False, None, None) info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] self.event_creator = homeserver.get_event_creation_handler() diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index a7b85004e5..949846fe33 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -27,7 +27,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase): room_creator = self.hs.get_room_creation_handler() user = UserID("alice", "test") - requester = Requester(user, None, False, None, None) + requester = Requester(user, None, False, False, None, None) # Real events, forward extremities events = [(3, 2), (6, 2), (4, 6)] diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 17c9da4838..d98fe8754d 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -187,7 +187,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): # Now let's create a room, which will insert a membership user = UserID("alice", "test") - requester = Requester(user, None, False, None, None) + requester = Requester(user, None, False, False, None, None) self.get_success(self.room_creator.create_room(requester, {})) # Register the background update to run again. diff --git a/tests/test_federation.py b/tests/test_federation.py index f2fa42bfb9..4a4548433f 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -42,7 +42,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) user_id = UserID("us", "test") - our_user = Requester(user_id, None, False, None, None) + our_user = Requester(user_id, None, False, False, None, None) room_creator = self.homeserver.get_room_creation_handler() room_deferred = ensureDeferred( room_creator.create_room( diff --git a/tests/unittest.py b/tests/unittest.py index d0bba3ddef..7b80999a74 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -250,7 +250,11 @@ class HomeserverTestCase(TestCase): async def get_user_by_req(request, allow_guest=False, rights="access"): return create_requester( - UserID.from_string(self.helper.auth_user_id), 1, False, None + UserID.from_string(self.helper.auth_user_id), + 1, + False, + False, + None, ) self.hs.get_auth().get_user_by_req = get_user_by_req @@ -540,7 +544,7 @@ class HomeserverTestCase(TestCase): """ event_creator = self.hs.get_event_creation_handler() secrets = self.hs.get_secrets() - requester = Requester(user, None, False, None, None) + requester = Requester(user, None, False, False, None, None) event, context = self.get_success( event_creator.create_event( -- cgit 1.5.1 From 050e20e7ca56c3a5985fdcf64012800c153260f2 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 17 Aug 2020 12:18:01 -0400 Subject: Convert some of the general database methods to async (#8100) --- changelog.d/8100.misc | 1 + synapse/storage/database.py | 23 ++++++++----------- synapse/storage/databases/main/appservice.py | 2 +- synapse/storage/databases/main/events_worker.py | 16 +++++++------ synapse/storage/databases/main/registration.py | 8 +++---- synapse/storage/databases/main/roommember.py | 4 ++-- tests/handlers/test_profile.py | 4 ++-- tests/handlers/test_typing.py | 2 +- tests/storage/test_appservice.py | 16 +++++++++---- tests/storage/test_base.py | 16 ++++++++----- tests/storage/test_event_push_actions.py | 30 +++++++++++++------------ tests/storage/test_main.py | 2 +- tests/storage/test_profile.py | 4 ++-- 13 files changed, 69 insertions(+), 59 deletions(-) create mode 100644 changelog.d/8100.misc (limited to 'synapse/storage/databases/main/registration.py') diff --git a/changelog.d/8100.misc b/changelog.d/8100.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8100.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 4ada6f5563..8a9e06efcf 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -332,8 +332,7 @@ class DatabasePool(object): """ return self._db_pool.running - @defer.inlineCallbacks - def _check_safe_to_upsert(self): + async def _check_safe_to_upsert(self): """ Is it safe to use native UPSERT? @@ -342,7 +341,7 @@ class DatabasePool(object): If the background updates have not completed, wait 15 sec and check again. """ - updates = yield self.simple_select_list( + updates = await self.simple_select_list( "background_updates", keyvalues=None, retcols=["update_name"], @@ -614,8 +613,7 @@ class DatabasePool(object): # "Simple" SQL API methods that operate on a single table with no JOINs, # no complex WHERE clauses, just a dict of values for columns. - @defer.inlineCallbacks - def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"): + async def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"): """Executes an INSERT query on the named table. Args: @@ -631,7 +629,7 @@ class DatabasePool(object): `or_ignore` is True """ try: - yield self.runInteraction(desc, self.simple_insert_txn, table, values) + await self.runInteraction(desc, self.simple_insert_txn, table, values) except self.engine.module.IntegrityError: # We have to do or_ignore flag at this layer, since we can't reuse # a cursor after we receive an error from the db. @@ -684,8 +682,7 @@ class DatabasePool(object): txn.executemany(sql, vals) - @defer.inlineCallbacks - def simple_upsert( + async def simple_upsert( self, table, keyvalues, @@ -714,14 +711,14 @@ class DatabasePool(object): inserting lock (bool): True to lock the table when doing the upsert. Returns: - Deferred(None or bool): Native upserts always return None. Emulated + None or bool: Native upserts always return None. Emulated upserts return True if a new entry was created, False if an existing one was updated. """ attempts = 0 while True: try: - result = yield self.runInteraction( + return await self.runInteraction( desc, self.simple_upsert_txn, table, @@ -730,7 +727,6 @@ class DatabasePool(object): insertion_values, lock=lock, ) - return result except self.engine.module.IntegrityError as e: attempts += 1 if attempts >= 5: @@ -1121,8 +1117,7 @@ class DatabasePool(object): return cls.cursor_to_dict(txn) - @defer.inlineCallbacks - def simple_select_many_batch( + async def simple_select_many_batch( self, table, column, @@ -1156,7 +1151,7 @@ class DatabasePool(object): it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size) ] for chunk in chunks: - rows = yield self.runInteraction( + rows = await self.runInteraction( desc, self.simple_select_many_txn, table, diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 5cf1a88399..02568a2391 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -169,7 +169,7 @@ class ApplicationServiceTransactionWorkerStore( service(ApplicationService): The service whose state to set. state(ApplicationServiceState): The connectivity state to apply. Returns: - A Deferred which resolves when the state was set successfully. + An Awaitable which resolves when the state was set successfully. """ return self.db_pool.simple_upsert( "application_services_state", {"as_id": service.id}, {"state": state} diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 5687448e3d..8c63a0dc4d 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -847,13 +847,15 @@ class EventsWorkerStore(SQLBaseStore): """Given a list of event ids, check if we have already processed and stored them as non outliers. """ - rows = yield self.db_pool.simple_select_many_batch( - table="events", - retcols=("event_id",), - column="event_id", - iterable=list(event_ids), - keyvalues={"outlier": False}, - desc="have_events_in_timeline", + rows = yield defer.ensureDeferred( + self.db_pool.simple_select_many_batch( + table="events", + retcols=("event_id",), + column="event_id", + iterable=list(event_ids), + keyvalues={"outlier": False}, + desc="have_events_in_timeline", + ) ) return {r["event_id"] for r in rows} diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index de50fa6e94..068ad22b30 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -17,9 +17,7 @@ import logging import re -from typing import Dict, List, Optional - -from twisted.internet.defer import Deferred +from typing import Awaitable, Dict, List, Optional from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError @@ -563,7 +561,7 @@ class RegistrationWorkerStore(SQLBaseStore): id_server (str) Returns: - Deferred + Awaitable """ # We need to use an upsert, in case they user had already bound the # threepid @@ -1084,7 +1082,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): def record_user_external_id( self, auth_provider: str, external_id: str, user_id: str - ) -> Deferred: + ) -> Awaitable: """Record a mapping from an external user id to a mxid Args: diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 1cc8c08ed0..161edbeccb 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -767,13 +767,13 @@ class RoomMemberWorkerStore(EventsWorkerStore): return set(room_ids) - def get_membership_from_event_ids( + async def get_membership_from_event_ids( self, member_event_ids: Iterable[str] ) -> List[dict]: """Get user_id and membership of a set of event IDs. """ - return self.db_pool.simple_select_many_batch( + return await self.db_pool.simple_select_many_batch( table="room_memberships", column="event_id", iterable=member_event_ids, diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index d70e1fc608..b609b30d4a 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -64,7 +64,7 @@ class ProfileTestCase(unittest.TestCase): self.bob = UserID.from_string("@4567:test") self.alice = UserID.from_string("@alice:remote") - yield self.store.create_profile(self.frank.localpart) + yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart)) self.handler = hs.get_profile_handler() self.hs = hs @@ -157,7 +157,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_incoming_fed_query(self): - yield self.store.create_profile("caroline") + yield defer.ensureDeferred(self.store.create_profile("caroline")) yield self.store.set_profile_displayname("caroline", "Caroline") response = yield defer.ensureDeferred( diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 64afd581bc..e01de158e5 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -156,7 +156,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ([], 0) ) self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None - self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed( + self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable( None ) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 98b74890d5..a425e66f37 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -207,7 +207,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_appservices_state_down(self): service = Mock(id=self.as_list[1]["id"]) - yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) + yield defer.ensureDeferred( + self.store.set_appservice_state(service, ApplicationServiceState.DOWN) + ) rows = yield self.db_pool.runQuery( self.engine.convert_param_style( "SELECT as_id FROM application_services_state WHERE state=?" @@ -219,9 +221,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_appservices_state_multiple_up(self): service = Mock(id=self.as_list[1]["id"]) - yield self.store.set_appservice_state(service, ApplicationServiceState.UP) - yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) - yield self.store.set_appservice_state(service, ApplicationServiceState.UP) + yield defer.ensureDeferred( + self.store.set_appservice_state(service, ApplicationServiceState.UP) + ) + yield defer.ensureDeferred( + self.store.set_appservice_state(service, ApplicationServiceState.DOWN) + ) + yield defer.ensureDeferred( + self.store.set_appservice_state(service, ApplicationServiceState.UP) + ) rows = yield self.db_pool.runQuery( self.engine.convert_param_style( "SELECT as_id FROM application_services_state WHERE state=?" diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index efcaeef1e7..13bcac743a 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -66,8 +66,10 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_insert_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore.db_pool.simple_insert( - table="tablename", values={"columname": "Value"} + yield defer.ensureDeferred( + self.datastore.db_pool.simple_insert( + table="tablename", values={"columname": "Value"} + ) ) self.mock_txn.execute.assert_called_with( @@ -78,10 +80,12 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_insert_3cols(self): self.mock_txn.rowcount = 1 - yield self.datastore.db_pool.simple_insert( - table="tablename", - # Use OrderedDict() so we can assert on the SQL generated - values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]), + yield defer.ensureDeferred( + self.datastore.db_pool.simple_insert( + table="tablename", + # Use OrderedDict() so we can assert on the SQL generated + values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]), + ) ) self.mock_txn.execute.assert_called_with( diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 857db071d4..238bad5b45 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -142,20 +142,22 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_find_first_stream_ordering_after_ts(self): def add_event(so, ts): - return self.store.db_pool.simple_insert( - "events", - { - "stream_ordering": so, - "received_ts": ts, - "event_id": "event%i" % so, - "type": "", - "room_id": "", - "content": "", - "processed": True, - "outlier": False, - "topological_ordering": 0, - "depth": 0, - }, + return defer.ensureDeferred( + self.store.db_pool.simple_insert( + "events", + { + "stream_ordering": so, + "received_ts": ts, + "event_id": "event%i" % so, + "type": "", + "room_id": "", + "content": "", + "processed": True, + "outlier": False, + "topological_ordering": 0, + "depth": 0, + }, + ) ) # start with the base case where there are no events in the table diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py index ab0df5ea93..fbf8af940a 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py @@ -35,7 +35,7 @@ class DataStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_users_paginate(self): yield self.store.register_user(self.user.to_string(), "pass") - yield self.store.create_profile(self.user.localpart) + yield defer.ensureDeferred(self.store.create_profile(self.user.localpart)) yield self.store.set_profile_displayname(self.user.localpart, self.displayname) users, total = yield self.store.get_users_paginate( diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 9b6f7211ae..9d5b8aa47d 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -33,7 +33,7 @@ class ProfileStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_displayname(self): - yield self.store.create_profile(self.u_frank.localpart) + yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank") @@ -43,7 +43,7 @@ class ProfileStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_avatar_url(self): - yield self.store.create_profile(self.u_frank.localpart) + yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) yield self.store.set_profile_avatar_url( self.u_frank.localpart, "http://my.site/here" -- cgit 1.5.1 From 3f49f74610197d32fe73678cabc10f08732e66b8 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 24 Aug 2020 11:33:55 +0100 Subject: Don't fail /submit_token requests on incorrect session ID if request_token_inhibit_3pid_errors is turned on (#7991) * Don't raise session_id errors on submit_token if request_token_inhibit_3pid_errors is set * Changelog * Also wait some time before responding to /requestToken * Incorporate review * Update synapse/storage/databases/main/registration.py Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> * Incorporate review Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- changelog.d/7991.misc | 1 + synapse/rest/client/v2_alpha/account.py | 10 +++++++++ synapse/rest/client/v2_alpha/register.py | 7 ++++++ synapse/storage/databases/main/registration.py | 25 ++++++++++++++++----- tests/storage/test_registration.py | 31 ++++++++++++++++++++++++++ 5 files changed, 68 insertions(+), 6 deletions(-) create mode 100644 changelog.d/7991.misc (limited to 'synapse/storage/databases/main/registration.py') diff --git a/changelog.d/7991.misc b/changelog.d/7991.misc new file mode 100644 index 0000000000..1562e3af9e --- /dev/null +++ b/changelog.d/7991.misc @@ -0,0 +1 @@ +Don't fail `/submit_token` requests on incorrect session ID if `request_token_inhibit_3pid_errors` is turned on. diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 203e76b9f2..3481477731 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import random from http import HTTPStatus from synapse.api.constants import LoginType @@ -109,6 +110,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): if self.config.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.clock.sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) @@ -448,6 +452,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): if self.config.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.clock.sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) @@ -516,6 +523,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): if self.hs.config.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.clock.sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index be0e680ac5..51372cdb5e 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -16,6 +16,7 @@ import hmac import logging +import random from typing import List, Union import synapse @@ -131,6 +132,9 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): if self.hs.config.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.clock.sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) @@ -203,6 +207,9 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): if self.hs.config.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.clock.sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError( diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 068ad22b30..321a51cc6a 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -889,6 +889,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): super(RegistrationStore, self).__init__(database, db_conn, hs) self._account_validity = hs.config.account_validity + self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors if self._account_validity.enabled: self._clock.call_later( @@ -1302,15 +1303,22 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) if not row: - raise ThreepidValidationError(400, "Unknown session_id") + if self._ignore_unknown_session_error: + # If we need to inhibit the error caused by an incorrect session ID, + # use None as placeholder values for the client secret and the + # validation timestamp. + # It shouldn't be an issue because they're both only checked after + # the token check, which should fail. And if it doesn't for some + # reason, the next check is on the client secret, which is NOT NULL, + # so we don't have to worry about the client secret matching by + # accident. + row = {"client_secret": None, "validated_at": None} + else: + raise ThreepidValidationError(400, "Unknown session_id") + retrieved_client_secret = row["client_secret"] validated_at = row["validated_at"] - if retrieved_client_secret != client_secret: - raise ThreepidValidationError( - 400, "This client_secret does not match the provided session_id" - ) - row = self.db_pool.simple_select_one_txn( txn, table="threepid_validation_token", @@ -1326,6 +1334,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): expires = row["expires"] next_link = row["next_link"] + if retrieved_client_secret != client_secret: + raise ThreepidValidationError( + 400, "This client_secret does not match the provided session_id" + ) + # If the session is already validated, no need to revalidate if validated_at: return next_link diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 840db66072..58f827d8d3 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -17,6 +17,7 @@ from twisted.internet import defer from synapse.api.constants import UserTypes +from synapse.api.errors import ThreepidValidationError from tests import unittest from tests.utils import setup_test_homeserver @@ -122,3 +123,33 @@ class RegistrationStoreTestCase(unittest.TestCase): ) res = yield self.store.is_support_user(SUPPORT_USER) self.assertTrue(res) + + @defer.inlineCallbacks + def test_3pid_inhibit_invalid_validation_session_error(self): + """Tests that enabling the configuration option to inhibit 3PID errors on + /requestToken also inhibits validation errors caused by an unknown session ID. + """ + + # Check that, with the config setting set to false (the default value), a + # validation error is caused by the unknown session ID. + try: + yield defer.ensureDeferred( + self.store.validate_threepid_session( + "fake_sid", "fake_client_secret", "fake_token", 0, + ) + ) + except ThreepidValidationError as e: + self.assertEquals(e.msg, "Unknown session_id", e) + + # Set the config setting to true. + self.store._ignore_unknown_session_error = True + + # Check that now the validation error is caused by the token not matching. + try: + yield defer.ensureDeferred( + self.store.validate_threepid_session( + "fake_sid", "fake_client_secret", "fake_token", 0, + ) + ) + except ThreepidValidationError as e: + self.assertEquals(e.msg, "Validation token not found or has expired", e) -- cgit 1.5.1 From 4c6c56dc58aba7af92f531655c2355d8f25e529c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 26 Aug 2020 07:19:32 -0400 Subject: Convert simple_select_one and simple_select_one_onecol to async (#8162) --- changelog.d/8162.misc | 1 + synapse/storage/database.py | 36 +++++++++++--- synapse/storage/databases/main/devices.py | 14 +++--- synapse/storage/databases/main/directory.py | 4 +- synapse/storage/databases/main/e2e_room_keys.py | 8 ++-- synapse/storage/databases/main/events_worker.py | 10 ++-- synapse/storage/databases/main/group_server.py | 18 ++++--- synapse/storage/databases/main/media_repository.py | 13 +++-- .../storage/databases/main/monthly_active_users.py | 15 +++--- synapse/storage/databases/main/profile.py | 17 ++++--- synapse/storage/databases/main/receipts.py | 6 ++- synapse/storage/databases/main/registration.py | 10 ++-- synapse/storage/databases/main/rejections.py | 5 +- synapse/storage/databases/main/room.py | 10 ++-- synapse/storage/databases/main/state.py | 4 +- synapse/storage/databases/main/stats.py | 10 ++-- synapse/storage/databases/main/user_directory.py | 9 ++-- tests/handlers/test_profile.py | 56 +++++++++++++++++----- tests/handlers/test_typing.py | 4 +- tests/module_api/test_api.py | 2 +- tests/storage/test_base.py | 28 ++++++----- tests/storage/test_devices.py | 8 ++-- tests/storage/test_profile.py | 23 +++++++-- tests/storage/test_registration.py | 2 +- tests/storage/test_room.py | 20 ++++++-- 25 files changed, 220 insertions(+), 113 deletions(-) create mode 100644 changelog.d/8162.misc (limited to 'synapse/storage/databases/main/registration.py') diff --git a/changelog.d/8162.misc b/changelog.d/8162.misc new file mode 100644 index 0000000000..e26764dea1 --- /dev/null +++ b/changelog.d/8162.misc @@ -0,0 +1 @@ + Convert various parts of the codebase to async/await. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index bc327e344e..181c3ec249 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -29,9 +29,11 @@ from typing import ( Tuple, TypeVar, Union, + overload, ) from prometheus_client import Histogram +from typing_extensions import Literal from twisted.enterprise import adbapi from twisted.internet import defer @@ -1020,14 +1022,36 @@ class DatabasePool(object): return txn.execute_batch(sql, args) - def simple_select_one( + @overload + async def simple_select_one( + self, + table: str, + keyvalues: Dict[str, Any], + retcols: Iterable[str], + allow_none: Literal[False] = False, + desc: str = "simple_select_one", + ) -> Dict[str, Any]: + ... + + @overload + async def simple_select_one( + self, + table: str, + keyvalues: Dict[str, Any], + retcols: Iterable[str], + allow_none: Literal[True] = True, + desc: str = "simple_select_one", + ) -> Optional[Dict[str, Any]]: + ... + + async def simple_select_one( self, table: str, keyvalues: Dict[str, Any], retcols: Iterable[str], allow_none: bool = False, desc: str = "simple_select_one", - ) -> defer.Deferred: + ) -> Optional[Dict[str, Any]]: """Executes a SELECT query on the named table, which is expected to return a single row, returning multiple columns from it. @@ -1038,18 +1062,18 @@ class DatabasePool(object): allow_none: If true, return None instead of failing if the SELECT statement returns no rows """ - return self.runInteraction( + return await self.runInteraction( desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none ) - def simple_select_one_onecol( + async def simple_select_one_onecol( self, table: str, keyvalues: Dict[str, Any], retcol: Iterable[str], allow_none: bool = False, desc: str = "simple_select_one_onecol", - ) -> defer.Deferred: + ) -> Optional[Any]: """Executes a SELECT query on the named table, which is expected to return a single row, returning a single column from it. @@ -1061,7 +1085,7 @@ class DatabasePool(object): statement returns no rows desc: description of the transaction, for logging and metrics """ - return self.runInteraction( + return await self.runInteraction( desc, self.simple_select_one_onecol_txn, table, diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 03b45dbc4d..a811a39eb5 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Dict, Iterable, List, Optional, Set, Tuple +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple from synapse.api.errors import Codes, StoreError from synapse.logging.opentracing import ( @@ -47,7 +47,7 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" class DeviceWorkerStore(SQLBaseStore): - def get_device(self, user_id: str, device_id: str): + async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]: """Retrieve a device. Only returns devices that are not marked as hidden. @@ -55,11 +55,11 @@ class DeviceWorkerStore(SQLBaseStore): user_id: The ID of the user which owns the device device_id: The ID of the device to retrieve Returns: - defer.Deferred for a dict containing the device information + A dict containing the device information Raises: StoreError: if the device is not found """ - return self.db_pool.simple_select_one( + return await self.db_pool.simple_select_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), @@ -656,11 +656,13 @@ class DeviceWorkerStore(SQLBaseStore): ) @cached(max_entries=10000) - def get_device_list_last_stream_id_for_remote(self, user_id: str): + async def get_device_list_last_stream_id_for_remote( + self, user_id: str + ) -> Optional[Any]: """Get the last stream_id we got for a user. May be None if we haven't got any information for them. """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, retcol="stream_id", diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index 037e02603c..301d5d845a 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -59,8 +59,8 @@ class DirectoryWorkerStore(SQLBaseStore): return RoomAliasMapping(room_id, room_alias.to_string(), servers) - def get_room_alias_creator(self, room_alias): - return self.db_pool.simple_select_one_onecol( + async def get_room_alias_creator(self, room_alias: str) -> str: + return await self.db_pool.simple_select_one_onecol( table="room_aliases", keyvalues={"room_alias": room_alias}, retcol="creator", diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index 2eeb9f97dc..46c3e33cc6 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -223,15 +223,15 @@ class EndToEndRoomKeyStore(SQLBaseStore): return ret - def count_e2e_room_keys(self, user_id, version): + async def count_e2e_room_keys(self, user_id: str, version: str) -> int: """Get the number of keys in a backup version. Args: - user_id (str): the user whose backup we're querying - version (str): the version ID of the backup we're querying about + user_id: the user whose backup we're querying + version: the version ID of the backup we're querying about """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="e2e_room_keys", keyvalues={"user_id": user_id, "version": version}, retcol="COUNT(*)", diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index e1241a724b..d59d73938a 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -119,19 +119,19 @@ class EventsWorkerStore(SQLBaseStore): super().process_replication_rows(stream_name, instance_name, token, rows) - def get_received_ts(self, event_id): + async def get_received_ts(self, event_id: str) -> Optional[int]: """Get received_ts (when it was persisted) for the event. Raises an exception for unknown events. Args: - event_id (str) + event_id: The event ID to query. Returns: - Deferred[int|None]: Timestamp in milliseconds, or None for events - that were persisted before received_ts was implemented. + Timestamp in milliseconds, or None for events that were persisted + before received_ts was implemented. """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="received_ts", diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index a488e0924b..c39864f59f 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json @@ -28,8 +28,8 @@ _DEFAULT_ROLE_ID = "" class GroupServerWorkerStore(SQLBaseStore): - def get_group(self, group_id): - return self.db_pool.simple_select_one( + async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( table="groups", keyvalues={"group_id": group_id}, retcols=( @@ -351,8 +351,10 @@ class GroupServerWorkerStore(SQLBaseStore): ) return bool(result) - def is_user_admin_in_group(self, group_id, user_id): - return self.db_pool.simple_select_one_onecol( + async def is_user_admin_in_group( + self, group_id: str, user_id: str + ) -> Optional[bool]: + return await self.db_pool.simple_select_one_onecol( table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="is_admin", @@ -360,10 +362,12 @@ class GroupServerWorkerStore(SQLBaseStore): desc="is_user_admin_in_group", ) - def is_user_invited_to_local_group(self, group_id, user_id): + async def is_user_invited_to_local_group( + self, group_id: str, user_id: str + ) -> Optional[bool]: """Has the group server invited a user? """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="user_id", diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 80fc1cd009..4ae255ebd8 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict, Optional + from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool @@ -37,12 +39,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): super(MediaRepositoryStore, self).__init__(database, db_conn, hs) - def get_local_media(self, media_id): + async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]: """Get the metadata for a local piece of media + Returns: None if the media_id doesn't exist. """ - return self.db_pool.simple_select_one( + return await self.db_pool.simple_select_one( "local_media_repository", {"media_id": media_id}, ( @@ -191,8 +194,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="store_local_thumbnail", ) - def get_cached_remote_media(self, origin, media_id): - return self.db_pool.simple_select_one( + async def get_cached_remote_media( + self, origin, media_id: str + ) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( "remote_media_cache", {"media_origin": origin, "media_id": media_id}, ( diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index e71cdd2cb4..fe30552c08 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -99,17 +99,18 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): return users @cached(num_args=1) - def user_last_seen_monthly_active(self, user_id): + async def user_last_seen_monthly_active(self, user_id: str) -> int: """ - Checks if a given user is part of the monthly active user group - Arguments: - user_id (str): user to add/update - Return: - Deferred[int] : timestamp since last seen, None if never seen + Checks if a given user is part of the monthly active user group + Arguments: + user_id: user to add/update + + Return: + Timestamp since last seen, None if never seen """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="monthly_active_users", keyvalues={"user_id": user_id}, retcol="timestamp", diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index b8261357d4..b8233c4848 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict, Optional from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore @@ -19,7 +20,7 @@ from synapse.storage.databases.main.roommember import ProfileInfo class ProfileWorkerStore(SQLBaseStore): - async def get_profileinfo(self, user_localpart): + async def get_profileinfo(self, user_localpart: str) -> ProfileInfo: try: profile = await self.db_pool.simple_select_one( table="profiles", @@ -38,24 +39,26 @@ class ProfileWorkerStore(SQLBaseStore): avatar_url=profile["avatar_url"], display_name=profile["displayname"] ) - def get_profile_displayname(self, user_localpart): - return self.db_pool.simple_select_one_onecol( + async def get_profile_displayname(self, user_localpart: str) -> str: + return await self.db_pool.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, retcol="displayname", desc="get_profile_displayname", ) - def get_profile_avatar_url(self, user_localpart): - return self.db_pool.simple_select_one_onecol( + async def get_profile_avatar_url(self, user_localpart: str) -> str: + return await self.db_pool.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, retcol="avatar_url", desc="get_profile_avatar_url", ) - def get_from_remote_profile_cache(self, user_id): - return self.db_pool.simple_select_one( + async def get_from_remote_profile_cache( + self, user_id: str + ) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( table="remote_profile_cache", keyvalues={"user_id": user_id}, retcols=("displayname", "avatar_url"), diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 6821476ee0..cea5ac9a68 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -71,8 +71,10 @@ class ReceiptsWorkerStore(SQLBaseStore): ) @cached(num_args=3) - def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type): - return self.db_pool.simple_select_one_onecol( + async def get_last_receipt_event_id_for_user( + self, user_id: str, room_id: str, receipt_type: str + ) -> Optional[str]: + return await self.db_pool.simple_select_one_onecol( table="receipts_linearized", keyvalues={ "room_id": room_id, diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 321a51cc6a..eced53d470 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -17,7 +17,7 @@ import logging import re -from typing import Awaitable, Dict, List, Optional +from typing import Any, Awaitable, Dict, List, Optional from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError @@ -46,8 +46,8 @@ class RegistrationWorkerStore(SQLBaseStore): ) @cached() - def get_user_by_id(self, user_id): - return self.db_pool.simple_select_one( + async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( table="users", keyvalues={"name": user_id}, retcols=[ @@ -1259,12 +1259,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): desc="del_user_pending_deactivation", ) - def get_user_pending_deactivation(self): + async def get_user_pending_deactivation(self) -> Optional[str]: """ Gets one user from the table of users waiting to be parted from all the rooms they're in. """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( "users_pending_deactivation", keyvalues={}, retcol="user_id", diff --git a/synapse/storage/databases/main/rejections.py b/synapse/storage/databases/main/rejections.py index cf9ba51205..1e361aaa9a 100644 --- a/synapse/storage/databases/main/rejections.py +++ b/synapse/storage/databases/main/rejections.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Optional from synapse.storage._base import SQLBaseStore @@ -21,8 +22,8 @@ logger = logging.getLogger(__name__) class RejectionsStore(SQLBaseStore): - def get_rejection_reason(self, event_id): - return self.db_pool.simple_select_one_onecol( + async def get_rejection_reason(self, event_id: str) -> Optional[str]: + return await self.db_pool.simple_select_one_onecol( table="rejections", retcol="reason", keyvalues={"event_id": event_id}, diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index b3772be2b2..97ecdb16e4 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -73,15 +73,15 @@ class RoomWorkerStore(SQLBaseStore): self.config = hs.config - def get_room(self, room_id): + async def get_room(self, room_id: str) -> dict: """Retrieve a room. Args: - room_id (str): The ID of the room to retrieve. + room_id: The ID of the room to retrieve. Returns: A dict containing the room information, or None if the room is unknown. """ - return self.db_pool.simple_select_one( + return await self.db_pool.simple_select_one( table="rooms", keyvalues={"room_id": room_id}, retcols=("room_id", "is_public", "creator"), @@ -330,8 +330,8 @@ class RoomWorkerStore(SQLBaseStore): return ret_val @cached(max_entries=10000) - def is_room_blocked(self, room_id): - return self.db_pool.simple_select_one_onecol( + async def is_room_blocked(self, room_id: str) -> Optional[bool]: + return await self.db_pool.simple_select_one_onecol( table="blocked_rooms", keyvalues={"room_id": room_id}, retcol="1", diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 991233a9bc..458f169617 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -260,8 +260,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return event.content.get("canonical_alias") @cached(max_entries=50000) - def _get_state_group_for_event(self, event_id): - return self.db_pool.simple_select_one_onecol( + async def _get_state_group_for_event(self, event_id: str) -> Optional[int]: + return await self.db_pool.simple_select_one_onecol( table="event_to_state_groups", keyvalues={"event_id": event_id}, retcol="state_group", diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 802c9019b9..9fe97af56a 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -211,11 +211,11 @@ class StatsStore(StateDeltasStore): return len(rooms_to_work_on) - def get_stats_positions(self): + async def get_stats_positions(self) -> int: """ Returns the stats processor positions. """ - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( table="stats_incremental_position", keyvalues={}, retcol="stream_id", @@ -300,7 +300,7 @@ class StatsStore(StateDeltasStore): return slice_list @cached() - def get_earliest_token_for_stats(self, stats_type, id): + async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int: """ Fetch the "earliest token". This is used by the room stats delta processor to ignore deltas that have been processed between the @@ -308,11 +308,11 @@ class StatsStore(StateDeltasStore): being calculated. Returns: - Deferred[int] + The earliest token. """ table, id_col = TYPE_TO_TABLE[stats_type] - return self.db_pool.simple_select_one_onecol( + return await self.db_pool.simple_select_one_onecol( "%s_current" % (table,), keyvalues={id_col: id}, retcol="completed_delta_stream_id", diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index af21fe457a..20cbcd851c 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -15,6 +15,7 @@ import logging import re +from typing import Any, Dict, Optional from synapse.api.constants import EventTypes, JoinRules from synapse.storage.database import DatabasePool @@ -527,8 +528,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ) @cached() - def get_user_in_directory(self, user_id): - return self.db_pool.simple_select_one( + async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]: + return await self.db_pool.simple_select_one( table="user_directory", keyvalues={"user_id": user_id}, retcols=("display_name", "avatar_url"), @@ -663,8 +664,8 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): users.update(rows) return list(users) - def get_user_directory_stream_pos(self): - return self.db_pool.simple_select_one_onecol( + async def get_user_directory_stream_pos(self) -> int: + return await self.db_pool.simple_select_one_onecol( table="user_directory_stream_pos", keyvalues={}, retcol="stream_id", diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index b609b30d4a..60ebc95f3e 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -71,7 +71,9 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_my_name(self): - yield self.store.set_profile_displayname(self.frank.localpart, "Frank") + yield defer.ensureDeferred( + self.store.set_profile_displayname(self.frank.localpart, "Frank") + ) displayname = yield defer.ensureDeferred( self.handler.get_displayname(self.frank) @@ -104,7 +106,12 @@ class ProfileTestCase(unittest.TestCase): ) self.assertEquals( - (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank", + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.frank.localpart) + ) + ), + "Frank", ) @defer.inlineCallbacks @@ -112,10 +119,17 @@ class ProfileTestCase(unittest.TestCase): self.hs.config.enable_set_displayname = False # Setting displayname for the first time is allowed - yield self.store.set_profile_displayname(self.frank.localpart, "Frank") + yield defer.ensureDeferred( + self.store.set_profile_displayname(self.frank.localpart, "Frank") + ) self.assertEquals( - (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank", + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.frank.localpart) + ) + ), + "Frank", ) # Setting displayname a second time is forbidden @@ -158,7 +172,9 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_incoming_fed_query(self): yield defer.ensureDeferred(self.store.create_profile("caroline")) - yield self.store.set_profile_displayname("caroline", "Caroline") + yield defer.ensureDeferred( + self.store.set_profile_displayname("caroline", "Caroline") + ) response = yield defer.ensureDeferred( self.query_handlers["profile"]( @@ -170,8 +186,10 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_my_avatar(self): - yield self.store.set_profile_avatar_url( - self.frank.localpart, "http://my.server/me.png" + yield defer.ensureDeferred( + self.store.set_profile_avatar_url( + self.frank.localpart, "http://my.server/me.png" + ) ) avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank)) @@ -188,7 +206,11 @@ class ProfileTestCase(unittest.TestCase): ) self.assertEquals( - (yield self.store.get_profile_avatar_url(self.frank.localpart)), + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.frank.localpart) + ) + ), "http://my.server/pic.gif", ) @@ -202,7 +224,11 @@ class ProfileTestCase(unittest.TestCase): ) self.assertEquals( - (yield self.store.get_profile_avatar_url(self.frank.localpart)), + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.frank.localpart) + ) + ), "http://my.server/me.png", ) @@ -211,12 +237,18 @@ class ProfileTestCase(unittest.TestCase): self.hs.config.enable_set_avatar_url = False # Setting displayname for the first time is allowed - yield self.store.set_profile_avatar_url( - self.frank.localpart, "http://my.server/me.png" + yield defer.ensureDeferred( + self.store.set_profile_avatar_url( + self.frank.localpart, "http://my.server/me.png" + ) ) self.assertEquals( - (yield self.store.get_profile_avatar_url(self.frank.localpart)), + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.frank.localpart) + ) + ), "http://my.server/me.png", ) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index e01de158e5..834b4a0af6 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -144,9 +144,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.datastore.get_users_in_room = get_users_in_room - self.datastore.get_user_directory_stream_pos.return_value = ( + self.datastore.get_user_directory_stream_pos.side_effect = ( # we deliberately return a non-None stream pos to avoid doing an initial_spam - defer.succeed(1) + lambda: make_awaitable(1) ) self.datastore.get_current_state_deltas.return_value = (0, None) diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 807cd65dd6..04de0b9dbe 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -35,7 +35,7 @@ class ModuleApiTestCase(HomeserverTestCase): # Check that the new user exists with all provided attributes self.assertEqual(user_id, "@bob:test") self.assertTrue(access_token) - self.assertTrue(self.store.get_user_by_id(user_id)) + self.assertTrue(self.get_success(self.store.get_user_by_id(user_id))) # Check that the email was assigned emails = self.get_success(self.store.user_get_threepids(user_id)) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 13bcac743a..bf22540d99 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -97,8 +97,10 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)])) - value = yield self.datastore.db_pool.simple_select_one_onecol( - table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol" + value = yield defer.ensureDeferred( + self.datastore.db_pool.simple_select_one_onecol( + table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol" + ) ) self.assertEquals("Value", value) @@ -111,10 +113,12 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 self.mock_txn.fetchone.return_value = (1, 2, 3) - ret = yield self.datastore.db_pool.simple_select_one( - table="tablename", - keyvalues={"keycol": "TheKey"}, - retcols=["colA", "colB", "colC"], + ret = yield defer.ensureDeferred( + self.datastore.db_pool.simple_select_one( + table="tablename", + keyvalues={"keycol": "TheKey"}, + retcols=["colA", "colB", "colC"], + ) ) self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret) @@ -127,11 +131,13 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 0 self.mock_txn.fetchone.return_value = None - ret = yield self.datastore.db_pool.simple_select_one( - table="tablename", - keyvalues={"keycol": "Not here"}, - retcols=["colA"], - allow_none=True, + ret = yield defer.ensureDeferred( + self.datastore.db_pool.simple_select_one( + table="tablename", + keyvalues={"keycol": "Not here"}, + retcols=["colA"], + allow_none=True, + ) ) self.assertFalse(ret) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 87ed8f8cd1..34ae8c9da7 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -38,7 +38,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase): self.store.store_device("user_id", "device_id", "display_name") ) - res = yield self.store.get_device("user_id", "device_id") + res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) self.assertDictContainsSubset( { "user_id": "user_id", @@ -111,12 +111,12 @@ class DeviceStoreTestCase(tests.unittest.TestCase): self.store.store_device("user_id", "device_id", "display_name 1") ) - res = yield self.store.get_device("user_id", "device_id") + res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) self.assertEqual("display_name 1", res["display_name"]) # do a no-op first yield defer.ensureDeferred(self.store.update_device("user_id", "device_id")) - res = yield self.store.get_device("user_id", "device_id") + res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) self.assertEqual("display_name 1", res["display_name"]) # do the update @@ -127,7 +127,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase): ) # check it worked - res = yield self.store.get_device("user_id", "device_id") + res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) self.assertEqual("display_name 2", res["display_name"]) @defer.inlineCallbacks diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 9d5b8aa47d..3fd0a38cf5 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -35,21 +35,34 @@ class ProfileStoreTestCase(unittest.TestCase): def test_displayname(self): yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) - yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank") + yield defer.ensureDeferred( + self.store.set_profile_displayname(self.u_frank.localpart, "Frank") + ) self.assertEquals( - "Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart)) + "Frank", + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.u_frank.localpart) + ) + ), ) @defer.inlineCallbacks def test_avatar_url(self): yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) - yield self.store.set_profile_avatar_url( - self.u_frank.localpart, "http://my.site/here" + yield defer.ensureDeferred( + self.store.set_profile_avatar_url( + self.u_frank.localpart, "http://my.site/here" + ) ) self.assertEquals( "http://my.site/here", - (yield self.store.get_profile_avatar_url(self.u_frank.localpart)), + ( + yield defer.ensureDeferred( + self.store.get_profile_avatar_url(self.u_frank.localpart) + ) + ), ) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 58f827d8d3..70c55cd650 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -53,7 +53,7 @@ class RegistrationStoreTestCase(unittest.TestCase): "user_type": None, "deactivated": 0, }, - (yield self.store.get_user_by_id(self.user_id)), + (yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))), ) @defer.inlineCallbacks diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index d07b985a8e..bc8400f240 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -54,12 +54,14 @@ class RoomStoreTestCase(unittest.TestCase): "creator": self.u_creator.to_string(), "is_public": True, }, - (yield self.store.get_room(self.room.to_string())), + (yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))), ) @defer.inlineCallbacks def test_get_room_unknown_room(self): - self.assertIsNone((yield self.store.get_room("!uknown:test")),) + self.assertIsNone( + (yield defer.ensureDeferred(self.store.get_room("!uknown:test"))) + ) @defer.inlineCallbacks def test_get_room_with_stats(self): @@ -69,12 +71,22 @@ class RoomStoreTestCase(unittest.TestCase): "creator": self.u_creator.to_string(), "public": True, }, - (yield self.store.get_room_with_stats(self.room.to_string())), + ( + yield defer.ensureDeferred( + self.store.get_room_with_stats(self.room.to_string()) + ) + ), ) @defer.inlineCallbacks def test_get_room_with_stats_unknown_room(self): - self.assertIsNone((yield self.store.get_room_with_stats("!uknown:test")),) + self.assertIsNone( + ( + yield defer.ensureDeferred( + self.store.get_room_with_stats("!uknown:test") + ) + ), + ) class RoomEventsStoreTestCase(unittest.TestCase): -- cgit 1.5.1 From 4a739c73b404284253a548f60197e70c6c385645 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 27 Aug 2020 07:08:38 -0400 Subject: Convert simple_update* and simple_select* to async (#8173) --- changelog.d/8173.misc | 1 + synapse/handlers/room.py | 6 +-- synapse/storage/database.py | 29 ++++++------ synapse/storage/databases/main/__init__.py | 8 ++-- synapse/storage/databases/main/directory.py | 6 +-- synapse/storage/databases/main/e2e_room_keys.py | 26 ++++++---- synapse/storage/databases/main/event_federation.py | 4 +- synapse/storage/databases/main/group_server.py | 55 +++++++++++++--------- synapse/storage/databases/main/media_repository.py | 16 ++++--- synapse/storage/databases/main/profile.py | 18 ++++--- synapse/storage/databases/main/receipts.py | 8 ++-- synapse/storage/databases/main/registration.py | 22 +++++---- synapse/storage/databases/main/room.py | 4 +- synapse/storage/databases/main/user_directory.py | 4 +- tests/handlers/test_stats.py | 4 +- tests/storage/test_base.py | 26 ++++++---- tests/storage/test_directory.py | 6 ++- tests/storage/test_main.py | 4 +- tests/test_federation.py | 50 +++++++++----------- 19 files changed, 164 insertions(+), 133 deletions(-) create mode 100644 changelog.d/8173.misc (limited to 'synapse/storage/databases/main/registration.py') diff --git a/changelog.d/8173.misc b/changelog.d/8173.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8173.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index e4788ef86b..236a37f777 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -51,7 +51,7 @@ from synapse.types import ( create_requester, ) from synapse.util import stringutils -from synapse.util.async_helpers import Linearizer, maybe_awaitable +from synapse.util.async_helpers import Linearizer from synapse.util.caches.response_cache import ResponseCache from synapse.visibility import filter_events_for_client @@ -1329,9 +1329,7 @@ class RoomShutdownHandler(object): ratelimit=False, ) - aliases_for_room = await maybe_awaitable( - self.store.get_aliases_for_room(room_id) - ) + aliases_for_room = await self.store.get_aliases_for_room(room_id) await self.store.update_aliases_for_room( room_id, new_room_id, requester_user_id diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 181c3ec249..38010af600 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -1132,13 +1132,13 @@ class DatabasePool(object): return [r[0] for r in txn] - def simple_select_onecol( + async def simple_select_onecol( self, table: str, keyvalues: Optional[Dict[str, Any]], retcol: str, desc: str = "simple_select_onecol", - ) -> defer.Deferred: + ) -> List[Any]: """Executes a SELECT query on the named table, which returns a list comprising of the values of the named column from the selected rows. @@ -1148,19 +1148,19 @@ class DatabasePool(object): retcol: column whos value we wish to retrieve. Returns: - Deferred: Results in a list + Results in a list """ - return self.runInteraction( + return await self.runInteraction( desc, self.simple_select_onecol_txn, table, keyvalues, retcol ) - def simple_select_list( + async def simple_select_list( self, table: str, keyvalues: Optional[Dict[str, Any]], retcols: Iterable[str], desc: str = "simple_select_list", - ) -> defer.Deferred: + ) -> List[Dict[str, Any]]: """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -1170,10 +1170,11 @@ class DatabasePool(object): column names and values to select the rows with, or None to not apply a WHERE clause. retcols: the names of the columns to return + Returns: - defer.Deferred: resolves to list[dict[str, Any]] + A list of dictionaries. """ - return self.runInteraction( + return await self.runInteraction( desc, self.simple_select_list_txn, table, keyvalues, retcols ) @@ -1299,14 +1300,14 @@ class DatabasePool(object): txn.execute(sql, values) return cls.cursor_to_dict(txn) - def simple_update( + async def simple_update( self, table: str, keyvalues: Dict[str, Any], updatevalues: Dict[str, Any], desc: str, - ) -> defer.Deferred: - return self.runInteraction( + ) -> int: + return await self.runInteraction( desc, self.simple_update_txn, table, keyvalues, updatevalues ) @@ -1332,13 +1333,13 @@ class DatabasePool(object): return txn.rowcount - def simple_update_one( + async def simple_update_one( self, table: str, keyvalues: Dict[str, Any], updatevalues: Dict[str, Any], desc: str = "simple_update_one", - ) -> defer.Deferred: + ) -> None: """Executes an UPDATE query on the named table, setting new values for columns in a row matching the key values. @@ -1347,7 +1348,7 @@ class DatabasePool(object): keyvalues: dict of column names and values to select the row with updatevalues: dict giving column names and values to update """ - return self.runInteraction( + await self.runInteraction( desc, self.simple_update_one_txn, table, keyvalues, updatevalues ) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 0934ae276c..8b9b6eb472 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -18,6 +18,7 @@ import calendar import logging import time +from typing import Any, Dict, List from synapse.api.constants import PresenceState from synapse.config.homeserver import HomeServerConfig @@ -476,14 +477,13 @@ class DataStore( "generate_user_daily_visits", _generate_user_daily_visits ) - def get_users(self): + async def get_users(self) -> List[Dict[str, Any]]: """Function to retrieve a list of users in users table. - Args: Returns: - defer.Deferred: resolves to list[dict[str, Any]] + A list of dictionaries representing users. """ - return self.db_pool.simple_select_list( + return await self.db_pool.simple_select_list( table="users", keyvalues={}, retcols=[ diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index 301d5d845a..405b5eafa5 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -14,7 +14,7 @@ # limitations under the License. from collections import namedtuple -from typing import Iterable, Optional +from typing import Iterable, List, Optional from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore @@ -68,8 +68,8 @@ class DirectoryWorkerStore(SQLBaseStore): ) @cached(max_entries=5000) - def get_aliases_for_room(self, room_id): - return self.db_pool.simple_select_onecol( + async def get_aliases_for_room(self, room_id: str) -> List[str]: + return await self.db_pool.simple_select_onecol( "room_aliases", {"room_id": room_id}, "room_alias", diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index 46c3e33cc6..82f9d870fd 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + from synapse.api.errors import StoreError from synapse.logging.opentracing import log_kv, trace from synapse.storage._base import SQLBaseStore, db_to_json @@ -368,18 +370,22 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) @trace - def update_e2e_room_keys_version( - self, user_id, version, info=None, version_etag=None - ): + async def update_e2e_room_keys_version( + self, + user_id: str, + version: str, + info: Optional[dict] = None, + version_etag: Optional[int] = None, + ) -> None: """Update a given backup version Args: - user_id(str): the user whose backup version we're updating - version(str): the version ID of the backup version we're updating - info (dict): the new backup version info to store. If None, then - the backup version info is not updated - version_etag (Optional[int]): etag of the keys in the backup. If - None, then the etag is not updated + user_id: the user whose backup version we're updating + version: the version ID of the backup version we're updating + info: the new backup version info to store. If None, then the backup + version info is not updated. + version_etag: etag of the keys in the backup. If None, then the etag + is not updated. """ updatevalues = {} @@ -389,7 +395,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): updatevalues["etag"] = version_etag if updatevalues: - return self.db_pool.simple_update( + await self.db_pool.simple_update( table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": version}, updatevalues=updatevalues, diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index e6a97b018c..6e5761c7b7 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -368,8 +368,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas ) @cached(max_entries=5000, iterable=True) - def get_latest_event_ids_in_room(self, room_id): - return self.db_pool.simple_select_onecol( + async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]: + return await self.db_pool.simple_select_onecol( table="event_forward_extremities", keyvalues={"room_id": room_id}, retcol="event_id", diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index c39864f59f..e3ead71853 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -44,24 +44,26 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_group", ) - def get_users_in_group(self, group_id, include_private=False): + async def get_users_in_group( + self, group_id: str, include_private: bool = False + ) -> List[Dict[str, Any]]: # TODO: Pagination keyvalues = {"group_id": group_id} if not include_private: keyvalues["is_public"] = True - return self.db_pool.simple_select_list( + return await self.db_pool.simple_select_list( table="group_users", keyvalues=keyvalues, retcols=("user_id", "is_public", "is_admin"), desc="get_users_in_group", ) - def get_invited_users_in_group(self, group_id): + async def get_invited_users_in_group(self, group_id: str) -> List[str]: # TODO: Pagination - return self.db_pool.simple_select_onecol( + return await self.db_pool.simple_select_onecol( table="group_invites", keyvalues={"group_id": group_id}, retcol="user_id", @@ -265,15 +267,14 @@ class GroupServerWorkerStore(SQLBaseStore): return role - def get_local_groups_for_room(self, room_id): + async def get_local_groups_for_room(self, room_id: str) -> List[str]: """Get all of the local group that contain a given room Args: - room_id (str): The ID of a room + room_id: The ID of a room Returns: - Deferred[list[str]]: A twisted.Deferred containing a list of group ids - containing this room + A list of group ids containing this room """ - return self.db_pool.simple_select_onecol( + return await self.db_pool.simple_select_onecol( table="group_rooms", keyvalues={"room_id": room_id}, retcol="group_id", @@ -422,10 +423,10 @@ class GroupServerWorkerStore(SQLBaseStore): "get_users_membership_info_in_group", _get_users_membership_in_group_txn ) - def get_publicised_groups_for_user(self, user_id): + async def get_publicised_groups_for_user(self, user_id: str) -> List[str]: """Get all groups a user is publicising """ - return self.db_pool.simple_select_onecol( + return await self.db_pool.simple_select_onecol( table="local_group_membership", keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True}, retcol="group_id", @@ -466,8 +467,8 @@ class GroupServerWorkerStore(SQLBaseStore): return None - def get_joined_groups(self, user_id): - return self.db_pool.simple_select_onecol( + async def get_joined_groups(self, user_id: str) -> List[str]: + return await self.db_pool.simple_select_onecol( table="local_group_membership", keyvalues={"user_id": user_id, "membership": "join"}, retcol="group_id", @@ -585,14 +586,14 @@ class GroupServerWorkerStore(SQLBaseStore): class GroupServerStore(GroupServerWorkerStore): - def set_group_join_policy(self, group_id, join_policy): + async def set_group_join_policy(self, group_id: str, join_policy: str) -> None: """Set the join policy of a group. join_policy can be one of: * "invite" * "open" """ - return self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="groups", keyvalues={"group_id": group_id}, updatevalues={"join_policy": join_policy}, @@ -1050,8 +1051,10 @@ class GroupServerStore(GroupServerWorkerStore): desc="add_room_to_group", ) - def update_room_in_group_visibility(self, group_id, room_id, is_public): - return self.db_pool.simple_update( + async def update_room_in_group_visibility( + self, group_id: str, room_id: str, is_public: bool + ) -> int: + return await self.db_pool.simple_update( table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, updatevalues={"is_public": is_public}, @@ -1076,10 +1079,12 @@ class GroupServerStore(GroupServerWorkerStore): "remove_room_from_group", _remove_room_from_group_txn ) - def update_group_publicity(self, group_id, user_id, publicise): + async def update_group_publicity( + self, group_id: str, user_id: str, publicise: bool + ) -> None: """Update whether the user is publicising their membership of the group """ - return self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="local_group_membership", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={"is_publicised": publicise}, @@ -1218,20 +1223,24 @@ class GroupServerStore(GroupServerWorkerStore): desc="update_group_profile", ) - def update_attestation_renewal(self, group_id, user_id, attestation): + async def update_attestation_renewal( + self, group_id: str, user_id: str, attestation: dict + ) -> None: """Update an attestation that we have renewed """ - return self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={"valid_until_ms": attestation["valid_until_ms"]}, desc="update_attestation_renewal", ) - def update_remote_attestion(self, group_id, user_id, attestation): + async def update_remote_attestion( + self, group_id: str, user_id: str, attestation: dict + ) -> None: """Update an attestation that a remote has renewed """ - return self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={ diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 4ae255ebd8..fc223f5a2a 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool @@ -84,9 +84,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="store_local_media", ) - def mark_local_media_as_safe(self, media_id: str): + async def mark_local_media_as_safe(self, media_id: str) -> None: """Mark a local media as safe from quarantining.""" - return self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="local_media_repository", keyvalues={"media_id": media_id}, updatevalues={"safe_from_quarantine": True}, @@ -158,8 +158,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="store_url_cache", ) - def get_local_media_thumbnails(self, media_id): - return self.db_pool.simple_select_list( + async def get_local_media_thumbnails(self, media_id: str) -> List[Dict[str, Any]]: + return await self.db_pool.simple_select_list( "local_media_repository_thumbnails", {"media_id": media_id}, ( @@ -271,8 +271,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "update_cached_last_access_time", update_cache_txn ) - def get_remote_media_thumbnails(self, origin, media_id): - return self.db_pool.simple_select_list( + async def get_remote_media_thumbnails( + self, origin: str, media_id: str + ) -> List[Dict[str, Any]]: + return await self.db_pool.simple_select_list( "remote_media_cache_thumbnails", {"media_origin": origin, "media_id": media_id}, ( diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index b8233c4848..858fd92420 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -71,16 +71,20 @@ class ProfileWorkerStore(SQLBaseStore): table="profiles", values={"user_id": user_localpart}, desc="create_profile" ) - def set_profile_displayname(self, user_localpart, new_displayname): - return self.db_pool.simple_update_one( + async def set_profile_displayname( + self, user_localpart: str, new_displayname: str + ) -> None: + await self.db_pool.simple_update_one( table="profiles", keyvalues={"user_id": user_localpart}, updatevalues={"displayname": new_displayname}, desc="set_profile_displayname", ) - def set_profile_avatar_url(self, user_localpart, new_avatar_url): - return self.db_pool.simple_update_one( + async def set_profile_avatar_url( + self, user_localpart: str, new_avatar_url: str + ) -> None: + await self.db_pool.simple_update_one( table="profiles", keyvalues={"user_id": user_localpart}, updatevalues={"avatar_url": new_avatar_url}, @@ -106,8 +110,10 @@ class ProfileStore(ProfileWorkerStore): desc="add_remote_profile_cache", ) - def update_remote_profile_cache(self, user_id, displayname, avatar_url): - return self.db_pool.simple_update( + async def update_remote_profile_cache( + self, user_id: str, displayname: str, avatar_url: str + ) -> int: + return await self.db_pool.simple_update( table="remote_profile_cache", keyvalues={"user_id": user_id}, updatevalues={ diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index cea5ac9a68..436f22ad2d 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -16,7 +16,7 @@ import abc import logging -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from twisted.internet import defer @@ -62,8 +62,10 @@ class ReceiptsWorkerStore(SQLBaseStore): return {r["user_id"] for r in receipts} @cached(num_args=2) - def get_receipts_for_room(self, room_id, receipt_type): - return self.db_pool.simple_select_list( + async def get_receipts_for_room( + self, room_id: str, receipt_type: str + ) -> List[Dict[str, Any]]: + return await self.db_pool.simple_select_list( table="receipts_linearized", keyvalues={"room_id": room_id, "receipt_type": receipt_type}, retcols=("user_id", "event_id"), diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index eced53d470..48bda66f3e 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -578,20 +578,20 @@ class RegistrationWorkerStore(SQLBaseStore): desc="add_user_bound_threepid", ) - def user_get_bound_threepids(self, user_id): + async def user_get_bound_threepids(self, user_id: str) -> List[Dict[str, Any]]: """Get the threepids that a user has bound to an identity server through the homeserver The homeserver remembers where binds to an identity server occurred. Using this method can retrieve those threepids. Args: - user_id (str): The ID of the user to retrieve threepids for + user_id: The ID of the user to retrieve threepids for Returns: - Deferred[list[dict]]: List of dictionaries containing the following: + List of dictionaries containing the following keys: medium (str): The medium of the threepid (e.g "email") address (str): The address of the threepid (e.g "bob@example.com") """ - return self.db_pool.simple_select_list( + return await self.db_pool.simple_select_list( table="user_threepid_id_server", keyvalues={"user_id": user_id}, retcols=["medium", "address"], @@ -623,19 +623,21 @@ class RegistrationWorkerStore(SQLBaseStore): desc="remove_user_bound_threepid", ) - def get_id_servers_user_bound(self, user_id, medium, address): + async def get_id_servers_user_bound( + self, user_id: str, medium: str, address: str + ) -> List[str]: """Get the list of identity servers that the server proxied bind requests to for given user and threepid Args: - user_id (str) - medium (str) - address (str) + user_id: The user to query for identity servers. + medium: The medium to query for identity servers. + address: The address to query for identity servers. Returns: - Deferred[list[str]]: Resolves to a list of identity servers + A list of identity servers """ - return self.db_pool.simple_select_onecol( + return await self.db_pool.simple_select_onecol( table="user_threepid_id_server", keyvalues={"user_id": user_id, "medium": medium, "address": address}, retcol="id_server", diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 97ecdb16e4..66d7135413 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -125,8 +125,8 @@ class RoomWorkerStore(SQLBaseStore): "get_room_with_stats", get_room_with_stats_txn, room_id ) - def get_public_room_ids(self): - return self.db_pool.simple_select_onecol( + async def get_public_room_ids(self) -> List[str]: + return await self.db_pool.simple_select_onecol( table="rooms", keyvalues={"is_public": True}, retcol="room_id", diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 20cbcd851c..a9f2e93614 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -537,8 +537,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): desc="get_user_in_directory", ) - def update_user_directory_stream_pos(self, stream_id): - return self.db_pool.simple_update_one( + async def update_user_directory_stream_pos(self, stream_id: str) -> None: + await self.db_pool.simple_update_one( table="user_directory_stream_pos", keyvalues={}, updatevalues={"stream_id": stream_id}, diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 88b05c23a0..a609f148c0 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -81,8 +81,8 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) - def get_all_room_state(self): - return self.store.db_pool.simple_select_list( + async def get_all_room_state(self): + return await self.store.db_pool.simple_select_list( "room_stats_state", None, retcols=("name", "topic", "canonical_alias") ) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index bf22540d99..64abe8cc49 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -148,8 +148,10 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) self.mock_txn.description = (("colA", None, None, None, None, None, None),) - ret = yield self.datastore.db_pool.simple_select_list( - table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"] + ret = yield defer.ensureDeferred( + self.datastore.db_pool.simple_select_list( + table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"] + ) ) self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret) @@ -161,10 +163,12 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_update_one_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore.db_pool.simple_update_one( - table="tablename", - keyvalues={"keycol": "TheKey"}, - updatevalues={"columnname": "New Value"}, + yield defer.ensureDeferred( + self.datastore.db_pool.simple_update_one( + table="tablename", + keyvalues={"keycol": "TheKey"}, + updatevalues={"columnname": "New Value"}, + ) ) self.mock_txn.execute.assert_called_with( @@ -176,10 +180,12 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_update_one_4cols(self): self.mock_txn.rowcount = 1 - yield self.datastore.db_pool.simple_update_one( - table="tablename", - keyvalues=OrderedDict([("colA", 1), ("colB", 2)]), - updatevalues=OrderedDict([("colC", 3), ("colD", 4)]), + yield defer.ensureDeferred( + self.datastore.db_pool.simple_update_one( + table="tablename", + keyvalues=OrderedDict([("colA", 1), ("colB", 2)]), + updatevalues=OrderedDict([("colC", 3), ("colD", 4)]), + ) ) self.mock_txn.execute.assert_called_with( diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index daac947cb2..da93ca3980 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -42,7 +42,11 @@ class DirectoryStoreTestCase(unittest.TestCase): self.assertEquals( ["#my-room:test"], - (yield self.store.get_aliases_for_room(self.room.to_string())), + ( + yield defer.ensureDeferred( + self.store.get_aliases_for_room(self.room.to_string()) + ) + ), ) @defer.inlineCallbacks diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py index fbf8af940a..954338a592 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py @@ -36,7 +36,9 @@ class DataStoreTestCase(unittest.TestCase): def test_get_users_paginate(self): yield self.store.register_user(self.user.to_string(), "pass") yield defer.ensureDeferred(self.store.create_profile(self.user.localpart)) - yield self.store.set_profile_displayname(self.user.localpart, self.displayname) + yield defer.ensureDeferred( + self.store.set_profile_displayname(self.user.localpart, self.displayname) + ) users, total = yield self.store.get_users_paginate( 0, 10, name="bc", guests=False diff --git a/tests/test_federation.py b/tests/test_federation.py index 4a4548433f..27a7fc9ed7 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -15,8 +15,9 @@ from mock import Mock -from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed +from twisted.internet.defer import succeed +from synapse.api.errors import FederationError from synapse.events import make_event_from_dict from synapse.logging.context import LoggingContext from synapse.types import Requester, UserID @@ -44,22 +45,17 @@ class MessageAcceptTests(unittest.HomeserverTestCase): user_id = UserID("us", "test") our_user = Requester(user_id, None, False, False, None, None) room_creator = self.homeserver.get_room_creation_handler() - room_deferred = ensureDeferred( + self.room_id = self.get_success( room_creator.create_room( our_user, room_creator._presets_dict["public_chat"], ratelimit=False ) - ) - self.reactor.advance(0.1) - self.room_id = self.successResultOf(room_deferred)[0]["room_id"] + )[0]["room_id"] self.store = self.homeserver.get_datastore() # Figure out what the most recent event is - most_recent = self.successResultOf( - maybeDeferred( - self.homeserver.get_datastore().get_latest_event_ids_in_room, - self.room_id, - ) + most_recent = self.get_success( + self.homeserver.get_datastore().get_latest_event_ids_in_room(self.room_id) )[0] join_event = make_event_from_dict( @@ -89,19 +85,18 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) # Send the join, it should return None (which is not an error) - d = ensureDeferred( - self.handler.on_receive_pdu( - "test.serv", join_event, sent_to_us_directly=True - ) + self.assertEqual( + self.get_success( + self.handler.on_receive_pdu( + "test.serv", join_event, sent_to_us_directly=True + ) + ), + None, ) - self.reactor.advance(1) - self.assertEqual(self.successResultOf(d), None) # Make sure we actually joined the room self.assertEqual( - self.successResultOf( - maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) - )[0], + self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))[0], "$join:test.serv", ) @@ -119,8 +114,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase): self.http_client.post_json = post_json # Figure out what the most recent event is - most_recent = self.successResultOf( - maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) + most_recent = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) )[0] # Now lie about an event @@ -140,17 +135,14 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) with LoggingContext(request="lying_event"): - d = ensureDeferred( + failure = self.get_failure( self.handler.on_receive_pdu( "test.serv", lying_event, sent_to_us_directly=True - ) + ), + FederationError, ) - # Step the reactor, so the database fetches come back - self.reactor.advance(1) - # on_receive_pdu should throw an error - failure = self.failureResultOf(d) self.assertEqual( failure.value.args[0], ( @@ -160,8 +152,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) # Make sure the invalid event isn't there - extrem = maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) - self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv") + extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)) + self.assertEqual(extrem[0], "$join:test.serv") def test_retry_device_list_resync(self): """Tests that device lists are marked as stale if they couldn't be synced, and -- cgit 1.5.1 From 9b7ac03af3e7ceae7d1933db566ee407cfdef72d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 27 Aug 2020 13:38:41 -0400 Subject: Convert calls of async database methods to async (#8166) --- changelog.d/8166.misc | 1 + synapse/federation/persistence.py | 16 +++++++----- synapse/federation/units.py | 4 +-- synapse/storage/databases/main/appservice.py | 6 ++--- synapse/storage/databases/main/devices.py | 4 +-- synapse/storage/databases/main/group_server.py | 30 ++++++++++++++++------ synapse/storage/databases/main/keys.py | 26 +++++++++++-------- synapse/storage/databases/main/media_repository.py | 22 ++++++++-------- synapse/storage/databases/main/openid.py | 6 +++-- synapse/storage/databases/main/profile.py | 10 +++++--- synapse/storage/databases/main/registration.py | 29 ++++++++++----------- synapse/storage/databases/main/room.py | 16 ++++++++---- synapse/storage/databases/main/stats.py | 10 ++++---- synapse/storage/databases/main/transactions.py | 18 +++++++------ 14 files changed, 114 insertions(+), 84 deletions(-) create mode 100644 changelog.d/8166.misc (limited to 'synapse/storage/databases/main/registration.py') diff --git a/changelog.d/8166.misc b/changelog.d/8166.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8166.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index d68b4bd670..769cd5de28 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -21,7 +21,9 @@ These actions are mostly only used by the :py:mod:`.replication` module. import logging +from synapse.federation.units import Transaction from synapse.logging.utils import log_function +from synapse.types import JsonDict logger = logging.getLogger(__name__) @@ -49,15 +51,15 @@ class TransactionActions(object): return self.store.get_received_txn_response(transaction.transaction_id, origin) @log_function - def set_response(self, origin, transaction, code, response): + async def set_response( + self, origin: str, transaction: Transaction, code: int, response: JsonDict + ) -> None: """ Persist how we responded to a transaction. - - Returns: - Deferred """ - if not transaction.transaction_id: + transaction_id = transaction.transaction_id # type: ignore + if not transaction_id: raise RuntimeError("Cannot persist a transaction with no transaction_id") - return self.store.set_received_txn_response( - transaction.transaction_id, origin, code, response + await self.store.set_received_txn_response( + transaction_id, origin, code, response ) diff --git a/synapse/federation/units.py b/synapse/federation/units.py index 6b32e0dcbf..64d98fc8f6 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -107,9 +107,7 @@ class Transaction(JsonEncodedObject): if "edus" in kwargs and not kwargs["edus"]: del kwargs["edus"] - super(Transaction, self).__init__( - transaction_id=transaction_id, pdus=pdus, **kwargs - ) + super().__init__(transaction_id=transaction_id, pdus=pdus, **kwargs) @staticmethod def create_new(pdus, **kwargs): diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 77723f7d4d..92f56f1602 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -161,16 +161,14 @@ class ApplicationServiceTransactionWorkerStore( return result.get("state") return None - def set_appservice_state(self, service, state): + async def set_appservice_state(self, service, state) -> None: """Set the application service state. Args: service(ApplicationService): The service whose state to set. state(ApplicationServiceState): The connectivity state to apply. - Returns: - An Awaitable which resolves when the state was set successfully. """ - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( "application_services_state", {"as_id": service.id}, {"state": state} ) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index a811a39eb5..ecd3f3b310 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -716,11 +716,11 @@ class DeviceWorkerStore(SQLBaseStore): return {row["user_id"] for row in rows} - def mark_remote_user_device_cache_as_stale(self, user_id: str): + async def mark_remote_user_device_cache_as_stale(self, user_id: str) -> None: """Records that the server has reason to believe the cache of the devices for the remote users is out of date. """ - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="device_lists_remote_resync", keyvalues={"user_id": user_id}, values={}, diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index e3ead71853..8acf254bf3 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -742,7 +742,13 @@ class GroupServerStore(GroupServerWorkerStore): desc="remove_room_from_summary", ) - def upsert_group_category(self, group_id, category_id, profile, is_public): + async def upsert_group_category( + self, + group_id: str, + category_id: str, + profile: Optional[JsonDict], + is_public: Optional[bool], + ) -> None: """Add/update room category for group """ insertion_values = {} @@ -758,7 +764,7 @@ class GroupServerStore(GroupServerWorkerStore): else: update_values["is_public"] = is_public - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, values=update_values, @@ -773,7 +779,13 @@ class GroupServerStore(GroupServerWorkerStore): desc="remove_group_category", ) - def upsert_group_role(self, group_id, role_id, profile, is_public): + async def upsert_group_role( + self, + group_id: str, + role_id: str, + profile: Optional[JsonDict], + is_public: Optional[bool], + ) -> None: """Add/remove user role """ insertion_values = {} @@ -789,7 +801,7 @@ class GroupServerStore(GroupServerWorkerStore): else: update_values["is_public"] = is_public - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, values=update_values, @@ -938,10 +950,10 @@ class GroupServerStore(GroupServerWorkerStore): desc="remove_user_from_summary", ) - def add_group_invite(self, group_id, user_id): + async def add_group_invite(self, group_id: str, user_id: str) -> None: """Record that the group server has invited a user """ - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( table="group_invites", values={"group_id": group_id, "user_id": user_id}, desc="add_group_invite", @@ -1044,8 +1056,10 @@ class GroupServerStore(GroupServerWorkerStore): "remove_user_from_group", _remove_user_from_group_txn ) - def add_room_to_group(self, group_id, room_id, is_public): - return self.db_pool.simple_insert( + async def add_room_to_group( + self, group_id: str, room_id: str, is_public: bool + ) -> None: + await self.db_pool.simple_insert( table="group_rooms", values={"group_id": group_id, "room_id": room_id, "is_public": is_public}, desc="add_room_to_group", diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index fadcad51e7..1c0a049c55 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -140,22 +140,28 @@ class KeyStore(SQLBaseStore): for i in invalidations: invalidate((i,)) - def store_server_keys_json( - self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes - ): + async def store_server_keys_json( + self, + server_name: str, + key_id: str, + from_server: str, + ts_now_ms: int, + ts_expires_ms: int, + key_json_bytes: bytes, + ) -> None: """Stores the JSON bytes for a set of keys from a server The JSON should be signed by the originating server, the intermediate server, and by this server. Updates the value for the (server_name, key_id, from_server) triplet if one already existed. Args: - server_name (str): The name of the server. - key_id (str): The identifer of the key this JSON is for. - from_server (str): The server this JSON was fetched from. - ts_now_ms (int): The time now in milliseconds. - ts_valid_until_ms (int): The time when this json stops being valid. - key_json (bytes): The encoded JSON. + server_name: The name of the server. + key_id: The identifer of the key this JSON is for. + from_server: The server this JSON was fetched from. + ts_now_ms: The time now in milliseconds. + ts_valid_until_ms: The time when this json stops being valid. + key_json_bytes: The encoded JSON. """ - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="server_keys_json", keyvalues={ "server_name": server_name, diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 8361dd63d9..3919ecad69 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -60,7 +60,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="get_local_media", ) - def store_local_media( + async def store_local_media( self, media_id, media_type, @@ -69,8 +69,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): media_length, user_id, url_cache=None, - ): - return self.db_pool.simple_insert( + ) -> None: + await self.db_pool.simple_insert( "local_media_repository", { "media_id": media_id, @@ -141,10 +141,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn) - def store_url_cache( + async def store_url_cache( self, url, response_code, etag, expires_ts, og, media_id, download_ts ): - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( "local_media_repository_url_cache", { "url": url, @@ -172,7 +172,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="get_local_media_thumbnails", ) - def store_local_thumbnail( + async def store_local_thumbnail( self, media_id, thumbnail_width, @@ -181,7 +181,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): thumbnail_method, thumbnail_length, ): - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( "local_media_repository_thumbnails", { "media_id": media_id, @@ -212,7 +212,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="get_cached_remote_media", ) - def store_cached_remote_media( + async def store_cached_remote_media( self, origin, media_id, @@ -222,7 +222,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): upload_name, filesystem_id, ): - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( "remote_media_cache", { "media_origin": origin, @@ -288,7 +288,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="get_remote_media_thumbnails", ) - def store_remote_media_thumbnail( + async def store_remote_media_thumbnail( self, origin, media_id, @@ -299,7 +299,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): thumbnail_method, thumbnail_length, ): - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( "remote_media_cache_thumbnails", { "media_origin": origin, diff --git a/synapse/storage/databases/main/openid.py b/synapse/storage/databases/main/openid.py index dcd1ff911a..4db8949da7 100644 --- a/synapse/storage/databases/main/openid.py +++ b/synapse/storage/databases/main/openid.py @@ -2,8 +2,10 @@ from synapse.storage._base import SQLBaseStore class OpenIdStore(SQLBaseStore): - def insert_open_id_token(self, token, ts_valid_until_ms, user_id): - return self.db_pool.simple_insert( + async def insert_open_id_token( + self, token: str, ts_valid_until_ms: int, user_id: str + ) -> None: + await self.db_pool.simple_insert( table="open_id_tokens", values={ "token": token, diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 858fd92420..301875a672 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -66,8 +66,8 @@ class ProfileWorkerStore(SQLBaseStore): desc="get_from_remote_profile_cache", ) - def create_profile(self, user_localpart): - return self.db_pool.simple_insert( + async def create_profile(self, user_localpart: str) -> None: + await self.db_pool.simple_insert( table="profiles", values={"user_id": user_localpart}, desc="create_profile" ) @@ -93,13 +93,15 @@ class ProfileWorkerStore(SQLBaseStore): class ProfileStore(ProfileWorkerStore): - def add_remote_profile_cache(self, user_id, displayname, avatar_url): + async def add_remote_profile_cache( + self, user_id: str, displayname: str, avatar_url: str + ) -> None: """Ensure we are caching the remote user's profiles. This should only be called when `is_subscribed_remote_profile_for_user` would return true for the user. """ - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="remote_profile_cache", keyvalues={"user_id": user_id}, values={ diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 48bda66f3e..28f7ae0430 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -17,7 +17,7 @@ import logging import re -from typing import Any, Awaitable, Dict, List, Optional +from typing import Any, Dict, List, Optional from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError @@ -549,23 +549,22 @@ class RegistrationWorkerStore(SQLBaseStore): desc="user_delete_threepids", ) - def add_user_bound_threepid(self, user_id, medium, address, id_server): + async def add_user_bound_threepid( + self, user_id: str, medium: str, address: str, id_server: str + ): """The server proxied a bind request to the given identity server on behalf of the given user. We need to remember this in case the user asks us to unbind the threepid. Args: - user_id (str) - medium (str) - address (str) - id_server (str) - - Returns: - Awaitable + user_id + medium + address + id_server """ # We need to use an upsert, in case they user had already bound the # threepid - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="user_threepid_id_server", keyvalues={ "user_id": user_id, @@ -1083,9 +1082,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - def record_user_external_id( + async def record_user_external_id( self, auth_provider: str, external_id: str, user_id: str - ) -> Awaitable: + ) -> None: """Record a mapping from an external user id to a mxid Args: @@ -1093,7 +1092,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): external_id: id on that system user_id: complete mxid that it is mapped to """ - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( table="user_external_ids", values={ "auth_provider": auth_provider, @@ -1237,12 +1236,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): return res if res else False - def add_user_pending_deactivation(self, user_id): + async def add_user_pending_deactivation(self, user_id: str) -> None: """ Adds a user to the table of users who need to be parted from all the rooms they're in """ - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( "users_pending_deactivation", values={"user_id": user_id}, desc="add_user_pending_deactivation", diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 66d7135413..a92641c339 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -27,7 +27,7 @@ from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.search import SearchStore -from synapse.types import ThirdPartyInstanceID +from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -1296,11 +1296,17 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): return self.db_pool.runInteraction("get_rooms", f) - def add_event_report( - self, room_id, event_id, user_id, reason, content, received_ts - ): + async def add_event_report( + self, + room_id: str, + event_id: str, + user_id: str, + reason: str, + content: JsonDict, + received_ts: int, + ) -> None: next_id = self._event_reports_id_gen.get_next() - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( table="event_reports", values={ "id": next_id, diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 9fe97af56a..7af2608ca4 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -16,7 +16,7 @@ import logging from itertools import chain -from typing import Tuple +from typing import Any, Dict, Tuple from twisted.internet.defer import DeferredLock @@ -222,11 +222,11 @@ class StatsStore(StateDeltasStore): desc="stats_incremental_position", ) - def update_room_state(self, room_id, fields): + async def update_room_state(self, room_id: str, fields: Dict[str, Any]) -> None: """ Args: - room_id (str) - fields (dict[str:Any]) + room_id + fields """ # For whatever reason some of the fields may contain null bytes, which @@ -244,7 +244,7 @@ class StatsStore(StateDeltasStore): if field and "\0" in field: fields[col] = None - return self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="room_stats_state", keyvalues={"room_id": room_id}, values=fields, diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 52668dbdf9..2efcc0dc66 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -21,6 +21,7 @@ from canonicaljson import encode_canonical_json from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool +from synapse.types import JsonDict from synapse.util.caches.expiringcache import ExpiringCache db_binary_type = memoryview @@ -98,20 +99,21 @@ class TransactionStore(SQLBaseStore): else: return None - def set_received_txn_response(self, transaction_id, origin, code, response_dict): - """Persist the response we returened for an incoming transaction, and + async def set_received_txn_response( + self, transaction_id: str, origin: str, code: int, response_dict: JsonDict + ) -> None: + """Persist the response we returned for an incoming transaction, and should return for subsequent transactions with the same transaction_id and origin. Args: - txn - transaction_id (str) - origin (str) - code (int) - response_json (str) + transaction_id: The incoming transaction ID. + origin: The origin server. + code: The response code. + response_dict: The response, to be encoded into JSON. """ - return self.db_pool.simple_insert( + await self.db_pool.simple_insert( table="received_transactions", values={ "transaction_id": transaction_id, -- cgit 1.5.1 From b71d4a094c4370d0229fbd4545b85c049364ecf3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 27 Aug 2020 14:16:41 -0400 Subject: Convert simple_delete to async/await. (#8191) --- changelog.d/8191.misc | 1 + synapse/storage/database.py | 63 ++++++++++++++++++++++---- synapse/storage/databases/main/group_server.py | 28 +++++++----- synapse/storage/databases/main/registration.py | 29 ++++++------ tests/storage/test_event_push_actions.py | 6 ++- 5 files changed, 90 insertions(+), 37 deletions(-) create mode 100644 changelog.d/8191.misc (limited to 'synapse/storage/databases/main/registration.py') diff --git a/changelog.d/8191.misc b/changelog.d/8191.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8191.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index ba4c0c9af6..7ab370efef 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -614,6 +614,7 @@ class DatabasePool(object): """Runs a single query for a result set. Args: + desc: description of the transaction, for logging and metrics decoder - The function which can resolve the cursor results to something meaningful. query - The query string to execute @@ -649,7 +650,7 @@ class DatabasePool(object): or_ignore: bool stating whether an exception should be raised when a conflicting row already exists. If True, False will be returned by the function instead - desc: string giving a description of the transaction + desc: description of the transaction, for logging and metrics Returns: Whether the row was inserted or not. Only useful when `or_ignore` is True @@ -686,7 +687,7 @@ class DatabasePool(object): Args: table: string giving the table name values: dict of new column names and values for them - desc: string giving a description of the transaction + desc: description of the transaction, for logging and metrics """ await self.runInteraction(desc, self.simple_insert_many_txn, table, values) @@ -700,7 +701,6 @@ class DatabasePool(object): txn: The transaction to use. table: string giving the table name values: dict of new column names and values for them - desc: string giving a description of the transaction """ if not values: return @@ -755,6 +755,7 @@ class DatabasePool(object): keyvalues: The unique key columns and their new values values: The nonunique columns and their new values insertion_values: additional key/values to use only when inserting + desc: description of the transaction, for logging and metrics lock: True to lock the table when doing the upsert. Returns: Native upserts always return None. Emulated upserts return True if a @@ -1081,6 +1082,7 @@ class DatabasePool(object): retcols: list of strings giving the names of the columns to return allow_none: If true, return None instead of failing if the SELECT statement returns no rows + desc: description of the transaction, for logging and metrics """ return await self.runInteraction( desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none @@ -1166,6 +1168,7 @@ class DatabasePool(object): table: table name keyvalues: column names and values to select the rows with retcol: column whos value we wish to retrieve. + desc: description of the transaction, for logging and metrics Returns: Results in a list @@ -1190,6 +1193,7 @@ class DatabasePool(object): column names and values to select the rows with, or None to not apply a WHERE clause. retcols: the names of the columns to return + desc: description of the transaction, for logging and metrics Returns: A list of dictionaries. @@ -1243,14 +1247,16 @@ class DatabasePool(object): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. - Filters rows by if value of `column` is in `iterable`. + Filters rows by whether the value of `column` is in `iterable`. Args: table: string giving the table name column: column name to test for inclusion against `iterable` iterable: list - keyvalues: dict of column names and values to select the rows with retcols: list of strings giving the names of the columns to return + keyvalues: dict of column names and values to select the rows with + desc: description of the transaction, for logging and metrics + batch_size: the number of rows for each select query """ results = [] # type: List[Dict[str, Any]] @@ -1291,7 +1297,7 @@ class DatabasePool(object): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. - Filters rows by if value of `column` is in `iterable`. + Filters rows by whether the value of `column` is in `iterable`. Args: txn: Transaction object @@ -1367,6 +1373,7 @@ class DatabasePool(object): table: string giving the table name keyvalues: dict of column names and values to select the row with updatevalues: dict giving column names and values to update + desc: description of the transaction, for logging and metrics """ await self.runInteraction( desc, self.simple_update_one_txn, table, keyvalues, updatevalues @@ -1426,6 +1433,7 @@ class DatabasePool(object): Args: table: string giving the table name keyvalues: dict of column names and values to select the row with + desc: description of the transaction, for logging and metrics """ await self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues) @@ -1451,13 +1459,38 @@ class DatabasePool(object): if txn.rowcount > 1: raise StoreError(500, "More than one row matched (%s)" % (table,)) - def simple_delete(self, table: str, keyvalues: Dict[str, Any], desc: str): - return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues) + async def simple_delete( + self, table: str, keyvalues: Dict[str, Any], desc: str + ) -> int: + """Executes a DELETE query on the named table. + + Filters rows by the key-value pairs. + + Args: + table: string giving the table name + keyvalues: dict of column names and values to select the row with + desc: description of the transaction, for logging and metrics + + Returns: + The number of deleted rows. + """ + return await self.runInteraction(desc, self.simple_delete_txn, table, keyvalues) @staticmethod def simple_delete_txn( txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any] ) -> int: + """Executes a DELETE query on the named table. + + Filters rows by the key-value pairs. + + Args: + table: string giving the table name + keyvalues: dict of column names and values to select the row with + + Returns: + The number of deleted rows. + """ sql = "DELETE FROM %s WHERE %s" % ( table, " AND ".join("%s = ?" % (k,) for k in keyvalues), @@ -1474,6 +1507,20 @@ class DatabasePool(object): keyvalues: Dict[str, Any], desc: str, ) -> int: + """Executes a DELETE query on the named table. + + Filters rows by if value of `column` is in `iterable`. + + Args: + table: string giving the table name + column: column name to test for inclusion against `iterable` + iterable: list + keyvalues: dict of column names and values to select the rows with + desc: description of the transaction, for logging and metrics + + Returns: + Number rows deleted + """ return await self.runInteraction( desc, self.simple_delete_many_txn, table, column, iterable, keyvalues ) diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 8acf254bf3..6c60171888 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -728,11 +728,13 @@ class GroupServerStore(GroupServerWorkerStore): }, ) - def remove_room_from_summary(self, group_id, room_id, category_id): + async def remove_room_from_summary( + self, group_id: str, room_id: str, category_id: str + ) -> int: if category_id is None: category_id = _DEFAULT_CATEGORY_ID - return self.db_pool.simple_delete( + return await self.db_pool.simple_delete( table="group_summary_rooms", keyvalues={ "group_id": group_id, @@ -772,8 +774,8 @@ class GroupServerStore(GroupServerWorkerStore): desc="upsert_group_category", ) - def remove_group_category(self, group_id, category_id): - return self.db_pool.simple_delete( + async def remove_group_category(self, group_id: str, category_id: str) -> int: + return await self.db_pool.simple_delete( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, desc="remove_group_category", @@ -809,8 +811,8 @@ class GroupServerStore(GroupServerWorkerStore): desc="upsert_group_role", ) - def remove_group_role(self, group_id, role_id): - return self.db_pool.simple_delete( + async def remove_group_role(self, group_id: str, role_id: str) -> int: + return await self.db_pool.simple_delete( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, desc="remove_group_role", @@ -940,11 +942,13 @@ class GroupServerStore(GroupServerWorkerStore): }, ) - def remove_user_from_summary(self, group_id, user_id, role_id): + async def remove_user_from_summary( + self, group_id: str, user_id: str, role_id: str + ) -> int: if role_id is None: role_id = _DEFAULT_ROLE_ID - return self.db_pool.simple_delete( + return await self.db_pool.simple_delete( table="group_summary_users", keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id}, desc="remove_user_from_summary", @@ -1264,16 +1268,16 @@ class GroupServerStore(GroupServerWorkerStore): desc="update_remote_attestion", ) - def remove_attestation_renewal(self, group_id, user_id): + async def remove_attestation_renewal(self, group_id: str, user_id: str) -> int: """Remove an attestation that we thought we should renew, but actually shouldn't. Ideally this would never get called as we would never incorrectly try and do attestations for local users on local groups. Args: - group_id (str) - user_id (str) + group_id + user_id """ - return self.db_pool.simple_delete( + return await self.db_pool.simple_delete( table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, desc="remove_attestation_renewal", diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 28f7ae0430..12689f4308 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -529,21 +529,21 @@ class RegistrationWorkerStore(SQLBaseStore): "user_get_threepids", ) - def user_delete_threepid(self, user_id, medium, address): - return self.db_pool.simple_delete( + async def user_delete_threepid(self, user_id, medium, address) -> None: + await self.db_pool.simple_delete( "user_threepids", keyvalues={"user_id": user_id, "medium": medium, "address": address}, desc="user_delete_threepid", ) - def user_delete_threepids(self, user_id: str): + async def user_delete_threepids(self, user_id: str) -> None: """Delete all threepid this user has bound Args: user_id: The user id to delete all threepids of """ - return self.db_pool.simple_delete( + await self.db_pool.simple_delete( "user_threepids", keyvalues={"user_id": user_id}, desc="user_delete_threepids", @@ -597,21 +597,20 @@ class RegistrationWorkerStore(SQLBaseStore): desc="user_get_bound_threepids", ) - def remove_user_bound_threepid(self, user_id, medium, address, id_server): + async def remove_user_bound_threepid( + self, user_id: str, medium: str, address: str, id_server: str + ) -> None: """The server proxied an unbind request to the given identity server on behalf of the given user, so we remove the mapping of threepid to identity server. Args: - user_id (str) - medium (str) - address (str) - id_server (str) - - Returns: - Deferred + user_id + medium + address + id_server """ - return self.db_pool.simple_delete( + await self.db_pool.simple_delete( table="user_threepid_id_server", keyvalues={ "user_id": user_id, @@ -1247,14 +1246,14 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): desc="add_user_pending_deactivation", ) - def del_user_pending_deactivation(self, user_id): + async def del_user_pending_deactivation(self, user_id: str) -> None: """ Removes the given user to the table of users who need to be parted from all the rooms they're in, effectively marking that user as fully deactivated. """ # XXX: This should be simple_delete_one but we failed to put a unique index on # the table, so somehow duplicate entries have ended up in it. - return self.db_pool.simple_delete( + await self.db_pool.simple_delete( "users_pending_deactivation", keyvalues={"user_id": user_id}, desc="del_user_pending_deactivation", diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 238bad5b45..0e7427e57a 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -123,8 +123,10 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): yield _inject_actions(6, PlAIN_NOTIF) yield _rotate(7) - yield self.store.db_pool.simple_delete( - table="event_push_actions", keyvalues={"1": 1}, desc="" + yield defer.ensureDeferred( + self.store.db_pool.simple_delete( + table="event_push_actions", keyvalues={"1": 1}, desc="" + ) ) yield _assert_counts(1, 0) -- cgit 1.5.1 From d58fda99ffd29c15254788476fb8775370eebec5 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 28 Aug 2020 11:34:50 -0400 Subject: Convert `event_push_actions`, `registration`, and `roommember` datastores to async (#8197) --- changelog.d/8197.misc | 1 + .../storage/databases/main/event_push_actions.py | 38 ++-- synapse/storage/databases/main/registration.py | 238 +++++++++++---------- synapse/storage/databases/main/roommember.py | 52 ++--- 4 files changed, 169 insertions(+), 160 deletions(-) create mode 100644 changelog.d/8197.misc (limited to 'synapse/storage/databases/main/registration.py') diff --git a/changelog.d/8197.misc b/changelog.d/8197.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8197.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index e8834b2162..384c39f4d7 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -15,7 +15,7 @@ # limitations under the License. import logging -from typing import List +from typing import Dict, List, Union from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json @@ -383,19 +383,20 @@ class EventPushActionsWorkerStore(SQLBaseStore): # Now return the first `limit` return notifs[:limit] - def get_if_maybe_push_in_range_for_user(self, user_id, min_stream_ordering): + async def get_if_maybe_push_in_range_for_user( + self, user_id: str, min_stream_ordering: int + ) -> bool: """A fast check to see if there might be something to push for the user since the given stream ordering. May return false positives. Useful to know whether to bother starting a pusher on start up or not. Args: - user_id (str) - min_stream_ordering (int) + user_id + min_stream_ordering Returns: - Deferred[bool]: True if there may be push to process, False if - there definitely isn't. + True if there may be push to process, False if there definitely isn't. """ def _get_if_maybe_push_in_range_for_user_txn(txn): @@ -408,22 +409,20 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, (user_id, min_stream_ordering)) return bool(txn.fetchone()) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_if_maybe_push_in_range_for_user", _get_if_maybe_push_in_range_for_user_txn, ) - async def add_push_actions_to_staging(self, event_id, user_id_actions): + async def add_push_actions_to_staging( + self, event_id: str, user_id_actions: Dict[str, List[Union[dict, str]]] + ) -> None: """Add the push actions for the event to the push action staging area. Args: - event_id (str) - user_id_actions (dict[str, list[dict|str])]): A dictionary mapping - user_id to list of push actions, where an action can either be - a string or dict. - - Returns: - Deferred + event_id + user_id_actions: A mapping of user_id to list of push actions, where + an action can either be a string or dict. """ if not user_id_actions: @@ -507,7 +506,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): "Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago ) - def find_first_stream_ordering_after_ts(self, ts): + async def find_first_stream_ordering_after_ts(self, ts: int) -> int: """Gets the stream ordering corresponding to a given timestamp. Specifically, finds the stream_ordering of the first event that was @@ -516,13 +515,12 @@ class EventPushActionsWorkerStore(SQLBaseStore): relatively slow. Args: - ts (int): timestamp in millis + ts: timestamp in millis Returns: - Deferred[int]: stream ordering of the first event received on/after - the timestamp + stream ordering of the first event received on/after the timestamp """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "_find_first_stream_ordering_after_ts_txn", self._find_first_stream_ordering_after_ts_txn, ts, diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 12689f4308..000b8ccff4 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -17,7 +17,7 @@ import logging import re -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError @@ -84,17 +84,17 @@ class RegistrationWorkerStore(SQLBaseStore): return is_trial @cached() - def get_user_by_access_token(self, token): + async def get_user_by_access_token(self, token: str) -> Optional[dict]: """Get a user from the given access token. Args: - token (str): The access token of a user. + token: The access token of a user. Returns: - defer.Deferred: None, if the token did not match, otherwise dict - including the keys `name`, `is_guest`, `device_id`, `token_id`, - `valid_until_ms`. + None, if the token did not match, otherwise dict + including the keys `name`, `is_guest`, `device_id`, `token_id`, + `valid_until_ms`. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_user_by_access_token", self._query_for_auth, token ) @@ -281,13 +281,12 @@ class RegistrationWorkerStore(SQLBaseStore): return bool(res) if res else False - def set_server_admin(self, user, admin): + async def set_server_admin(self, user: UserID, admin: bool) -> None: """Sets whether a user is an admin of this homeserver. Args: - user (UserID): user ID of the user to test - admin (bool): true iff the user is to be a server admin, - false otherwise. + user: user ID of the user to test + admin: true iff the user is to be a server admin, false otherwise. """ def set_server_admin_txn(txn): @@ -298,7 +297,7 @@ class RegistrationWorkerStore(SQLBaseStore): txn, self.get_user_by_id, (user.to_string(),) ) - return self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) + await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) def _query_for_auth(self, txn, token): sql = ( @@ -364,9 +363,11 @@ class RegistrationWorkerStore(SQLBaseStore): ) return True if res == UserTypes.SUPPORT else False - def get_users_by_id_case_insensitive(self, user_id): + async def get_users_by_id_case_insensitive(self, user_id: str) -> Dict[str, str]: """Gets users that match user_id case insensitively. - Returns a mapping of user_id -> password_hash. + + Returns: + A mapping of user_id -> password_hash. """ def f(txn): @@ -374,7 +375,7 @@ class RegistrationWorkerStore(SQLBaseStore): txn.execute(sql, (user_id,)) return dict(txn) - return self.db_pool.runInteraction("get_users_by_id_case_insensitive", f) + return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f) async def get_user_by_external_id( self, auth_provider: str, external_id: str @@ -408,7 +409,7 @@ class RegistrationWorkerStore(SQLBaseStore): return await self.db_pool.runInteraction("count_users", _count_users) - def count_daily_user_type(self): + async def count_daily_user_type(self) -> Dict[str, int]: """ Counts 1) native non guest users 2) native guests users @@ -437,7 +438,7 @@ class RegistrationWorkerStore(SQLBaseStore): results[row[0]] = row[1] return results - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "count_daily_user_type", _count_daily_user_type ) @@ -663,24 +664,29 @@ class RegistrationWorkerStore(SQLBaseStore): # Convert the integer into a boolean. return res == 1 - def get_threepid_validation_session( - self, medium, client_secret, address=None, sid=None, validated=True - ): + async def get_threepid_validation_session( + self, + medium: Optional[str], + client_secret: str, + address: Optional[str] = None, + sid: Optional[str] = None, + validated: Optional[bool] = True, + ) -> Optional[Dict[str, Any]]: """Gets a session_id and last_send_attempt (if available) for a combination of validation metadata Args: - medium (str|None): The medium of the 3PID - address (str|None): The address of the 3PID - sid (str|None): The ID of the validation session - client_secret (str): A unique string provided by the client to help identify this + medium: The medium of the 3PID + client_secret: A unique string provided by the client to help identify this validation attempt - validated (bool|None): Whether sessions should be filtered by + address: The address of the 3PID + sid: The ID of the validation session + validated: Whether sessions should be filtered by whether they have been validated already or not. None to perform no filtering Returns: - Deferred[dict|None]: A dict containing the following: + A dict containing the following: * address - address of the 3pid * medium - medium of the 3pid * client_secret - a secret provided by the client for this validation session @@ -726,17 +732,17 @@ class RegistrationWorkerStore(SQLBaseStore): return rows[0] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_threepid_validation_session", get_threepid_validation_session_txn ) - def delete_threepid_session(self, session_id): + async def delete_threepid_session(self, session_id: str) -> None: """Removes a threepid validation session from the database. This can be done after validation has been performed and whatever action was waiting on it has been carried out Args: - session_id (str): The ID of the session to delete + session_id: The ID of the session to delete """ def delete_threepid_session_txn(txn): @@ -751,7 +757,7 @@ class RegistrationWorkerStore(SQLBaseStore): keyvalues={"session_id": session_id}, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_threepid_session", delete_threepid_session_txn ) @@ -941,43 +947,40 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): desc="add_access_token_to_user", ) - def register_user( + async def register_user( self, - user_id, - password_hash=None, - was_guest=False, - make_guest=False, - appservice_id=None, - create_profile_with_displayname=None, - admin=False, - user_type=None, - shadow_banned=False, - ): + user_id: str, + password_hash: Optional[str] = None, + was_guest: bool = False, + make_guest: bool = False, + appservice_id: Optional[str] = None, + create_profile_with_displayname: Optional[str] = None, + admin: bool = False, + user_type: Optional[str] = None, + shadow_banned: bool = False, + ) -> None: """Attempts to register an account. Args: - user_id (str): The desired user ID to register. - password_hash (str|None): Optional. The password hash for this user. - was_guest (bool): Optional. Whether this is a guest account being - upgraded to a non-guest account. - make_guest (boolean): True if the the new user should be guest, - false to add a regular user account. - appservice_id (str): The ID of the appservice registering the user. - create_profile_with_displayname (unicode): Optionally create a profile for + user_id: The desired user ID to register. + password_hash: Optional. The password hash for this user. + was_guest: Whether this is a guest account being upgraded to a + non-guest account. + make_guest: True if the the new user should be guest, false to add a + regular user account. + appservice_id: The ID of the appservice registering the user. + create_profile_with_displayname: Optionally create a profile for the user, setting their displayname to the given value - admin (boolean): is an admin user? - user_type (str|None): type of user. One of the values from - api.constants.UserTypes, or None for a normal user. - shadow_banned (bool): Whether the user is shadow-banned, - i.e. they may be told their requests succeeded but we ignore them. + admin: is an admin user? + user_type: type of user. One of the values from api.constants.UserTypes, + or None for a normal user. + shadow_banned: Whether the user is shadow-banned, i.e. they may be + told their requests succeeded but we ignore them. Raises: StoreError if the user_id could not be registered. - - Returns: - Deferred """ - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "register_user", self._register_user, user_id, @@ -1101,7 +1104,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): desc="record_user_external_id", ) - def user_set_password_hash(self, user_id, password_hash): + async def user_set_password_hash(self, user_id: str, password_hash: str) -> None: """ NB. This does *not* evict any cache because the one use for this removes most of the entries subsequently anyway so it would be @@ -1114,17 +1117,18 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "user_set_password_hash", user_set_password_hash_txn ) - def user_set_consent_version(self, user_id, consent_version): + async def user_set_consent_version( + self, user_id: str, consent_version: str + ) -> None: """Updates the user table to record privacy policy consent Args: - user_id (str): full mxid of the user to update - consent_version (str): version of the policy the user has consented - to + user_id: full mxid of the user to update + consent_version: version of the policy the user has consented to Raises: StoreError(404) if user not found @@ -1139,16 +1143,17 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - return self.db_pool.runInteraction("user_set_consent_version", f) + await self.db_pool.runInteraction("user_set_consent_version", f) - def user_set_consent_server_notice_sent(self, user_id, consent_version): + async def user_set_consent_server_notice_sent( + self, user_id: str, consent_version: str + ) -> None: """Updates the user table to record that we have sent the user a server notice about privacy policy consent Args: - user_id (str): full mxid of the user to update - consent_version (str): version of the policy we have notified the - user about + user_id: full mxid of the user to update + consent_version: version of the policy we have notified the user about Raises: StoreError(404) if user not found @@ -1163,22 +1168,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - return self.db_pool.runInteraction("user_set_consent_server_notice_sent", f) + await self.db_pool.runInteraction("user_set_consent_server_notice_sent", f) - def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None): + async def user_delete_access_tokens( + self, + user_id: str, + except_token_id: Optional[str] = None, + device_id: Optional[str] = None, + ) -> List[Tuple[str, int, Optional[str]]]: """ Invalidate access tokens belonging to a user Args: - user_id (str): ID of user the tokens belong to - except_token_id (str): list of access_tokens IDs which should - *not* be deleted - device_id (str|None): ID of device the tokens are associated with. + user_id: ID of user the tokens belong to + except_token_id: access_tokens ID which should *not* be deleted + device_id: ID of device the tokens are associated with. If None, tokens associated with any device (or no device) will be deleted Returns: - defer.Deferred[list[str, int, str|None, int]]: a list of - (token, token id, device id) for each of the deleted tokens + A tuple of (token, token id, device id) for each of the deleted tokens """ def f(txn): @@ -1209,9 +1217,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): return tokens_and_devices - return self.db_pool.runInteraction("user_delete_access_tokens", f) + return await self.db_pool.runInteraction("user_delete_access_tokens", f) - def delete_access_token(self, access_token): + async def delete_access_token(self, access_token: str) -> None: def f(txn): self.db_pool.simple_delete_one_txn( txn, table="access_tokens", keyvalues={"token": access_token} @@ -1221,7 +1229,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): txn, self.get_user_by_access_token, (access_token,) ) - return self.db_pool.runInteraction("delete_access_token", f) + await self.db_pool.runInteraction("delete_access_token", f) @cached() async def is_guest(self, user_id: str) -> bool: @@ -1272,24 +1280,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): desc="get_users_pending_deactivation", ) - def validate_threepid_session(self, session_id, client_secret, token, current_ts): + async def validate_threepid_session( + self, session_id: str, client_secret: str, token: str, current_ts: int + ) -> Optional[str]: """Attempt to validate a threepid session using a token Args: - session_id (str): The id of a validation session - client_secret (str): A unique string provided by the client to - help identify this validation attempt - token (str): A validation token - current_ts (int): The current unix time in milliseconds. Used for - checking token expiry status + session_id: The id of a validation session + client_secret: A unique string provided by the client to help identify + this validation attempt + token: A validation token + current_ts: The current unix time in milliseconds. Used for checking + token expiry status Raises: ThreepidValidationError: if a matching validation token was not found or has expired Returns: - deferred str|None: A str representing a link to redirect the user - to if there is one. + A str representing a link to redirect the user to if there is one. """ # Insert everything into a transaction in order to run atomically @@ -1359,36 +1368,35 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): return next_link # Return next_link if it exists - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "validate_threepid_session_txn", validate_threepid_session_txn ) - def start_or_continue_validation_session( + async def start_or_continue_validation_session( self, - medium, - address, - session_id, - client_secret, - send_attempt, - next_link, - token, - token_expires, - ): + medium: str, + address: str, + session_id: str, + client_secret: str, + send_attempt: int, + next_link: Optional[str], + token: str, + token_expires: int, + ) -> None: """Creates a new threepid validation session if it does not already exist and associates a new validation token with it Args: - medium (str): The medium of the 3PID - address (str): The address of the 3PID - session_id (str): The id of this validation session - client_secret (str): A unique string provided by the client to - help identify this validation attempt - send_attempt (int): The latest send_attempt on this session - next_link (str|None): The link to redirect the user to upon - successful validation - token (str): The validation token - token_expires (int): The timestamp for which after the token - will no longer be valid + medium: The medium of the 3PID + address: The address of the 3PID + session_id: The id of this validation session + client_secret: A unique string provided by the client to help + identify this validation attempt + send_attempt: The latest send_attempt on this session + next_link: The link to redirect the user to upon successful validation + token: The validation token + token_expires: The timestamp for which after the token will no + longer be valid """ def start_or_continue_validation_session_txn(txn): @@ -1417,12 +1425,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): }, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "start_or_continue_validation_session", start_or_continue_validation_session_txn, ) - def cull_expired_threepid_validation_tokens(self): + async def cull_expired_threepid_validation_tokens(self) -> None: """Remove threepid validation tokens with expiry dates that have passed""" def cull_expired_threepid_validation_tokens_txn(txn, ts): @@ -1430,9 +1438,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): DELETE FROM threepid_validation_token WHERE expires < ? """ - return txn.execute(sql, (ts,)) + txn.execute(sql, (ts,)) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "cull_expired_threepid_validation_tokens", cull_expired_threepid_validation_tokens_txn, self.clock.time_msec(), diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 161edbeccb..c4ce5c17c2 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -15,7 +15,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set +from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set from synapse.api.constants import EventTypes, Membership from synapse.events import EventBase @@ -152,8 +152,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) @cached(max_entries=100000, iterable=True) - def get_users_in_room(self, room_id: str): - return self.db_pool.runInteraction( + async def get_users_in_room(self, room_id: str) -> List[str]: + return await self.db_pool.runInteraction( "get_users_in_room", self.get_users_in_room_txn, room_id ) @@ -180,14 +180,13 @@ class RoomMemberWorkerStore(EventsWorkerStore): return [r[0] for r in txn] @cached(max_entries=100000) - def get_room_summary(self, room_id: str): + async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]: """ Get the details of a room roughly suitable for use by the room summary extension to /sync. Useful when lazy loading room members. Args: room_id: The room ID to query Returns: - Deferred[dict[str, MemberSummary]: - dict of membership states, pointing to a MemberSummary named tuple. + dict of membership states, pointing to a MemberSummary named tuple. """ def _get_room_summary_txn(txn): @@ -261,20 +260,22 @@ class RoomMemberWorkerStore(EventsWorkerStore): return res - return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn) + return await self.db_pool.runInteraction( + "get_room_summary", _get_room_summary_txn + ) @cached() - def get_invited_rooms_for_local_user(self, user_id: str) -> Awaitable[RoomsForUser]: + async def get_invited_rooms_for_local_user(self, user_id: str) -> RoomsForUser: """Get all the rooms the *local* user is invited to. Args: user_id: The user ID. Returns: - A awaitable list of RoomsForUser. + A list of RoomsForUser. """ - return self.get_rooms_for_local_user_where_membership_is( + return await self.get_rooms_for_local_user_where_membership_is( user_id, [Membership.INVITE] ) @@ -357,7 +358,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): return results @cached(max_entries=500000, iterable=True) - def get_rooms_for_user_with_stream_ordering(self, user_id: str): + async def get_rooms_for_user_with_stream_ordering( + self, user_id: str + ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]: """Returns a set of room_ids the user is currently joined to. If a remote user only returns rooms this server is currently @@ -367,17 +370,18 @@ class RoomMemberWorkerStore(EventsWorkerStore): user_id Returns: - Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns - the rooms the user is in currently, along with the stream ordering - of the most recent join for that user and room. + Returns the rooms the user is in currently, along with the stream + ordering of the most recent join for that user and room. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_rooms_for_user_with_stream_ordering", self._get_rooms_for_user_with_stream_ordering_txn, user_id, ) - def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id: str): + def _get_rooms_for_user_with_stream_ordering_txn( + self, txn, user_id: str + ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]: # We use `current_state_events` here and not `local_current_membership` # as a) this gets called with remote users and b) this only gets called # for rooms the server is participating in. @@ -404,9 +408,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): """ txn.execute(sql, (user_id, Membership.JOIN)) - results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn) - - return results + return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn) async def get_users_server_still_shares_room_with( self, user_ids: Collection[str] @@ -711,14 +713,14 @@ class RoomMemberWorkerStore(EventsWorkerStore): return count == 0 @cached() - def get_forgotten_rooms_for_user(self, user_id: str): + async def get_forgotten_rooms_for_user(self, user_id: str) -> Set[str]: """Gets all rooms the user has forgotten. Args: - user_id + user_id: The user ID to query the rooms of. Returns: - Deferred[set[str]] + The forgotten rooms. """ def _get_forgotten_rooms_for_user_txn(txn): @@ -744,7 +746,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): txn.execute(sql, (user_id,)) return {row[0] for row in txn if row[1] == 0} - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn ) @@ -973,7 +975,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): super(RoomMemberStore, self).__init__(database, db_conn, hs) - def forget(self, user_id: str, room_id: str): + async def forget(self, user_id: str, room_id: str) -> None: """Indicate that user_id wishes to discard history for room_id.""" def f(txn): @@ -994,7 +996,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): txn, self.get_forgotten_rooms_for_user, (user_id,) ) - return self.db_pool.runInteraction("forget_membership", f) + await self.db_pool.runInteraction("forget_membership", f) class _JoinedHostsCache(object): -- cgit 1.5.1 From b4826d6eb18eeb05c973882efd4c0f5d68359449 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Fri, 28 Aug 2020 17:39:14 +0100 Subject: Fix incorrect return signature --- synapse/storage/databases/main/registration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage/databases/main/registration.py') diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 000b8ccff4..01f20c03c2 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -99,7 +99,7 @@ class RegistrationWorkerStore(SQLBaseStore): ) @cached() - async def get_expiration_ts_for_user(self, user_id: str) -> Optional[None]: + async def get_expiration_ts_for_user(self, user_id: str) -> Optional[int]: """Get the expiration timestamp for the account bearing a given user ID. Args: -- cgit 1.5.1