diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 7c5be251bd..b2fcfc9bfe 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -15,11 +15,13 @@
# limitations under the License.
import logging
-from typing import Iterable, List, Set
+from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import (
@@ -40,9 +42,12 @@ from synapse.storage.roommember import (
from synapse.types import Collection, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
+from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
from synapse.util.metrics import Measure
+if TYPE_CHECKING:
+ from synapse.state import _StateCacheEntry
+
logger = logging.getLogger(__name__)
@@ -150,12 +155,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
@cached(max_entries=100000, iterable=True)
- def get_users_in_room(self, room_id):
+ def get_users_in_room(self, room_id: str):
return self.db_pool.runInteraction(
"get_users_in_room", self.get_users_in_room_txn, room_id
)
- def get_users_in_room_txn(self, txn, room_id):
+ def get_users_in_room_txn(self, txn, room_id: str) -> List[str]:
# If we can assume current_state_events.membership is up to date
# then we can avoid a join, which is a Very Good Thing given how
# frequently this function gets called.
@@ -178,11 +183,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return [r[0] for r in txn]
@cached(max_entries=100000)
- def get_room_summary(self, room_id):
+ def get_room_summary(self, room_id: str):
""" Get the details of a room roughly suitable for use by the room
summary extension to /sync. Useful when lazy loading room members.
Args:
- room_id (str): The room ID to query
+ room_id: The room ID to query
Returns:
Deferred[dict[str, MemberSummary]:
dict of membership states, pointing to a MemberSummary named tuple.
@@ -261,78 +266,59 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn)
- def _get_user_counts_in_room_txn(self, txn, room_id):
- """
- Get the user count in a room by membership.
-
- Args:
- room_id (str)
- membership (Membership)
-
- Returns:
- Deferred[int]
- """
- sql = """
- SELECT m.membership, count(*) FROM room_memberships as m
- INNER JOIN current_state_events as c USING(event_id)
- WHERE c.type = 'm.room.member' AND c.room_id = ?
- GROUP BY m.membership
- """
-
- txn.execute(sql, (room_id,))
- return {row[0]: row[1] for row in txn}
-
@cached()
- def get_invited_rooms_for_local_user(self, user_id):
- """ Get all the rooms the *local* user is invited to
+ def get_invited_rooms_for_local_user(self, user_id: str) -> Awaitable[RoomsForUser]:
+ """Get all the rooms the *local* user is invited to.
Args:
- user_id (str): The user ID.
+ user_id: The user ID.
+
Returns:
- A deferred list of RoomsForUser.
+ A awaitable list of RoomsForUser.
"""
return self.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE]
)
- @defer.inlineCallbacks
- def get_invite_for_local_user_in_room(self, user_id, room_id):
- """Gets the invite for the given *local* user and room
+ async def get_invite_for_local_user_in_room(
+ self, user_id: str, room_id: str
+ ) -> Optional[RoomsForUser]:
+ """Gets the invite for the given *local* user and room.
Args:
- user_id (str)
- room_id (str)
+ user_id: The user ID to find the invite of.
+ room_id: The room to user was invited to.
Returns:
- Deferred: Resolves to either a RoomsForUser or None if no invite was
- found.
+ Either a RoomsForUser or None if no invite was found.
"""
- invites = yield self.get_invited_rooms_for_local_user(user_id)
+ invites = await self.get_invited_rooms_for_local_user(user_id)
for invite in invites:
if invite.room_id == room_id:
return invite
return None
- @defer.inlineCallbacks
- def get_rooms_for_local_user_where_membership_is(self, user_id, membership_list):
- """ Get all the rooms for this *local* user where the membership for this user
+ async def get_rooms_for_local_user_where_membership_is(
+ self, user_id: str, membership_list: List[str]
+ ) -> Optional[List[RoomsForUser]]:
+ """Get all the rooms for this *local* user where the membership for this user
matches one in the membership list.
Filters out forgotten rooms.
Args:
- user_id (str): The user ID.
- membership_list (list): A list of synapse.api.constants.Membership
- values which the user must be in.
+ user_id: The user ID.
+ membership_list: A list of synapse.api.constants.Membership
+ values which the user must be in.
Returns:
- Deferred[list[RoomsForUser]]
+ The RoomsForUser that the user matches the membership types.
"""
if not membership_list:
- return defer.succeed(None)
+ return None
- rooms = yield self.db_pool.runInteraction(
+ rooms = await self.db_pool.runInteraction(
"get_rooms_for_local_user_where_membership_is",
self._get_rooms_for_local_user_where_membership_is_txn,
user_id,
@@ -340,12 +326,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
# Now we filter out forgotten rooms
- forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id)
+ forgotten_rooms = await self.get_forgotten_rooms_for_user(user_id)
return [room for room in rooms if room.room_id not in forgotten_rooms]
def _get_rooms_for_local_user_where_membership_is_txn(
- self, txn, user_id, membership_list
- ):
+ self, txn, user_id: str, membership_list: List[str]
+ ) -> List[RoomsForUser]:
# Paranoia check.
if not self.hs.is_mine_id(user_id):
raise Exception(
@@ -374,14 +360,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results
@cached(max_entries=500000, iterable=True)
- def get_rooms_for_user_with_stream_ordering(self, user_id):
+ def get_rooms_for_user_with_stream_ordering(self, user_id: str):
"""Returns a set of room_ids the user is currently joined to.
If a remote user only returns rooms this server is currently
participating in.
Args:
- user_id (str)
+ user_id
Returns:
Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
@@ -394,7 +380,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
user_id,
)
- def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id):
+ def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id: str):
# We use `current_state_events` here and not `local_current_membership`
# as a) this gets called with remote users and b) this only gets called
# for rooms the server is participating in.
@@ -458,37 +444,39 @@ class RoomMemberWorkerStore(EventsWorkerStore):
_get_users_server_still_shares_room_with_txn,
)
- @defer.inlineCallbacks
- def get_rooms_for_user(self, user_id, on_invalidate=None):
+ async def get_rooms_for_user(self, user_id: str, on_invalidate=None):
"""Returns a set of room_ids the user is currently joined to.
If a remote user only returns rooms this server is currently
participating in.
"""
- rooms = yield self.get_rooms_for_user_with_stream_ordering(
+ rooms = await self.get_rooms_for_user_with_stream_ordering(
user_id, on_invalidate=on_invalidate
)
return frozenset(r.room_id for r in rooms)
- @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
- def get_users_who_share_room_with_user(self, user_id, cache_context):
+ @cached(max_entries=500000, cache_context=True, iterable=True)
+ async def get_users_who_share_room_with_user(
+ self, user_id: str, cache_context: _CacheContext
+ ) -> Set[str]:
"""Returns the set of users who share a room with `user_id`
"""
- room_ids = yield self.get_rooms_for_user(
+ room_ids = await self.get_rooms_for_user(
user_id, on_invalidate=cache_context.invalidate
)
user_who_share_room = set()
for room_id in room_ids:
- user_ids = yield self.get_users_in_room(
+ user_ids = await self.get_users_in_room(
room_id, on_invalidate=cache_context.invalidate
)
user_who_share_room.update(user_ids)
return user_who_share_room
- @defer.inlineCallbacks
- def get_joined_users_from_context(self, event, context):
+ async def get_joined_users_from_context(
+ self, event: EventBase, context: EventContext
+ ):
state_group = context.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
@@ -497,14 +485,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# To do this we set the state_group to a new object as object() != object()
state_group = object()
- current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
- result = yield self._get_joined_users_from_context(
+ current_state_ids = await context.get_current_state_ids()
+ return await self._get_joined_users_from_context(
event.room_id, state_group, current_state_ids, event=event, context=context
)
- return result
- @defer.inlineCallbacks
- def get_joined_users_from_state(self, room_id, state_entry):
+ async def get_joined_users_from_state(self, room_id, state_entry):
state_group = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
@@ -514,16 +500,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
state_group = object()
with Measure(self._clock, "get_joined_users_from_state"):
- return (
- yield self._get_joined_users_from_context(
- room_id, state_group, state_entry.state, context=state_entry
- )
+ return await self._get_joined_users_from_context(
+ room_id, state_group, state_entry.state, context=state_entry
)
- @cachedInlineCallbacks(
- num_args=2, cache_context=True, iterable=True, max_entries=100000
- )
- def _get_joined_users_from_context(
+ @cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
+ async def _get_joined_users_from_context(
self,
room_id,
state_group,
@@ -535,7 +517,6 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# We don't use `state_group`, it's there so that we can cache based
# on it. However, it's important that it's never None, since two current_states
# with a state_group of None are likely to be different.
- # See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None
users_in_room = {}
@@ -588,7 +569,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
missing_member_event_ids.append(event_id)
if missing_member_event_ids:
- event_to_memberships = yield self._get_joined_profiles_from_event_ids(
+ event_to_memberships = await self._get_joined_profiles_from_event_ids(
missing_member_event_ids
)
users_in_room.update((row for row in event_to_memberships.values() if row))
@@ -612,12 +593,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
list_name="event_ids",
inlineCallbacks=True,
)
- def _get_joined_profiles_from_event_ids(self, event_ids):
+ def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
"""For given set of member event_ids check if they point to a join
event and if so return the associated user and profile info.
Args:
- event_ids (Iterable[str]): The member event IDs to lookup
+ event_ids: The member event IDs to lookup
Returns:
Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
@@ -644,8 +625,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
for row in rows
}
- @cachedInlineCallbacks(max_entries=10000)
- def is_host_joined(self, room_id, host):
+ @cached(max_entries=10000)
+ async def is_host_joined(self, room_id: str, host: str) -> bool:
if "%" in host or "_" in host:
raise Exception("Invalid host name")
@@ -664,7 +645,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# the returned user actually has the correct domain.
like_clause = "%:" + host
- rows = yield self.db_pool.execute(
+ rows = await self.db_pool.execute(
"is_host_joined", None, sql, room_id, like_clause
)
@@ -678,50 +659,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return True
- @cachedInlineCallbacks()
- def was_host_joined(self, room_id, host):
- """Check whether the server is or ever was in the room.
-
- Args:
- room_id (str)
- host (str)
-
- Returns:
- Deferred: Resolves to True if the host is/was in the room, otherwise
- False.
- """
- if "%" in host or "_" in host:
- raise Exception("Invalid host name")
-
- sql = """
- SELECT user_id FROM room_memberships
- WHERE room_id = ?
- AND user_id LIKE ?
- AND membership = 'join'
- LIMIT 1
- """
-
- # We do need to be careful to ensure that host doesn't have any wild cards
- # in it, but we checked above for known ones and we'll check below that
- # the returned user actually has the correct domain.
- like_clause = "%:" + host
-
- rows = yield self.db_pool.execute(
- "was_host_joined", None, sql, room_id, like_clause
- )
-
- if not rows:
- return False
-
- user_id = rows[0][0]
- if get_domain_from_id(user_id) != host:
- # This can only happen if the host name has something funky in it
- raise Exception("Invalid host name")
-
- return True
-
- @defer.inlineCallbacks
- def get_joined_hosts(self, room_id, state_entry):
+ async def get_joined_hosts(self, room_id: str, state_entry):
state_group = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
@@ -731,32 +669,28 @@ class RoomMemberWorkerStore(EventsWorkerStore):
state_group = object()
with Measure(self._clock, "get_joined_hosts"):
- return (
- yield self._get_joined_hosts(
- room_id, state_group, state_entry.state, state_entry=state_entry
- )
+ return await self._get_joined_hosts(
+ room_id, state_group, state_entry.state, state_entry=state_entry
)
- @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
- # @defer.inlineCallbacks
- def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry):
+ @cached(num_args=2, max_entries=10000, iterable=True)
+ async def _get_joined_hosts(
+ self, room_id, state_group, current_state_ids, state_entry
+ ):
# We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different.
- # See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None
- cache = yield self._get_joined_hosts_cache(room_id)
- joined_hosts = yield cache.get_destinations(state_entry)
-
- return joined_hosts
+ cache = await self._get_joined_hosts_cache(room_id)
+ return await cache.get_destinations(state_entry)
@cached(max_entries=10000)
- def _get_joined_hosts_cache(self, room_id):
+ def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache":
return _JoinedHostsCache(self, room_id)
- @cachedInlineCallbacks(num_args=2)
- def did_forget(self, user_id, room_id):
+ @cached(num_args=2)
+ async def did_forget(self, user_id: str, room_id: str) -> bool:
"""Returns whether user_id has elected to discard history for room_id.
Returns False if they have since re-joined."""
@@ -778,15 +712,15 @@ class RoomMemberWorkerStore(EventsWorkerStore):
rows = txn.fetchall()
return rows[0][0]
- count = yield self.db_pool.runInteraction("did_forget_membership", f)
+ count = await self.db_pool.runInteraction("did_forget_membership", f)
return count == 0
@cached()
- def get_forgotten_rooms_for_user(self, user_id):
+ def get_forgotten_rooms_for_user(self, user_id: str):
"""Gets all rooms the user has forgotten.
Args:
- user_id (str)
+ user_id
Returns:
Deferred[set[str]]
@@ -819,18 +753,17 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
)
- @defer.inlineCallbacks
- def get_rooms_user_has_been_in(self, user_id):
+ async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]:
"""Get all rooms that the user has ever been in.
Args:
- user_id (str)
+ user_id: The user ID to get the rooms of.
Returns:
- Deferred[set[str]]: Set of room IDs.
+ Set of room IDs.
"""
- room_ids = yield self.db_pool.simple_select_onecol(
+ room_ids = await self.db_pool.simple_select_onecol(
table="room_memberships",
keyvalues={"membership": Membership.JOIN, "user_id": user_id},
retcol="room_id",
@@ -905,8 +838,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
where_clause="forgotten = 1",
)
- @defer.inlineCallbacks
- def _background_add_membership_profile(self, progress, batch_size):
+ async def _background_add_membership_profile(self, progress, batch_size):
target_min_stream_id = progress.get(
"target_min_stream_id_inclusive", self._min_stream_order_on_start
)
@@ -971,19 +903,18 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
return len(rows)
- result = yield self.db_pool.runInteraction(
+ result = await self.db_pool.runInteraction(
_MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn
)
if not result:
- yield self.db_pool.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
_MEMBERSHIP_PROFILE_UPDATE_NAME
)
return result
- @defer.inlineCallbacks
- def _background_current_state_membership(self, progress, batch_size):
+ async def _background_current_state_membership(self, progress, batch_size):
"""Update the new membership column on current_state_events.
This works by iterating over all rooms in alphebetical order.
@@ -1029,14 +960,14 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
# string, which will compare before all room IDs correctly.
last_processed_room = progress.get("last_processed_room", "")
- row_count, finished = yield self.db_pool.runInteraction(
+ row_count, finished = await self.db_pool.runInteraction(
"_background_current_state_membership_update",
_background_current_state_membership_txn,
last_processed_room,
)
if finished:
- yield self.db_pool.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME
)
@@ -1047,7 +978,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomMemberStore, self).__init__(database, db_conn, hs)
- def forget(self, user_id, room_id):
+ def forget(self, user_id: str, room_id: str):
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn):
@@ -1088,17 +1019,19 @@ class _JoinedHostsCache(object):
self._len = 0
- @defer.inlineCallbacks
- def get_destinations(self, state_entry):
+ async def get_destinations(self, state_entry: "_StateCacheEntry") -> Set[str]:
"""Get set of destinations for a state entry
Args:
- state_entry(synapse.state._StateCacheEntry)
+ state_entry
+
+ Returns:
+ The destinations as a set.
"""
if state_entry.state_group == self.state_group:
return frozenset(self.hosts_to_joined_users)
- with (yield self.linearizer.queue(())):
+ with (await self.linearizer.queue(())):
if state_entry.state_group == self.state_group:
pass
elif state_entry.prev_group == self.state_group:
@@ -1110,7 +1043,7 @@ class _JoinedHostsCache(object):
user_id = state_key
known_joins = self.hosts_to_joined_users.setdefault(host, set())
- event = yield self.store.get_event(event_id)
+ event = await self.store.get_event(event_id)
if event.membership == Membership.JOIN:
known_joins.add(user_id)
else:
@@ -1119,7 +1052,7 @@ class _JoinedHostsCache(object):
if not known_joins:
self.hosts_to_joined_users.pop(host, None)
else:
- joined_users = yield self.store.get_joined_users_from_state(
+ joined_users = await self.store.get_joined_users_from_state(
self.room_id, state_entry
)
|