diff options
Diffstat (limited to 'synapse/storage/databases/main/roommember.py')
-rw-r--r-- | synapse/storage/databases/main/roommember.py | 117 |
1 files changed, 56 insertions, 61 deletions
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 8ada3cdac3..982e1f08e3 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -15,7 +15,6 @@ import logging from typing import ( TYPE_CHECKING, - Callable, Collection, Dict, FrozenSet, @@ -52,7 +51,6 @@ from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import _CacheContext, cached, cachedList -from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -600,58 +598,6 @@ class RoomMemberWorkerStore(EventsWorkerStore): for room_id, instance, stream_id in txn ) - @cachedList( - cached_method_name="get_rooms_for_user_with_stream_ordering", - list_name="user_ids", - ) - async def get_rooms_for_users_with_stream_ordering( - self, user_ids: Collection[str] - ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]: - """A batched version of `get_rooms_for_user_with_stream_ordering`. - - Returns: - Map from user_id to set of rooms that is currently in. - """ - return await self.db_pool.runInteraction( - "get_rooms_for_users_with_stream_ordering", - self._get_rooms_for_users_with_stream_ordering_txn, - user_ids, - ) - - def _get_rooms_for_users_with_stream_ordering_txn( - self, txn: LoggingTransaction, user_ids: Collection[str] - ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]: - - clause, args = make_in_list_sql_clause( - self.database_engine, - "c.state_key", - user_ids, - ) - - sql = f""" - SELECT c.state_key, room_id, e.instance_name, e.stream_ordering - FROM current_state_events AS c - INNER JOIN events AS e USING (room_id, event_id) - WHERE - c.type = 'm.room.member' - AND c.membership = ? - AND {clause} - """ - - txn.execute(sql, [Membership.JOIN] + args) - - result: Dict[str, Set[GetRoomsForUserWithStreamOrdering]] = { - user_id: set() for user_id in user_ids - } - for user_id, room_id, instance, stream_id in txn: - result[user_id].add( - GetRoomsForUserWithStreamOrdering( - room_id, PersistedEventPosition(instance, stream_id) - ) - ) - - return {user_id: frozenset(v) for user_id, v in result.items()} - async def get_users_server_still_shares_room_with( self, user_ids: Collection[str] ) -> Set[str]: @@ -693,19 +639,68 @@ class RoomMemberWorkerStore(EventsWorkerStore): return {row[0] for row in txn} - @cancellable - async def get_rooms_for_user( - self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None - ) -> FrozenSet[str]: + @cached(max_entries=500000, iterable=True) + async def get_rooms_for_user(self, user_id: str) -> FrozenSet[str]: """Returns a set of room_ids the user is currently joined to. If a remote user only returns rooms this server is currently participating in. """ - rooms = await self.get_rooms_for_user_with_stream_ordering( - user_id, on_invalidate=on_invalidate + rooms = self.get_rooms_for_user_with_stream_ordering.cache.get_immediate( + (user_id,), + None, + update_metrics=False, + ) + if rooms: + return frozenset(r.room_id for r in rooms) + + room_ids = await self.db_pool.simple_select_onecol( + table="current_state_events", + keyvalues={ + "type": EventTypes.Member, + "membership": Membership.JOIN, + "state_key": user_id, + }, + retcol="room_id", + desc="get_rooms_for_user", ) - return frozenset(r.room_id for r in rooms) + + return frozenset(room_ids) + + @cachedList( + cached_method_name="get_rooms_for_user", + list_name="user_ids", + ) + async def get_rooms_for_users( + self, user_ids: Collection[str] + ) -> Dict[str, FrozenSet[str]]: + """A batched version of `get_rooms_for_user`. + + Returns: + Map from user_id to set of rooms that is currently in. + """ + + rows = await self.db_pool.simple_select_many_batch( + table="current_state_events", + column="state_key", + iterable=user_ids, + retcols=( + "state_key", + "room_id", + ), + keyvalues={ + "type": EventTypes.Member, + "membership": Membership.JOIN, + }, + desc="get_rooms_for_users", + ) + + user_rooms: Dict[str, Set[str]] = {user_id: set() for user_id in user_ids} + + for row in rows: + user_rooms[row["state_key"]].add(row["room_id"]) + + return {key: frozenset(rooms) for key, rooms in user_rooms.items()} @cached(max_entries=10000) async def does_pair_of_users_share_a_room( |