From c978f6c4515a631f289aedb1844d8579b9334aaa Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 30 Jul 2020 08:01:33 -0400 Subject: Convert federation client to async/await. (#7975) --- tests/test_federation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests/test_federation.py') diff --git a/tests/test_federation.py b/tests/test_federation.py index 87a16d7d7a..c2f12c2741 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -95,7 +95,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): prev_events that said event references. """ - def post_json(destination, path, data, headers=None, timeout=0): + async def post_json(destination, path, data, headers=None, timeout=0): # If it asks us for new missing events, give them NOTHING if path.startswith("/_matrix/federation/v1/get_missing_events/"): return {"events": []} -- cgit 1.5.1 From fbe930dad28c81a5e563ddc8683f65b8279aad52 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 12 Aug 2020 12:14:34 -0400 Subject: Convert the roommember database to async/await. (#8070) --- changelog.d/8070.misc | 1 + synapse/storage/_base.py | 1 - synapse/storage/databases/main/push_rule.py | 75 -------- synapse/storage/databases/main/roommember.py | 263 ++++++++++----------------- tests/test_federation.py | 18 +- 5 files changed, 116 insertions(+), 242 deletions(-) create mode 100644 changelog.d/8070.misc (limited to 'tests/test_federation.py') diff --git a/changelog.d/8070.misc b/changelog.d/8070.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8070.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index ca800df831..6814bf5fcf 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -58,7 +58,6 @@ class SQLBaseStore(metaclass=ABCMeta): """ for host in {get_domain_from_id(u) for u in members_changed}: self._attempt_to_invalidate_cache("is_host_joined", (room_id, host)) - self._attempt_to_invalidate_cache("was_host_joined", (room_id, host)) self._attempt_to_invalidate_cache("get_users_in_room", (room_id,)) self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 19a0211a03..6562db5c2b 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -256,81 +256,6 @@ class PushRulesWorkerStore( ): yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule) - @defer.inlineCallbacks - def bulk_get_push_rules_for_room(self, event, context): - state_group = context.state_group - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() - - current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) - result = yield self._bulk_get_push_rules_for_room( - event.room_id, state_group, current_state_ids, event=event - ) - return result - - @cachedInlineCallbacks(num_args=2, cache_context=True) - def _bulk_get_push_rules_for_room( - self, room_id, state_group, current_state_ids, cache_context, event=None - ): - # We don't use `state_group`, its there so that we can cache based - # on it. However, its important that its never None, since two current_state's - # with a state_group of None are likely to be different. - # See bulk_get_push_rules_for_room for how we work around this. - assert state_group is not None - - # We also will want to generate notifs for other people in the room so - # their unread countss are correct in the event stream, but to avoid - # generating them for bot / AS users etc, we only do so for people who've - # sent a read receipt into the room. - - users_in_room = yield self._get_joined_users_from_context( - room_id, - state_group, - current_state_ids, - on_invalidate=cache_context.invalidate, - event=event, - ) - - # We ignore app service users for now. This is so that we don't fill - # up the `get_if_users_have_pushers` cache with AS entries that we - # know don't have pushers, nor even read receipts. - local_users_in_room = { - u - for u in users_in_room - if self.hs.is_mine_id(u) - and not self.get_if_app_services_interested_in_user(u) - } - - # users in the room who have pushers need to get push rules run because - # that's how their pushers work - if_users_with_pushers = yield self.get_if_users_have_pushers( - local_users_in_room, on_invalidate=cache_context.invalidate - ) - user_ids = { - uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher - } - - users_with_receipts = yield self.get_users_with_read_receipts_in_room( - room_id, on_invalidate=cache_context.invalidate - ) - - # any users with pushers must be ours: they have pushers - for uid in users_with_receipts: - if uid in local_users_in_room: - user_ids.add(uid) - - rules_by_user = yield self.bulk_get_push_rules( - user_ids, on_invalidate=cache_context.invalidate - ) - - rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None} - - return rules_by_user - @cachedList( cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids", diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 7c5be251bd..b2fcfc9bfe 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -15,11 +15,13 @@ # limitations under the License. import logging -from typing import Iterable, List, Set +from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set from twisted.internet import defer from synapse.api.constants import EventTypes, Membership +from synapse.events import EventBase +from synapse.events.snapshot import EventContext from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import ( @@ -40,9 +42,12 @@ from synapse.storage.roommember import ( from synapse.types import Collection, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList +from synapse.util.caches.descriptors import _CacheContext, cached, cachedList from synapse.util.metrics import Measure +if TYPE_CHECKING: + from synapse.state import _StateCacheEntry + logger = logging.getLogger(__name__) @@ -150,12 +155,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) @cached(max_entries=100000, iterable=True) - def get_users_in_room(self, room_id): + def get_users_in_room(self, room_id: str): return self.db_pool.runInteraction( "get_users_in_room", self.get_users_in_room_txn, room_id ) - def get_users_in_room_txn(self, txn, room_id): + def get_users_in_room_txn(self, txn, room_id: str) -> List[str]: # If we can assume current_state_events.membership is up to date # then we can avoid a join, which is a Very Good Thing given how # frequently this function gets called. @@ -178,11 +183,11 @@ class RoomMemberWorkerStore(EventsWorkerStore): return [r[0] for r in txn] @cached(max_entries=100000) - def get_room_summary(self, room_id): + def get_room_summary(self, room_id: str): """ Get the details of a room roughly suitable for use by the room summary extension to /sync. Useful when lazy loading room members. Args: - room_id (str): The room ID to query + room_id: The room ID to query Returns: Deferred[dict[str, MemberSummary]: dict of membership states, pointing to a MemberSummary named tuple. @@ -261,78 +266,59 @@ class RoomMemberWorkerStore(EventsWorkerStore): return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn) - def _get_user_counts_in_room_txn(self, txn, room_id): - """ - Get the user count in a room by membership. - - Args: - room_id (str) - membership (Membership) - - Returns: - Deferred[int] - """ - sql = """ - SELECT m.membership, count(*) FROM room_memberships as m - INNER JOIN current_state_events as c USING(event_id) - WHERE c.type = 'm.room.member' AND c.room_id = ? - GROUP BY m.membership - """ - - txn.execute(sql, (room_id,)) - return {row[0]: row[1] for row in txn} - @cached() - def get_invited_rooms_for_local_user(self, user_id): - """ Get all the rooms the *local* user is invited to + def get_invited_rooms_for_local_user(self, user_id: str) -> Awaitable[RoomsForUser]: + """Get all the rooms the *local* user is invited to. Args: - user_id (str): The user ID. + user_id: The user ID. + Returns: - A deferred list of RoomsForUser. + A awaitable list of RoomsForUser. """ return self.get_rooms_for_local_user_where_membership_is( user_id, [Membership.INVITE] ) - @defer.inlineCallbacks - def get_invite_for_local_user_in_room(self, user_id, room_id): - """Gets the invite for the given *local* user and room + async def get_invite_for_local_user_in_room( + self, user_id: str, room_id: str + ) -> Optional[RoomsForUser]: + """Gets the invite for the given *local* user and room. Args: - user_id (str) - room_id (str) + user_id: The user ID to find the invite of. + room_id: The room to user was invited to. Returns: - Deferred: Resolves to either a RoomsForUser or None if no invite was - found. + Either a RoomsForUser or None if no invite was found. """ - invites = yield self.get_invited_rooms_for_local_user(user_id) + invites = await self.get_invited_rooms_for_local_user(user_id) for invite in invites: if invite.room_id == room_id: return invite return None - @defer.inlineCallbacks - def get_rooms_for_local_user_where_membership_is(self, user_id, membership_list): - """ Get all the rooms for this *local* user where the membership for this user + async def get_rooms_for_local_user_where_membership_is( + self, user_id: str, membership_list: List[str] + ) -> Optional[List[RoomsForUser]]: + """Get all the rooms for this *local* user where the membership for this user matches one in the membership list. Filters out forgotten rooms. Args: - user_id (str): The user ID. - membership_list (list): A list of synapse.api.constants.Membership - values which the user must be in. + user_id: The user ID. + membership_list: A list of synapse.api.constants.Membership + values which the user must be in. Returns: - Deferred[list[RoomsForUser]] + The RoomsForUser that the user matches the membership types. """ if not membership_list: - return defer.succeed(None) + return None - rooms = yield self.db_pool.runInteraction( + rooms = await self.db_pool.runInteraction( "get_rooms_for_local_user_where_membership_is", self._get_rooms_for_local_user_where_membership_is_txn, user_id, @@ -340,12 +326,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) # Now we filter out forgotten rooms - forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id) + forgotten_rooms = await self.get_forgotten_rooms_for_user(user_id) return [room for room in rooms if room.room_id not in forgotten_rooms] def _get_rooms_for_local_user_where_membership_is_txn( - self, txn, user_id, membership_list - ): + self, txn, user_id: str, membership_list: List[str] + ) -> List[RoomsForUser]: # Paranoia check. if not self.hs.is_mine_id(user_id): raise Exception( @@ -374,14 +360,14 @@ class RoomMemberWorkerStore(EventsWorkerStore): return results @cached(max_entries=500000, iterable=True) - def get_rooms_for_user_with_stream_ordering(self, user_id): + def get_rooms_for_user_with_stream_ordering(self, user_id: str): """Returns a set of room_ids the user is currently joined to. If a remote user only returns rooms this server is currently participating in. Args: - user_id (str) + user_id Returns: Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns @@ -394,7 +380,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): user_id, ) - def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id): + def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id: str): # We use `current_state_events` here and not `local_current_membership` # as a) this gets called with remote users and b) this only gets called # for rooms the server is participating in. @@ -458,37 +444,39 @@ class RoomMemberWorkerStore(EventsWorkerStore): _get_users_server_still_shares_room_with_txn, ) - @defer.inlineCallbacks - def get_rooms_for_user(self, user_id, on_invalidate=None): + async def get_rooms_for_user(self, user_id: str, on_invalidate=None): """Returns a set of room_ids the user is currently joined to. If a remote user only returns rooms this server is currently participating in. """ - rooms = yield self.get_rooms_for_user_with_stream_ordering( + rooms = await self.get_rooms_for_user_with_stream_ordering( user_id, on_invalidate=on_invalidate ) return frozenset(r.room_id for r in rooms) - @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True) - def get_users_who_share_room_with_user(self, user_id, cache_context): + @cached(max_entries=500000, cache_context=True, iterable=True) + async def get_users_who_share_room_with_user( + self, user_id: str, cache_context: _CacheContext + ) -> Set[str]: """Returns the set of users who share a room with `user_id` """ - room_ids = yield self.get_rooms_for_user( + room_ids = await self.get_rooms_for_user( user_id, on_invalidate=cache_context.invalidate ) user_who_share_room = set() for room_id in room_ids: - user_ids = yield self.get_users_in_room( + user_ids = await self.get_users_in_room( room_id, on_invalidate=cache_context.invalidate ) user_who_share_room.update(user_ids) return user_who_share_room - @defer.inlineCallbacks - def get_joined_users_from_context(self, event, context): + async def get_joined_users_from_context( + self, event: EventBase, context: EventContext + ): state_group = context.state_group if not state_group: # If state_group is None it means it has yet to be assigned a @@ -497,14 +485,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): # To do this we set the state_group to a new object as object() != object() state_group = object() - current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) - result = yield self._get_joined_users_from_context( + current_state_ids = await context.get_current_state_ids() + return await self._get_joined_users_from_context( event.room_id, state_group, current_state_ids, event=event, context=context ) - return result - @defer.inlineCallbacks - def get_joined_users_from_state(self, room_id, state_entry): + async def get_joined_users_from_state(self, room_id, state_entry): state_group = state_entry.state_group if not state_group: # If state_group is None it means it has yet to be assigned a @@ -514,16 +500,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): state_group = object() with Measure(self._clock, "get_joined_users_from_state"): - return ( - yield self._get_joined_users_from_context( - room_id, state_group, state_entry.state, context=state_entry - ) + return await self._get_joined_users_from_context( + room_id, state_group, state_entry.state, context=state_entry ) - @cachedInlineCallbacks( - num_args=2, cache_context=True, iterable=True, max_entries=100000 - ) - def _get_joined_users_from_context( + @cached(num_args=2, cache_context=True, iterable=True, max_entries=100000) + async def _get_joined_users_from_context( self, room_id, state_group, @@ -535,7 +517,6 @@ class RoomMemberWorkerStore(EventsWorkerStore): # We don't use `state_group`, it's there so that we can cache based # on it. However, it's important that it's never None, since two current_states # with a state_group of None are likely to be different. - # See bulk_get_push_rules_for_room for how we work around this. assert state_group is not None users_in_room = {} @@ -588,7 +569,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): missing_member_event_ids.append(event_id) if missing_member_event_ids: - event_to_memberships = yield self._get_joined_profiles_from_event_ids( + event_to_memberships = await self._get_joined_profiles_from_event_ids( missing_member_event_ids ) users_in_room.update((row for row in event_to_memberships.values() if row)) @@ -612,12 +593,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): list_name="event_ids", inlineCallbacks=True, ) - def _get_joined_profiles_from_event_ids(self, event_ids): + def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]): """For given set of member event_ids check if they point to a join event and if so return the associated user and profile info. Args: - event_ids (Iterable[str]): The member event IDs to lookup + event_ids: The member event IDs to lookup Returns: Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID @@ -644,8 +625,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): for row in rows } - @cachedInlineCallbacks(max_entries=10000) - def is_host_joined(self, room_id, host): + @cached(max_entries=10000) + async def is_host_joined(self, room_id: str, host: str) -> bool: if "%" in host or "_" in host: raise Exception("Invalid host name") @@ -664,7 +645,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # the returned user actually has the correct domain. like_clause = "%:" + host - rows = yield self.db_pool.execute( + rows = await self.db_pool.execute( "is_host_joined", None, sql, room_id, like_clause ) @@ -678,50 +659,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): return True - @cachedInlineCallbacks() - def was_host_joined(self, room_id, host): - """Check whether the server is or ever was in the room. - - Args: - room_id (str) - host (str) - - Returns: - Deferred: Resolves to True if the host is/was in the room, otherwise - False. - """ - if "%" in host or "_" in host: - raise Exception("Invalid host name") - - sql = """ - SELECT user_id FROM room_memberships - WHERE room_id = ? - AND user_id LIKE ? - AND membership = 'join' - LIMIT 1 - """ - - # We do need to be careful to ensure that host doesn't have any wild cards - # in it, but we checked above for known ones and we'll check below that - # the returned user actually has the correct domain. - like_clause = "%:" + host - - rows = yield self.db_pool.execute( - "was_host_joined", None, sql, room_id, like_clause - ) - - if not rows: - return False - - user_id = rows[0][0] - if get_domain_from_id(user_id) != host: - # This can only happen if the host name has something funky in it - raise Exception("Invalid host name") - - return True - - @defer.inlineCallbacks - def get_joined_hosts(self, room_id, state_entry): + async def get_joined_hosts(self, room_id: str, state_entry): state_group = state_entry.state_group if not state_group: # If state_group is None it means it has yet to be assigned a @@ -731,32 +669,28 @@ class RoomMemberWorkerStore(EventsWorkerStore): state_group = object() with Measure(self._clock, "get_joined_hosts"): - return ( - yield self._get_joined_hosts( - room_id, state_group, state_entry.state, state_entry=state_entry - ) + return await self._get_joined_hosts( + room_id, state_group, state_entry.state, state_entry=state_entry ) - @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True) - # @defer.inlineCallbacks - def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry): + @cached(num_args=2, max_entries=10000, iterable=True) + async def _get_joined_hosts( + self, room_id, state_group, current_state_ids, state_entry + ): # We don't use `state_group`, its there so that we can cache based # on it. However, its important that its never None, since two current_state's # with a state_group of None are likely to be different. - # See bulk_get_push_rules_for_room for how we work around this. assert state_group is not None - cache = yield self._get_joined_hosts_cache(room_id) - joined_hosts = yield cache.get_destinations(state_entry) - - return joined_hosts + cache = await self._get_joined_hosts_cache(room_id) + return await cache.get_destinations(state_entry) @cached(max_entries=10000) - def _get_joined_hosts_cache(self, room_id): + def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache": return _JoinedHostsCache(self, room_id) - @cachedInlineCallbacks(num_args=2) - def did_forget(self, user_id, room_id): + @cached(num_args=2) + async def did_forget(self, user_id: str, room_id: str) -> bool: """Returns whether user_id has elected to discard history for room_id. Returns False if they have since re-joined.""" @@ -778,15 +712,15 @@ class RoomMemberWorkerStore(EventsWorkerStore): rows = txn.fetchall() return rows[0][0] - count = yield self.db_pool.runInteraction("did_forget_membership", f) + count = await self.db_pool.runInteraction("did_forget_membership", f) return count == 0 @cached() - def get_forgotten_rooms_for_user(self, user_id): + def get_forgotten_rooms_for_user(self, user_id: str): """Gets all rooms the user has forgotten. Args: - user_id (str) + user_id Returns: Deferred[set[str]] @@ -819,18 +753,17 @@ class RoomMemberWorkerStore(EventsWorkerStore): "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn ) - @defer.inlineCallbacks - def get_rooms_user_has_been_in(self, user_id): + async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]: """Get all rooms that the user has ever been in. Args: - user_id (str) + user_id: The user ID to get the rooms of. Returns: - Deferred[set[str]]: Set of room IDs. + Set of room IDs. """ - room_ids = yield self.db_pool.simple_select_onecol( + room_ids = await self.db_pool.simple_select_onecol( table="room_memberships", keyvalues={"membership": Membership.JOIN, "user_id": user_id}, retcol="room_id", @@ -905,8 +838,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): where_clause="forgotten = 1", ) - @defer.inlineCallbacks - def _background_add_membership_profile(self, progress, batch_size): + async def _background_add_membership_profile(self, progress, batch_size): target_min_stream_id = progress.get( "target_min_stream_id_inclusive", self._min_stream_order_on_start ) @@ -971,19 +903,18 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): return len(rows) - result = yield self.db_pool.runInteraction( + result = await self.db_pool.runInteraction( _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn ) if not result: - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( _MEMBERSHIP_PROFILE_UPDATE_NAME ) return result - @defer.inlineCallbacks - def _background_current_state_membership(self, progress, batch_size): + async def _background_current_state_membership(self, progress, batch_size): """Update the new membership column on current_state_events. This works by iterating over all rooms in alphebetical order. @@ -1029,14 +960,14 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): # string, which will compare before all room IDs correctly. last_processed_room = progress.get("last_processed_room", "") - row_count, finished = yield self.db_pool.runInteraction( + row_count, finished = await self.db_pool.runInteraction( "_background_current_state_membership_update", _background_current_state_membership_txn, last_processed_room, ) if finished: - yield self.db_pool.updates._end_background_update( + await self.db_pool.updates._end_background_update( _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME ) @@ -1047,7 +978,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): super(RoomMemberStore, self).__init__(database, db_conn, hs) - def forget(self, user_id, room_id): + def forget(self, user_id: str, room_id: str): """Indicate that user_id wishes to discard history for room_id.""" def f(txn): @@ -1088,17 +1019,19 @@ class _JoinedHostsCache(object): self._len = 0 - @defer.inlineCallbacks - def get_destinations(self, state_entry): + async def get_destinations(self, state_entry: "_StateCacheEntry") -> Set[str]: """Get set of destinations for a state entry Args: - state_entry(synapse.state._StateCacheEntry) + state_entry + + Returns: + The destinations as a set. """ if state_entry.state_group == self.state_group: return frozenset(self.hosts_to_joined_users) - with (yield self.linearizer.queue(())): + with (await self.linearizer.queue(())): if state_entry.state_group == self.state_group: pass elif state_entry.prev_group == self.state_group: @@ -1110,7 +1043,7 @@ class _JoinedHostsCache(object): user_id = state_key known_joins = self.hosts_to_joined_users.setdefault(host, set()) - event = yield self.store.get_event(event_id) + event = await self.store.get_event(event_id) if event.membership == Membership.JOIN: known_joins.add(user_id) else: @@ -1119,7 +1052,7 @@ class _JoinedHostsCache(object): if not known_joins: self.hosts_to_joined_users.pop(host, None) else: - joined_users = yield self.store.get_joined_users_from_state( + joined_users = await self.store.get_joined_users_from_state( self.room_id, state_entry ) diff --git a/tests/test_federation.py b/tests/test_federation.py index c2f12c2741..f2fa42bfb9 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -1,3 +1,18 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from mock import Mock from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed @@ -10,6 +25,7 @@ from synapse.util.retryutils import NotRetryingDestination from tests import unittest from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver +from tests.test_utils import make_awaitable class MessageAcceptTests(unittest.HomeserverTestCase): @@ -173,7 +189,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Register a mock on the store so that the incoming update doesn't fail because # we don't share a room with the user. store = self.homeserver.get_datastore() - store.get_rooms_for_user = Mock(return_value=succeed(["!someroom:test"])) + store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"])) # Manually inject a fake device list update. We need this update to include at # least one prev_id so that the user's device list will need to be retried. -- cgit 1.5.1 From ac77cdb64e50c9fdfc00cccbc7b96f42057aa741 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Aug 2020 12:37:59 -0400 Subject: Add a shadow-banned flag to users. (#8092) --- changelog.d/8092.feature | 1 + synapse/api/auth.py | 12 ++++++++++- synapse/handlers/register.py | 8 +++++++ synapse/replication/http/register.py | 4 ++++ synapse/storage/databases/main/registration.py | 9 +++++++- .../main/schema/delta/58/09shadow_ban.sql | 18 ++++++++++++++++ synapse/types.py | 25 +++++++++++++++++++--- tests/storage/test_cleanup_extrems.py | 4 ++-- tests/storage/test_event_metrics.py | 2 +- tests/storage/test_roommember.py | 2 +- tests/test_federation.py | 2 +- tests/unittest.py | 8 +++++-- 12 files changed, 83 insertions(+), 12 deletions(-) create mode 100644 changelog.d/8092.feature create mode 100644 synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql (limited to 'tests/test_federation.py') diff --git a/changelog.d/8092.feature b/changelog.d/8092.feature new file mode 100644 index 0000000000..813e6d0903 --- /dev/null +++ b/changelog.d/8092.feature @@ -0,0 +1 @@ +Add support for shadow-banning users (ignoring any message send requests). diff --git a/synapse/api/auth.py b/synapse/api/auth.py index d8190f92ab..7aab764360 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -213,6 +213,7 @@ class Auth(object): user = user_info["user"] token_id = user_info["token_id"] is_guest = user_info["is_guest"] + shadow_banned = user_info["shadow_banned"] # Deny the request if the user account has expired. if self._account_validity.enabled and not allow_expired: @@ -252,7 +253,12 @@ class Auth(object): opentracing.set_tag("device_id", device_id) return synapse.types.create_requester( - user, token_id, is_guest, device_id, app_service=app_service + user, + token_id, + is_guest, + shadow_banned, + device_id, + app_service=app_service, ) except KeyError: raise MissingClientTokenError() @@ -297,6 +303,7 @@ class Auth(object): dict that includes: `user` (UserID) `is_guest` (bool) + `shadow_banned` (bool) `token_id` (int|None): access token id. May be None if guest `device_id` (str|None): device corresponding to access token Raises: @@ -356,6 +363,7 @@ class Auth(object): ret = { "user": user, "is_guest": True, + "shadow_banned": False, "token_id": None, # all guests get the same device id "device_id": GUEST_DEVICE_ID, @@ -365,6 +373,7 @@ class Auth(object): ret = { "user": user, "is_guest": False, + "shadow_banned": False, "token_id": None, "device_id": None, } @@ -488,6 +497,7 @@ class Auth(object): "user": UserID.from_string(ret.get("name")), "token_id": ret.get("token_id", None), "is_guest": False, + "shadow_banned": ret.get("shadow_banned"), "device_id": ret.get("device_id"), "valid_until_ms": ret.get("valid_until_ms"), } diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index c94209ab3d..999bc6efb5 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -142,6 +142,7 @@ class RegistrationHandler(BaseHandler): address=None, bind_emails=[], by_admin=False, + shadow_banned=False, ): """Registers a new client on the server. @@ -159,6 +160,7 @@ class RegistrationHandler(BaseHandler): bind_emails (List[str]): list of emails to bind to this account. by_admin (bool): True if this registration is being made via the admin api, otherwise False. + shadow_banned (bool): Shadow-ban the created user. Returns: str: user_id Raises: @@ -194,6 +196,7 @@ class RegistrationHandler(BaseHandler): admin=admin, user_type=user_type, address=address, + shadow_banned=shadow_banned, ) if self.hs.config.user_directory_search_all_users: @@ -224,6 +227,7 @@ class RegistrationHandler(BaseHandler): make_guest=make_guest, create_profile_with_displayname=default_display_name, address=address, + shadow_banned=shadow_banned, ) # Successfully registered @@ -529,6 +533,7 @@ class RegistrationHandler(BaseHandler): admin=False, user_type=None, address=None, + shadow_banned=False, ): """Register user in the datastore. @@ -546,6 +551,7 @@ class RegistrationHandler(BaseHandler): user_type (str|None): type of user. One of the values from api.constants.UserTypes, or None for a normal user. address (str|None): the IP address used to perform the registration. + shadow_banned (bool): Whether to shadow-ban the user Returns: Awaitable @@ -561,6 +567,7 @@ class RegistrationHandler(BaseHandler): admin=admin, user_type=user_type, address=address, + shadow_banned=shadow_banned, ) else: return self.store.register_user( @@ -572,6 +579,7 @@ class RegistrationHandler(BaseHandler): create_profile_with_displayname=create_profile_with_displayname, admin=admin, user_type=user_type, + shadow_banned=shadow_banned, ) async def register_device( diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index ce9420aa69..a02b27474d 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -44,6 +44,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): admin, user_type, address, + shadow_banned, ): """ Args: @@ -60,6 +61,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): user_type (str|None): type of user. One of the values from api.constants.UserTypes, or None for a normal user. address (str|None): the IP address used to perform the regitration. + shadow_banned (bool): Whether to shadow-ban the user """ return { "password_hash": password_hash, @@ -70,6 +72,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): "admin": admin, "user_type": user_type, "address": address, + "shadow_banned": shadow_banned, } async def _handle_request(self, request, user_id): @@ -87,6 +90,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): admin=content["admin"], user_type=content["user_type"], address=content["address"], + shadow_banned=content["shadow_banned"], ) return 200, {} diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 7965a52e30..de50fa6e94 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -304,7 +304,7 @@ class RegistrationWorkerStore(SQLBaseStore): def _query_for_auth(self, txn, token): sql = ( - "SELECT users.name, users.is_guest, access_tokens.id as token_id," + "SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id," " access_tokens.device_id, access_tokens.valid_until_ms" " FROM users" " INNER JOIN access_tokens on users.name = access_tokens.user_id" @@ -952,6 +952,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): create_profile_with_displayname=None, admin=False, user_type=None, + shadow_banned=False, ): """Attempts to register an account. @@ -968,6 +969,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): admin (boolean): is an admin user? user_type (str|None): type of user. One of the values from api.constants.UserTypes, or None for a normal user. + shadow_banned (bool): Whether the user is shadow-banned, + i.e. they may be told their requests succeeded but we ignore them. Raises: StoreError if the user_id could not be registered. @@ -986,6 +989,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): create_profile_with_displayname, admin, user_type, + shadow_banned, ) def _register_user( @@ -999,6 +1003,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): create_profile_with_displayname, admin, user_type, + shadow_banned, ): user_id_obj = UserID.from_string(user_id) @@ -1028,6 +1033,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): "appservice_id": appservice_id, "admin": 1 if admin else 0, "user_type": user_type, + "shadow_banned": shadow_banned, }, ) else: @@ -1042,6 +1048,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): "appservice_id": appservice_id, "admin": 1 if admin else 0, "user_type": user_type, + "shadow_banned": shadow_banned, }, ) diff --git a/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql b/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql new file mode 100644 index 0000000000..260b009b48 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql @@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- A shadow-banned user may be told that their requests succeeded when they were +-- actually ignored. +ALTER TABLE users ADD COLUMN shadow_banned BOOLEAN; diff --git a/synapse/types.py b/synapse/types.py index 9e580f4295..bc36cdde30 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -51,7 +51,15 @@ JsonDict = Dict[str, Any] class Requester( namedtuple( - "Requester", ["user", "access_token_id", "is_guest", "device_id", "app_service"] + "Requester", + [ + "user", + "access_token_id", + "is_guest", + "shadow_banned", + "device_id", + "app_service", + ], ) ): """ @@ -62,6 +70,7 @@ class Requester( access_token_id (int|None): *ID* of the access token used for this request, or None if it came via the appservice API or similar is_guest (bool): True if the user making this request is a guest user + shadow_banned (bool): True if the user making this request has been shadow-banned. device_id (str|None): device_id which was set at authentication time app_service (ApplicationService|None): the AS requesting on behalf of the user """ @@ -77,6 +86,7 @@ class Requester( "user_id": self.user.to_string(), "access_token_id": self.access_token_id, "is_guest": self.is_guest, + "shadow_banned": self.shadow_banned, "device_id": self.device_id, "app_server_id": self.app_service.id if self.app_service else None, } @@ -101,13 +111,19 @@ class Requester( user=UserID.from_string(input["user_id"]), access_token_id=input["access_token_id"], is_guest=input["is_guest"], + shadow_banned=input["shadow_banned"], device_id=input["device_id"], app_service=appservice, ) def create_requester( - user_id, access_token_id=None, is_guest=False, device_id=None, app_service=None + user_id, + access_token_id=None, + is_guest=False, + shadow_banned=False, + device_id=None, + app_service=None, ): """ Create a new ``Requester`` object @@ -117,6 +133,7 @@ def create_requester( access_token_id (int|None): *ID* of the access token used for this request, or None if it came via the appservice API or similar is_guest (bool): True if the user making this request is a guest user + shadow_banned (bool): True if the user making this request is shadow-banned. device_id (str|None): device_id which was set at authentication time app_service (ApplicationService|None): the AS requesting on behalf of the user @@ -125,7 +142,9 @@ def create_requester( """ if not isinstance(user_id, UserID): user_id = UserID.from_string(user_id) - return Requester(user_id, access_token_id, is_guest, device_id, app_service) + return Requester( + user_id, access_token_id, is_guest, shadow_banned, device_id, app_service + ) def get_domain_from_id(string): diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 3fab5a5248..8e9a650f9f 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -38,7 +38,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): # Create a test user and room self.user = UserID("alice", "test") - self.requester = Requester(self.user, None, False, None, None) + self.requester = Requester(self.user, None, False, False, None, None) info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] @@ -260,7 +260,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): # Create a test user and room self.user = UserID.from_string(self.register_user("user1", "password")) self.token1 = self.login("user1", "password") - self.requester = Requester(self.user, None, False, None, None) + self.requester = Requester(self.user, None, False, False, None, None) info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] self.event_creator = homeserver.get_event_creation_handler() diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index a7b85004e5..949846fe33 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -27,7 +27,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase): room_creator = self.hs.get_room_creation_handler() user = UserID("alice", "test") - requester = Requester(user, None, False, None, None) + requester = Requester(user, None, False, False, None, None) # Real events, forward extremities events = [(3, 2), (6, 2), (4, 6)] diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 17c9da4838..d98fe8754d 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -187,7 +187,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): # Now let's create a room, which will insert a membership user = UserID("alice", "test") - requester = Requester(user, None, False, None, None) + requester = Requester(user, None, False, False, None, None) self.get_success(self.room_creator.create_room(requester, {})) # Register the background update to run again. diff --git a/tests/test_federation.py b/tests/test_federation.py index f2fa42bfb9..4a4548433f 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -42,7 +42,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) user_id = UserID("us", "test") - our_user = Requester(user_id, None, False, None, None) + our_user = Requester(user_id, None, False, False, None, None) room_creator = self.homeserver.get_room_creation_handler() room_deferred = ensureDeferred( room_creator.create_room( diff --git a/tests/unittest.py b/tests/unittest.py index d0bba3ddef..7b80999a74 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -250,7 +250,11 @@ class HomeserverTestCase(TestCase): async def get_user_by_req(request, allow_guest=False, rights="access"): return create_requester( - UserID.from_string(self.helper.auth_user_id), 1, False, None + UserID.from_string(self.helper.auth_user_id), + 1, + False, + False, + None, ) self.hs.get_auth().get_user_by_req = get_user_by_req @@ -540,7 +544,7 @@ class HomeserverTestCase(TestCase): """ event_creator = self.hs.get_event_creation_handler() secrets = self.hs.get_secrets() - requester = Requester(user, None, False, None, None) + requester = Requester(user, None, False, False, None, None) event, context = self.get_success( event_creator.create_event( -- cgit 1.5.1 From 4a739c73b404284253a548f60197e70c6c385645 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 27 Aug 2020 07:08:38 -0400 Subject: Convert simple_update* and simple_select* to async (#8173) --- changelog.d/8173.misc | 1 + synapse/handlers/room.py | 6 +-- synapse/storage/database.py | 29 ++++++------ synapse/storage/databases/main/__init__.py | 8 ++-- synapse/storage/databases/main/directory.py | 6 +-- synapse/storage/databases/main/e2e_room_keys.py | 26 ++++++---- synapse/storage/databases/main/event_federation.py | 4 +- synapse/storage/databases/main/group_server.py | 55 +++++++++++++--------- synapse/storage/databases/main/media_repository.py | 16 ++++--- synapse/storage/databases/main/profile.py | 18 ++++--- synapse/storage/databases/main/receipts.py | 8 ++-- synapse/storage/databases/main/registration.py | 22 +++++---- synapse/storage/databases/main/room.py | 4 +- synapse/storage/databases/main/user_directory.py | 4 +- tests/handlers/test_stats.py | 4 +- tests/storage/test_base.py | 26 ++++++---- tests/storage/test_directory.py | 6 ++- tests/storage/test_main.py | 4 +- tests/test_federation.py | 50 +++++++++----------- 19 files changed, 164 insertions(+), 133 deletions(-) create mode 100644 changelog.d/8173.misc (limited to 'tests/test_federation.py') diff --git a/changelog.d/8173.misc b/changelog.d/8173.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8173.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index e4788ef86b..236a37f777 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -51,7 +51,7 @@ from synapse.types import ( create_requester, ) from synapse.util import stringutils -from synapse.util.async_helpers import Linearizer, maybe_awaitable +from synapse.util.async_helpers import Linearizer from synapse.util.caches.response_cache import ResponseCache from synapse.visibility import filter_events_for_client @@ -1329,9 +1329,7 @@ class RoomShutdownHandler(object): ratelimit=False, ) - aliases_for_room = await maybe_awaitable( - self.store.get_aliases_for_room(room_id) - ) + aliases_for_room = await self.store.get_aliases_for_room(room_id) await self.store.update_aliases_for_room( room_id, new_room_id, requester_user_id diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 181c3ec249..38010af600 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -1132,13 +1132,13 @@ class DatabasePool(object): return [r[0] for r in txn] - def simple_select_onecol( + async def simple_select_onecol( self, table: str, keyvalues: Optional[Dict[str, Any]], retcol: str, desc: str = "simple_select_onecol", - ) -> defer.Deferred: + ) -> List[Any]: """Executes a SELECT query on the named table, which returns a list comprising of the values of the named column from the selected rows. @@ -1148,19 +1148,19 @@ class DatabasePool(object): retcol: column whos value we wish to retrieve. Returns: - Deferred: Results in a list + Results in a list """ - return self.runInteraction( + return await self.runInteraction( desc, self.simple_select_onecol_txn, table, keyvalues, retcol ) - def simple_select_list( + async def simple_select_list( self, table: str, keyvalues: Optional[Dict[str, Any]], retcols: Iterable[str], desc: str = "simple_select_list", - ) -> defer.Deferred: + ) -> List[Dict[str, Any]]: """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -1170,10 +1170,11 @@ class DatabasePool(object): column names and values to select the rows with, or None to not apply a WHERE clause. retcols: the names of the columns to return + Returns: - defer.Deferred: resolves to list[dict[str, Any]] + A list of dictionaries. """ - return self.runInteraction( + return await self.runInteraction( desc, self.simple_select_list_txn, table, keyvalues, retcols ) @@ -1299,14 +1300,14 @@ class DatabasePool(object): txn.execute(sql, values) return cls.cursor_to_dict(txn) - def simple_update( + async def simple_update( self, table: str, keyvalues: Dict[str, Any], updatevalues: Dict[str, Any], desc: str, - ) -> defer.Deferred: - return self.runInteraction( + ) -> int: + return await self.runInteraction( desc, self.simple_update_txn, table, keyvalues, updatevalues ) @@ -1332,13 +1333,13 @@ class DatabasePool(object): return txn.rowcount - def simple_update_one( + async def simple_update_one( self, table: str, keyvalues: Dict[str, Any], updatevalues: Dict[str, Any], desc: str = "simple_update_one", - ) -> defer.Deferred: + ) -> None: """Executes an UPDATE query on the named table, setting new values for columns in a row matching the key values. @@ -1347,7 +1348,7 @@ class DatabasePool(object): keyvalues: dict of column names and values to select the row with updatevalues: dict giving column names and values to update """ - return self.runInteraction( + await self.runInteraction( desc, self.simple_update_one_txn, table, keyvalues, updatevalues ) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 0934ae276c..8b9b6eb472 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -18,6 +18,7 @@ import calendar import logging import time +from typing import Any, Dict, List from synapse.api.constants import PresenceState from synapse.config.homeserver import HomeServerConfig @@ -476,14 +477,13 @@ class DataStore( "generate_user_daily_visits", _generate_user_daily_visits ) - def get_users(self): + async def get_users(self) -> List[Dict[str, Any]]: """Function to retrieve a list of users in users table. - Args: Returns: - defer.Deferred: resolves to list[dict[str, Any]] + A list of dictionaries representing users. """ - return self.db_pool.simple_select_list( + return await self.db_pool.simple_select_list( table="users", keyvalues={}, retcols=[ diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index 301d5d845a..405b5eafa5 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -14,7 +14,7 @@ # limitations under the License. from collections import namedtuple -from typing import Iterable, Optional +from typing import Iterable, List, Optional from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore @@ -68,8 +68,8 @@ class DirectoryWorkerStore(SQLBaseStore): ) @cached(max_entries=5000) - def get_aliases_for_room(self, room_id): - return self.db_pool.simple_select_onecol( + async def get_aliases_for_room(self, room_id: str) -> List[str]: + return await self.db_pool.simple_select_onecol( "room_aliases", {"room_id": room_id}, "room_alias", diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index 46c3e33cc6..82f9d870fd 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + from synapse.api.errors import StoreError from synapse.logging.opentracing import log_kv, trace from synapse.storage._base import SQLBaseStore, db_to_json @@ -368,18 +370,22 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) @trace - def update_e2e_room_keys_version( - self, user_id, version, info=None, version_etag=None - ): + async def update_e2e_room_keys_version( + self, + user_id: str, + version: str, + info: Optional[dict] = None, + version_etag: Optional[int] = None, + ) -> None: """Update a given backup version Args: - user_id(str): the user whose backup version we're updating - version(str): the version ID of the backup version we're updating - info (dict): the new backup version info to store. If None, then - the backup version info is not updated - version_etag (Optional[int]): etag of the keys in the backup. If - None, then the etag is not updated + user_id: the user whose backup version we're updating + version: the version ID of the backup version we're updating + info: the new backup version info to store. If None, then the backup + version info is not updated. + version_etag: etag of the keys in the backup. If None, then the etag + is not updated. """ updatevalues = {} @@ -389,7 +395,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): updatevalues["etag"] = version_etag if updatevalues: - return self.db_pool.simple_update( + await self.db_pool.simple_update( table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": version}, updatevalues=updatevalues, diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index e6a97b018c..6e5761c7b7 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -368,8 +368,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas ) @cached(max_entries=5000, iterable=True) - def get_latest_event_ids_in_room(self, room_id): - return self.db_pool.simple_select_onecol( + async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]: + return await self.db_pool.simple_select_onecol( table="event_forward_extremities", keyvalues={"room_id": room_id}, retcol="event_id", diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index c39864f59f..e3ead71853 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -44,24 +44,26 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_group", ) - def get_users_in_group(self, group_id, include_private=False): + async def get_users_in_group( + self, group_id: str, include_private: bool = False + ) -> List[Dict[str, Any]]: # TODO: Pagination keyvalues = {"group_id": group_id} if not include_private: keyvalues["is_public"] = True - return self.db_pool.simple_select_list( + return await self.db_pool.simple_select_list( table="group_users", keyvalues=keyvalues, retcols=("user_id", "is_public", "is_admin"), desc="get_users_in_group", ) - def get_invited_users_in_group(self, group_id): + async def get_invited_users_in_group(self, group_id: str) -> List[str]: # TODO: Pagination - return self.db_pool.simple_select_onecol( + return await self.db_pool.simple_select_onecol( table="group_invites", keyvalues={"group_id": group_id}, retcol="user_id", @@ -265,15 +267,14 @@ class GroupServerWorkerStore(SQLBaseStore): return role - def get_local_groups_for_room(self, room_id): + async def get_local_groups_for_room(self, room_id: str) -> List[str]: """Get all of the local group that contain a given room Args: - room_id (str): The ID of a room + room_id: The ID of a room Returns: - Deferred[list[str]]: A twisted.Deferred containing a list of group ids - containing this room + A list of group ids containing this room """ - return self.db_pool.simple_select_onecol( + return await self.db_pool.simple_select_onecol( table="group_rooms", keyvalues={"room_id": room_id}, retcol="group_id", @@ -422,10 +423,10 @@ class GroupServerWorkerStore(SQLBaseStore): "get_users_membership_info_in_group", _get_users_membership_in_group_txn ) - def get_publicised_groups_for_user(self, user_id): + async def get_publicised_groups_for_user(self, user_id: str) -> List[str]: """Get all groups a user is publicising """ - return self.db_pool.simple_select_onecol( + return await self.db_pool.simple_select_onecol( table="local_group_membership", keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True}, retcol="group_id", @@ -466,8 +467,8 @@ class GroupServerWorkerStore(SQLBaseStore): return None - def get_joined_groups(self, user_id): - return self.db_pool.simple_select_onecol( + async def get_joined_groups(self, user_id: str) -> List[str]: + return await self.db_pool.simple_select_onecol( table="local_group_membership", keyvalues={"user_id": user_id, "membership": "join"}, retcol="group_id", @@ -585,14 +586,14 @@ class GroupServerWorkerStore(SQLBaseStore): class GroupServerStore(GroupServerWorkerStore): - def set_group_join_policy(self, group_id, join_policy): + async def set_group_join_policy(self, group_id: str, join_policy: str) -> None: """Set the join policy of a group. join_policy can be one of: * "invite" * "open" """ - return self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="groups", keyvalues={"group_id": group_id}, updatevalues={"join_policy": join_policy}, @@ -1050,8 +1051,10 @@ class GroupServerStore(GroupServerWorkerStore): desc="add_room_to_group", ) - def update_room_in_group_visibility(self, group_id, room_id, is_public): - return self.db_pool.simple_update( + async def update_room_in_group_visibility( + self, group_id: str, room_id: str, is_public: bool + ) -> int: + return await self.db_pool.simple_update( table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, updatevalues={"is_public": is_public}, @@ -1076,10 +1079,12 @@ class GroupServerStore(GroupServerWorkerStore): "remove_room_from_group", _remove_room_from_group_txn ) - def update_group_publicity(self, group_id, user_id, publicise): + async def update_group_publicity( + self, group_id: str, user_id: str, publicise: bool + ) -> None: """Update whether the user is publicising their membership of the group """ - return self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="local_group_membership", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={"is_publicised": publicise}, @@ -1218,20 +1223,24 @@ class GroupServerStore(GroupServerWorkerStore): desc="update_group_profile", ) - def update_attestation_renewal(self, group_id, user_id, attestation): + async def update_attestation_renewal( + self, group_id: str, user_id: str, attestation: dict + ) -> None: """Update an attestation that we have renewed """ - return self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={"valid_until_ms": attestation["valid_until_ms"]}, desc="update_attestation_renewal", ) - def update_remote_attestion(self, group_id, user_id, attestation): + async def update_remote_attestion( + self, group_id: str, user_id: str, attestation: dict + ) -> None: """Update an attestation that a remote has renewed """ - return self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={ diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 4ae255ebd8..fc223f5a2a 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool @@ -84,9 +84,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="store_local_media", ) - def mark_local_media_as_safe(self, media_id: str): + async def mark_local_media_as_safe(self, media_id: str) -> None: """Mark a local media as safe from quarantining.""" - return self.db_pool.simple_update_one( + await self.db_pool.simple_update_one( table="local_media_repository", keyvalues={"media_id": media_id}, updatevalues={"safe_from_quarantine": True}, @@ -158,8 +158,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="store_url_cache", ) - def get_local_media_thumbnails(self, media_id): - return self.db_pool.simple_select_list( + async def get_local_media_thumbnails(self, media_id: str) -> List[Dict[str, Any]]: + return await self.db_pool.simple_select_list( "local_media_repository_thumbnails", {"media_id": media_id}, ( @@ -271,8 +271,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "update_cached_last_access_time", update_cache_txn ) - def get_remote_media_thumbnails(self, origin, media_id): - return self.db_pool.simple_select_list( + async def get_remote_media_thumbnails( + self, origin: str, media_id: str + ) -> List[Dict[str, Any]]: + return await self.db_pool.simple_select_list( "remote_media_cache_thumbnails", {"media_origin": origin, "media_id": media_id}, ( diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index b8233c4848..858fd92420 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -71,16 +71,20 @@ class ProfileWorkerStore(SQLBaseStore): table="profiles", values={"user_id": user_localpart}, desc="create_profile" ) - def set_profile_displayname(self, user_localpart, new_displayname): - return self.db_pool.simple_update_one( + async def set_profile_displayname( + self, user_localpart: str, new_displayname: str + ) -> None: + await self.db_pool.simple_update_one( table="profiles", keyvalues={"user_id": user_localpart}, updatevalues={"displayname": new_displayname}, desc="set_profile_displayname", ) - def set_profile_avatar_url(self, user_localpart, new_avatar_url): - return self.db_pool.simple_update_one( + async def set_profile_avatar_url( + self, user_localpart: str, new_avatar_url: str + ) -> None: + await self.db_pool.simple_update_one( table="profiles", keyvalues={"user_id": user_localpart}, updatevalues={"avatar_url": new_avatar_url}, @@ -106,8 +110,10 @@ class ProfileStore(ProfileWorkerStore): desc="add_remote_profile_cache", ) - def update_remote_profile_cache(self, user_id, displayname, avatar_url): - return self.db_pool.simple_update( + async def update_remote_profile_cache( + self, user_id: str, displayname: str, avatar_url: str + ) -> int: + return await self.db_pool.simple_update( table="remote_profile_cache", keyvalues={"user_id": user_id}, updatevalues={ diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index cea5ac9a68..436f22ad2d 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -16,7 +16,7 @@ import abc import logging -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from twisted.internet import defer @@ -62,8 +62,10 @@ class ReceiptsWorkerStore(SQLBaseStore): return {r["user_id"] for r in receipts} @cached(num_args=2) - def get_receipts_for_room(self, room_id, receipt_type): - return self.db_pool.simple_select_list( + async def get_receipts_for_room( + self, room_id: str, receipt_type: str + ) -> List[Dict[str, Any]]: + return await self.db_pool.simple_select_list( table="receipts_linearized", keyvalues={"room_id": room_id, "receipt_type": receipt_type}, retcols=("user_id", "event_id"), diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index eced53d470..48bda66f3e 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -578,20 +578,20 @@ class RegistrationWorkerStore(SQLBaseStore): desc="add_user_bound_threepid", ) - def user_get_bound_threepids(self, user_id): + async def user_get_bound_threepids(self, user_id: str) -> List[Dict[str, Any]]: """Get the threepids that a user has bound to an identity server through the homeserver The homeserver remembers where binds to an identity server occurred. Using this method can retrieve those threepids. Args: - user_id (str): The ID of the user to retrieve threepids for + user_id: The ID of the user to retrieve threepids for Returns: - Deferred[list[dict]]: List of dictionaries containing the following: + List of dictionaries containing the following keys: medium (str): The medium of the threepid (e.g "email") address (str): The address of the threepid (e.g "bob@example.com") """ - return self.db_pool.simple_select_list( + return await self.db_pool.simple_select_list( table="user_threepid_id_server", keyvalues={"user_id": user_id}, retcols=["medium", "address"], @@ -623,19 +623,21 @@ class RegistrationWorkerStore(SQLBaseStore): desc="remove_user_bound_threepid", ) - def get_id_servers_user_bound(self, user_id, medium, address): + async def get_id_servers_user_bound( + self, user_id: str, medium: str, address: str + ) -> List[str]: """Get the list of identity servers that the server proxied bind requests to for given user and threepid Args: - user_id (str) - medium (str) - address (str) + user_id: The user to query for identity servers. + medium: The medium to query for identity servers. + address: The address to query for identity servers. Returns: - Deferred[list[str]]: Resolves to a list of identity servers + A list of identity servers """ - return self.db_pool.simple_select_onecol( + return await self.db_pool.simple_select_onecol( table="user_threepid_id_server", keyvalues={"user_id": user_id, "medium": medium, "address": address}, retcol="id_server", diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 97ecdb16e4..66d7135413 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -125,8 +125,8 @@ class RoomWorkerStore(SQLBaseStore): "get_room_with_stats", get_room_with_stats_txn, room_id ) - def get_public_room_ids(self): - return self.db_pool.simple_select_onecol( + async def get_public_room_ids(self) -> List[str]: + return await self.db_pool.simple_select_onecol( table="rooms", keyvalues={"is_public": True}, retcol="room_id", diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 20cbcd851c..a9f2e93614 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -537,8 +537,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): desc="get_user_in_directory", ) - def update_user_directory_stream_pos(self, stream_id): - return self.db_pool.simple_update_one( + async def update_user_directory_stream_pos(self, stream_id: str) -> None: + await self.db_pool.simple_update_one( table="user_directory_stream_pos", keyvalues={}, updatevalues={"stream_id": stream_id}, diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 88b05c23a0..a609f148c0 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -81,8 +81,8 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) - def get_all_room_state(self): - return self.store.db_pool.simple_select_list( + async def get_all_room_state(self): + return await self.store.db_pool.simple_select_list( "room_stats_state", None, retcols=("name", "topic", "canonical_alias") ) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index bf22540d99..64abe8cc49 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -148,8 +148,10 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) self.mock_txn.description = (("colA", None, None, None, None, None, None),) - ret = yield self.datastore.db_pool.simple_select_list( - table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"] + ret = yield defer.ensureDeferred( + self.datastore.db_pool.simple_select_list( + table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"] + ) ) self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret) @@ -161,10 +163,12 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_update_one_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore.db_pool.simple_update_one( - table="tablename", - keyvalues={"keycol": "TheKey"}, - updatevalues={"columnname": "New Value"}, + yield defer.ensureDeferred( + self.datastore.db_pool.simple_update_one( + table="tablename", + keyvalues={"keycol": "TheKey"}, + updatevalues={"columnname": "New Value"}, + ) ) self.mock_txn.execute.assert_called_with( @@ -176,10 +180,12 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_update_one_4cols(self): self.mock_txn.rowcount = 1 - yield self.datastore.db_pool.simple_update_one( - table="tablename", - keyvalues=OrderedDict([("colA", 1), ("colB", 2)]), - updatevalues=OrderedDict([("colC", 3), ("colD", 4)]), + yield defer.ensureDeferred( + self.datastore.db_pool.simple_update_one( + table="tablename", + keyvalues=OrderedDict([("colA", 1), ("colB", 2)]), + updatevalues=OrderedDict([("colC", 3), ("colD", 4)]), + ) ) self.mock_txn.execute.assert_called_with( diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index daac947cb2..da93ca3980 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -42,7 +42,11 @@ class DirectoryStoreTestCase(unittest.TestCase): self.assertEquals( ["#my-room:test"], - (yield self.store.get_aliases_for_room(self.room.to_string())), + ( + yield defer.ensureDeferred( + self.store.get_aliases_for_room(self.room.to_string()) + ) + ), ) @defer.inlineCallbacks diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py index fbf8af940a..954338a592 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py @@ -36,7 +36,9 @@ class DataStoreTestCase(unittest.TestCase): def test_get_users_paginate(self): yield self.store.register_user(self.user.to_string(), "pass") yield defer.ensureDeferred(self.store.create_profile(self.user.localpart)) - yield self.store.set_profile_displayname(self.user.localpart, self.displayname) + yield defer.ensureDeferred( + self.store.set_profile_displayname(self.user.localpart, self.displayname) + ) users, total = yield self.store.get_users_paginate( 0, 10, name="bc", guests=False diff --git a/tests/test_federation.py b/tests/test_federation.py index 4a4548433f..27a7fc9ed7 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -15,8 +15,9 @@ from mock import Mock -from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed +from twisted.internet.defer import succeed +from synapse.api.errors import FederationError from synapse.events import make_event_from_dict from synapse.logging.context import LoggingContext from synapse.types import Requester, UserID @@ -44,22 +45,17 @@ class MessageAcceptTests(unittest.HomeserverTestCase): user_id = UserID("us", "test") our_user = Requester(user_id, None, False, False, None, None) room_creator = self.homeserver.get_room_creation_handler() - room_deferred = ensureDeferred( + self.room_id = self.get_success( room_creator.create_room( our_user, room_creator._presets_dict["public_chat"], ratelimit=False ) - ) - self.reactor.advance(0.1) - self.room_id = self.successResultOf(room_deferred)[0]["room_id"] + )[0]["room_id"] self.store = self.homeserver.get_datastore() # Figure out what the most recent event is - most_recent = self.successResultOf( - maybeDeferred( - self.homeserver.get_datastore().get_latest_event_ids_in_room, - self.room_id, - ) + most_recent = self.get_success( + self.homeserver.get_datastore().get_latest_event_ids_in_room(self.room_id) )[0] join_event = make_event_from_dict( @@ -89,19 +85,18 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) # Send the join, it should return None (which is not an error) - d = ensureDeferred( - self.handler.on_receive_pdu( - "test.serv", join_event, sent_to_us_directly=True - ) + self.assertEqual( + self.get_success( + self.handler.on_receive_pdu( + "test.serv", join_event, sent_to_us_directly=True + ) + ), + None, ) - self.reactor.advance(1) - self.assertEqual(self.successResultOf(d), None) # Make sure we actually joined the room self.assertEqual( - self.successResultOf( - maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) - )[0], + self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))[0], "$join:test.serv", ) @@ -119,8 +114,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase): self.http_client.post_json = post_json # Figure out what the most recent event is - most_recent = self.successResultOf( - maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) + most_recent = self.get_success( + self.store.get_latest_event_ids_in_room(self.room_id) )[0] # Now lie about an event @@ -140,17 +135,14 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) with LoggingContext(request="lying_event"): - d = ensureDeferred( + failure = self.get_failure( self.handler.on_receive_pdu( "test.serv", lying_event, sent_to_us_directly=True - ) + ), + FederationError, ) - # Step the reactor, so the database fetches come back - self.reactor.advance(1) - # on_receive_pdu should throw an error - failure = self.failureResultOf(d) self.assertEqual( failure.value.args[0], ( @@ -160,8 +152,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) # Make sure the invalid event isn't there - extrem = maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) - self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv") + extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)) + self.assertEqual(extrem[0], "$join:test.serv") def test_retry_device_list_resync(self): """Tests that device lists are marked as stale if they couldn't be synced, and -- cgit 1.5.1