summary refs log tree commit diff
path: root/synapse/storage/databases/main/roommember.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/roommember.py')
-rw-r--r--synapse/storage/databases/main/roommember.py81
1 files changed, 39 insertions, 42 deletions
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py

index b2fcfc9bfe..91a8b43da3 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py
@@ -15,9 +15,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set - -from twisted.internet import defer +from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set from synapse.api.constants import EventTypes, Membership from synapse.events import EventBase @@ -92,8 +90,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): lambda: self._known_servers_count, ) - @defer.inlineCallbacks - def _count_known_servers(self): + async def _count_known_servers(self): """ Count the servers that this server knows about. @@ -121,7 +118,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): txn.execute(query) return list(txn)[0][0] - count = yield self.db_pool.runInteraction("get_known_servers", _transact) + count = await self.db_pool.runInteraction("get_known_servers", _transact) # We always know about ourselves, even if we have nothing in # room_memberships (for example, the server is new). @@ -155,8 +152,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) @cached(max_entries=100000, iterable=True) - def get_users_in_room(self, room_id: str): - return self.db_pool.runInteraction( + async def get_users_in_room(self, room_id: str) -> List[str]: + return await self.db_pool.runInteraction( "get_users_in_room", self.get_users_in_room_txn, room_id ) @@ -183,14 +180,13 @@ class RoomMemberWorkerStore(EventsWorkerStore): return [r[0] for r in txn] @cached(max_entries=100000) - def get_room_summary(self, room_id: str): + async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]: """ Get the details of a room roughly suitable for use by the room summary extension to /sync. Useful when lazy loading room members. Args: room_id: The room ID to query Returns: - Deferred[dict[str, MemberSummary]: - dict of membership states, pointing to a MemberSummary named tuple. + dict of membership states, pointing to a MemberSummary named tuple. """ def _get_room_summary_txn(txn): @@ -264,20 +260,22 @@ class RoomMemberWorkerStore(EventsWorkerStore): return res - return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn) + return await self.db_pool.runInteraction( + "get_room_summary", _get_room_summary_txn + ) @cached() - def get_invited_rooms_for_local_user(self, user_id: str) -> Awaitable[RoomsForUser]: + async def get_invited_rooms_for_local_user(self, user_id: str) -> RoomsForUser: """Get all the rooms the *local* user is invited to. Args: user_id: The user ID. Returns: - A awaitable list of RoomsForUser. + A list of RoomsForUser. """ - return self.get_rooms_for_local_user_where_membership_is( + return await self.get_rooms_for_local_user_where_membership_is( user_id, [Membership.INVITE] ) @@ -300,8 +298,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): return None async def get_rooms_for_local_user_where_membership_is( - self, user_id: str, membership_list: List[str] - ) -> Optional[List[RoomsForUser]]: + self, user_id: str, membership_list: Collection[str] + ) -> List[RoomsForUser]: """Get all the rooms for this *local* user where the membership for this user matches one in the membership list. @@ -316,7 +314,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): The RoomsForUser that the user matches the membership types. """ if not membership_list: - return None + return [] rooms = await self.db_pool.runInteraction( "get_rooms_for_local_user_where_membership_is", @@ -360,7 +358,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): return results @cached(max_entries=500000, iterable=True) - def get_rooms_for_user_with_stream_ordering(self, user_id: str): + async def get_rooms_for_user_with_stream_ordering( + self, user_id: str + ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]: """Returns a set of room_ids the user is currently joined to. If a remote user only returns rooms this server is currently @@ -370,17 +370,18 @@ class RoomMemberWorkerStore(EventsWorkerStore): user_id Returns: - Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns - the rooms the user is in currently, along with the stream ordering - of the most recent join for that user and room. + Returns the rooms the user is in currently, along with the stream + ordering of the most recent join for that user and room. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_rooms_for_user_with_stream_ordering", self._get_rooms_for_user_with_stream_ordering_txn, user_id, ) - def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id: str): + def _get_rooms_for_user_with_stream_ordering_txn( + self, txn, user_id: str + ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]: # We use `current_state_events` here and not `local_current_membership` # as a) this gets called with remote users and b) this only gets called # for rooms the server is participating in. @@ -407,9 +408,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): """ txn.execute(sql, (user_id, Membership.JOIN)) - results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn) - - return results + return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn) async def get_users_server_still_shares_room_with( self, user_ids: Collection[str] @@ -589,11 +588,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): raise NotImplementedError() @cachedList( - cached_method_name="_get_joined_profile_from_event_id", - list_name="event_ids", - inlineCallbacks=True, + cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids", ) - def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]): + async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]): """For given set of member event_ids check if they point to a join event and if so return the associated user and profile info. @@ -601,11 +598,11 @@ class RoomMemberWorkerStore(EventsWorkerStore): event_ids: The member event IDs to lookup Returns: - Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID + dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID to `user_id` and ProfileInfo (or None if not join event). """ - rows = yield self.db_pool.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="room_memberships", column="event_id", iterable=event_ids, @@ -716,14 +713,14 @@ class RoomMemberWorkerStore(EventsWorkerStore): return count == 0 @cached() - def get_forgotten_rooms_for_user(self, user_id: str): + async def get_forgotten_rooms_for_user(self, user_id: str) -> Set[str]: """Gets all rooms the user has forgotten. Args: - user_id + user_id: The user ID to query the rooms of. Returns: - Deferred[set[str]] + The forgotten rooms. """ def _get_forgotten_rooms_for_user_txn(txn): @@ -749,7 +746,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): txn.execute(sql, (user_id,)) return {row[0] for row in txn if row[1] == 0} - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn ) @@ -772,13 +769,13 @@ class RoomMemberWorkerStore(EventsWorkerStore): return set(room_ids) - def get_membership_from_event_ids( + async def get_membership_from_event_ids( self, member_event_ids: Iterable[str] ) -> List[dict]: """Get user_id and membership of a set of event IDs. """ - return self.db_pool.simple_select_many_batch( + return await self.db_pool.simple_select_many_batch( table="room_memberships", column="event_id", iterable=member_event_ids, @@ -978,7 +975,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): super(RoomMemberStore, self).__init__(database, db_conn, hs) - def forget(self, user_id: str, room_id: str): + async def forget(self, user_id: str, room_id: str) -> None: """Indicate that user_id wishes to discard history for room_id.""" def f(txn): @@ -999,10 +996,10 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): txn, self.get_forgotten_rooms_for_user, (user_id,) ) - return self.db_pool.runInteraction("forget_membership", f) + await self.db_pool.runInteraction("forget_membership", f) -class _JoinedHostsCache(object): +class _JoinedHostsCache: """Cache for joined hosts in a room that is optimised to handle updates via state deltas. """