From a7bdf98d01d2225a479753a85ba81adf02b16a32 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 5 Aug 2020 21:38:57 +0100 Subject: Rename database classes to make some sense (#8033) --- synapse/storage/databases/main/appservice.py | 374 +++++++++++++++++++++++++++ 1 file changed, 374 insertions(+) create mode 100644 synapse/storage/databases/main/appservice.py (limited to 'synapse/storage/databases/main/appservice.py') diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py new file mode 100644 index 0000000000..055a3962dc --- /dev/null +++ b/synapse/storage/databases/main/appservice.py @@ -0,0 +1,374 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import re + +from canonicaljson import json + +from twisted.internet import defer + +from synapse.appservice import AppServiceTransaction +from synapse.config.appservice import load_appservices +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.events_worker import EventsWorkerStore + +logger = logging.getLogger(__name__) + + +def _make_exclusive_regex(services_cache): + # We precompile a regex constructed from all the regexes that the AS's + # have registered for exclusive users. + exclusive_user_regexes = [ + regex.pattern + for service in services_cache + for regex in service.get_exclusive_user_regexes() + ] + if exclusive_user_regexes: + exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes) + exclusive_user_regex = re.compile(exclusive_user_regex) + else: + # We handle this case specially otherwise the constructed regex + # will always match + exclusive_user_regex = None + + return exclusive_user_regex + + +class ApplicationServiceWorkerStore(SQLBaseStore): + def __init__(self, database: DatabasePool, db_conn, hs): + self.services_cache = load_appservices( + hs.hostname, hs.config.app_service_config_files + ) + self.exclusive_user_regex = _make_exclusive_regex(self.services_cache) + + super(ApplicationServiceWorkerStore, self).__init__(database, db_conn, hs) + + def get_app_services(self): + return self.services_cache + + def get_if_app_services_interested_in_user(self, user_id): + """Check if the user is one associated with an app service (exclusively) + """ + if self.exclusive_user_regex: + return bool(self.exclusive_user_regex.match(user_id)) + else: + return False + + def get_app_service_by_user_id(self, user_id): + """Retrieve an application service from their user ID. + + All application services have associated with them a particular user ID. + There is no distinguishing feature on the user ID which indicates it + represents an application service. This function allows you to map from + a user ID to an application service. + + Args: + user_id(str): The user ID to see if it is an application service. + Returns: + synapse.appservice.ApplicationService or None. + """ + for service in self.services_cache: + if service.sender == user_id: + return service + return None + + def get_app_service_by_token(self, token): + """Get the application service with the given appservice token. + + Args: + token (str): The application service token. + Returns: + synapse.appservice.ApplicationService or None. + """ + for service in self.services_cache: + if service.token == token: + return service + return None + + def get_app_service_by_id(self, as_id): + """Get the application service with the given appservice ID. + + Args: + as_id (str): The application service ID. + Returns: + synapse.appservice.ApplicationService or None. + """ + for service in self.services_cache: + if service.id == as_id: + return service + return None + + +class ApplicationServiceStore(ApplicationServiceWorkerStore): + # This is currently empty due to there not being any AS storage functions + # that can't be run on the workers. Since this may change in future, and + # to keep consistency with the other stores, we keep this empty class for + # now. + pass + + +class ApplicationServiceTransactionWorkerStore( + ApplicationServiceWorkerStore, EventsWorkerStore +): + @defer.inlineCallbacks + def get_appservices_by_state(self, state): + """Get a list of application services based on their state. + + Args: + state(ApplicationServiceState): The state to filter on. + Returns: + A Deferred which resolves to a list of ApplicationServices, which + may be empty. + """ + results = yield self.db_pool.simple_select_list( + "application_services_state", {"state": state}, ["as_id"] + ) + # NB: This assumes this class is linked with ApplicationServiceStore + as_list = self.get_app_services() + services = [] + + for res in results: + for service in as_list: + if service.id == res["as_id"]: + services.append(service) + return services + + @defer.inlineCallbacks + def get_appservice_state(self, service): + """Get the application service state. + + Args: + service(ApplicationService): The service whose state to set. + Returns: + A Deferred which resolves to ApplicationServiceState. + """ + result = yield self.db_pool.simple_select_one( + "application_services_state", + {"as_id": service.id}, + ["state"], + allow_none=True, + desc="get_appservice_state", + ) + if result: + return result.get("state") + return None + + def set_appservice_state(self, service, state): + """Set the application service state. + + Args: + service(ApplicationService): The service whose state to set. + state(ApplicationServiceState): The connectivity state to apply. + Returns: + A Deferred which resolves when the state was set successfully. + """ + return self.db_pool.simple_upsert( + "application_services_state", {"as_id": service.id}, {"state": state} + ) + + def create_appservice_txn(self, service, events): + """Atomically creates a new transaction for this application service + with the given list of events. + + Args: + service(ApplicationService): The service who the transaction is for. + events(list): A list of events to put in the transaction. + Returns: + AppServiceTransaction: A new transaction. + """ + + def _create_appservice_txn(txn): + # work out new txn id (highest txn id for this service += 1) + # The highest id may be the last one sent (in which case it is last_txn) + # or it may be the highest in the txns list (which are waiting to be/are + # being sent) + last_txn_id = self._get_last_txn(txn, service.id) + + txn.execute( + "SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?", + (service.id,), + ) + highest_txn_id = txn.fetchone()[0] + if highest_txn_id is None: + highest_txn_id = 0 + + new_txn_id = max(highest_txn_id, last_txn_id) + 1 + + # Insert new txn into txn table + event_ids = json.dumps([e.event_id for e in events]) + txn.execute( + "INSERT INTO application_services_txns(as_id, txn_id, event_ids) " + "VALUES(?,?,?)", + (service.id, new_txn_id, event_ids), + ) + return AppServiceTransaction(service=service, id=new_txn_id, events=events) + + return self.db_pool.runInteraction( + "create_appservice_txn", _create_appservice_txn + ) + + def complete_appservice_txn(self, txn_id, service): + """Completes an application service transaction. + + Args: + txn_id(str): The transaction ID being completed. + service(ApplicationService): The application service which was sent + this transaction. + Returns: + A Deferred which resolves if this transaction was stored + successfully. + """ + txn_id = int(txn_id) + + def _complete_appservice_txn(txn): + # Debugging query: Make sure the txn being completed is EXACTLY +1 from + # what was there before. If it isn't, we've got problems (e.g. the AS + # has probably missed some events), so whine loudly but still continue, + # since it shouldn't fail completion of the transaction. + last_txn_id = self._get_last_txn(txn, service.id) + if (last_txn_id + 1) != txn_id: + logger.error( + "appservice: Completing a transaction which has an ID > 1 from " + "the last ID sent to this AS. We've either dropped events or " + "sent it to the AS out of order. FIX ME. last_txn=%s " + "completing_txn=%s service_id=%s", + last_txn_id, + txn_id, + service.id, + ) + + # Set current txn_id for AS to 'txn_id' + self.db_pool.simple_upsert_txn( + txn, + "application_services_state", + {"as_id": service.id}, + {"last_txn": txn_id}, + ) + + # Delete txn + self.db_pool.simple_delete_txn( + txn, + "application_services_txns", + {"txn_id": txn_id, "as_id": service.id}, + ) + + return self.db_pool.runInteraction( + "complete_appservice_txn", _complete_appservice_txn + ) + + @defer.inlineCallbacks + def get_oldest_unsent_txn(self, service): + """Get the oldest transaction which has not been sent for this + service. + + Args: + service(ApplicationService): The app service to get the oldest txn. + Returns: + A Deferred which resolves to an AppServiceTransaction or + None. + """ + + def _get_oldest_unsent_txn(txn): + # Monotonically increasing txn ids, so just select the smallest + # one in the txns table (we delete them when they are sent) + txn.execute( + "SELECT * FROM application_services_txns WHERE as_id=?" + " ORDER BY txn_id ASC LIMIT 1", + (service.id,), + ) + rows = self.db_pool.cursor_to_dict(txn) + if not rows: + return None + + entry = rows[0] + + return entry + + entry = yield self.db_pool.runInteraction( + "get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn + ) + + if not entry: + return None + + event_ids = db_to_json(entry["event_ids"]) + + events = yield self.get_events_as_list(event_ids) + + return AppServiceTransaction(service=service, id=entry["txn_id"], events=events) + + def _get_last_txn(self, txn, service_id): + txn.execute( + "SELECT last_txn FROM application_services_state WHERE as_id=?", + (service_id,), + ) + last_txn_id = txn.fetchone() + if last_txn_id is None or last_txn_id[0] is None: # no row exists + return 0 + else: + return int(last_txn_id[0]) # select 'last_txn' col + + def set_appservice_last_pos(self, pos): + def set_appservice_last_pos_txn(txn): + txn.execute( + "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) + ) + + return self.db_pool.runInteraction( + "set_appservice_last_pos", set_appservice_last_pos_txn + ) + + @defer.inlineCallbacks + def get_new_events_for_appservice(self, current_id, limit): + """Get all new evnets""" + + def get_new_events_for_appservice_txn(txn): + sql = ( + "SELECT e.stream_ordering, e.event_id" + " FROM events AS e" + " WHERE" + " (SELECT stream_ordering FROM appservice_stream_position)" + " < e.stream_ordering" + " AND e.stream_ordering <= ?" + " ORDER BY e.stream_ordering ASC" + " LIMIT ?" + ) + + txn.execute(sql, (current_id, limit)) + rows = txn.fetchall() + + upper_bound = current_id + if len(rows) == limit: + upper_bound = rows[-1][0] + + return upper_bound, [row[1] for row in rows] + + upper_bound, event_ids = yield self.db_pool.runInteraction( + "get_new_events_for_appservice", get_new_events_for_appservice_txn + ) + + events = yield self.get_events_as_list(event_ids) + + return upper_bound, events + + +class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore): + # This is currently empty due to there not being any AS storage functions + # that can't be run on the workers. Since this may change in future, and + # to keep consistency with the other stores, we keep this empty class for + # now. + pass -- cgit 1.5.1 From a3a59bab7bb3b69dcfc5620e6f3ac51af3f0f965 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 12 Aug 2020 09:28:48 -0400 Subject: Convert appservice, group server, profile and more databases to async (#8066) --- changelog.d/8066.misc | 1 + synapse/storage/databases/main/appservice.py | 34 ++++------ synapse/storage/databases/main/filtering.py | 8 +-- synapse/storage/databases/main/group_server.py | 86 ++++++++++++-------------- synapse/storage/databases/main/presence.py | 7 +-- synapse/storage/databases/main/profile.py | 21 +++---- synapse/storage/databases/main/relations.py | 19 +++--- synapse/storage/databases/main/transactions.py | 7 +-- tests/storage/test_appservice.py | 24 +++---- 9 files changed, 91 insertions(+), 116 deletions(-) create mode 100644 changelog.d/8066.misc (limited to 'synapse/storage/databases/main/appservice.py') diff --git a/changelog.d/8066.misc b/changelog.d/8066.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8066.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 055a3962dc..5cf1a88399 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -18,8 +18,6 @@ import re from canonicaljson import json -from twisted.internet import defer - from synapse.appservice import AppServiceTransaction from synapse.config.appservice import load_appservices from synapse.storage._base import SQLBaseStore, db_to_json @@ -124,17 +122,15 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore): class ApplicationServiceTransactionWorkerStore( ApplicationServiceWorkerStore, EventsWorkerStore ): - @defer.inlineCallbacks - def get_appservices_by_state(self, state): + async def get_appservices_by_state(self, state): """Get a list of application services based on their state. Args: state(ApplicationServiceState): The state to filter on. Returns: - A Deferred which resolves to a list of ApplicationServices, which - may be empty. + A list of ApplicationServices, which may be empty. """ - results = yield self.db_pool.simple_select_list( + results = await self.db_pool.simple_select_list( "application_services_state", {"state": state}, ["as_id"] ) # NB: This assumes this class is linked with ApplicationServiceStore @@ -147,16 +143,15 @@ class ApplicationServiceTransactionWorkerStore( services.append(service) return services - @defer.inlineCallbacks - def get_appservice_state(self, service): + async def get_appservice_state(self, service): """Get the application service state. Args: service(ApplicationService): The service whose state to set. Returns: - A Deferred which resolves to ApplicationServiceState. + An ApplicationServiceState. """ - result = yield self.db_pool.simple_select_one( + result = await self.db_pool.simple_select_one( "application_services_state", {"as_id": service.id}, ["state"], @@ -270,16 +265,14 @@ class ApplicationServiceTransactionWorkerStore( "complete_appservice_txn", _complete_appservice_txn ) - @defer.inlineCallbacks - def get_oldest_unsent_txn(self, service): + async def get_oldest_unsent_txn(self, service): """Get the oldest transaction which has not been sent for this service. Args: service(ApplicationService): The app service to get the oldest txn. Returns: - A Deferred which resolves to an AppServiceTransaction or - None. + An AppServiceTransaction or None. """ def _get_oldest_unsent_txn(txn): @@ -298,7 +291,7 @@ class ApplicationServiceTransactionWorkerStore( return entry - entry = yield self.db_pool.runInteraction( + entry = await self.db_pool.runInteraction( "get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn ) @@ -307,7 +300,7 @@ class ApplicationServiceTransactionWorkerStore( event_ids = db_to_json(entry["event_ids"]) - events = yield self.get_events_as_list(event_ids) + events = await self.get_events_as_list(event_ids) return AppServiceTransaction(service=service, id=entry["txn_id"], events=events) @@ -332,8 +325,7 @@ class ApplicationServiceTransactionWorkerStore( "set_appservice_last_pos", set_appservice_last_pos_txn ) - @defer.inlineCallbacks - def get_new_events_for_appservice(self, current_id, limit): + async def get_new_events_for_appservice(self, current_id, limit): """Get all new evnets""" def get_new_events_for_appservice_txn(txn): @@ -357,11 +349,11 @@ class ApplicationServiceTransactionWorkerStore( return upper_bound, [row[1] for row in rows] - upper_bound, event_ids = yield self.db_pool.runInteraction( + upper_bound, event_ids = await self.db_pool.runInteraction( "get_new_events_for_appservice", get_new_events_for_appservice_txn ) - events = yield self.get_events_as_list(event_ids) + events = await self.get_events_as_list(event_ids) return upper_bound, events diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py index cae6bda80e..45a1760170 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py @@ -17,12 +17,12 @@ from canonicaljson import encode_canonical_json from synapse.api.errors import Codes, SynapseError from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util.caches.descriptors import cached class FilteringStore(SQLBaseStore): - @cachedInlineCallbacks(num_args=2) - def get_user_filter(self, user_localpart, filter_id): + @cached(num_args=2) + async def get_user_filter(self, user_localpart, filter_id): # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail # with a coherent error message rather than 500 M_UNKNOWN. try: @@ -30,7 +30,7 @@ class FilteringStore(SQLBaseStore): except ValueError: raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM) - def_json = yield self.db_pool.simple_select_one_onecol( + def_json = await self.db_pool.simple_select_one_onecol( table="user_filters", keyvalues={"user_id": user_localpart, "filter_id": filter_id}, retcol="filter_json", diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 75ea6d4b2f..380db3a3f3 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -14,12 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple - -from twisted.internet import defer +from typing import List, Optional, Tuple from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.types import JsonDict from synapse.util import json_encoder # The category ID for the "default" category. We don't store as null in the @@ -210,9 +209,8 @@ class GroupServerWorkerStore(SQLBaseStore): "get_rooms_for_summary", _get_rooms_for_summary_txn ) - @defer.inlineCallbacks - def get_group_categories(self, group_id): - rows = yield self.db_pool.simple_select_list( + async def get_group_categories(self, group_id): + rows = await self.db_pool.simple_select_list( table="group_room_categories", keyvalues={"group_id": group_id}, retcols=("category_id", "is_public", "profile"), @@ -227,9 +225,8 @@ class GroupServerWorkerStore(SQLBaseStore): for row in rows } - @defer.inlineCallbacks - def get_group_category(self, group_id, category_id): - category = yield self.db_pool.simple_select_one( + async def get_group_category(self, group_id, category_id): + category = await self.db_pool.simple_select_one( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, retcols=("is_public", "profile"), @@ -240,9 +237,8 @@ class GroupServerWorkerStore(SQLBaseStore): return category - @defer.inlineCallbacks - def get_group_roles(self, group_id): - rows = yield self.db_pool.simple_select_list( + async def get_group_roles(self, group_id): + rows = await self.db_pool.simple_select_list( table="group_roles", keyvalues={"group_id": group_id}, retcols=("role_id", "is_public", "profile"), @@ -257,9 +253,8 @@ class GroupServerWorkerStore(SQLBaseStore): for row in rows } - @defer.inlineCallbacks - def get_group_role(self, group_id, role_id): - role = yield self.db_pool.simple_select_one( + async def get_group_role(self, group_id, role_id): + role = await self.db_pool.simple_select_one( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, retcols=("is_public", "profile"), @@ -448,12 +443,11 @@ class GroupServerWorkerStore(SQLBaseStore): "get_attestations_need_renewals", _get_attestations_need_renewals_txn ) - @defer.inlineCallbacks - def get_remote_attestation(self, group_id, user_id): + async def get_remote_attestation(self, group_id, user_id): """Get the attestation that proves the remote agrees that the user is in the group. """ - row = yield self.db_pool.simple_select_one( + row = await self.db_pool.simple_select_one( table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, retcols=("valid_until_ms", "attestation_json"), @@ -499,13 +493,13 @@ class GroupServerWorkerStore(SQLBaseStore): "get_all_groups_for_user", _get_all_groups_for_user_txn ) - def get_groups_changes_for_user(self, user_id, from_token, to_token): + async def get_groups_changes_for_user(self, user_id, from_token, to_token): from_token = int(from_token) has_changed = self._group_updates_stream_cache.has_entity_changed( user_id, from_token ) if not has_changed: - return defer.succeed([]) + return [] def _get_groups_changes_for_user_txn(txn): sql = """ @@ -525,7 +519,7 @@ class GroupServerWorkerStore(SQLBaseStore): for group_id, membership, gtype, content_json in txn ] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_groups_changes_for_user", _get_groups_changes_for_user_txn ) @@ -1087,31 +1081,31 @@ class GroupServerStore(GroupServerWorkerStore): desc="update_group_publicity", ) - @defer.inlineCallbacks - def register_user_group_membership( + async def register_user_group_membership( self, - group_id, - user_id, - membership, - is_admin=False, - content={}, - local_attestation=None, - remote_attestation=None, - is_publicised=False, - ): + group_id: str, + user_id: str, + membership: str, + is_admin: bool = False, + content: JsonDict = {}, + local_attestation: Optional[dict] = None, + remote_attestation: Optional[dict] = None, + is_publicised: bool = False, + ) -> int: """Registers that a local user is a member of a (local or remote) group. Args: - group_id (str) - user_id (str) - membership (str) - is_admin (bool) - content (dict): Content of the membership, e.g. includes the inviter + group_id: The group the member is being added to. + user_id: THe user ID to add to the group. + membership: The type of group membership. + is_admin: Whether the user should be added as a group admin. + content: Content of the membership, e.g. includes the inviter if the user has been invited. - local_attestation (dict): If remote group then store the fact that we + local_attestation: If remote group then store the fact that we have given out an attestation, else None. - remote_attestation (dict): If remote group then store the remote + remote_attestation: If remote group then store the remote attestation from the group, else None. + is_publicised: Whether this should be publicised. """ def _register_user_group_membership_txn(txn, next_id): @@ -1188,18 +1182,17 @@ class GroupServerStore(GroupServerWorkerStore): return next_id with self._group_updates_id_gen.get_next() as next_id: - res = yield self.db_pool.runInteraction( + res = await self.db_pool.runInteraction( "register_user_group_membership", _register_user_group_membership_txn, next_id, ) return res - @defer.inlineCallbacks - def create_group( + async def create_group( self, group_id, user_id, name, avatar_url, short_description, long_description - ): - yield self.db_pool.simple_insert( + ) -> None: + await self.db_pool.simple_insert( table="groups", values={ "group_id": group_id, @@ -1212,9 +1205,8 @@ class GroupServerStore(GroupServerWorkerStore): desc="create_group", ) - @defer.inlineCallbacks - def update_group_profile(self, group_id, profile): - yield self.db_pool.simple_update_one( + async def update_group_profile(self, group_id, profile): + await self.db_pool.simple_update_one( table="groups", keyvalues={"group_id": group_id}, updatevalues=profile, diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 99e66dc6e9..59ba12820a 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -15,8 +15,6 @@ from typing import List, Tuple -from twisted.internet import defer - from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.presence import UserPresenceState from synapse.util.caches.descriptors import cached, cachedList @@ -24,14 +22,13 @@ from synapse.util.iterutils import batch_iter class PresenceStore(SQLBaseStore): - @defer.inlineCallbacks - def update_presence(self, presence_states): + async def update_presence(self, presence_states): stream_ordering_manager = self._presence_id_gen.get_next_mult( len(presence_states) ) with stream_ordering_manager as stream_orderings: - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "update_presence", self._update_presence_txn, stream_orderings, diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 4a4f2cb385..b8261357d4 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -13,18 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore from synapse.storage.databases.main.roommember import ProfileInfo class ProfileWorkerStore(SQLBaseStore): - @defer.inlineCallbacks - def get_profileinfo(self, user_localpart): + async def get_profileinfo(self, user_localpart): try: - profile = yield self.db_pool.simple_select_one( + profile = await self.db_pool.simple_select_one( table="profiles", keyvalues={"user_id": user_localpart}, retcols=("displayname", "avatar_url"), @@ -118,14 +115,13 @@ class ProfileStore(ProfileWorkerStore): desc="update_remote_profile_cache", ) - @defer.inlineCallbacks - def maybe_delete_remote_profile_cache(self, user_id): + async def maybe_delete_remote_profile_cache(self, user_id): """Check if we still care about the remote user's profile, and if we don't then remove their profile from the cache """ - subscribed = yield self.is_subscribed_remote_profile_for_user(user_id) + subscribed = await self.is_subscribed_remote_profile_for_user(user_id) if not subscribed: - yield self.db_pool.simple_delete( + await self.db_pool.simple_delete( table="remote_profile_cache", keyvalues={"user_id": user_id}, desc="delete_remote_profile_cache", @@ -151,11 +147,10 @@ class ProfileStore(ProfileWorkerStore): _get_remote_profile_cache_entries_that_expire_txn, ) - @defer.inlineCallbacks - def is_subscribed_remote_profile_for_user(self, user_id): + async def is_subscribed_remote_profile_for_user(self, user_id): """Check whether we are interested in a remote user's profile. """ - res = yield self.db_pool.simple_select_one_onecol( + res = await self.db_pool.simple_select_one_onecol( table="group_users", keyvalues={"user_id": user_id}, retcol="user_id", @@ -166,7 +161,7 @@ class ProfileStore(ProfileWorkerStore): if res: return True - res = yield self.db_pool.simple_select_one_onecol( + res = await self.db_pool.simple_select_one_onecol( table="group_invites", keyvalues={"user_id": user_id}, retcol="user_id", diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index b81f1449b7..a9ceffc20e 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -14,10 +14,12 @@ # limitations under the License. import logging +from typing import Optional import attr from synapse.api.constants import RelationTypes +from synapse.events import EventBase from synapse.storage._base import SQLBaseStore from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.relations import ( @@ -25,7 +27,7 @@ from synapse.storage.relations import ( PaginationChunk, RelationPaginationToken, ) -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -227,18 +229,18 @@ class RelationsWorkerStore(SQLBaseStore): "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn ) - @cachedInlineCallbacks() - def get_applicable_edit(self, event_id): + @cached() + async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: """Get the most recent edit (if any) that has happened for the given event. Correctly handles checking whether edits were allowed to happen. Args: - event_id (str): The original event ID + event_id: The original event ID Returns: - Deferred[EventBase|None]: Returns the most recent edit, if any. + The most recent edit, if any. """ # We only allow edits for `m.room.message` events that have the same sender @@ -268,15 +270,14 @@ class RelationsWorkerStore(SQLBaseStore): if row: return row[0] - edit_id = yield self.db_pool.runInteraction( + edit_id = await self.db_pool.runInteraction( "get_applicable_edit", _get_applicable_edit_txn ) if not edit_id: - return + return None - edit_event = yield self.get_event(edit_id, allow_none=True) - return edit_event + return await self.get_event(edit_id, allow_none=True) def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender): """Check if a user has already annotated an event with the same key diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 8804c0e4ac..52668dbdf9 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -18,8 +18,6 @@ from collections import namedtuple from canonicaljson import encode_canonical_json -from twisted.internet import defer - from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool @@ -126,8 +124,7 @@ class TransactionStore(SQLBaseStore): desc="set_received_txn_response", ) - @defer.inlineCallbacks - def get_destination_retry_timings(self, destination): + async def get_destination_retry_timings(self, destination): """Gets the current retry timings (if any) for a given destination. Args: @@ -142,7 +139,7 @@ class TransactionStore(SQLBaseStore): if result is not SENTINEL: return result - result = yield self.db_pool.runInteraction( + result = await self.db_pool.runInteraction( "get_destination_retry_timings", self._get_destination_retry_timings, destination, diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 1b516b7976..98b74890d5 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -178,14 +178,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_appservice_state_none(self): service = Mock(id="999") - state = yield self.store.get_appservice_state(service) + state = yield defer.ensureDeferred(self.store.get_appservice_state(service)) self.assertEquals(None, state) @defer.inlineCallbacks def test_get_appservice_state_up(self): yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP) service = Mock(id=self.as_list[0]["id"]) - state = yield self.store.get_appservice_state(service) + state = yield defer.ensureDeferred(self.store.get_appservice_state(service)) self.assertEquals(ApplicationServiceState.UP, state) @defer.inlineCallbacks @@ -194,13 +194,13 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN) yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) service = Mock(id=self.as_list[1]["id"]) - state = yield self.store.get_appservice_state(service) + state = yield defer.ensureDeferred(self.store.get_appservice_state(service)) self.assertEquals(ApplicationServiceState.DOWN, state) @defer.inlineCallbacks def test_get_appservices_by_state_none(self): - services = yield self.store.get_appservices_by_state( - ApplicationServiceState.DOWN + services = yield defer.ensureDeferred( + self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) self.assertEquals(0, len(services)) @@ -339,7 +339,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): def test_get_oldest_unsent_txn_none(self): service = Mock(id=self.as_list[0]["id"]) - txn = yield self.store.get_oldest_unsent_txn(service) + txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service)) self.assertEquals(None, txn) @defer.inlineCallbacks @@ -349,14 +349,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): other_events = [Mock(event_id="e5"), Mock(event_id="e6")] # we aren't testing store._base stuff here, so mock this out - self.store.get_events_as_list = Mock(return_value=events) + self.store.get_events_as_list = Mock(return_value=defer.succeed(events)) yield self._insert_txn(self.as_list[1]["id"], 9, other_events) yield self._insert_txn(service.id, 10, events) yield self._insert_txn(service.id, 11, other_events) yield self._insert_txn(service.id, 12, other_events) - txn = yield self.store.get_oldest_unsent_txn(service) + txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service)) self.assertEquals(service, txn.service) self.assertEquals(10, txn.id) self.assertEquals(events, txn.events) @@ -366,8 +366,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN) yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP) - services = yield self.store.get_appservices_by_state( - ApplicationServiceState.DOWN + services = yield defer.ensureDeferred( + self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) self.assertEquals(1, len(services)) self.assertEquals(self.as_list[0]["id"], services[0].id) @@ -379,8 +379,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) yield self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP) - services = yield self.store.get_appservices_by_state( - ApplicationServiceState.DOWN + services = yield defer.ensureDeferred( + self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) self.assertEquals(2, len(services)) self.assertEquals( -- cgit 1.5.1