From a7bdf98d01d2225a479753a85ba81adf02b16a32 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 5 Aug 2020 21:38:57 +0100 Subject: Rename database classes to make some sense (#8033) --- tests/storage/test_appservice.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'tests/storage/test_appservice.py') diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index ef296e7dab..1b516b7976 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -24,11 +24,11 @@ from twisted.internet import defer from synapse.appservice import ApplicationService, ApplicationServiceState from synapse.config._base import ConfigError -from synapse.storage.data_stores.main.appservice import ( +from synapse.storage.database import DatabasePool, make_conn +from synapse.storage.databases.main.appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore, ) -from synapse.storage.database import Database, make_conn from tests import unittest from tests.utils import setup_test_homeserver @@ -391,7 +391,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): # required for ApplicationServiceTransactionStoreTestCase tests class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(TestTransactionStore, self).__init__(database, db_conn, hs) -- cgit 1.5.1 From a3a59bab7bb3b69dcfc5620e6f3ac51af3f0f965 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 12 Aug 2020 09:28:48 -0400 Subject: Convert appservice, group server, profile and more databases to async (#8066) --- changelog.d/8066.misc | 1 + synapse/storage/databases/main/appservice.py | 34 ++++------ synapse/storage/databases/main/filtering.py | 8 +-- synapse/storage/databases/main/group_server.py | 86 ++++++++++++-------------- synapse/storage/databases/main/presence.py | 7 +-- synapse/storage/databases/main/profile.py | 21 +++---- synapse/storage/databases/main/relations.py | 19 +++--- synapse/storage/databases/main/transactions.py | 7 +-- tests/storage/test_appservice.py | 24 +++---- 9 files changed, 91 insertions(+), 116 deletions(-) create mode 100644 changelog.d/8066.misc (limited to 'tests/storage/test_appservice.py') diff --git a/changelog.d/8066.misc b/changelog.d/8066.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8066.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 055a3962dc..5cf1a88399 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -18,8 +18,6 @@ import re from canonicaljson import json -from twisted.internet import defer - from synapse.appservice import AppServiceTransaction from synapse.config.appservice import load_appservices from synapse.storage._base import SQLBaseStore, db_to_json @@ -124,17 +122,15 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore): class ApplicationServiceTransactionWorkerStore( ApplicationServiceWorkerStore, EventsWorkerStore ): - @defer.inlineCallbacks - def get_appservices_by_state(self, state): + async def get_appservices_by_state(self, state): """Get a list of application services based on their state. Args: state(ApplicationServiceState): The state to filter on. Returns: - A Deferred which resolves to a list of ApplicationServices, which - may be empty. + A list of ApplicationServices, which may be empty. """ - results = yield self.db_pool.simple_select_list( + results = await self.db_pool.simple_select_list( "application_services_state", {"state": state}, ["as_id"] ) # NB: This assumes this class is linked with ApplicationServiceStore @@ -147,16 +143,15 @@ class ApplicationServiceTransactionWorkerStore( services.append(service) return services - @defer.inlineCallbacks - def get_appservice_state(self, service): + async def get_appservice_state(self, service): """Get the application service state. Args: service(ApplicationService): The service whose state to set. Returns: - A Deferred which resolves to ApplicationServiceState. + An ApplicationServiceState. """ - result = yield self.db_pool.simple_select_one( + result = await self.db_pool.simple_select_one( "application_services_state", {"as_id": service.id}, ["state"], @@ -270,16 +265,14 @@ class ApplicationServiceTransactionWorkerStore( "complete_appservice_txn", _complete_appservice_txn ) - @defer.inlineCallbacks - def get_oldest_unsent_txn(self, service): + async def get_oldest_unsent_txn(self, service): """Get the oldest transaction which has not been sent for this service. Args: service(ApplicationService): The app service to get the oldest txn. Returns: - A Deferred which resolves to an AppServiceTransaction or - None. + An AppServiceTransaction or None. """ def _get_oldest_unsent_txn(txn): @@ -298,7 +291,7 @@ class ApplicationServiceTransactionWorkerStore( return entry - entry = yield self.db_pool.runInteraction( + entry = await self.db_pool.runInteraction( "get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn ) @@ -307,7 +300,7 @@ class ApplicationServiceTransactionWorkerStore( event_ids = db_to_json(entry["event_ids"]) - events = yield self.get_events_as_list(event_ids) + events = await self.get_events_as_list(event_ids) return AppServiceTransaction(service=service, id=entry["txn_id"], events=events) @@ -332,8 +325,7 @@ class ApplicationServiceTransactionWorkerStore( "set_appservice_last_pos", set_appservice_last_pos_txn ) - @defer.inlineCallbacks - def get_new_events_for_appservice(self, current_id, limit): + async def get_new_events_for_appservice(self, current_id, limit): """Get all new evnets""" def get_new_events_for_appservice_txn(txn): @@ -357,11 +349,11 @@ class ApplicationServiceTransactionWorkerStore( return upper_bound, [row[1] for row in rows] - upper_bound, event_ids = yield self.db_pool.runInteraction( + upper_bound, event_ids = await self.db_pool.runInteraction( "get_new_events_for_appservice", get_new_events_for_appservice_txn ) - events = yield self.get_events_as_list(event_ids) + events = await self.get_events_as_list(event_ids) return upper_bound, events diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py index cae6bda80e..45a1760170 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py @@ -17,12 +17,12 @@ from canonicaljson import encode_canonical_json from synapse.api.errors import Codes, SynapseError from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util.caches.descriptors import cached class FilteringStore(SQLBaseStore): - @cachedInlineCallbacks(num_args=2) - def get_user_filter(self, user_localpart, filter_id): + @cached(num_args=2) + async def get_user_filter(self, user_localpart, filter_id): # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail # with a coherent error message rather than 500 M_UNKNOWN. try: @@ -30,7 +30,7 @@ class FilteringStore(SQLBaseStore): except ValueError: raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM) - def_json = yield self.db_pool.simple_select_one_onecol( + def_json = await self.db_pool.simple_select_one_onecol( table="user_filters", keyvalues={"user_id": user_localpart, "filter_id": filter_id}, retcol="filter_json", diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index 75ea6d4b2f..380db3a3f3 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -14,12 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple - -from twisted.internet import defer +from typing import List, Optional, Tuple from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.types import JsonDict from synapse.util import json_encoder # The category ID for the "default" category. We don't store as null in the @@ -210,9 +209,8 @@ class GroupServerWorkerStore(SQLBaseStore): "get_rooms_for_summary", _get_rooms_for_summary_txn ) - @defer.inlineCallbacks - def get_group_categories(self, group_id): - rows = yield self.db_pool.simple_select_list( + async def get_group_categories(self, group_id): + rows = await self.db_pool.simple_select_list( table="group_room_categories", keyvalues={"group_id": group_id}, retcols=("category_id", "is_public", "profile"), @@ -227,9 +225,8 @@ class GroupServerWorkerStore(SQLBaseStore): for row in rows } - @defer.inlineCallbacks - def get_group_category(self, group_id, category_id): - category = yield self.db_pool.simple_select_one( + async def get_group_category(self, group_id, category_id): + category = await self.db_pool.simple_select_one( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, retcols=("is_public", "profile"), @@ -240,9 +237,8 @@ class GroupServerWorkerStore(SQLBaseStore): return category - @defer.inlineCallbacks - def get_group_roles(self, group_id): - rows = yield self.db_pool.simple_select_list( + async def get_group_roles(self, group_id): + rows = await self.db_pool.simple_select_list( table="group_roles", keyvalues={"group_id": group_id}, retcols=("role_id", "is_public", "profile"), @@ -257,9 +253,8 @@ class GroupServerWorkerStore(SQLBaseStore): for row in rows } - @defer.inlineCallbacks - def get_group_role(self, group_id, role_id): - role = yield self.db_pool.simple_select_one( + async def get_group_role(self, group_id, role_id): + role = await self.db_pool.simple_select_one( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, retcols=("is_public", "profile"), @@ -448,12 +443,11 @@ class GroupServerWorkerStore(SQLBaseStore): "get_attestations_need_renewals", _get_attestations_need_renewals_txn ) - @defer.inlineCallbacks - def get_remote_attestation(self, group_id, user_id): + async def get_remote_attestation(self, group_id, user_id): """Get the attestation that proves the remote agrees that the user is in the group. """ - row = yield self.db_pool.simple_select_one( + row = await self.db_pool.simple_select_one( table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, retcols=("valid_until_ms", "attestation_json"), @@ -499,13 +493,13 @@ class GroupServerWorkerStore(SQLBaseStore): "get_all_groups_for_user", _get_all_groups_for_user_txn ) - def get_groups_changes_for_user(self, user_id, from_token, to_token): + async def get_groups_changes_for_user(self, user_id, from_token, to_token): from_token = int(from_token) has_changed = self._group_updates_stream_cache.has_entity_changed( user_id, from_token ) if not has_changed: - return defer.succeed([]) + return [] def _get_groups_changes_for_user_txn(txn): sql = """ @@ -525,7 +519,7 @@ class GroupServerWorkerStore(SQLBaseStore): for group_id, membership, gtype, content_json in txn ] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_groups_changes_for_user", _get_groups_changes_for_user_txn ) @@ -1087,31 +1081,31 @@ class GroupServerStore(GroupServerWorkerStore): desc="update_group_publicity", ) - @defer.inlineCallbacks - def register_user_group_membership( + async def register_user_group_membership( self, - group_id, - user_id, - membership, - is_admin=False, - content={}, - local_attestation=None, - remote_attestation=None, - is_publicised=False, - ): + group_id: str, + user_id: str, + membership: str, + is_admin: bool = False, + content: JsonDict = {}, + local_attestation: Optional[dict] = None, + remote_attestation: Optional[dict] = None, + is_publicised: bool = False, + ) -> int: """Registers that a local user is a member of a (local or remote) group. Args: - group_id (str) - user_id (str) - membership (str) - is_admin (bool) - content (dict): Content of the membership, e.g. includes the inviter + group_id: The group the member is being added to. + user_id: THe user ID to add to the group. + membership: The type of group membership. + is_admin: Whether the user should be added as a group admin. + content: Content of the membership, e.g. includes the inviter if the user has been invited. - local_attestation (dict): If remote group then store the fact that we + local_attestation: If remote group then store the fact that we have given out an attestation, else None. - remote_attestation (dict): If remote group then store the remote + remote_attestation: If remote group then store the remote attestation from the group, else None. + is_publicised: Whether this should be publicised. """ def _register_user_group_membership_txn(txn, next_id): @@ -1188,18 +1182,17 @@ class GroupServerStore(GroupServerWorkerStore): return next_id with self._group_updates_id_gen.get_next() as next_id: - res = yield self.db_pool.runInteraction( + res = await self.db_pool.runInteraction( "register_user_group_membership", _register_user_group_membership_txn, next_id, ) return res - @defer.inlineCallbacks - def create_group( + async def create_group( self, group_id, user_id, name, avatar_url, short_description, long_description - ): - yield self.db_pool.simple_insert( + ) -> None: + await self.db_pool.simple_insert( table="groups", values={ "group_id": group_id, @@ -1212,9 +1205,8 @@ class GroupServerStore(GroupServerWorkerStore): desc="create_group", ) - @defer.inlineCallbacks - def update_group_profile(self, group_id, profile): - yield self.db_pool.simple_update_one( + async def update_group_profile(self, group_id, profile): + await self.db_pool.simple_update_one( table="groups", keyvalues={"group_id": group_id}, updatevalues=profile, diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 99e66dc6e9..59ba12820a 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -15,8 +15,6 @@ from typing import List, Tuple -from twisted.internet import defer - from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.presence import UserPresenceState from synapse.util.caches.descriptors import cached, cachedList @@ -24,14 +22,13 @@ from synapse.util.iterutils import batch_iter class PresenceStore(SQLBaseStore): - @defer.inlineCallbacks - def update_presence(self, presence_states): + async def update_presence(self, presence_states): stream_ordering_manager = self._presence_id_gen.get_next_mult( len(presence_states) ) with stream_ordering_manager as stream_orderings: - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "update_presence", self._update_presence_txn, stream_orderings, diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 4a4f2cb385..b8261357d4 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -13,18 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore from synapse.storage.databases.main.roommember import ProfileInfo class ProfileWorkerStore(SQLBaseStore): - @defer.inlineCallbacks - def get_profileinfo(self, user_localpart): + async def get_profileinfo(self, user_localpart): try: - profile = yield self.db_pool.simple_select_one( + profile = await self.db_pool.simple_select_one( table="profiles", keyvalues={"user_id": user_localpart}, retcols=("displayname", "avatar_url"), @@ -118,14 +115,13 @@ class ProfileStore(ProfileWorkerStore): desc="update_remote_profile_cache", ) - @defer.inlineCallbacks - def maybe_delete_remote_profile_cache(self, user_id): + async def maybe_delete_remote_profile_cache(self, user_id): """Check if we still care about the remote user's profile, and if we don't then remove their profile from the cache """ - subscribed = yield self.is_subscribed_remote_profile_for_user(user_id) + subscribed = await self.is_subscribed_remote_profile_for_user(user_id) if not subscribed: - yield self.db_pool.simple_delete( + await self.db_pool.simple_delete( table="remote_profile_cache", keyvalues={"user_id": user_id}, desc="delete_remote_profile_cache", @@ -151,11 +147,10 @@ class ProfileStore(ProfileWorkerStore): _get_remote_profile_cache_entries_that_expire_txn, ) - @defer.inlineCallbacks - def is_subscribed_remote_profile_for_user(self, user_id): + async def is_subscribed_remote_profile_for_user(self, user_id): """Check whether we are interested in a remote user's profile. """ - res = yield self.db_pool.simple_select_one_onecol( + res = await self.db_pool.simple_select_one_onecol( table="group_users", keyvalues={"user_id": user_id}, retcol="user_id", @@ -166,7 +161,7 @@ class ProfileStore(ProfileWorkerStore): if res: return True - res = yield self.db_pool.simple_select_one_onecol( + res = await self.db_pool.simple_select_one_onecol( table="group_invites", keyvalues={"user_id": user_id}, retcol="user_id", diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index b81f1449b7..a9ceffc20e 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -14,10 +14,12 @@ # limitations under the License. import logging +from typing import Optional import attr from synapse.api.constants import RelationTypes +from synapse.events import EventBase from synapse.storage._base import SQLBaseStore from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.relations import ( @@ -25,7 +27,7 @@ from synapse.storage.relations import ( PaginationChunk, RelationPaginationToken, ) -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -227,18 +229,18 @@ class RelationsWorkerStore(SQLBaseStore): "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn ) - @cachedInlineCallbacks() - def get_applicable_edit(self, event_id): + @cached() + async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: """Get the most recent edit (if any) that has happened for the given event. Correctly handles checking whether edits were allowed to happen. Args: - event_id (str): The original event ID + event_id: The original event ID Returns: - Deferred[EventBase|None]: Returns the most recent edit, if any. + The most recent edit, if any. """ # We only allow edits for `m.room.message` events that have the same sender @@ -268,15 +270,14 @@ class RelationsWorkerStore(SQLBaseStore): if row: return row[0] - edit_id = yield self.db_pool.runInteraction( + edit_id = await self.db_pool.runInteraction( "get_applicable_edit", _get_applicable_edit_txn ) if not edit_id: - return + return None - edit_event = yield self.get_event(edit_id, allow_none=True) - return edit_event + return await self.get_event(edit_id, allow_none=True) def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender): """Check if a user has already annotated an event with the same key diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 8804c0e4ac..52668dbdf9 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -18,8 +18,6 @@ from collections import namedtuple from canonicaljson import encode_canonical_json -from twisted.internet import defer - from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool @@ -126,8 +124,7 @@ class TransactionStore(SQLBaseStore): desc="set_received_txn_response", ) - @defer.inlineCallbacks - def get_destination_retry_timings(self, destination): + async def get_destination_retry_timings(self, destination): """Gets the current retry timings (if any) for a given destination. Args: @@ -142,7 +139,7 @@ class TransactionStore(SQLBaseStore): if result is not SENTINEL: return result - result = yield self.db_pool.runInteraction( + result = await self.db_pool.runInteraction( "get_destination_retry_timings", self._get_destination_retry_timings, destination, diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 1b516b7976..98b74890d5 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -178,14 +178,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_appservice_state_none(self): service = Mock(id="999") - state = yield self.store.get_appservice_state(service) + state = yield defer.ensureDeferred(self.store.get_appservice_state(service)) self.assertEquals(None, state) @defer.inlineCallbacks def test_get_appservice_state_up(self): yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP) service = Mock(id=self.as_list[0]["id"]) - state = yield self.store.get_appservice_state(service) + state = yield defer.ensureDeferred(self.store.get_appservice_state(service)) self.assertEquals(ApplicationServiceState.UP, state) @defer.inlineCallbacks @@ -194,13 +194,13 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN) yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) service = Mock(id=self.as_list[1]["id"]) - state = yield self.store.get_appservice_state(service) + state = yield defer.ensureDeferred(self.store.get_appservice_state(service)) self.assertEquals(ApplicationServiceState.DOWN, state) @defer.inlineCallbacks def test_get_appservices_by_state_none(self): - services = yield self.store.get_appservices_by_state( - ApplicationServiceState.DOWN + services = yield defer.ensureDeferred( + self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) self.assertEquals(0, len(services)) @@ -339,7 +339,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): def test_get_oldest_unsent_txn_none(self): service = Mock(id=self.as_list[0]["id"]) - txn = yield self.store.get_oldest_unsent_txn(service) + txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service)) self.assertEquals(None, txn) @defer.inlineCallbacks @@ -349,14 +349,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): other_events = [Mock(event_id="e5"), Mock(event_id="e6")] # we aren't testing store._base stuff here, so mock this out - self.store.get_events_as_list = Mock(return_value=events) + self.store.get_events_as_list = Mock(return_value=defer.succeed(events)) yield self._insert_txn(self.as_list[1]["id"], 9, other_events) yield self._insert_txn(service.id, 10, events) yield self._insert_txn(service.id, 11, other_events) yield self._insert_txn(service.id, 12, other_events) - txn = yield self.store.get_oldest_unsent_txn(service) + txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service)) self.assertEquals(service, txn.service) self.assertEquals(10, txn.id) self.assertEquals(events, txn.events) @@ -366,8 +366,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN) yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP) - services = yield self.store.get_appservices_by_state( - ApplicationServiceState.DOWN + services = yield defer.ensureDeferred( + self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) self.assertEquals(1, len(services)) self.assertEquals(self.as_list[0]["id"], services[0].id) @@ -379,8 +379,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) yield self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP) - services = yield self.store.get_appservices_by_state( - ApplicationServiceState.DOWN + services = yield defer.ensureDeferred( + self.store.get_appservices_by_state(ApplicationServiceState.DOWN) ) self.assertEquals(2, len(services)) self.assertEquals( -- cgit 1.5.1 From 050e20e7ca56c3a5985fdcf64012800c153260f2 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 17 Aug 2020 12:18:01 -0400 Subject: Convert some of the general database methods to async (#8100) --- changelog.d/8100.misc | 1 + synapse/storage/database.py | 23 ++++++++----------- synapse/storage/databases/main/appservice.py | 2 +- synapse/storage/databases/main/events_worker.py | 16 +++++++------ synapse/storage/databases/main/registration.py | 8 +++---- synapse/storage/databases/main/roommember.py | 4 ++-- tests/handlers/test_profile.py | 4 ++-- tests/handlers/test_typing.py | 2 +- tests/storage/test_appservice.py | 16 +++++++++---- tests/storage/test_base.py | 16 ++++++++----- tests/storage/test_event_push_actions.py | 30 +++++++++++++------------ tests/storage/test_main.py | 2 +- tests/storage/test_profile.py | 4 ++-- 13 files changed, 69 insertions(+), 59 deletions(-) create mode 100644 changelog.d/8100.misc (limited to 'tests/storage/test_appservice.py') diff --git a/changelog.d/8100.misc b/changelog.d/8100.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8100.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 4ada6f5563..8a9e06efcf 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -332,8 +332,7 @@ class DatabasePool(object): """ return self._db_pool.running - @defer.inlineCallbacks - def _check_safe_to_upsert(self): + async def _check_safe_to_upsert(self): """ Is it safe to use native UPSERT? @@ -342,7 +341,7 @@ class DatabasePool(object): If the background updates have not completed, wait 15 sec and check again. """ - updates = yield self.simple_select_list( + updates = await self.simple_select_list( "background_updates", keyvalues=None, retcols=["update_name"], @@ -614,8 +613,7 @@ class DatabasePool(object): # "Simple" SQL API methods that operate on a single table with no JOINs, # no complex WHERE clauses, just a dict of values for columns. - @defer.inlineCallbacks - def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"): + async def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"): """Executes an INSERT query on the named table. Args: @@ -631,7 +629,7 @@ class DatabasePool(object): `or_ignore` is True """ try: - yield self.runInteraction(desc, self.simple_insert_txn, table, values) + await self.runInteraction(desc, self.simple_insert_txn, table, values) except self.engine.module.IntegrityError: # We have to do or_ignore flag at this layer, since we can't reuse # a cursor after we receive an error from the db. @@ -684,8 +682,7 @@ class DatabasePool(object): txn.executemany(sql, vals) - @defer.inlineCallbacks - def simple_upsert( + async def simple_upsert( self, table, keyvalues, @@ -714,14 +711,14 @@ class DatabasePool(object): inserting lock (bool): True to lock the table when doing the upsert. Returns: - Deferred(None or bool): Native upserts always return None. Emulated + None or bool: Native upserts always return None. Emulated upserts return True if a new entry was created, False if an existing one was updated. """ attempts = 0 while True: try: - result = yield self.runInteraction( + return await self.runInteraction( desc, self.simple_upsert_txn, table, @@ -730,7 +727,6 @@ class DatabasePool(object): insertion_values, lock=lock, ) - return result except self.engine.module.IntegrityError as e: attempts += 1 if attempts >= 5: @@ -1121,8 +1117,7 @@ class DatabasePool(object): return cls.cursor_to_dict(txn) - @defer.inlineCallbacks - def simple_select_many_batch( + async def simple_select_many_batch( self, table, column, @@ -1156,7 +1151,7 @@ class DatabasePool(object): it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size) ] for chunk in chunks: - rows = yield self.runInteraction( + rows = await self.runInteraction( desc, self.simple_select_many_txn, table, diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 5cf1a88399..02568a2391 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -169,7 +169,7 @@ class ApplicationServiceTransactionWorkerStore( service(ApplicationService): The service whose state to set. state(ApplicationServiceState): The connectivity state to apply. Returns: - A Deferred which resolves when the state was set successfully. + An Awaitable which resolves when the state was set successfully. """ return self.db_pool.simple_upsert( "application_services_state", {"as_id": service.id}, {"state": state} diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 5687448e3d..8c63a0dc4d 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -847,13 +847,15 @@ class EventsWorkerStore(SQLBaseStore): """Given a list of event ids, check if we have already processed and stored them as non outliers. """ - rows = yield self.db_pool.simple_select_many_batch( - table="events", - retcols=("event_id",), - column="event_id", - iterable=list(event_ids), - keyvalues={"outlier": False}, - desc="have_events_in_timeline", + rows = yield defer.ensureDeferred( + self.db_pool.simple_select_many_batch( + table="events", + retcols=("event_id",), + column="event_id", + iterable=list(event_ids), + keyvalues={"outlier": False}, + desc="have_events_in_timeline", + ) ) return {r["event_id"] for r in rows} diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index de50fa6e94..068ad22b30 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -17,9 +17,7 @@ import logging import re -from typing import Dict, List, Optional - -from twisted.internet.defer import Deferred +from typing import Awaitable, Dict, List, Optional from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError @@ -563,7 +561,7 @@ class RegistrationWorkerStore(SQLBaseStore): id_server (str) Returns: - Deferred + Awaitable """ # We need to use an upsert, in case they user had already bound the # threepid @@ -1084,7 +1082,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): def record_user_external_id( self, auth_provider: str, external_id: str, user_id: str - ) -> Deferred: + ) -> Awaitable: """Record a mapping from an external user id to a mxid Args: diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 1cc8c08ed0..161edbeccb 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -767,13 +767,13 @@ class RoomMemberWorkerStore(EventsWorkerStore): return set(room_ids) - def get_membership_from_event_ids( + async def get_membership_from_event_ids( self, member_event_ids: Iterable[str] ) -> List[dict]: """Get user_id and membership of a set of event IDs. """ - return self.db_pool.simple_select_many_batch( + return await self.db_pool.simple_select_many_batch( table="room_memberships", column="event_id", iterable=member_event_ids, diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index d70e1fc608..b609b30d4a 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -64,7 +64,7 @@ class ProfileTestCase(unittest.TestCase): self.bob = UserID.from_string("@4567:test") self.alice = UserID.from_string("@alice:remote") - yield self.store.create_profile(self.frank.localpart) + yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart)) self.handler = hs.get_profile_handler() self.hs = hs @@ -157,7 +157,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_incoming_fed_query(self): - yield self.store.create_profile("caroline") + yield defer.ensureDeferred(self.store.create_profile("caroline")) yield self.store.set_profile_displayname("caroline", "Caroline") response = yield defer.ensureDeferred( diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 64afd581bc..e01de158e5 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -156,7 +156,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ([], 0) ) self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None - self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed( + self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable( None ) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 98b74890d5..a425e66f37 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -207,7 +207,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_appservices_state_down(self): service = Mock(id=self.as_list[1]["id"]) - yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) + yield defer.ensureDeferred( + self.store.set_appservice_state(service, ApplicationServiceState.DOWN) + ) rows = yield self.db_pool.runQuery( self.engine.convert_param_style( "SELECT as_id FROM application_services_state WHERE state=?" @@ -219,9 +221,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_appservices_state_multiple_up(self): service = Mock(id=self.as_list[1]["id"]) - yield self.store.set_appservice_state(service, ApplicationServiceState.UP) - yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) - yield self.store.set_appservice_state(service, ApplicationServiceState.UP) + yield defer.ensureDeferred( + self.store.set_appservice_state(service, ApplicationServiceState.UP) + ) + yield defer.ensureDeferred( + self.store.set_appservice_state(service, ApplicationServiceState.DOWN) + ) + yield defer.ensureDeferred( + self.store.set_appservice_state(service, ApplicationServiceState.UP) + ) rows = yield self.db_pool.runQuery( self.engine.convert_param_style( "SELECT as_id FROM application_services_state WHERE state=?" diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index efcaeef1e7..13bcac743a 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -66,8 +66,10 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_insert_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore.db_pool.simple_insert( - table="tablename", values={"columname": "Value"} + yield defer.ensureDeferred( + self.datastore.db_pool.simple_insert( + table="tablename", values={"columname": "Value"} + ) ) self.mock_txn.execute.assert_called_with( @@ -78,10 +80,12 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_insert_3cols(self): self.mock_txn.rowcount = 1 - yield self.datastore.db_pool.simple_insert( - table="tablename", - # Use OrderedDict() so we can assert on the SQL generated - values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]), + yield defer.ensureDeferred( + self.datastore.db_pool.simple_insert( + table="tablename", + # Use OrderedDict() so we can assert on the SQL generated + values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]), + ) ) self.mock_txn.execute.assert_called_with( diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 857db071d4..238bad5b45 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -142,20 +142,22 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_find_first_stream_ordering_after_ts(self): def add_event(so, ts): - return self.store.db_pool.simple_insert( - "events", - { - "stream_ordering": so, - "received_ts": ts, - "event_id": "event%i" % so, - "type": "", - "room_id": "", - "content": "", - "processed": True, - "outlier": False, - "topological_ordering": 0, - "depth": 0, - }, + return defer.ensureDeferred( + self.store.db_pool.simple_insert( + "events", + { + "stream_ordering": so, + "received_ts": ts, + "event_id": "event%i" % so, + "type": "", + "room_id": "", + "content": "", + "processed": True, + "outlier": False, + "topological_ordering": 0, + "depth": 0, + }, + ) ) # start with the base case where there are no events in the table diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py index ab0df5ea93..fbf8af940a 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py @@ -35,7 +35,7 @@ class DataStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_users_paginate(self): yield self.store.register_user(self.user.to_string(), "pass") - yield self.store.create_profile(self.user.localpart) + yield defer.ensureDeferred(self.store.create_profile(self.user.localpart)) yield self.store.set_profile_displayname(self.user.localpart, self.displayname) users, total = yield self.store.get_users_paginate( diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 9b6f7211ae..9d5b8aa47d 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -33,7 +33,7 @@ class ProfileStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_displayname(self): - yield self.store.create_profile(self.u_frank.localpart) + yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank") @@ -43,7 +43,7 @@ class ProfileStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_avatar_url(self): - yield self.store.create_profile(self.u_frank.localpart) + yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) yield self.store.set_profile_avatar_url( self.u_frank.localpart, "http://my.site/here" -- cgit 1.5.1 From f40645e60b9cab69c953094848be61c0989a91cb Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 18 Aug 2020 16:20:49 -0400 Subject: Convert events worker database to async/await. (#8071) --- changelog.d/8071.misc | 1 + synapse/event_auth.py | 2 +- synapse/handlers/federation.py | 16 +-- synapse/handlers/message.py | 6 +- synapse/handlers/room_member.py | 2 +- synapse/spam_checker_api/__init__.py | 2 +- synapse/state/__init__.py | 2 +- synapse/storage/databases/main/event_federation.py | 30 +++-- synapse/storage/databases/main/events_worker.py | 132 ++++++++++++--------- synapse/storage/databases/main/stream.py | 1 - .../test_resource_limits_server_notices.py | 6 +- tests/storage/test_appservice.py | 3 +- 12 files changed, 106 insertions(+), 97 deletions(-) create mode 100644 changelog.d/8071.misc (limited to 'tests/storage/test_appservice.py') diff --git a/changelog.d/8071.misc b/changelog.d/8071.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8071.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/event_auth.py b/synapse/event_auth.py index c0981eee62..8c907ad596 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -47,7 +47,7 @@ def check( Args: room_version_obj: the version of the room event: the event being checked. - auth_events (dict: event-key -> event): the existing room state. + auth_events: the existing room state. Raises: AuthError if the checks fail diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 593932adb7..5b270228e7 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1777,9 +1777,7 @@ class FederationHandler(BaseHandler): """Returns the state at the event. i.e. not including said event. """ - event = await self.store.get_event( - event_id, allow_none=False, check_room_id=room_id - ) + event = await self.store.get_event(event_id, check_room_id=room_id) state_groups = await self.state_store.get_state_groups(room_id, [event_id]) @@ -1805,9 +1803,7 @@ class FederationHandler(BaseHandler): async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]: """Returns the state at the event. i.e. not including said event. """ - event = await self.store.get_event( - event_id, allow_none=False, check_room_id=room_id - ) + event = await self.store.get_event(event_id, check_room_id=room_id) state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id]) @@ -2155,9 +2151,9 @@ class FederationHandler(BaseHandler): auth_types = auth_types_for_event(event) current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types] - current_auth_events = await self.store.get_events(current_state_ids) + auth_events_map = await self.store.get_events(current_state_ids) current_auth_events = { - (e.type, e.state_key): e for e in current_auth_events.values() + (e.type, e.state_key): e for e in auth_events_map.values() } try: @@ -2173,9 +2169,7 @@ class FederationHandler(BaseHandler): if not in_room: raise AuthError(403, "Host not in room.") - event = await self.store.get_event( - event_id, allow_none=False, check_room_id=room_id - ) + event = await self.store.get_event(event_id, check_room_id=room_id) # Just go through and process each event in `remote_auth_chain`. We # don't want to fall into the trap of `missing` being wrong. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 532fc30681..b999d91d1a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -960,7 +960,7 @@ class EventCreationHandler(object): allow_none=True, ) - is_admin_redaction = ( + is_admin_redaction = bool( original_event and event.sender != original_event.sender ) @@ -1080,8 +1080,8 @@ class EventCreationHandler(object): auth_events_ids = self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) - auth_events = await self.store.get_events(auth_events_ids) - auth_events = {(e.type, e.state_key): e for e in auth_events.values()} + auth_events_map = await self.store.get_events(auth_events_ids) + auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()} room_version = await self.store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 31705cdbdb..aa1ccde211 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -716,7 +716,7 @@ class RoomMemberHandler(object): guest_access = await self.store.get_event(guest_access_id) - return ( + return bool( guest_access and guest_access.content and "guest_access" in guest_access.content diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py index 9b78924d96..4d9b13ac04 100644 --- a/synapse/spam_checker_api/__init__.py +++ b/synapse/spam_checker_api/__init__.py @@ -51,5 +51,5 @@ class SpamCheckerApi(object): state_ids = yield self._store.get_filtered_current_state_ids( room_id=room_id, state_filter=StateFilter.from_types(types) ) - state = yield self._store.get_events(state_ids.values()) + state = yield defer.ensureDeferred(self._store.get_events(state_ids.values())) return state.values() diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index a1d3884667..dba8d91eef 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -641,7 +641,7 @@ class StateResolutionStore(object): allow_rejected (bool): If True return rejected events. Returns: - Deferred[dict[str, FrozenEvent]]: Dict from event_id to event. + Awaitable[dict[str, FrozenEvent]]: Dict from event_id to event. """ return self.store.get_events( diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 431bd76693..4826be630c 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): - def get_auth_chain(self, event_ids, include_given=False): + async def get_auth_chain(self, event_ids, include_given=False): """Get auth events for given event_ids. The events *must* be state events. Args: @@ -40,9 +40,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas Returns: list of events """ - return self.get_auth_chain_ids( + event_ids = await self.get_auth_chain_ids( event_ids, include_given=include_given - ).addCallback(self.get_events_as_list) + ) + return await self.get_events_as_list(event_ids) def get_auth_chain_ids( self, @@ -459,7 +460,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn ) - def get_backfill_events(self, room_id, event_list, limit): + async def get_backfill_events(self, room_id, event_list, limit): """Get a list of Events for a given topic that occurred before (and including) the events in event_list. Return a list of max size `limit` @@ -469,17 +470,15 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas event_list (list) limit (int) """ - return ( - self.db_pool.runInteraction( - "get_backfill_events", - self._get_backfill_events, - room_id, - event_list, - limit, - ) - .addCallback(self.get_events_as_list) - .addCallback(lambda l: sorted(l, key=lambda e: -e.depth)) + event_ids = await self.db_pool.runInteraction( + "get_backfill_events", + self._get_backfill_events, + room_id, + event_list, + limit, ) + events = await self.get_events_as_list(event_ids) + return sorted(events, key=lambda e: -e.depth) def _get_backfill_events(self, txn, room_id, event_list, limit): logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit) @@ -540,8 +539,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas latest_events, limit, ) - events = await self.get_events_as_list(ids) - return events + return await self.get_events_as_list(ids) def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit): diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 8c63a0dc4d..e3a154a527 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -19,9 +19,10 @@ import itertools import logging import threading from collections import namedtuple -from typing import List, Optional, Tuple +from typing import Dict, Iterable, List, Optional, Tuple, overload from constantly import NamedConstant, Names +from typing_extensions import Literal from twisted.internet import defer @@ -32,7 +33,7 @@ from synapse.api.room_versions import ( EventFormatVersions, RoomVersions, ) -from synapse.events import make_event_from_dict +from synapse.events import EventBase, make_event_from_dict from synapse.events.utils import prune_event from synapse.logging.context import PreserveLoggingContext, current_context from synapse.metrics.background_process_metrics import run_as_background_process @@ -42,8 +43,8 @@ from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator -from synapse.types import get_domain_from_id -from synapse.util.caches.descriptors import Cache, cachedInlineCallbacks +from synapse.types import Collection, get_domain_from_id +from synapse.util.caches.descriptors import Cache, cached from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -137,8 +138,33 @@ class EventsWorkerStore(SQLBaseStore): desc="get_received_ts", ) - @defer.inlineCallbacks - def get_event( + # Inform mypy that if allow_none is False (the default) then get_event + # always returns an EventBase. + @overload + async def get_event( + self, + event_id: str, + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: Literal[False] = False, + check_room_id: Optional[str] = None, + ) -> EventBase: + ... + + @overload + async def get_event( + self, + event_id: str, + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: Literal[True] = False, + check_room_id: Optional[str] = None, + ) -> Optional[EventBase]: + ... + + async def get_event( self, event_id: str, redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, @@ -146,7 +172,7 @@ class EventsWorkerStore(SQLBaseStore): allow_rejected: bool = False, allow_none: bool = False, check_room_id: Optional[str] = None, - ): + ) -> Optional[EventBase]: """Get an event from the database by event_id. Args: @@ -171,12 +197,12 @@ class EventsWorkerStore(SQLBaseStore): If there is a mismatch, behave as per allow_none. Returns: - Deferred[EventBase|None] + The event, or None if the event was not found. """ if not isinstance(event_id, str): raise TypeError("Invalid event event_id %r" % (event_id,)) - events = yield self.get_events_as_list( + events = await self.get_events_as_list( [event_id], redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, @@ -194,14 +220,13 @@ class EventsWorkerStore(SQLBaseStore): return event - @defer.inlineCallbacks - def get_events( + async def get_events( self, - event_ids: List[str], + event_ids: Iterable[str], redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, get_prev_content: bool = False, allow_rejected: bool = False, - ): + ) -> Dict[str, EventBase]: """Get events from the database Args: @@ -220,9 +245,9 @@ class EventsWorkerStore(SQLBaseStore): omits rejeted events from the response. Returns: - Deferred : Dict from event_id to event. + A mapping from event_id to event. """ - events = yield self.get_events_as_list( + events = await self.get_events_as_list( event_ids, redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, @@ -231,14 +256,13 @@ class EventsWorkerStore(SQLBaseStore): return {e.event_id: e for e in events} - @defer.inlineCallbacks - def get_events_as_list( + async def get_events_as_list( self, - event_ids: List[str], + event_ids: Collection[str], redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, get_prev_content: bool = False, allow_rejected: bool = False, - ): + ) -> List[EventBase]: """Get events from the database and return in a list in the same order as given by `event_ids` arg. @@ -259,8 +283,8 @@ class EventsWorkerStore(SQLBaseStore): omits rejected events from the response. Returns: - Deferred[list[EventBase]]: List of events fetched from the database. The - events are in the same order as `event_ids` arg. + List of events fetched from the database. The events are in the same + order as `event_ids` arg. Note that the returned list may be smaller than the list of event IDs if not all events could be fetched. @@ -270,7 +294,7 @@ class EventsWorkerStore(SQLBaseStore): return [] # there may be duplicates so we cast the list to a set - event_entry_map = yield self._get_events_from_cache_or_db( + event_entry_map = await self._get_events_from_cache_or_db( set(event_ids), allow_rejected=allow_rejected ) @@ -305,7 +329,7 @@ class EventsWorkerStore(SQLBaseStore): continue redacted_event_id = entry.event.redacts - event_map = yield self._get_events_from_cache_or_db([redacted_event_id]) + event_map = await self._get_events_from_cache_or_db([redacted_event_id]) original_event_entry = event_map.get(redacted_event_id) if not original_event_entry: # we don't have the redacted event (or it was rejected). @@ -371,7 +395,7 @@ class EventsWorkerStore(SQLBaseStore): if get_prev_content: if "replaces_state" in event.unsigned: - prev = yield self.get_event( + prev = await self.get_event( event.unsigned["replaces_state"], get_prev_content=False, allow_none=True, @@ -383,8 +407,7 @@ class EventsWorkerStore(SQLBaseStore): return events - @defer.inlineCallbacks - def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): + async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): """Fetch a bunch of events from the cache or the database. If events are pulled from the database, they will be cached for future lookups. @@ -399,7 +422,7 @@ class EventsWorkerStore(SQLBaseStore): rejected events are omitted from the response. Returns: - Deferred[Dict[str, _EventCacheEntry]]: + Dict[str, _EventCacheEntry]: map from event id to result """ event_entry_map = self._get_events_from_cache( @@ -417,7 +440,7 @@ class EventsWorkerStore(SQLBaseStore): # the events have been redacted, and if so pulling the redaction event out # of the database to check it. # - missing_events = yield self._get_events_from_db( + missing_events = await self._get_events_from_db( missing_events_ids, allow_rejected=allow_rejected ) @@ -525,8 +548,7 @@ class EventsWorkerStore(SQLBaseStore): with PreserveLoggingContext(): self.hs.get_reactor().callFromThread(fire, event_list, e) - @defer.inlineCallbacks - def _get_events_from_db(self, event_ids, allow_rejected=False): + async def _get_events_from_db(self, event_ids, allow_rejected=False): """Fetch a bunch of events from the database. Returned events will be added to the cache for future lookups. @@ -540,7 +562,7 @@ class EventsWorkerStore(SQLBaseStore): rejected events are omitted from the response. Returns: - Deferred[Dict[str, _EventCacheEntry]]: + Dict[str, _EventCacheEntry]: map from event id to result. May return extra events which weren't asked for. """ @@ -548,7 +570,7 @@ class EventsWorkerStore(SQLBaseStore): events_to_fetch = event_ids while events_to_fetch: - row_map = yield self._enqueue_events(events_to_fetch) + row_map = await self._enqueue_events(events_to_fetch) # we need to recursively fetch any redactions of those events redaction_ids = set() @@ -650,8 +672,7 @@ class EventsWorkerStore(SQLBaseStore): return result_map - @defer.inlineCallbacks - def _enqueue_events(self, events): + async def _enqueue_events(self, events): """Fetches events from the database using the _event_fetch_list. This allows batch and bulk fetching of events - it allows us to fetch events without having to create a new transaction for each request for events. @@ -660,7 +681,7 @@ class EventsWorkerStore(SQLBaseStore): events (Iterable[str]): events to be fetched. Returns: - Deferred[Dict[str, Dict]]: map from event id to row data from the database. + Dict[str, Dict]: map from event id to row data from the database. May contain events that weren't requested. """ @@ -683,7 +704,7 @@ class EventsWorkerStore(SQLBaseStore): logger.debug("Loading %d events: %s", len(events), events) with PreserveLoggingContext(): - row_map = yield events_d + row_map = await events_d logger.debug("Loaded %d events (%d rows)", len(events), len(row_map)) return row_map @@ -842,33 +863,29 @@ class EventsWorkerStore(SQLBaseStore): # no valid redaction found for this event return None - @defer.inlineCallbacks - def have_events_in_timeline(self, event_ids): + async def have_events_in_timeline(self, event_ids): """Given a list of event ids, check if we have already processed and stored them as non outliers. """ - rows = yield defer.ensureDeferred( - self.db_pool.simple_select_many_batch( - table="events", - retcols=("event_id",), - column="event_id", - iterable=list(event_ids), - keyvalues={"outlier": False}, - desc="have_events_in_timeline", - ) + rows = await self.db_pool.simple_select_many_batch( + table="events", + retcols=("event_id",), + column="event_id", + iterable=list(event_ids), + keyvalues={"outlier": False}, + desc="have_events_in_timeline", ) return {r["event_id"] for r in rows} - @defer.inlineCallbacks - def have_seen_events(self, event_ids): + async def have_seen_events(self, event_ids): """Given a list of event ids, check if we have already processed them. Args: event_ids (iterable[str]): Returns: - Deferred[set[str]]: The events we have already seen. + set[str]: The events we have already seen. """ results = set() @@ -884,7 +901,7 @@ class EventsWorkerStore(SQLBaseStore): # break the input up into chunks of 100 input_iterator = iter(event_ids) for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []): - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "have_seen_events", have_seen_events_txn, chunk ) return results @@ -914,8 +931,7 @@ class EventsWorkerStore(SQLBaseStore): room_id, ) - @defer.inlineCallbacks - def get_room_complexity(self, room_id): + async def get_room_complexity(self, room_id): """ Get a rough approximation of the complexity of the room. This is used by remote servers to decide whether they wish to join the room or not. @@ -926,9 +942,9 @@ class EventsWorkerStore(SQLBaseStore): room_id (str) Returns: - Deferred[dict[str:int]] of complexity version to complexity. + dict[str:int] of complexity version to complexity. """ - state_events = yield self.get_current_state_event_counts(room_id) + state_events = await self.get_current_state_event_counts(room_id) # Call this one "v1", so we can introduce new ones as we want to develop # it. @@ -1165,9 +1181,9 @@ class EventsWorkerStore(SQLBaseStore): to_2, so_2 = await self.get_event_ordering(event_id2) return (to_1, so_1) > (to_2, so_2) - @cachedInlineCallbacks(max_entries=5000) - def get_event_ordering(self, event_id): - res = yield self.db_pool.simple_select_one( + @cached(max_entries=5000) + async def get_event_ordering(self, event_id): + res = await self.db_pool.simple_select_one( table="events", retcols=["topological_ordering", "stream_ordering"], keyvalues={"event_id": event_id}, diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 4377bddb8c..497f607703 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -379,7 +379,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): limit: int = 0, order: str = "DESC", ) -> Tuple[List[EventBase], str]: - """Get new room events in stream ordering since `from_key`. Args: diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 2858d13558..23db821fb7 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -104,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( - return_value=defer.succeed({"123": mock_event}) + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check the content, but once == remove blocking event @@ -122,7 +122,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( - return_value=defer.succeed({"123": mock_event}) + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -217,7 +217,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( - return_value=defer.succeed({"123": mock_event}) + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index a425e66f37..17fbde284a 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -31,6 +31,7 @@ from synapse.storage.databases.main.appservice import ( ) from tests import unittest +from tests.test_utils import make_awaitable from tests.utils import setup_test_homeserver @@ -357,7 +358,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): other_events = [Mock(event_id="e5"), Mock(event_id="e6")] # we aren't testing store._base stuff here, so mock this out - self.store.get_events_as_list = Mock(return_value=defer.succeed(events)) + self.store.get_events_as_list = Mock(return_value=make_awaitable(events)) yield self._insert_txn(self.as_list[1]["id"], 9, other_events) yield self._insert_txn(service.id, 10, events) -- cgit 1.5.1