From e19de43eb5903c3b6ccca82334971ebc57fc38de Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 4 Aug 2020 07:21:47 -0400 Subject: Convert streams to async. (#8014) --- tests/server_notices/test_resource_limits_server_notices.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests/server_notices') diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 99908edba3..7f70353b0d 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -275,7 +275,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): self.server_notices_manager.get_or_create_notice_room_for_user(self.user_id) ) - token = self.get_success(self.event_source.get_current_token()) + token = self.event_source.get_current_token() events, _ = self.get_success( self.store.get_recent_events_for_room( room_id, limit=100, end_token=token.room_key -- cgit 1.5.1 From d4a7829b12197faf52eb487c443ee09acafeb37e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 6 Aug 2020 08:30:06 -0400 Subject: Convert synapse.api to async/await (#8031) --- changelog.d/8031.misc | 1 + synapse/api/auth.py | 123 ++++++++++----------- synapse/api/auth_blocking.py | 13 +-- synapse/api/filtering.py | 7 +- synapse/events/builder.py | 2 +- synapse/handlers/federation.py | 2 +- synapse/handlers/message.py | 2 +- synapse/module_api/__init__.py | 8 +- synapse/push/bulk_push_rule_evaluator.py | 2 +- synapse/replication/slave/storage/client_ips.py | 2 +- synapse/rest/client/v1/directory.py | 2 +- synapse/rest/client/v2_alpha/register.py | 2 +- synapse/storage/databases/main/client_ips.py | 5 +- tests/api/test_auth.py | 69 +++++++----- tests/api/test_filtering.py | 36 ++++-- tests/handlers/test_typing.py | 4 +- tests/rest/admin/test_user.py | 10 +- tests/rest/client/v1/test_profile.py | 4 +- tests/rest/client/v1/test_rooms.py | 6 +- tests/rest/client/v1/test_typing.py | 6 +- .../test_resource_limits_server_notices.py | 2 +- tests/unittest.py | 24 ++-- 22 files changed, 172 insertions(+), 160 deletions(-) create mode 100644 changelog.d/8031.misc (limited to 'tests/server_notices') diff --git a/changelog.d/8031.misc b/changelog.d/8031.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8031.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 2178e623da..d8190f92ab 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -13,12 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Optional +from typing import List, Optional, Tuple import pymacaroons from netaddr import IPAddress -from twisted.internet import defer from twisted.web.server import Request import synapse.types @@ -80,13 +79,14 @@ class Auth(object): self._track_appservice_user_ips = hs.config.track_appservice_user_ips self._macaroon_secret_key = hs.config.macaroon_secret_key - @defer.inlineCallbacks - def check_from_context(self, room_version: str, event, context, do_sig_check=True): - prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) - auth_events_ids = yield self.compute_auth_events( + async def check_from_context( + self, room_version: str, event, context, do_sig_check=True + ): + prev_state_ids = await context.get_prev_state_ids() + auth_events_ids = self.compute_auth_events( event, prev_state_ids, for_verification=True ) - auth_events = yield self.store.get_events(auth_events_ids) + auth_events = await self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in auth_events.values()} room_version_obj = KNOWN_ROOM_VERSIONS[room_version] @@ -94,14 +94,13 @@ class Auth(object): room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check ) - @defer.inlineCallbacks - def check_user_in_room( + async def check_user_in_room( self, room_id: str, user_id: str, current_state: Optional[StateMap[EventBase]] = None, allow_departed_users: bool = False, - ): + ) -> EventBase: """Check if the user is in the room, or was at some point. Args: room_id: The room to check. @@ -119,37 +118,35 @@ class Auth(object): Raises: AuthError if the user is/was not in the room. Returns: - Deferred[Optional[EventBase]]: - Membership event for the user if the user was in the - room. This will be the join event if they are currently joined to - the room. This will be the leave event if they have left the room. + Membership event for the user if the user was in the + room. This will be the join event if they are currently joined to + the room. This will be the leave event if they have left the room. """ if current_state: member = current_state.get((EventTypes.Member, user_id), None) else: - member = yield defer.ensureDeferred( - self.state.get_current_state( - room_id=room_id, event_type=EventTypes.Member, state_key=user_id - ) + member = await self.state.get_current_state( + room_id=room_id, event_type=EventTypes.Member, state_key=user_id ) - membership = member.membership if member else None - if membership == Membership.JOIN: - return member + if member: + membership = member.membership - # XXX this looks totally bogus. Why do we not allow users who have been banned, - # or those who were members previously and have been re-invited? - if allow_departed_users and membership == Membership.LEAVE: - forgot = yield self.store.did_forget(user_id, room_id) - if not forgot: + if membership == Membership.JOIN: return member + # XXX this looks totally bogus. Why do we not allow users who have been banned, + # or those who were members previously and have been re-invited? + if allow_departed_users and membership == Membership.LEAVE: + forgot = await self.store.did_forget(user_id, room_id) + if not forgot: + return member + raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) - @defer.inlineCallbacks - def check_host_in_room(self, room_id, host): + async def check_host_in_room(self, room_id, host): with Measure(self.clock, "check_host_in_room"): - latest_event_ids = yield self.store.is_host_joined(room_id, host) + latest_event_ids = await self.store.is_host_joined(room_id, host) return latest_event_ids def can_federate(self, event, auth_events): @@ -160,14 +157,13 @@ class Auth(object): def get_public_keys(self, invite_event): return event_auth.get_public_keys(invite_event) - @defer.inlineCallbacks - def get_user_by_req( + async def get_user_by_req( self, request: Request, allow_guest: bool = False, rights: str = "access", allow_expired: bool = False, - ): + ) -> synapse.types.Requester: """ Get a registered user's ID. Args: @@ -180,7 +176,7 @@ class Auth(object): /login will deliver access tokens regardless of expiration. Returns: - defer.Deferred: resolves to a `synapse.types.Requester` object + Resolves to the requester Raises: InvalidClientCredentialsError if no user by that token exists or the token is invalid. @@ -194,14 +190,14 @@ class Auth(object): access_token = self.get_access_token_from_request(request) - user_id, app_service = yield self._get_appservice_user_id(request) + user_id, app_service = await self._get_appservice_user_id(request) if user_id: request.authenticated_entity = user_id opentracing.set_tag("authenticated_entity", user_id) opentracing.set_tag("appservice_id", app_service.id) if ip_addr and self._track_appservice_user_ips: - yield self.store.insert_client_ip( + await self.store.insert_client_ip( user_id=user_id, access_token=access_token, ip=ip_addr, @@ -211,7 +207,7 @@ class Auth(object): return synapse.types.create_requester(user_id, app_service=app_service) - user_info = yield self.get_user_by_access_token( + user_info = await self.get_user_by_access_token( access_token, rights, allow_expired=allow_expired ) user = user_info["user"] @@ -221,7 +217,7 @@ class Auth(object): # Deny the request if the user account has expired. if self._account_validity.enabled and not allow_expired: user_id = user.to_string() - expiration_ts = yield self.store.get_expiration_ts_for_user(user_id) + expiration_ts = await self.store.get_expiration_ts_for_user(user_id) if ( expiration_ts is not None and self.clock.time_msec() >= expiration_ts @@ -235,7 +231,7 @@ class Auth(object): device_id = user_info.get("device_id") if user and access_token and ip_addr: - yield self.store.insert_client_ip( + await self.store.insert_client_ip( user_id=user.to_string(), access_token=access_token, ip=ip_addr, @@ -261,8 +257,7 @@ class Auth(object): except KeyError: raise MissingClientTokenError() - @defer.inlineCallbacks - def _get_appservice_user_id(self, request): + async def _get_appservice_user_id(self, request): app_service = self.store.get_app_service_by_token( self.get_access_token_from_request(request) ) @@ -283,14 +278,13 @@ class Auth(object): if not app_service.is_interested_in_user(user_id): raise AuthError(403, "Application service cannot masquerade as this user.") - if not (yield self.store.get_user_by_id(user_id)): + if not (await self.store.get_user_by_id(user_id)): raise AuthError(403, "Application service has not registered this user") return user_id, app_service - @defer.inlineCallbacks - def get_user_by_access_token( + async def get_user_by_access_token( self, token: str, rights: str = "access", allow_expired: bool = False, - ): + ) -> dict: """ Validate access token and get user_id from it Args: @@ -300,7 +294,7 @@ class Auth(object): allow_expired: If False, raises an InvalidClientTokenError if the token is expired Returns: - Deferred[dict]: dict that includes: + dict that includes: `user` (UserID) `is_guest` (bool) `token_id` (int|None): access token id. May be None if guest @@ -314,7 +308,7 @@ class Auth(object): if rights == "access": # first look in the database - r = yield self._look_up_user_by_access_token(token) + r = await self._look_up_user_by_access_token(token) if r: valid_until_ms = r["valid_until_ms"] if ( @@ -352,7 +346,7 @@ class Auth(object): # It would of course be much easier to store guest access # tokens in the database as well, but that would break existing # guest tokens. - stored_user = yield self.store.get_user_by_id(user_id) + stored_user = await self.store.get_user_by_id(user_id) if not stored_user: raise InvalidClientTokenError("Unknown user_id %s" % user_id) if not stored_user["is_guest"]: @@ -482,9 +476,8 @@ class Auth(object): now = self.hs.get_clock().time_msec() return now < expiry - @defer.inlineCallbacks - def _look_up_user_by_access_token(self, token): - ret = yield self.store.get_user_by_access_token(token) + async def _look_up_user_by_access_token(self, token): + ret = await self.store.get_user_by_access_token(token) if not ret: return None @@ -507,7 +500,7 @@ class Auth(object): logger.warning("Unrecognised appservice access token.") raise InvalidClientTokenError() request.authenticated_entity = service.sender - return defer.succeed(service) + return service async def is_server_admin(self, user: UserID) -> bool: """ Check if the given user is a local server admin. @@ -522,7 +515,7 @@ class Auth(object): def compute_auth_events( self, event, current_state_ids: StateMap[str], for_verification: bool = False, - ): + ) -> List[str]: """Given an event and current state return the list of event IDs used to auth an event. @@ -530,11 +523,11 @@ class Auth(object): should be added to the event's `auth_events`. Returns: - defer.Deferred(list[str]): List of event IDs. + List of event IDs. """ if event.type == EventTypes.Create: - return defer.succeed([]) + return [] # Currently we ignore the `for_verification` flag even though there are # some situations where we can drop particular auth events when adding @@ -553,7 +546,7 @@ class Auth(object): if auth_ev_id: auth_ids.append(auth_ev_id) - return defer.succeed(auth_ids) + return auth_ids async def check_can_change_room_list(self, room_id: str, user: UserID): """Determine whether the user is allowed to edit the room's entry in the @@ -636,10 +629,9 @@ class Auth(object): return query_params[0].decode("ascii") - @defer.inlineCallbacks - def check_user_in_room_or_world_readable( + async def check_user_in_room_or_world_readable( self, room_id: str, user_id: str, allow_departed_users: bool = False - ): + ) -> Tuple[str, Optional[str]]: """Checks that the user is or was in the room or the room is world readable. If it isn't then an exception is raised. @@ -650,10 +642,9 @@ class Auth(object): members but have now departed Returns: - Deferred[tuple[str, str|None]]: Resolves to the current membership of - the user in the room and the membership event ID of the user. If - the user is not in the room and never has been, then - `(Membership.JOIN, None)` is returned. + Resolves to the current membership of the user in the room and the + membership event ID of the user. If the user is not in the room and + never has been, then `(Membership.JOIN, None)` is returned. """ try: @@ -662,15 +653,13 @@ class Auth(object): # * The user is a non-guest user, and was ever in the room # * The user is a guest user, and has joined the room # else it will throw. - member_event = yield self.check_user_in_room( + member_event = await self.check_user_in_room( room_id, user_id, allow_departed_users=allow_departed_users ) return member_event.membership, member_event.event_id except AuthError: - visibility = yield defer.ensureDeferred( - self.state.get_current_state( - room_id, EventTypes.RoomHistoryVisibility, "" - ) + visibility = await self.state.get_current_state( + room_id, EventTypes.RoomHistoryVisibility, "" ) if ( visibility diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py index 5c499b6b4e..49093bf181 100644 --- a/synapse/api/auth_blocking.py +++ b/synapse/api/auth_blocking.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.constants import LimitBlockingTypes, UserTypes from synapse.api.errors import Codes, ResourceLimitError from synapse.config.server import is_threepid_reserved @@ -36,8 +34,7 @@ class AuthBlocking(object): self._limit_usage_by_mau = hs.config.limit_usage_by_mau self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids - @defer.inlineCallbacks - def check_auth_blocking(self, user_id=None, threepid=None, user_type=None): + async def check_auth_blocking(self, user_id=None, threepid=None, user_type=None): """Checks if the user should be rejected for some external reason, such as monthly active user limiting or global disable flag @@ -60,7 +57,7 @@ class AuthBlocking(object): if user_id is not None: if user_id == self._server_notices_mxid: return - if (yield self.store.is_support_user(user_id)): + if await self.store.is_support_user(user_id): return if self._hs_disabled: @@ -76,11 +73,11 @@ class AuthBlocking(object): # If the user is already part of the MAU cohort or a trial user if user_id: - timestamp = yield self.store.user_last_seen_monthly_active(user_id) + timestamp = await self.store.user_last_seen_monthly_active(user_id) if timestamp: return - is_trial = yield self.store.is_trial_user(user_id) + is_trial = await self.store.is_trial_user(user_id) if is_trial: return elif threepid: @@ -93,7 +90,7 @@ class AuthBlocking(object): # allow registration. Support users are excluded from MAU checks. return # Else if there is no room in the MAU bucket, bail - current_mau = yield self.store.get_monthly_active_count() + current_mau = await self.store.get_monthly_active_count() if current_mau >= self._max_mau_value: raise ResourceLimitError( 403, diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index f988f62a1e..7393d6cb74 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -21,8 +21,6 @@ import jsonschema from canonicaljson import json from jsonschema import FormatChecker -from twisted.internet import defer - from synapse.api.constants import EventContentFields from synapse.api.errors import SynapseError from synapse.storage.presence import UserPresenceState @@ -137,9 +135,8 @@ class Filtering(object): super(Filtering, self).__init__() self.store = hs.get_datastore() - @defer.inlineCallbacks - def get_user_filter(self, user_localpart, filter_id): - result = yield self.store.get_user_filter(user_localpart, filter_id) + async def get_user_filter(self, user_localpart, filter_id): + result = await self.store.get_user_filter(user_localpart, filter_id) return FilterCollection(result) def add_user_filter(self, user_localpart, user_filter): diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 69b53ca2bc..4e179d49b3 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -106,7 +106,7 @@ class EventBuilder(object): state_ids = await self._state.get_current_state_ids( self.room_id, prev_event_ids ) - auth_ids = await self._auth.compute_auth_events(self, state_ids) + auth_ids = self._auth.compute_auth_events(self, state_ids) format_version = self.room_version.event_format if format_version == EventFormatVersions.V1: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index b3764dedae..593932adb7 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -2064,7 +2064,7 @@ class FederationHandler(BaseHandler): if not auth_events: prev_state_ids = await context.get_prev_state_ids() - auth_events_ids = await self.auth.compute_auth_events( + auth_events_ids = self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) auth_events_x = await self.store.get_events(auth_events_ids) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 43901d0934..708533d4d1 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1061,7 +1061,7 @@ class EventCreationHandler(object): raise SynapseError(400, "Cannot redact event from a different room") prev_state_ids = await context.get_prev_state_ids() - auth_events_ids = await self.auth.compute_auth_events( + auth_events_ids = self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) auth_events = await self.store.get_events(auth_events_ids) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 8201849951..c2fb757d9a 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -194,12 +194,16 @@ class ModuleApi(object): synapse.api.errors.AuthError: the access token is invalid """ # see if the access token corresponds to a device - user_info = yield self._auth.get_user_by_access_token(access_token) + user_info = yield defer.ensureDeferred( + self._auth.get_user_by_access_token(access_token) + ) device_id = user_info.get("device_id") user_id = user_info["user"].to_string() if device_id: # delete the device, which will also delete its access tokens - yield self._hs.get_device_handler().delete_device(user_id, device_id) + yield defer.ensureDeferred( + self._hs.get_device_handler().delete_device(user_id, device_id) + ) else: # no associated device. Just delete the access token. yield defer.ensureDeferred( diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 04b9d8ac82..e7fcee0e87 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -120,7 +120,7 @@ class BulkPushRuleEvaluator(object): pl_event = await self.store.get_event(pl_event_id) auth_events = {POWER_KEY: pl_event} else: - auth_events_ids = await self.auth.compute_auth_events( + auth_events_ids = self.auth.compute_auth_events( event, prev_state_ids, for_verification=False ) auth_events = await self.store.get_events(auth_events_ids) diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 60dd3f6701..a6fdedde63 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -28,7 +28,7 @@ class SlavedClientIpStore(BaseSlavedStore): name="client_ip_last_seen", keylen=4, max_entries=50000 ) - def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): + async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): now = int(self._clock.time_msec()) key = (user_id, access_token, ip) diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 5934b1fe8b..b210015173 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -89,7 +89,7 @@ class ClientDirectoryServer(RestServlet): dir_handler = self.handlers.directory_handler try: - service = await self.auth.get_appservice_by_req(request) + service = self.auth.get_appservice_by_req(request) room_alias = RoomAlias.from_string(room_alias) await dir_handler.delete_appservice_association(service, room_alias) logger.info( diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index a4c079196d..c549c090b3 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -424,7 +424,7 @@ class RegisterRestServlet(RestServlet): appservice = None if self.auth.has_access_token(request): - appservice = await self.auth.get_appservice_by_req(request) + appservice = self.auth.get_appservice_by_req(request) # fork off as soon as possible for ASes which have completely # different registration flows to normal users diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 712c8d0264..50d71f5ebc 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -380,8 +380,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): if self.user_ips_max_age: self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) - @defer.inlineCallbacks - def insert_client_ip( + async def insert_client_ip( self, user_id, access_token, ip, user_agent, device_id, now=None ): if not now: @@ -392,7 +391,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): last_seen = self.client_ip_last_seen.get(key) except KeyError: last_seen = None - yield self.populate_monthly_active_users(user_id) + await self.populate_monthly_active_users(user_id) # Rate-limited inserts if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: return diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 0bfb86bf1f..5d45689c8c 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -62,12 +62,15 @@ class AuthTestCase(unittest.TestCase): # this is overridden for the appservice tests self.store.get_app_service_by_token = Mock(return_value=None) + self.store.insert_client_ip = Mock(return_value=defer.succeed(None)) self.store.is_support_user = Mock(return_value=defer.succeed(False)) @defer.inlineCallbacks def test_get_user_by_req_user_valid_token(self): user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"} - self.store.get_user_by_access_token = Mock(return_value=user_info) + self.store.get_user_by_access_token = Mock( + return_value=defer.succeed(user_info) + ) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] @@ -76,23 +79,25 @@ class AuthTestCase(unittest.TestCase): self.assertEquals(requester.user.to_string(), self.test_user) def test_get_user_by_req_user_bad_token(self): - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.failureResultOf(d, InvalidClientTokenError).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") def test_get_user_by_req_user_missing_token(self): user_info = {"name": self.test_user, "token_id": "ditto"} - self.store.get_user_by_access_token = Mock(return_value=user_info) + self.store.get_user_by_access_token = Mock( + return_value=defer.succeed(user_info) + ) request = Mock(args={}) request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.failureResultOf(d, MissingClientTokenError).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_MISSING_TOKEN") @@ -103,7 +108,7 @@ class AuthTestCase(unittest.TestCase): token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" @@ -123,7 +128,7 @@ class AuthTestCase(unittest.TestCase): ip_range_whitelist=IPSet(["192.168/16"]), ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.getClientIP.return_value = "192.168.10.10" @@ -142,25 +147,25 @@ class AuthTestCase(unittest.TestCase): ip_range_whitelist=IPSet(["192.168/16"]), ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.getClientIP.return_value = "131.111.8.42" request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.failureResultOf(d, InvalidClientTokenError).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") def test_get_user_by_req_appservice_bad_token(self): self.store.get_app_service_by_token = Mock(return_value=None) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.failureResultOf(d, InvalidClientTokenError).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") @@ -168,11 +173,11 @@ class AuthTestCase(unittest.TestCase): def test_get_user_by_req_appservice_missing_token(self): app_service = Mock(token="foobar", url="a_url", sender=self.test_user) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) f = self.failureResultOf(d, MissingClientTokenError).value self.assertEqual(f.code, 401) self.assertEqual(f.errcode, "M_MISSING_TOKEN") @@ -185,7 +190,11 @@ class AuthTestCase(unittest.TestCase): ) app_service.is_interested_in_user = Mock(return_value=True) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + # This just needs to return a truth-y value. + self.store.get_user_by_id = Mock( + return_value=defer.succeed({"is_guest": False}) + ) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" @@ -204,20 +213,22 @@ class AuthTestCase(unittest.TestCase): ) app_service.is_interested_in_user = Mock(return_value=False) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) request = Mock(args={}) request.getClientIP.return_value = "127.0.0.1" request.args[b"access_token"] = [self.test_token] request.args[b"user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - d = self.auth.get_user_by_req(request) + d = defer.ensureDeferred(self.auth.get_user_by_req(request)) self.failureResultOf(d, AuthError) @defer.inlineCallbacks def test_get_user_from_macaroon(self): self.store.get_user_by_access_token = Mock( - return_value={"name": "@baldrick:matrix.org", "device_id": "device"} + return_value=defer.succeed( + {"name": "@baldrick:matrix.org", "device_id": "device"} + ) ) user_id = "@baldrick:matrix.org" @@ -241,8 +252,8 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_guest_user_from_macaroon(self): - self.store.get_user_by_id = Mock(return_value={"is_guest": True}) - self.store.get_user_by_access_token = Mock(return_value=None) + self.store.get_user_by_id = Mock(return_value=defer.succeed({"is_guest": True})) + self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None)) user_id = "@baldrick:matrix.org" macaroon = pymacaroons.Macaroon( @@ -282,16 +293,20 @@ class AuthTestCase(unittest.TestCase): def get_user(tok): if token != tok: - return None - return { - "name": USER_ID, - "is_guest": False, - "token_id": 1234, - "device_id": "DEVICE", - } + return defer.succeed(None) + return defer.succeed( + { + "name": USER_ID, + "is_guest": False, + "token_id": 1234, + "device_id": "DEVICE", + } + ) self.store.get_user_by_access_token = get_user - self.store.get_user_by_id = Mock(return_value={"is_guest": False}) + self.store.get_user_by_id = Mock( + return_value=defer.succeed({"is_guest": False}) + ) # check the token works request = Mock(args={}) diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 4e67503cf0..1fab1d6b69 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -375,8 +375,10 @@ class FilteringTestCase(unittest.TestCase): event = MockEvent(sender="@foo:bar", type="m.profile") events = [event] - user_filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart, filter_id=filter_id + user_filter = yield defer.ensureDeferred( + self.filtering.get_user_filter( + user_localpart=user_localpart, filter_id=filter_id + ) ) results = user_filter.filter_presence(events=events) @@ -396,8 +398,10 @@ class FilteringTestCase(unittest.TestCase): ) events = [event] - user_filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart + "2", filter_id=filter_id + user_filter = yield defer.ensureDeferred( + self.filtering.get_user_filter( + user_localpart=user_localpart + "2", filter_id=filter_id + ) ) results = user_filter.filter_presence(events=events) @@ -412,8 +416,10 @@ class FilteringTestCase(unittest.TestCase): event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar") events = [event] - user_filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart, filter_id=filter_id + user_filter = yield defer.ensureDeferred( + self.filtering.get_user_filter( + user_localpart=user_localpart, filter_id=filter_id + ) ) results = user_filter.filter_room_state(events=events) @@ -430,8 +436,10 @@ class FilteringTestCase(unittest.TestCase): ) events = [event] - user_filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart, filter_id=filter_id + user_filter = yield defer.ensureDeferred( + self.filtering.get_user_filter( + user_localpart=user_localpart, filter_id=filter_id + ) ) results = user_filter.filter_room_state(events) @@ -465,8 +473,10 @@ class FilteringTestCase(unittest.TestCase): self.assertEquals( user_filter_json, ( - yield self.datastore.get_user_filter( - user_localpart=user_localpart, filter_id=0 + yield defer.ensureDeferred( + self.datastore.get_user_filter( + user_localpart=user_localpart, filter_id=0 + ) ) ), ) @@ -479,8 +489,10 @@ class FilteringTestCase(unittest.TestCase): user_localpart=user_localpart, user_filter=user_filter_json ) - filter = yield self.filtering.get_user_filter( - user_localpart=user_localpart, filter_id=filter_id + filter = yield defer.ensureDeferred( + self.filtering.get_user_filter( + user_localpart=user_localpart, filter_id=filter_id + ) ) self.assertEquals(filter.get_filter_json(), user_filter_json) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 5878f74175..b7d0adb10e 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -126,10 +126,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.room_members = [] - def check_user_in_room(room_id, user_id): + async def check_user_in_room(room_id, user_id): if user_id not in [u.to_string() for u in self.room_members]: raise AuthError(401, "User is not in the room") - return defer.succeed(None) + return None hs.get_auth().check_user_in_room = check_user_in_room diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index f16eef15f7..17d0aae2e9 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -20,6 +20,8 @@ import urllib.parse from mock import Mock +from twisted.internet import defer + import synapse.rest.admin from synapse.api.constants import UserTypes from synapse.api.errors import HttpResponseException, ResourceLimitError @@ -335,7 +337,9 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): store = self.hs.get_datastore() # Set monthly active users to the limit - store.get_monthly_active_count = Mock(return_value=self.hs.config.max_mau_value) + store.get_monthly_active_count = Mock( + return_value=defer.succeed(self.hs.config.max_mau_value) + ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit self.get_failure( @@ -588,7 +592,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit self.store.get_monthly_active_count = Mock( - return_value=self.hs.config.max_mau_value + return_value=defer.succeed(self.hs.config.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -628,7 +632,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit self.store.get_monthly_active_count = Mock( - return_value=self.hs.config.max_mau_value + return_value=defer.succeed(self.hs.config.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 8df58b4a63..ace0a3c08d 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -70,8 +70,8 @@ class MockHandlerProfileTestCase(unittest.TestCase): profile_handler=self.mock_handler, ) - def _get_user_by_req(request=None, allow_guest=False): - return defer.succeed(synapse.types.create_requester(myid)) + async def _get_user_by_req(request=None, allow_guest=False): + return synapse.types.create_requester(myid) hs.get_auth().get_user_by_req = _get_user_by_req diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 5ccda8b2bd..ef6b775ed2 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -23,8 +23,6 @@ from urllib import parse as urlparse from mock import Mock -from twisted.internet import defer - import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.handlers.pagination import PurgeStatus @@ -51,8 +49,8 @@ class RoomBase(unittest.HomeserverTestCase): self.hs.get_federation_handler = Mock(return_value=Mock()) - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) + async def _insert_client_ip(*args, **kwargs): + return None self.hs.get_datastore().insert_client_ip = _insert_client_ip diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 18260bb90e..94d2bf2eb1 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -46,7 +46,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): hs.get_handlers().federation_handler = Mock() - def get_user_by_access_token(token=None, allow_guest=False): + async def get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.auth_user_id), "token_id": 1, @@ -55,8 +55,8 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): hs.get_auth().get_user_by_access_token = get_user_by_access_token - def _insert_client_ip(*args, **kwargs): - return defer.succeed(None) + async def _insert_client_ip(*args, **kwargs): + return None hs.get_datastore().insert_client_ip = _insert_client_ip diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 7f70353b0d..3f88abe3d2 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -258,7 +258,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): self.user_id = "@user_id:test" def test_server_notice_only_sent_once(self): - self.store.get_monthly_active_count = Mock(return_value=1000) + self.store.get_monthly_active_count = Mock(return_value=defer.succeed(1000)) self.store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(1000) diff --git a/tests/unittest.py b/tests/unittest.py index 2152c693f2..d0bba3ddef 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -241,20 +241,16 @@ class HomeserverTestCase(TestCase): if hasattr(self, "user_id"): if self.hijack_auth: - def get_user_by_access_token(token=None, allow_guest=False): - return succeed( - { - "user": UserID.from_string(self.helper.auth_user_id), - "token_id": 1, - "is_guest": False, - } - ) - - def get_user_by_req(request, allow_guest=False, rights="access"): - return succeed( - create_requester( - UserID.from_string(self.helper.auth_user_id), 1, False, None - ) + async def get_user_by_access_token(token=None, allow_guest=False): + return { + "user": UserID.from_string(self.helper.auth_user_id), + "token_id": 1, + "is_guest": False, + } + + async def get_user_by_req(request, allow_guest=False, rights="access"): + return create_requester( + UserID.from_string(self.helper.auth_user_id), 1, False, None ) self.hs.get_auth().get_user_by_req = get_user_by_req -- cgit 1.5.1 From 04faa0bfa960d9f0dc60e9cf4ec270221249b7ca Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 11 Aug 2020 17:21:20 -0400 Subject: Convert tags and metrics databases to async/await (#8062) --- changelog.d/8062.misc | 1 + synapse/storage/databases/main/metrics.py | 20 ++-- synapse/storage/databases/main/tags.py | 103 +++++++++++---------- .../test_resource_limits_server_notices.py | 5 +- 4 files changed, 64 insertions(+), 65 deletions(-) create mode 100644 changelog.d/8062.misc (limited to 'tests/server_notices') diff --git a/changelog.d/8062.misc b/changelog.d/8062.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8062.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index baa7a5092a..686052bd83 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -15,8 +15,6 @@ import typing from collections import Counter -from twisted.internet import defer - from synapse.metrics import BucketCollector from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore @@ -69,8 +67,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): res = await self.db_pool.runInteraction("read_forward_extremities", fetch) self._current_forward_extremities_amount = Counter([x[0] for x in res]) - @defer.inlineCallbacks - def count_daily_messages(self): + async def count_daily_messages(self): """ Returns an estimate of the number of messages sent in the last day. @@ -88,11 +85,9 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): (count,) = txn.fetchone() return count - ret = yield self.db_pool.runInteraction("count_messages", _count_messages) - return ret + return await self.db_pool.runInteraction("count_messages", _count_messages) - @defer.inlineCallbacks - def count_daily_sent_messages(self): + async def count_daily_sent_messages(self): def _count_messages(txn): # This is good enough as if you have silly characters in your own # hostname then thats your own fault. @@ -109,13 +104,11 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): (count,) = txn.fetchone() return count - ret = yield self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "count_daily_sent_messages", _count_messages ) - return ret - @defer.inlineCallbacks - def count_daily_active_rooms(self): + async def count_daily_active_rooms(self): def _count(txn): sql = """ SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events @@ -126,5 +119,4 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): (count,) = txn.fetchone() return count - ret = yield self.db_pool.runInteraction("count_daily_active_rooms", _count) - return ret + return await self.db_pool.runInteraction("count_daily_active_rooms", _count) diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index eedd2d96c3..e4e0a0c433 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -15,14 +15,13 @@ # limitations under the License. import logging -from typing import List, Tuple +from typing import Dict, List, Tuple from canonicaljson import json -from twisted.internet import defer - from synapse.storage._base import db_to_json from synapse.storage.databases.main.account_data import AccountDataWorkerStore +from synapse.types import JsonDict from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -30,30 +29,26 @@ logger = logging.getLogger(__name__) class TagsWorkerStore(AccountDataWorkerStore): @cached() - def get_tags_for_user(self, user_id): + async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict]]: """Get all the tags for a user. Args: - user_id(str): The user to get the tags for. + user_id: The user to get the tags for. Returns: - A deferred dict mapping from room_id strings to dicts mapping from - tag strings to tag content. + A mapping from room_id strings to dicts mapping from tag strings to + tag content. """ - deferred = self.db_pool.simple_select_list( + rows = await self.db_pool.simple_select_list( "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"] ) - @deferred.addCallback - def tags_by_room(rows): - tags_by_room = {} - for row in rows: - room_tags = tags_by_room.setdefault(row["room_id"], {}) - room_tags[row["tag"]] = db_to_json(row["content"]) - return tags_by_room - - return deferred + tags_by_room = {} + for row in rows: + room_tags = tags_by_room.setdefault(row["room_id"], {}) + room_tags[row["tag"]] = db_to_json(row["content"]) + return tags_by_room async def get_all_updated_tags( self, instance_name: str, last_id: int, current_id: int, limit: int @@ -127,17 +122,19 @@ class TagsWorkerStore(AccountDataWorkerStore): return results, upto_token, limited - @defer.inlineCallbacks - def get_updated_tags(self, user_id, stream_id): + async def get_updated_tags( + self, user_id: str, stream_id: int + ) -> Dict[str, List[str]]: """Get all the tags for the rooms where the tags have changed since the given version Args: user_id(str): The user to get the tags for. stream_id(int): The earliest update to get for the user. + Returns: - A deferred dict mapping from room_id strings to lists of tag - strings for all the rooms that changed since the stream_id token. + A mapping from room_id strings to lists of tag strings for all the + rooms that changed since the stream_id token. """ def get_updated_tags_txn(txn): @@ -155,47 +152,53 @@ class TagsWorkerStore(AccountDataWorkerStore): if not changed: return {} - room_ids = yield self.db_pool.runInteraction( + room_ids = await self.db_pool.runInteraction( "get_updated_tags", get_updated_tags_txn ) results = {} if room_ids: - tags_by_room = yield self.get_tags_for_user(user_id) + tags_by_room = await self.get_tags_for_user(user_id) for room_id in room_ids: results[room_id] = tags_by_room.get(room_id, {}) return results - def get_tags_for_room(self, user_id, room_id): + async def get_tags_for_room( + self, user_id: str, room_id: str + ) -> Dict[str, JsonDict]: """Get all the tags for the given room + Args: - user_id(str): The user to get tags for - room_id(str): The room to get tags for + user_id: The user to get tags for + room_id: The room to get tags for + Returns: - A deferred list of string tags. + A mapping of tags to tag content. """ - return self.db_pool.simple_select_list( + rows = await self.db_pool.simple_select_list( table="room_tags", keyvalues={"user_id": user_id, "room_id": room_id}, retcols=("tag", "content"), desc="get_tags_for_room", - ).addCallback( - lambda rows: {row["tag"]: db_to_json(row["content"]) for row in rows} ) + return {row["tag"]: db_to_json(row["content"]) for row in rows} class TagsStore(TagsWorkerStore): - @defer.inlineCallbacks - def add_tag_to_room(self, user_id, room_id, tag, content): + async def add_tag_to_room( + self, user_id: str, room_id: str, tag: str, content: JsonDict + ) -> int: """Add a tag to a room for a user. + Args: - user_id(str): The user to add a tag for. - room_id(str): The room to add a tag for. - tag(str): The tag name to add. - content(dict): A json object to associate with the tag. + user_id: The user to add a tag for. + room_id: The room to add a tag for. + tag: The tag name to add. + content: A json object to associate with the tag. + Returns: - A deferred that completes once the tag has been added. + The next account data ID. """ content_json = json.dumps(content) @@ -209,18 +212,17 @@ class TagsStore(TagsWorkerStore): self._update_revision_txn(txn, user_id, room_id, next_id) with self._account_data_id_gen.get_next() as next_id: - yield self.db_pool.runInteraction("add_tag", add_tag_txn, next_id) + await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) - result = self._account_data_id_gen.get_current_token() - return result + return self._account_data_id_gen.get_current_token() - @defer.inlineCallbacks - def remove_tag_from_room(self, user_id, room_id, tag): + async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int: """Remove a tag from a room for a user. + Returns: - A deferred that completes once the tag has been removed + The next account data ID. """ def remove_tag_txn(txn, next_id): @@ -232,21 +234,22 @@ class TagsStore(TagsWorkerStore): self._update_revision_txn(txn, user_id, room_id, next_id) with self._account_data_id_gen.get_next() as next_id: - yield self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id) + await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) - result = self._account_data_id_gen.get_current_token() - return result + return self._account_data_id_gen.get_current_token() - def _update_revision_txn(self, txn, user_id, room_id, next_id): + def _update_revision_txn( + self, txn, user_id: str, room_id: str, next_id: int + ) -> None: """Update the latest revision of the tags for the given user and room. Args: txn: The database cursor - user_id(str): The ID of the user. - room_id(str): The ID of the room. - next_id(int): The the revision to advance to. + user_id: The ID of the user. + room_id: The ID of the room. + next_id: The the revision to advance to. """ txn.call_after( diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 3f88abe3d2..2858d13558 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -27,6 +27,7 @@ from synapse.server_notices.resource_limits_server_notices import ( ) from tests import unittest +from tests.test_utils import make_awaitable from tests.unittest import override_config from tests.utils import default_config @@ -79,7 +80,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): return_value=defer.succeed("!something:localhost") ) self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None)) - self._rlsn._store.get_tags_for_room = Mock(return_value=defer.succeed({})) + self._rlsn._store.get_tags_for_room = Mock( + side_effect=lambda user_id, room_id: make_awaitable({}) + ) @override_config({"hs_disabled": True}) def test_maybe_send_server_notice_disabled_hs(self): -- cgit 1.5.1 From f40645e60b9cab69c953094848be61c0989a91cb Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 18 Aug 2020 16:20:49 -0400 Subject: Convert events worker database to async/await. (#8071) --- changelog.d/8071.misc | 1 + synapse/event_auth.py | 2 +- synapse/handlers/federation.py | 16 +-- synapse/handlers/message.py | 6 +- synapse/handlers/room_member.py | 2 +- synapse/spam_checker_api/__init__.py | 2 +- synapse/state/__init__.py | 2 +- synapse/storage/databases/main/event_federation.py | 30 +++-- synapse/storage/databases/main/events_worker.py | 132 ++++++++++++--------- synapse/storage/databases/main/stream.py | 1 - .../test_resource_limits_server_notices.py | 6 +- tests/storage/test_appservice.py | 3 +- 12 files changed, 106 insertions(+), 97 deletions(-) create mode 100644 changelog.d/8071.misc (limited to 'tests/server_notices') diff --git a/changelog.d/8071.misc b/changelog.d/8071.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8071.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/event_auth.py b/synapse/event_auth.py index c0981eee62..8c907ad596 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -47,7 +47,7 @@ def check( Args: room_version_obj: the version of the room event: the event being checked. - auth_events (dict: event-key -> event): the existing room state. + auth_events: the existing room state. Raises: AuthError if the checks fail diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 593932adb7..5b270228e7 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1777,9 +1777,7 @@ class FederationHandler(BaseHandler): """Returns the state at the event. i.e. not including said event. """ - event = await self.store.get_event( - event_id, allow_none=False, check_room_id=room_id - ) + event = await self.store.get_event(event_id, check_room_id=room_id) state_groups = await self.state_store.get_state_groups(room_id, [event_id]) @@ -1805,9 +1803,7 @@ class FederationHandler(BaseHandler): async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]: """Returns the state at the event. i.e. not including said event. """ - event = await self.store.get_event( - event_id, allow_none=False, check_room_id=room_id - ) + event = await self.store.get_event(event_id, check_room_id=room_id) state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id]) @@ -2155,9 +2151,9 @@ class FederationHandler(BaseHandler): auth_types = auth_types_for_event(event) current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types] - current_auth_events = await self.store.get_events(current_state_ids) + auth_events_map = await self.store.get_events(current_state_ids) current_auth_events = { - (e.type, e.state_key): e for e in current_auth_events.values() + (e.type, e.state_key): e for e in auth_events_map.values() } try: @@ -2173,9 +2169,7 @@ class FederationHandler(BaseHandler): if not in_room: raise AuthError(403, "Host not in room.") - event = await self.store.get_event( - event_id, allow_none=False, check_room_id=room_id - ) + event = await self.store.get_event(event_id, check_room_id=room_id) # Just go through and process each event in `remote_auth_chain`. We # don't want to fall into the trap of `missing` being wrong. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 532fc30681..b999d91d1a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -960,7 +960,7 @@ class EventCreationHandler(object): allow_none=True, ) - is_admin_redaction = ( + is_admin_redaction = bool( original_event and event.sender != original_event.sender ) @@ -1080,8 +1080,8 @@ class EventCreationHandler(object): auth_events_ids = self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) - auth_events = await self.store.get_events(auth_events_ids) - auth_events = {(e.type, e.state_key): e for e in auth_events.values()} + auth_events_map = await self.store.get_events(auth_events_ids) + auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()} room_version = await self.store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 31705cdbdb..aa1ccde211 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -716,7 +716,7 @@ class RoomMemberHandler(object): guest_access = await self.store.get_event(guest_access_id) - return ( + return bool( guest_access and guest_access.content and "guest_access" in guest_access.content diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py index 9b78924d96..4d9b13ac04 100644 --- a/synapse/spam_checker_api/__init__.py +++ b/synapse/spam_checker_api/__init__.py @@ -51,5 +51,5 @@ class SpamCheckerApi(object): state_ids = yield self._store.get_filtered_current_state_ids( room_id=room_id, state_filter=StateFilter.from_types(types) ) - state = yield self._store.get_events(state_ids.values()) + state = yield defer.ensureDeferred(self._store.get_events(state_ids.values())) return state.values() diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index a1d3884667..dba8d91eef 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -641,7 +641,7 @@ class StateResolutionStore(object): allow_rejected (bool): If True return rejected events. Returns: - Deferred[dict[str, FrozenEvent]]: Dict from event_id to event. + Awaitable[dict[str, FrozenEvent]]: Dict from event_id to event. """ return self.store.get_events( diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 431bd76693..4826be630c 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): - def get_auth_chain(self, event_ids, include_given=False): + async def get_auth_chain(self, event_ids, include_given=False): """Get auth events for given event_ids. The events *must* be state events. Args: @@ -40,9 +40,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas Returns: list of events """ - return self.get_auth_chain_ids( + event_ids = await self.get_auth_chain_ids( event_ids, include_given=include_given - ).addCallback(self.get_events_as_list) + ) + return await self.get_events_as_list(event_ids) def get_auth_chain_ids( self, @@ -459,7 +460,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn ) - def get_backfill_events(self, room_id, event_list, limit): + async def get_backfill_events(self, room_id, event_list, limit): """Get a list of Events for a given topic that occurred before (and including) the events in event_list. Return a list of max size `limit` @@ -469,17 +470,15 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas event_list (list) limit (int) """ - return ( - self.db_pool.runInteraction( - "get_backfill_events", - self._get_backfill_events, - room_id, - event_list, - limit, - ) - .addCallback(self.get_events_as_list) - .addCallback(lambda l: sorted(l, key=lambda e: -e.depth)) + event_ids = await self.db_pool.runInteraction( + "get_backfill_events", + self._get_backfill_events, + room_id, + event_list, + limit, ) + events = await self.get_events_as_list(event_ids) + return sorted(events, key=lambda e: -e.depth) def _get_backfill_events(self, txn, room_id, event_list, limit): logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit) @@ -540,8 +539,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas latest_events, limit, ) - events = await self.get_events_as_list(ids) - return events + return await self.get_events_as_list(ids) def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit): diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 8c63a0dc4d..e3a154a527 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -19,9 +19,10 @@ import itertools import logging import threading from collections import namedtuple -from typing import List, Optional, Tuple +from typing import Dict, Iterable, List, Optional, Tuple, overload from constantly import NamedConstant, Names +from typing_extensions import Literal from twisted.internet import defer @@ -32,7 +33,7 @@ from synapse.api.room_versions import ( EventFormatVersions, RoomVersions, ) -from synapse.events import make_event_from_dict +from synapse.events import EventBase, make_event_from_dict from synapse.events.utils import prune_event from synapse.logging.context import PreserveLoggingContext, current_context from synapse.metrics.background_process_metrics import run_as_background_process @@ -42,8 +43,8 @@ from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool from synapse.storage.util.id_generators import StreamIdGenerator -from synapse.types import get_domain_from_id -from synapse.util.caches.descriptors import Cache, cachedInlineCallbacks +from synapse.types import Collection, get_domain_from_id +from synapse.util.caches.descriptors import Cache, cached from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -137,8 +138,33 @@ class EventsWorkerStore(SQLBaseStore): desc="get_received_ts", ) - @defer.inlineCallbacks - def get_event( + # Inform mypy that if allow_none is False (the default) then get_event + # always returns an EventBase. + @overload + async def get_event( + self, + event_id: str, + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: Literal[False] = False, + check_room_id: Optional[str] = None, + ) -> EventBase: + ... + + @overload + async def get_event( + self, + event_id: str, + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: Literal[True] = False, + check_room_id: Optional[str] = None, + ) -> Optional[EventBase]: + ... + + async def get_event( self, event_id: str, redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, @@ -146,7 +172,7 @@ class EventsWorkerStore(SQLBaseStore): allow_rejected: bool = False, allow_none: bool = False, check_room_id: Optional[str] = None, - ): + ) -> Optional[EventBase]: """Get an event from the database by event_id. Args: @@ -171,12 +197,12 @@ class EventsWorkerStore(SQLBaseStore): If there is a mismatch, behave as per allow_none. Returns: - Deferred[EventBase|None] + The event, or None if the event was not found. """ if not isinstance(event_id, str): raise TypeError("Invalid event event_id %r" % (event_id,)) - events = yield self.get_events_as_list( + events = await self.get_events_as_list( [event_id], redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, @@ -194,14 +220,13 @@ class EventsWorkerStore(SQLBaseStore): return event - @defer.inlineCallbacks - def get_events( + async def get_events( self, - event_ids: List[str], + event_ids: Iterable[str], redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, get_prev_content: bool = False, allow_rejected: bool = False, - ): + ) -> Dict[str, EventBase]: """Get events from the database Args: @@ -220,9 +245,9 @@ class EventsWorkerStore(SQLBaseStore): omits rejeted events from the response. Returns: - Deferred : Dict from event_id to event. + A mapping from event_id to event. """ - events = yield self.get_events_as_list( + events = await self.get_events_as_list( event_ids, redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, @@ -231,14 +256,13 @@ class EventsWorkerStore(SQLBaseStore): return {e.event_id: e for e in events} - @defer.inlineCallbacks - def get_events_as_list( + async def get_events_as_list( self, - event_ids: List[str], + event_ids: Collection[str], redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, get_prev_content: bool = False, allow_rejected: bool = False, - ): + ) -> List[EventBase]: """Get events from the database and return in a list in the same order as given by `event_ids` arg. @@ -259,8 +283,8 @@ class EventsWorkerStore(SQLBaseStore): omits rejected events from the response. Returns: - Deferred[list[EventBase]]: List of events fetched from the database. The - events are in the same order as `event_ids` arg. + List of events fetched from the database. The events are in the same + order as `event_ids` arg. Note that the returned list may be smaller than the list of event IDs if not all events could be fetched. @@ -270,7 +294,7 @@ class EventsWorkerStore(SQLBaseStore): return [] # there may be duplicates so we cast the list to a set - event_entry_map = yield self._get_events_from_cache_or_db( + event_entry_map = await self._get_events_from_cache_or_db( set(event_ids), allow_rejected=allow_rejected ) @@ -305,7 +329,7 @@ class EventsWorkerStore(SQLBaseStore): continue redacted_event_id = entry.event.redacts - event_map = yield self._get_events_from_cache_or_db([redacted_event_id]) + event_map = await self._get_events_from_cache_or_db([redacted_event_id]) original_event_entry = event_map.get(redacted_event_id) if not original_event_entry: # we don't have the redacted event (or it was rejected). @@ -371,7 +395,7 @@ class EventsWorkerStore(SQLBaseStore): if get_prev_content: if "replaces_state" in event.unsigned: - prev = yield self.get_event( + prev = await self.get_event( event.unsigned["replaces_state"], get_prev_content=False, allow_none=True, @@ -383,8 +407,7 @@ class EventsWorkerStore(SQLBaseStore): return events - @defer.inlineCallbacks - def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): + async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): """Fetch a bunch of events from the cache or the database. If events are pulled from the database, they will be cached for future lookups. @@ -399,7 +422,7 @@ class EventsWorkerStore(SQLBaseStore): rejected events are omitted from the response. Returns: - Deferred[Dict[str, _EventCacheEntry]]: + Dict[str, _EventCacheEntry]: map from event id to result """ event_entry_map = self._get_events_from_cache( @@ -417,7 +440,7 @@ class EventsWorkerStore(SQLBaseStore): # the events have been redacted, and if so pulling the redaction event out # of the database to check it. # - missing_events = yield self._get_events_from_db( + missing_events = await self._get_events_from_db( missing_events_ids, allow_rejected=allow_rejected ) @@ -525,8 +548,7 @@ class EventsWorkerStore(SQLBaseStore): with PreserveLoggingContext(): self.hs.get_reactor().callFromThread(fire, event_list, e) - @defer.inlineCallbacks - def _get_events_from_db(self, event_ids, allow_rejected=False): + async def _get_events_from_db(self, event_ids, allow_rejected=False): """Fetch a bunch of events from the database. Returned events will be added to the cache for future lookups. @@ -540,7 +562,7 @@ class EventsWorkerStore(SQLBaseStore): rejected events are omitted from the response. Returns: - Deferred[Dict[str, _EventCacheEntry]]: + Dict[str, _EventCacheEntry]: map from event id to result. May return extra events which weren't asked for. """ @@ -548,7 +570,7 @@ class EventsWorkerStore(SQLBaseStore): events_to_fetch = event_ids while events_to_fetch: - row_map = yield self._enqueue_events(events_to_fetch) + row_map = await self._enqueue_events(events_to_fetch) # we need to recursively fetch any redactions of those events redaction_ids = set() @@ -650,8 +672,7 @@ class EventsWorkerStore(SQLBaseStore): return result_map - @defer.inlineCallbacks - def _enqueue_events(self, events): + async def _enqueue_events(self, events): """Fetches events from the database using the _event_fetch_list. This allows batch and bulk fetching of events - it allows us to fetch events without having to create a new transaction for each request for events. @@ -660,7 +681,7 @@ class EventsWorkerStore(SQLBaseStore): events (Iterable[str]): events to be fetched. Returns: - Deferred[Dict[str, Dict]]: map from event id to row data from the database. + Dict[str, Dict]: map from event id to row data from the database. May contain events that weren't requested. """ @@ -683,7 +704,7 @@ class EventsWorkerStore(SQLBaseStore): logger.debug("Loading %d events: %s", len(events), events) with PreserveLoggingContext(): - row_map = yield events_d + row_map = await events_d logger.debug("Loaded %d events (%d rows)", len(events), len(row_map)) return row_map @@ -842,33 +863,29 @@ class EventsWorkerStore(SQLBaseStore): # no valid redaction found for this event return None - @defer.inlineCallbacks - def have_events_in_timeline(self, event_ids): + async def have_events_in_timeline(self, event_ids): """Given a list of event ids, check if we have already processed and stored them as non outliers. """ - rows = yield defer.ensureDeferred( - self.db_pool.simple_select_many_batch( - table="events", - retcols=("event_id",), - column="event_id", - iterable=list(event_ids), - keyvalues={"outlier": False}, - desc="have_events_in_timeline", - ) + rows = await self.db_pool.simple_select_many_batch( + table="events", + retcols=("event_id",), + column="event_id", + iterable=list(event_ids), + keyvalues={"outlier": False}, + desc="have_events_in_timeline", ) return {r["event_id"] for r in rows} - @defer.inlineCallbacks - def have_seen_events(self, event_ids): + async def have_seen_events(self, event_ids): """Given a list of event ids, check if we have already processed them. Args: event_ids (iterable[str]): Returns: - Deferred[set[str]]: The events we have already seen. + set[str]: The events we have already seen. """ results = set() @@ -884,7 +901,7 @@ class EventsWorkerStore(SQLBaseStore): # break the input up into chunks of 100 input_iterator = iter(event_ids) for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []): - yield self.db_pool.runInteraction( + await self.db_pool.runInteraction( "have_seen_events", have_seen_events_txn, chunk ) return results @@ -914,8 +931,7 @@ class EventsWorkerStore(SQLBaseStore): room_id, ) - @defer.inlineCallbacks - def get_room_complexity(self, room_id): + async def get_room_complexity(self, room_id): """ Get a rough approximation of the complexity of the room. This is used by remote servers to decide whether they wish to join the room or not. @@ -926,9 +942,9 @@ class EventsWorkerStore(SQLBaseStore): room_id (str) Returns: - Deferred[dict[str:int]] of complexity version to complexity. + dict[str:int] of complexity version to complexity. """ - state_events = yield self.get_current_state_event_counts(room_id) + state_events = await self.get_current_state_event_counts(room_id) # Call this one "v1", so we can introduce new ones as we want to develop # it. @@ -1165,9 +1181,9 @@ class EventsWorkerStore(SQLBaseStore): to_2, so_2 = await self.get_event_ordering(event_id2) return (to_1, so_1) > (to_2, so_2) - @cachedInlineCallbacks(max_entries=5000) - def get_event_ordering(self, event_id): - res = yield self.db_pool.simple_select_one( + @cached(max_entries=5000) + async def get_event_ordering(self, event_id): + res = await self.db_pool.simple_select_one( table="events", retcols=["topological_ordering", "stream_ordering"], keyvalues={"event_id": event_id}, diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 4377bddb8c..497f607703 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -379,7 +379,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): limit: int = 0, order: str = "DESC", ) -> Tuple[List[EventBase], str]: - """Get new room events in stream ordering since `from_key`. Args: diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 2858d13558..23db821fb7 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -104,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( - return_value=defer.succeed({"123": mock_event}) + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check the content, but once == remove blocking event @@ -122,7 +122,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( - return_value=defer.succeed({"123": mock_event}) + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -217,7 +217,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) self._rlsn._store.get_events = Mock( - return_value=defer.succeed({"123": mock_event}) + return_value=make_awaitable({"123": mock_event}) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index a425e66f37..17fbde284a 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -31,6 +31,7 @@ from synapse.storage.databases.main.appservice import ( ) from tests import unittest +from tests.test_utils import make_awaitable from tests.utils import setup_test_homeserver @@ -357,7 +358,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): other_events = [Mock(event_id="e5"), Mock(event_id="e6")] # we aren't testing store._base stuff here, so mock this out - self.store.get_events_as_list = Mock(return_value=defer.succeed(events)) + self.store.get_events_as_list = Mock(return_value=make_awaitable(events)) yield self._insert_txn(self.as_list[1]["id"], 9, other_events) yield self._insert_txn(service.id, 10, events) -- cgit 1.5.1 From b49a5b9307fbbc9032e28e532e9036db07555d3d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 27 Aug 2020 17:24:37 -0400 Subject: Convert stats and related calls to async/await (#8192) --- changelog.d/8192.misc | 1 + .../storage/databases/main/monthly_active_users.py | 15 ++-- synapse/storage/databases/main/stats.py | 82 +++++++++++----------- tests/handlers/test_auth.py | 21 +++--- tests/handlers/test_register.py | 12 ++-- tests/rest/admin/test_user.py | 9 ++- .../test_resource_limits_server_notices.py | 10 +-- tests/storage/test_client_ips.py | 5 +- 8 files changed, 78 insertions(+), 77 deletions(-) create mode 100644 changelog.d/8192.misc (limited to 'tests/server_notices') diff --git a/changelog.d/8192.misc b/changelog.d/8192.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8192.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index fe30552c08..1d793d3deb 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import List +from typing import Dict, List from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_in_list_sql_clause @@ -33,11 +33,11 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): self.hs = hs @cached(num_args=0) - def get_monthly_active_count(self): + async def get_monthly_active_count(self) -> int: """Generates current count of monthly active users Returns: - Defered[int]: Number of current monthly active users + Number of current monthly active users """ def _count_users(txn): @@ -46,10 +46,10 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): (count,) = txn.fetchone() return count - return self.db_pool.runInteraction("count_users", _count_users) + return await self.db_pool.runInteraction("count_users", _count_users) @cached(num_args=0) - def get_monthly_active_count_by_service(self): + async def get_monthly_active_count_by_service(self) -> Dict[str, int]: """Generates current count of monthly active users broken down by service. A service is typically an appservice but also includes native matrix users. Since the `monthly_active_users` table is populated from the `user_ips` table @@ -57,8 +57,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): method to return anything other than native matrix users. Returns: - Deferred[dict]: dict that includes a mapping between app_service_id - and the number of occurrences. + A mapping between app_service_id and the number of occurrences. """ @@ -74,7 +73,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): result = txn.fetchall() return dict(result) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "count_users_by_service", _count_users_by_service ) diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 7af2608ca4..9b9bc304a8 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -15,8 +15,9 @@ # limitations under the License. import logging +from collections import Counter from itertools import chain -from typing import Any, Dict, Tuple +from typing import Any, Dict, List, Optional, Tuple from twisted.internet.defer import DeferredLock @@ -251,21 +252,23 @@ class StatsStore(StateDeltasStore): desc="update_room_state", ) - def get_statistics_for_subject(self, stats_type, stats_id, start, size=100): + async def get_statistics_for_subject( + self, stats_type: str, stats_id: str, start: str, size: int = 100 + ) -> List[dict]: """ Get statistics for a given subject. Args: - stats_type (str): The type of subject - stats_id (str): The ID of the subject (e.g. room_id or user_id) - start (int): Pagination start. Number of entries, not timestamp. - size (int): How many entries to return. + stats_type: The type of subject + stats_id: The ID of the subject (e.g. room_id or user_id) + start: Pagination start. Number of entries, not timestamp. + size: How many entries to return. Returns: - Deferred[list[dict]], where the dict has the keys of + A list of dicts, where the dict has the keys of ABSOLUTE_STATS_FIELDS[stats_type], and "bucket_size" and "end_ts". """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_statistics_for_subject", self._get_statistics_for_subject_txn, stats_type, @@ -319,18 +322,17 @@ class StatsStore(StateDeltasStore): allow_none=True, ) - def bulk_update_stats_delta(self, ts, updates, stream_id): + async def bulk_update_stats_delta( + self, ts: int, updates: Dict[str, Dict[str, Dict[str, Counter]]], stream_id: int + ) -> None: """Bulk update stats tables for a given stream_id and updates the stats incremental position. Args: - ts (int): Current timestamp in ms - updates(dict[str, dict[str, dict[str, Counter]]]): The updates to - commit as a mapping stats_type -> stats_id -> field -> delta. - stream_id (int): Current position. - - Returns: - Deferred + ts: Current timestamp in ms + updates: The updates to commit as a mapping of + stats_type -> stats_id -> field -> delta. + stream_id: Current position. """ def _bulk_update_stats_delta_txn(txn): @@ -355,38 +357,37 @@ class StatsStore(StateDeltasStore): updatevalues={"stream_id": stream_id}, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "bulk_update_stats_delta", _bulk_update_stats_delta_txn ) - def update_stats_delta( + async def update_stats_delta( self, - ts, - stats_type, - stats_id, - fields, - complete_with_stream_id, - absolute_field_overrides=None, - ): + ts: int, + stats_type: str, + stats_id: str, + fields: Dict[str, int], + complete_with_stream_id: Optional[int], + absolute_field_overrides: Optional[Dict[str, int]] = None, + ) -> None: """ Updates the statistics for a subject, with a delta (difference/relative change). Args: - ts (int): timestamp of the change - stats_type (str): "room" or "user" – the kind of subject - stats_id (str): the subject's ID (room ID or user ID) - fields (dict[str, int]): Deltas of stats values. - complete_with_stream_id (int, optional): + ts: timestamp of the change + stats_type: "room" or "user" – the kind of subject + stats_id: the subject's ID (room ID or user ID) + fields: Deltas of stats values. + complete_with_stream_id: If supplied, converts an incomplete row into a complete row, with the supplied stream_id marked as the stream_id where the row was completed. - absolute_field_overrides (dict[str, int]): Current stats values - (i.e. not deltas) of absolute fields. - Does not work with per-slice fields. + absolute_field_overrides: Current stats values (i.e. not deltas) of + absolute fields. Does not work with per-slice fields. """ - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "update_stats_delta", self._update_stats_delta_txn, ts, @@ -646,19 +647,20 @@ class StatsStore(StateDeltasStore): txn, into_table, all_dest_keyvalues, src_row ) - def get_changes_room_total_events_and_bytes(self, min_pos, max_pos): + async def get_changes_room_total_events_and_bytes( + self, min_pos: int, max_pos: int + ) -> Dict[str, Dict[str, int]]: """Fetches the counts of events in the given range of stream IDs. Args: - min_pos (int) - max_pos (int) + min_pos + max_pos Returns: - Deferred[dict[str, dict[str, int]]]: Mapping of room ID to field - changes. + Mapping of room ID to field changes. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "stats_incremental_total_events_and_bytes", self.get_changes_room_total_events_and_bytes_txn, min_pos, diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index c01b04e1dc..4b3fb018b1 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -24,6 +24,7 @@ from synapse.api.errors import ResourceLimitError from synapse.handlers.auth import AuthHandler from tests import unittest +from tests.test_utils import make_awaitable from tests.utils import setup_test_homeserver @@ -142,7 +143,7 @@ class AuthTestCase(unittest.TestCase): def test_mau_limits_exceeded_large(self): self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.large_number_of_users) + side_effect=lambda: make_awaitable(self.large_number_of_users) ) with self.assertRaises(ResourceLimitError): @@ -153,7 +154,7 @@ class AuthTestCase(unittest.TestCase): ) self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.large_number_of_users) + side_effect=lambda: make_awaitable(self.large_number_of_users) ) with self.assertRaises(ResourceLimitError): yield defer.ensureDeferred( @@ -168,7 +169,7 @@ class AuthTestCase(unittest.TestCase): # If not in monthly active cohort self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.auth_blocking._max_mau_value) + side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value) ) with self.assertRaises(ResourceLimitError): yield defer.ensureDeferred( @@ -178,7 +179,7 @@ class AuthTestCase(unittest.TestCase): ) self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.auth_blocking._max_mau_value) + side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value) ) with self.assertRaises(ResourceLimitError): yield defer.ensureDeferred( @@ -188,10 +189,10 @@ class AuthTestCase(unittest.TestCase): ) # If in monthly active cohort self.hs.get_datastore().user_last_seen_monthly_active = Mock( - return_value=defer.succeed(self.hs.get_clock().time_msec()) + side_effect=lambda user_id: make_awaitable(self.hs.get_clock().time_msec()) ) self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.auth_blocking._max_mau_value) + side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value) ) yield defer.ensureDeferred( self.auth_handler.get_access_token_for_user_id( @@ -199,10 +200,10 @@ class AuthTestCase(unittest.TestCase): ) ) self.hs.get_datastore().user_last_seen_monthly_active = Mock( - return_value=defer.succeed(self.hs.get_clock().time_msec()) + side_effect=lambda user_id: make_awaitable(self.hs.get_clock().time_msec()) ) self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.auth_blocking._max_mau_value) + side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value) ) yield defer.ensureDeferred( self.auth_handler.validate_short_term_login_token_and_get_user_id( @@ -215,7 +216,7 @@ class AuthTestCase(unittest.TestCase): self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.small_number_of_users) + side_effect=lambda: make_awaitable(self.small_number_of_users) ) # Ensure does not raise exception yield defer.ensureDeferred( @@ -225,7 +226,7 @@ class AuthTestCase(unittest.TestCase): ) self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.small_number_of_users) + side_effect=lambda: make_awaitable(self.small_number_of_users) ) yield defer.ensureDeferred( self.auth_handler.validate_short_term_login_token_and_get_user_id( diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 5c92d0e8c9..eddf5e2498 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -15,8 +15,6 @@ from mock import Mock -from twisted.internet import defer - from synapse.api.auth import Auth from synapse.api.constants import UserTypes from synapse.api.errors import Codes, ResourceLimitError, SynapseError @@ -102,7 +100,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_get_or_create_user_mau_not_blocked(self): self.hs.config.limit_usage_by_mau = True self.store.count_monthly_users = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value - 1) + side_effect=lambda: make_awaitable(self.hs.config.max_mau_value - 1) ) # Ensure does not throw exception self.get_success(self.get_or_create_user(self.requester, "c", "User")) @@ -110,7 +108,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_get_or_create_user_mau_blocked(self): self.hs.config.limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( - return_value=defer.succeed(self.lots_of_users) + side_effect=lambda: make_awaitable(self.lots_of_users) ) self.get_failure( self.get_or_create_user(self.requester, "b", "display_name"), @@ -118,7 +116,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ) self.store.get_monthly_active_count = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value) + side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) ) self.get_failure( self.get_or_create_user(self.requester, "b", "display_name"), @@ -128,14 +126,14 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_register_mau_blocked(self): self.hs.config.limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( - return_value=defer.succeed(self.lots_of_users) + side_effect=lambda: make_awaitable(self.lots_of_users) ) self.get_failure( self.handler.register_user(localpart="local_part"), ResourceLimitError ) self.store.get_monthly_active_count = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value) + side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) ) self.get_failure( self.handler.register_user(localpart="local_part"), ResourceLimitError diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 17d0aae2e9..160c630235 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -20,8 +20,6 @@ import urllib.parse from mock import Mock -from twisted.internet import defer - import synapse.rest.admin from synapse.api.constants import UserTypes from synapse.api.errors import HttpResponseException, ResourceLimitError @@ -29,6 +27,7 @@ from synapse.rest.client.v1 import login from synapse.rest.client.v2_alpha import sync from tests import unittest +from tests.test_utils import make_awaitable from tests.unittest import override_config @@ -338,7 +337,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit store.get_monthly_active_count = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value) + side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -592,7 +591,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit self.store.get_monthly_active_count = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value) + side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -632,7 +631,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit self.store.get_monthly_active_count = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value) + side_effect=lambda: make_awaitable(self.hs.config.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 23db821fb7..973338ea71 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -67,7 +67,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): raise Exception("Failed to find reference to ResourceLimitsServerNotices") self._rlsn._store.user_last_seen_monthly_active = Mock( - return_value=defer.succeed(1000) + side_effect=lambda user_id: make_awaitable(1000) ) self._rlsn._server_notices_manager.send_notice = Mock( return_value=defer.succeed(Mock()) @@ -158,7 +158,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) self._rlsn._store.user_last_seen_monthly_active = Mock( - return_value=defer.succeed(None) + side_effect=lambda user_id: make_awaitable(None) ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -261,10 +261,12 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): self.user_id = "@user_id:test" def test_server_notice_only_sent_once(self): - self.store.get_monthly_active_count = Mock(return_value=defer.succeed(1000)) + self.store.get_monthly_active_count = Mock( + side_effect=lambda: make_awaitable(1000) + ) self.store.user_last_seen_monthly_active = Mock( - return_value=defer.succeed(1000) + side_effect=lambda user_id: make_awaitable(1000) ) # Call the function multiple times to ensure we only send the notice once diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 224ea6fd79..370c247e16 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -16,13 +16,12 @@ from mock import Mock -from twisted.internet import defer - import synapse.rest.admin from synapse.http.site import XForwardedForRequest from synapse.rest.client.v1 import login from tests import unittest +from tests.test_utils import make_awaitable from tests.unittest import override_config @@ -155,7 +154,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): user_id = "@user:server" self.store.get_monthly_active_count = Mock( - return_value=defer.succeed(lots_of_users) + side_effect=lambda: make_awaitable(lots_of_users) ) self.get_success( self.store.insert_client_ip( -- cgit 1.5.1