summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/_base.py6
-rw-r--r--synapse/storage/databases/main/cache.py6
-rw-r--r--synapse/storage/databases/main/roommember.py26
3 files changed, 35 insertions, 3 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 881888fa93..066f3d08ae 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -120,6 +120,9 @@ class SQLBaseStore(metaclass=ABCMeta):
                 "get_user_in_room_with_profile", (room_id, user_id)
             )
             self._attempt_to_invalidate_cache("get_rooms_for_user", (user_id,))
+            self._attempt_to_invalidate_cache(
+                "_get_rooms_for_local_user_where_membership_is_inner", (user_id,)
+            )
 
         # Purge other caches based on room state.
         self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
@@ -146,6 +149,9 @@ class SQLBaseStore(metaclass=ABCMeta):
         self._attempt_to_invalidate_cache("does_pair_of_users_share_a_room", None)
         self._attempt_to_invalidate_cache("get_user_in_room_with_profile", None)
         self._attempt_to_invalidate_cache("get_rooms_for_user", None)
+        self._attempt_to_invalidate_cache(
+            "_get_rooms_for_local_user_where_membership_is_inner", None
+        )
         self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
 
     def _attempt_to_invalidate_cache(
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 2d6b75e47e..26b8e1a172 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -331,6 +331,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
                 "get_invited_rooms_for_local_user", (state_key,)
             )
             self._attempt_to_invalidate_cache("get_rooms_for_user", (state_key,))
+            self._attempt_to_invalidate_cache(
+                "_get_rooms_for_local_user_where_membership_is_inner", (state_key,)
+            )
 
             self._attempt_to_invalidate_cache(
                 "did_forget",
@@ -393,6 +396,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         self._attempt_to_invalidate_cache("get_thread_id_for_receipts", None)
         self._attempt_to_invalidate_cache("get_invited_rooms_for_local_user", None)
         self._attempt_to_invalidate_cache("get_rooms_for_user", None)
+        self._attempt_to_invalidate_cache(
+            "_get_rooms_for_local_user_where_membership_is_inner", None
+        )
         self._attempt_to_invalidate_cache("did_forget", None)
         self._attempt_to_invalidate_cache("get_forgotten_rooms_for_user", None)
         self._attempt_to_invalidate_cache("get_references_for_event", None)
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index f62d9f705d..640ab123f0 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -445,9 +445,11 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
         if not membership_list:
             return []
 
-        rooms = await self.db_pool.runInteraction(
-            "get_rooms_for_local_user_where_membership_is",
-            self._get_rooms_for_local_user_where_membership_is_txn,
+        # Convert membership list to frozen set as a) it needs to be hashable,
+        # and b) we don't care about the order.
+        membership_list = frozenset(membership_list)
+
+        rooms = await self._get_rooms_for_local_user_where_membership_is_inner(
             user_id,
             membership_list,
         )
@@ -466,6 +468,24 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
 
         return [room for room in rooms if room.room_id not in rooms_to_exclude]
 
+    @cached(max_entries=1000, tree=True)
+    async def _get_rooms_for_local_user_where_membership_is_inner(
+        self,
+        user_id: str,
+        membership_list: Collection[str],
+    ) -> Sequence[RoomsForUser]:
+        if not membership_list:
+            return []
+
+        rooms = await self.db_pool.runInteraction(
+            "get_rooms_for_local_user_where_membership_is",
+            self._get_rooms_for_local_user_where_membership_is_txn,
+            user_id,
+            membership_list,
+        )
+
+        return rooms
+
     def _get_rooms_for_local_user_where_membership_is_txn(
         self,
         txn: LoggingTransaction,