diff --git a/changelog.d/13787.misc b/changelog.d/13787.misc
new file mode 100644
index 0000000000..a9b93717f0
--- /dev/null
+++ b/changelog.d/13787.misc
@@ -0,0 +1 @@
+Optimise get rooms for user calls. Contributed by Nick @ Beeper (@fizzadar).
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 03082fce42..f9cc5bddbc 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -273,11 +273,9 @@ class DeviceWorkerHandler:
possibly_left = possibly_changed | possibly_left
# Double check if we still share rooms with the given user.
- users_rooms = await self.store.get_rooms_for_users_with_stream_ordering(
- possibly_left
- )
+ users_rooms = await self.store.get_rooms_for_users(possibly_left)
for changed_user_id, entries in users_rooms.items():
- if any(e.room_id in room_ids for e in entries):
+ if any(rid in room_ids for rid in entries):
possibly_left.discard(changed_user_id)
else:
possibly_joined.discard(changed_user_id)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index e75fc6b947..4abb9b6127 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1490,16 +1490,14 @@ class SyncHandler:
since_token.device_list_key
)
if changed_users is not None:
- result = await self.store.get_rooms_for_users_with_stream_ordering(
- changed_users
- )
+ result = await self.store.get_rooms_for_users(changed_users)
for changed_user_id, entries in result.items():
# Check if the changed user shares any rooms with the user,
# or if the changed user is the syncing user (as we always
# want to include device list updates of their own devices).
if user_id == changed_user_id or any(
- e.room_id in joined_rooms for e in entries
+ rid in joined_rooms for rid in entries
):
users_that_have_changed.add(changed_user_id)
else:
@@ -1533,13 +1531,9 @@ class SyncHandler:
newly_left_users.update(left_users)
# Remove any users that we still share a room with.
- left_users_rooms = (
- await self.store.get_rooms_for_users_with_stream_ordering(
- newly_left_users
- )
- )
+ left_users_rooms = await self.store.get_rooms_for_users(newly_left_users)
for user_id, entries in left_users_rooms.items():
- if any(e.room_id in joined_rooms for e in entries):
+ if any(rid in joined_rooms for rid in entries):
newly_left_users.discard(user_id)
return DeviceListUpdates(
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 313e8aca7d..bf42aeb8d1 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -94,6 +94,7 @@ class SQLBaseStore(metaclass=ABCMeta):
self._attempt_to_invalidate_cache(
"get_rooms_for_user_with_stream_ordering", (user_id,)
)
+ self._attempt_to_invalidate_cache("get_rooms_for_user", (user_id,))
# Purge other caches based on room state.
self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index db6ce83a2b..3b8ed1f7ee 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -205,6 +205,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self.get_rooms_for_user_with_stream_ordering.invalidate(
(data.state_key,)
)
+ self.get_rooms_for_user.invalidate((data.state_key,))
else:
raise Exception("Unknown events stream row type %s" % (row.type,))
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 8ada3cdac3..982e1f08e3 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -15,7 +15,6 @@
import logging
from typing import (
TYPE_CHECKING,
- Callable,
Collection,
Dict,
FrozenSet,
@@ -52,7 +51,6 @@ from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
-from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@@ -600,58 +598,6 @@ class RoomMemberWorkerStore(EventsWorkerStore):
for room_id, instance, stream_id in txn
)
- @cachedList(
- cached_method_name="get_rooms_for_user_with_stream_ordering",
- list_name="user_ids",
- )
- async def get_rooms_for_users_with_stream_ordering(
- self, user_ids: Collection[str]
- ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
- """A batched version of `get_rooms_for_user_with_stream_ordering`.
-
- Returns:
- Map from user_id to set of rooms that is currently in.
- """
- return await self.db_pool.runInteraction(
- "get_rooms_for_users_with_stream_ordering",
- self._get_rooms_for_users_with_stream_ordering_txn,
- user_ids,
- )
-
- def _get_rooms_for_users_with_stream_ordering_txn(
- self, txn: LoggingTransaction, user_ids: Collection[str]
- ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
-
- clause, args = make_in_list_sql_clause(
- self.database_engine,
- "c.state_key",
- user_ids,
- )
-
- sql = f"""
- SELECT c.state_key, room_id, e.instance_name, e.stream_ordering
- FROM current_state_events AS c
- INNER JOIN events AS e USING (room_id, event_id)
- WHERE
- c.type = 'm.room.member'
- AND c.membership = ?
- AND {clause}
- """
-
- txn.execute(sql, [Membership.JOIN] + args)
-
- result: Dict[str, Set[GetRoomsForUserWithStreamOrdering]] = {
- user_id: set() for user_id in user_ids
- }
- for user_id, room_id, instance, stream_id in txn:
- result[user_id].add(
- GetRoomsForUserWithStreamOrdering(
- room_id, PersistedEventPosition(instance, stream_id)
- )
- )
-
- return {user_id: frozenset(v) for user_id, v in result.items()}
-
async def get_users_server_still_shares_room_with(
self, user_ids: Collection[str]
) -> Set[str]:
@@ -693,19 +639,68 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return {row[0] for row in txn}
- @cancellable
- async def get_rooms_for_user(
- self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None
- ) -> FrozenSet[str]:
+ @cached(max_entries=500000, iterable=True)
+ async def get_rooms_for_user(self, user_id: str) -> FrozenSet[str]:
"""Returns a set of room_ids the user is currently joined to.
If a remote user only returns rooms this server is currently
participating in.
"""
- rooms = await self.get_rooms_for_user_with_stream_ordering(
- user_id, on_invalidate=on_invalidate
+ rooms = self.get_rooms_for_user_with_stream_ordering.cache.get_immediate(
+ (user_id,),
+ None,
+ update_metrics=False,
+ )
+ if rooms:
+ return frozenset(r.room_id for r in rooms)
+
+ room_ids = await self.db_pool.simple_select_onecol(
+ table="current_state_events",
+ keyvalues={
+ "type": EventTypes.Member,
+ "membership": Membership.JOIN,
+ "state_key": user_id,
+ },
+ retcol="room_id",
+ desc="get_rooms_for_user",
)
- return frozenset(r.room_id for r in rooms)
+
+ return frozenset(room_ids)
+
+ @cachedList(
+ cached_method_name="get_rooms_for_user",
+ list_name="user_ids",
+ )
+ async def get_rooms_for_users(
+ self, user_ids: Collection[str]
+ ) -> Dict[str, FrozenSet[str]]:
+ """A batched version of `get_rooms_for_user`.
+
+ Returns:
+ Map from user_id to set of rooms that is currently in.
+ """
+
+ rows = await self.db_pool.simple_select_many_batch(
+ table="current_state_events",
+ column="state_key",
+ iterable=user_ids,
+ retcols=(
+ "state_key",
+ "room_id",
+ ),
+ keyvalues={
+ "type": EventTypes.Member,
+ "membership": Membership.JOIN,
+ },
+ desc="get_rooms_for_users",
+ )
+
+ user_rooms: Dict[str, Set[str]] = {user_id: set() for user_id in user_ids}
+
+ for row in rows:
+ user_rooms[row["state_key"]].add(row["room_id"])
+
+ return {key: frozenset(rooms) for key, rooms in user_rooms.items()}
@cached(max_entries=10000)
async def does_pair_of_users_share_a_room(
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index e3f38fbcc5..ab5c101eb7 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -159,6 +159,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Blow away caches (supported room versions can only change due to a restart).
self.store.get_rooms_for_user_with_stream_ordering.invalidate_all()
+ self.store.get_rooms_for_user.invalidate_all()
self.get_success(self.store._get_event_cache.clear())
self.store._event_ref.clear()
|