From cc9bb3dc3f299d451ab523dea192e48c32e87c68 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 22 Jul 2020 12:29:15 -0400 Subject: Convert the message handler to async/await. (#7884) --- tests/storage/test_state.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'tests/storage/test_state.py') diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 0b88308ff4..a0e133cd4a 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -64,8 +64,8 @@ class StateStoreTestCase(tests.unittest.TestCase): }, ) - event, context = yield self.event_creation_handler.create_new_client_event( - builder + event, context = yield defer.ensureDeferred( + self.event_creation_handler.create_new_client_event(builder) ) yield self.storage.persistence.persist_event(event, context) -- cgit 1.5.1 From 3345c166a45cb4a8f87c583ee0476c2bca5c41bd Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 28 Jul 2020 16:09:53 -0400 Subject: Convert storage layer to async/await. (#7963) --- changelog.d/7963.misc | 1 + synapse/storage/persist_events.py | 40 ++++---- synapse/storage/purge_events.py | 38 ++++--- synapse/storage/state.py | 207 ++++++++++++++++++++------------------ tests/storage/test_purge.py | 8 +- tests/storage/test_room.py | 6 +- tests/storage/test_state.py | 64 +++++++----- tests/test_visibility.py | 14 ++- tests/utils.py | 16 +-- tox.ini | 1 + 10 files changed, 210 insertions(+), 185 deletions(-) create mode 100644 changelog.d/7963.misc (limited to 'tests/storage/test_state.py') diff --git a/changelog.d/7963.misc b/changelog.d/7963.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/7963.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 78fbdcdee8..4a164834d9 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -25,7 +25,7 @@ from prometheus_client import Counter, Histogram from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.events import FrozenEvent +from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process @@ -192,12 +192,11 @@ class EventsPersistenceStorage(object): self._event_persist_queue = _EventPeristenceQueue() self._state_resolution_handler = hs.get_state_resolution_handler() - @defer.inlineCallbacks - def persist_events( + async def persist_events( self, - events_and_contexts: List[Tuple[FrozenEvent, EventContext]], + events_and_contexts: List[Tuple[EventBase, EventContext]], backfilled: bool = False, - ): + ) -> int: """ Write events to the database Args: @@ -207,7 +206,7 @@ class EventsPersistenceStorage(object): which might update the current state etc. Returns: - Deferred[int]: the stream ordering of the latest persisted event + the stream ordering of the latest persisted event """ partitioned = {} for event, ctx in events_and_contexts: @@ -223,22 +222,19 @@ class EventsPersistenceStorage(object): for room_id in partitioned: self._maybe_start_persisting(room_id) - yield make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults(deferreds, consumeErrors=True) ) - max_persisted_id = yield self.main_store.get_current_events_token() - - return max_persisted_id + return self.main_store.get_current_events_token() - @defer.inlineCallbacks - def persist_event( - self, event: FrozenEvent, context: EventContext, backfilled: bool = False - ): + async def persist_event( + self, event: EventBase, context: EventContext, backfilled: bool = False + ) -> Tuple[int, int]: """ Returns: - Deferred[Tuple[int, int]]: the stream ordering of ``event``, - and the stream ordering of the latest persisted event + The stream ordering of `event`, and the stream ordering of the + latest persisted event """ deferred = self._event_persist_queue.add_to_queue( event.room_id, [(event, context)], backfilled=backfilled @@ -246,9 +242,9 @@ class EventsPersistenceStorage(object): self._maybe_start_persisting(event.room_id) - yield make_deferred_yieldable(deferred) + await make_deferred_yieldable(deferred) - max_persisted_id = yield self.main_store.get_current_events_token() + max_persisted_id = self.main_store.get_current_events_token() return (event.internal_metadata.stream_ordering, max_persisted_id) def _maybe_start_persisting(self, room_id: str): @@ -262,7 +258,7 @@ class EventsPersistenceStorage(object): async def _persist_events( self, - events_and_contexts: List[Tuple[FrozenEvent, EventContext]], + events_and_contexts: List[Tuple[EventBase, EventContext]], backfilled: bool = False, ): """Calculates the change to current state and forward extremities, and @@ -439,7 +435,7 @@ class EventsPersistenceStorage(object): async def _calculate_new_extremities( self, room_id: str, - event_contexts: List[Tuple[FrozenEvent, EventContext]], + event_contexts: List[Tuple[EventBase, EventContext]], latest_event_ids: List[str], ): """Calculates the new forward extremities for a room given events to @@ -497,7 +493,7 @@ class EventsPersistenceStorage(object): async def _get_new_state_after_events( self, room_id: str, - events_context: List[Tuple[FrozenEvent, EventContext]], + events_context: List[Tuple[EventBase, EventContext]], old_latest_event_ids: Iterable[str], new_latest_event_ids: Iterable[str], ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]: @@ -683,7 +679,7 @@ class EventsPersistenceStorage(object): async def _is_server_still_joined( self, room_id: str, - ev_ctx_rm: List[Tuple[FrozenEvent, EventContext]], + ev_ctx_rm: List[Tuple[EventBase, EventContext]], delta: DeltaState, current_state: Optional[StateMap[str]], potentially_left_users: Set[str], diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py index fdc0abf5cf..79d9f06e2e 100644 --- a/synapse/storage/purge_events.py +++ b/synapse/storage/purge_events.py @@ -15,8 +15,7 @@ import itertools import logging - -from twisted.internet import defer +from typing import Set logger = logging.getLogger(__name__) @@ -28,49 +27,48 @@ class PurgeEventsStorage(object): def __init__(self, hs, stores): self.stores = stores - @defer.inlineCallbacks - def purge_room(self, room_id: str): + async def purge_room(self, room_id: str): """Deletes all record of a room """ - state_groups_to_delete = yield self.stores.main.purge_room(room_id) - yield self.stores.state.purge_room_state(room_id, state_groups_to_delete) + state_groups_to_delete = await self.stores.main.purge_room(room_id) + await self.stores.state.purge_room_state(room_id, state_groups_to_delete) - @defer.inlineCallbacks - def purge_history(self, room_id, token, delete_local_events): + async def purge_history( + self, room_id: str, token: str, delete_local_events: bool + ) -> None: """Deletes room history before a certain point Args: - room_id (str): + room_id: The room ID - token (str): A topological token to delete events before + token: A topological token to delete events before - delete_local_events (bool): + delete_local_events: if True, we will delete local events as well as remote ones (instead of just marking them as outliers and deleting their state groups). """ - state_groups = yield self.stores.main.purge_history( + state_groups = await self.stores.main.purge_history( room_id, token, delete_local_events ) logger.info("[purge] finding state groups that can be deleted") - sg_to_delete = yield self._find_unreferenced_groups(state_groups) + sg_to_delete = await self._find_unreferenced_groups(state_groups) - yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete) + await self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete) - @defer.inlineCallbacks - def _find_unreferenced_groups(self, state_groups): + async def _find_unreferenced_groups(self, state_groups: Set[int]) -> Set[int]: """Used when purging history to figure out which state groups can be deleted. Args: - state_groups (set[int]): Set of state groups referenced by events + state_groups: Set of state groups referenced by events that are going to be deleted. Returns: - Deferred[set[int]] The set of state groups that can be deleted. + The set of state groups that can be deleted. """ # Graph of state group -> previous group graph = {} @@ -93,7 +91,7 @@ class PurgeEventsStorage(object): current_search = set(itertools.islice(next_to_search, 100)) next_to_search -= current_search - referenced = yield self.stores.main.get_referenced_state_groups( + referenced = await self.stores.main.get_referenced_state_groups( current_search ) referenced_groups |= referenced @@ -102,7 +100,7 @@ class PurgeEventsStorage(object): # groups that are referenced. current_search -= referenced - edges = yield self.stores.state.get_previous_state_groups(current_search) + edges = await self.stores.state.get_previous_state_groups(current_search) prevs = set(edges.values()) # We don't bother re-handling groups we've already seen diff --git a/synapse/storage/state.py b/synapse/storage/state.py index dc568476f4..49ee9c9a74 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -14,13 +14,12 @@ # limitations under the License. import logging -from typing import Iterable, List, TypeVar +from typing import Dict, Iterable, List, Optional, Set, Tuple, TypeVar import attr -from twisted.internet import defer - from synapse.api.constants import EventTypes +from synapse.events import EventBase from synapse.types import StateMap logger = logging.getLogger(__name__) @@ -34,16 +33,16 @@ class StateFilter(object): """A filter used when querying for state. Attributes: - types (dict[str, set[str]|None]): Map from type to set of state keys (or - None). This specifies which state_keys for the given type to fetch - from the DB. If None then all events with that type are fetched. If - the set is empty then no events with that type are fetched. - include_others (bool): Whether to fetch events with types that do not + types: Map from type to set of state keys (or None). This specifies + which state_keys for the given type to fetch from the DB. If None + then all events with that type are fetched. If the set is empty + then no events with that type are fetched. + include_others: Whether to fetch events with types that do not appear in `types`. """ - types = attr.ib() - include_others = attr.ib(default=False) + types = attr.ib(type=Dict[str, Optional[Set[str]]]) + include_others = attr.ib(default=False, type=bool) def __attrs_post_init__(self): # If `include_others` is set we canonicalise the filter by removing @@ -52,36 +51,35 @@ class StateFilter(object): self.types = {k: v for k, v in self.types.items() if v is not None} @staticmethod - def all(): + def all() -> "StateFilter": """Creates a filter that fetches everything. Returns: - StateFilter + The new state filter. """ return StateFilter(types={}, include_others=True) @staticmethod - def none(): + def none() -> "StateFilter": """Creates a filter that fetches nothing. Returns: - StateFilter + The new state filter. """ return StateFilter(types={}, include_others=False) @staticmethod - def from_types(types): + def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter": """Creates a filter that only fetches the given types Args: - types (Iterable[tuple[str, str|None]]): A list of type and state - keys to fetch. A state_key of None fetches everything for - that type + types: A list of type and state keys to fetch. A state_key of None + fetches everything for that type Returns: - StateFilter + The new state filter. """ - type_dict = {} + type_dict = {} # type: Dict[str, Optional[Set[str]]] for typ, s in types: if typ in type_dict: if type_dict[typ] is None: @@ -91,24 +89,24 @@ class StateFilter(object): type_dict[typ] = None continue - type_dict.setdefault(typ, set()).add(s) + type_dict.setdefault(typ, set()).add(s) # type: ignore return StateFilter(types=type_dict) @staticmethod - def from_lazy_load_member_list(members): + def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter": """Creates a filter that returns all non-member events, plus the member events for the given users Args: - members (iterable[str]): Set of user IDs + members: Set of user IDs Returns: - StateFilter + The new state filter """ return StateFilter(types={EventTypes.Member: set(members)}, include_others=True) - def return_expanded(self): + def return_expanded(self) -> "StateFilter": """Creates a new StateFilter where type wild cards have been removed (except for memberships). The returned filter is a superset of the current one, i.e. anything that passes the current filter will pass @@ -130,7 +128,7 @@ class StateFilter(object): return all non-member events Returns: - StateFilter + The new state filter. """ if self.is_full(): @@ -167,7 +165,7 @@ class StateFilter(object): include_others=True, ) - def make_sql_filter_clause(self): + def make_sql_filter_clause(self) -> Tuple[str, List[str]]: """Converts the filter to an SQL clause. For example: @@ -179,13 +177,12 @@ class StateFilter(object): Returns: - tuple[str, list]: The SQL string (may be empty) and arguments. An - empty SQL string is returned when the filter matches everything - (i.e. is "full"). + The SQL string (may be empty) and arguments. An empty SQL string is + returned when the filter matches everything (i.e. is "full"). """ where_clause = "" - where_args = [] + where_args = [] # type: List[str] if self.is_full(): return where_clause, where_args @@ -221,7 +218,7 @@ class StateFilter(object): return where_clause, where_args - def max_entries_returned(self): + def max_entries_returned(self) -> Optional[int]: """Returns the maximum number of entries this filter will return if known, otherwise returns None. @@ -260,33 +257,33 @@ class StateFilter(object): return filtered_state - def is_full(self): + def is_full(self) -> bool: """Whether this filter fetches everything or not Returns: - bool + True if the filter fetches everything. """ return self.include_others and not self.types - def has_wildcards(self): + def has_wildcards(self) -> bool: """Whether the filter includes wildcards or is attempting to fetch specific state. Returns: - bool + True if the filter includes wildcards. """ return self.include_others or any( state_keys is None for state_keys in self.types.values() ) - def concrete_types(self): + def concrete_types(self) -> List[Tuple[str, str]]: """Returns a list of concrete type/state_keys (i.e. not None) that will be fetched. This will be a complete list if `has_wildcards` returns False, but otherwise will be a subset (or even empty). Returns: - list[tuple[str,str]] + A list of type/state_keys tuples. """ return [ (t, s) @@ -295,7 +292,7 @@ class StateFilter(object): for s in state_keys ] - def get_member_split(self): + def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]: """Return the filter split into two: one which assumes it's exclusively matching against member state, and one which assumes it's matching against non member state. @@ -307,7 +304,7 @@ class StateFilter(object): state caches). Returns: - tuple[StateFilter, StateFilter]: The member and non member filters + The member and non member filters """ if EventTypes.Member in self.types: @@ -340,6 +337,9 @@ class StateGroupStorage(object): """Given a state group try to return a previous group and a delta between the old and the new. + Args: + state_group: The state group used to retrieve state deltas. + Returns: Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]: (prev_group, delta_ids) @@ -347,55 +347,59 @@ class StateGroupStorage(object): return self.stores.state.get_state_group_delta(state_group) - @defer.inlineCallbacks - def get_state_groups_ids(self, _room_id, event_ids): + async def get_state_groups_ids( + self, _room_id: str, event_ids: Iterable[str] + ) -> Dict[int, StateMap[str]]: """Get the event IDs of all the state for the state groups for the given events Args: - _room_id (str): id of the room for these events - event_ids (iterable[str]): ids of the events + _room_id: id of the room for these events + event_ids: ids of the events Returns: - Deferred[dict[int, StateMap[str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) + dict of state_group_id -> (dict of (type, state_key) -> event id) """ if not event_ids: return {} - event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) + event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) groups = set(event_to_groups.values()) - group_to_state = yield self.stores.state._get_state_for_groups(groups) + group_to_state = await self.stores.state._get_state_for_groups(groups) return group_to_state - @defer.inlineCallbacks - def get_state_ids_for_group(self, state_group): + async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]: """Get the event IDs of all the state in the given state group Args: - state_group (int) + state_group: A state group for which we want to get the state IDs. Returns: - Deferred[dict]: Resolves to a map of (type, state_key) -> event_id + Resolves to a map of (type, state_key) -> event_id """ - group_to_state = yield self._get_state_for_groups((state_group,)) + group_to_state = await self._get_state_for_groups((state_group,)) return group_to_state[state_group] - @defer.inlineCallbacks - def get_state_groups(self, room_id, event_ids): + async def get_state_groups( + self, room_id: str, event_ids: Iterable[str] + ) -> Dict[int, List[EventBase]]: """ Get the state groups for the given list of event_ids + + Args: + room_id: ID of the room for these events. + event_ids: The event IDs to retrieve state for. + Returns: - Deferred[dict[int, list[EventBase]]]: - dict of state_group_id -> list of state events. + dict of state_group_id -> list of state events. """ if not event_ids: return {} - group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) + group_to_ids = await self.get_state_groups_ids(room_id, event_ids) - state_event_map = yield self.stores.main.get_events( + state_event_map = await self.stores.main.get_events( [ ev_id for group_ids in group_to_ids.values() @@ -423,31 +427,34 @@ class StateGroupStorage(object): groups: list of state group IDs to query state_filter: The state filter used to fetch state from the database. + Returns: Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. """ return self.stores.state._get_state_groups_from_groups(groups, state_filter) - @defer.inlineCallbacks - def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): + async def get_state_for_events( + self, event_ids: List[str], state_filter: StateFilter = StateFilter.all() + ): """Given a list of event_ids and type tuples, return a list of state dicts for each event. + Args: - event_ids (list[string]) - state_filter (StateFilter): The state filter used to fetch state - from the database. + event_ids: The events to fetch the state of. + state_filter: The state filter used to fetch state. + Returns: - deferred: A dict of (event_id) -> (type, state_key) -> [state_events] + A dict of (event_id) -> (type, state_key) -> [state_events] """ - event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) + event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) groups = set(event_to_groups.values()) - group_to_state = yield self.stores.state._get_state_for_groups( + group_to_state = await self.stores.state._get_state_for_groups( groups, state_filter ) - state_event_map = yield self.stores.main.get_events( + state_event_map = await self.stores.main.get_events( [ev_id for sd in group_to_state.values() for ev_id in sd.values()], get_prev_content=False, ) @@ -463,24 +470,24 @@ class StateGroupStorage(object): return {event: event_to_state[event] for event in event_ids} - @defer.inlineCallbacks - def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()): + async def get_state_ids_for_events( + self, event_ids: List[str], state_filter: StateFilter = StateFilter.all() + ): """ Get the state dicts corresponding to a list of events, containing the event_ids of the state events (as opposed to the events themselves) Args: - event_ids(list(str)): events whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. + event_ids: events whose state should be returned + state_filter: The state filter used to fetch state from the database. Returns: - A deferred dict from event_id -> (type, state_key) -> event_id + A dict from event_id -> (type, state_key) -> event_id """ - event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) + event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) groups = set(event_to_groups.values()) - group_to_state = yield self.stores.state._get_state_for_groups( + group_to_state = await self.stores.state._get_state_for_groups( groups, state_filter ) @@ -491,36 +498,36 @@ class StateGroupStorage(object): return {event: event_to_state[event] for event in event_ids} - @defer.inlineCallbacks - def get_state_for_event(self, event_id, state_filter=StateFilter.all()): + async def get_state_for_event( + self, event_id: str, state_filter: StateFilter = StateFilter.all() + ): """ Get the state dict corresponding to a particular event Args: - event_id(str): event whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. + event_id: event whose state should be returned + state_filter: The state filter used to fetch state from the database. Returns: - A deferred dict from (type, state_key) -> state_event + A dict from (type, state_key) -> state_event """ - state_map = yield self.get_state_for_events([event_id], state_filter) + state_map = await self.get_state_for_events([event_id], state_filter) return state_map[event_id] - @defer.inlineCallbacks - def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()): + async def get_state_ids_for_event( + self, event_id: str, state_filter: StateFilter = StateFilter.all() + ): """ Get the state dict corresponding to a particular event Args: - event_id(str): event whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. + event_id: event whose state should be returned + state_filter: The state filter used to fetch state from the database. Returns: A deferred dict from (type, state_key) -> state_event """ - state_map = yield self.get_state_ids_for_events([event_id], state_filter) + state_map = await self.get_state_ids_for_events([event_id], state_filter) return state_map[event_id] def _get_state_for_groups( @@ -530,9 +537,8 @@ class StateGroupStorage(object): filtering by type/state_key Args: - groups (iterable[int]): list of state groups for which we want - to get the state. - state_filter (StateFilter): The state filter used to fetch state + groups: list of state groups for which we want to get the state. + state_filter: The state filter used to fetch state. from the database. Returns: Deferred[dict[int, StateMap[str]]]: Dict of state group to state map. @@ -540,18 +546,23 @@ class StateGroupStorage(object): return self.stores.state._get_state_for_groups(groups, state_filter) def store_state_group( - self, event_id, room_id, prev_group, delta_ids, current_state_ids + self, + event_id: str, + room_id: str, + prev_group: Optional[int], + delta_ids: Optional[dict], + current_state_ids: dict, ): """Store a new set of state, returning a newly assigned state group. Args: - event_id (str): The event ID for which the state was calculated - room_id (str) - prev_group (int|None): A previous state group for the room, optional. - delta_ids (dict|None): The delta between state at `prev_group` and + event_id: The event ID for which the state was calculated. + room_id: ID of the room for which the state was calculated. + prev_group: A previous state group for the room, optional. + delta_ids: The delta between state at `prev_group` and `current_state_ids`, if `prev_group` was given. Same format as `current_state_ids`. - current_state_ids (dict): The state to store. Map of (type, state_key) + current_state_ids: The state to store. Map of (type, state_key) to event_id. Returns: diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index b9fafaa1a6..a6012c973d 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer + from synapse.rest.client.v1 import room from tests.unittest import HomeserverTestCase @@ -49,7 +51,9 @@ class PurgeTests(HomeserverTestCase): event = self.successResultOf(event) # Purge everything before this topological token - purge = storage.purge_events.purge_history(self.room_id, event, True) + purge = defer.ensureDeferred( + storage.purge_events.purge_history(self.room_id, event, True) + ) self.pump() self.assertEqual(self.successResultOf(purge), None) @@ -88,7 +92,7 @@ class PurgeTests(HomeserverTestCase): ) # Purge everything before this topological token - purge = storage.purge_history(self.room_id, event, True) + purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True)) self.pump() f = self.failureResultOf(purge) self.assertIn("greater than forward", f.value.args[0]) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 1d77b4a2d6..a5f250d477 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -97,8 +97,10 @@ class RoomEventsStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def inject_room_event(self, **kwargs): - yield self.storage.persistence.persist_event( - self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) + yield defer.ensureDeferred( + self.storage.persistence.persist_event( + self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) + ) ) @defer.inlineCallbacks diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index a0e133cd4a..6a48b9d3b3 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -68,7 +68,9 @@ class StateStoreTestCase(tests.unittest.TestCase): self.event_creation_handler.create_new_client_event(builder) ) - yield self.storage.persistence.persist_event(event, context) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) + ) return event @@ -87,8 +89,8 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield self.storage.state.get_state_groups_ids( - self.room, [e2.event_id] + state_group_map = yield defer.ensureDeferred( + self.storage.state.get_state_groups_ids(self.room, [e2.event_id]) ) self.assertEqual(len(state_group_map), 1) state_map = list(state_group_map.values())[0] @@ -106,8 +108,8 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield self.storage.state.get_state_groups( - self.room, [e2.event_id] + state_group_map = yield defer.ensureDeferred( + self.storage.state.get_state_groups(self.room, [e2.event_id]) ) self.assertEqual(len(state_group_map), 1) state_list = list(state_group_map.values())[0] @@ -148,7 +150,9 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we get the full state as of the final event - state = yield self.storage.state.get_state_for_event(e5.event_id) + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event(e5.event_id) + ) self.assertIsNotNone(e4) @@ -164,22 +168,28 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we can filter to the m.room.name event (with a '' state key) - state = yield self.storage.state.get_state_for_event( - e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) + ) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can filter to the m.room.name event (with a wildcard None state key) - state = yield self.storage.state.get_state_for_event( - e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) + ) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can grab the m.room.member events (with a wildcard None state key) - state = yield self.storage.state.get_state_for_event( - e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) + ) ) self.assertStateMapEqual( @@ -188,12 +198,14 @@ class StateStoreTestCase(tests.unittest.TestCase): # check we can grab a specific room member without filtering out the # other event types - state = yield self.storage.state.get_state_for_event( - e5.event_id, - state_filter=StateFilter( - types={EventTypes.Member: {self.u_alice.to_string()}}, - include_others=True, - ), + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, + state_filter=StateFilter( + types={EventTypes.Member: {self.u_alice.to_string()}}, + include_others=True, + ), + ) ) self.assertStateMapEqual( @@ -206,11 +218,13 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check that we can grab everything except members - state = yield self.storage.state.get_state_for_event( - e5.event_id, - state_filter=StateFilter( - types={EventTypes.Member: set()}, include_others=True - ), + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, + state_filter=StateFilter( + types={EventTypes.Member: set()}, include_others=True + ), + ) ) self.assertStateMapEqual( @@ -222,8 +236,8 @@ class StateStoreTestCase(tests.unittest.TestCase): ####################################################### room_id = self.room.to_string() - group_ids = yield self.storage.state.get_state_groups_ids( - room_id, [e5.event_id] + group_ids = yield defer.ensureDeferred( + self.storage.state.get_state_groups_ids(room_id, [e5.event_id]) ) group = list(group_ids.keys())[0] diff --git a/tests/test_visibility.py b/tests/test_visibility.py index a7a36174ea..531a9b9118 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -40,7 +40,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): self.store = self.hs.get_datastore() self.storage = self.hs.get_storage() - yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM") + yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")) @defer.inlineCallbacks def test_filtering(self): @@ -140,7 +140,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): event, context = yield defer.ensureDeferred( self.event_creation_handler.create_new_client_event(builder) ) - yield self.storage.persistence.persist_event(event, context) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) + ) return event @defer.inlineCallbacks @@ -162,7 +164,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): self.event_creation_handler.create_new_client_event(builder) ) - yield self.storage.persistence.persist_event(event, context) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) + ) return event @defer.inlineCallbacks @@ -183,7 +187,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): self.event_creation_handler.create_new_client_event(builder) ) - yield self.storage.persistence.persist_event(event, context) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) + ) return event @defer.inlineCallbacks diff --git a/tests/utils.py b/tests/utils.py index ac643679aa..b33b6860d4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -638,14 +638,8 @@ class DeferredMockCallable(object): ) -@defer.inlineCallbacks -def create_room(hs, room_id, creator_id): +async def create_room(hs, room_id: str, creator_id: str): """Creates and persist a creation event for the given room - - Args: - hs - room_id (str) - creator_id (str) """ persistence_store = hs.get_storage().persistence @@ -653,7 +647,7 @@ def create_room(hs, room_id, creator_id): event_builder_factory = hs.get_event_builder_factory() event_creation_handler = hs.get_event_creation_handler() - yield store.store_room( + await store.store_room( room_id=room_id, room_creator_user_id=creator_id, is_public=False, @@ -671,8 +665,6 @@ def create_room(hs, room_id, creator_id): }, ) - event, context = yield defer.ensureDeferred( - event_creation_handler.create_new_client_event(builder) - ) + event, context = await event_creation_handler.create_new_client_event(builder) - yield persistence_store.persist_event(event, context) + await persistence_store.persist_event(event, context) diff --git a/tox.ini b/tox.ini index 595ab3ba66..a394f6eadc 100644 --- a/tox.ini +++ b/tox.ini @@ -206,6 +206,7 @@ commands = mypy \ synapse/storage/data_stores/main/ui_auth.py \ synapse/storage/database.py \ synapse/storage/engines \ + synapse/storage/state.py \ synapse/storage/util \ synapse/streams \ synapse/util/caches/stream_change_cache.py \ -- cgit 1.5.1 From b3a97d6dac7f9f619b02e213bb8a745d65983d0d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 30 Jul 2020 07:20:41 -0400 Subject: Convert some of the data store to async. (#7976) --- changelog.d/7976.misc | 1 + .../storage/data_stores/main/event_push_actions.py | 92 ++++++++++---------- synapse/storage/data_stores/main/room.py | 98 ++++++++++------------ synapse/storage/data_stores/main/state.py | 57 ++++++------- synapse/storage/data_stores/main/stats.py | 53 ++++++------ synapse/storage/data_stores/state/store.py | 37 ++++---- synapse/storage/state.py | 11 +-- tests/storage/test_event_push_actions.py | 12 ++- tests/storage/test_room.py | 24 +++--- tests/storage/test_state.py | 12 +-- 10 files changed, 190 insertions(+), 207 deletions(-) create mode 100644 changelog.d/7976.misc (limited to 'tests/storage/test_state.py') diff --git a/changelog.d/7976.misc b/changelog.d/7976.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/7976.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index 18297cf3b8..ad82838901 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -15,11 +15,10 @@ # limitations under the License. import logging +from typing import List from canonicaljson import json -from twisted.internet import defer - from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json from synapse.storage.database import Database @@ -166,8 +165,9 @@ class EventPushActionsWorkerStore(SQLBaseStore): return {"notify_count": notify_count, "highlight_count": highlight_count} - @defer.inlineCallbacks - def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering): + async def get_push_action_users_in_range( + self, min_stream_ordering, max_stream_ordering + ): def f(txn): sql = ( "SELECT DISTINCT(user_id) FROM event_push_actions WHERE" @@ -176,26 +176,28 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, (min_stream_ordering, max_stream_ordering)) return [r[0] for r in txn] - ret = yield self.db.runInteraction("get_push_action_users_in_range", f) + ret = await self.db.runInteraction("get_push_action_users_in_range", f) return ret - @defer.inlineCallbacks - def get_unread_push_actions_for_user_in_range_for_http( - self, user_id, min_stream_ordering, max_stream_ordering, limit=20 - ): + async def get_unread_push_actions_for_user_in_range_for_http( + self, + user_id: str, + min_stream_ordering: int, + max_stream_ordering: int, + limit: int = 20, + ) -> List[dict]: """Get a list of the most recent unread push actions for a given user, within the given stream ordering range. Called by the httppusher. Args: - user_id (str): The user to fetch push actions for. - min_stream_ordering(int): The exclusive lower bound on the + user_id: The user to fetch push actions for. + min_stream_ordering: The exclusive lower bound on the stream ordering of event push actions to fetch. - max_stream_ordering(int): The inclusive upper bound on the + max_stream_ordering: The inclusive upper bound on the stream ordering of event push actions to fetch. - limit (int): The maximum number of rows to return. + limit: The maximum number of rows to return. Returns: - A promise which resolves to a list of dicts with the keys "event_id", - "room_id", "stream_ordering", "actions". + A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions". The list will be ordered by ascending stream_ordering. The list will have between 0~limit entries. """ @@ -228,7 +230,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, args) return txn.fetchall() - after_read_receipt = yield self.db.runInteraction( + after_read_receipt = await self.db.runInteraction( "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt ) @@ -256,7 +258,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, args) return txn.fetchall() - no_read_receipt = yield self.db.runInteraction( + no_read_receipt = await self.db.runInteraction( "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt ) @@ -280,23 +282,25 @@ class EventPushActionsWorkerStore(SQLBaseStore): # one of the subqueries may have hit the limit. return notifs[:limit] - @defer.inlineCallbacks - def get_unread_push_actions_for_user_in_range_for_email( - self, user_id, min_stream_ordering, max_stream_ordering, limit=20 - ): + async def get_unread_push_actions_for_user_in_range_for_email( + self, + user_id: str, + min_stream_ordering: int, + max_stream_ordering: int, + limit: int = 20, + ) -> List[dict]: """Get a list of the most recent unread push actions for a given user, within the given stream ordering range. Called by the emailpusher Args: - user_id (str): The user to fetch push actions for. - min_stream_ordering(int): The exclusive lower bound on the + user_id: The user to fetch push actions for. + min_stream_ordering: The exclusive lower bound on the stream ordering of event push actions to fetch. - max_stream_ordering(int): The inclusive upper bound on the + max_stream_ordering: The inclusive upper bound on the stream ordering of event push actions to fetch. - limit (int): The maximum number of rows to return. + limit: The maximum number of rows to return. Returns: - A promise which resolves to a list of dicts with the keys "event_id", - "room_id", "stream_ordering", "actions", "received_ts". + A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions", "received_ts". The list will be ordered by descending received_ts. The list will have between 0~limit entries. """ @@ -328,7 +332,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, args) return txn.fetchall() - after_read_receipt = yield self.db.runInteraction( + after_read_receipt = await self.db.runInteraction( "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt ) @@ -356,7 +360,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, args) return txn.fetchall() - no_read_receipt = yield self.db.runInteraction( + no_read_receipt = await self.db.runInteraction( "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt ) @@ -461,17 +465,13 @@ class EventPushActionsWorkerStore(SQLBaseStore): "add_push_actions_to_staging", _add_push_actions_to_staging_txn ) - @defer.inlineCallbacks - def remove_push_actions_from_staging(self, event_id): + async def remove_push_actions_from_staging(self, event_id: str) -> None: """Called if we failed to persist the event to ensure that stale push actions don't build up in the DB - - Args: - event_id (str) """ try: - res = yield self.db.simple_delete( + res = await self.db.simple_delete( table="event_push_actions_staging", keyvalues={"event_id": event_id}, desc="remove_push_actions_from_staging", @@ -606,8 +606,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): return range_end - @defer.inlineCallbacks - def get_time_of_last_push_action_before(self, stream_ordering): + async def get_time_of_last_push_action_before(self, stream_ordering): def f(txn): sql = ( "SELECT e.received_ts" @@ -620,7 +619,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, (stream_ordering,)) return txn.fetchone() - result = yield self.db.runInteraction("get_time_of_last_push_action_before", f) + result = await self.db.runInteraction("get_time_of_last_push_action_before", f) return result[0] if result else None @@ -650,8 +649,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): self._start_rotate_notifs, 30 * 60 * 1000 ) - @defer.inlineCallbacks - def get_push_actions_for_user( + async def get_push_actions_for_user( self, user_id, before=None, limit=50, only_highlight=False ): def f(txn): @@ -682,18 +680,17 @@ class EventPushActionsStore(EventPushActionsWorkerStore): txn.execute(sql, args) return self.db.cursor_to_dict(txn) - push_actions = yield self.db.runInteraction("get_push_actions_for_user", f) + push_actions = await self.db.runInteraction("get_push_actions_for_user", f) for pa in push_actions: pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) return push_actions - @defer.inlineCallbacks - def get_latest_push_action_stream_ordering(self): + async def get_latest_push_action_stream_ordering(self): def f(txn): txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions") return txn.fetchone() - result = yield self.db.runInteraction( + result = await self.db.runInteraction( "get_latest_push_action_stream_ordering", f ) return result[0] or 0 @@ -747,8 +744,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): def _start_rotate_notifs(self): return run_as_background_process("rotate_notifs", self._rotate_notifs) - @defer.inlineCallbacks - def _rotate_notifs(self): + async def _rotate_notifs(self): if self._doing_notif_rotation or self.stream_ordering_day_ago is None: return self._doing_notif_rotation = True @@ -757,12 +753,12 @@ class EventPushActionsStore(EventPushActionsWorkerStore): while True: logger.info("Rotating notifications") - caught_up = yield self.db.runInteraction( + caught_up = await self.db.runInteraction( "_rotate_notifs", self._rotate_notifs_txn ) if caught_up: break - yield self.hs.get_clock().sleep(self._rotate_delay) + await self.hs.get_clock().sleep(self._rotate_delay) finally: self._doing_notif_rotation = False diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index d2e1e36e7f..ab48052cdc 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -23,8 +23,6 @@ from typing import Any, Dict, List, Optional, Tuple from canonicaljson import json -from twisted.internet import defer - from synapse.api.constants import EventTypes from synapse.api.errors import StoreError from synapse.api.room_versions import RoomVersion, RoomVersions @@ -32,7 +30,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.data_stores.main.search import SearchStore from synapse.storage.database import Database, LoggingTransaction from synapse.types import ThirdPartyInstanceID -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -192,8 +190,7 @@ class RoomWorkerStore(SQLBaseStore): return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn) - @defer.inlineCallbacks - def get_largest_public_rooms( + async def get_largest_public_rooms( self, network_tuple: Optional[ThirdPartyInstanceID], search_filter: Optional[dict], @@ -330,10 +327,10 @@ class RoomWorkerStore(SQLBaseStore): return results - ret_val = yield self.db.runInteraction( + ret_val = await self.db.runInteraction( "get_largest_public_rooms", _get_largest_public_rooms_txn ) - defer.returnValue(ret_val) + return ret_val @cached(max_entries=10000) def is_room_blocked(self, room_id): @@ -509,8 +506,8 @@ class RoomWorkerStore(SQLBaseStore): "get_rooms_paginate", _get_rooms_paginate_txn, ) - @cachedInlineCallbacks(max_entries=10000) - def get_ratelimit_for_user(self, user_id): + @cached(max_entries=10000) + async def get_ratelimit_for_user(self, user_id): """Check if there are any overrides for ratelimiting for the given user @@ -522,7 +519,7 @@ class RoomWorkerStore(SQLBaseStore): of RatelimitOverride are None or 0 then ratelimitng has been disabled for that user entirely. """ - row = yield self.db.simple_select_one( + row = await self.db.simple_select_one( table="ratelimit_override", keyvalues={"user_id": user_id}, retcols=("messages_per_second", "burst_count"), @@ -538,8 +535,8 @@ class RoomWorkerStore(SQLBaseStore): else: return None - @cachedInlineCallbacks() - def get_retention_policy_for_room(self, room_id): + @cached() + async def get_retention_policy_for_room(self, room_id): """Get the retention policy for a given room. If no retention policy has been found for this room, returns a policy defined @@ -566,19 +563,17 @@ class RoomWorkerStore(SQLBaseStore): return self.db.cursor_to_dict(txn) - ret = yield self.db.runInteraction( + ret = await self.db.runInteraction( "get_retention_policy_for_room", get_retention_policy_for_room_txn, ) # If we don't know this room ID, ret will be None, in this case return the default # policy. if not ret: - defer.returnValue( - { - "min_lifetime": self.config.retention_default_min_lifetime, - "max_lifetime": self.config.retention_default_max_lifetime, - } - ) + return { + "min_lifetime": self.config.retention_default_min_lifetime, + "max_lifetime": self.config.retention_default_max_lifetime, + } row = ret[0] @@ -592,7 +587,7 @@ class RoomWorkerStore(SQLBaseStore): if row["max_lifetime"] is None: row["max_lifetime"] = self.config.retention_default_max_lifetime - defer.returnValue(row) + return row def get_media_mxcs_in_room(self, room_id): """Retrieves all the local and remote media MXC URIs in a given room @@ -881,8 +876,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): self._background_add_rooms_room_version_column, ) - @defer.inlineCallbacks - def _background_insert_retention(self, progress, batch_size): + async def _background_insert_retention(self, progress, batch_size): """Retrieves a list of all rooms within a range and inserts an entry for each of them into the room_retention table. NULLs the property's columns if missing from the retention event in the room's @@ -940,14 +934,14 @@ class RoomBackgroundUpdateStore(SQLBaseStore): else: return False - end = yield self.db.runInteraction( + end = await self.db.runInteraction( "insert_room_retention", _background_insert_retention_txn, ) if end: - yield self.db.updates._end_background_update("insert_room_retention") + await self.db.updates._end_background_update("insert_room_retention") - defer.returnValue(batch_size) + return batch_size async def _background_add_rooms_room_version_column( self, progress: dict, batch_size: int @@ -1096,8 +1090,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): lock=False, ) - @defer.inlineCallbacks - def store_room( + async def store_room( self, room_id: str, room_creator_user_id: str, @@ -1140,7 +1133,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) with self._public_room_id_gen.get_next() as next_id: - yield self.db.runInteraction("store_room_txn", store_room_txn, next_id) + await self.db.runInteraction("store_room_txn", store_room_txn, next_id) except Exception as e: logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") @@ -1165,8 +1158,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): lock=False, ) - @defer.inlineCallbacks - def set_room_is_public(self, room_id, is_public): + async def set_room_is_public(self, room_id, is_public): def set_room_is_public_txn(txn, next_id): self.db.simple_update_one_txn( txn, @@ -1206,13 +1198,12 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) with self._public_room_id_gen.get_next() as next_id: - yield self.db.runInteraction( + await self.db.runInteraction( "set_room_is_public", set_room_is_public_txn, next_id ) self.hs.get_notifier().on_new_replication_data() - @defer.inlineCallbacks - def set_room_is_public_appservice( + async def set_room_is_public_appservice( self, room_id, appservice_id, network_id, is_public ): """Edit the appservice/network specific public room list. @@ -1287,7 +1278,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) with self._public_room_id_gen.get_next() as next_id: - yield self.db.runInteraction( + await self.db.runInteraction( "set_room_is_public_appservice", set_room_is_public_appservice_txn, next_id, @@ -1327,52 +1318,47 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): def get_current_public_room_stream_id(self): return self._public_room_id_gen.get_current_token() - @defer.inlineCallbacks - def block_room(self, room_id, user_id): + async def block_room(self, room_id: str, user_id: str) -> None: """Marks the room as blocked. Can be called multiple times. Args: - room_id (str): Room to block - user_id (str): Who blocked it - - Returns: - Deferred + room_id: Room to block + user_id: Who blocked it """ - yield self.db.simple_upsert( + await self.db.simple_upsert( table="blocked_rooms", keyvalues={"room_id": room_id}, values={}, insertion_values={"user_id": user_id}, desc="block_room", ) - yield self.db.runInteraction( + await self.db.runInteraction( "block_room_invalidation", self._invalidate_cache_and_stream, self.is_room_blocked, (room_id,), ) - @defer.inlineCallbacks - def get_rooms_for_retention_period_in_range( - self, min_ms, max_ms, include_null=False - ): + async def get_rooms_for_retention_period_in_range( + self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False + ) -> Dict[str, dict]: """Retrieves all of the rooms within the given retention range. Optionally includes the rooms which don't have a retention policy. Args: - min_ms (int|None): Duration in milliseconds that define the lower limit of + min_ms: Duration in milliseconds that define the lower limit of the range to handle (exclusive). If None, doesn't set a lower limit. - max_ms (int|None): Duration in milliseconds that define the upper limit of + max_ms: Duration in milliseconds that define the upper limit of the range to handle (inclusive). If None, doesn't set an upper limit. - include_null (bool): Whether to include rooms which retention policy is NULL + include_null: Whether to include rooms which retention policy is NULL in the returned set. Returns: - dict[str, dict]: The rooms within this range, along with their retention - policy. The key is "room_id", and maps to a dict describing the retention - policy associated with this room ID. The keys for this nested dict are - "min_lifetime" (int|None), and "max_lifetime" (int|None). + The rooms within this range, along with their retention + policy. The key is "room_id", and maps to a dict describing the retention + policy associated with this room ID. The keys for this nested dict are + "min_lifetime" (int|None), and "max_lifetime" (int|None). """ def get_rooms_for_retention_period_in_range_txn(txn): @@ -1431,9 +1417,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): return rooms_dict - rooms = yield self.db.runInteraction( + rooms = await self.db.runInteraction( "get_rooms_for_retention_period_in_range", get_rooms_for_retention_period_in_range_txn, ) - defer.returnValue(rooms) + return rooms diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index bb38a04ede..a360699408 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -16,12 +16,12 @@ import collections.abc import logging from collections import namedtuple - -from twisted.internet import defer +from typing import Iterable, Optional, Set from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion +from synapse.events import EventBase from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore @@ -108,28 +108,27 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): create_event = await self.get_create_event_for_room(room_id) return create_event.content.get("room_version", "1") - @defer.inlineCallbacks - def get_room_predecessor(self, room_id): + async def get_room_predecessor(self, room_id: str) -> Optional[dict]: """Get the predecessor of an upgraded room if it exists. Otherwise return None. Args: - room_id (str) + room_id: The room ID. Returns: - Deferred[dict|None]: A dictionary containing the structure of the predecessor - field from the room's create event. The structure is subject to other servers, - but it is expected to be: - * room_id (str): The room ID of the predecessor room - * event_id (str): The ID of the tombstone event in the predecessor room + A dictionary containing the structure of the predecessor + field from the room's create event. The structure is subject to other servers, + but it is expected to be: + * room_id (str): The room ID of the predecessor room + * event_id (str): The ID of the tombstone event in the predecessor room - None if a predecessor key is not found, or is not a dictionary. + None if a predecessor key is not found, or is not a dictionary. Raises: NotFoundError if the given room is unknown """ # Retrieve the room's create event - create_event = yield self.get_create_event_for_room(room_id) + create_event = await self.get_create_event_for_room(room_id) # Retrieve the predecessor key of the create event predecessor = create_event.content.get("predecessor", None) @@ -140,20 +139,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return predecessor - @defer.inlineCallbacks - def get_create_event_for_room(self, room_id): + async def get_create_event_for_room(self, room_id: str) -> EventBase: """Get the create state event for a room. Args: - room_id (str) + room_id: The room ID. Returns: - Deferred[EventBase]: The room creation event. + The room creation event. Raises: NotFoundError if the room is unknown """ - state_ids = yield self.get_current_state_ids(room_id) + state_ids = await self.get_current_state_ids(room_id) create_id = state_ids.get((EventTypes.Create, "")) # If we can't find the create event, assume we've hit a dead end @@ -161,7 +159,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): raise NotFoundError("Unknown room %s" % (room_id,)) # Retrieve the room's create event and return - create_event = yield self.get_event(create_id) + create_event = await self.get_event(create_id) return create_event @cached(max_entries=100000, iterable=True) @@ -237,18 +235,17 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn ) - @defer.inlineCallbacks - def get_canonical_alias_for_room(self, room_id): + async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]: """Get canonical alias for room, if any Args: - room_id (str) + room_id: The room ID Returns: - Deferred[str|None]: The canonical alias, if any + The canonical alias, if any """ - state = yield self.get_filtered_current_state_ids( + state = await self.get_filtered_current_state_ids( room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")]) ) @@ -256,7 +253,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): if not event_id: return - event = yield self.get_event(event_id, allow_none=True) + event = await self.get_event(event_id, allow_none=True) if not event: return @@ -292,19 +289,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return {row["event_id"]: row["state_group"] for row in rows} - @defer.inlineCallbacks - def get_referenced_state_groups(self, state_groups): + async def get_referenced_state_groups( + self, state_groups: Iterable[int] + ) -> Set[int]: """Check if the state groups are referenced by events. Args: - state_groups (Iterable[int]) + state_groups Returns: - Deferred[set[int]]: The subset of state groups that are - referenced. + The subset of state groups that are referenced. """ - rows = yield self.db.simple_select_many_batch( + rows = await self.db.simple_select_many_batch( table="event_to_state_groups", column="state_group", iterable=state_groups, diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py index 380c1ec7da..922400a7c3 100644 --- a/synapse/storage/data_stores/main/stats.py +++ b/synapse/storage/data_stores/main/stats.py @@ -16,8 +16,8 @@ import logging from itertools import chain +from typing import Tuple -from twisted.internet import defer from twisted.internet.defer import DeferredLock from synapse.api.constants import EventTypes, Membership @@ -97,13 +97,12 @@ class StatsStore(StateDeltasStore): """ return (ts // self.stats_bucket_size) * self.stats_bucket_size - @defer.inlineCallbacks - def _populate_stats_process_users(self, progress, batch_size): + async def _populate_stats_process_users(self, progress, batch_size): """ This is a background update which regenerates statistics for users. """ if not self.stats_enabled: - yield self.db.updates._end_background_update("populate_stats_process_users") + await self.db.updates._end_background_update("populate_stats_process_users") return 1 last_user_id = progress.get("last_user_id", "") @@ -118,20 +117,20 @@ class StatsStore(StateDeltasStore): txn.execute(sql, (last_user_id, batch_size)) return [r for r, in txn] - users_to_work_on = yield self.db.runInteraction( + users_to_work_on = await self.db.runInteraction( "_populate_stats_process_users", _get_next_batch ) # No more rooms -- complete the transaction. if not users_to_work_on: - yield self.db.updates._end_background_update("populate_stats_process_users") + await self.db.updates._end_background_update("populate_stats_process_users") return 1 for user_id in users_to_work_on: - yield self._calculate_and_set_initial_state_for_user(user_id) + await self._calculate_and_set_initial_state_for_user(user_id) progress["last_user_id"] = user_id - yield self.db.runInteraction( + await self.db.runInteraction( "populate_stats_process_users", self.db.updates._background_update_progress_txn, "populate_stats_process_users", @@ -140,13 +139,12 @@ class StatsStore(StateDeltasStore): return len(users_to_work_on) - @defer.inlineCallbacks - def _populate_stats_process_rooms(self, progress, batch_size): + async def _populate_stats_process_rooms(self, progress, batch_size): """ This is a background update which regenerates statistics for rooms. """ if not self.stats_enabled: - yield self.db.updates._end_background_update("populate_stats_process_rooms") + await self.db.updates._end_background_update("populate_stats_process_rooms") return 1 last_room_id = progress.get("last_room_id", "") @@ -161,20 +159,20 @@ class StatsStore(StateDeltasStore): txn.execute(sql, (last_room_id, batch_size)) return [r for r, in txn] - rooms_to_work_on = yield self.db.runInteraction( + rooms_to_work_on = await self.db.runInteraction( "populate_stats_rooms_get_batch", _get_next_batch ) # No more rooms -- complete the transaction. if not rooms_to_work_on: - yield self.db.updates._end_background_update("populate_stats_process_rooms") + await self.db.updates._end_background_update("populate_stats_process_rooms") return 1 for room_id in rooms_to_work_on: - yield self._calculate_and_set_initial_state_for_room(room_id) + await self._calculate_and_set_initial_state_for_room(room_id) progress["last_room_id"] = room_id - yield self.db.runInteraction( + await self.db.runInteraction( "_populate_stats_process_rooms", self.db.updates._background_update_progress_txn, "populate_stats_process_rooms", @@ -696,16 +694,16 @@ class StatsStore(StateDeltasStore): return room_deltas, user_deltas - @defer.inlineCallbacks - def _calculate_and_set_initial_state_for_room(self, room_id): + async def _calculate_and_set_initial_state_for_room( + self, room_id: str + ) -> Tuple[dict, dict, int]: """Calculate and insert an entry into room_stats_current. Args: - room_id (str) + room_id: The room ID under calculation. Returns: - Deferred[tuple[dict, dict, int]]: A tuple of room state, membership - counts and stream position. + A tuple of room state, membership counts and stream position. """ def _fetch_current_state_stats(txn): @@ -767,11 +765,11 @@ class StatsStore(StateDeltasStore): current_state_events_count, users_in_room, pos, - ) = yield self.db.runInteraction( + ) = await self.db.runInteraction( "get_initial_state_for_room", _fetch_current_state_stats ) - state_event_map = yield self.get_events(event_ids, get_prev_content=False) + state_event_map = await self.get_events(event_ids, get_prev_content=False) room_state = { "join_rules": None, @@ -806,11 +804,11 @@ class StatsStore(StateDeltasStore): event.content.get("m.federate", True) is True ) - yield self.update_room_state(room_id, room_state) + await self.update_room_state(room_id, room_state) local_users_in_room = [u for u in users_in_room if self.hs.is_mine_id(u)] - yield self.update_stats_delta( + await self.update_stats_delta( ts=self.clock.time_msec(), stats_type="room", stats_id=room_id, @@ -826,8 +824,7 @@ class StatsStore(StateDeltasStore): }, ) - @defer.inlineCallbacks - def _calculate_and_set_initial_state_for_user(self, user_id): + async def _calculate_and_set_initial_state_for_user(self, user_id): def _calculate_and_set_initial_state_for_user_txn(txn): pos = self._get_max_stream_id_in_current_state_deltas_txn(txn) @@ -842,12 +839,12 @@ class StatsStore(StateDeltasStore): (count,) = txn.fetchone() return count, pos - joined_rooms, pos = yield self.db.runInteraction( + joined_rooms, pos = await self.db.runInteraction( "calculate_and_set_initial_state_for_user", _calculate_and_set_initial_state_for_user_txn, ) - yield self.update_stats_delta( + await self.update_stats_delta( ts=self.clock.time_msec(), stats_type="user", stats_id=user_id, diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py index 128c09a2cf..7dada7f75f 100644 --- a/synapse/storage/data_stores/state/store.py +++ b/synapse/storage/data_stores/state/store.py @@ -139,10 +139,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): "get_state_group_delta", _get_state_group_delta_txn ) - @defer.inlineCallbacks - def _get_state_groups_from_groups( + async def _get_state_groups_from_groups( self, groups: List[int], state_filter: StateFilter - ): + ) -> Dict[int, StateMap[str]]: """Returns the state groups for a given set of groups from the database, filtering on types of state events. @@ -151,13 +150,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): state_filter: The state filter used to fetch state from the database. Returns: - Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. + Dict of state group to state map. """ results = {} chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] for chunk in chunks: - res = yield self.db.runInteraction( + res = await self.db.runInteraction( "_get_state_groups_from_groups", self._get_state_groups_from_groups_txn, chunk, @@ -206,10 +205,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return state_filter.filter_state(state_dict_ids), not missing_types - @defer.inlineCallbacks - def _get_state_for_groups( + async def _get_state_for_groups( self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() - ): + ) -> Dict[int, StateMap[str]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key @@ -219,7 +217,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): state_filter: The state filter used to fetch state from the database. Returns: - Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. + Dict of state group to state map. """ member_filter, non_member_filter = state_filter.get_member_split() @@ -228,14 +226,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): ( non_member_state, incomplete_groups_nm, - ) = yield self._get_state_for_groups_using_cache( + ) = self._get_state_for_groups_using_cache( groups, self._state_group_cache, state_filter=non_member_filter ) - ( - member_state, - incomplete_groups_m, - ) = yield self._get_state_for_groups_using_cache( + (member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache( groups, self._state_group_members_cache, state_filter=member_filter ) @@ -256,7 +251,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): # Help the cache hit ratio by expanding the filter a bit db_state_filter = state_filter.return_expanded() - group_to_state_dict = yield self._get_state_groups_from_groups( + group_to_state_dict = await self._get_state_groups_from_groups( list(incomplete_groups), state_filter=db_state_filter ) @@ -576,19 +571,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): ((sg,) for sg in state_groups_to_delete), ) - @defer.inlineCallbacks - def get_previous_state_groups(self, state_groups): + async def get_previous_state_groups( + self, state_groups: Iterable[int] + ) -> Dict[int, int]: """Fetch the previous groups of the given state groups. Args: - state_groups (Iterable[int]) + state_groups Returns: - Deferred[dict[int, int]]: mapping from state group to previous - state group. + A mapping from state group to previous state group. """ - rows = yield self.db.simple_select_many_batch( + rows = await self.db.simple_select_many_batch( table="state_group_edges", column="prev_state_group", iterable=state_groups, diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 49ee9c9a74..534883361f 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Dict, Iterable, List, Optional, Set, Tuple, TypeVar +from typing import Awaitable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar import attr @@ -419,7 +419,7 @@ class StateGroupStorage(object): def _get_state_groups_from_groups( self, groups: List[int], state_filter: StateFilter - ): + ) -> Awaitable[Dict[int, StateMap[str]]]: """Returns the state groups for a given set of groups, filtering on types of state events. @@ -429,7 +429,7 @@ class StateGroupStorage(object): from the database. Returns: - Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. + Dict of state group to state map. """ return self.stores.state._get_state_groups_from_groups(groups, state_filter) @@ -532,7 +532,7 @@ class StateGroupStorage(object): def _get_state_for_groups( self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() - ): + ) -> Awaitable[Dict[int, StateMap[str]]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key @@ -540,8 +540,9 @@ class StateGroupStorage(object): groups: list of state groups for which we want to get the state. state_filter: The state filter used to fetch state. from the database. + Returns: - Deferred[dict[int, StateMap[str]]]: Dict of state group to state map. + Dict of state group to state map. """ return self.stores.state._get_state_for_groups(groups, state_filter) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 43dbeb42c5..2b1580feeb 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -39,14 +39,18 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_get_unread_push_actions_for_user_in_range_for_http(self): - yield self.store.get_unread_push_actions_for_user_in_range_for_http( - USER_ID, 0, 1000, 20 + yield defer.ensureDeferred( + self.store.get_unread_push_actions_for_user_in_range_for_http( + USER_ID, 0, 1000, 20 + ) ) @defer.inlineCallbacks def test_get_unread_push_actions_for_user_in_range_for_email(self): - yield self.store.get_unread_push_actions_for_user_in_range_for_email( - USER_ID, 0, 1000, 20 + yield defer.ensureDeferred( + self.store.get_unread_push_actions_for_user_in_range_for_email( + USER_ID, 0, 1000, 20 + ) ) @defer.inlineCallbacks diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index a5f250d477..d07b985a8e 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -37,11 +37,13 @@ class RoomStoreTestCase(unittest.TestCase): self.alias = RoomAlias.from_string("#a-room-name:test") self.u_creator = UserID.from_string("@creator:test") - yield self.store.store_room( - self.room.to_string(), - room_creator_user_id=self.u_creator.to_string(), - is_public=True, - room_version=RoomVersions.V1, + yield defer.ensureDeferred( + self.store.store_room( + self.room.to_string(), + room_creator_user_id=self.u_creator.to_string(), + is_public=True, + room_version=RoomVersions.V1, + ) ) @defer.inlineCallbacks @@ -88,11 +90,13 @@ class RoomEventsStoreTestCase(unittest.TestCase): self.room = RoomID.from_string("!abcde:test") - yield self.store.store_room( - self.room.to_string(), - room_creator_user_id="@creator:text", - is_public=True, - room_version=RoomVersions.V1, + yield defer.ensureDeferred( + self.store.store_room( + self.room.to_string(), + room_creator_user_id="@creator:text", + is_public=True, + room_version=RoomVersions.V1, + ) ) @defer.inlineCallbacks diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 6a48b9d3b3..8bd12fa847 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -44,11 +44,13 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room = RoomID.from_string("!abc123:test") - yield self.store.store_room( - self.room.to_string(), - room_creator_user_id="@creator:text", - is_public=True, - room_version=RoomVersions.V1, + yield defer.ensureDeferred( + self.store.store_room( + self.room.to_string(), + room_creator_user_id="@creator:text", + is_public=True, + room_version=RoomVersions.V1, + ) ) @defer.inlineCallbacks -- cgit 1.5.1