summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erikj@element.io>2024-07-19 16:19:15 +0100
committerGitHub <noreply@github.com>2024-07-19 16:19:15 +0100
commitd3f9afd8d9db8c80b342177b9ab162c79357c431 (patch)
tree70eec1581cff3d605db6da7f6f491e179268cf26
parentGenerate room sync data concurrently (#17458) (diff)
downloadsynapse-d3f9afd8d9db8c80b342177b9ab162c79357c431.tar.xz
Add a cache on `get_rooms_for_local_user_where_membership_is` (#17460)
As it gets used in sliding sync.

We basically invalidate it in all the same places as
`get_rooms_for_user`. Most of the changes are due to needing the
arguments you pass in to be hashable (which lists aren't)
-rw-r--r--changelog.d/17460.misc1
-rw-r--r--synapse/api/constants.py2
-rw-r--r--synapse/storage/_base.py6
-rw-r--r--synapse/storage/databases/main/cache.py6
-rw-r--r--synapse/storage/databases/main/roommember.py26
-rw-r--r--tests/handlers/test_sync.py1
6 files changed, 38 insertions, 4 deletions
diff --git a/changelog.d/17460.misc b/changelog.d/17460.misc
new file mode 100644
index 0000000000..fd99da5a95
--- /dev/null
+++ b/changelog.d/17460.misc
@@ -0,0 +1 @@
+Add cache to `get_rooms_for_local_user_where_membership_is` to speed up sliding sync.
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 12d18137e0..85001d9676 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -50,7 +50,7 @@ class Membership:
     KNOCK: Final = "knock"
     LEAVE: Final = "leave"
     BAN: Final = "ban"
-    LIST: Final = {INVITE, JOIN, KNOCK, LEAVE, BAN}
+    LIST: Final = frozenset((INVITE, JOIN, KNOCK, LEAVE, BAN))
 
 
 class PresenceState:
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 881888fa93..066f3d08ae 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -120,6 +120,9 @@ class SQLBaseStore(metaclass=ABCMeta):
                 "get_user_in_room_with_profile", (room_id, user_id)
             )
             self._attempt_to_invalidate_cache("get_rooms_for_user", (user_id,))
+            self._attempt_to_invalidate_cache(
+                "_get_rooms_for_local_user_where_membership_is_inner", (user_id,)
+            )
 
         # Purge other caches based on room state.
         self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
@@ -146,6 +149,9 @@ class SQLBaseStore(metaclass=ABCMeta):
         self._attempt_to_invalidate_cache("does_pair_of_users_share_a_room", None)
         self._attempt_to_invalidate_cache("get_user_in_room_with_profile", None)
         self._attempt_to_invalidate_cache("get_rooms_for_user", None)
+        self._attempt_to_invalidate_cache(
+            "_get_rooms_for_local_user_where_membership_is_inner", None
+        )
         self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
 
     def _attempt_to_invalidate_cache(
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 2d6b75e47e..26b8e1a172 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -331,6 +331,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
                 "get_invited_rooms_for_local_user", (state_key,)
             )
             self._attempt_to_invalidate_cache("get_rooms_for_user", (state_key,))
+            self._attempt_to_invalidate_cache(
+                "_get_rooms_for_local_user_where_membership_is_inner", (state_key,)
+            )
 
             self._attempt_to_invalidate_cache(
                 "did_forget",
@@ -393,6 +396,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         self._attempt_to_invalidate_cache("get_thread_id_for_receipts", None)
         self._attempt_to_invalidate_cache("get_invited_rooms_for_local_user", None)
         self._attempt_to_invalidate_cache("get_rooms_for_user", None)
+        self._attempt_to_invalidate_cache(
+            "_get_rooms_for_local_user_where_membership_is_inner", None
+        )
         self._attempt_to_invalidate_cache("did_forget", None)
         self._attempt_to_invalidate_cache("get_forgotten_rooms_for_user", None)
         self._attempt_to_invalidate_cache("get_references_for_event", None)
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index f62d9f705d..640ab123f0 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -445,9 +445,11 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
         if not membership_list:
             return []
 
-        rooms = await self.db_pool.runInteraction(
-            "get_rooms_for_local_user_where_membership_is",
-            self._get_rooms_for_local_user_where_membership_is_txn,
+        # Convert membership list to frozen set as a) it needs to be hashable,
+        # and b) we don't care about the order.
+        membership_list = frozenset(membership_list)
+
+        rooms = await self._get_rooms_for_local_user_where_membership_is_inner(
             user_id,
             membership_list,
         )
@@ -466,6 +468,24 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
 
         return [room for room in rooms if room.room_id not in rooms_to_exclude]
 
+    @cached(max_entries=1000, tree=True)
+    async def _get_rooms_for_local_user_where_membership_is_inner(
+        self,
+        user_id: str,
+        membership_list: Collection[str],
+    ) -> Sequence[RoomsForUser]:
+        if not membership_list:
+            return []
+
+        rooms = await self.db_pool.runInteraction(
+            "get_rooms_for_local_user_where_membership_is",
+            self._get_rooms_for_local_user_where_membership_is_txn,
+            user_id,
+            membership_list,
+        )
+
+        return rooms
+
     def _get_rooms_for_local_user_where_membership_is_txn(
         self,
         txn: LoggingTransaction,
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 77aafa492e..fa55f76916 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -211,6 +211,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
 
         # Blow away caches (supported room versions can only change due to a restart).
         self.store.get_rooms_for_user.invalidate_all()
+        self.store._get_rooms_for_local_user_where_membership_is_inner.invalidate_all()
         self.store._get_event_cache.clear()
         self.store._event_ref.clear()