diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index b2fcfc9bfe..91a8b43da3 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -15,9 +15,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
-
-from twisted.internet import defer
+from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
@@ -92,8 +90,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
lambda: self._known_servers_count,
)
- @defer.inlineCallbacks
- def _count_known_servers(self):
+ async def _count_known_servers(self):
"""
Count the servers that this server knows about.
@@ -121,7 +118,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(query)
return list(txn)[0][0]
- count = yield self.db_pool.runInteraction("get_known_servers", _transact)
+ count = await self.db_pool.runInteraction("get_known_servers", _transact)
# We always know about ourselves, even if we have nothing in
# room_memberships (for example, the server is new).
@@ -155,8 +152,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
@cached(max_entries=100000, iterable=True)
- def get_users_in_room(self, room_id: str):
- return self.db_pool.runInteraction(
+ async def get_users_in_room(self, room_id: str) -> List[str]:
+ return await self.db_pool.runInteraction(
"get_users_in_room", self.get_users_in_room_txn, room_id
)
@@ -183,14 +180,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return [r[0] for r in txn]
@cached(max_entries=100000)
- def get_room_summary(self, room_id: str):
+ async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]:
""" Get the details of a room roughly suitable for use by the room
summary extension to /sync. Useful when lazy loading room members.
Args:
room_id: The room ID to query
Returns:
- Deferred[dict[str, MemberSummary]:
- dict of membership states, pointing to a MemberSummary named tuple.
+ dict of membership states, pointing to a MemberSummary named tuple.
"""
def _get_room_summary_txn(txn):
@@ -264,20 +260,22 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return res
- return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn)
+ return await self.db_pool.runInteraction(
+ "get_room_summary", _get_room_summary_txn
+ )
@cached()
- def get_invited_rooms_for_local_user(self, user_id: str) -> Awaitable[RoomsForUser]:
+ async def get_invited_rooms_for_local_user(self, user_id: str) -> RoomsForUser:
"""Get all the rooms the *local* user is invited to.
Args:
user_id: The user ID.
Returns:
- A awaitable list of RoomsForUser.
+ A list of RoomsForUser.
"""
- return self.get_rooms_for_local_user_where_membership_is(
+ return await self.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE]
)
@@ -300,8 +298,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return None
async def get_rooms_for_local_user_where_membership_is(
- self, user_id: str, membership_list: List[str]
- ) -> Optional[List[RoomsForUser]]:
+ self, user_id: str, membership_list: Collection[str]
+ ) -> List[RoomsForUser]:
"""Get all the rooms for this *local* user where the membership for this user
matches one in the membership list.
@@ -316,7 +314,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
The RoomsForUser that the user matches the membership types.
"""
if not membership_list:
- return None
+ return []
rooms = await self.db_pool.runInteraction(
"get_rooms_for_local_user_where_membership_is",
@@ -360,7 +358,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results
@cached(max_entries=500000, iterable=True)
- def get_rooms_for_user_with_stream_ordering(self, user_id: str):
+ async def get_rooms_for_user_with_stream_ordering(
+ self, user_id: str
+ ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
"""Returns a set of room_ids the user is currently joined to.
If a remote user only returns rooms this server is currently
@@ -370,17 +370,18 @@ class RoomMemberWorkerStore(EventsWorkerStore):
user_id
Returns:
- Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
- the rooms the user is in currently, along with the stream ordering
- of the most recent join for that user and room.
+ Returns the rooms the user is in currently, along with the stream
+ ordering of the most recent join for that user and room.
"""
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_rooms_for_user_with_stream_ordering",
self._get_rooms_for_user_with_stream_ordering_txn,
user_id,
)
- def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id: str):
+ def _get_rooms_for_user_with_stream_ordering_txn(
+ self, txn, user_id: str
+ ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
# We use `current_state_events` here and not `local_current_membership`
# as a) this gets called with remote users and b) this only gets called
# for rooms the server is participating in.
@@ -407,9 +408,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""
txn.execute(sql, (user_id, Membership.JOIN))
- results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
-
- return results
+ return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
async def get_users_server_still_shares_room_with(
self, user_ids: Collection[str]
@@ -589,11 +588,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
raise NotImplementedError()
@cachedList(
- cached_method_name="_get_joined_profile_from_event_id",
- list_name="event_ids",
- inlineCallbacks=True,
+ cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids",
)
- def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
+ async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
"""For given set of member event_ids check if they point to a join
event and if so return the associated user and profile info.
@@ -601,11 +598,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
event_ids: The member event IDs to lookup
Returns:
- Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
+ dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID
to `user_id` and ProfileInfo (or None if not join event).
"""
- rows = yield self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=event_ids,
@@ -716,14 +713,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return count == 0
@cached()
- def get_forgotten_rooms_for_user(self, user_id: str):
+ async def get_forgotten_rooms_for_user(self, user_id: str) -> Set[str]:
"""Gets all rooms the user has forgotten.
Args:
- user_id
+ user_id: The user ID to query the rooms of.
Returns:
- Deferred[set[str]]
+ The forgotten rooms.
"""
def _get_forgotten_rooms_for_user_txn(txn):
@@ -749,7 +746,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(sql, (user_id,))
return {row[0] for row in txn if row[1] == 0}
- return self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
)
@@ -772,13 +769,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return set(room_ids)
- def get_membership_from_event_ids(
+ async def get_membership_from_event_ids(
self, member_event_ids: Iterable[str]
) -> List[dict]:
"""Get user_id and membership of a set of event IDs.
"""
- return self.db_pool.simple_select_many_batch(
+ return await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=member_event_ids,
@@ -978,7 +975,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomMemberStore, self).__init__(database, db_conn, hs)
- def forget(self, user_id: str, room_id: str):
+ async def forget(self, user_id: str, room_id: str) -> None:
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn):
@@ -999,10 +996,10 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
txn, self.get_forgotten_rooms_for_user, (user_id,)
)
- return self.db_pool.runInteraction("forget_membership", f)
+ await self.db_pool.runInteraction("forget_membership", f)
-class _JoinedHostsCache(object):
+class _JoinedHostsCache:
"""Cache for joined hosts in a room that is optimised to handle updates
via state deltas.
"""
|