From 894dae74fe8e79911c3c001c8b84620ef3985bf6 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Aug 2020 07:24:26 -0400 Subject: Convert misc database code to async (#8087) --- synapse/storage/databases/main/devices.py | 5 ++--- synapse/storage/databases/main/event_push_actions.py | 9 ++++----- synapse/storage/databases/main/presence.py | 9 +++------ synapse/storage/databases/main/push_rule.py | 16 ++++++---------- synapse/storage/databases/main/pusher.py | 9 +++------ synapse/storage/databases/main/receipts.py | 5 ++--- synapse/storage/databases/main/roommember.py | 17 ++++++----------- synapse/storage/databases/main/state.py | 5 ++--- synapse/storage/databases/main/user_erasure_store.py | 13 +++++-------- 9 files changed, 33 insertions(+), 55 deletions(-) (limited to 'synapse/storage/databases') diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 2b33060480..9a786e2929 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -671,10 +671,9 @@ class DeviceWorkerStore(SQLBaseStore): @cachedList( cached_method_name="get_device_list_last_stream_id_for_remote", list_name="user_ids", - inlineCallbacks=True, ) - def get_device_list_last_stream_id_for_remotes(self, user_ids: str): - rows = yield self.db_pool.simple_select_many_batch( + async def get_device_list_last_stream_id_for_remotes(self, user_ids: str): + rows = await self.db_pool.simple_select_many_batch( table="device_lists_remote_extremeties", column="user_id", iterable=user_ids, diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 7c246d3e4c..e8834b2162 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -21,7 +21,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.util import json_encoder -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -86,18 +86,17 @@ class EventPushActionsWorkerStore(SQLBaseStore): self._rotate_delay = 3 self._rotate_count = 10000 - @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000) - def get_unread_event_push_actions_by_room_for_user( + @cached(num_args=3, tree=True, max_entries=5000) + async def get_unread_event_push_actions_by_room_for_user( self, room_id, user_id, last_read_event_id ): - ret = yield self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_unread_event_push_actions_by_room", self._get_unread_counts_by_receipt_txn, room_id, user_id, last_read_event_id, ) - return ret def _get_unread_counts_by_receipt_txn( self, txn, room_id, user_id, last_read_event_id diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 59ba12820a..fd213d2dfd 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -130,13 +130,10 @@ class PresenceStore(SQLBaseStore): raise NotImplementedError() @cachedList( - cached_method_name="_get_presence_for_user", - list_name="user_ids", - num_args=1, - inlineCallbacks=True, + cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1, ) - def get_presence_for_users(self, user_ids): - rows = yield self.db_pool.simple_select_many_batch( + async def get_presence_for_users(self, user_ids): + rows = await self.db_pool.simple_select_many_batch( table="presence_stream", column="user_id", iterable=user_ids, diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 6562db5c2b..6aa5802977 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -170,18 +170,15 @@ class PushRulesWorkerStore( ) @cachedList( - cached_method_name="get_push_rules_for_user", - list_name="user_ids", - num_args=1, - inlineCallbacks=True, + cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1, ) - def bulk_get_push_rules(self, user_ids): + async def bulk_get_push_rules(self, user_ids): if not user_ids: return {} results = {user_id: [] for user_id in user_ids} - rows = yield self.db_pool.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="push_rules", column="user_name", iterable=user_ids, @@ -194,7 +191,7 @@ class PushRulesWorkerStore( for row in rows: results.setdefault(row["user_name"], []).append(row) - enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids) + enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids) for user_id, rules in results.items(): use_new_defaults = user_id in self._users_new_default_push_rules @@ -260,15 +257,14 @@ class PushRulesWorkerStore( cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids", num_args=1, - inlineCallbacks=True, ) - def bulk_get_push_rules_enabled(self, user_ids): + async def bulk_get_push_rules_enabled(self, user_ids): if not user_ids: return {} results = {user_id: {} for user_id in user_ids} - rows = yield self.db_pool.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="push_rules_enable", column="user_name", iterable=user_ids, diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index b5200fbe79..8b793d1487 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -170,13 +170,10 @@ class PusherWorkerStore(SQLBaseStore): raise NotImplementedError() @cachedList( - cached_method_name="get_if_user_has_pusher", - list_name="user_ids", - num_args=1, - inlineCallbacks=True, + cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1, ) - def get_if_users_have_pushers(self, user_ids): - rows = yield self.db_pool.simple_select_many_batch( + async def get_if_users_have_pushers(self, user_ids): + rows = await self.db_pool.simple_select_many_batch( table="pushers", column="user_name", iterable=user_ids, diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 1920a8a152..579b7bb17b 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -212,9 +212,8 @@ class ReceiptsWorkerStore(SQLBaseStore): cached_method_name="_get_linearized_receipts_for_room", list_name="room_ids", num_args=3, - inlineCallbacks=True, ) - def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): + async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): if not room_ids: return {} @@ -243,7 +242,7 @@ class ReceiptsWorkerStore(SQLBaseStore): return self.db_pool.cursor_to_dict(txn) - txn_results = yield self.db_pool.runInteraction( + txn_results = await self.db_pool.runInteraction( "_get_linearized_receipts_for_rooms", f ) diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index b2fcfc9bfe..1cc8c08ed0 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -17,8 +17,6 @@ import logging from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.events import EventBase from synapse.events.snapshot import EventContext @@ -92,8 +90,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): lambda: self._known_servers_count, ) - @defer.inlineCallbacks - def _count_known_servers(self): + async def _count_known_servers(self): """ Count the servers that this server knows about. @@ -121,7 +118,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): txn.execute(query) return list(txn)[0][0] - count = yield self.db_pool.runInteraction("get_known_servers", _transact) + count = await self.db_pool.runInteraction("get_known_servers", _transact) # We always know about ourselves, even if we have nothing in # room_memberships (for example, the server is new). @@ -589,11 +586,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): raise NotImplementedError() @cachedList( - cached_method_name="_get_joined_profile_from_event_id", - list_name="event_ids", - inlineCallbacks=True, + cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids", ) - def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]): + async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]): """For given set of member event_ids check if they point to a join event and if so return the associated user and profile info. @@ -601,11 +596,11 @@ class RoomMemberWorkerStore(EventsWorkerStore): event_ids: The member event IDs to lookup Returns: - Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID + dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID to `user_id` and ProfileInfo (or None if not join event). """ - rows = yield self.db_pool.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="room_memberships", column="event_id", iterable=event_ids, diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 96e0378e50..991233a9bc 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -273,12 +273,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): cached_method_name="_get_state_group_for_event", list_name="event_ids", num_args=1, - inlineCallbacks=True, ) - def _get_state_group_for_events(self, event_ids): + async def _get_state_group_for_events(self, event_ids): """Returns mapping event_id -> state_group """ - rows = yield self.db_pool.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="event_to_state_groups", column="event_id", iterable=event_ids, diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py index ab6cb2c1f6..da23fe7355 100644 --- a/synapse/storage/databases/main/user_erasure_store.py +++ b/synapse/storage/databases/main/user_erasure_store.py @@ -38,10 +38,8 @@ class UserErasureWorkerStore(SQLBaseStore): desc="is_user_erased", ).addCallback(operator.truth) - @cachedList( - cached_method_name="is_user_erased", list_name="user_ids", inlineCallbacks=True - ) - def are_users_erased(self, user_ids): + @cachedList(cached_method_name="is_user_erased", list_name="user_ids") + async def are_users_erased(self, user_ids): """ Checks which users in a list have requested erasure @@ -49,14 +47,14 @@ class UserErasureWorkerStore(SQLBaseStore): user_ids (iterable[str]): full user id to check Returns: - Deferred[dict[str, bool]]: + dict[str, bool]: for each user, whether the user has requested erasure. """ # this serves the dual purpose of (a) making sure we can do len and # iterate it multiple times, and (b) avoiding duplicates. user_ids = tuple(set(user_ids)) - rows = yield self.db_pool.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="erased_users", column="user_id", iterable=user_ids, @@ -65,8 +63,7 @@ class UserErasureWorkerStore(SQLBaseStore): ) erased_users = {row["user_id"] for row in rows} - res = {u: u in erased_users for u in user_ids} - return res + return {u: u in erased_users for u in user_ids} class UserErasureStore(UserErasureWorkerStore): -- cgit 1.5.1 From 6b7ce1d332766bc2da9e99b22a452e0813a2aae3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Aug 2020 09:25:40 -0400 Subject: Remove some unused database functions. (#8085) --- changelog.d/8085.misc | 1 + synapse/storage/databases/main/event_federation.py | 13 -- synapse/storage/databases/main/events_worker.py | 170 +-------------------- synapse/storage/databases/main/presence.py | 21 --- synapse/storage/databases/main/registration.py | 37 ----- synapse/storage/databases/main/room.py | 4 - .../delta/58/13remove_presence_allow_inbound.sql | 17 +++ 7 files changed, 19 insertions(+), 244 deletions(-) create mode 100644 changelog.d/8085.misc create mode 100644 synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql (limited to 'synapse/storage/databases') diff --git a/changelog.d/8085.misc b/changelog.d/8085.misc new file mode 100644 index 0000000000..c3da1e297c --- /dev/null +++ b/changelog.d/8085.misc @@ -0,0 +1 @@ +Remove some unused database functions. diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 484875f989..431bd76693 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -257,11 +257,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas # Return all events where not all sets can reach them. return {eid for eid, n in event_to_missing_sets.items() if n} - def get_oldest_events_in_room(self, room_id): - return self.db_pool.runInteraction( - "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id - ) - def get_oldest_events_with_depth_in_room(self, room_id): return self.db_pool.runInteraction( "get_oldest_events_with_depth_in_room", @@ -303,14 +298,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas else: return max(row["depth"] for row in rows) - def _get_oldest_events_in_room_txn(self, txn, room_id): - return self.db_pool.simple_select_onecol_txn( - txn, - table="event_backward_extremities", - keyvalues={"room_id": room_id}, - retcol="event_id", - ) - def get_prev_events_for_room(self, room_id: str): """ Gets a subset of the current forward extremities in the given room. diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 755b7a2a85..5687448e3d 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -43,7 +43,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import get_domain_from_id -from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import Cache, cachedInlineCallbacks from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -137,42 +137,6 @@ class EventsWorkerStore(SQLBaseStore): desc="get_received_ts", ) - def get_received_ts_by_stream_pos(self, stream_ordering): - """Given a stream ordering get an approximate timestamp of when it - happened. - - This is done by simply taking the received ts of the first event that - has a stream ordering greater than or equal to the given stream pos. - If none exists returns the current time, on the assumption that it must - have happened recently. - - Args: - stream_ordering (int) - - Returns: - Deferred[int] - """ - - def _get_approximate_received_ts_txn(txn): - sql = """ - SELECT received_ts FROM events - WHERE stream_ordering >= ? - LIMIT 1 - """ - - txn.execute(sql, (stream_ordering,)) - row = txn.fetchone() - if row and row[0]: - ts = row[0] - else: - ts = self.clock.time_msec() - - return ts - - return self.db_pool.runInteraction( - "get_approximate_received_ts", _get_approximate_received_ts_txn - ) - @defer.inlineCallbacks def get_event( self, @@ -923,36 +887,6 @@ class EventsWorkerStore(SQLBaseStore): ) return results - def _get_total_state_event_counts_txn(self, txn, room_id): - """ - See get_total_state_event_counts. - """ - # We join against the events table as that has an index on room_id - sql = """ - SELECT COUNT(*) FROM state_events - INNER JOIN events USING (room_id, event_id) - WHERE room_id=? - """ - txn.execute(sql, (room_id,)) - row = txn.fetchone() - return row[0] if row else 0 - - def get_total_state_event_counts(self, room_id): - """ - Gets the total number of state events in a room. - - Args: - room_id (str) - - Returns: - Deferred[int] - """ - return self.db_pool.runInteraction( - "get_total_state_event_counts", - self._get_total_state_event_counts_txn, - room_id, - ) - def _get_current_state_event_counts_txn(self, txn, room_id): """ See get_current_state_event_counts. @@ -1222,97 +1156,6 @@ class EventsWorkerStore(SQLBaseStore): return rows, to_token, True - @cached(num_args=5, max_entries=10) - def get_all_new_events( - self, - last_backfill_id, - last_forward_id, - current_backfill_id, - current_forward_id, - limit, - ): - """Get all the new events that have arrived at the server either as - new events or as backfilled events""" - have_backfill_events = last_backfill_id != current_backfill_id - have_forward_events = last_forward_id != current_forward_id - - if not have_backfill_events and not have_forward_events: - return defer.succeed(AllNewEventsResult([], [], [], [], [])) - - def get_all_new_events_txn(txn): - sql = ( - "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " WHERE ? < stream_ordering AND stream_ordering <= ?" - " ORDER BY stream_ordering ASC" - " LIMIT ?" - ) - if have_forward_events: - txn.execute(sql, (last_forward_id, current_forward_id, limit)) - new_forward_events = txn.fetchall() - - if len(new_forward_events) == limit: - upper_bound = new_forward_events[-1][0] - else: - upper_bound = current_forward_id - - sql = ( - "SELECT event_stream_ordering, event_id, state_group" - " FROM ex_outlier_stream" - " WHERE ? > event_stream_ordering" - " AND event_stream_ordering >= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (last_forward_id, upper_bound)) - forward_ex_outliers = txn.fetchall() - else: - new_forward_events = [] - forward_ex_outliers = [] - - sql = ( - "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " WHERE ? > stream_ordering AND stream_ordering >= ?" - " ORDER BY stream_ordering DESC" - " LIMIT ?" - ) - if have_backfill_events: - txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit)) - new_backfill_events = txn.fetchall() - - if len(new_backfill_events) == limit: - upper_bound = new_backfill_events[-1][0] - else: - upper_bound = current_backfill_id - - sql = ( - "SELECT -event_stream_ordering, event_id, state_group" - " FROM ex_outlier_stream" - " WHERE ? > event_stream_ordering" - " AND event_stream_ordering >= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (-last_backfill_id, -upper_bound)) - backward_ex_outliers = txn.fetchall() - else: - new_backfill_events = [] - backward_ex_outliers = [] - - return AllNewEventsResult( - new_forward_events, - new_backfill_events, - forward_ex_outliers, - backward_ex_outliers, - ) - - return self.db_pool.runInteraction("get_all_new_events", get_all_new_events_txn) - async def is_event_after(self, event_id1, event_id2): """Returns True if event_id1 is after event_id2 in the stream """ @@ -1357,14 +1200,3 @@ class EventsWorkerStore(SQLBaseStore): return self.db_pool.runInteraction( desc="get_next_event_to_expire", func=get_next_event_to_expire_txn ) - - -AllNewEventsResult = namedtuple( - "AllNewEventsResult", - [ - "new_forward_events", - "new_backfill_events", - "forward_ex_outliers", - "backward_ex_outliers", - ], -) diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index fd213d2dfd..9f691e5792 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -157,24 +157,3 @@ class PresenceStore(SQLBaseStore): def get_current_presence_token(self): return self._presence_id_gen.get_current_token() - - def allow_presence_visible(self, observed_localpart, observer_userid): - return self.db_pool.simple_insert( - table="presence_allow_inbound", - values={ - "observed_user_id": observed_localpart, - "observer_user_id": observer_userid, - }, - desc="allow_presence_visible", - or_ignore=True, - ) - - def disallow_presence_visible(self, observed_localpart, observer_userid): - return self.db_pool.simple_delete_one( - table="presence_allow_inbound", - keyvalues={ - "observed_user_id": observed_localpart, - "observer_user_id": observer_userid, - }, - desc="disallow_presence_visible", - ) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 402ae25571..7965a52e30 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1345,43 +1345,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): "validate_threepid_session_txn", validate_threepid_session_txn ) - def upsert_threepid_validation_session( - self, - medium, - address, - client_secret, - send_attempt, - session_id, - validated_at=None, - ): - """Upsert a threepid validation session - Args: - medium (str): The medium of the 3PID - address (str): The address of the 3PID - client_secret (str): A unique string provided by the client to - help identify this validation attempt - send_attempt (int): The latest send_attempt on this session - session_id (str): The id of this validation session - validated_at (int|None): The unix timestamp in milliseconds of - when the session was marked as valid - """ - insertion_values = { - "medium": medium, - "address": address, - "client_secret": client_secret, - } - - if validated_at: - insertion_values["validated_at"] = validated_at - - return self.db_pool.simple_upsert( - table="threepid_validation_session", - keyvalues={"session_id": session_id}, - values={"last_send_attempt": send_attempt}, - insertion_values=insertion_values, - desc="upsert_threepid_validation_session", - ) - def start_or_continue_validation_session( self, medium, diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index f4008e6221..aef08c7e12 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -35,10 +35,6 @@ from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) -OpsLevel = collections.namedtuple( - "OpsLevel", ("ban_level", "kick_level", "redact_level") -) - RatelimitOverride = collections.namedtuple( "RatelimitOverride", ("messages_per_second", "burst_count") ) diff --git a/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql b/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql new file mode 100644 index 0000000000..15421b99ac --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql @@ -0,0 +1,17 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This table is no longer used. +DROP TABLE IF EXISTS presence_allow_inbound; -- cgit 1.5.1 From e8861957d9005a9f6cad050a55a478b7706f34c9 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Aug 2020 10:05:19 -0400 Subject: Convert receipts and events databases to async/await. (#8076) --- changelog.d/8076.misc | 1 + synapse/storage/databases/main/events.py | 33 ++++----- .../storage/databases/main/events_bg_updates.py | 46 +++++------- synapse/storage/databases/main/receipts.py | 82 ++++++++++++---------- 4 files changed, 80 insertions(+), 82 deletions(-) create mode 100644 changelog.d/8076.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/8076.misc b/changelog.d/8076.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8076.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 1a68bf32cb..b90e6de2d5 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -17,13 +17,11 @@ import itertools import logging from collections import OrderedDict, namedtuple -from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple import attr from prometheus_client import Counter -from twisted.internet import defer - import synapse.metrics from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.room_versions import RoomVersions @@ -113,15 +111,14 @@ class PersistEventsStore: hs.config.worker.writers.events == hs.get_instance_name() ), "Can only instantiate EventsStore on master" - @defer.inlineCallbacks - def _persist_events_and_state_updates( + async def _persist_events_and_state_updates( self, events_and_contexts: List[Tuple[EventBase, EventContext]], current_state_for_room: Dict[str, StateMap[str]], state_delta_for_room: Dict[str, DeltaState], new_forward_extremeties: Dict[str, List[str]], backfilled: bool = False, - ): + ) -> None: """Persist a set of events alongside updates to the current state and forward extremities tables. @@ -136,7 +133,7 @@ class PersistEventsStore: backfilled Returns: - Deferred: resolves when the events have been persisted + Resolves when the events have been persisted """ # We want to calculate the stream orderings as late as possible, as @@ -168,7 +165,7 @@ class PersistEventsStore: for (event, context), stream in zip(events_and_contexts, stream_orderings): event.internal_metadata.stream_ordering = stream - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "persist_events", self._persist_events_txn, events_and_contexts=events_and_contexts, @@ -206,16 +203,15 @@ class PersistEventsStore: (room_id,), list(latest_event_ids) ) - @defer.inlineCallbacks - def _get_events_which_are_prevs(self, event_ids): + async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]: """Filter the supplied list of event_ids to get those which are prev_events of existing (non-outlier/rejected) events. Args: - event_ids (Iterable[str]): event ids to filter + event_ids: event ids to filter Returns: - Deferred[List[str]]: filtered event ids + Filtered event ids """ results = [] @@ -240,14 +236,13 @@ class PersistEventsStore: results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed")) for chunk in batch_iter(event_ids, 100): - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk ) return results - @defer.inlineCallbacks - def _get_prevs_before_rejected(self, event_ids): + async def _get_prevs_before_rejected(self, event_ids: Iterable[str]) -> Set[str]: """Get soft-failed ancestors to remove from the extremities. Given a set of events, find all those that have been soft-failed or @@ -259,11 +254,11 @@ class PersistEventsStore: are separated by soft failed events. Args: - event_ids (Iterable[str]): Events to find prev events for. Note - that these must have already been persisted. + event_ids: Events to find prev events for. Note that these must have + already been persisted. Returns: - Deferred[set[str]] + The previous events. """ # The set of event_ids to return. This includes all soft-failed events @@ -304,7 +299,7 @@ class PersistEventsStore: existing_prevs.add(prev_event_id) for chunk in batch_iter(event_ids, 100): - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk ) diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 35a0e09e3c..e53c6373a8 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.constants import EventContentFields from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool @@ -94,8 +92,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): where_clause="NOT have_censored", ) - @defer.inlineCallbacks - def _background_reindex_fields_sender(self, progress, batch_size): + async def _background_reindex_fields_sender(self, progress, batch_size): target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) @@ -155,19 +152,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): return len(rows) - result = yield self.db_pool.runInteraction( + result = await self.db_pool.runInteraction( self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn ) if not result: - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME ) return result - @defer.inlineCallbacks - def _background_reindex_origin_server_ts(self, progress, batch_size): + async def _background_reindex_origin_server_ts(self, progress, batch_size): target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) @@ -234,19 +230,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): return len(rows_to_update) - result = yield self.db_pool.runInteraction( + result = await self.db_pool.runInteraction( self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn ) if not result: - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.EVENT_ORIGIN_SERVER_TS_NAME ) return result - @defer.inlineCallbacks - def _cleanup_extremities_bg_update(self, progress, batch_size): + async def _cleanup_extremities_bg_update(self, progress, batch_size): """Background update to clean out extremities that should have been deleted previously. @@ -414,26 +409,25 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): return len(original_set) - num_handled = yield self.db_pool.runInteraction( + num_handled = await self.db_pool.runInteraction( "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn ) if not num_handled: - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( self.DELETE_SOFT_FAILED_EXTREMITIES ) def _drop_table_txn(txn): txn.execute("DROP TABLE _extremities_to_check") - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_cleanup_extremities_bg_update_drop_table", _drop_table_txn ) return num_handled - @defer.inlineCallbacks - def _redactions_received_ts(self, progress, batch_size): + async def _redactions_received_ts(self, progress, batch_size): """Handles filling out the `received_ts` column in redactions. """ last_event_id = progress.get("last_event_id", "") @@ -480,17 +474,16 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): return len(rows) - count = yield self.db_pool.runInteraction( + count = await self.db_pool.runInteraction( "_redactions_received_ts", _redactions_received_ts_txn ) if not count: - yield self.db_pool.updates._end_background_update("redactions_received_ts") + await self.db_pool.updates._end_background_update("redactions_received_ts") return count - @defer.inlineCallbacks - def _event_fix_redactions_bytes(self, progress, batch_size): + async def _event_fix_redactions_bytes(self, progress, batch_size): """Undoes hex encoded censored redacted event JSON. """ @@ -511,16 +504,15 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): txn.execute("DROP INDEX redactions_censored_redacts") - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn ) - yield self.db_pool.updates._end_background_update("event_fix_redactions_bytes") + await self.db_pool.updates._end_background_update("event_fix_redactions_bytes") return 1 - @defer.inlineCallbacks - def _event_store_labels(self, progress, batch_size): + async def _event_store_labels(self, progress, batch_size): """Background update handler which will store labels for existing events.""" last_event_id = progress.get("last_event_id", "") @@ -575,11 +567,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): return nbrows - num_rows = yield self.db_pool.runInteraction( + num_rows = await self.db_pool.runInteraction( desc="event_store_labels", func=_event_store_labels_txn ) if not num_rows: - yield self.db_pool.updates._end_background_update("event_store_labels") + await self.db_pool.updates._end_background_update("event_store_labels") return num_rows diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 579b7bb17b..19ad1c056f 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -16,7 +16,7 @@ import abc import logging -from typing import List, Tuple +from typing import List, Optional, Tuple from twisted.internet import defer @@ -25,7 +25,7 @@ from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList +from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) @@ -56,9 +56,9 @@ class ReceiptsWorkerStore(SQLBaseStore): """ raise NotImplementedError() - @cachedInlineCallbacks() - def get_users_with_read_receipts_in_room(self, room_id): - receipts = yield self.get_receipts_for_room(room_id, "m.read") + @cached() + async def get_users_with_read_receipts_in_room(self, room_id): + receipts = await self.get_receipts_for_room(room_id, "m.read") return {r["user_id"] for r in receipts} @cached(num_args=2) @@ -84,9 +84,9 @@ class ReceiptsWorkerStore(SQLBaseStore): allow_none=True, ) - @cachedInlineCallbacks(num_args=2) - def get_receipts_for_user(self, user_id, receipt_type): - rows = yield self.db_pool.simple_select_list( + @cached(num_args=2) + async def get_receipts_for_user(self, user_id, receipt_type): + rows = await self.db_pool.simple_select_list( table="receipts_linearized", keyvalues={"user_id": user_id, "receipt_type": receipt_type}, retcols=("room_id", "event_id"), @@ -95,8 +95,7 @@ class ReceiptsWorkerStore(SQLBaseStore): return {row["room_id"]: row["event_id"] for row in rows} - @defer.inlineCallbacks - def get_receipts_for_user_with_orderings(self, user_id, receipt_type): + async def get_receipts_for_user_with_orderings(self, user_id, receipt_type): def f(txn): sql = ( "SELECT rl.room_id, rl.event_id," @@ -110,7 +109,7 @@ class ReceiptsWorkerStore(SQLBaseStore): txn.execute(sql, (user_id,)) return txn.fetchall() - rows = yield self.db_pool.runInteraction( + rows = await self.db_pool.runInteraction( "get_receipts_for_user_with_orderings", f ) return { @@ -122,56 +121,61 @@ class ReceiptsWorkerStore(SQLBaseStore): for row in rows } - @defer.inlineCallbacks - def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): + async def get_linearized_receipts_for_rooms( + self, room_ids: List[str], to_key: int, from_key: Optional[int] = None + ) -> List[dict]: """Get receipts for multiple rooms for sending to clients. Args: - room_ids (list): List of room_ids. - to_key (int): Max stream id to fetch receipts upto. - from_key (int): Min stream id to fetch receipts from. None fetches + room_id: List of room_ids. + to_key: Max stream id to fetch receipts upto. + from_key: Min stream id to fetch receipts from. None fetches from the start. Returns: - list: A list of receipts. + A list of receipts. """ room_ids = set(room_ids) if from_key is not None: # Only ask the database about rooms where there have been new # receipts added since `from_key` - room_ids = yield self._receipts_stream_cache.get_entities_changed( + room_ids = self._receipts_stream_cache.get_entities_changed( room_ids, from_key ) - results = yield self._get_linearized_receipts_for_rooms( + results = await self._get_linearized_receipts_for_rooms( room_ids, to_key, from_key=from_key ) return [ev for res in results.values() for ev in res] - def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): + async def get_linearized_receipts_for_room( + self, room_id: str, to_key: int, from_key: Optional[int] = None + ) -> List[dict]: """Get receipts for a single room for sending to clients. Args: - room_ids (str): The room id. - to_key (int): Max stream id to fetch receipts upto. - from_key (int): Min stream id to fetch receipts from. None fetches + room_ids: The room id. + to_key: Max stream id to fetch receipts upto. + from_key: Min stream id to fetch receipts from. None fetches from the start. Returns: - Deferred[list]: A list of receipts. + A list of receipts. """ if from_key is not None: # Check the cache first to see if any new receipts have been added # since`from_key`. If not we can no-op. if not self._receipts_stream_cache.has_entity_changed(room_id, from_key): - defer.succeed([]) + return [] - return self._get_linearized_receipts_for_room(room_id, to_key, from_key) + return await self._get_linearized_receipts_for_room(room_id, to_key, from_key) - @cachedInlineCallbacks(num_args=3, tree=True) - def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): + @cached(num_args=3, tree=True) + async def _get_linearized_receipts_for_room( + self, room_id: str, to_key: int, from_key: Optional[int] = None + ) -> List[dict]: """See get_linearized_receipts_for_room """ @@ -195,7 +199,7 @@ class ReceiptsWorkerStore(SQLBaseStore): return rows - rows = yield self.db_pool.runInteraction("get_linearized_receipts_for_room", f) + rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f) if not rows: return [] @@ -345,7 +349,7 @@ class ReceiptsWorkerStore(SQLBaseStore): ) def _invalidate_get_users_with_receipts_in_room( - self, room_id, receipt_type, user_id + self, room_id: str, receipt_type: str, user_id: str ): if receipt_type != "m.read": return @@ -471,15 +475,21 @@ class ReceiptsStore(ReceiptsWorkerStore): return rx_ts - @defer.inlineCallbacks - def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data): + async def insert_receipt( + self, + room_id: str, + receipt_type: str, + user_id: str, + event_ids: List[str], + data: dict, + ) -> Optional[Tuple[int, int]]: """Insert a receipt, either from local client or remote server. Automatically does conversion between linearized and graph representations. """ if not event_ids: - return + return None if len(event_ids) == 1: linearized_event_id = event_ids[0] @@ -506,13 +516,13 @@ class ReceiptsStore(ReceiptsWorkerStore): else: raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,)) - linearized_event_id = yield self.db_pool.runInteraction( + linearized_event_id = await self.db_pool.runInteraction( "insert_receipt_conv", graph_to_linear ) stream_id_manager = self._receipts_id_gen.get_next() with stream_id_manager as stream_id: - event_ts = yield self.db_pool.runInteraction( + event_ts = await self.db_pool.runInteraction( "insert_linearized_receipt", self.insert_linearized_receipt_txn, room_id, @@ -534,7 +544,7 @@ class ReceiptsStore(ReceiptsWorkerStore): now - event_ts, ) - yield self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data) + await self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data) max_persisted_id = self._receipts_id_gen.get_current_token() -- cgit 1.5.1 From b069b78bb4fc9ce005cc84099e208497a0789ddc Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Aug 2020 10:30:16 -0400 Subject: Convert pusher databases to async/await. (#8075) --- changelog.d/8075.misc | 1 + synapse/rest/client/v1/push_rule.py | 9 +-- synapse/storage/databases/main/push_rule.py | 80 ++++++++++++------------ synapse/storage/databases/main/pusher.py | 95 ++++++++++++++--------------- 4 files changed, 90 insertions(+), 95 deletions(-) create mode 100644 changelog.d/8075.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/8075.misc b/changelog.d/8075.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8075.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 00831879f3..e2df638cc5 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - from synapse.api.errors import ( NotFoundError, StoreError, @@ -163,7 +162,7 @@ class PushRuleRestServlet(RestServlet): stream_id, _ = self.store.get_push_rules_stream_token() self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id]) - def set_rule_attr(self, user_id, spec, val): + async def set_rule_attr(self, user_id, spec, val): if spec["attr"] == "enabled": if isinstance(val, dict) and "enabled" in val: val = val["enabled"] @@ -173,7 +172,9 @@ class PushRuleRestServlet(RestServlet): # bools directly, so let's not break them. raise SynapseError(400, "Value for 'enabled' must be boolean") namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - return self.store.set_push_rule_enabled(user_id, namespaced_rule_id, val) + return await self.store.set_push_rule_enabled( + user_id, namespaced_rule_id, val + ) elif spec["attr"] == "actions": actions = val.get("actions") _check_actions(actions) @@ -188,7 +189,7 @@ class PushRuleRestServlet(RestServlet): if namespaced_rule_id not in rule_ids: raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,)) - return self.store.set_push_rule_actions( + return await self.store.set_push_rule_actions( user_id, namespaced_rule_id, actions, is_default_rule ) else: diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 6aa5802977..c2289a9557 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -32,7 +32,7 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException from synapse.storage.util.id_generators import ChainedIdGenerator from synapse.util import json_encoder -from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList +from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) @@ -115,9 +115,9 @@ class PushRulesWorkerStore( """ raise NotImplementedError() - @cachedInlineCallbacks(max_entries=5000) - def get_push_rules_for_user(self, user_id): - rows = yield self.db_pool.simple_select_list( + @cached(max_entries=5000) + async def get_push_rules_for_user(self, user_id): + rows = await self.db_pool.simple_select_list( table="push_rules", keyvalues={"user_name": user_id}, retcols=( @@ -133,17 +133,15 @@ class PushRulesWorkerStore( rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) - enabled_map = yield self.get_push_rules_enabled_for_user(user_id) + enabled_map = await self.get_push_rules_enabled_for_user(user_id) use_new_defaults = user_id in self._users_new_default_push_rules - rules = _load_rules(rows, enabled_map, use_new_defaults) + return _load_rules(rows, enabled_map, use_new_defaults) - return rules - - @cachedInlineCallbacks(max_entries=5000) - def get_push_rules_enabled_for_user(self, user_id): - results = yield self.db_pool.simple_select_list( + @cached(max_entries=5000) + async def get_push_rules_enabled_for_user(self, user_id): + results = await self.db_pool.simple_select_list( table="push_rules_enable", keyvalues={"user_name": user_id}, retcols=("user_name", "rule_id", "enabled"), @@ -202,14 +200,15 @@ class PushRulesWorkerStore( return results - @defer.inlineCallbacks - def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule): + async def copy_push_rule_from_room_to_room( + self, new_room_id: str, user_id: str, rule: dict + ) -> None: """Copy a single push rule from one room to another for a specific user. Args: - new_room_id (str): ID of the new room. - user_id (str): ID of user the push rule belongs to. - rule (Dict): A push rule. + new_room_id: ID of the new room. + user_id : ID of user the push rule belongs to. + rule: A push rule. """ # Create new rule id rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1]) @@ -221,7 +220,7 @@ class PushRulesWorkerStore( condition["pattern"] = new_room_id # Add the rule for the new room - yield self.add_push_rule( + await self.add_push_rule( user_id=user_id, rule_id=new_rule_id, priority_class=rule["priority_class"], @@ -229,20 +228,19 @@ class PushRulesWorkerStore( actions=rule["actions"], ) - @defer.inlineCallbacks - def copy_push_rules_from_room_to_room_for_user( - self, old_room_id, new_room_id, user_id - ): + async def copy_push_rules_from_room_to_room_for_user( + self, old_room_id: str, new_room_id: str, user_id: str + ) -> None: """Copy all of the push rules from one room to another for a specific user. Args: - old_room_id (str): ID of the old room. - new_room_id (str): ID of the new room. - user_id (str): ID of user to copy push rules for. + old_room_id: ID of the old room. + new_room_id: ID of the new room. + user_id: ID of user to copy push rules for. """ # Retrieve push rules for this user - user_push_rules = yield self.get_push_rules_for_user(user_id) + user_push_rules = await self.get_push_rules_for_user(user_id) # Get rules relating to the old room and copy them to the new room for rule in user_push_rules: @@ -251,7 +249,7 @@ class PushRulesWorkerStore( (c.get("key") == "room_id" and c.get("pattern") == old_room_id) for c in conditions ): - yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule) + await self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule) @cachedList( cached_method_name="get_push_rules_enabled_for_user", @@ -328,8 +326,7 @@ class PushRulesWorkerStore( class PushRuleStore(PushRulesWorkerStore): - @defer.inlineCallbacks - def add_push_rule( + async def add_push_rule( self, user_id, rule_id, @@ -338,13 +335,13 @@ class PushRuleStore(PushRulesWorkerStore): actions, before=None, after=None, - ): + ) -> None: conditions_json = json_encoder.encode(conditions) actions_json = json_encoder.encode(actions) with self._push_rules_stream_id_gen.get_next() as ids: stream_id, event_stream_ordering = ids if before or after: - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_add_push_rule_relative_txn", self._add_push_rule_relative_txn, stream_id, @@ -358,7 +355,7 @@ class PushRuleStore(PushRulesWorkerStore): after, ) else: - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_add_push_rule_highest_priority_txn", self._add_push_rule_highest_priority_txn, stream_id, @@ -542,16 +539,15 @@ class PushRuleStore(PushRulesWorkerStore): }, ) - @defer.inlineCallbacks - def delete_push_rule(self, user_id, rule_id): + async def delete_push_rule(self, user_id: str, rule_id: str) -> None: """ Delete a push rule. Args specify the row to be deleted and can be any of the columns in the push_rule table, but below are the standard ones Args: - user_id (str): The matrix ID of the push rule owner - rule_id (str): The rule_id of the rule to be deleted + user_id: The matrix ID of the push rule owner + rule_id: The rule_id of the rule to be deleted """ def delete_push_rule_txn(txn, stream_id, event_stream_ordering): @@ -565,18 +561,17 @@ class PushRuleStore(PushRulesWorkerStore): with self._push_rules_stream_id_gen.get_next() as ids: stream_id, event_stream_ordering = ids - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_push_rule", delete_push_rule_txn, stream_id, event_stream_ordering, ) - @defer.inlineCallbacks - def set_push_rule_enabled(self, user_id, rule_id, enabled): + async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None: with self._push_rules_stream_id_gen.get_next() as ids: stream_id, event_stream_ordering = ids - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_set_push_rule_enabled_txn", self._set_push_rule_enabled_txn, stream_id, @@ -607,8 +602,9 @@ class PushRuleStore(PushRulesWorkerStore): op="ENABLE" if enabled else "DISABLE", ) - @defer.inlineCallbacks - def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule): + async def set_push_rule_actions( + self, user_id, rule_id, actions, is_default_rule + ) -> None: actions_json = json_encoder.encode(actions) def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering): @@ -649,7 +645,7 @@ class PushRuleStore(PushRulesWorkerStore): with self._push_rules_stream_id_gen.get_next() as ids: stream_id, event_stream_ordering = ids - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "set_push_rule_actions", set_push_rule_actions_txn, stream_id, diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 8b793d1487..1126fd0751 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -19,10 +19,8 @@ from typing import Iterable, Iterator, List, Tuple from canonicaljson import encode_canonical_json -from twisted.internet import defer - from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList +from synapse.util.caches.descriptors import cached, cachedList logger = logging.getLogger(__name__) @@ -34,23 +32,22 @@ class PusherWorkerStore(SQLBaseStore): Drops any rows whose data cannot be decoded """ for r in rows: - dataJson = r["data"] + data_json = r["data"] try: - r["data"] = db_to_json(dataJson) + r["data"] = db_to_json(data_json) except Exception as e: logger.warning( "Invalid JSON in data for pusher %d: %s, %s", r["id"], - dataJson, + data_json, e.args[0], ) continue yield r - @defer.inlineCallbacks - def user_has_pusher(self, user_id): - ret = yield self.db_pool.simple_select_one_onecol( + async def user_has_pusher(self, user_id): + ret = await self.db_pool.simple_select_one_onecol( "pushers", {"user_name": user_id}, "id", allow_none=True ) return ret is not None @@ -61,9 +58,8 @@ class PusherWorkerStore(SQLBaseStore): def get_pushers_by_user_id(self, user_id): return self.get_pushers_by({"user_name": user_id}) - @defer.inlineCallbacks - def get_pushers_by(self, keyvalues): - ret = yield self.db_pool.simple_select_list( + async def get_pushers_by(self, keyvalues): + ret = await self.db_pool.simple_select_list( "pushers", keyvalues, [ @@ -87,16 +83,14 @@ class PusherWorkerStore(SQLBaseStore): ) return self._decode_pushers_rows(ret) - @defer.inlineCallbacks - def get_all_pushers(self): + async def get_all_pushers(self): def get_pushers(txn): txn.execute("SELECT * FROM pushers") rows = self.db_pool.cursor_to_dict(txn) return self._decode_pushers_rows(rows) - rows = yield self.db_pool.runInteraction("get_all_pushers", get_pushers) - return rows + return await self.db_pool.runInteraction("get_all_pushers", get_pushers) async def get_all_updated_pushers_rows( self, instance_name: str, last_id: int, current_id: int, limit: int @@ -164,8 +158,8 @@ class PusherWorkerStore(SQLBaseStore): "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn ) - @cachedInlineCallbacks(num_args=1, max_entries=15000) - def get_if_user_has_pusher(self, user_id): + @cached(num_args=1, max_entries=15000) + async def get_if_user_has_pusher(self, user_id): # This only exists for the cachedList decorator raise NotImplementedError() @@ -186,34 +180,38 @@ class PusherWorkerStore(SQLBaseStore): return result - @defer.inlineCallbacks - def update_pusher_last_stream_ordering( + async def update_pusher_last_stream_ordering( self, app_id, pushkey, user_id, last_stream_ordering - ): - yield self.db_pool.simple_update_one( + ) -> None: + await self.db_pool.simple_update_one( "pushers", {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, {"last_stream_ordering": last_stream_ordering}, desc="update_pusher_last_stream_ordering", ) - @defer.inlineCallbacks - def update_pusher_last_stream_ordering_and_success( - self, app_id, pushkey, user_id, last_stream_ordering, last_success - ): + async def update_pusher_last_stream_ordering_and_success( + self, + app_id: str, + pushkey: str, + user_id: str, + last_stream_ordering: int, + last_success: int, + ) -> bool: """Update the last stream ordering position we've processed up to for the given pusher. Args: - app_id (str) - pushkey (str) - last_stream_ordering (int) - last_success (int) + app_id + pushkey + user_id + last_stream_ordering + last_success Returns: - Deferred[bool]: True if the pusher still exists; False if it has been deleted. + True if the pusher still exists; False if it has been deleted. """ - updated = yield self.db_pool.simple_update( + updated = await self.db_pool.simple_update( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, updatevalues={ @@ -225,18 +223,18 @@ class PusherWorkerStore(SQLBaseStore): return bool(updated) - @defer.inlineCallbacks - def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since): - yield self.db_pool.simple_update( + async def update_pusher_failing_since( + self, app_id, pushkey, user_id, failing_since + ) -> None: + await self.db_pool.simple_update( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, updatevalues={"failing_since": failing_since}, desc="update_pusher_failing_since", ) - @defer.inlineCallbacks - def get_throttle_params_by_room(self, pusher_id): - res = yield self.db_pool.simple_select_list( + async def get_throttle_params_by_room(self, pusher_id): + res = await self.db_pool.simple_select_list( "pusher_throttle", {"pusher": pusher_id}, ["room_id", "last_sent_ts", "throttle_ms"], @@ -252,11 +250,10 @@ class PusherWorkerStore(SQLBaseStore): return params_by_room - @defer.inlineCallbacks - def set_throttle_params(self, pusher_id, room_id, params): + async def set_throttle_params(self, pusher_id, room_id, params) -> None: # no need to lock because `pusher_throttle` has a primary key on # (pusher, room_id) so simple_upsert will retry - yield self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( "pusher_throttle", {"pusher": pusher_id, "room_id": room_id}, params, @@ -269,8 +266,7 @@ class PusherStore(PusherWorkerStore): def get_pushers_stream_token(self): return self._pushers_id_gen.get_current_token() - @defer.inlineCallbacks - def add_pusher( + async def add_pusher( self, user_id, access_token, @@ -284,11 +280,11 @@ class PusherStore(PusherWorkerStore): data, last_stream_ordering, profile_tag="", - ): + ) -> None: with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on # (app_id, pushkey, user_name) so simple_upsert will retry - yield self.db_pool.simple_upsert( + await self.db_pool.simple_upsert( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, values={ @@ -313,15 +309,16 @@ class PusherStore(PusherWorkerStore): if user_has_pusher is not True: # invalidate, since we the user might not have had a pusher before - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "add_pusher", self._invalidate_cache_and_stream, self.get_if_user_has_pusher, (user_id,), ) - @defer.inlineCallbacks - def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id): + async def delete_pusher_by_app_id_pushkey_user_id( + self, app_id, pushkey, user_id + ) -> None: def delete_pusher_txn(txn, stream_id): self._invalidate_cache_and_stream( txn, self.get_if_user_has_pusher, (user_id,) @@ -348,6 +345,6 @@ class PusherStore(PusherWorkerStore): ) with self._pushers_id_gen.get_next() as stream_id: - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_pusher", delete_pusher_txn, stream_id ) -- cgit 1.5.1 From ac77cdb64e50c9fdfc00cccbc7b96f42057aa741 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Aug 2020 12:37:59 -0400 Subject: Add a shadow-banned flag to users. (#8092) --- changelog.d/8092.feature | 1 + synapse/api/auth.py | 12 ++++++++++- synapse/handlers/register.py | 8 +++++++ synapse/replication/http/register.py | 4 ++++ synapse/storage/databases/main/registration.py | 9 +++++++- .../main/schema/delta/58/09shadow_ban.sql | 18 ++++++++++++++++ synapse/types.py | 25 +++++++++++++++++++--- tests/storage/test_cleanup_extrems.py | 4 ++-- tests/storage/test_event_metrics.py | 2 +- tests/storage/test_roommember.py | 2 +- tests/test_federation.py | 2 +- tests/unittest.py | 8 +++++-- 12 files changed, 83 insertions(+), 12 deletions(-) create mode 100644 changelog.d/8092.feature create mode 100644 synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql (limited to 'synapse/storage/databases') diff --git a/changelog.d/8092.feature b/changelog.d/8092.feature new file mode 100644 index 0000000000..813e6d0903 --- /dev/null +++ b/changelog.d/8092.feature @@ -0,0 +1 @@ +Add support for shadow-banning users (ignoring any message send requests). diff --git a/synapse/api/auth.py b/synapse/api/auth.py index d8190f92ab..7aab764360 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -213,6 +213,7 @@ class Auth(object): user = user_info["user"] token_id = user_info["token_id"] is_guest = user_info["is_guest"] + shadow_banned = user_info["shadow_banned"] # Deny the request if the user account has expired. if self._account_validity.enabled and not allow_expired: @@ -252,7 +253,12 @@ class Auth(object): opentracing.set_tag("device_id", device_id) return synapse.types.create_requester( - user, token_id, is_guest, device_id, app_service=app_service + user, + token_id, + is_guest, + shadow_banned, + device_id, + app_service=app_service, ) except KeyError: raise MissingClientTokenError() @@ -297,6 +303,7 @@ class Auth(object): dict that includes: `user` (UserID) `is_guest` (bool) + `shadow_banned` (bool) `token_id` (int|None): access token id. May be None if guest `device_id` (str|None): device corresponding to access token Raises: @@ -356,6 +363,7 @@ class Auth(object): ret = { "user": user, "is_guest": True, + "shadow_banned": False, "token_id": None, # all guests get the same device id "device_id": GUEST_DEVICE_ID, @@ -365,6 +373,7 @@ class Auth(object): ret = { "user": user, "is_guest": False, + "shadow_banned": False, "token_id": None, "device_id": None, } @@ -488,6 +497,7 @@ class Auth(object): "user": UserID.from_string(ret.get("name")), "token_id": ret.get("token_id", None), "is_guest": False, + "shadow_banned": ret.get("shadow_banned"), "device_id": ret.get("device_id"), "valid_until_ms": ret.get("valid_until_ms"), } diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index c94209ab3d..999bc6efb5 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -142,6 +142,7 @@ class RegistrationHandler(BaseHandler): address=None, bind_emails=[], by_admin=False, + shadow_banned=False, ): """Registers a new client on the server. @@ -159,6 +160,7 @@ class RegistrationHandler(BaseHandler): bind_emails (List[str]): list of emails to bind to this account. by_admin (bool): True if this registration is being made via the admin api, otherwise False. + shadow_banned (bool): Shadow-ban the created user. Returns: str: user_id Raises: @@ -194,6 +196,7 @@ class RegistrationHandler(BaseHandler): admin=admin, user_type=user_type, address=address, + shadow_banned=shadow_banned, ) if self.hs.config.user_directory_search_all_users: @@ -224,6 +227,7 @@ class RegistrationHandler(BaseHandler): make_guest=make_guest, create_profile_with_displayname=default_display_name, address=address, + shadow_banned=shadow_banned, ) # Successfully registered @@ -529,6 +533,7 @@ class RegistrationHandler(BaseHandler): admin=False, user_type=None, address=None, + shadow_banned=False, ): """Register user in the datastore. @@ -546,6 +551,7 @@ class RegistrationHandler(BaseHandler): user_type (str|None): type of user. One of the values from api.constants.UserTypes, or None for a normal user. address (str|None): the IP address used to perform the registration. + shadow_banned (bool): Whether to shadow-ban the user Returns: Awaitable @@ -561,6 +567,7 @@ class RegistrationHandler(BaseHandler): admin=admin, user_type=user_type, address=address, + shadow_banned=shadow_banned, ) else: return self.store.register_user( @@ -572,6 +579,7 @@ class RegistrationHandler(BaseHandler): create_profile_with_displayname=create_profile_with_displayname, admin=admin, user_type=user_type, + shadow_banned=shadow_banned, ) async def register_device( diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index ce9420aa69..a02b27474d 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -44,6 +44,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): admin, user_type, address, + shadow_banned, ): """ Args: @@ -60,6 +61,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): user_type (str|None): type of user. One of the values from api.constants.UserTypes, or None for a normal user. address (str|None): the IP address used to perform the regitration. + shadow_banned (bool): Whether to shadow-ban the user """ return { "password_hash": password_hash, @@ -70,6 +72,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): "admin": admin, "user_type": user_type, "address": address, + "shadow_banned": shadow_banned, } async def _handle_request(self, request, user_id): @@ -87,6 +90,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): admin=content["admin"], user_type=content["user_type"], address=content["address"], + shadow_banned=content["shadow_banned"], ) return 200, {} diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 7965a52e30..de50fa6e94 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -304,7 +304,7 @@ class RegistrationWorkerStore(SQLBaseStore): def _query_for_auth(self, txn, token): sql = ( - "SELECT users.name, users.is_guest, access_tokens.id as token_id," + "SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id," " access_tokens.device_id, access_tokens.valid_until_ms" " FROM users" " INNER JOIN access_tokens on users.name = access_tokens.user_id" @@ -952,6 +952,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): create_profile_with_displayname=None, admin=False, user_type=None, + shadow_banned=False, ): """Attempts to register an account. @@ -968,6 +969,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): admin (boolean): is an admin user? user_type (str|None): type of user. One of the values from api.constants.UserTypes, or None for a normal user. + shadow_banned (bool): Whether the user is shadow-banned, + i.e. they may be told their requests succeeded but we ignore them. Raises: StoreError if the user_id could not be registered. @@ -986,6 +989,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): create_profile_with_displayname, admin, user_type, + shadow_banned, ) def _register_user( @@ -999,6 +1003,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): create_profile_with_displayname, admin, user_type, + shadow_banned, ): user_id_obj = UserID.from_string(user_id) @@ -1028,6 +1033,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): "appservice_id": appservice_id, "admin": 1 if admin else 0, "user_type": user_type, + "shadow_banned": shadow_banned, }, ) else: @@ -1042,6 +1048,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): "appservice_id": appservice_id, "admin": 1 if admin else 0, "user_type": user_type, + "shadow_banned": shadow_banned, }, ) diff --git a/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql b/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql new file mode 100644 index 0000000000..260b009b48 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql @@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- A shadow-banned user may be told that their requests succeeded when they were +-- actually ignored. +ALTER TABLE users ADD COLUMN shadow_banned BOOLEAN; diff --git a/synapse/types.py b/synapse/types.py index 9e580f4295..bc36cdde30 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -51,7 +51,15 @@ JsonDict = Dict[str, Any] class Requester( namedtuple( - "Requester", ["user", "access_token_id", "is_guest", "device_id", "app_service"] + "Requester", + [ + "user", + "access_token_id", + "is_guest", + "shadow_banned", + "device_id", + "app_service", + ], ) ): """ @@ -62,6 +70,7 @@ class Requester( access_token_id (int|None): *ID* of the access token used for this request, or None if it came via the appservice API or similar is_guest (bool): True if the user making this request is a guest user + shadow_banned (bool): True if the user making this request has been shadow-banned. device_id (str|None): device_id which was set at authentication time app_service (ApplicationService|None): the AS requesting on behalf of the user """ @@ -77,6 +86,7 @@ class Requester( "user_id": self.user.to_string(), "access_token_id": self.access_token_id, "is_guest": self.is_guest, + "shadow_banned": self.shadow_banned, "device_id": self.device_id, "app_server_id": self.app_service.id if self.app_service else None, } @@ -101,13 +111,19 @@ class Requester( user=UserID.from_string(input["user_id"]), access_token_id=input["access_token_id"], is_guest=input["is_guest"], + shadow_banned=input["shadow_banned"], device_id=input["device_id"], app_service=appservice, ) def create_requester( - user_id, access_token_id=None, is_guest=False, device_id=None, app_service=None + user_id, + access_token_id=None, + is_guest=False, + shadow_banned=False, + device_id=None, + app_service=None, ): """ Create a new ``Requester`` object @@ -117,6 +133,7 @@ def create_requester( access_token_id (int|None): *ID* of the access token used for this request, or None if it came via the appservice API or similar is_guest (bool): True if the user making this request is a guest user + shadow_banned (bool): True if the user making this request is shadow-banned. device_id (str|None): device_id which was set at authentication time app_service (ApplicationService|None): the AS requesting on behalf of the user @@ -125,7 +142,9 @@ def create_requester( """ if not isinstance(user_id, UserID): user_id = UserID.from_string(user_id) - return Requester(user_id, access_token_id, is_guest, device_id, app_service) + return Requester( + user_id, access_token_id, is_guest, shadow_banned, device_id, app_service + ) def get_domain_from_id(string): diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 3fab5a5248..8e9a650f9f 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -38,7 +38,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): # Create a test user and room self.user = UserID("alice", "test") - self.requester = Requester(self.user, None, False, None, None) + self.requester = Requester(self.user, None, False, False, None, None) info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] @@ -260,7 +260,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): # Create a test user and room self.user = UserID.from_string(self.register_user("user1", "password")) self.token1 = self.login("user1", "password") - self.requester = Requester(self.user, None, False, None, None) + self.requester = Requester(self.user, None, False, False, None, None) info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] self.event_creator = homeserver.get_event_creation_handler() diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index a7b85004e5..949846fe33 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -27,7 +27,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase): room_creator = self.hs.get_room_creation_handler() user = UserID("alice", "test") - requester = Requester(user, None, False, None, None) + requester = Requester(user, None, False, False, None, None) # Real events, forward extremities events = [(3, 2), (6, 2), (4, 6)] diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 17c9da4838..d98fe8754d 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -187,7 +187,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): # Now let's create a room, which will insert a membership user = UserID("alice", "test") - requester = Requester(user, None, False, None, None) + requester = Requester(user, None, False, False, None, None) self.get_success(self.room_creator.create_room(requester, {})) # Register the background update to run again. diff --git a/tests/test_federation.py b/tests/test_federation.py index f2fa42bfb9..4a4548433f 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -42,7 +42,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) user_id = UserID("us", "test") - our_user = Requester(user_id, None, False, None, None) + our_user = Requester(user_id, None, False, False, None, None) room_creator = self.homeserver.get_room_creation_handler() room_deferred = ensureDeferred( room_creator.create_room( diff --git a/tests/unittest.py b/tests/unittest.py index d0bba3ddef..7b80999a74 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -250,7 +250,11 @@ class HomeserverTestCase(TestCase): async def get_user_by_req(request, allow_guest=False, rights="access"): return create_requester( - UserID.from_string(self.helper.auth_user_id), 1, False, None + UserID.from_string(self.helper.auth_user_id), + 1, + False, + False, + None, ) self.hs.get_auth().get_user_by_req = get_user_by_req @@ -540,7 +544,7 @@ class HomeserverTestCase(TestCase): """ event_creator = self.hs.get_event_creation_handler() secrets = self.hs.get_secrets() - requester = Requester(user, None, False, None, None) + requester = Requester(user, None, False, False, None, None) event, context = self.get_success( event_creator.create_event( -- cgit 1.5.1 From ad6190c9252aafd37cd8c229b70853bfc4ef0e64 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 17 Aug 2020 07:24:46 -0400 Subject: Convert stream database to async/await. (#8074) --- changelog.d/8074.misc | 1 + synapse/api/filtering.py | 2 +- synapse/api/presence.py | 69 ++++ synapse/federation/send_queue.py | 2 +- synapse/federation/sender/__init__.py | 2 +- synapse/federation/sender/per_destination_queue.py | 2 +- synapse/handlers/presence.py | 2 +- synapse/storage/databases/main/presence.py | 2 +- synapse/storage/databases/main/stream.py | 387 +++++++++++---------- synapse/storage/presence.py | 69 ---- tests/handlers/test_presence.py | 2 +- tests/storage/test_purge.py | 49 +-- 12 files changed, 293 insertions(+), 296 deletions(-) create mode 100644 changelog.d/8074.misc create mode 100644 synapse/api/presence.py delete mode 100644 synapse/storage/presence.py (limited to 'synapse/storage/databases') diff --git a/changelog.d/8074.misc b/changelog.d/8074.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8074.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 7393d6cb74..a8937d2595 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -23,7 +23,7 @@ from jsonschema import FormatChecker from synapse.api.constants import EventContentFields from synapse.api.errors import SynapseError -from synapse.storage.presence import UserPresenceState +from synapse.api.presence import UserPresenceState from synapse.types import RoomID, UserID FILTER_SCHEMA = { diff --git a/synapse/api/presence.py b/synapse/api/presence.py new file mode 100644 index 0000000000..18a462f0ee --- /dev/null +++ b/synapse/api/presence.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple + +from synapse.api.constants import PresenceState + + +class UserPresenceState( + namedtuple( + "UserPresenceState", + ( + "user_id", + "state", + "last_active_ts", + "last_federation_update_ts", + "last_user_sync_ts", + "status_msg", + "currently_active", + ), + ) +): + """Represents the current presence state of the user. + + user_id (str) + last_active (int): Time in msec that the user last interacted with server. + last_federation_update (int): Time in msec since either a) we sent a presence + update to other servers or b) we received a presence update, depending + on if is a local user or not. + last_user_sync (int): Time in msec that the user last *completed* a sync + (or event stream). + status_msg (str): User set status message. + """ + + def as_dict(self): + return dict(self._asdict()) + + @staticmethod + def from_dict(d): + return UserPresenceState(**d) + + def copy_and_replace(self, **kwargs): + return self._replace(**kwargs) + + @classmethod + def default(cls, user_id): + """Returns a default presence state. + """ + return cls( + user_id=user_id, + state=PresenceState.OFFLINE, + last_active_ts=0, + last_federation_update_ts=0, + last_user_sync_ts=0, + status_msg=None, + currently_active=False, + ) diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 2b0ab2dcbf..4d65d4aeea 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -37,8 +37,8 @@ from sortedcontainers import SortedDict from twisted.internet import defer +from synapse.api.presence import UserPresenceState from synapse.metrics import LaterGauge -from synapse.storage.presence import UserPresenceState from synapse.util.metrics import Measure from .units import Edu diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 94cc63001e..e53b6ac456 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -22,6 +22,7 @@ from twisted.internet import defer import synapse import synapse.metrics +from synapse.api.presence import UserPresenceState from synapse.events import EventBase from synapse.federation.sender.per_destination_queue import PerDestinationQueue from synapse.federation.sender.transaction_manager import TransactionManager @@ -39,7 +40,6 @@ from synapse.metrics import ( events_processed_counter, ) from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage.presence import UserPresenceState from synapse.types import ReadReceipt from synapse.util.metrics import Measure, measure_func diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 8cbc23d901..c09ffcaf4c 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -24,12 +24,12 @@ from synapse.api.errors import ( HttpResponseException, RequestSendFailed, ) +from synapse.api.presence import UserPresenceState from synapse.events import EventBase from synapse.federation.units import Edu from synapse.handlers.presence import format_user_presence_state from synapse.metrics import sent_transactions_counter from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage.presence import UserPresenceState from synapse.types import ReadReceipt from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 5387b3724f..24e1940ee5 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -33,13 +33,13 @@ from typing_extensions import ContextManager import synapse.metrics from synapse.api.constants import EventTypes, Membership, PresenceState from synapse.api.errors import SynapseError +from synapse.api.presence import UserPresenceState from synapse.logging.context import run_in_background from synapse.logging.utils import log_function from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.state import StateHandler from synapse.storage.databases.main import DataStore -from synapse.storage.presence import UserPresenceState from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches.descriptors import cached diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 9f691e5792..4e3ec02d14 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -15,8 +15,8 @@ from typing import List, Tuple +from synapse.api.presence import UserPresenceState from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.presence import UserPresenceState from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index aaf225894e..8ccfb8fc46 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -39,15 +39,17 @@ what sort order was used: import abc import logging from collections import namedtuple -from typing import Optional +from typing import Dict, Iterable, List, Optional, Tuple from twisted.internet import defer +from synapse.api.filtering import Filter +from synapse.events import EventBase from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_in_list_sql_clause from synapse.storage.databases.main.events_worker import EventsWorkerStore -from synapse.storage.engines import PostgresEngine +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine from synapse.types import RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -68,8 +70,12 @@ _EventDictReturn = namedtuple( def generate_pagination_where_clause( - direction, column_names, from_token, to_token, engine -): + direction: str, + column_names: Tuple[str, str], + from_token: Optional[Tuple[int, int]], + to_token: Optional[Tuple[int, int]], + engine: BaseDatabaseEngine, +) -> str: """Creates an SQL expression to bound the columns by the pagination tokens. @@ -90,21 +96,19 @@ def generate_pagination_where_clause( token, but include those that match the to token. Args: - direction (str): Whether we're paginating backwards("b") or - forwards ("f"). - column_names (tuple[str, str]): The column names to bound. Must *not* - be user defined as these get inserted directly into the SQL - statement without escapes. - from_token (tuple[int, int]|None): The start point for the pagination. - This is an exclusive minimum bound if direction is "f", and an - inclusive maximum bound if direction is "b". - to_token (tuple[int, int]|None): The endpoint point for the pagination. - This is an inclusive maximum bound if direction is "f", and an - exclusive minimum bound if direction is "b". + direction: Whether we're paginating backwards("b") or forwards ("f"). + column_names: The column names to bound. Must *not* be user defined as + these get inserted directly into the SQL statement without escapes. + from_token: The start point for the pagination. This is an exclusive + minimum bound if direction is "f", and an inclusive maximum bound if + direction is "b". + to_token: The endpoint point for the pagination. This is an inclusive + maximum bound if direction is "f", and an exclusive minimum bound if + direction is "b". engine: The database engine to generate the clauses for Returns: - str: The sql expression + The sql expression """ assert direction in ("b", "f") @@ -132,7 +136,12 @@ def generate_pagination_where_clause( return " AND ".join(where_clause) -def _make_generic_sql_bound(bound, column_names, values, engine): +def _make_generic_sql_bound( + bound: str, + column_names: Tuple[str, str], + values: Tuple[Optional[int], int], + engine: BaseDatabaseEngine, +) -> str: """Create an SQL expression that bounds the given column names by the values, e.g. create the equivalent of `(1, 2) < (col1, col2)`. @@ -142,18 +151,18 @@ def _make_generic_sql_bound(bound, column_names, values, engine): out manually. Args: - bound (str): The comparison operator to use. One of ">", "<", ">=", + bound: The comparison operator to use. One of ">", "<", ">=", "<=", where the values are on the left and columns on the right. - names (tuple[str, str]): The column names. Must *not* be user defined + names: The column names. Must *not* be user defined as these get inserted directly into the SQL statement without escapes. - values (tuple[int|None, int]): The values to bound the columns by. If + values: The values to bound the columns by. If the first value is None then only creates a bound on the second column. engine: The database engine to generate the SQL for Returns: - str + The SQL statement """ assert bound in (">", "<", ">=", "<=") @@ -193,7 +202,7 @@ def _make_generic_sql_bound(bound, column_names, values, engine): ) -def filter_to_clause(event_filter): +def filter_to_clause(event_filter: Filter) -> Tuple[str, List[str]]: # NB: This may create SQL clauses that don't optimise well (and we don't # have indices on all possible clauses). E.g. it may create # "room_id == X AND room_id != X", which postgres doesn't optimise. @@ -291,34 +300,35 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): def get_room_min_stream_ordering(self): raise NotImplementedError() - @defer.inlineCallbacks - def get_room_events_stream_for_rooms( - self, room_ids, from_key, to_key, limit=0, order="DESC" - ): + async def get_room_events_stream_for_rooms( + self, + room_ids: Iterable[str], + from_key: str, + to_key: str, + limit: int = 0, + order: str = "DESC", + ) -> Dict[str, Tuple[List[EventBase], str]]: """Get new room events in stream ordering since `from_key`. Args: - room_id (str) - from_key (str): Token from which no events are returned before - to_key (str): Token from which no events are returned after. (This + room_ids + from_key: Token from which no events are returned before + to_key: Token from which no events are returned after. (This is typically the current stream token) - limit (int): Maximum number of events to return - order (str): Either "DESC" or "ASC". Determines which events are + limit: Maximum number of events to return + order: Either "DESC" or "ASC". Determines which events are returned when the result is limited. If "DESC" then the most recent `limit` events are returned, otherwise returns the oldest `limit` events. Returns: - Deferred[dict[str,tuple[list[FrozenEvent], str]]] - A map from room id to a tuple containing: - - list of recent events in the room - - stream ordering key for the start of the chunk of events returned. + A map from room id to a tuple containing: + - list of recent events in the room + - stream ordering key for the start of the chunk of events returned. """ from_id = RoomStreamToken.parse_stream_token(from_key).stream - room_ids = yield self._events_stream_cache.get_entities_changed( - room_ids, from_id - ) + room_ids = self._events_stream_cache.get_entities_changed(room_ids, from_id) if not room_ids: return {} @@ -326,7 +336,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): results = {} room_ids = list(room_ids) for rm_ids in (room_ids[i : i + 20] for i in range(0, len(room_ids), 20)): - res = yield make_deferred_yieldable( + res = await make_deferred_yieldable( defer.gatherResults( [ run_in_background( @@ -361,28 +371,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if self._events_stream_cache.has_entity_changed(room_id, from_key) } - @defer.inlineCallbacks - def get_room_events_stream_for_room( - self, room_id, from_key, to_key, limit=0, order="DESC" - ): + async def get_room_events_stream_for_room( + self, + room_id: str, + from_key: str, + to_key: str, + limit: int = 0, + order: str = "DESC", + ) -> Tuple[List[EventBase], str]: """Get new room events in stream ordering since `from_key`. Args: - room_id (str) - from_key (str): Token from which no events are returned before - to_key (str): Token from which no events are returned after. (This + room_id + from_key: Token from which no events are returned before + to_key: Token from which no events are returned after. (This is typically the current stream token) - limit (int): Maximum number of events to return - order (str): Either "DESC" or "ASC". Determines which events are + limit: Maximum number of events to return + order: Either "DESC" or "ASC". Determines which events are returned when the result is limited. If "DESC" then the most recent `limit` events are returned, otherwise returns the oldest `limit` events. Returns: - Deferred[tuple[list[FrozenEvent], str]]: Returns the list of - events (in ascending order) and the token from the start of - the chunk of events returned. + The list of events (in ascending order) and the token from the start + of the chunk of events returned. """ if from_key == to_key: return [], from_key @@ -390,9 +403,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): from_id = RoomStreamToken.parse_stream_token(from_key).stream to_id = RoomStreamToken.parse_stream_token(to_key).stream - has_changed = yield self._events_stream_cache.has_entity_changed( - room_id, from_id - ) + has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id) if not has_changed: return [], from_key @@ -410,9 +421,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] return rows - rows = yield self.db_pool.runInteraction("get_room_events_stream_for_room", f) + rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f) - ret = yield self.get_events_as_list( + ret = await self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) @@ -430,8 +441,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return ret, key - @defer.inlineCallbacks - def get_membership_changes_for_user(self, user_id, from_key, to_key): + async def get_membership_changes_for_user(self, user_id, from_key, to_key): from_id = RoomStreamToken.parse_stream_token(from_key).stream to_id = RoomStreamToken.parse_stream_token(to_key).stream @@ -460,9 +470,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return rows - rows = yield self.db_pool.runInteraction("get_membership_changes_for_user", f) + rows = await self.db_pool.runInteraction("get_membership_changes_for_user", f) - ret = yield self.get_events_as_list( + ret = await self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) @@ -470,27 +480,26 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return ret - @defer.inlineCallbacks - def get_recent_events_for_room(self, room_id, limit, end_token): + async def get_recent_events_for_room( + self, room_id: str, limit: int, end_token: str + ) -> Tuple[List[EventBase], str]: """Get the most recent events in the room in topological ordering. Args: - room_id (str) - limit (int) - end_token (str): The stream token representing now. + room_id + limit + end_token: The stream token representing now. Returns: - Deferred[tuple[list[FrozenEvent], str]]: Returns a list of - events and a token pointing to the start of the returned - events. - The events returned are in ascending order. + A list of events and a token pointing to the start of the returned + events. The events returned are in ascending order. """ - rows, token = yield self.get_recent_event_ids_for_room( + rows, token = await self.get_recent_event_ids_for_room( room_id, limit, end_token ) - events = yield self.get_events_as_list( + events = await self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) @@ -498,20 +507,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return (events, token) - @defer.inlineCallbacks - def get_recent_event_ids_for_room(self, room_id, limit, end_token): + async def get_recent_event_ids_for_room( + self, room_id: str, limit: int, end_token: str + ) -> Tuple[List[_EventDictReturn], str]: """Get the most recent events in the room in topological ordering. Args: - room_id (str) - limit (int) - end_token (str): The stream token representing now. + room_id + limit + end_token: The stream token representing now. Returns: - Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of - _EventDictReturn and a token pointing to the start of the returned - events. - The events returned are in ascending order. + A list of _EventDictReturn and a token pointing to the start of the + returned events. The events returned are in ascending order. """ # Allow a zero limit here, and no-op. if limit == 0: @@ -519,7 +527,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): end_token = RoomStreamToken.parse(end_token) - rows, token = yield self.db_pool.runInteraction( + rows, token = await self.db_pool.runInteraction( "get_recent_event_ids_for_room", self._paginate_room_events_txn, room_id, @@ -532,12 +540,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return rows, token - def get_room_event_before_stream_ordering(self, room_id, stream_ordering): + def get_room_event_before_stream_ordering(self, room_id: str, stream_ordering: int): """Gets details of the first event in a room at or before a stream ordering Args: - room_id (str): - stream_ordering (int): + room_id: + stream_ordering: Returns: Deferred[(int, int, str)]: @@ -574,55 +582,56 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) return "t%d-%d" % (topo, token) - def get_stream_token_for_event(self, event_id): + async def get_stream_token_for_event(self, event_id: str) -> str: """The stream token for an event Args: - event_id(str): The id of the event to look up a stream token for. + event_id: The id of the event to look up a stream token for. Raises: StoreError if the event wasn't in the database. Returns: - A deferred "s%d" stream token. + A "s%d" stream token. """ - return self.db_pool.simple_select_one_onecol( + row = await self.db_pool.simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering" - ).addCallback(lambda row: "s%d" % (row,)) + ) + return "s%d" % (row,) - def get_topological_token_for_event(self, event_id): + async def get_topological_token_for_event(self, event_id: str) -> str: """The stream token for an event Args: - event_id(str): The id of the event to look up a stream token for. + event_id: The id of the event to look up a stream token for. Raises: StoreError if the event wasn't in the database. Returns: - A deferred "t%d-%d" topological token. + A "t%d-%d" topological token. """ - return self.db_pool.simple_select_one( + row = await self.db_pool.simple_select_one( table="events", keyvalues={"event_id": event_id}, retcols=("stream_ordering", "topological_ordering"), desc="get_topological_token_for_event", - ).addCallback( - lambda row: "t%d-%d" % (row["topological_ordering"], row["stream_ordering"]) ) + return "t%d-%d" % (row["topological_ordering"], row["stream_ordering"]) - def get_max_topological_token(self, room_id, stream_key): + async def get_max_topological_token(self, room_id: str, stream_key: int) -> int: """Get the max topological token in a room before the given stream ordering. Args: - room_id (str) - stream_key (int) + room_id + stream_key Returns: - Deferred[int] + The maximum topological token. """ sql = ( "SELECT coalesce(max(topological_ordering), 0) FROM events" " WHERE room_id = ? AND stream_ordering < ?" ) - return self.db_pool.execute( + row = await self.db_pool.execute( "get_max_topological_token", None, sql, room_id, stream_key - ).addCallback(lambda r: r[0][0] if r else 0) + ) + return row[0][0] if row else 0 def _get_max_topological_txn(self, txn, room_id): txn.execute( @@ -634,16 +643,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return rows[0][0] if rows else 0 @staticmethod - def _set_before_and_after(events, rows, topo_order=True): + def _set_before_and_after( + events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True + ): """Inserts ordering information to events' internal metadata from the DB rows. Args: - events (list[FrozenEvent]) - rows (list[_EventDictReturn]) - topo_order (bool): Whether the events were ordered topologically - or by stream ordering. If true then all rows should have a non - null topological_ordering. + events + rows + topo_order: Whether the events were ordered topologically or by stream + ordering. If true then all rows should have a non null + topological_ordering. """ for event, row in zip(events, rows): stream = row.stream_ordering @@ -656,25 +667,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): internal.after = str(RoomStreamToken(topo, stream)) internal.order = (int(topo) if topo else 0, int(stream)) - @defer.inlineCallbacks - def get_events_around( - self, room_id, event_id, before_limit, after_limit, event_filter=None - ): + async def get_events_around( + self, + room_id: str, + event_id: str, + before_limit: int, + after_limit: int, + event_filter: Optional[Filter] = None, + ) -> dict: """Retrieve events and pagination tokens around a given event in a room. - - Args: - room_id (str) - event_id (str) - before_limit (int) - after_limit (int) - event_filter (Filter|None) - - Returns: - dict """ - results = yield self.db_pool.runInteraction( + results = await self.db_pool.runInteraction( "get_events_around", self._get_events_around_txn, room_id, @@ -684,11 +689,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): event_filter, ) - events_before = yield self.get_events_as_list( + events_before = await self.get_events_as_list( list(results["before"]["event_ids"]), get_prev_content=True ) - events_after = yield self.get_events_as_list( + events_after = await self.get_events_as_list( list(results["after"]["event_ids"]), get_prev_content=True ) @@ -700,17 +705,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): } def _get_events_around_txn( - self, txn, room_id, event_id, before_limit, after_limit, event_filter - ): + self, + txn, + room_id: str, + event_id: str, + before_limit: int, + after_limit: int, + event_filter: Optional[Filter], + ) -> dict: """Retrieves event_ids and pagination tokens around a given event in a room. Args: - room_id (str) - event_id (str) - before_limit (int) - after_limit (int) - event_filter (Filter|None) + room_id + event_id + before_limit + after_limit + event_filter Returns: dict @@ -758,22 +769,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): "after": {"event_ids": events_after, "token": end_token}, } - @defer.inlineCallbacks - def get_all_new_events_stream(self, from_id, current_id, limit): + async def get_all_new_events_stream( + self, from_id: int, current_id: int, limit: int + ) -> Tuple[int, List[EventBase]]: """Get all new events Returns all events with from_id < stream_ordering <= current_id. Args: - from_id (int): the stream_ordering of the last event we processed - current_id (int): the stream_ordering of the most recently processed event - limit (int): the maximum number of events to return + from_id: the stream_ordering of the last event we processed + current_id: the stream_ordering of the most recently processed event + limit: the maximum number of events to return Returns: - Deferred[Tuple[int, list[FrozenEvent]]]: A tuple of (next_id, events), where - `next_id` is the next value to pass as `from_id` (it will either be the - stream_ordering of the last returned event, or, if fewer than `limit` events - were found, `current_id`. + A tuple of (next_id, events), where `next_id` is the next value to + pass as `from_id` (it will either be the stream_ordering of the + last returned event, or, if fewer than `limit` events were found, + the `current_id`). """ def get_all_new_events_stream_txn(txn): @@ -795,11 +807,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return upper_bound, [row[1] for row in rows] - upper_bound, event_ids = yield self.db_pool.runInteraction( + upper_bound, event_ids = await self.db_pool.runInteraction( "get_all_new_events_stream", get_all_new_events_stream_txn ) - events = yield self.get_events_as_list(event_ids) + events = await self.get_events_as_list(event_ids) return upper_bound, events @@ -817,21 +829,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): desc="get_federation_out_pos", ) - async def update_federation_out_pos(self, typ, stream_id): + async def update_federation_out_pos(self, typ: str, stream_id: int) -> None: if self._need_to_reset_federation_stream_positions: await self.db_pool.runInteraction( "_reset_federation_positions_txn", self._reset_federation_positions_txn ) self._need_to_reset_federation_stream_positions = False - return await self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="federation_stream_position", keyvalues={"type": typ, "instance_name": self._instance_name}, updatevalues={"stream_id": stream_id}, desc="update_federation_out_pos", ) - def _reset_federation_positions_txn(self, txn): + def _reset_federation_positions_txn(self, txn) -> None: """Fiddles with the `federation_stream_position` table to make it match the configured federation sender instances during start up. """ @@ -892,39 +904,37 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): values={"stream_id": stream_id}, ) - def has_room_changed_since(self, room_id, stream_id): + def has_room_changed_since(self, room_id: str, stream_id: int) -> bool: return self._events_stream_cache.has_entity_changed(room_id, stream_id) def _paginate_room_events_txn( self, txn, - room_id, - from_token, - to_token=None, - direction="b", - limit=-1, - event_filter=None, - ): + room_id: str, + from_token: RoomStreamToken, + to_token: Optional[RoomStreamToken] = None, + direction: str = "b", + limit: int = -1, + event_filter: Optional[Filter] = None, + ) -> Tuple[List[_EventDictReturn], str]: """Returns list of events before or after a given token. Args: txn - room_id (str) - from_token (RoomStreamToken): The token used to stream from - to_token (RoomStreamToken|None): A token which if given limits the - results to only those before - direction(char): Either 'b' or 'f' to indicate whether we are - paginating forwards or backwards from `from_key`. - limit (int): The maximum number of events to return. - event_filter (Filter|None): If provided filters the events to + room_id + from_token: The token used to stream from + to_token: A token which if given limits the results to only those before + direction: Either 'b' or 'f' to indicate whether we are paginating + forwards or backwards from `from_key`. + limit: The maximum number of events to return. + event_filter: If provided filters the events to those that match the filter. Returns: - Deferred[tuple[list[_EventDictReturn], str]]: Returns the results - as a list of _EventDictReturn and a token that points to the end - of the result set. If no events are returned then the end of the - stream has been reached (i.e. there are no events between - `from_token` and `to_token`), or `limit` is zero. + A list of _EventDictReturn and a token that points to the end of the + result set. If no events are returned then the end of the stream has + been reached (i.e. there are no events between `from_token` and + `to_token`), or `limit` is zero. """ assert int(limit) >= 0 @@ -1008,35 +1018,38 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return rows, str(next_token) - @defer.inlineCallbacks - def paginate_room_events( - self, room_id, from_key, to_key=None, direction="b", limit=-1, event_filter=None - ): + async def paginate_room_events( + self, + room_id: str, + from_key: str, + to_key: Optional[str] = None, + direction: str = "b", + limit: int = -1, + event_filter: Optional[Filter] = None, + ) -> Tuple[List[EventBase], str]: """Returns list of events before or after a given token. Args: - room_id (str) - from_key (str): The token used to stream from - to_key (str|None): A token which if given limits the results to - only those before - direction(char): Either 'b' or 'f' to indicate whether we are - paginating forwards or backwards from `from_key`. - limit (int): The maximum number of events to return. - event_filter (Filter|None): If provided filters the events to - those that match the filter. + room_id + from_key: The token used to stream from + to_key: A token which if given limits the results to only those before + direction: Either 'b' or 'f' to indicate whether we are paginating + forwards or backwards from `from_key`. + limit: The maximum number of events to return. + event_filter: If provided filters the events to those that match the filter. Returns: - tuple[list[FrozenEvent], str]: Returns the results as a list of - events and a token that points to the end of the result set. If no - events are returned then the end of the stream has been reached - (i.e. there are no events between `from_key` and `to_key`). + The results as a list of events and a token that points to the end + of the result set. If no events are returned then the end of the + stream has been reached (i.e. there are no events between `from_key` + and `to_key`). """ from_key = RoomStreamToken.parse(from_key) if to_key: to_key = RoomStreamToken.parse(to_key) - rows, token = yield self.db_pool.runInteraction( + rows, token = await self.db_pool.runInteraction( "paginate_room_events", self._paginate_room_events_txn, room_id, @@ -1047,7 +1060,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): event_filter, ) - events = yield self.get_events_as_list( + events = await self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) @@ -1057,8 +1070,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): class StreamStore(StreamWorkerStore): - def get_room_max_stream_ordering(self): + def get_room_max_stream_ordering(self) -> int: return self._stream_id_gen.get_current_token() - def get_room_min_stream_ordering(self): + def get_room_min_stream_ordering(self) -> int: return self._backfill_id_gen.get_current_token() diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py deleted file mode 100644 index 18a462f0ee..0000000000 --- a/synapse/storage/presence.py +++ /dev/null @@ -1,69 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import namedtuple - -from synapse.api.constants import PresenceState - - -class UserPresenceState( - namedtuple( - "UserPresenceState", - ( - "user_id", - "state", - "last_active_ts", - "last_federation_update_ts", - "last_user_sync_ts", - "status_msg", - "currently_active", - ), - ) -): - """Represents the current presence state of the user. - - user_id (str) - last_active (int): Time in msec that the user last interacted with server. - last_federation_update (int): Time in msec since either a) we sent a presence - update to other servers or b) we received a presence update, depending - on if is a local user or not. - last_user_sync (int): Time in msec that the user last *completed* a sync - (or event stream). - status_msg (str): User set status message. - """ - - def as_dict(self): - return dict(self._asdict()) - - @staticmethod - def from_dict(d): - return UserPresenceState(**d) - - def copy_and_replace(self, **kwargs): - return self._replace(**kwargs) - - @classmethod - def default(cls, user_id): - """Returns a default presence state. - """ - return cls( - user_id=user_id, - state=PresenceState.OFFLINE, - last_active_ts=0, - last_federation_update_ts=0, - last_user_sync_ts=0, - status_msg=None, - currently_active=False, - ) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 05ea40a7de..306dcfe944 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -19,6 +19,7 @@ from mock import Mock, call from signedjson.key import generate_signing_key from synapse.api.constants import EventTypes, Membership, PresenceState +from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events.builder import EventBuilder from synapse.handlers.presence import ( @@ -32,7 +33,6 @@ from synapse.handlers.presence import ( handle_update, ) from synapse.rest.client.v1 import room -from synapse.storage.presence import UserPresenceState from synapse.types import UserID, get_domain_from_id from tests import unittest diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index a6012c973d..918387733b 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -15,6 +15,7 @@ from twisted.internet import defer +from synapse.api.errors import NotFoundError from synapse.rest.client.v1 import room from tests.unittest import HomeserverTestCase @@ -46,30 +47,19 @@ class PurgeTests(HomeserverTestCase): storage = self.hs.get_storage() # Get the topological token - event = store.get_topological_token_for_event(last["event_id"]) - self.pump() - event = self.successResultOf(event) - - # Purge everything before this topological token - purge = defer.ensureDeferred( - storage.purge_events.purge_history(self.room_id, event, True) + event = self.get_success( + store.get_topological_token_for_event(last["event_id"]) ) - self.pump() - self.assertEqual(self.successResultOf(purge), None) - # Try and get the events - get_first = store.get_event(first["event_id"]) - get_second = store.get_event(second["event_id"]) - get_third = store.get_event(third["event_id"]) - get_last = store.get_event(last["event_id"]) - self.pump() + # Purge everything before this topological token + self.get_success(storage.purge_events.purge_history(self.room_id, event, True)) # 1-3 should fail and last will succeed, meaning that 1-3 are deleted # and last is not. - self.failureResultOf(get_first) - self.failureResultOf(get_second) - self.failureResultOf(get_third) - self.successResultOf(get_last) + self.get_failure(store.get_event(first["event_id"]), NotFoundError) + self.get_failure(store.get_event(second["event_id"]), NotFoundError) + self.get_failure(store.get_event(third["event_id"]), NotFoundError) + self.get_success(store.get_event(last["event_id"])) def test_purge_wont_delete_extrems(self): """ @@ -84,9 +74,9 @@ class PurgeTests(HomeserverTestCase): storage = self.hs.get_datastore() # Set the topological token higher than it should be - event = storage.get_topological_token_for_event(last["event_id"]) - self.pump() - event = self.successResultOf(event) + event = self.get_success( + storage.get_topological_token_for_event(last["event_id"]) + ) event = "t{}-{}".format( *list(map(lambda x: x + 1, map(int, event[1:].split("-")))) ) @@ -98,14 +88,7 @@ class PurgeTests(HomeserverTestCase): self.assertIn("greater than forward", f.value.args[0]) # Try and get the events - get_first = storage.get_event(first["event_id"]) - get_second = storage.get_event(second["event_id"]) - get_third = storage.get_event(third["event_id"]) - get_last = storage.get_event(last["event_id"]) - self.pump() - - # Nothing is deleted. - self.successResultOf(get_first) - self.successResultOf(get_second) - self.successResultOf(get_third) - self.successResultOf(get_last) + self.get_success(storage.get_event(first["event_id"])) + self.get_success(storage.get_event(second["event_id"])) + self.get_success(storage.get_event(third["event_id"])) + self.get_success(storage.get_event(last["event_id"])) -- cgit 1.5.1 From 050e20e7ca56c3a5985fdcf64012800c153260f2 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 17 Aug 2020 12:18:01 -0400 Subject: Convert some of the general database methods to async (#8100) --- changelog.d/8100.misc | 1 + synapse/storage/database.py | 23 ++++++++----------- synapse/storage/databases/main/appservice.py | 2 +- synapse/storage/databases/main/events_worker.py | 16 +++++++------ synapse/storage/databases/main/registration.py | 8 +++---- synapse/storage/databases/main/roommember.py | 4 ++-- tests/handlers/test_profile.py | 4 ++-- tests/handlers/test_typing.py | 2 +- tests/storage/test_appservice.py | 16 +++++++++---- tests/storage/test_base.py | 16 ++++++++----- tests/storage/test_event_push_actions.py | 30 +++++++++++++------------ tests/storage/test_main.py | 2 +- tests/storage/test_profile.py | 4 ++-- 13 files changed, 69 insertions(+), 59 deletions(-) create mode 100644 changelog.d/8100.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/8100.misc b/changelog.d/8100.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8100.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 4ada6f5563..8a9e06efcf 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -332,8 +332,7 @@ class DatabasePool(object): """ return self._db_pool.running - @defer.inlineCallbacks - def _check_safe_to_upsert(self): + async def _check_safe_to_upsert(self): """ Is it safe to use native UPSERT? @@ -342,7 +341,7 @@ class DatabasePool(object): If the background updates have not completed, wait 15 sec and check again. """ - updates = yield self.simple_select_list( + updates = await self.simple_select_list( "background_updates", keyvalues=None, retcols=["update_name"], @@ -614,8 +613,7 @@ class DatabasePool(object): # "Simple" SQL API methods that operate on a single table with no JOINs, # no complex WHERE clauses, just a dict of values for columns. - @defer.inlineCallbacks - def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"): + async def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"): """Executes an INSERT query on the named table. Args: @@ -631,7 +629,7 @@ class DatabasePool(object): `or_ignore` is True """ try: - yield self.runInteraction(desc, self.simple_insert_txn, table, values) + await self.runInteraction(desc, self.simple_insert_txn, table, values) except self.engine.module.IntegrityError: # We have to do or_ignore flag at this layer, since we can't reuse # a cursor after we receive an error from the db. @@ -684,8 +682,7 @@ class DatabasePool(object): txn.executemany(sql, vals) - @defer.inlineCallbacks - def simple_upsert( + async def simple_upsert( self, table, keyvalues, @@ -714,14 +711,14 @@ class DatabasePool(object): inserting lock (bool): True to lock the table when doing the upsert. Returns: - Deferred(None or bool): Native upserts always return None. Emulated + None or bool: Native upserts always return None. Emulated upserts return True if a new entry was created, False if an existing one was updated. """ attempts = 0 while True: try: - result = yield self.runInteraction( + return await self.runInteraction( desc, self.simple_upsert_txn, table, @@ -730,7 +727,6 @@ class DatabasePool(object): insertion_values, lock=lock, ) - return result except self.engine.module.IntegrityError as e: attempts += 1 if attempts >= 5: @@ -1121,8 +1117,7 @@ class DatabasePool(object): return cls.cursor_to_dict(txn) - @defer.inlineCallbacks - def simple_select_many_batch( + async def simple_select_many_batch( self, table, column, @@ -1156,7 +1151,7 @@ class DatabasePool(object): it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size) ] for chunk in chunks: - rows = yield self.runInteraction( + rows = await self.runInteraction( desc, self.simple_select_many_txn, table, diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 5cf1a88399..02568a2391 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -169,7 +169,7 @@ class ApplicationServiceTransactionWorkerStore( service(ApplicationService): The service whose state to set. state(ApplicationServiceState): The connectivity state to apply. Returns: - A Deferred which resolves when the state was set successfully. + An Awaitable which resolves when the state was set successfully. """ return self.db_pool.simple_upsert( "application_services_state", {"as_id": service.id}, {"state": state} diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 5687448e3d..8c63a0dc4d 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -847,13 +847,15 @@ class EventsWorkerStore(SQLBaseStore): """Given a list of event ids, check if we have already processed and stored them as non outliers. """ - rows = yield self.db_pool.simple_select_many_batch( - table="events", - retcols=("event_id",), - column="event_id", - iterable=list(event_ids), - keyvalues={"outlier": False}, - desc="have_events_in_timeline", + rows = yield defer.ensureDeferred( + self.db_pool.simple_select_many_batch( + table="events", + retcols=("event_id",), + column="event_id", + iterable=list(event_ids), + keyvalues={"outlier": False}, + desc="have_events_in_timeline", + ) ) return {r["event_id"] for r in rows} diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index de50fa6e94..068ad22b30 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -17,9 +17,7 @@ import logging import re -from typing import Dict, List, Optional - -from twisted.internet.defer import Deferred +from typing import Awaitable, Dict, List, Optional from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError @@ -563,7 +561,7 @@ class RegistrationWorkerStore(SQLBaseStore): id_server (str) Returns: - Deferred + Awaitable """ # We need to use an upsert, in case they user had already bound the # threepid @@ -1084,7 +1082,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): def record_user_external_id( self, auth_provider: str, external_id: str, user_id: str - ) -> Deferred: + ) -> Awaitable: """Record a mapping from an external user id to a mxid Args: diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 1cc8c08ed0..161edbeccb 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -767,13 +767,13 @@ class RoomMemberWorkerStore(EventsWorkerStore): return set(room_ids) - def get_membership_from_event_ids( + async def get_membership_from_event_ids( self, member_event_ids: Iterable[str] ) -> List[dict]: """Get user_id and membership of a set of event IDs. """ - return self.db_pool.simple_select_many_batch( + return await self.db_pool.simple_select_many_batch( table="room_memberships", column="event_id", iterable=member_event_ids, diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index d70e1fc608..b609b30d4a 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -64,7 +64,7 @@ class ProfileTestCase(unittest.TestCase): self.bob = UserID.from_string("@4567:test") self.alice = UserID.from_string("@alice:remote") - yield self.store.create_profile(self.frank.localpart) + yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart)) self.handler = hs.get_profile_handler() self.hs = hs @@ -157,7 +157,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_incoming_fed_query(self): - yield self.store.create_profile("caroline") + yield defer.ensureDeferred(self.store.create_profile("caroline")) yield self.store.set_profile_displayname("caroline", "Caroline") response = yield defer.ensureDeferred( diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 64afd581bc..e01de158e5 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -156,7 +156,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ([], 0) ) self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None - self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed( + self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable( None ) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 98b74890d5..a425e66f37 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -207,7 +207,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_appservices_state_down(self): service = Mock(id=self.as_list[1]["id"]) - yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) + yield defer.ensureDeferred( + self.store.set_appservice_state(service, ApplicationServiceState.DOWN) + ) rows = yield self.db_pool.runQuery( self.engine.convert_param_style( "SELECT as_id FROM application_services_state WHERE state=?" @@ -219,9 +221,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_appservices_state_multiple_up(self): service = Mock(id=self.as_list[1]["id"]) - yield self.store.set_appservice_state(service, ApplicationServiceState.UP) - yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) - yield self.store.set_appservice_state(service, ApplicationServiceState.UP) + yield defer.ensureDeferred( + self.store.set_appservice_state(service, ApplicationServiceState.UP) + ) + yield defer.ensureDeferred( + self.store.set_appservice_state(service, ApplicationServiceState.DOWN) + ) + yield defer.ensureDeferred( + self.store.set_appservice_state(service, ApplicationServiceState.UP) + ) rows = yield self.db_pool.runQuery( self.engine.convert_param_style( "SELECT as_id FROM application_services_state WHERE state=?" diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index efcaeef1e7..13bcac743a 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -66,8 +66,10 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_insert_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore.db_pool.simple_insert( - table="tablename", values={"columname": "Value"} + yield defer.ensureDeferred( + self.datastore.db_pool.simple_insert( + table="tablename", values={"columname": "Value"} + ) ) self.mock_txn.execute.assert_called_with( @@ -78,10 +80,12 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_insert_3cols(self): self.mock_txn.rowcount = 1 - yield self.datastore.db_pool.simple_insert( - table="tablename", - # Use OrderedDict() so we can assert on the SQL generated - values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]), + yield defer.ensureDeferred( + self.datastore.db_pool.simple_insert( + table="tablename", + # Use OrderedDict() so we can assert on the SQL generated + values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]), + ) ) self.mock_txn.execute.assert_called_with( diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 857db071d4..238bad5b45 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -142,20 +142,22 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_find_first_stream_ordering_after_ts(self): def add_event(so, ts): - return self.store.db_pool.simple_insert( - "events", - { - "stream_ordering": so, - "received_ts": ts, - "event_id": "event%i" % so, - "type": "", - "room_id": "", - "content": "", - "processed": True, - "outlier": False, - "topological_ordering": 0, - "depth": 0, - }, + return defer.ensureDeferred( + self.store.db_pool.simple_insert( + "events", + { + "stream_ordering": so, + "received_ts": ts, + "event_id": "event%i" % so, + "type": "", + "room_id": "", + "content": "", + "processed": True, + "outlier": False, + "topological_ordering": 0, + "depth": 0, + }, + ) ) # start with the base case where there are no events in the table diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py index ab0df5ea93..fbf8af940a 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py @@ -35,7 +35,7 @@ class DataStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_users_paginate(self): yield self.store.register_user(self.user.to_string(), "pass") - yield self.store.create_profile(self.user.localpart) + yield defer.ensureDeferred(self.store.create_profile(self.user.localpart)) yield self.store.set_profile_displayname(self.user.localpart, self.displayname) users, total = yield self.store.get_users_paginate( diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 9b6f7211ae..9d5b8aa47d 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -33,7 +33,7 @@ class ProfileStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_displayname(self): - yield self.store.create_profile(self.u_frank.localpart) + yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank") @@ -43,7 +43,7 @@ class ProfileStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def test_avatar_url(self): - yield self.store.create_profile(self.u_frank.localpart) + yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) yield self.store.set_profile_avatar_url( self.u_frank.localpart, "http://my.site/here" -- cgit 1.5.1