From 491f0dab1ba5456f52b0710461fbaabc594ff1f5 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 14 Jul 2020 13:36:23 +0200 Subject: Add delete room admin endpoint (#7613) The Delete Room admin API allows server admins to remove rooms from server and block these rooms. `DELETE /_synapse/admin/v1/rooms/` It is a combination and improvement of "[Shutdown room](https://github.com/matrix-org/synapse/blob/develop/docs/admin_api/shutdown_room.md)" and "[Purge room](https://github.com/matrix-org/synapse/blob/develop/docs/admin_api/purge_room.md)" API. Fixes: #6425 It also fixes a bug in [synapse/storage/data_stores/main/room.py](synapse/storage/data_stores/main/room.py) in ` get_room_with_stats`. It should return `None` if the room is unknown. But it returns an `IndexError`. https://github.com/matrix-org/synapse/blob/901b1fa561e3cc661d78aa96d59802cf2078cb0d/synapse/storage/data_stores/main/room.py#L99-L105 Related to: - #5575 - https://github.com/Awesome-Technologies/synapse-admin/issues/17 Signed-off-by: Dirk Klimpel dirk@klimpel.org --- tests/storage/test_room.py | 8 ++++++++ 1 file changed, 8 insertions(+) (limited to 'tests/storage/test_room.py') diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 3b78d48896..b1dceb2918 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -55,6 +55,10 @@ class RoomStoreTestCase(unittest.TestCase): (yield self.store.get_room(self.room.to_string())), ) + @defer.inlineCallbacks + def test_get_room_unknown_room(self): + self.assertIsNone((yield self.store.get_room("!uknown:test")),) + @defer.inlineCallbacks def test_get_room_with_stats(self): self.assertDictContainsSubset( @@ -66,6 +70,10 @@ class RoomStoreTestCase(unittest.TestCase): (yield self.store.get_room_with_stats(self.room.to_string())), ) + @defer.inlineCallbacks + def test_get_room_with_stats_unknown_room(self): + self.assertIsNone((yield self.store.get_room_with_stats("!uknown:test")),) + class RoomEventsStoreTestCase(unittest.TestCase): @defer.inlineCallbacks -- cgit 1.5.1 From b975fa2e9952f1f8ac2cddb15c287768bf9b0b4e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 24 Jul 2020 10:59:51 -0400 Subject: Convert state resolution to async/await (#7942) --- changelog.d/7942.misc | 1 + synapse/api/auth.py | 12 ++- synapse/events/builder.py | 4 +- synapse/federation/sender/__init__.py | 4 +- synapse/handlers/presence.py | 4 +- synapse/push/bulk_push_rule_evaluator.py | 4 +- synapse/state/__init__.py | 95 ++++++++---------- synapse/state/v1.py | 15 ++- synapse/state/v2.py | 107 ++++++++++----------- synapse/storage/data_stores/main/push_rule.py | 2 +- synapse/storage/data_stores/main/roommember.py | 2 +- synapse/storage/data_stores/main/user_directory.py | 4 +- synapse/storage/persist_events.py | 5 +- tests/federation/test_federation_sender.py | 19 ++-- tests/state/test_v2.py | 17 ++-- tests/storage/test_room.py | 8 +- tests/test_state.py | 72 ++++++++------ tests/test_utils/__init__.py | 7 +- 18 files changed, 198 insertions(+), 184 deletions(-) create mode 100644 changelog.d/7942.misc (limited to 'tests/storage/test_room.py') diff --git a/changelog.d/7942.misc b/changelog.d/7942.misc new file mode 100644 index 0000000000..b504cf4e6f --- /dev/null +++ b/changelog.d/7942.misc @@ -0,0 +1 @@ +Convert state resolution to async/await. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 40dc62ef6c..b53e8451e5 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -127,8 +127,10 @@ class Auth(object): if current_state: member = current_state.get((EventTypes.Member, user_id), None) else: - member = yield self.state.get_current_state( - room_id=room_id, event_type=EventTypes.Member, state_key=user_id + member = yield defer.ensureDeferred( + self.state.get_current_state( + room_id=room_id, event_type=EventTypes.Member, state_key=user_id + ) ) membership = member.membership if member else None @@ -665,8 +667,10 @@ class Auth(object): ) return member_event.membership, member_event.event_id except AuthError: - visibility = yield self.state.get_current_state( - room_id, EventTypes.RoomHistoryVisibility, "" + visibility = yield defer.ensureDeferred( + self.state.get_current_state( + room_id, EventTypes.RoomHistoryVisibility, "" + ) ) if ( visibility diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 92aadfe7ef..0bb216419a 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -106,8 +106,8 @@ class EventBuilder(object): Deferred[FrozenEvent] """ - state_ids = yield self._state.get_current_state_ids( - self.room_id, prev_event_ids + state_ids = yield defer.ensureDeferred( + self._state.get_current_state_ids(self.room_id, prev_event_ids) ) auth_ids = yield self._auth.compute_auth_events(self, state_ids) diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 99ce73e081..ba4ddd2370 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -330,7 +330,9 @@ class FederationSender(object): room_id = receipt.room_id # Work out which remote servers should be poked and poke them. - domains = yield self.state.get_current_hosts_in_room(room_id) + domains = yield defer.ensureDeferred( + self.state.get_current_hosts_in_room(room_id) + ) domains = [ d for d in domains diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 8e99c83d9d..b3a3bb8c3f 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -928,8 +928,8 @@ class PresenceHandler(BasePresenceHandler): # TODO: Check that this is actually a new server joining the # room. - user_ids = await self.state.get_current_users_in_room(room_id) - user_ids = list(filter(self.is_mine_id, user_ids)) + users = await self.state.get_current_users_in_room(room_id) + user_ids = list(filter(self.is_mine_id, users)) states_d = await self.current_state_for_users(user_ids) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 43ffe6faf0..472ddf9f7d 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -304,7 +304,9 @@ class RulesForRoom(object): push_rules_delta_state_cache_metric.inc_hits() else: - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred( + context.get_current_state_ids() + ) push_rules_delta_state_cache_metric.inc_misses() push_rules_state_size_counter.inc(len(current_state_ids)) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 495d9f04c8..25ccef5aa5 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -16,14 +16,12 @@ import logging from collections import namedtuple -from typing import Dict, Iterable, List, Optional, Set +from typing import Awaitable, Dict, Iterable, List, Optional, Set import attr from frozendict import frozendict from prometheus_client import Histogram -from twisted.internet import defer - from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions from synapse.events import EventBase @@ -31,6 +29,7 @@ from synapse.events.snapshot import EventContext from synapse.logging.utils import log_function from synapse.state import v1, v2 from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour +from synapse.storage.roommember import ProfileInfo from synapse.types import StateMap from synapse.util import Clock from synapse.util.async_helpers import Linearizer @@ -108,8 +107,7 @@ class StateHandler(object): self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() - @defer.inlineCallbacks - def get_current_state( + async def get_current_state( self, room_id, event_type=None, state_key="", latest_event_ids=None ): """ Retrieves the current state for the room. This is done by @@ -126,20 +124,20 @@ class StateHandler(object): map from (type, state_key) to event """ if not latest_event_ids: - latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) logger.debug("calling resolve_state_groups from get_current_state") - ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) + ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) state = ret.state if event_type: event_id = state.get((event_type, state_key)) event = None if event_id: - event = yield self.store.get_event(event_id, allow_none=True) + event = await self.store.get_event(event_id, allow_none=True) return event - state_map = yield self.store.get_events( + state_map = await self.store.get_events( list(state.values()), get_prev_content=False ) state = { @@ -148,8 +146,7 @@ class StateHandler(object): return state - @defer.inlineCallbacks - def get_current_state_ids(self, room_id, latest_event_ids=None): + async def get_current_state_ids(self, room_id, latest_event_ids=None): """Get the current state, or the state at a set of events, for a room Args: @@ -164,41 +161,38 @@ class StateHandler(object): (event_type, state_key) -> event_id """ if not latest_event_ids: - latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) logger.debug("calling resolve_state_groups from get_current_state_ids") - ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) + ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) state = ret.state return state - @defer.inlineCallbacks - def get_current_users_in_room(self, room_id, latest_event_ids=None): + async def get_current_users_in_room( + self, room_id: str, latest_event_ids: Optional[List[str]] = None + ) -> Dict[str, ProfileInfo]: """ Get the users who are currently in a room. Args: - room_id (str): The ID of the room. - latest_event_ids (List[str]|None): Precomputed list of latest - event IDs. Will be computed if None. + room_id: The ID of the room. + latest_event_ids: Precomputed list of latest event IDs. Will be computed if None. Returns: - Deferred[Dict[str,ProfileInfo]]: Dictionary of user IDs to their - profileinfo. + Dictionary of user IDs to their profileinfo. """ if not latest_event_ids: - latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) logger.debug("calling resolve_state_groups from get_current_users_in_room") - entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) - joined_users = yield self.store.get_joined_users_from_state(room_id, entry) + entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids) + joined_users = await self.store.get_joined_users_from_state(room_id, entry) return joined_users - @defer.inlineCallbacks - def get_current_hosts_in_room(self, room_id): - event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - return (yield self.get_hosts_in_room_at_events(room_id, event_ids)) + async def get_current_hosts_in_room(self, room_id): + event_ids = await self.store.get_latest_event_ids_in_room(room_id) + return await self.get_hosts_in_room_at_events(room_id, event_ids) - @defer.inlineCallbacks - def get_hosts_in_room_at_events(self, room_id, event_ids): + async def get_hosts_in_room_at_events(self, room_id, event_ids): """Get the hosts that were in a room at the given event ids Args: @@ -208,12 +202,11 @@ class StateHandler(object): Returns: Deferred[list[str]]: the hosts in the room at the given events """ - entry = yield self.resolve_state_groups_for_events(room_id, event_ids) - joined_hosts = yield self.store.get_joined_hosts(room_id, entry) + entry = await self.resolve_state_groups_for_events(room_id, event_ids) + joined_hosts = await self.store.get_joined_hosts(room_id, entry) return joined_hosts - @defer.inlineCallbacks - def compute_event_context( + async def compute_event_context( self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None ): """Build an EventContext structure for the event. @@ -278,7 +271,7 @@ class StateHandler(object): # otherwise, we'll need to resolve the state across the prev_events. logger.debug("calling resolve_state_groups from compute_event_context") - entry = yield self.resolve_state_groups_for_events( + entry = await self.resolve_state_groups_for_events( event.room_id, event.prev_event_ids() ) @@ -295,7 +288,7 @@ class StateHandler(object): # if not state_group_before_event: - state_group_before_event = yield self.state_store.store_state_group( + state_group_before_event = await self.state_store.store_state_group( event.event_id, event.room_id, prev_group=state_group_before_event_prev_group, @@ -335,7 +328,7 @@ class StateHandler(object): state_ids_after_event[key] = event.event_id delta_ids = {key: event.event_id} - state_group_after_event = yield self.state_store.store_state_group( + state_group_after_event = await self.state_store.store_state_group( event.event_id, event.room_id, prev_group=state_group_before_event, @@ -353,8 +346,7 @@ class StateHandler(object): ) @measure_func() - @defer.inlineCallbacks - def resolve_state_groups_for_events(self, room_id, event_ids): + async def resolve_state_groups_for_events(self, room_id, event_ids): """ Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. @@ -373,7 +365,7 @@ class StateHandler(object): # map from state group id to the state in that state group (where # 'state' is a map from state key to event id) # dict[int, dict[(str, str), str]] - state_groups_ids = yield self.state_store.get_state_groups_ids( + state_groups_ids = await self.state_store.get_state_groups_ids( room_id, event_ids ) @@ -382,7 +374,7 @@ class StateHandler(object): elif len(state_groups_ids) == 1: name, state_list = list(state_groups_ids.items()).pop() - prev_group, delta_ids = yield self.state_store.get_state_group_delta(name) + prev_group, delta_ids = await self.state_store.get_state_group_delta(name) return _StateCacheEntry( state=state_list, @@ -391,9 +383,9 @@ class StateHandler(object): delta_ids=delta_ids, ) - room_version = yield self.store.get_room_version_id(room_id) + room_version = await self.store.get_room_version_id(room_id) - result = yield self._state_resolution_handler.resolve_state_groups( + result = await self._state_resolution_handler.resolve_state_groups( room_id, room_version, state_groups_ids, @@ -402,8 +394,7 @@ class StateHandler(object): ) return result - @defer.inlineCallbacks - def resolve_events(self, room_version, state_sets, event): + async def resolve_events(self, room_version, state_sets, event): logger.info( "Resolving state for %s with %d groups", event.room_id, len(state_sets) ) @@ -414,7 +405,7 @@ class StateHandler(object): state_map = {ev.event_id: ev for st in state_sets for ev in st} with Measure(self.clock, "state._resolve_events"): - new_state = yield resolve_events_with_store( + new_state = await resolve_events_with_store( self.clock, event.room_id, room_version, @@ -451,9 +442,8 @@ class StateResolutionHandler(object): reset_expiry_on_get=True, ) - @defer.inlineCallbacks @log_function - def resolve_state_groups( + async def resolve_state_groups( self, room_id, room_version, state_groups_ids, event_map, state_res_store ): """Resolves conflicts between a set of state groups @@ -479,13 +469,13 @@ class StateResolutionHandler(object): state_res_store (StateResolutionStore) Returns: - Deferred[_StateCacheEntry]: resolved state + _StateCacheEntry: resolved state """ logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys()) group_names = frozenset(state_groups_ids.keys()) - with (yield self.resolve_linearizer.queue(group_names)): + with (await self.resolve_linearizer.queue(group_names)): if self._state_cache is not None: cache = self._state_cache.get(group_names, None) if cache: @@ -517,7 +507,7 @@ class StateResolutionHandler(object): if conflicted_state: logger.info("Resolving conflicted state for %r", room_id) with Measure(self.clock, "state._resolve_events"): - new_state = yield resolve_events_with_store( + new_state = await resolve_events_with_store( self.clock, room_id, room_version, @@ -598,7 +588,7 @@ def resolve_events_with_store( state_sets: List[StateMap[str]], event_map: Optional[Dict[str, EventBase]], state_res_store: "StateResolutionStore", -): +) -> Awaitable[StateMap[str]]: """ Args: room_id: the room we are working in @@ -619,8 +609,7 @@ def resolve_events_with_store( state_res_store: a place to fetch events from Returns: - Deferred[dict[(str, str), str]]: - a map from (type, state_key) to event_id. + a map from (type, state_key) to event_id. """ v = KNOWN_ROOM_VERSIONS[room_version] if v.state_res == StateResolutionVersions.V1: diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 7b531a8337..ab5e24841d 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -15,9 +15,7 @@ import hashlib import logging -from typing import Callable, Dict, List, Optional - -from twisted.internet import defer +from typing import Awaitable, Callable, Dict, List, Optional from synapse import event_auth from synapse.api.constants import EventTypes @@ -32,12 +30,11 @@ logger = logging.getLogger(__name__) POWER_KEY = (EventTypes.PowerLevels, "") -@defer.inlineCallbacks -def resolve_events_with_store( +async def resolve_events_with_store( room_id: str, state_sets: List[StateMap[str]], event_map: Optional[Dict[str, EventBase]], - state_map_factory: Callable, + state_map_factory: Callable[[List[str]], Awaitable], ): """ Args: @@ -56,7 +53,7 @@ def resolve_events_with_store( state_map_factory: will be called with a list of event_ids that are needed, and should return with - a Deferred of dict of event_id to event. + an Awaitable that resolves to a dict of event_id to event. Returns: Deferred[dict[(str, str), str]]: @@ -80,7 +77,7 @@ def resolve_events_with_store( # dict[str, FrozenEvent]: a map from state event id to event. Only includes # the state events which are in conflict (and those in event_map) - state_map = yield state_map_factory(needed_events) + state_map = await state_map_factory(needed_events) if event_map is not None: state_map.update(event_map) @@ -110,7 +107,7 @@ def resolve_events_with_store( "Asking for %d/%d auth events", len(new_needed_events), new_needed_event_count ) - state_map_new = yield state_map_factory(new_needed_events) + state_map_new = await state_map_factory(new_needed_events) for event in state_map_new.values(): if event.room_id != room_id: raise Exception( diff --git a/synapse/state/v2.py b/synapse/state/v2.py index bf6caa0946..6634955cdc 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -18,8 +18,6 @@ import itertools import logging from typing import Dict, List, Optional -from twisted.internet import defer - import synapse.state from synapse import event_auth from synapse.api.constants import EventTypes @@ -32,14 +30,13 @@ from synapse.util import Clock logger = logging.getLogger(__name__) -# We want to yield to the reactor occasionally during state res when dealing +# We want to await to the reactor occasionally during state res when dealing # with large data sets, so that we don't exhaust the reactor. This is done by -# yielding to reactor during loops every N iterations. -_YIELD_AFTER_ITERATIONS = 100 +# awaiting to reactor during loops every N iterations. +_AWAIT_AFTER_ITERATIONS = 100 -@defer.inlineCallbacks -def resolve_events_with_store( +async def resolve_events_with_store( clock: Clock, room_id: str, room_version: str, @@ -87,7 +84,7 @@ def resolve_events_with_store( # Also fetch all auth events that appear in only some of the state sets' # auth chains. - auth_diff = yield _get_auth_chain_difference(state_sets, event_map, state_res_store) + auth_diff = await _get_auth_chain_difference(state_sets, event_map, state_res_store) full_conflicted_set = set( itertools.chain( @@ -95,7 +92,7 @@ def resolve_events_with_store( ) ) - events = yield state_res_store.get_events( + events = await state_res_store.get_events( [eid for eid in full_conflicted_set if eid not in event_map], allow_rejected=True, ) @@ -118,14 +115,14 @@ def resolve_events_with_store( eid for eid in full_conflicted_set if _is_power_event(event_map[eid]) ) - sorted_power_events = yield _reverse_topological_power_sort( + sorted_power_events = await _reverse_topological_power_sort( clock, room_id, power_events, event_map, state_res_store, full_conflicted_set ) logger.debug("sorted %d power events", len(sorted_power_events)) # Now sequentially auth each one - resolved_state = yield _iterative_auth_checks( + resolved_state = await _iterative_auth_checks( clock, room_id, room_version, @@ -148,13 +145,13 @@ def resolve_events_with_store( logger.debug("sorting %d remaining events", len(leftover_events)) pl = resolved_state.get((EventTypes.PowerLevels, ""), None) - leftover_events = yield _mainline_sort( + leftover_events = await _mainline_sort( clock, room_id, leftover_events, pl, event_map, state_res_store ) logger.debug("resolving remaining events") - resolved_state = yield _iterative_auth_checks( + resolved_state = await _iterative_auth_checks( clock, room_id, room_version, @@ -174,8 +171,7 @@ def resolve_events_with_store( return resolved_state -@defer.inlineCallbacks -def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): +async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): """Return the power level of the sender of the given event according to their auth events. @@ -188,11 +184,11 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): Returns: Deferred[int] """ - event = yield _get_event(room_id, event_id, event_map, state_res_store) + event = await _get_event(room_id, event_id, event_map, state_res_store) pl = None for aid in event.auth_event_ids(): - aev = yield _get_event( + aev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): @@ -202,7 +198,7 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): if pl is None: # Couldn't find power level. Check if they're the creator of the room for aid in event.auth_event_ids(): - aev = yield _get_event( + aev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) if aev and (aev.type, aev.state_key) == (EventTypes.Create, ""): @@ -221,8 +217,7 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): return int(level) -@defer.inlineCallbacks -def _get_auth_chain_difference(state_sets, event_map, state_res_store): +async def _get_auth_chain_difference(state_sets, event_map, state_res_store): """Compare the auth chains of each state set and return the set of events that only appear in some but not all of the auth chains. @@ -235,7 +230,7 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store): Deferred[set[str]]: Set of event IDs """ - difference = yield state_res_store.get_auth_chain_difference( + difference = await state_res_store.get_auth_chain_difference( [set(state_set.values()) for state_set in state_sets] ) @@ -292,8 +287,7 @@ def _is_power_event(event): return False -@defer.inlineCallbacks -def _add_event_and_auth_chain_to_graph( +async def _add_event_and_auth_chain_to_graph( graph, room_id, event_id, event_map, state_res_store, auth_diff ): """Helper function for _reverse_topological_power_sort that add the event @@ -314,7 +308,7 @@ def _add_event_and_auth_chain_to_graph( eid = state.pop() graph.setdefault(eid, set()) - event = yield _get_event(room_id, eid, event_map, state_res_store) + event = await _get_event(room_id, eid, event_map, state_res_store) for aid in event.auth_event_ids(): if aid in auth_diff: if aid not in graph: @@ -323,8 +317,7 @@ def _add_event_and_auth_chain_to_graph( graph.setdefault(eid, set()).add(aid) -@defer.inlineCallbacks -def _reverse_topological_power_sort( +async def _reverse_topological_power_sort( clock, room_id, event_ids, event_map, state_res_store, auth_diff ): """Returns a list of the event_ids sorted by reverse topological ordering, @@ -344,26 +337,26 @@ def _reverse_topological_power_sort( graph = {} for idx, event_id in enumerate(event_ids, start=1): - yield _add_event_and_auth_chain_to_graph( + await _add_event_and_auth_chain_to_graph( graph, room_id, event_id, event_map, state_res_store, auth_diff ) - # We yield occasionally when we're working with large data sets to + # We await occasionally when we're working with large data sets to # ensure that we don't block the reactor loop for too long. - if idx % _YIELD_AFTER_ITERATIONS == 0: - yield clock.sleep(0) + if idx % _AWAIT_AFTER_ITERATIONS == 0: + await clock.sleep(0) event_to_pl = {} for idx, event_id in enumerate(graph, start=1): - pl = yield _get_power_level_for_sender( + pl = await _get_power_level_for_sender( room_id, event_id, event_map, state_res_store ) event_to_pl[event_id] = pl - # We yield occasionally when we're working with large data sets to + # We await occasionally when we're working with large data sets to # ensure that we don't block the reactor loop for too long. - if idx % _YIELD_AFTER_ITERATIONS == 0: - yield clock.sleep(0) + if idx % _AWAIT_AFTER_ITERATIONS == 0: + await clock.sleep(0) def _get_power_order(event_id): ev = event_map[event_id] @@ -378,8 +371,7 @@ def _reverse_topological_power_sort( return sorted_events -@defer.inlineCallbacks -def _iterative_auth_checks( +async def _iterative_auth_checks( clock, room_id, room_version, event_ids, base_state, event_map, state_res_store ): """Sequentially apply auth checks to each event in given list, updating the @@ -405,7 +397,7 @@ def _iterative_auth_checks( auth_events = {} for aid in event.auth_event_ids(): - ev = yield _get_event( + ev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) @@ -420,7 +412,7 @@ def _iterative_auth_checks( for key in event_auth.auth_types_for_event(event): if key in resolved_state: ev_id = resolved_state[key] - ev = yield _get_event(room_id, ev_id, event_map, state_res_store) + ev = await _get_event(room_id, ev_id, event_map, state_res_store) if ev.rejected_reason is None: auth_events[key] = event_map[ev_id] @@ -438,16 +430,15 @@ def _iterative_auth_checks( except AuthError: pass - # We yield occasionally when we're working with large data sets to + # We await occasionally when we're working with large data sets to # ensure that we don't block the reactor loop for too long. - if idx % _YIELD_AFTER_ITERATIONS == 0: - yield clock.sleep(0) + if idx % _AWAIT_AFTER_ITERATIONS == 0: + await clock.sleep(0) return resolved_state -@defer.inlineCallbacks -def _mainline_sort( +async def _mainline_sort( clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store ): """Returns a sorted list of event_ids sorted by mainline ordering based on @@ -474,21 +465,21 @@ def _mainline_sort( idx = 0 while pl: mainline.append(pl) - pl_ev = yield _get_event(room_id, pl, event_map, state_res_store) + pl_ev = await _get_event(room_id, pl, event_map, state_res_store) auth_events = pl_ev.auth_event_ids() pl = None for aid in auth_events: - ev = yield _get_event( + ev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) if ev and (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""): pl = aid break - # We yield occasionally when we're working with large data sets to + # We await occasionally when we're working with large data sets to # ensure that we don't block the reactor loop for too long. - if idx != 0 and idx % _YIELD_AFTER_ITERATIONS == 0: - yield clock.sleep(0) + if idx != 0 and idx % _AWAIT_AFTER_ITERATIONS == 0: + await clock.sleep(0) idx += 1 @@ -498,23 +489,24 @@ def _mainline_sort( order_map = {} for idx, ev_id in enumerate(event_ids, start=1): - depth = yield _get_mainline_depth_for_event( + depth = await _get_mainline_depth_for_event( event_map[ev_id], mainline_map, event_map, state_res_store ) order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id) - # We yield occasionally when we're working with large data sets to + # We await occasionally when we're working with large data sets to # ensure that we don't block the reactor loop for too long. - if idx % _YIELD_AFTER_ITERATIONS == 0: - yield clock.sleep(0) + if idx % _AWAIT_AFTER_ITERATIONS == 0: + await clock.sleep(0) event_ids.sort(key=lambda ev_id: order_map[ev_id]) return event_ids -@defer.inlineCallbacks -def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_store): +async def _get_mainline_depth_for_event( + event, mainline_map, event_map, state_res_store +): """Get the mainline depths for the given event based on the mainline map Args: @@ -541,7 +533,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor event = None for aid in auth_events: - aev = yield _get_event( + aev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): @@ -552,8 +544,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor return 0 -@defer.inlineCallbacks -def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False): +async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False): """Helper function to look up event in event_map, falling back to looking it up in the store @@ -569,7 +560,7 @@ def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False): Deferred[Optional[FrozenEvent]] """ if event_id not in event_map: - events = yield state_res_store.get_events([event_id], allow_rejected=True) + events = await state_res_store.get_events([event_id], allow_rejected=True) event_map.update(events) event = event_map.get(event_id) diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index d181488db7..c229248101 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -259,7 +259,7 @@ class PushRulesWorkerStore( # To do this we set the state_group to a new object as object() != object() state_group = object() - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) result = yield self._bulk_get_push_rules_for_room( event.room_id, state_group, current_state_ids, event=event ) diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 29765890ee..a92e401e88 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -497,7 +497,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # To do this we set the state_group to a new object as object() != object() state_group = object() - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) result = yield self._get_joined_users_from_context( event.room_id, state_group, current_state_ids, event=event, context=context ) diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/data_stores/main/user_directory.py index 6b8130bf0f..942e51fd3a 100644 --- a/synapse/storage/data_stores/main/user_directory.py +++ b/synapse/storage/data_stores/main/user_directory.py @@ -198,7 +198,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): room_id ) - users_with_profile = yield state.get_current_users_in_room(room_id) + users_with_profile = yield defer.ensureDeferred( + state.get_current_users_in_room(room_id) + ) user_ids = set(users_with_profile) # Update each user in the user directory. diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index fa46041676..78fbdcdee8 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -29,7 +29,6 @@ from synapse.events import FrozenEvent from synapse.events.snapshot import EventContext from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.state import StateResolutionStore from synapse.storage.data_stores import DataStores from synapse.storage.data_stores.main.events import DeltaState from synapse.types import StateMap @@ -648,6 +647,10 @@ class EventsPersistenceStorage(object): room_version = await self.main_store.get_room_version_id(room_id) logger.debug("calling resolve_state_groups from preserve_events") + + # Avoid a circular import. + from synapse.state import StateResolutionStore + res = await self._state_resolution_handler.resolve_state_groups( room_id, room_version, diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 1a9bd5f37d..d1bd18da39 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -26,21 +26,24 @@ from synapse.rest import admin from synapse.rest.client.v1 import login from synapse.types import JsonDict, ReadReceipt +from tests.test_utils import make_awaitable from tests.unittest import HomeserverTestCase, override_config class FederationSenderReceiptsTestCases(HomeserverTestCase): def make_homeserver(self, reactor, clock): + mock_state_handler = Mock(spec=["get_current_hosts_in_room"]) + # Ensure a new Awaitable is created for each call. + mock_state_handler.get_current_hosts_in_room.side_effect = lambda room_Id: make_awaitable( + ["test", "host2"] + ) return self.setup_test_homeserver( - state_handler=Mock(spec=["get_current_hosts_in_room"]), + state_handler=mock_state_handler, federation_transport_client=Mock(spec=["send_transaction"]), ) @override_config({"send_federation": True}) def test_send_receipts(self): - mock_state_handler = self.hs.get_state_handler() - mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] - mock_send_transaction = ( self.hs.get_federation_transport_client().send_transaction ) @@ -81,9 +84,6 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): def test_send_receipts_with_backoff(self): """Send two receipts in quick succession; the second should be flushed, but only after 20ms""" - mock_state_handler = self.hs.get_state_handler() - mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] - mock_send_transaction = ( self.hs.get_federation_transport_client().send_transaction ) @@ -164,7 +164,6 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): def make_homeserver(self, reactor, clock): return self.setup_test_homeserver( - state_handler=Mock(spec=["get_current_hosts_in_room"]), federation_transport_client=Mock(spec=["send_transaction"]), ) @@ -174,10 +173,6 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): return c def prepare(self, reactor, clock, hs): - # stub out get_current_hosts_in_room - mock_state_handler = hs.get_state_handler() - mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] - # stub out get_users_who_share_room_with_user so that it claims that # `@user2:host2` is in the room def get_users_who_share_room_with_user(user_id): diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 38f9b423ef..f2955a9c69 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -14,6 +14,7 @@ # limitations under the License. import itertools +from typing import List import attr @@ -432,7 +433,7 @@ class StateTestCase(unittest.TestCase): state_res_store=TestStateResolutionStore(event_map), ) - state_before = self.successResultOf(state_d) + state_before = self.successResultOf(defer.ensureDeferred(state_d)) state_after = dict(state_before) if fake_event.state_key is not None: @@ -581,7 +582,7 @@ class SimpleParamStateTestCase(unittest.TestCase): state_res_store=TestStateResolutionStore(self.event_map), ) - state = self.successResultOf(state_d) + state = self.successResultOf(defer.ensureDeferred(state_d)) self.assert_dict(self.expected_combined_state, state) @@ -608,9 +609,11 @@ class TestStateResolutionStore(object): Deferred[dict[str, FrozenEvent]]: Dict from event_id to event. """ - return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map} + return defer.succeed( + {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map} + ) - def _get_auth_chain(self, event_ids): + def _get_auth_chain(self, event_ids: List[str]) -> List[str]: """Gets the full auth chain for a set of events (including rejected events). @@ -622,10 +625,10 @@ class TestStateResolutionStore(object): presence of rejected events Args: - event_ids (list): The event IDs of the events to fetch the auth + event_ids: The event IDs of the events to fetch the auth chain for. Must be state events. Returns: - Deferred[list[str]]: List of event IDs of the auth chain. + List of event IDs of the auth chain. """ # Simple DFS for auth chain @@ -648,4 +651,4 @@ class TestStateResolutionStore(object): chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets] common = set(chains[0]).intersection(*chains[1:]) - return set(chains[0]).union(*chains[1:]) - common + return defer.succeed(set(chains[0]).union(*chains[1:]) - common) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index b1dceb2918..1d77b4a2d6 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -109,7 +109,9 @@ class RoomEventsStoreTestCase(unittest.TestCase): etype=EventTypes.Name, name=name, content={"name": name}, depth=1 ) - state = yield self.store.get_current_state(room_id=self.room.to_string()) + state = yield defer.ensureDeferred( + self.store.get_current_state(room_id=self.room.to_string()) + ) self.assertEquals(1, len(state)) self.assertObjectHasAttributes( @@ -125,7 +127,9 @@ class RoomEventsStoreTestCase(unittest.TestCase): etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1 ) - state = yield self.store.get_current_state(room_id=self.room.to_string()) + state = yield defer.ensureDeferred( + self.store.get_current_state(room_id=self.room.to_string()) + ) self.assertEquals(1, len(state)) self.assertObjectHasAttributes( diff --git a/tests/test_state.py b/tests/test_state.py index 66f22f6813..4858e8fc59 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -97,17 +97,19 @@ class StateGroupStore(object): self._group_to_state[state_group] = dict(current_state_ids) - return state_group + return defer.succeed(state_group) def get_events(self, event_ids, **kwargs): - return { - e_id: self._event_id_to_event[e_id] - for e_id in event_ids - if e_id in self._event_id_to_event - } + return defer.succeed( + { + e_id: self._event_id_to_event[e_id] + for e_id in event_ids + if e_id in self._event_id_to_event + } + ) def get_state_group_delta(self, name): - return None, None + return defer.succeed((None, None)) def register_events(self, events): for e in events: @@ -120,7 +122,7 @@ class StateGroupStore(object): self._event_to_state_group[event_id] = state_group def get_room_version_id(self, room_id): - return RoomVersions.V1.identifier + return defer.succeed(RoomVersions.V1.identifier) class DictObj(dict): @@ -202,7 +204,9 @@ class StateTestCase(unittest.TestCase): context_store = {} # type: dict[str, EventContext] for event in graph.walk(): - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event) + ) self.store.register_event_context(event, context) context_store[event.event_id] = context @@ -244,7 +248,9 @@ class StateTestCase(unittest.TestCase): context_store = {} for event in graph.walk(): - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event) + ) self.store.register_event_context(event, context) context_store[event.event_id] = context @@ -300,7 +306,9 @@ class StateTestCase(unittest.TestCase): context_store = {} for event in graph.walk(): - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event) + ) self.store.register_event_context(event, context) context_store[event.event_id] = context @@ -373,7 +381,9 @@ class StateTestCase(unittest.TestCase): context_store = {} for event in graph.walk(): - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event) + ) self.store.register_event_context(event, context) context_store[event.event_id] = context @@ -411,12 +421,14 @@ class StateTestCase(unittest.TestCase): create_event(type="test2", state_key=""), ] - context = yield self.state.compute_event_context(event, old_state=old_state) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event, old_state=old_state) + ) prev_state_ids = yield context.get_prev_state_ids() self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertCountEqual( (e.event_id for e in old_state), current_state_ids.values() ) @@ -434,12 +446,14 @@ class StateTestCase(unittest.TestCase): create_event(type="test2", state_key=""), ] - context = yield self.state.compute_event_context(event, old_state=old_state) + context = yield defer.ensureDeferred( + self.state.compute_event_context(event, old_state=old_state) + ) prev_state_ids = yield context.get_prev_state_ids() self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertCountEqual( (e.event_id for e in old_state + [event]), current_state_ids.values() ) @@ -462,7 +476,7 @@ class StateTestCase(unittest.TestCase): create_event(type="test2", state_key=""), ] - group_name = self.store.store_state_group( + group_name = yield self.store.store_state_group( prev_event_id, event.room_id, None, @@ -471,9 +485,9 @@ class StateTestCase(unittest.TestCase): ) self.store.register_event_id_state_group(prev_event_id, group_name) - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred(self.state.compute_event_context(event)) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual( {e.event_id for e in old_state}, set(current_state_ids.values()) @@ -494,7 +508,7 @@ class StateTestCase(unittest.TestCase): create_event(type="test2", state_key=""), ] - group_name = self.store.store_state_group( + group_name = yield self.store.store_state_group( prev_event_id, event.room_id, None, @@ -503,7 +517,7 @@ class StateTestCase(unittest.TestCase): ) self.store.register_event_id_state_group(prev_event_id, group_name) - context = yield self.state.compute_event_context(event) + context = yield defer.ensureDeferred(self.state.compute_event_context(event)) prev_state_ids = yield context.get_prev_state_ids() @@ -544,7 +558,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual(len(current_state_ids), 6) @@ -586,7 +600,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual(len(current_state_ids), 6) @@ -641,7 +655,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")]) @@ -669,14 +683,15 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids() + current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")]) + @defer.inlineCallbacks def _get_context( self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2 ): - sg1 = self.store.store_state_group( + sg1 = yield self.store.store_state_group( prev_event_id_1, event.room_id, None, @@ -685,7 +700,7 @@ class StateTestCase(unittest.TestCase): ) self.store.register_event_id_state_group(prev_event_id_1, sg1) - sg2 = self.store.store_state_group( + sg2 = yield self.store.store_state_group( prev_event_id_2, event.room_id, None, @@ -694,4 +709,5 @@ class StateTestCase(unittest.TestCase): ) self.store.register_event_id_state_group(prev_event_id_2, sg2) - return self.state.compute_event_context(event) + result = yield defer.ensureDeferred(self.state.compute_event_context(event)) + return result diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index 7b345b03bb..508aeba078 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -17,7 +17,7 @@ """ Utilities for running the unit tests """ -from typing import Awaitable, TypeVar +from typing import Any, Awaitable, TypeVar TV = TypeVar("TV") @@ -36,3 +36,8 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV: # if next didn't raise, the awaitable hasn't completed. raise Exception("awaitable has not yet completed") + + +async def make_awaitable(result: Any): + """Create an awaitable that just returns a result.""" + return result -- cgit 1.5.1 From 3345c166a45cb4a8f87c583ee0476c2bca5c41bd Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 28 Jul 2020 16:09:53 -0400 Subject: Convert storage layer to async/await. (#7963) --- changelog.d/7963.misc | 1 + synapse/storage/persist_events.py | 40 ++++---- synapse/storage/purge_events.py | 38 ++++--- synapse/storage/state.py | 207 ++++++++++++++++++++------------------ tests/storage/test_purge.py | 8 +- tests/storage/test_room.py | 6 +- tests/storage/test_state.py | 64 +++++++----- tests/test_visibility.py | 14 ++- tests/utils.py | 16 +-- tox.ini | 1 + 10 files changed, 210 insertions(+), 185 deletions(-) create mode 100644 changelog.d/7963.misc (limited to 'tests/storage/test_room.py') diff --git a/changelog.d/7963.misc b/changelog.d/7963.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/7963.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 78fbdcdee8..4a164834d9 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -25,7 +25,7 @@ from prometheus_client import Counter, Histogram from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.events import FrozenEvent +from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process @@ -192,12 +192,11 @@ class EventsPersistenceStorage(object): self._event_persist_queue = _EventPeristenceQueue() self._state_resolution_handler = hs.get_state_resolution_handler() - @defer.inlineCallbacks - def persist_events( + async def persist_events( self, - events_and_contexts: List[Tuple[FrozenEvent, EventContext]], + events_and_contexts: List[Tuple[EventBase, EventContext]], backfilled: bool = False, - ): + ) -> int: """ Write events to the database Args: @@ -207,7 +206,7 @@ class EventsPersistenceStorage(object): which might update the current state etc. Returns: - Deferred[int]: the stream ordering of the latest persisted event + the stream ordering of the latest persisted event """ partitioned = {} for event, ctx in events_and_contexts: @@ -223,22 +222,19 @@ class EventsPersistenceStorage(object): for room_id in partitioned: self._maybe_start_persisting(room_id) - yield make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults(deferreds, consumeErrors=True) ) - max_persisted_id = yield self.main_store.get_current_events_token() - - return max_persisted_id + return self.main_store.get_current_events_token() - @defer.inlineCallbacks - def persist_event( - self, event: FrozenEvent, context: EventContext, backfilled: bool = False - ): + async def persist_event( + self, event: EventBase, context: EventContext, backfilled: bool = False + ) -> Tuple[int, int]: """ Returns: - Deferred[Tuple[int, int]]: the stream ordering of ``event``, - and the stream ordering of the latest persisted event + The stream ordering of `event`, and the stream ordering of the + latest persisted event """ deferred = self._event_persist_queue.add_to_queue( event.room_id, [(event, context)], backfilled=backfilled @@ -246,9 +242,9 @@ class EventsPersistenceStorage(object): self._maybe_start_persisting(event.room_id) - yield make_deferred_yieldable(deferred) + await make_deferred_yieldable(deferred) - max_persisted_id = yield self.main_store.get_current_events_token() + max_persisted_id = self.main_store.get_current_events_token() return (event.internal_metadata.stream_ordering, max_persisted_id) def _maybe_start_persisting(self, room_id: str): @@ -262,7 +258,7 @@ class EventsPersistenceStorage(object): async def _persist_events( self, - events_and_contexts: List[Tuple[FrozenEvent, EventContext]], + events_and_contexts: List[Tuple[EventBase, EventContext]], backfilled: bool = False, ): """Calculates the change to current state and forward extremities, and @@ -439,7 +435,7 @@ class EventsPersistenceStorage(object): async def _calculate_new_extremities( self, room_id: str, - event_contexts: List[Tuple[FrozenEvent, EventContext]], + event_contexts: List[Tuple[EventBase, EventContext]], latest_event_ids: List[str], ): """Calculates the new forward extremities for a room given events to @@ -497,7 +493,7 @@ class EventsPersistenceStorage(object): async def _get_new_state_after_events( self, room_id: str, - events_context: List[Tuple[FrozenEvent, EventContext]], + events_context: List[Tuple[EventBase, EventContext]], old_latest_event_ids: Iterable[str], new_latest_event_ids: Iterable[str], ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]: @@ -683,7 +679,7 @@ class EventsPersistenceStorage(object): async def _is_server_still_joined( self, room_id: str, - ev_ctx_rm: List[Tuple[FrozenEvent, EventContext]], + ev_ctx_rm: List[Tuple[EventBase, EventContext]], delta: DeltaState, current_state: Optional[StateMap[str]], potentially_left_users: Set[str], diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py index fdc0abf5cf..79d9f06e2e 100644 --- a/synapse/storage/purge_events.py +++ b/synapse/storage/purge_events.py @@ -15,8 +15,7 @@ import itertools import logging - -from twisted.internet import defer +from typing import Set logger = logging.getLogger(__name__) @@ -28,49 +27,48 @@ class PurgeEventsStorage(object): def __init__(self, hs, stores): self.stores = stores - @defer.inlineCallbacks - def purge_room(self, room_id: str): + async def purge_room(self, room_id: str): """Deletes all record of a room """ - state_groups_to_delete = yield self.stores.main.purge_room(room_id) - yield self.stores.state.purge_room_state(room_id, state_groups_to_delete) + state_groups_to_delete = await self.stores.main.purge_room(room_id) + await self.stores.state.purge_room_state(room_id, state_groups_to_delete) - @defer.inlineCallbacks - def purge_history(self, room_id, token, delete_local_events): + async def purge_history( + self, room_id: str, token: str, delete_local_events: bool + ) -> None: """Deletes room history before a certain point Args: - room_id (str): + room_id: The room ID - token (str): A topological token to delete events before + token: A topological token to delete events before - delete_local_events (bool): + delete_local_events: if True, we will delete local events as well as remote ones (instead of just marking them as outliers and deleting their state groups). """ - state_groups = yield self.stores.main.purge_history( + state_groups = await self.stores.main.purge_history( room_id, token, delete_local_events ) logger.info("[purge] finding state groups that can be deleted") - sg_to_delete = yield self._find_unreferenced_groups(state_groups) + sg_to_delete = await self._find_unreferenced_groups(state_groups) - yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete) + await self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete) - @defer.inlineCallbacks - def _find_unreferenced_groups(self, state_groups): + async def _find_unreferenced_groups(self, state_groups: Set[int]) -> Set[int]: """Used when purging history to figure out which state groups can be deleted. Args: - state_groups (set[int]): Set of state groups referenced by events + state_groups: Set of state groups referenced by events that are going to be deleted. Returns: - Deferred[set[int]] The set of state groups that can be deleted. + The set of state groups that can be deleted. """ # Graph of state group -> previous group graph = {} @@ -93,7 +91,7 @@ class PurgeEventsStorage(object): current_search = set(itertools.islice(next_to_search, 100)) next_to_search -= current_search - referenced = yield self.stores.main.get_referenced_state_groups( + referenced = await self.stores.main.get_referenced_state_groups( current_search ) referenced_groups |= referenced @@ -102,7 +100,7 @@ class PurgeEventsStorage(object): # groups that are referenced. current_search -= referenced - edges = yield self.stores.state.get_previous_state_groups(current_search) + edges = await self.stores.state.get_previous_state_groups(current_search) prevs = set(edges.values()) # We don't bother re-handling groups we've already seen diff --git a/synapse/storage/state.py b/synapse/storage/state.py index dc568476f4..49ee9c9a74 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -14,13 +14,12 @@ # limitations under the License. import logging -from typing import Iterable, List, TypeVar +from typing import Dict, Iterable, List, Optional, Set, Tuple, TypeVar import attr -from twisted.internet import defer - from synapse.api.constants import EventTypes +from synapse.events import EventBase from synapse.types import StateMap logger = logging.getLogger(__name__) @@ -34,16 +33,16 @@ class StateFilter(object): """A filter used when querying for state. Attributes: - types (dict[str, set[str]|None]): Map from type to set of state keys (or - None). This specifies which state_keys for the given type to fetch - from the DB. If None then all events with that type are fetched. If - the set is empty then no events with that type are fetched. - include_others (bool): Whether to fetch events with types that do not + types: Map from type to set of state keys (or None). This specifies + which state_keys for the given type to fetch from the DB. If None + then all events with that type are fetched. If the set is empty + then no events with that type are fetched. + include_others: Whether to fetch events with types that do not appear in `types`. """ - types = attr.ib() - include_others = attr.ib(default=False) + types = attr.ib(type=Dict[str, Optional[Set[str]]]) + include_others = attr.ib(default=False, type=bool) def __attrs_post_init__(self): # If `include_others` is set we canonicalise the filter by removing @@ -52,36 +51,35 @@ class StateFilter(object): self.types = {k: v for k, v in self.types.items() if v is not None} @staticmethod - def all(): + def all() -> "StateFilter": """Creates a filter that fetches everything. Returns: - StateFilter + The new state filter. """ return StateFilter(types={}, include_others=True) @staticmethod - def none(): + def none() -> "StateFilter": """Creates a filter that fetches nothing. Returns: - StateFilter + The new state filter. """ return StateFilter(types={}, include_others=False) @staticmethod - def from_types(types): + def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter": """Creates a filter that only fetches the given types Args: - types (Iterable[tuple[str, str|None]]): A list of type and state - keys to fetch. A state_key of None fetches everything for - that type + types: A list of type and state keys to fetch. A state_key of None + fetches everything for that type Returns: - StateFilter + The new state filter. """ - type_dict = {} + type_dict = {} # type: Dict[str, Optional[Set[str]]] for typ, s in types: if typ in type_dict: if type_dict[typ] is None: @@ -91,24 +89,24 @@ class StateFilter(object): type_dict[typ] = None continue - type_dict.setdefault(typ, set()).add(s) + type_dict.setdefault(typ, set()).add(s) # type: ignore return StateFilter(types=type_dict) @staticmethod - def from_lazy_load_member_list(members): + def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter": """Creates a filter that returns all non-member events, plus the member events for the given users Args: - members (iterable[str]): Set of user IDs + members: Set of user IDs Returns: - StateFilter + The new state filter """ return StateFilter(types={EventTypes.Member: set(members)}, include_others=True) - def return_expanded(self): + def return_expanded(self) -> "StateFilter": """Creates a new StateFilter where type wild cards have been removed (except for memberships). The returned filter is a superset of the current one, i.e. anything that passes the current filter will pass @@ -130,7 +128,7 @@ class StateFilter(object): return all non-member events Returns: - StateFilter + The new state filter. """ if self.is_full(): @@ -167,7 +165,7 @@ class StateFilter(object): include_others=True, ) - def make_sql_filter_clause(self): + def make_sql_filter_clause(self) -> Tuple[str, List[str]]: """Converts the filter to an SQL clause. For example: @@ -179,13 +177,12 @@ class StateFilter(object): Returns: - tuple[str, list]: The SQL string (may be empty) and arguments. An - empty SQL string is returned when the filter matches everything - (i.e. is "full"). + The SQL string (may be empty) and arguments. An empty SQL string is + returned when the filter matches everything (i.e. is "full"). """ where_clause = "" - where_args = [] + where_args = [] # type: List[str] if self.is_full(): return where_clause, where_args @@ -221,7 +218,7 @@ class StateFilter(object): return where_clause, where_args - def max_entries_returned(self): + def max_entries_returned(self) -> Optional[int]: """Returns the maximum number of entries this filter will return if known, otherwise returns None. @@ -260,33 +257,33 @@ class StateFilter(object): return filtered_state - def is_full(self): + def is_full(self) -> bool: """Whether this filter fetches everything or not Returns: - bool + True if the filter fetches everything. """ return self.include_others and not self.types - def has_wildcards(self): + def has_wildcards(self) -> bool: """Whether the filter includes wildcards or is attempting to fetch specific state. Returns: - bool + True if the filter includes wildcards. """ return self.include_others or any( state_keys is None for state_keys in self.types.values() ) - def concrete_types(self): + def concrete_types(self) -> List[Tuple[str, str]]: """Returns a list of concrete type/state_keys (i.e. not None) that will be fetched. This will be a complete list if `has_wildcards` returns False, but otherwise will be a subset (or even empty). Returns: - list[tuple[str,str]] + A list of type/state_keys tuples. """ return [ (t, s) @@ -295,7 +292,7 @@ class StateFilter(object): for s in state_keys ] - def get_member_split(self): + def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]: """Return the filter split into two: one which assumes it's exclusively matching against member state, and one which assumes it's matching against non member state. @@ -307,7 +304,7 @@ class StateFilter(object): state caches). Returns: - tuple[StateFilter, StateFilter]: The member and non member filters + The member and non member filters """ if EventTypes.Member in self.types: @@ -340,6 +337,9 @@ class StateGroupStorage(object): """Given a state group try to return a previous group and a delta between the old and the new. + Args: + state_group: The state group used to retrieve state deltas. + Returns: Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]: (prev_group, delta_ids) @@ -347,55 +347,59 @@ class StateGroupStorage(object): return self.stores.state.get_state_group_delta(state_group) - @defer.inlineCallbacks - def get_state_groups_ids(self, _room_id, event_ids): + async def get_state_groups_ids( + self, _room_id: str, event_ids: Iterable[str] + ) -> Dict[int, StateMap[str]]: """Get the event IDs of all the state for the state groups for the given events Args: - _room_id (str): id of the room for these events - event_ids (iterable[str]): ids of the events + _room_id: id of the room for these events + event_ids: ids of the events Returns: - Deferred[dict[int, StateMap[str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) + dict of state_group_id -> (dict of (type, state_key) -> event id) """ if not event_ids: return {} - event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) + event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) groups = set(event_to_groups.values()) - group_to_state = yield self.stores.state._get_state_for_groups(groups) + group_to_state = await self.stores.state._get_state_for_groups(groups) return group_to_state - @defer.inlineCallbacks - def get_state_ids_for_group(self, state_group): + async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]: """Get the event IDs of all the state in the given state group Args: - state_group (int) + state_group: A state group for which we want to get the state IDs. Returns: - Deferred[dict]: Resolves to a map of (type, state_key) -> event_id + Resolves to a map of (type, state_key) -> event_id """ - group_to_state = yield self._get_state_for_groups((state_group,)) + group_to_state = await self._get_state_for_groups((state_group,)) return group_to_state[state_group] - @defer.inlineCallbacks - def get_state_groups(self, room_id, event_ids): + async def get_state_groups( + self, room_id: str, event_ids: Iterable[str] + ) -> Dict[int, List[EventBase]]: """ Get the state groups for the given list of event_ids + + Args: + room_id: ID of the room for these events. + event_ids: The event IDs to retrieve state for. + Returns: - Deferred[dict[int, list[EventBase]]]: - dict of state_group_id -> list of state events. + dict of state_group_id -> list of state events. """ if not event_ids: return {} - group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) + group_to_ids = await self.get_state_groups_ids(room_id, event_ids) - state_event_map = yield self.stores.main.get_events( + state_event_map = await self.stores.main.get_events( [ ev_id for group_ids in group_to_ids.values() @@ -423,31 +427,34 @@ class StateGroupStorage(object): groups: list of state group IDs to query state_filter: The state filter used to fetch state from the database. + Returns: Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. """ return self.stores.state._get_state_groups_from_groups(groups, state_filter) - @defer.inlineCallbacks - def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): + async def get_state_for_events( + self, event_ids: List[str], state_filter: StateFilter = StateFilter.all() + ): """Given a list of event_ids and type tuples, return a list of state dicts for each event. + Args: - event_ids (list[string]) - state_filter (StateFilter): The state filter used to fetch state - from the database. + event_ids: The events to fetch the state of. + state_filter: The state filter used to fetch state. + Returns: - deferred: A dict of (event_id) -> (type, state_key) -> [state_events] + A dict of (event_id) -> (type, state_key) -> [state_events] """ - event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) + event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) groups = set(event_to_groups.values()) - group_to_state = yield self.stores.state._get_state_for_groups( + group_to_state = await self.stores.state._get_state_for_groups( groups, state_filter ) - state_event_map = yield self.stores.main.get_events( + state_event_map = await self.stores.main.get_events( [ev_id for sd in group_to_state.values() for ev_id in sd.values()], get_prev_content=False, ) @@ -463,24 +470,24 @@ class StateGroupStorage(object): return {event: event_to_state[event] for event in event_ids} - @defer.inlineCallbacks - def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()): + async def get_state_ids_for_events( + self, event_ids: List[str], state_filter: StateFilter = StateFilter.all() + ): """ Get the state dicts corresponding to a list of events, containing the event_ids of the state events (as opposed to the events themselves) Args: - event_ids(list(str)): events whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. + event_ids: events whose state should be returned + state_filter: The state filter used to fetch state from the database. Returns: - A deferred dict from event_id -> (type, state_key) -> event_id + A dict from event_id -> (type, state_key) -> event_id """ - event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) + event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) groups = set(event_to_groups.values()) - group_to_state = yield self.stores.state._get_state_for_groups( + group_to_state = await self.stores.state._get_state_for_groups( groups, state_filter ) @@ -491,36 +498,36 @@ class StateGroupStorage(object): return {event: event_to_state[event] for event in event_ids} - @defer.inlineCallbacks - def get_state_for_event(self, event_id, state_filter=StateFilter.all()): + async def get_state_for_event( + self, event_id: str, state_filter: StateFilter = StateFilter.all() + ): """ Get the state dict corresponding to a particular event Args: - event_id(str): event whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. + event_id: event whose state should be returned + state_filter: The state filter used to fetch state from the database. Returns: - A deferred dict from (type, state_key) -> state_event + A dict from (type, state_key) -> state_event """ - state_map = yield self.get_state_for_events([event_id], state_filter) + state_map = await self.get_state_for_events([event_id], state_filter) return state_map[event_id] - @defer.inlineCallbacks - def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()): + async def get_state_ids_for_event( + self, event_id: str, state_filter: StateFilter = StateFilter.all() + ): """ Get the state dict corresponding to a particular event Args: - event_id(str): event whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. + event_id: event whose state should be returned + state_filter: The state filter used to fetch state from the database. Returns: A deferred dict from (type, state_key) -> state_event """ - state_map = yield self.get_state_ids_for_events([event_id], state_filter) + state_map = await self.get_state_ids_for_events([event_id], state_filter) return state_map[event_id] def _get_state_for_groups( @@ -530,9 +537,8 @@ class StateGroupStorage(object): filtering by type/state_key Args: - groups (iterable[int]): list of state groups for which we want - to get the state. - state_filter (StateFilter): The state filter used to fetch state + groups: list of state groups for which we want to get the state. + state_filter: The state filter used to fetch state. from the database. Returns: Deferred[dict[int, StateMap[str]]]: Dict of state group to state map. @@ -540,18 +546,23 @@ class StateGroupStorage(object): return self.stores.state._get_state_for_groups(groups, state_filter) def store_state_group( - self, event_id, room_id, prev_group, delta_ids, current_state_ids + self, + event_id: str, + room_id: str, + prev_group: Optional[int], + delta_ids: Optional[dict], + current_state_ids: dict, ): """Store a new set of state, returning a newly assigned state group. Args: - event_id (str): The event ID for which the state was calculated - room_id (str) - prev_group (int|None): A previous state group for the room, optional. - delta_ids (dict|None): The delta between state at `prev_group` and + event_id: The event ID for which the state was calculated. + room_id: ID of the room for which the state was calculated. + prev_group: A previous state group for the room, optional. + delta_ids: The delta between state at `prev_group` and `current_state_ids`, if `prev_group` was given. Same format as `current_state_ids`. - current_state_ids (dict): The state to store. Map of (type, state_key) + current_state_ids: The state to store. Map of (type, state_key) to event_id. Returns: diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index b9fafaa1a6..a6012c973d 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer + from synapse.rest.client.v1 import room from tests.unittest import HomeserverTestCase @@ -49,7 +51,9 @@ class PurgeTests(HomeserverTestCase): event = self.successResultOf(event) # Purge everything before this topological token - purge = storage.purge_events.purge_history(self.room_id, event, True) + purge = defer.ensureDeferred( + storage.purge_events.purge_history(self.room_id, event, True) + ) self.pump() self.assertEqual(self.successResultOf(purge), None) @@ -88,7 +92,7 @@ class PurgeTests(HomeserverTestCase): ) # Purge everything before this topological token - purge = storage.purge_history(self.room_id, event, True) + purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True)) self.pump() f = self.failureResultOf(purge) self.assertIn("greater than forward", f.value.args[0]) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 1d77b4a2d6..a5f250d477 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -97,8 +97,10 @@ class RoomEventsStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def inject_room_event(self, **kwargs): - yield self.storage.persistence.persist_event( - self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) + yield defer.ensureDeferred( + self.storage.persistence.persist_event( + self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) + ) ) @defer.inlineCallbacks diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index a0e133cd4a..6a48b9d3b3 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -68,7 +68,9 @@ class StateStoreTestCase(tests.unittest.TestCase): self.event_creation_handler.create_new_client_event(builder) ) - yield self.storage.persistence.persist_event(event, context) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) + ) return event @@ -87,8 +89,8 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield self.storage.state.get_state_groups_ids( - self.room, [e2.event_id] + state_group_map = yield defer.ensureDeferred( + self.storage.state.get_state_groups_ids(self.room, [e2.event_id]) ) self.assertEqual(len(state_group_map), 1) state_map = list(state_group_map.values())[0] @@ -106,8 +108,8 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield self.storage.state.get_state_groups( - self.room, [e2.event_id] + state_group_map = yield defer.ensureDeferred( + self.storage.state.get_state_groups(self.room, [e2.event_id]) ) self.assertEqual(len(state_group_map), 1) state_list = list(state_group_map.values())[0] @@ -148,7 +150,9 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we get the full state as of the final event - state = yield self.storage.state.get_state_for_event(e5.event_id) + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event(e5.event_id) + ) self.assertIsNotNone(e4) @@ -164,22 +168,28 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we can filter to the m.room.name event (with a '' state key) - state = yield self.storage.state.get_state_for_event( - e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) + ) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can filter to the m.room.name event (with a wildcard None state key) - state = yield self.storage.state.get_state_for_event( - e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) + ) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can grab the m.room.member events (with a wildcard None state key) - state = yield self.storage.state.get_state_for_event( - e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) + ) ) self.assertStateMapEqual( @@ -188,12 +198,14 @@ class StateStoreTestCase(tests.unittest.TestCase): # check we can grab a specific room member without filtering out the # other event types - state = yield self.storage.state.get_state_for_event( - e5.event_id, - state_filter=StateFilter( - types={EventTypes.Member: {self.u_alice.to_string()}}, - include_others=True, - ), + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, + state_filter=StateFilter( + types={EventTypes.Member: {self.u_alice.to_string()}}, + include_others=True, + ), + ) ) self.assertStateMapEqual( @@ -206,11 +218,13 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check that we can grab everything except members - state = yield self.storage.state.get_state_for_event( - e5.event_id, - state_filter=StateFilter( - types={EventTypes.Member: set()}, include_others=True - ), + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, + state_filter=StateFilter( + types={EventTypes.Member: set()}, include_others=True + ), + ) ) self.assertStateMapEqual( @@ -222,8 +236,8 @@ class StateStoreTestCase(tests.unittest.TestCase): ####################################################### room_id = self.room.to_string() - group_ids = yield self.storage.state.get_state_groups_ids( - room_id, [e5.event_id] + group_ids = yield defer.ensureDeferred( + self.storage.state.get_state_groups_ids(room_id, [e5.event_id]) ) group = list(group_ids.keys())[0] diff --git a/tests/test_visibility.py b/tests/test_visibility.py index a7a36174ea..531a9b9118 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -40,7 +40,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): self.store = self.hs.get_datastore() self.storage = self.hs.get_storage() - yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM") + yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")) @defer.inlineCallbacks def test_filtering(self): @@ -140,7 +140,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): event, context = yield defer.ensureDeferred( self.event_creation_handler.create_new_client_event(builder) ) - yield self.storage.persistence.persist_event(event, context) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) + ) return event @defer.inlineCallbacks @@ -162,7 +164,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): self.event_creation_handler.create_new_client_event(builder) ) - yield self.storage.persistence.persist_event(event, context) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) + ) return event @defer.inlineCallbacks @@ -183,7 +187,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): self.event_creation_handler.create_new_client_event(builder) ) - yield self.storage.persistence.persist_event(event, context) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) + ) return event @defer.inlineCallbacks diff --git a/tests/utils.py b/tests/utils.py index ac643679aa..b33b6860d4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -638,14 +638,8 @@ class DeferredMockCallable(object): ) -@defer.inlineCallbacks -def create_room(hs, room_id, creator_id): +async def create_room(hs, room_id: str, creator_id: str): """Creates and persist a creation event for the given room - - Args: - hs - room_id (str) - creator_id (str) """ persistence_store = hs.get_storage().persistence @@ -653,7 +647,7 @@ def create_room(hs, room_id, creator_id): event_builder_factory = hs.get_event_builder_factory() event_creation_handler = hs.get_event_creation_handler() - yield store.store_room( + await store.store_room( room_id=room_id, room_creator_user_id=creator_id, is_public=False, @@ -671,8 +665,6 @@ def create_room(hs, room_id, creator_id): }, ) - event, context = yield defer.ensureDeferred( - event_creation_handler.create_new_client_event(builder) - ) + event, context = await event_creation_handler.create_new_client_event(builder) - yield persistence_store.persist_event(event, context) + await persistence_store.persist_event(event, context) diff --git a/tox.ini b/tox.ini index 595ab3ba66..a394f6eadc 100644 --- a/tox.ini +++ b/tox.ini @@ -206,6 +206,7 @@ commands = mypy \ synapse/storage/data_stores/main/ui_auth.py \ synapse/storage/database.py \ synapse/storage/engines \ + synapse/storage/state.py \ synapse/storage/util \ synapse/streams \ synapse/util/caches/stream_change_cache.py \ -- cgit 1.5.1 From b3a97d6dac7f9f619b02e213bb8a745d65983d0d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 30 Jul 2020 07:20:41 -0400 Subject: Convert some of the data store to async. (#7976) --- changelog.d/7976.misc | 1 + .../storage/data_stores/main/event_push_actions.py | 92 ++++++++++---------- synapse/storage/data_stores/main/room.py | 98 ++++++++++------------ synapse/storage/data_stores/main/state.py | 57 ++++++------- synapse/storage/data_stores/main/stats.py | 53 ++++++------ synapse/storage/data_stores/state/store.py | 37 ++++---- synapse/storage/state.py | 11 +-- tests/storage/test_event_push_actions.py | 12 ++- tests/storage/test_room.py | 24 +++--- tests/storage/test_state.py | 12 +-- 10 files changed, 190 insertions(+), 207 deletions(-) create mode 100644 changelog.d/7976.misc (limited to 'tests/storage/test_room.py') diff --git a/changelog.d/7976.misc b/changelog.d/7976.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/7976.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index 18297cf3b8..ad82838901 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -15,11 +15,10 @@ # limitations under the License. import logging +from typing import List from canonicaljson import json -from twisted.internet import defer - from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json from synapse.storage.database import Database @@ -166,8 +165,9 @@ class EventPushActionsWorkerStore(SQLBaseStore): return {"notify_count": notify_count, "highlight_count": highlight_count} - @defer.inlineCallbacks - def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering): + async def get_push_action_users_in_range( + self, min_stream_ordering, max_stream_ordering + ): def f(txn): sql = ( "SELECT DISTINCT(user_id) FROM event_push_actions WHERE" @@ -176,26 +176,28 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, (min_stream_ordering, max_stream_ordering)) return [r[0] for r in txn] - ret = yield self.db.runInteraction("get_push_action_users_in_range", f) + ret = await self.db.runInteraction("get_push_action_users_in_range", f) return ret - @defer.inlineCallbacks - def get_unread_push_actions_for_user_in_range_for_http( - self, user_id, min_stream_ordering, max_stream_ordering, limit=20 - ): + async def get_unread_push_actions_for_user_in_range_for_http( + self, + user_id: str, + min_stream_ordering: int, + max_stream_ordering: int, + limit: int = 20, + ) -> List[dict]: """Get a list of the most recent unread push actions for a given user, within the given stream ordering range. Called by the httppusher. Args: - user_id (str): The user to fetch push actions for. - min_stream_ordering(int): The exclusive lower bound on the + user_id: The user to fetch push actions for. + min_stream_ordering: The exclusive lower bound on the stream ordering of event push actions to fetch. - max_stream_ordering(int): The inclusive upper bound on the + max_stream_ordering: The inclusive upper bound on the stream ordering of event push actions to fetch. - limit (int): The maximum number of rows to return. + limit: The maximum number of rows to return. Returns: - A promise which resolves to a list of dicts with the keys "event_id", - "room_id", "stream_ordering", "actions". + A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions". The list will be ordered by ascending stream_ordering. The list will have between 0~limit entries. """ @@ -228,7 +230,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, args) return txn.fetchall() - after_read_receipt = yield self.db.runInteraction( + after_read_receipt = await self.db.runInteraction( "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt ) @@ -256,7 +258,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, args) return txn.fetchall() - no_read_receipt = yield self.db.runInteraction( + no_read_receipt = await self.db.runInteraction( "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt ) @@ -280,23 +282,25 @@ class EventPushActionsWorkerStore(SQLBaseStore): # one of the subqueries may have hit the limit. return notifs[:limit] - @defer.inlineCallbacks - def get_unread_push_actions_for_user_in_range_for_email( - self, user_id, min_stream_ordering, max_stream_ordering, limit=20 - ): + async def get_unread_push_actions_for_user_in_range_for_email( + self, + user_id: str, + min_stream_ordering: int, + max_stream_ordering: int, + limit: int = 20, + ) -> List[dict]: """Get a list of the most recent unread push actions for a given user, within the given stream ordering range. Called by the emailpusher Args: - user_id (str): The user to fetch push actions for. - min_stream_ordering(int): The exclusive lower bound on the + user_id: The user to fetch push actions for. + min_stream_ordering: The exclusive lower bound on the stream ordering of event push actions to fetch. - max_stream_ordering(int): The inclusive upper bound on the + max_stream_ordering: The inclusive upper bound on the stream ordering of event push actions to fetch. - limit (int): The maximum number of rows to return. + limit: The maximum number of rows to return. Returns: - A promise which resolves to a list of dicts with the keys "event_id", - "room_id", "stream_ordering", "actions", "received_ts". + A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions", "received_ts". The list will be ordered by descending received_ts. The list will have between 0~limit entries. """ @@ -328,7 +332,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, args) return txn.fetchall() - after_read_receipt = yield self.db.runInteraction( + after_read_receipt = await self.db.runInteraction( "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt ) @@ -356,7 +360,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, args) return txn.fetchall() - no_read_receipt = yield self.db.runInteraction( + no_read_receipt = await self.db.runInteraction( "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt ) @@ -461,17 +465,13 @@ class EventPushActionsWorkerStore(SQLBaseStore): "add_push_actions_to_staging", _add_push_actions_to_staging_txn ) - @defer.inlineCallbacks - def remove_push_actions_from_staging(self, event_id): + async def remove_push_actions_from_staging(self, event_id: str) -> None: """Called if we failed to persist the event to ensure that stale push actions don't build up in the DB - - Args: - event_id (str) """ try: - res = yield self.db.simple_delete( + res = await self.db.simple_delete( table="event_push_actions_staging", keyvalues={"event_id": event_id}, desc="remove_push_actions_from_staging", @@ -606,8 +606,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): return range_end - @defer.inlineCallbacks - def get_time_of_last_push_action_before(self, stream_ordering): + async def get_time_of_last_push_action_before(self, stream_ordering): def f(txn): sql = ( "SELECT e.received_ts" @@ -620,7 +619,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, (stream_ordering,)) return txn.fetchone() - result = yield self.db.runInteraction("get_time_of_last_push_action_before", f) + result = await self.db.runInteraction("get_time_of_last_push_action_before", f) return result[0] if result else None @@ -650,8 +649,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): self._start_rotate_notifs, 30 * 60 * 1000 ) - @defer.inlineCallbacks - def get_push_actions_for_user( + async def get_push_actions_for_user( self, user_id, before=None, limit=50, only_highlight=False ): def f(txn): @@ -682,18 +680,17 @@ class EventPushActionsStore(EventPushActionsWorkerStore): txn.execute(sql, args) return self.db.cursor_to_dict(txn) - push_actions = yield self.db.runInteraction("get_push_actions_for_user", f) + push_actions = await self.db.runInteraction("get_push_actions_for_user", f) for pa in push_actions: pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) return push_actions - @defer.inlineCallbacks - def get_latest_push_action_stream_ordering(self): + async def get_latest_push_action_stream_ordering(self): def f(txn): txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions") return txn.fetchone() - result = yield self.db.runInteraction( + result = await self.db.runInteraction( "get_latest_push_action_stream_ordering", f ) return result[0] or 0 @@ -747,8 +744,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): def _start_rotate_notifs(self): return run_as_background_process("rotate_notifs", self._rotate_notifs) - @defer.inlineCallbacks - def _rotate_notifs(self): + async def _rotate_notifs(self): if self._doing_notif_rotation or self.stream_ordering_day_ago is None: return self._doing_notif_rotation = True @@ -757,12 +753,12 @@ class EventPushActionsStore(EventPushActionsWorkerStore): while True: logger.info("Rotating notifications") - caught_up = yield self.db.runInteraction( + caught_up = await self.db.runInteraction( "_rotate_notifs", self._rotate_notifs_txn ) if caught_up: break - yield self.hs.get_clock().sleep(self._rotate_delay) + await self.hs.get_clock().sleep(self._rotate_delay) finally: self._doing_notif_rotation = False diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index d2e1e36e7f..ab48052cdc 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -23,8 +23,6 @@ from typing import Any, Dict, List, Optional, Tuple from canonicaljson import json -from twisted.internet import defer - from synapse.api.constants import EventTypes from synapse.api.errors import StoreError from synapse.api.room_versions import RoomVersion, RoomVersions @@ -32,7 +30,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.data_stores.main.search import SearchStore from synapse.storage.database import Database, LoggingTransaction from synapse.types import ThirdPartyInstanceID -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -192,8 +190,7 @@ class RoomWorkerStore(SQLBaseStore): return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn) - @defer.inlineCallbacks - def get_largest_public_rooms( + async def get_largest_public_rooms( self, network_tuple: Optional[ThirdPartyInstanceID], search_filter: Optional[dict], @@ -330,10 +327,10 @@ class RoomWorkerStore(SQLBaseStore): return results - ret_val = yield self.db.runInteraction( + ret_val = await self.db.runInteraction( "get_largest_public_rooms", _get_largest_public_rooms_txn ) - defer.returnValue(ret_val) + return ret_val @cached(max_entries=10000) def is_room_blocked(self, room_id): @@ -509,8 +506,8 @@ class RoomWorkerStore(SQLBaseStore): "get_rooms_paginate", _get_rooms_paginate_txn, ) - @cachedInlineCallbacks(max_entries=10000) - def get_ratelimit_for_user(self, user_id): + @cached(max_entries=10000) + async def get_ratelimit_for_user(self, user_id): """Check if there are any overrides for ratelimiting for the given user @@ -522,7 +519,7 @@ class RoomWorkerStore(SQLBaseStore): of RatelimitOverride are None or 0 then ratelimitng has been disabled for that user entirely. """ - row = yield self.db.simple_select_one( + row = await self.db.simple_select_one( table="ratelimit_override", keyvalues={"user_id": user_id}, retcols=("messages_per_second", "burst_count"), @@ -538,8 +535,8 @@ class RoomWorkerStore(SQLBaseStore): else: return None - @cachedInlineCallbacks() - def get_retention_policy_for_room(self, room_id): + @cached() + async def get_retention_policy_for_room(self, room_id): """Get the retention policy for a given room. If no retention policy has been found for this room, returns a policy defined @@ -566,19 +563,17 @@ class RoomWorkerStore(SQLBaseStore): return self.db.cursor_to_dict(txn) - ret = yield self.db.runInteraction( + ret = await self.db.runInteraction( "get_retention_policy_for_room", get_retention_policy_for_room_txn, ) # If we don't know this room ID, ret will be None, in this case return the default # policy. if not ret: - defer.returnValue( - { - "min_lifetime": self.config.retention_default_min_lifetime, - "max_lifetime": self.config.retention_default_max_lifetime, - } - ) + return { + "min_lifetime": self.config.retention_default_min_lifetime, + "max_lifetime": self.config.retention_default_max_lifetime, + } row = ret[0] @@ -592,7 +587,7 @@ class RoomWorkerStore(SQLBaseStore): if row["max_lifetime"] is None: row["max_lifetime"] = self.config.retention_default_max_lifetime - defer.returnValue(row) + return row def get_media_mxcs_in_room(self, room_id): """Retrieves all the local and remote media MXC URIs in a given room @@ -881,8 +876,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): self._background_add_rooms_room_version_column, ) - @defer.inlineCallbacks - def _background_insert_retention(self, progress, batch_size): + async def _background_insert_retention(self, progress, batch_size): """Retrieves a list of all rooms within a range and inserts an entry for each of them into the room_retention table. NULLs the property's columns if missing from the retention event in the room's @@ -940,14 +934,14 @@ class RoomBackgroundUpdateStore(SQLBaseStore): else: return False - end = yield self.db.runInteraction( + end = await self.db.runInteraction( "insert_room_retention", _background_insert_retention_txn, ) if end: - yield self.db.updates._end_background_update("insert_room_retention") + await self.db.updates._end_background_update("insert_room_retention") - defer.returnValue(batch_size) + return batch_size async def _background_add_rooms_room_version_column( self, progress: dict, batch_size: int @@ -1096,8 +1090,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): lock=False, ) - @defer.inlineCallbacks - def store_room( + async def store_room( self, room_id: str, room_creator_user_id: str, @@ -1140,7 +1133,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) with self._public_room_id_gen.get_next() as next_id: - yield self.db.runInteraction("store_room_txn", store_room_txn, next_id) + await self.db.runInteraction("store_room_txn", store_room_txn, next_id) except Exception as e: logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") @@ -1165,8 +1158,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): lock=False, ) - @defer.inlineCallbacks - def set_room_is_public(self, room_id, is_public): + async def set_room_is_public(self, room_id, is_public): def set_room_is_public_txn(txn, next_id): self.db.simple_update_one_txn( txn, @@ -1206,13 +1198,12 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) with self._public_room_id_gen.get_next() as next_id: - yield self.db.runInteraction( + await self.db.runInteraction( "set_room_is_public", set_room_is_public_txn, next_id ) self.hs.get_notifier().on_new_replication_data() - @defer.inlineCallbacks - def set_room_is_public_appservice( + async def set_room_is_public_appservice( self, room_id, appservice_id, network_id, is_public ): """Edit the appservice/network specific public room list. @@ -1287,7 +1278,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) with self._public_room_id_gen.get_next() as next_id: - yield self.db.runInteraction( + await self.db.runInteraction( "set_room_is_public_appservice", set_room_is_public_appservice_txn, next_id, @@ -1327,52 +1318,47 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): def get_current_public_room_stream_id(self): return self._public_room_id_gen.get_current_token() - @defer.inlineCallbacks - def block_room(self, room_id, user_id): + async def block_room(self, room_id: str, user_id: str) -> None: """Marks the room as blocked. Can be called multiple times. Args: - room_id (str): Room to block - user_id (str): Who blocked it - - Returns: - Deferred + room_id: Room to block + user_id: Who blocked it """ - yield self.db.simple_upsert( + await self.db.simple_upsert( table="blocked_rooms", keyvalues={"room_id": room_id}, values={}, insertion_values={"user_id": user_id}, desc="block_room", ) - yield self.db.runInteraction( + await self.db.runInteraction( "block_room_invalidation", self._invalidate_cache_and_stream, self.is_room_blocked, (room_id,), ) - @defer.inlineCallbacks - def get_rooms_for_retention_period_in_range( - self, min_ms, max_ms, include_null=False - ): + async def get_rooms_for_retention_period_in_range( + self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False + ) -> Dict[str, dict]: """Retrieves all of the rooms within the given retention range. Optionally includes the rooms which don't have a retention policy. Args: - min_ms (int|None): Duration in milliseconds that define the lower limit of + min_ms: Duration in milliseconds that define the lower limit of the range to handle (exclusive). If None, doesn't set a lower limit. - max_ms (int|None): Duration in milliseconds that define the upper limit of + max_ms: Duration in milliseconds that define the upper limit of the range to handle (inclusive). If None, doesn't set an upper limit. - include_null (bool): Whether to include rooms which retention policy is NULL + include_null: Whether to include rooms which retention policy is NULL in the returned set. Returns: - dict[str, dict]: The rooms within this range, along with their retention - policy. The key is "room_id", and maps to a dict describing the retention - policy associated with this room ID. The keys for this nested dict are - "min_lifetime" (int|None), and "max_lifetime" (int|None). + The rooms within this range, along with their retention + policy. The key is "room_id", and maps to a dict describing the retention + policy associated with this room ID. The keys for this nested dict are + "min_lifetime" (int|None), and "max_lifetime" (int|None). """ def get_rooms_for_retention_period_in_range_txn(txn): @@ -1431,9 +1417,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): return rooms_dict - rooms = yield self.db.runInteraction( + rooms = await self.db.runInteraction( "get_rooms_for_retention_period_in_range", get_rooms_for_retention_period_in_range_txn, ) - defer.returnValue(rooms) + return rooms diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index bb38a04ede..a360699408 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -16,12 +16,12 @@ import collections.abc import logging from collections import namedtuple - -from twisted.internet import defer +from typing import Iterable, Optional, Set from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion +from synapse.events import EventBase from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore @@ -108,28 +108,27 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): create_event = await self.get_create_event_for_room(room_id) return create_event.content.get("room_version", "1") - @defer.inlineCallbacks - def get_room_predecessor(self, room_id): + async def get_room_predecessor(self, room_id: str) -> Optional[dict]: """Get the predecessor of an upgraded room if it exists. Otherwise return None. Args: - room_id (str) + room_id: The room ID. Returns: - Deferred[dict|None]: A dictionary containing the structure of the predecessor - field from the room's create event. The structure is subject to other servers, - but it is expected to be: - * room_id (str): The room ID of the predecessor room - * event_id (str): The ID of the tombstone event in the predecessor room + A dictionary containing the structure of the predecessor + field from the room's create event. The structure is subject to other servers, + but it is expected to be: + * room_id (str): The room ID of the predecessor room + * event_id (str): The ID of the tombstone event in the predecessor room - None if a predecessor key is not found, or is not a dictionary. + None if a predecessor key is not found, or is not a dictionary. Raises: NotFoundError if the given room is unknown """ # Retrieve the room's create event - create_event = yield self.get_create_event_for_room(room_id) + create_event = await self.get_create_event_for_room(room_id) # Retrieve the predecessor key of the create event predecessor = create_event.content.get("predecessor", None) @@ -140,20 +139,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return predecessor - @defer.inlineCallbacks - def get_create_event_for_room(self, room_id): + async def get_create_event_for_room(self, room_id: str) -> EventBase: """Get the create state event for a room. Args: - room_id (str) + room_id: The room ID. Returns: - Deferred[EventBase]: The room creation event. + The room creation event. Raises: NotFoundError if the room is unknown """ - state_ids = yield self.get_current_state_ids(room_id) + state_ids = await self.get_current_state_ids(room_id) create_id = state_ids.get((EventTypes.Create, "")) # If we can't find the create event, assume we've hit a dead end @@ -161,7 +159,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): raise NotFoundError("Unknown room %s" % (room_id,)) # Retrieve the room's create event and return - create_event = yield self.get_event(create_id) + create_event = await self.get_event(create_id) return create_event @cached(max_entries=100000, iterable=True) @@ -237,18 +235,17 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn ) - @defer.inlineCallbacks - def get_canonical_alias_for_room(self, room_id): + async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]: """Get canonical alias for room, if any Args: - room_id (str) + room_id: The room ID Returns: - Deferred[str|None]: The canonical alias, if any + The canonical alias, if any """ - state = yield self.get_filtered_current_state_ids( + state = await self.get_filtered_current_state_ids( room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")]) ) @@ -256,7 +253,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): if not event_id: return - event = yield self.get_event(event_id, allow_none=True) + event = await self.get_event(event_id, allow_none=True) if not event: return @@ -292,19 +289,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return {row["event_id"]: row["state_group"] for row in rows} - @defer.inlineCallbacks - def get_referenced_state_groups(self, state_groups): + async def get_referenced_state_groups( + self, state_groups: Iterable[int] + ) -> Set[int]: """Check if the state groups are referenced by events. Args: - state_groups (Iterable[int]) + state_groups Returns: - Deferred[set[int]]: The subset of state groups that are - referenced. + The subset of state groups that are referenced. """ - rows = yield self.db.simple_select_many_batch( + rows = await self.db.simple_select_many_batch( table="event_to_state_groups", column="state_group", iterable=state_groups, diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py index 380c1ec7da..922400a7c3 100644 --- a/synapse/storage/data_stores/main/stats.py +++ b/synapse/storage/data_stores/main/stats.py @@ -16,8 +16,8 @@ import logging from itertools import chain +from typing import Tuple -from twisted.internet import defer from twisted.internet.defer import DeferredLock from synapse.api.constants import EventTypes, Membership @@ -97,13 +97,12 @@ class StatsStore(StateDeltasStore): """ return (ts // self.stats_bucket_size) * self.stats_bucket_size - @defer.inlineCallbacks - def _populate_stats_process_users(self, progress, batch_size): + async def _populate_stats_process_users(self, progress, batch_size): """ This is a background update which regenerates statistics for users. """ if not self.stats_enabled: - yield self.db.updates._end_background_update("populate_stats_process_users") + await self.db.updates._end_background_update("populate_stats_process_users") return 1 last_user_id = progress.get("last_user_id", "") @@ -118,20 +117,20 @@ class StatsStore(StateDeltasStore): txn.execute(sql, (last_user_id, batch_size)) return [r for r, in txn] - users_to_work_on = yield self.db.runInteraction( + users_to_work_on = await self.db.runInteraction( "_populate_stats_process_users", _get_next_batch ) # No more rooms -- complete the transaction. if not users_to_work_on: - yield self.db.updates._end_background_update("populate_stats_process_users") + await self.db.updates._end_background_update("populate_stats_process_users") return 1 for user_id in users_to_work_on: - yield self._calculate_and_set_initial_state_for_user(user_id) + await self._calculate_and_set_initial_state_for_user(user_id) progress["last_user_id"] = user_id - yield self.db.runInteraction( + await self.db.runInteraction( "populate_stats_process_users", self.db.updates._background_update_progress_txn, "populate_stats_process_users", @@ -140,13 +139,12 @@ class StatsStore(StateDeltasStore): return len(users_to_work_on) - @defer.inlineCallbacks - def _populate_stats_process_rooms(self, progress, batch_size): + async def _populate_stats_process_rooms(self, progress, batch_size): """ This is a background update which regenerates statistics for rooms. """ if not self.stats_enabled: - yield self.db.updates._end_background_update("populate_stats_process_rooms") + await self.db.updates._end_background_update("populate_stats_process_rooms") return 1 last_room_id = progress.get("last_room_id", "") @@ -161,20 +159,20 @@ class StatsStore(StateDeltasStore): txn.execute(sql, (last_room_id, batch_size)) return [r for r, in txn] - rooms_to_work_on = yield self.db.runInteraction( + rooms_to_work_on = await self.db.runInteraction( "populate_stats_rooms_get_batch", _get_next_batch ) # No more rooms -- complete the transaction. if not rooms_to_work_on: - yield self.db.updates._end_background_update("populate_stats_process_rooms") + await self.db.updates._end_background_update("populate_stats_process_rooms") return 1 for room_id in rooms_to_work_on: - yield self._calculate_and_set_initial_state_for_room(room_id) + await self._calculate_and_set_initial_state_for_room(room_id) progress["last_room_id"] = room_id - yield self.db.runInteraction( + await self.db.runInteraction( "_populate_stats_process_rooms", self.db.updates._background_update_progress_txn, "populate_stats_process_rooms", @@ -696,16 +694,16 @@ class StatsStore(StateDeltasStore): return room_deltas, user_deltas - @defer.inlineCallbacks - def _calculate_and_set_initial_state_for_room(self, room_id): + async def _calculate_and_set_initial_state_for_room( + self, room_id: str + ) -> Tuple[dict, dict, int]: """Calculate and insert an entry into room_stats_current. Args: - room_id (str) + room_id: The room ID under calculation. Returns: - Deferred[tuple[dict, dict, int]]: A tuple of room state, membership - counts and stream position. + A tuple of room state, membership counts and stream position. """ def _fetch_current_state_stats(txn): @@ -767,11 +765,11 @@ class StatsStore(StateDeltasStore): current_state_events_count, users_in_room, pos, - ) = yield self.db.runInteraction( + ) = await self.db.runInteraction( "get_initial_state_for_room", _fetch_current_state_stats ) - state_event_map = yield self.get_events(event_ids, get_prev_content=False) + state_event_map = await self.get_events(event_ids, get_prev_content=False) room_state = { "join_rules": None, @@ -806,11 +804,11 @@ class StatsStore(StateDeltasStore): event.content.get("m.federate", True) is True ) - yield self.update_room_state(room_id, room_state) + await self.update_room_state(room_id, room_state) local_users_in_room = [u for u in users_in_room if self.hs.is_mine_id(u)] - yield self.update_stats_delta( + await self.update_stats_delta( ts=self.clock.time_msec(), stats_type="room", stats_id=room_id, @@ -826,8 +824,7 @@ class StatsStore(StateDeltasStore): }, ) - @defer.inlineCallbacks - def _calculate_and_set_initial_state_for_user(self, user_id): + async def _calculate_and_set_initial_state_for_user(self, user_id): def _calculate_and_set_initial_state_for_user_txn(txn): pos = self._get_max_stream_id_in_current_state_deltas_txn(txn) @@ -842,12 +839,12 @@ class StatsStore(StateDeltasStore): (count,) = txn.fetchone() return count, pos - joined_rooms, pos = yield self.db.runInteraction( + joined_rooms, pos = await self.db.runInteraction( "calculate_and_set_initial_state_for_user", _calculate_and_set_initial_state_for_user_txn, ) - yield self.update_stats_delta( + await self.update_stats_delta( ts=self.clock.time_msec(), stats_type="user", stats_id=user_id, diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py index 128c09a2cf..7dada7f75f 100644 --- a/synapse/storage/data_stores/state/store.py +++ b/synapse/storage/data_stores/state/store.py @@ -139,10 +139,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): "get_state_group_delta", _get_state_group_delta_txn ) - @defer.inlineCallbacks - def _get_state_groups_from_groups( + async def _get_state_groups_from_groups( self, groups: List[int], state_filter: StateFilter - ): + ) -> Dict[int, StateMap[str]]: """Returns the state groups for a given set of groups from the database, filtering on types of state events. @@ -151,13 +150,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): state_filter: The state filter used to fetch state from the database. Returns: - Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. + Dict of state group to state map. """ results = {} chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] for chunk in chunks: - res = yield self.db.runInteraction( + res = await self.db.runInteraction( "_get_state_groups_from_groups", self._get_state_groups_from_groups_txn, chunk, @@ -206,10 +205,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return state_filter.filter_state(state_dict_ids), not missing_types - @defer.inlineCallbacks - def _get_state_for_groups( + async def _get_state_for_groups( self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() - ): + ) -> Dict[int, StateMap[str]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key @@ -219,7 +217,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): state_filter: The state filter used to fetch state from the database. Returns: - Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. + Dict of state group to state map. """ member_filter, non_member_filter = state_filter.get_member_split() @@ -228,14 +226,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): ( non_member_state, incomplete_groups_nm, - ) = yield self._get_state_for_groups_using_cache( + ) = self._get_state_for_groups_using_cache( groups, self._state_group_cache, state_filter=non_member_filter ) - ( - member_state, - incomplete_groups_m, - ) = yield self._get_state_for_groups_using_cache( + (member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache( groups, self._state_group_members_cache, state_filter=member_filter ) @@ -256,7 +251,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): # Help the cache hit ratio by expanding the filter a bit db_state_filter = state_filter.return_expanded() - group_to_state_dict = yield self._get_state_groups_from_groups( + group_to_state_dict = await self._get_state_groups_from_groups( list(incomplete_groups), state_filter=db_state_filter ) @@ -576,19 +571,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): ((sg,) for sg in state_groups_to_delete), ) - @defer.inlineCallbacks - def get_previous_state_groups(self, state_groups): + async def get_previous_state_groups( + self, state_groups: Iterable[int] + ) -> Dict[int, int]: """Fetch the previous groups of the given state groups. Args: - state_groups (Iterable[int]) + state_groups Returns: - Deferred[dict[int, int]]: mapping from state group to previous - state group. + A mapping from state group to previous state group. """ - rows = yield self.db.simple_select_many_batch( + rows = await self.db.simple_select_many_batch( table="state_group_edges", column="prev_state_group", iterable=state_groups, diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 49ee9c9a74..534883361f 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Dict, Iterable, List, Optional, Set, Tuple, TypeVar +from typing import Awaitable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar import attr @@ -419,7 +419,7 @@ class StateGroupStorage(object): def _get_state_groups_from_groups( self, groups: List[int], state_filter: StateFilter - ): + ) -> Awaitable[Dict[int, StateMap[str]]]: """Returns the state groups for a given set of groups, filtering on types of state events. @@ -429,7 +429,7 @@ class StateGroupStorage(object): from the database. Returns: - Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. + Dict of state group to state map. """ return self.stores.state._get_state_groups_from_groups(groups, state_filter) @@ -532,7 +532,7 @@ class StateGroupStorage(object): def _get_state_for_groups( self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() - ): + ) -> Awaitable[Dict[int, StateMap[str]]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key @@ -540,8 +540,9 @@ class StateGroupStorage(object): groups: list of state groups for which we want to get the state. state_filter: The state filter used to fetch state. from the database. + Returns: - Deferred[dict[int, StateMap[str]]]: Dict of state group to state map. + Dict of state group to state map. """ return self.stores.state._get_state_for_groups(groups, state_filter) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 43dbeb42c5..2b1580feeb 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -39,14 +39,18 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_get_unread_push_actions_for_user_in_range_for_http(self): - yield self.store.get_unread_push_actions_for_user_in_range_for_http( - USER_ID, 0, 1000, 20 + yield defer.ensureDeferred( + self.store.get_unread_push_actions_for_user_in_range_for_http( + USER_ID, 0, 1000, 20 + ) ) @defer.inlineCallbacks def test_get_unread_push_actions_for_user_in_range_for_email(self): - yield self.store.get_unread_push_actions_for_user_in_range_for_email( - USER_ID, 0, 1000, 20 + yield defer.ensureDeferred( + self.store.get_unread_push_actions_for_user_in_range_for_email( + USER_ID, 0, 1000, 20 + ) ) @defer.inlineCallbacks diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index a5f250d477..d07b985a8e 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -37,11 +37,13 @@ class RoomStoreTestCase(unittest.TestCase): self.alias = RoomAlias.from_string("#a-room-name:test") self.u_creator = UserID.from_string("@creator:test") - yield self.store.store_room( - self.room.to_string(), - room_creator_user_id=self.u_creator.to_string(), - is_public=True, - room_version=RoomVersions.V1, + yield defer.ensureDeferred( + self.store.store_room( + self.room.to_string(), + room_creator_user_id=self.u_creator.to_string(), + is_public=True, + room_version=RoomVersions.V1, + ) ) @defer.inlineCallbacks @@ -88,11 +90,13 @@ class RoomEventsStoreTestCase(unittest.TestCase): self.room = RoomID.from_string("!abcde:test") - yield self.store.store_room( - self.room.to_string(), - room_creator_user_id="@creator:text", - is_public=True, - room_version=RoomVersions.V1, + yield defer.ensureDeferred( + self.store.store_room( + self.room.to_string(), + room_creator_user_id="@creator:text", + is_public=True, + room_version=RoomVersions.V1, + ) ) @defer.inlineCallbacks diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 6a48b9d3b3..8bd12fa847 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -44,11 +44,13 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room = RoomID.from_string("!abc123:test") - yield self.store.store_room( - self.room.to_string(), - room_creator_user_id="@creator:text", - is_public=True, - room_version=RoomVersions.V1, + yield defer.ensureDeferred( + self.store.store_room( + self.room.to_string(), + room_creator_user_id="@creator:text", + is_public=True, + room_version=RoomVersions.V1, + ) ) @defer.inlineCallbacks -- cgit 1.5.1