From 62b1ce85398f52e7d6137e77083294d0c90af459 Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Sun, 5 Jul 2020 16:32:02 +0100 Subject: isort 5 compatibility (#7786) The CI appears to use the latest version of isort, which is a problem when isort gets a major version bump. Rather than try to pin the version, I've done the necessary to make isort5 happy with synapse. --- tests/test_utils/event_injection.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'tests/test_utils/event_injection.py') diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index 431e9f8e5e..43297b530c 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import Optional, Tuple import synapse.server @@ -25,7 +24,6 @@ from synapse.types import Collection from tests.test_utils import get_awaitable_result - """ Utility functions for poking events into the storage of the server under test. """ -- cgit 1.5.1 From cc9bb3dc3f299d451ab523dea192e48c32e87c68 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 22 Jul 2020 12:29:15 -0400 Subject: Convert the message handler to async/await. (#7884) --- changelog.d/7884.misc | 1 + synapse/handlers/message.py | 288 ++++++++++++++------------- tests/events/test_snapshot.py | 36 ++-- tests/replication/tcp/streams/test_events.py | 76 ++++--- tests/storage/test_roommember.py | 56 +++--- tests/storage/test_state.py | 4 +- tests/test_utils/event_injection.py | 28 +-- tests/test_visibility.py | 14 +- tests/unittest.py | 4 +- tests/utils.py | 4 +- 10 files changed, 273 insertions(+), 238 deletions(-) create mode 100644 changelog.d/7884.misc (limited to 'tests/test_utils/event_injection.py') diff --git a/changelog.d/7884.misc b/changelog.d/7884.misc new file mode 100644 index 0000000000..36c7d4de67 --- /dev/null +++ b/changelog.d/7884.misc @@ -0,0 +1 @@ +Convert the message handler to async/await. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index c47764a4ce..172a7214b2 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -15,12 +15,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple from canonicaljson import encode_canonical_json, json -from twisted.internet import defer -from twisted.internet.defer import succeed from twisted.internet.interfaces import IDelayedCall from synapse import event_auth @@ -41,13 +39,22 @@ from synapse.api.errors import ( from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.api.urls import ConsentURIBuilder from synapse.events import EventBase +from synapse.events.builder import EventBuilder +from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter -from synapse.types import Collection, RoomAlias, UserID, create_requester +from synapse.types import ( + Collection, + Requester, + RoomAlias, + StreamToken, + UserID, + create_requester, +) from synapse.util.async_helpers import Linearizer from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.metrics import measure_func @@ -84,14 +91,22 @@ class MessageHandler(object): "_schedule_next_expiry", self._schedule_next_expiry ) - @defer.inlineCallbacks - def get_room_data( - self, user_id=None, room_id=None, event_type=None, state_key="", is_guest=False - ): + async def get_room_data( + self, + user_id: str = None, + room_id: str = None, + event_type: Optional[str] = None, + state_key: str = "", + is_guest: bool = False, + ) -> dict: """ Get data from a room. Args: - event : The room path event + user_id + room_id + event_type + state_key + is_guest Returns: The path data content. Raises: @@ -100,30 +115,29 @@ class MessageHandler(object): ( membership, membership_event_id, - ) = yield self.auth.check_user_in_room_or_world_readable( + ) = await self.auth.check_user_in_room_or_world_readable( room_id, user_id, allow_departed_users=True ) if membership == Membership.JOIN: - data = yield self.state.get_current_state(room_id, event_type, state_key) + data = await self.state.get_current_state(room_id, event_type, state_key) elif membership == Membership.LEAVE: key = (event_type, state_key) - room_state = yield self.state_store.get_state_for_events( + room_state = await self.state_store.get_state_for_events( [membership_event_id], StateFilter.from_types([key]) ) data = room_state[membership_event_id].get(key) return data - @defer.inlineCallbacks - def get_state_events( + async def get_state_events( self, - user_id, - room_id, - state_filter=StateFilter.all(), - at_token=None, - is_guest=False, - ): + user_id: str, + room_id: str, + state_filter: StateFilter = StateFilter.all(), + at_token: Optional[StreamToken] = None, + is_guest: bool = False, + ) -> List[dict]: """Retrieve all state events for a given room. If the user is joined to the room then return the current state. If the user has left the room return the state events from when they left. If an explicit @@ -131,15 +145,14 @@ class MessageHandler(object): visible. Args: - user_id(str): The user requesting state events. - room_id(str): The room ID to get all state events from. - state_filter (StateFilter): The state filter used to fetch state - from the database. - at_token(StreamToken|None): the stream token of the at which we are requesting + user_id: The user requesting state events. + room_id: The room ID to get all state events from. + state_filter: The state filter used to fetch state from the database. + at_token: the stream token of the at which we are requesting the stats. If the user is not allowed to view the state as of that stream token, we raise a 403 SynapseError. If None, returns the current state based on the current_state_events table. - is_guest(bool): whether this user is a guest + is_guest: whether this user is a guest Returns: A list of dicts representing state events. [{}, {}, {}] Raises: @@ -153,20 +166,20 @@ class MessageHandler(object): # get_recent_events_for_room operates by topo ordering. This therefore # does not reliably give you the state at the given stream position. # (https://github.com/matrix-org/synapse/issues/3305) - last_events, _ = yield self.store.get_recent_events_for_room( + last_events, _ = await self.store.get_recent_events_for_room( room_id, end_token=at_token.room_key, limit=1 ) if not last_events: raise NotFoundError("Can't find event for token %s" % (at_token,)) - visible_events = yield filter_events_for_client( + visible_events = await filter_events_for_client( self.storage, user_id, last_events, filter_send_to_client=False ) event = last_events[0] if visible_events: - room_state = yield self.state_store.get_state_for_events( + room_state = await self.state_store.get_state_for_events( [event.event_id], state_filter=state_filter ) room_state = room_state[event.event_id] @@ -180,23 +193,23 @@ class MessageHandler(object): ( membership, membership_event_id, - ) = yield self.auth.check_user_in_room_or_world_readable( + ) = await self.auth.check_user_in_room_or_world_readable( room_id, user_id, allow_departed_users=True ) if membership == Membership.JOIN: - state_ids = yield self.store.get_filtered_current_state_ids( + state_ids = await self.store.get_filtered_current_state_ids( room_id, state_filter=state_filter ) - room_state = yield self.store.get_events(state_ids.values()) + room_state = await self.store.get_events(state_ids.values()) elif membership == Membership.LEAVE: - room_state = yield self.state_store.get_state_for_events( + room_state = await self.state_store.get_state_for_events( [membership_event_id], state_filter=state_filter ) room_state = room_state[membership_event_id] now = self.clock.time_msec() - events = yield self._event_serializer.serialize_events( + events = await self._event_serializer.serialize_events( room_state.values(), now, # We don't bother bundling aggregations in when asked for state @@ -205,15 +218,14 @@ class MessageHandler(object): ) return events - @defer.inlineCallbacks - def get_joined_members(self, requester, room_id): + async def get_joined_members(self, requester: Requester, room_id: str) -> dict: """Get all the joined members in the room and their profile information. If the user has left the room return the state events from when they left. Args: - requester(Requester): The user requesting state events. - room_id(str): The room ID to get all state events from. + requester: The user requesting state events. + room_id: The room ID to get all state events from. Returns: A dict of user_id to profile info """ @@ -221,7 +233,7 @@ class MessageHandler(object): if not requester.app_service: # We check AS auth after fetching the room membership, as it # requires us to pull out all joined members anyway. - membership, _ = yield self.auth.check_user_in_room_or_world_readable( + membership, _ = await self.auth.check_user_in_room_or_world_readable( room_id, user_id, allow_departed_users=True ) if membership != Membership.JOIN: @@ -229,7 +241,7 @@ class MessageHandler(object): "Getting joined members after leaving is not implemented" ) - users_with_profile = yield self.state.get_current_users_in_room(room_id) + users_with_profile = await self.state.get_current_users_in_room(room_id) # If this is an AS, double check that they are allowed to see the members. # This can either be because the AS user is in the room or because there @@ -250,7 +262,7 @@ class MessageHandler(object): for user_id, profile in users_with_profile.items() } - def maybe_schedule_expiry(self, event): + def maybe_schedule_expiry(self, event: EventBase): """Schedule the expiry of an event if there's not already one scheduled, or if the one running is for an event that will expire after the provided timestamp. @@ -259,7 +271,7 @@ class MessageHandler(object): the master process, and therefore needs to be run on there. Args: - event (EventBase): The event to schedule the expiry of. + event: The event to schedule the expiry of. """ expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) @@ -270,8 +282,7 @@ class MessageHandler(object): # a task scheduled for a timestamp that's sooner than the provided one. self._schedule_expiry_for_event(event.event_id, expiry_ts) - @defer.inlineCallbacks - def _schedule_next_expiry(self): + async def _schedule_next_expiry(self): """Retrieve the ID and the expiry timestamp of the next event to be expired, and schedule an expiry task for it. @@ -279,18 +290,18 @@ class MessageHandler(object): future call to save_expiry_ts can schedule a new expiry task. """ # Try to get the expiry timestamp of the next event to expire. - res = yield self.store.get_next_event_to_expire() + res = await self.store.get_next_event_to_expire() if res: event_id, expiry_ts = res self._schedule_expiry_for_event(event_id, expiry_ts) - def _schedule_expiry_for_event(self, event_id, expiry_ts): + def _schedule_expiry_for_event(self, event_id: str, expiry_ts: int): """Schedule an expiry task for the provided event if there's not already one scheduled at a timestamp that's sooner than the provided one. Args: - event_id (str): The ID of the event to expire. - expiry_ts (int): The timestamp at which to expire the event. + event_id: The ID of the event to expire. + expiry_ts: The timestamp at which to expire the event. """ if self._scheduled_expiry: # If the provided timestamp refers to a time before the scheduled time of the @@ -320,8 +331,7 @@ class MessageHandler(object): event_id, ) - @defer.inlineCallbacks - def _expire_event(self, event_id): + async def _expire_event(self, event_id: str): """Retrieve and expire an event that needs to be expired from the database. If the event doesn't exist in the database, log it and delete the expiry date @@ -336,12 +346,12 @@ class MessageHandler(object): try: # Expire the event if we know about it. This function also deletes the expiry # date from the database in the same database transaction. - yield self.store.expire_event(event_id) + await self.store.expire_event(event_id) except Exception as e: logger.error("Could not expire event %s: %r", event_id, e) # Schedule the expiry of the next event to expire. - yield self._schedule_next_expiry() + await self._schedule_next_expiry() # The duration (in ms) after which rooms should be removed @@ -423,16 +433,15 @@ class EventCreationHandler(object): self._dummy_events_threshold = hs.config.dummy_events_threshold - @defer.inlineCallbacks - def create_event( + async def create_event( self, - requester, - event_dict, - token_id=None, - txn_id=None, + requester: Requester, + event_dict: dict, + token_id: Optional[str] = None, + txn_id: Optional[str] = None, prev_event_ids: Optional[Collection[str]] = None, - require_consent=True, - ): + require_consent: bool = True, + ) -> Tuple[EventBase, EventContext]: """ Given a dict from a client, create a new event. @@ -443,31 +452,29 @@ class EventCreationHandler(object): Args: requester - event_dict (dict): An entire event - token_id (str) - txn_id (str) - + event_dict: An entire event + token_id + txn_id prev_event_ids: the forward extremities to use as the prev_events for the new event. If None, they will be requested from the database. - - require_consent (bool): Whether to check if the requester has - consented to privacy policy. + require_consent: Whether to check if the requester has + consented to the privacy policy. Raises: ResourceLimitError if server is blocked to some resource being exceeded Returns: - Tuple of created event (FrozenEvent), Context + Tuple of created event, Context """ - yield self.auth.check_auth_blocking(requester.user.to_string()) + await self.auth.check_auth_blocking(requester.user.to_string()) if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "": room_version = event_dict["content"]["room_version"] else: try: - room_version = yield self.store.get_room_version_id( + room_version = await self.store.get_room_version_id( event_dict["room_id"] ) except NotFoundError: @@ -488,15 +495,11 @@ class EventCreationHandler(object): try: if "displayname" not in content: - displayname = yield defer.ensureDeferred( - profile.get_displayname(target) - ) + displayname = await profile.get_displayname(target) if displayname is not None: content["displayname"] = displayname if "avatar_url" not in content: - avatar_url = yield defer.ensureDeferred( - profile.get_avatar_url(target) - ) + avatar_url = await profile.get_avatar_url(target) if avatar_url is not None: content["avatar_url"] = avatar_url except Exception as e: @@ -504,9 +507,9 @@ class EventCreationHandler(object): "Failed to get profile information for %r: %s", target, e ) - is_exempt = yield self._is_exempt_from_privacy_policy(builder, requester) + is_exempt = await self._is_exempt_from_privacy_policy(builder, requester) if require_consent and not is_exempt: - yield self.assert_accepted_privacy_policy(requester) + await self.assert_accepted_privacy_policy(requester) if token_id is not None: builder.internal_metadata.token_id = token_id @@ -514,7 +517,7 @@ class EventCreationHandler(object): if txn_id is not None: builder.internal_metadata.txn_id = txn_id - event, context = yield self.create_new_client_event( + event, context = await self.create_new_client_event( builder=builder, requester=requester, prev_event_ids=prev_event_ids, ) @@ -530,10 +533,10 @@ class EventCreationHandler(object): # federation as well as those created locally. As of room v3, aliases events # can be created by users that are not in the room, therefore we have to # tolerate them in event_auth.check(). - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = await context.get_prev_state_ids() prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender)) prev_event = ( - yield self.store.get_event(prev_event_id, allow_none=True) + await self.store.get_event(prev_event_id, allow_none=True) if prev_event_id else None ) @@ -556,37 +559,36 @@ class EventCreationHandler(object): return (event, context) - def _is_exempt_from_privacy_policy(self, builder, requester): + async def _is_exempt_from_privacy_policy( + self, builder: EventBuilder, requester: Requester + ) -> bool: """"Determine if an event to be sent is exempt from having to consent to the privacy policy Args: - builder (synapse.events.builder.EventBuilder): event being created - requester (Requster): user requesting this event + builder: event being created + requester: user requesting this event Returns: - Deferred[bool]: true if the event can be sent without the user - consenting + true if the event can be sent without the user consenting """ # the only thing the user can do is join the server notices room. if builder.type == EventTypes.Member: membership = builder.content.get("membership", None) if membership == Membership.JOIN: - return self._is_server_notices_room(builder.room_id) + return await self._is_server_notices_room(builder.room_id) elif membership == Membership.LEAVE: # the user is always allowed to leave (but not kick people) return builder.state_key == requester.user.to_string() - return succeed(False) + return False - @defer.inlineCallbacks - def _is_server_notices_room(self, room_id): + async def _is_server_notices_room(self, room_id: str) -> bool: if self.config.server_notices_mxid is None: return False - user_ids = yield self.store.get_users_in_room(room_id) + user_ids = await self.store.get_users_in_room(room_id) return self.config.server_notices_mxid in user_ids - @defer.inlineCallbacks - def assert_accepted_privacy_policy(self, requester): + async def assert_accepted_privacy_policy(self, requester: Requester) -> None: """Check if a user has accepted the privacy policy Called when the given user is about to do something that requires @@ -595,12 +597,10 @@ class EventCreationHandler(object): raised. Args: - requester (synapse.types.Requester): - The user making the request + requester: The user making the request Returns: - Deferred[None]: returns normally if the user has consented or is - exempt + Returns normally if the user has consented or is exempt Raises: ConsentNotGivenError: if the user has not given consent yet @@ -621,7 +621,7 @@ class EventCreationHandler(object): ): return - u = yield self.store.get_user_by_id(user_id) + u = await self.store.get_user_by_id(user_id) assert u is not None if u["user_type"] in (UserTypes.SUPPORT, UserTypes.BOT): # support and bot users are not required to consent @@ -639,16 +639,20 @@ class EventCreationHandler(object): raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri) async def send_nonmember_event( - self, requester, event, context, ratelimit=True + self, + requester: Requester, + event: EventBase, + context: EventContext, + ratelimit: bool = True, ) -> int: """ Persists and notifies local clients and federation of an event. Args: - event (FrozenEvent) the event to send. - context (Context) the context of the event. - ratelimit (bool): Whether to rate limit this send. - is_guest (bool): Whether the sender is a guest. + requester + event the event to send. + context: the context of the event. + ratelimit: Whether to rate limit this send. Return: The stream_id of the persisted event. @@ -676,19 +680,20 @@ class EventCreationHandler(object): requester=requester, event=event, context=context, ratelimit=ratelimit ) - @defer.inlineCallbacks - def deduplicate_state_event(self, event, context): + async def deduplicate_state_event( + self, event: EventBase, context: EventContext + ) -> None: """ Checks whether event is in the latest resolved state in context. If so, returns the version of the event in context. Otherwise, returns None. """ - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = await context.get_prev_state_ids() prev_event_id = prev_state_ids.get((event.type, event.state_key)) if not prev_event_id: return - prev_event = yield self.store.get_event(prev_event_id, allow_none=True) + prev_event = await self.store.get_event(prev_event_id, allow_none=True) if not prev_event: return @@ -700,7 +705,11 @@ class EventCreationHandler(object): return async def create_and_send_nonmember_event( - self, requester, event_dict, ratelimit=True, txn_id=None + self, + requester: Requester, + event_dict: EventBase, + ratelimit: bool = True, + txn_id: Optional[str] = None, ) -> Tuple[EventBase, int]: """ Creates an event, then sends it. @@ -730,17 +739,17 @@ class EventCreationHandler(object): return event, stream_id @measure_func("create_new_client_event") - @defer.inlineCallbacks - def create_new_client_event( - self, builder, requester=None, prev_event_ids: Optional[Collection[str]] = None - ): + async def create_new_client_event( + self, + builder: EventBuilder, + requester: Optional[Requester] = None, + prev_event_ids: Optional[Collection[str]] = None, + ) -> Tuple[EventBase, EventContext]: """Create a new event for a local client Args: - builder (EventBuilder): - - requester (synapse.types.Requester|None): - + builder: + requester: prev_event_ids: the forward extremities to use as the prev_events for the new event. @@ -748,7 +757,7 @@ class EventCreationHandler(object): If None, they will be requested from the database. Returns: - Deferred[(synapse.events.EventBase, synapse.events.snapshot.EventContext)] + Tuple of created event, context """ if prev_event_ids is not None: @@ -757,10 +766,10 @@ class EventCreationHandler(object): % (len(prev_event_ids),) ) else: - prev_event_ids = yield self.store.get_prev_events_for_room(builder.room_id) + prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id) - event = yield builder.build(prev_event_ids=prev_event_ids) - context = yield self.state.compute_event_context(event) + event = await builder.build(prev_event_ids=prev_event_ids) + context = await self.state.compute_event_context(event) if requester: context.app_service = requester.app_service @@ -774,7 +783,7 @@ class EventCreationHandler(object): relates_to = relation["event_id"] aggregation_key = relation["key"] - already_exists = yield self.store.has_user_annotated_event( + already_exists = await self.store.has_user_annotated_event( relates_to, event.type, aggregation_key, event.sender ) if already_exists: @@ -786,7 +795,12 @@ class EventCreationHandler(object): @measure_func("handle_new_client_event") async def handle_new_client_event( - self, requester, event, context, ratelimit=True, extra_users=[] + self, + requester: Requester, + event: EventBase, + context: EventContext, + ratelimit: bool = True, + extra_users: List[UserID] = [], ) -> int: """Processes a new event. This includes checking auth, persisting it, notifying users, sending to remote servers, etc. @@ -795,11 +809,11 @@ class EventCreationHandler(object): processing. Args: - requester (Requester) - event (FrozenEvent) - context (EventContext) - ratelimit (bool) - extra_users (list(UserID)): Any extra users to notify about event + requester + event + context + ratelimit + extra_users: Any extra users to notify about event Return: The stream_id of the persisted event. @@ -878,10 +892,9 @@ class EventCreationHandler(object): self.store.remove_push_actions_from_staging, event.event_id ) - @defer.inlineCallbacks - def _validate_canonical_alias( - self, directory_handler, room_alias_str, expected_room_id - ): + async def _validate_canonical_alias( + self, directory_handler, room_alias_str: str, expected_room_id: str + ) -> None: """ Ensure that the given room alias points to the expected room ID. @@ -892,9 +905,7 @@ class EventCreationHandler(object): """ room_alias = RoomAlias.from_string(room_alias_str) try: - mapping = yield defer.ensureDeferred( - directory_handler.get_association(room_alias) - ) + mapping = await directory_handler.get_association(room_alias) except SynapseError as e: # Turn M_NOT_FOUND errors into M_BAD_ALIAS errors. if e.errcode == Codes.NOT_FOUND: @@ -913,7 +924,12 @@ class EventCreationHandler(object): ) async def persist_and_notify_client_event( - self, requester, event, context, ratelimit=True, extra_users=[] + self, + requester: Requester, + event: EventBase, + context: EventContext, + ratelimit: bool = True, + extra_users: List[UserID] = [], ) -> int: """Called when we have fully built the event, have already calculated the push actions for the event, and checked auth. @@ -1106,7 +1122,7 @@ class EventCreationHandler(object): return event_stream_id - async def _bump_active_time(self, user): + async def _bump_active_time(self, user: UserID) -> None: try: presence = self.hs.get_presence_handler() await presence.bump_presence_active_time(user) diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py index 640f5f3bce..3a80626224 100644 --- a/tests/events/test_snapshot.py +++ b/tests/events/test_snapshot.py @@ -41,8 +41,10 @@ class TestEventContext(unittest.HomeserverTestCase): serialize/deserialize. """ - event, context = create_event( - self.hs, room_id=self.room_id, type="m.test", sender=self.user_id, + event, context = self.get_success( + create_event( + self.hs, room_id=self.room_id, type="m.test", sender=self.user_id, + ) ) self._check_serialize_deserialize(event, context) @@ -51,12 +53,14 @@ class TestEventContext(unittest.HomeserverTestCase): """Test that an EventContext for a state event (with not previous entry) is the same after serialize/deserialize. """ - event, context = create_event( - self.hs, - room_id=self.room_id, - type="m.test", - sender=self.user_id, - state_key="", + event, context = self.get_success( + create_event( + self.hs, + room_id=self.room_id, + type="m.test", + sender=self.user_id, + state_key="", + ) ) self._check_serialize_deserialize(event, context) @@ -65,13 +69,15 @@ class TestEventContext(unittest.HomeserverTestCase): """Test that an EventContext for a state event (which replaces a previous entry) is the same after serialize/deserialize. """ - event, context = create_event( - self.hs, - room_id=self.room_id, - type="m.room.member", - sender=self.user_id, - state_key=self.user_id, - content={"membership": "leave"}, + event, context = self.get_success( + create_event( + self.hs, + room_id=self.room_id, + type="m.room.member", + sender=self.user_id, + state_key=self.user_id, + content={"membership": "leave"}, + ) ) self._check_serialize_deserialize(event, context) diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index 097e1653b4..c9998e88e6 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -119,7 +119,9 @@ class EventsStreamTestCase(BaseStreamTestCase): OTHER_USER = "@other_user:localhost" # have the user join - inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN) + self.get_success( + inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN) + ) # Update existing power levels with mod at PL50 pls = self.helper.get_state( @@ -157,14 +159,16 @@ class EventsStreamTestCase(BaseStreamTestCase): # roll back all the state by de-modding the user prev_events = fork_point pls["users"][OTHER_USER] = 0 - pl_event = inject_event( - self.hs, - prev_event_ids=prev_events, - type=EventTypes.PowerLevels, - state_key="", - sender=self.user_id, - room_id=self.room_id, - content=pls, + pl_event = self.get_success( + inject_event( + self.hs, + prev_event_ids=prev_events, + type=EventTypes.PowerLevels, + state_key="", + sender=self.user_id, + room_id=self.room_id, + content=pls, + ) ) # one more bit of state that doesn't get rolled back @@ -268,7 +272,9 @@ class EventsStreamTestCase(BaseStreamTestCase): # have the users join for u in user_ids: - inject_member_event(self.hs, self.room_id, u, Membership.JOIN) + self.get_success( + inject_member_event(self.hs, self.room_id, u, Membership.JOIN) + ) # Update existing power levels with mod at PL50 pls = self.helper.get_state( @@ -306,14 +312,16 @@ class EventsStreamTestCase(BaseStreamTestCase): pl_events = [] for u in user_ids: pls["users"][u] = 0 - e = inject_event( - self.hs, - prev_event_ids=prev_events, - type=EventTypes.PowerLevels, - state_key="", - sender=self.user_id, - room_id=self.room_id, - content=pls, + e = self.get_success( + inject_event( + self.hs, + prev_event_ids=prev_events, + type=EventTypes.PowerLevels, + state_key="", + sender=self.user_id, + room_id=self.room_id, + content=pls, + ) ) prev_events = [e.event_id] pl_events.append(e) @@ -434,13 +442,15 @@ class EventsStreamTestCase(BaseStreamTestCase): body = "event %i" % (self.event_count,) self.event_count += 1 - return inject_event( - self.hs, - room_id=self.room_id, - sender=sender, - type="test_event", - content={"body": body}, - **kwargs + return self.get_success( + inject_event( + self.hs, + room_id=self.room_id, + sender=sender, + type="test_event", + content={"body": body}, + **kwargs + ) ) def _inject_state_event( @@ -459,11 +469,13 @@ class EventsStreamTestCase(BaseStreamTestCase): if body is None: body = "state event %s" % (state_key,) - return inject_event( - self.hs, - room_id=self.room_id, - sender=sender, - type="test_state_event", - state_key=state_key, - content={"body": body}, + return self.get_success( + inject_event( + self.hs, + room_id=self.room_id, + sender=sender, + type="test_state_event", + state_key=state_key, + content={"body": body}, + ) ) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 5dd46005e6..f282921538 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -118,18 +118,22 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): def test_get_joined_users_from_context(self): room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) - bob_event = event_injection.inject_member_event( - self.hs, room, self.u_bob, Membership.JOIN + bob_event = self.get_success( + event_injection.inject_member_event( + self.hs, room, self.u_bob, Membership.JOIN + ) ) # first, create a regular event - event, context = event_injection.create_event( - self.hs, - room_id=room, - sender=self.u_alice, - prev_event_ids=[bob_event.event_id], - type="m.test.1", - content={}, + event, context = self.get_success( + event_injection.create_event( + self.hs, + room_id=room, + sender=self.u_alice, + prev_event_ids=[bob_event.event_id], + type="m.test.1", + content={}, + ) ) users = self.get_success( @@ -140,22 +144,26 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # Regression test for #7376: create a state event whose key matches bob's # user_id, but which is *not* a membership event, and persist that; then check # that `get_joined_users_from_context` returns the correct users for the next event. - non_member_event = event_injection.inject_event( - self.hs, - room_id=room, - sender=self.u_bob, - prev_event_ids=[bob_event.event_id], - type="m.test.2", - state_key=self.u_bob, - content={}, + non_member_event = self.get_success( + event_injection.inject_event( + self.hs, + room_id=room, + sender=self.u_bob, + prev_event_ids=[bob_event.event_id], + type="m.test.2", + state_key=self.u_bob, + content={}, + ) ) - event, context = event_injection.create_event( - self.hs, - room_id=room, - sender=self.u_alice, - prev_event_ids=[non_member_event.event_id], - type="m.test.3", - content={}, + event, context = self.get_success( + event_injection.create_event( + self.hs, + room_id=room, + sender=self.u_alice, + prev_event_ids=[non_member_event.event_id], + type="m.test.3", + content={}, + ) ) users = self.get_success( self.store.get_joined_users_from_context(event, context) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 0b88308ff4..a0e133cd4a 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -64,8 +64,8 @@ class StateStoreTestCase(tests.unittest.TestCase): }, ) - event, context = yield self.event_creation_handler.create_new_client_event( - builder + event, context = yield defer.ensureDeferred( + self.event_creation_handler.create_new_client_event(builder) ) yield self.storage.persistence.persist_event(event, context) diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index 43297b530c..8522c6fc09 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -22,14 +22,12 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.types import Collection -from tests.test_utils import get_awaitable_result - """ Utility functions for poking events into the storage of the server under test. """ -def inject_member_event( +async def inject_member_event( hs: synapse.server.HomeServer, room_id: str, sender: str, @@ -46,7 +44,7 @@ def inject_member_event( if extra_content: content.update(extra_content) - return inject_event( + return await inject_event( hs, room_id=room_id, type=EventTypes.Member, @@ -57,7 +55,7 @@ def inject_member_event( ) -def inject_event( +async def inject_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, prev_event_ids: Optional[Collection[str]] = None, @@ -72,37 +70,27 @@ def inject_event( prev_event_ids: prev_events for the event. If not specified, will be looked up kwargs: fields for the event to be created """ - test_reactor = hs.get_reactor() - - event, context = create_event(hs, room_version, prev_event_ids, **kwargs) + event, context = await create_event(hs, room_version, prev_event_ids, **kwargs) - d = hs.get_storage().persistence.persist_event(event, context) - test_reactor.advance(0) - get_awaitable_result(d) + await hs.get_storage().persistence.persist_event(event, context) return event -def create_event( +async def create_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, prev_event_ids: Optional[Collection[str]] = None, **kwargs ) -> Tuple[EventBase, EventContext]: - test_reactor = hs.get_reactor() - if room_version is None: - d = hs.get_datastore().get_room_version_id(kwargs["room_id"]) - test_reactor.advance(0) - room_version = get_awaitable_result(d) + room_version = await hs.get_datastore().get_room_version_id(kwargs["room_id"]) builder = hs.get_event_builder_factory().for_room_version( KNOWN_ROOM_VERSIONS[room_version], kwargs ) - d = hs.get_event_creation_handler().create_new_client_event( + event, context = await hs.get_event_creation_handler().create_new_client_event( builder, prev_event_ids=prev_event_ids ) - test_reactor.advance(0) - event, context = get_awaitable_result(d) return event, context diff --git a/tests/test_visibility.py b/tests/test_visibility.py index f7381b2885..b371efc0df 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -53,7 +53,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): # # before we do that, we persist some other events to act as state. - self.inject_visibility("@admin:hs", "joined") + yield self.inject_visibility("@admin:hs", "joined") for i in range(0, 10): yield self.inject_room_member("@resident%i:hs" % i) @@ -137,8 +137,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): }, ) - event, context = yield self.event_creation_handler.create_new_client_event( - builder + event, context = yield defer.ensureDeferred( + self.event_creation_handler.create_new_client_event(builder) ) yield self.storage.persistence.persist_event(event, context) return event @@ -158,8 +158,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): }, ) - event, context = yield self.event_creation_handler.create_new_client_event( - builder + event, context = yield defer.ensureDeferred( + self.event_creation_handler.create_new_client_event(builder) ) yield self.storage.persistence.persist_event(event, context) @@ -179,8 +179,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): }, ) - event, context = yield self.event_creation_handler.create_new_client_event( - builder + event, context = yield defer.ensureDeferred( + self.event_creation_handler.create_new_client_event(builder) ) yield self.storage.persistence.persist_event(event, context) diff --git a/tests/unittest.py b/tests/unittest.py index 3175a3fa02..68d2586efd 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -603,7 +603,9 @@ class HomeserverTestCase(TestCase): user: MXID of the user to inject the membership for. membership: The membership type. """ - event_injection.inject_member_event(self.hs, room, user, membership) + self.get_success( + event_injection.inject_member_event(self.hs, room, user, membership) + ) class FederatingHomeserverTestCase(HomeserverTestCase): diff --git a/tests/utils.py b/tests/utils.py index 4d17355a5c..ac643679aa 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -671,6 +671,8 @@ def create_room(hs, room_id, creator_id): }, ) - event, context = yield event_creation_handler.create_new_client_event(builder) + event, context = yield defer.ensureDeferred( + event_creation_handler.create_new_client_event(builder) + ) yield persistence_store.persist_event(event, context) -- cgit 1.5.1 From da77520cd1c414c9341da287967feb1bab14cbec Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 1 Sep 2020 08:39:04 -0400 Subject: Convert additional databases to async/await part 2 (#8200) --- changelog.d/8200.misc | 1 + synapse/events/builder.py | 19 +++++---- synapse/handlers/message.py | 13 ++---- synapse/handlers/room_member.py | 12 +----- synapse/storage/databases/main/client_ips.py | 4 +- synapse/storage/databases/main/directory.py | 6 +-- synapse/storage/databases/main/filtering.py | 5 ++- synapse/storage/databases/main/openid.py | 8 +++- synapse/storage/databases/main/profile.py | 6 ++- synapse/storage/databases/main/push_rule.py | 10 ++--- synapse/storage/databases/main/room.py | 49 ++++++++++++---------- synapse/storage/databases/main/signatures.py | 40 ++++++++++++++---- synapse/storage/databases/main/ui_auth.py | 4 +- .../storage/databases/main/user_erasure_store.py | 8 ++-- tests/test_utils/event_injection.py | 7 ++-- 15 files changed, 111 insertions(+), 81 deletions(-) create mode 100644 changelog.d/8200.misc (limited to 'tests/test_utils/event_injection.py') diff --git a/changelog.d/8200.misc b/changelog.d/8200.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8200.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 9ed24380dd..7878cd7044 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Any, Dict, List, Optional, Tuple, Union import attr from nacl.signing import SigningKey @@ -97,14 +97,14 @@ class EventBuilder(object): def is_state(self): return self._state_key is not None - async def build(self, prev_event_ids): + async def build(self, prev_event_ids: List[str]) -> EventBase: """Transform into a fully signed and hashed event Args: - prev_event_ids (list[str]): The event IDs to use as the prev events + prev_event_ids: The event IDs to use as the prev events Returns: - FrozenEvent + The signed and hashed event. """ state_ids = await self._state.get_current_state_ids( @@ -114,8 +114,13 @@ class EventBuilder(object): format_version = self.room_version.event_format if format_version == EventFormatVersions.V1: - auth_events = await self._store.add_event_hashes(auth_ids) - prev_events = await self._store.add_event_hashes(prev_event_ids) + # The types of auth/prev events changes between event versions. + auth_events = await self._store.add_event_hashes( + auth_ids + ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] + prev_events = await self._store.add_event_hashes( + prev_event_ids + ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] else: auth_events = auth_ids prev_events = prev_event_ids @@ -138,7 +143,7 @@ class EventBuilder(object): "unsigned": self.unsigned, "depth": depth, "prev_state": [], - } + } # type: Dict[str, Any] if self.is_state(): event_dict["state_key"] = self._state_key diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 9d0c38f4df..72bb638167 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -49,14 +49,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter -from synapse.types import ( - Collection, - Requester, - RoomAlias, - StreamToken, - UserID, - create_requester, -) +from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester from synapse.util import json_decoder from synapse.util.async_helpers import Linearizer from synapse.util.frozenutils import frozendict_json_encoder @@ -446,7 +439,7 @@ class EventCreationHandler(object): event_dict: dict, token_id: Optional[str] = None, txn_id: Optional[str] = None, - prev_event_ids: Optional[Collection[str]] = None, + prev_event_ids: Optional[List[str]] = None, require_consent: bool = True, ) -> Tuple[EventBase, EventContext]: """ @@ -786,7 +779,7 @@ class EventCreationHandler(object): self, builder: EventBuilder, requester: Optional[Requester] = None, - prev_event_ids: Optional[Collection[str]] = None, + prev_event_ids: Optional[List[str]] = None, ) -> Tuple[EventBase, EventContext]: """Create a new event for a local client diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index cae4d013b8..a7962b0ada 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -38,15 +38,7 @@ from synapse.events.builder import create_local_event_from_event_dict from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.storage.roommember import RoomsForUser -from synapse.types import ( - Collection, - JsonDict, - Requester, - RoomAlias, - RoomID, - StateMap, - UserID, -) +from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room, user_left_room @@ -184,7 +176,7 @@ class RoomMemberHandler(object): target: UserID, room_id: str, membership: str, - prev_event_ids: Collection[str], + prev_event_ids: List[str], txn_id: Optional[str] = None, ratelimit: bool = True, content: Optional[dict] = None, diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 216a5925fc..c2fc847fbc 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -396,7 +396,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): self._batch_row_update[key] = (user_agent, device_id, now) @wrap_as_background_process("update_client_ips") - def _update_client_ips_batch(self): + async def _update_client_ips_batch(self) -> None: # If the DB pool has already terminated, don't try updating if not self.db_pool.is_running(): @@ -405,7 +405,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): to_update = self._batch_row_update self._batch_row_update = {} - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update ) diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index 405b5eafa5..e5060d4c46 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -159,9 +159,9 @@ class DirectoryStore(DirectoryWorkerStore): return room_id - def update_aliases_for_room( + async def update_aliases_for_room( self, old_room_id: str, new_room_id: str, creator: Optional[str] = None, - ): + ) -> None: """Repoint all of the aliases for a given room, to a different room. Args: @@ -189,6 +189,6 @@ class DirectoryStore(DirectoryWorkerStore): txn, self.get_aliases_for_room, (new_room_id,) ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_update_aliases_for_room_txn", _update_aliases_for_room_txn ) diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py index 45a1760170..d2f5b9a502 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py @@ -17,6 +17,7 @@ from canonicaljson import encode_canonical_json from synapse.api.errors import Codes, SynapseError from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.types import JsonDict from synapse.util.caches.descriptors import cached @@ -40,7 +41,7 @@ class FilteringStore(SQLBaseStore): return db_to_json(def_json) - def add_user_filter(self, user_localpart, user_filter): + async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str: def_json = encode_canonical_json(user_filter) # Need an atomic transaction to SELECT the maximal ID so far then @@ -71,4 +72,4 @@ class FilteringStore(SQLBaseStore): return filter_id - return self.db_pool.runInteraction("add_user_filter", _do_txn) + return await self.db_pool.runInteraction("add_user_filter", _do_txn) diff --git a/synapse/storage/databases/main/openid.py b/synapse/storage/databases/main/openid.py index 4db8949da7..2aac64901b 100644 --- a/synapse/storage/databases/main/openid.py +++ b/synapse/storage/databases/main/openid.py @@ -1,3 +1,5 @@ +from typing import Optional + from synapse.storage._base import SQLBaseStore @@ -15,7 +17,9 @@ class OpenIdStore(SQLBaseStore): desc="insert_open_id_token", ) - def get_user_id_for_open_id_token(self, token, ts_now_ms): + async def get_user_id_for_open_id_token( + self, token: str, ts_now_ms: int + ) -> Optional[str]: def get_user_id_for_token_txn(txn): sql = ( "SELECT user_id FROM open_id_tokens" @@ -30,6 +34,6 @@ class OpenIdStore(SQLBaseStore): else: return rows[0][0] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_user_id_for_token", get_user_id_for_token_txn ) diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 301875a672..d2e0685e9e 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -138,7 +138,9 @@ class ProfileStore(ProfileWorkerStore): desc="delete_remote_profile_cache", ) - def get_remote_profile_cache_entries_that_expire(self, last_checked): + async def get_remote_profile_cache_entries_that_expire( + self, last_checked: int + ) -> Dict[str, str]: """Get all users who haven't been checked since `last_checked` """ @@ -153,7 +155,7 @@ class ProfileStore(ProfileWorkerStore): return self.db_pool.cursor_to_dict(txn) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_remote_profile_cache_entries_that_expire", _get_remote_profile_cache_entries_that_expire_txn, ) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 2fb5b02d7d..0de802a86b 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -18,8 +18,6 @@ import abc import logging from typing import List, Tuple, Union -from twisted.internet import defer - from synapse.push.baserules import list_with_base_rules from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore, db_to_json @@ -149,9 +147,11 @@ class PushRulesWorkerStore( ) return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results} - def have_push_rules_changed_for_user(self, user_id, last_id): + async def have_push_rules_changed_for_user( + self, user_id: str, last_id: int + ) -> bool: if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): - return defer.succeed(False) + return False else: def have_push_rules_changed_txn(txn): @@ -163,7 +163,7 @@ class PushRulesWorkerStore( (count,) = txn.fetchone() return bool(count) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "have_push_rules_changed", have_push_rules_changed_txn ) diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index a92641c339..717df97301 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -89,7 +89,7 @@ class RoomWorkerStore(SQLBaseStore): allow_none=True, ) - def get_room_with_stats(self, room_id: str): + async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]: """Retrieve room with statistics. Args: @@ -121,7 +121,7 @@ class RoomWorkerStore(SQLBaseStore): res["public"] = bool(res["public"]) return res - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_room_with_stats", get_room_with_stats_txn, room_id ) @@ -133,13 +133,17 @@ class RoomWorkerStore(SQLBaseStore): desc="get_public_room_ids", ) - def count_public_rooms(self, network_tuple, ignore_non_federatable): + async def count_public_rooms( + self, + network_tuple: Optional[ThirdPartyInstanceID], + ignore_non_federatable: bool, + ) -> int: """Counts the number of public rooms as tracked in the room_stats_current and room_stats_state table. Args: - network_tuple (ThirdPartyInstanceID|None) - ignore_non_federatable (bool): If true filters out non-federatable rooms + network_tuple + ignore_non_federatable: If true filters out non-federatable rooms """ def _count_public_rooms_txn(txn): @@ -183,7 +187,7 @@ class RoomWorkerStore(SQLBaseStore): txn.execute(sql, query_args) return txn.fetchone()[0] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "count_public_rooms", _count_public_rooms_txn ) @@ -586,15 +590,14 @@ class RoomWorkerStore(SQLBaseStore): return row - def get_media_mxcs_in_room(self, room_id): + async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]: """Retrieves all the local and remote media MXC URIs in a given room Args: - room_id (str) + room_id Returns: - The local and remote media as a lists of tuples where the key is - the hostname and the value is the media ID. + The local and remote media as a lists of the media IDs. """ def _get_media_mxcs_in_room_txn(txn): @@ -610,11 +613,13 @@ class RoomWorkerStore(SQLBaseStore): return local_media_mxcs, remote_media_mxcs - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_media_ids_in_room", _get_media_mxcs_in_room_txn ) - def quarantine_media_ids_in_room(self, room_id, quarantined_by): + async def quarantine_media_ids_in_room( + self, room_id: str, quarantined_by: str + ) -> int: """For a room loops through all events with media and quarantines the associated media """ @@ -627,7 +632,7 @@ class RoomWorkerStore(SQLBaseStore): txn, local_mxcs, remote_mxcs, quarantined_by ) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "quarantine_media_in_room", _quarantine_media_in_room_txn ) @@ -690,9 +695,9 @@ class RoomWorkerStore(SQLBaseStore): return local_media_mxcs, remote_media_mxcs - def quarantine_media_by_id( + async def quarantine_media_by_id( self, server_name: str, media_id: str, quarantined_by: str, - ): + ) -> int: """quarantines a single local or remote media id Args: @@ -711,11 +716,13 @@ class RoomWorkerStore(SQLBaseStore): txn, local_mxcs, remote_mxcs, quarantined_by ) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "quarantine_media_by_user", _quarantine_media_by_id_txn ) - def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str): + async def quarantine_media_ids_by_user( + self, user_id: str, quarantined_by: str + ) -> int: """quarantines all local media associated with a single user Args: @@ -727,7 +734,7 @@ class RoomWorkerStore(SQLBaseStore): local_media_ids = self._get_media_ids_by_user_txn(txn, user_id) return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "quarantine_media_by_user", _quarantine_media_by_user_txn ) @@ -1284,8 +1291,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) self.hs.get_notifier().on_new_replication_data() - def get_room_count(self): - """Retrieve a list of all rooms + async def get_room_count(self) -> int: + """Retrieve the total number of rooms. """ def f(txn): @@ -1294,7 +1301,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): row = txn.fetchone() return row[0] or 0 - return self.db_pool.runInteraction("get_rooms", f) + return await self.db_pool.runInteraction("get_rooms", f) async def add_event_report( self, diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py index be191dd870..c8c67953e4 100644 --- a/synapse/storage/databases/main/signatures.py +++ b/synapse/storage/databases/main/signatures.py @@ -13,9 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, Iterable, List, Tuple + from unpaddedbase64 import encode_base64 from synapse.storage._base import SQLBaseStore +from synapse.storage.types import Cursor from synapse.util.caches.descriptors import cached, cachedList @@ -29,16 +32,37 @@ class SignatureWorkerStore(SQLBaseStore): @cachedList( cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1 ) - def get_event_reference_hashes(self, event_ids): + async def get_event_reference_hashes( + self, event_ids: Iterable[str] + ) -> Dict[str, Dict[str, bytes]]: + """Get all hashes for given events. + + Args: + event_ids: The event IDs to get hashes for. + + Returns: + A mapping of event ID to a mapping of algorithm to hash. + """ + def f(txn): return { event_id: self._get_event_reference_hashes_txn(txn, event_id) for event_id in event_ids } - return self.db_pool.runInteraction("get_event_reference_hashes", f) + return await self.db_pool.runInteraction("get_event_reference_hashes", f) - async def add_event_hashes(self, event_ids): + async def add_event_hashes( + self, event_ids: Iterable[str] + ) -> List[Tuple[str, Dict[str, str]]]: + """ + + Args: + event_ids: The event IDs + + Returns: + A list of tuples of event ID and a mapping of algorithm to base-64 encoded hash. + """ hashes = await self.get_event_reference_hashes(event_ids) hashes = { e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"} @@ -47,13 +71,15 @@ class SignatureWorkerStore(SQLBaseStore): return list(hashes.items()) - def _get_event_reference_hashes_txn(self, txn, event_id): + def _get_event_reference_hashes_txn( + self, txn: Cursor, event_id: str + ) -> Dict[str, bytes]: """Get all the hashes for a given PDU. Args: - txn (cursor): - event_id (str): Id for the Event. + txn: + event_id: Id for the Event. Returns: - A dict[unicode, bytes] of algorithm -> hash. + A mapping of algorithm -> hash. """ query = ( "SELECT algorithm, hash" diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 9eef8e57c5..b89668d561 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -290,7 +290,7 @@ class UIAuthWorkerStore(SQLBaseStore): class UIAuthStore(UIAuthWorkerStore): - def delete_old_ui_auth_sessions(self, expiration_time: int): + async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None: """ Remove sessions which were last used earlier than the expiration time. @@ -299,7 +299,7 @@ class UIAuthStore(UIAuthWorkerStore): This is an epoch time in milliseconds. """ - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_old_ui_auth_sessions", self._delete_old_ui_auth_sessions_txn, expiration_time, diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py index e3547e53b3..2f7c95fc74 100644 --- a/synapse/storage/databases/main/user_erasure_store.py +++ b/synapse/storage/databases/main/user_erasure_store.py @@ -66,7 +66,7 @@ class UserErasureWorkerStore(SQLBaseStore): class UserErasureStore(UserErasureWorkerStore): - def mark_user_erased(self, user_id: str) -> None: + async def mark_user_erased(self, user_id: str) -> None: """Indicate that user_id wishes their message history to be erased. Args: @@ -84,9 +84,9 @@ class UserErasureStore(UserErasureWorkerStore): self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) - return self.db_pool.runInteraction("mark_user_erased", f) + await self.db_pool.runInteraction("mark_user_erased", f) - def mark_user_not_erased(self, user_id: str) -> None: + async def mark_user_not_erased(self, user_id: str) -> None: """Indicate that user_id is no longer erased. Args: @@ -106,4 +106,4 @@ class UserErasureStore(UserErasureWorkerStore): self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) - return self.db_pool.runInteraction("mark_user_not_erased", f) + await self.db_pool.runInteraction("mark_user_not_erased", f) diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index 8522c6fc09..fb1ca90336 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -13,14 +13,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import List, Optional, Tuple import synapse.server from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.events.snapshot import EventContext -from synapse.types import Collection """ Utility functions for poking events into the storage of the server under test. @@ -58,7 +57,7 @@ async def inject_member_event( async def inject_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, - prev_event_ids: Optional[Collection[str]] = None, + prev_event_ids: Optional[List[str]] = None, **kwargs ) -> EventBase: """Inject a generic event into a room @@ -80,7 +79,7 @@ async def inject_event( async def create_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, - prev_event_ids: Optional[Collection[str]] = None, + prev_event_ids: Optional[List[str]] = None, **kwargs ) -> Tuple[EventBase, EventContext]: if room_version is None: -- cgit 1.5.1