diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 2d6b75e47e..26b8e1a172 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -331,6 +331,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
"get_invited_rooms_for_local_user", (state_key,)
)
self._attempt_to_invalidate_cache("get_rooms_for_user", (state_key,))
+ self._attempt_to_invalidate_cache(
+ "_get_rooms_for_local_user_where_membership_is_inner", (state_key,)
+ )
self._attempt_to_invalidate_cache(
"did_forget",
@@ -393,6 +396,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._attempt_to_invalidate_cache("get_thread_id_for_receipts", None)
self._attempt_to_invalidate_cache("get_invited_rooms_for_local_user", None)
self._attempt_to_invalidate_cache("get_rooms_for_user", None)
+ self._attempt_to_invalidate_cache(
+ "_get_rooms_for_local_user_where_membership_is_inner", None
+ )
self._attempt_to_invalidate_cache("did_forget", None)
self._attempt_to_invalidate_cache("get_forgotten_rooms_for_user", None)
self._attempt_to_invalidate_cache("get_references_for_event", None)
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index f62d9f705d..640ab123f0 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -445,9 +445,11 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
if not membership_list:
return []
- rooms = await self.db_pool.runInteraction(
- "get_rooms_for_local_user_where_membership_is",
- self._get_rooms_for_local_user_where_membership_is_txn,
+ # Convert membership list to frozen set as a) it needs to be hashable,
+ # and b) we don't care about the order.
+ membership_list = frozenset(membership_list)
+
+ rooms = await self._get_rooms_for_local_user_where_membership_is_inner(
user_id,
membership_list,
)
@@ -466,6 +468,24 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
return [room for room in rooms if room.room_id not in rooms_to_exclude]
+ @cached(max_entries=1000, tree=True)
+ async def _get_rooms_for_local_user_where_membership_is_inner(
+ self,
+ user_id: str,
+ membership_list: Collection[str],
+ ) -> Sequence[RoomsForUser]:
+ if not membership_list:
+ return []
+
+ rooms = await self.db_pool.runInteraction(
+ "get_rooms_for_local_user_where_membership_is",
+ self._get_rooms_for_local_user_where_membership_is_txn,
+ user_id,
+ membership_list,
+ )
+
+ return rooms
+
def _get_rooms_for_local_user_where_membership_is_txn(
self,
txn: LoggingTransaction,
|