From a7bdf98d01d2225a479753a85ba81adf02b16a32 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 5 Aug 2020 21:38:57 +0100 Subject: Rename database classes to make some sense (#8033) --- .../storage/databases/main/monthly_active_users.py | 361 +++++++++++++++++++++ 1 file changed, 361 insertions(+) create mode 100644 synapse/storage/databases/main/monthly_active_users.py (limited to 'synapse/storage/databases/main/monthly_active_users.py') diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py new file mode 100644 index 0000000000..02b01d9619 --- /dev/null +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -0,0 +1,361 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import List + +from twisted.internet import defer + +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import DatabasePool, make_in_list_sql_clause +from synapse.util.caches.descriptors import cached + +logger = logging.getLogger(__name__) + +# Number of msec of granularity to store the monthly_active_user timestamp +# This means it is not necessary to update the table on every request +LAST_SEEN_GRANULARITY = 60 * 60 * 1000 + + +class MonthlyActiveUsersWorkerStore(SQLBaseStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super(MonthlyActiveUsersWorkerStore, self).__init__(database, db_conn, hs) + self._clock = hs.get_clock() + self.hs = hs + + @cached(num_args=0) + def get_monthly_active_count(self): + """Generates current count of monthly active users + + Returns: + Defered[int]: Number of current monthly active users + """ + + def _count_users(txn): + sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users" + txn.execute(sql) + (count,) = txn.fetchone() + return count + + return self.db_pool.runInteraction("count_users", _count_users) + + @cached(num_args=0) + def get_monthly_active_count_by_service(self): + """Generates current count of monthly active users broken down by service. + A service is typically an appservice but also includes native matrix users. + Since the `monthly_active_users` table is populated from the `user_ips` table + `config.track_appservice_user_ips` must be set to `true` for this + method to return anything other than native matrix users. + + Returns: + Deferred[dict]: dict that includes a mapping between app_service_id + and the number of occurrences. + + """ + + def _count_users_by_service(txn): + sql = """ + SELECT COALESCE(appservice_id, 'native'), COALESCE(count(*), 0) + FROM monthly_active_users + LEFT JOIN users ON monthly_active_users.user_id=users.name + GROUP BY appservice_id; + """ + + txn.execute(sql) + result = txn.fetchall() + return dict(result) + + return self.db_pool.runInteraction( + "count_users_by_service", _count_users_by_service + ) + + async def get_registered_reserved_users(self) -> List[str]: + """Of the reserved threepids defined in config, retrieve those that are associated + with registered users + + Returns: + User IDs of actual users that are reserved + """ + users = [] + + for tp in self.hs.config.mau_limits_reserved_threepids[ + : self.hs.config.max_mau_value + ]: + user_id = await self.hs.get_datastore().get_user_id_by_threepid( + tp["medium"], tp["address"] + ) + if user_id: + users.append(user_id) + + return users + + @cached(num_args=1) + def user_last_seen_monthly_active(self, user_id): + """ + Checks if a given user is part of the monthly active user group + Arguments: + user_id (str): user to add/update + Return: + Deferred[int] : timestamp since last seen, None if never seen + + """ + + return self.db_pool.simple_select_one_onecol( + table="monthly_active_users", + keyvalues={"user_id": user_id}, + retcol="timestamp", + allow_none=True, + desc="user_last_seen_monthly_active", + ) + + +class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs) + + self._limit_usage_by_mau = hs.config.limit_usage_by_mau + self._mau_stats_only = hs.config.mau_stats_only + self._max_mau_value = hs.config.max_mau_value + + # Do not add more reserved users than the total allowable number + # cur = LoggingTransaction( + self.db_pool.new_transaction( + db_conn, + "initialise_mau_threepids", + [], + [], + self._initialise_reserved_users, + hs.config.mau_limits_reserved_threepids[: self._max_mau_value], + ) + + def _initialise_reserved_users(self, txn, threepids): + """Ensures that reserved threepids are accounted for in the MAU table, should + be called on start up. + + Args: + txn (cursor): + threepids (list[dict]): List of threepid dicts to reserve + """ + + # XXX what is this function trying to achieve? It upserts into + # monthly_active_users for each *registered* reserved mau user, but why? + # + # - shouldn't there already be an entry for each reserved user (at least + # if they have been active recently)? + # + # - if it's important that the timestamp is kept up to date, why do we only + # run this at startup? + + for tp in threepids: + user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"]) + + if user_id: + is_support = self.is_support_user_txn(txn, user_id) + if not is_support: + # We do this manually here to avoid hitting #6791 + self.db_pool.simple_upsert_txn( + txn, + table="monthly_active_users", + keyvalues={"user_id": user_id}, + values={"timestamp": int(self._clock.time_msec())}, + ) + else: + logger.warning("mau limit reserved threepid %s not found in db" % tp) + + async def reap_monthly_active_users(self): + """Cleans out monthly active user table to ensure that no stale + entries exist. + """ + + def _reap_users(txn, reserved_users): + """ + Args: + reserved_users (tuple): reserved users to preserve + """ + + thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) + + in_clause, in_clause_args = make_in_list_sql_clause( + self.database_engine, "user_id", reserved_users + ) + + txn.execute( + "DELETE FROM monthly_active_users WHERE timestamp < ? AND NOT %s" + % (in_clause,), + [thirty_days_ago] + in_clause_args, + ) + + if self._limit_usage_by_mau: + # If MAU user count still exceeds the MAU threshold, then delete on + # a least recently active basis. + # Note it is not possible to write this query using OFFSET due to + # incompatibilities in how sqlite and postgres support the feature. + # Sqlite requires 'LIMIT -1 OFFSET ?', the LIMIT must be present, + # while Postgres does not require 'LIMIT', but also does not support + # negative LIMIT values. So there is no way to write it that both can + # support + + # Limit must be >= 0 for postgres + num_of_non_reserved_users_to_remove = max( + self._max_mau_value - len(reserved_users), 0 + ) + + # It is important to filter reserved users twice to guard + # against the case where the reserved user is present in the + # SELECT, meaning that a legitimate mau is deleted. + sql = """ + DELETE FROM monthly_active_users + WHERE user_id NOT IN ( + SELECT user_id FROM monthly_active_users + WHERE NOT %s + ORDER BY timestamp DESC + LIMIT ? + ) + AND NOT %s + """ % ( + in_clause, + in_clause, + ) + + query_args = ( + in_clause_args + + [num_of_non_reserved_users_to_remove] + + in_clause_args + ) + txn.execute(sql, query_args) + + # It seems poor to invalidate the whole cache. Postgres supports + # 'Returning' which would allow me to invalidate only the + # specific users, but sqlite has no way to do this and instead + # I would need to SELECT and the DELETE which without locking + # is racy. + # Have resolved to invalidate the whole cache for now and do + # something about it if and when the perf becomes significant + self._invalidate_all_cache_and_stream( + txn, self.user_last_seen_monthly_active + ) + self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) + + reserved_users = await self.get_registered_reserved_users() + await self.db_pool.runInteraction( + "reap_monthly_active_users", _reap_users, reserved_users + ) + + @defer.inlineCallbacks + def upsert_monthly_active_user(self, user_id): + """Updates or inserts the user into the monthly active user table, which + is used to track the current MAU usage of the server + + Args: + user_id (str): user to add/update + + Returns: + Deferred + """ + # Support user never to be included in MAU stats. Note I can't easily call this + # from upsert_monthly_active_user_txn because then I need a _txn form of + # is_support_user which is complicated because I want to cache the result. + # Therefore I call it here and ignore the case where + # upsert_monthly_active_user_txn is called directly from + # _initialise_reserved_users reasoning that it would be very strange to + # include a support user in this context. + + is_support = yield self.is_support_user(user_id) + if is_support: + return + + yield self.db_pool.runInteraction( + "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id + ) + + def upsert_monthly_active_user_txn(self, txn, user_id): + """Updates or inserts monthly active user member + + We consciously do not call is_support_txn from this method because it + is not possible to cache the response. is_support_txn will be false in + almost all cases, so it seems reasonable to call it only for + upsert_monthly_active_user and to call is_support_txn manually + for cases where upsert_monthly_active_user_txn is called directly, + like _initialise_reserved_users + + In short, don't call this method with support users. (Support users + should not appear in the MAU stats). + + Args: + txn (cursor): + user_id (str): user to add/update + + Returns: + bool: True if a new entry was created, False if an + existing one was updated. + """ + + # Am consciously deciding to lock the table on the basis that is ought + # never be a big table and alternative approaches (batching multiple + # upserts into a single txn) introduced a lot of extra complexity. + # See https://github.com/matrix-org/synapse/issues/3854 for more + is_insert = self.db_pool.simple_upsert_txn( + txn, + table="monthly_active_users", + keyvalues={"user_id": user_id}, + values={"timestamp": int(self._clock.time_msec())}, + ) + + self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) + self._invalidate_cache_and_stream( + txn, self.get_monthly_active_count_by_service, () + ) + self._invalidate_cache_and_stream( + txn, self.user_last_seen_monthly_active, (user_id,) + ) + + return is_insert + + @defer.inlineCallbacks + def populate_monthly_active_users(self, user_id): + """Checks on the state of monthly active user limits and optionally + add the user to the monthly active tables + + Args: + user_id(str): the user_id to query + """ + if self._limit_usage_by_mau or self._mau_stats_only: + # Trial users and guests should not be included as part of MAU group + is_guest = yield self.is_guest(user_id) + if is_guest: + return + is_trial = yield self.is_trial_user(user_id) + if is_trial: + return + + last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id) + now = self.hs.get_clock().time_msec() + + # We want to reduce to the total number of db writes, and are happy + # to trade accuracy of timestamp in order to lighten load. This means + # We always insert new users (where MAU threshold has not been reached), + # but only update if we have not previously seen the user for + # LAST_SEEN_GRANULARITY ms + if last_seen_timestamp is None: + # In the case where mau_stats_only is True and limit_usage_by_mau is + # False, there is no point in checking get_monthly_active_count - it + # adds no value and will break the logic if max_mau_value is exceeded. + if not self._limit_usage_by_mau: + yield self.upsert_monthly_active_user(user_id) + else: + count = yield self.get_monthly_active_count() + if count < self._max_mau_value: + yield self.upsert_monthly_active_user(user_id) + elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY: + yield self.upsert_monthly_active_user(user_id) -- cgit 1.5.1 From 7f837959ea25ef50b3675c9c2596ef42592dc127 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 7 Aug 2020 13:36:29 -0400 Subject: Convert directory, e2e_room_keys, end_to_end_keys, monthly_active_users database to async (#8042) --- changelog.d/8042.misc | 1 + synapse/storage/databases/main/devices.py | 12 ++-- synapse/storage/databases/main/directory.py | 51 ++++++++------- synapse/storage/databases/main/e2e_room_keys.py | 30 +++++---- synapse/storage/databases/main/end_to_end_keys.py | 73 +++++++++++----------- .../storage/databases/main/monthly_active_users.py | 31 ++++----- tests/handlers/test_appservice.py | 2 +- tests/storage/test_directory.py | 32 +++++++--- tests/storage/test_end_to_end_keys.py | 12 ++-- tests/storage/test_monthly_active_users.py | 17 +++-- 10 files changed, 141 insertions(+), 120 deletions(-) create mode 100644 changelog.d/8042.misc (limited to 'synapse/storage/databases/main/monthly_active_users.py') diff --git a/changelog.d/8042.misc b/changelog.d/8042.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8042.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 81e64de126..7a5f0bab05 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -136,7 +136,9 @@ class DeviceWorkerStore(SQLBaseStore): master_key_by_user = {} self_signing_key_by_user = {} for user in users: - cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master") + cross_signing_key = yield defer.ensureDeferred( + self.get_e2e_cross_signing_key(user, "master") + ) if cross_signing_key: key_id, verify_key = get_verify_key_from_cross_signing_key( cross_signing_key @@ -149,8 +151,8 @@ class DeviceWorkerStore(SQLBaseStore): "device_id": verify_key.version, } - cross_signing_key = yield self.get_e2e_cross_signing_key( - user, "self_signing" + cross_signing_key = yield defer.ensureDeferred( + self.get_e2e_cross_signing_key(user, "self_signing") ) if cross_signing_key: key_id, verify_key = get_verify_key_from_cross_signing_key( @@ -246,7 +248,7 @@ class DeviceWorkerStore(SQLBaseStore): destination (str): The host the device updates are intended for from_stream_id (int): The minimum stream_id to filter updates by, exclusive query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping - user_id/device_id to update stream_id and the relevent json-encoded + user_id/device_id to update stream_id and the relevant json-encoded opentracing context Returns: @@ -599,7 +601,7 @@ class DeviceWorkerStore(SQLBaseStore): between the requested tokens due to the limit. The token returned can be used in a subsequent call to this - function to get further updatees. + function to get further updates. The updates are a list of 2-tuples of stream ID and the row data """ diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index 7819bfcbb3..037e02603c 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -14,30 +14,29 @@ # limitations under the License. from collections import namedtuple -from typing import Optional - -from twisted.internet import defer +from typing import Iterable, Optional from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore +from synapse.types import RoomAlias from synapse.util.caches.descriptors import cached RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers")) class DirectoryWorkerStore(SQLBaseStore): - @defer.inlineCallbacks - def get_association_from_room_alias(self, room_alias): - """ Get's the room_id and server list for a given room_alias + async def get_association_from_room_alias( + self, room_alias: RoomAlias + ) -> Optional[RoomAliasMapping]: + """Gets the room_id and server list for a given room_alias Args: - room_alias (RoomAlias) + room_alias: The alias to translate to an ID. Returns: - Deferred: results in namedtuple with keys "room_id" and - "servers" or None if no association can be found + The room alias mapping or None if no association can be found. """ - room_id = yield self.db_pool.simple_select_one_onecol( + room_id = await self.db_pool.simple_select_one_onecol( "room_aliases", {"room_alias": room_alias.to_string()}, "room_id", @@ -48,7 +47,7 @@ class DirectoryWorkerStore(SQLBaseStore): if not room_id: return None - servers = yield self.db_pool.simple_select_onecol( + servers = await self.db_pool.simple_select_onecol( "room_alias_servers", {"room_alias": room_alias.to_string()}, "server", @@ -79,18 +78,20 @@ class DirectoryWorkerStore(SQLBaseStore): class DirectoryStore(DirectoryWorkerStore): - @defer.inlineCallbacks - def create_room_alias_association(self, room_alias, room_id, servers, creator=None): + async def create_room_alias_association( + self, + room_alias: RoomAlias, + room_id: str, + servers: Iterable[str], + creator: Optional[str] = None, + ) -> None: """ Creates an association between a room alias and room_id/servers Args: - room_alias (RoomAlias) - room_id (str) - servers (list) - creator (str): Optional user_id of creator. - - Returns: - Deferred + room_alias: The alias to create. + room_id: The target of the alias. + servers: A list of servers through which it may be possible to join the room + creator: Optional user_id of creator. """ def alias_txn(txn): @@ -118,24 +119,22 @@ class DirectoryStore(DirectoryWorkerStore): ) try: - ret = yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "create_room_alias_association", alias_txn ) except self.database_engine.module.IntegrityError: raise SynapseError( 409, "Room alias %s already exists" % room_alias.to_string() ) - return ret - @defer.inlineCallbacks - def delete_room_alias(self, room_alias): - room_id = yield self.db_pool.runInteraction( + async def delete_room_alias(self, room_alias: RoomAlias) -> str: + room_id = await self.db_pool.runInteraction( "delete_room_alias", self._delete_room_alias_txn, room_alias ) return room_id - def _delete_room_alias_txn(self, txn, room_alias): + def _delete_room_alias_txn(self, txn, room_alias: RoomAlias) -> str: txn.execute( "SELECT room_id FROM room_aliases WHERE room_alias = ?", (room_alias.to_string(),), diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index c4aaec3993..2eeb9f97dc 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -14,8 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from synapse.api.errors import StoreError from synapse.logging.opentracing import log_kv, trace from synapse.storage._base import SQLBaseStore, db_to_json @@ -23,8 +21,9 @@ from synapse.util import json_encoder class EndToEndRoomKeyStore(SQLBaseStore): - @defer.inlineCallbacks - def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key): + async def update_e2e_room_key( + self, user_id, version, room_id, session_id, room_key + ): """Replaces the encrypted E2E room key for a given session in a given backup Args: @@ -37,7 +36,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): StoreError """ - yield self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="e2e_room_keys", keyvalues={ "user_id": user_id, @@ -54,8 +53,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): desc="update_e2e_room_key", ) - @defer.inlineCallbacks - def add_e2e_room_keys(self, user_id, version, room_keys): + async def add_e2e_room_keys(self, user_id, version, room_keys): """Bulk add room keys to a given backup. Args: @@ -88,13 +86,12 @@ class EndToEndRoomKeyStore(SQLBaseStore): } ) - yield self.db_pool.simple_insert_many( + await self.db_pool.simple_insert_many( table="e2e_room_keys", values=values, desc="add_e2e_room_keys" ) @trace - @defer.inlineCallbacks - def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): + async def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): """Bulk get the E2E room keys for a given backup, optionally filtered to a given room, or a given session. @@ -109,7 +106,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): the backup (or for the specified room) Returns: - A deferred list of dicts giving the session_data and message metadata for + A list of dicts giving the session_data and message metadata for these room keys. """ @@ -124,7 +121,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): if session_id: keyvalues["session_id"] = session_id - rows = yield self.db_pool.simple_select_list( + rows = await self.db_pool.simple_select_list( table="e2e_room_keys", keyvalues=keyvalues, retcols=( @@ -242,8 +239,9 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) @trace - @defer.inlineCallbacks - def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): + async def delete_e2e_room_keys( + self, user_id, version, room_id=None, session_id=None + ): """Bulk delete the E2E room keys for a given backup, optionally filtered to a given room or a given session. @@ -258,7 +256,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): the backup (or for the specified room) Returns: - A deferred of the deletion transaction + The deletion transaction """ keyvalues = {"user_id": user_id, "version": int(version)} @@ -267,7 +265,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): if session_id: keyvalues["session_id"] = session_id - yield self.db_pool.simple_delete( + await self.db_pool.simple_delete( table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys" ) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 6126376a6f..f93e0d320d 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -14,12 +14,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Tuple +from typing import Dict, Iterable, List, Optional, Tuple from canonicaljson import encode_canonical_json from twisted.enterprise.adbapi import Connection -from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json @@ -31,8 +30,7 @@ from synapse.util.iterutils import batch_iter class EndToEndKeyWorkerStore(SQLBaseStore): @trace - @defer.inlineCallbacks - def get_e2e_device_keys( + async def get_e2e_device_keys( self, query_list, include_all_devices=False, include_deleted_devices=False ): """Fetch a list of device keys. @@ -52,7 +50,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): if not query_list: return {} - results = yield self.db_pool.runInteraction( + results = await self.db_pool.runInteraction( "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list, @@ -175,8 +173,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore): log_kv(result) return result - @defer.inlineCallbacks - def get_e2e_one_time_keys(self, user_id, device_id, key_ids): + async def get_e2e_one_time_keys( + self, user_id: str, device_id: str, key_ids: List[str] + ) -> Dict[Tuple[str, str], str]: """Retrieve a number of one-time keys for a user Args: @@ -186,11 +185,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore): retrieve Returns: - deferred resolving to Dict[(str, str), str]: map from (algorithm, - key_id) to json string for key + A map from (algorithm, key_id) to json string for key """ - rows = yield self.db_pool.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="e2e_one_time_keys_json", column="key_id", iterable=key_ids, @@ -202,17 +200,21 @@ class EndToEndKeyWorkerStore(SQLBaseStore): log_kv({"message": "Fetched one time keys for user", "one_time_keys": result}) return result - @defer.inlineCallbacks - def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys): + async def add_e2e_one_time_keys( + self, + user_id: str, + device_id: str, + time_now: int, + new_keys: Iterable[Tuple[str, str, str]], + ) -> None: """Insert some new one time keys for a device. Errors if any of the keys already exist. Args: - user_id(str): id of user to get keys for - device_id(str): id of device to get keys for - time_now(long): insertion time to record (ms since epoch) - new_keys(iterable[(str, str, str)]: keys to add - each a tuple of - (algorithm, key_id, key json) + user_id: id of user to get keys for + device_id: id of device to get keys for + time_now: insertion time to record (ms since epoch) + new_keys: keys to add - each a tuple of (algorithm, key_id, key json) """ def _add_e2e_one_time_keys(txn): @@ -242,7 +244,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): txn, self.count_e2e_one_time_keys, (user_id, device_id) ) - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys ) @@ -269,22 +271,23 @@ class EndToEndKeyWorkerStore(SQLBaseStore): "count_e2e_one_time_keys", _count_e2e_one_time_keys ) - @defer.inlineCallbacks - def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None): + async def get_e2e_cross_signing_key( + self, user_id: str, key_type: str, from_user_id: Optional[str] = None + ) -> Optional[dict]: """Returns a user's cross-signing key. Args: - user_id (str): the user whose key is being requested - key_type (str): the type of key that is being requested: either 'master' + user_id: the user whose key is being requested + key_type: the type of key that is being requested: either 'master' for a master key, 'self_signing' for a self-signing key, or 'user_signing' for a user-signing key - from_user_id (str): if specified, signatures made by this user on + from_user_id: if specified, signatures made by this user on the self-signing key will be included in the result Returns: dict of the key data or None if not found """ - res = yield self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id) + res = await self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id) user_keys = res.get(user_id) if not user_keys: return None @@ -450,28 +453,26 @@ class EndToEndKeyWorkerStore(SQLBaseStore): return keys - @defer.inlineCallbacks - def get_e2e_cross_signing_keys_bulk( - self, user_ids: List[str], from_user_id: str = None - ) -> defer.Deferred: + async def get_e2e_cross_signing_keys_bulk( + self, user_ids: List[str], from_user_id: Optional[str] = None + ) -> Dict[str, Dict[str, dict]]: """Returns the cross-signing keys for a set of users. Args: - user_ids (list[str]): the users whose keys are being requested - from_user_id (str): if specified, signatures made by this user on + user_ids: the users whose keys are being requested + from_user_id: if specified, signatures made by this user on the self-signing keys will be included in the result Returns: - Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to - key data. If a user's cross-signing keys were not found, either - their user ID will not be in the dict, or their user ID will map - to None. + A map of user ID to key type to key data. If a user's cross-signing + keys were not found, either their user ID will not be in the dict, + or their user ID will map to None. """ - result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids) + result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids) if from_user_id: - result = yield self.db_pool.runInteraction( + result = await self.db_pool.runInteraction( "get_e2e_cross_signing_signatures", self._get_e2e_cross_signing_signatures_txn, result, diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index 02b01d9619..e71cdd2cb4 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -15,8 +15,6 @@ import logging from typing import List -from twisted.internet import defer - from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_in_list_sql_clause from synapse.util.caches.descriptors import cached @@ -252,16 +250,12 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): "reap_monthly_active_users", _reap_users, reserved_users ) - @defer.inlineCallbacks - def upsert_monthly_active_user(self, user_id): + async def upsert_monthly_active_user(self, user_id: str) -> None: """Updates or inserts the user into the monthly active user table, which is used to track the current MAU usage of the server Args: - user_id (str): user to add/update - - Returns: - Deferred + user_id: user to add/update """ # Support user never to be included in MAU stats. Note I can't easily call this # from upsert_monthly_active_user_txn because then I need a _txn form of @@ -271,11 +265,11 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): # _initialise_reserved_users reasoning that it would be very strange to # include a support user in this context. - is_support = yield self.is_support_user(user_id) + is_support = await self.is_support_user(user_id) if is_support: return - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id ) @@ -322,8 +316,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): return is_insert - @defer.inlineCallbacks - def populate_monthly_active_users(self, user_id): + async def populate_monthly_active_users(self, user_id): """Checks on the state of monthly active user limits and optionally add the user to the monthly active tables @@ -332,14 +325,14 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): """ if self._limit_usage_by_mau or self._mau_stats_only: # Trial users and guests should not be included as part of MAU group - is_guest = yield self.is_guest(user_id) + is_guest = await self.is_guest(user_id) if is_guest: return - is_trial = yield self.is_trial_user(user_id) + is_trial = await self.is_trial_user(user_id) if is_trial: return - last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id) + last_seen_timestamp = await self.user_last_seen_monthly_active(user_id) now = self.hs.get_clock().time_msec() # We want to reduce to the total number of db writes, and are happy @@ -352,10 +345,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): # False, there is no point in checking get_monthly_active_count - it # adds no value and will break the logic if max_mau_value is exceeded. if not self._limit_usage_by_mau: - yield self.upsert_monthly_active_user(user_id) + await self.upsert_monthly_active_user(user_id) else: - count = yield self.get_monthly_active_count() + count = await self.get_monthly_active_count() if count < self._max_mau_value: - yield self.upsert_monthly_active_user(user_id) + await self.upsert_monthly_active_user(user_id) elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY: - yield self.upsert_monthly_active_user(user_id) + await self.upsert_monthly_active_user(user_id) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 628f7d8db0..2a0b7c1b56 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -120,7 +120,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_as_api.query_alias.return_value = make_awaitable(True) self.mock_store.get_app_services.return_value = services - self.mock_store.get_association_from_room_alias.return_value = defer.succeed( + self.mock_store.get_association_from_room_alias.return_value = make_awaitable( Mock(room_id=room_id, servers=servers) ) diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index 4e128e1047..daac947cb2 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -34,8 +34,10 @@ class DirectoryStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_room_to_alias(self): - yield self.store.create_room_alias_association( - room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] + yield defer.ensureDeferred( + self.store.create_room_alias_association( + room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] + ) ) self.assertEquals( @@ -45,24 +47,36 @@ class DirectoryStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_alias_to_room(self): - yield self.store.create_room_alias_association( - room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] + yield defer.ensureDeferred( + self.store.create_room_alias_association( + room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] + ) ) self.assertObjectHasAttributes( {"room_id": self.room.to_string(), "servers": ["test"]}, - (yield self.store.get_association_from_room_alias(self.alias)), + ( + yield defer.ensureDeferred( + self.store.get_association_from_room_alias(self.alias) + ) + ), ) @defer.inlineCallbacks def test_delete_alias(self): - yield self.store.create_room_alias_association( - room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] + yield defer.ensureDeferred( + self.store.create_room_alias_association( + room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] + ) ) - room_id = yield self.store.delete_room_alias(self.alias) + room_id = yield defer.ensureDeferred(self.store.delete_room_alias(self.alias)) self.assertEqual(self.room.to_string(), room_id) self.assertIsNone( - (yield self.store.get_association_from_room_alias(self.alias)) + ( + yield defer.ensureDeferred( + self.store.get_association_from_room_alias(self.alias) + ) + ) ) diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index 398d546280..9f8d30373b 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -34,7 +34,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): yield self.store.set_e2e_device_keys("user", "device", now, json) - res = yield self.store.get_e2e_device_keys((("user", "device"),)) + res = yield defer.ensureDeferred( + self.store.get_e2e_device_keys((("user", "device"),)) + ) self.assertIn("user", res) self.assertIn("device", res["user"]) dev = res["user"]["device"] @@ -63,7 +65,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): yield self.store.set_e2e_device_keys("user", "device", now, json) yield self.store.store_device("user", "device", "display_name") - res = yield self.store.get_e2e_device_keys((("user", "device"),)) + res = yield defer.ensureDeferred( + self.store.get_e2e_device_keys((("user", "device"),)) + ) self.assertIn("user", res) self.assertIn("device", res["user"]) dev = res["user"]["device"] @@ -85,8 +89,8 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"}) yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"}) - res = yield self.store.get_e2e_device_keys( - (("user1", "device1"), ("user2", "device2")) + res = yield defer.ensureDeferred( + self.store.get_e2e_device_keys((("user1", "device1"), ("user2", "device2"))) ) self.assertIn("user1", res) self.assertIn("device1", res["user1"]) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 259f2215f1..e793781a26 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -19,6 +19,7 @@ from twisted.internet import defer from synapse.api.constants import UserTypes from tests import unittest +from tests.test_utils import make_awaitable from tests.unittest import default_config, override_config FORTY_DAYS = 40 * 24 * 60 * 60 @@ -230,7 +231,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): ) self.get_success(d) - self.store.upsert_monthly_active_user = Mock() + self.store.upsert_monthly_active_user = Mock( + side_effect=lambda user_id: make_awaitable(None) + ) d = self.store.populate_monthly_active_users(user_id) self.get_success(d) @@ -238,7 +241,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.upsert_monthly_active_user.assert_not_called() def test_populate_monthly_users_should_update(self): - self.store.upsert_monthly_active_user = Mock() + self.store.upsert_monthly_active_user = Mock( + side_effect=lambda user_id: make_awaitable(None) + ) self.store.is_trial_user = Mock(return_value=defer.succeed(False)) @@ -251,7 +256,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.upsert_monthly_active_user.assert_called_once() def test_populate_monthly_users_should_not_update(self): - self.store.upsert_monthly_active_user = Mock() + self.store.upsert_monthly_active_user = Mock( + side_effect=lambda user_id: make_awaitable(None) + ) self.store.is_trial_user = Mock(return_value=defer.succeed(False)) self.store.user_last_seen_monthly_active = Mock( @@ -333,7 +340,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): @override_config({"limit_usage_by_mau": False, "mau_stats_only": False}) def test_no_users_when_not_tracking(self): - self.store.upsert_monthly_active_user = Mock() + self.store.upsert_monthly_active_user = Mock( + side_effect=lambda user_id: make_awaitable(None) + ) self.get_success(self.store.populate_monthly_active_users("@user:sever")) -- cgit 1.5.1