summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/_base.py1
-rw-r--r--synapse/storage/databases/main/cache.py1
-rw-r--r--synapse/storage/databases/main/roommember.py117
3 files changed, 58 insertions, 61 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 313e8aca7d..bf42aeb8d1 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -94,6 +94,7 @@ class SQLBaseStore(metaclass=ABCMeta):
             self._attempt_to_invalidate_cache(
                 "get_rooms_for_user_with_stream_ordering", (user_id,)
             )
+            self._attempt_to_invalidate_cache("get_rooms_for_user", (user_id,))
 
         # Purge other caches based on room state.
         self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index db6ce83a2b..3b8ed1f7ee 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -205,6 +205,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
                 self.get_rooms_for_user_with_stream_ordering.invalidate(
                     (data.state_key,)
                 )
+                self.get_rooms_for_user.invalidate((data.state_key,))
         else:
             raise Exception("Unknown events stream row type %s" % (row.type,))
 
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 8ada3cdac3..982e1f08e3 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -15,7 +15,6 @@
 import logging
 from typing import (
     TYPE_CHECKING,
-    Callable,
     Collection,
     Dict,
     FrozenSet,
@@ -52,7 +51,6 @@ from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
-from synapse.util.cancellation import cancellable
 from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
 
@@ -600,58 +598,6 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             for room_id, instance, stream_id in txn
         )
 
-    @cachedList(
-        cached_method_name="get_rooms_for_user_with_stream_ordering",
-        list_name="user_ids",
-    )
-    async def get_rooms_for_users_with_stream_ordering(
-        self, user_ids: Collection[str]
-    ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
-        """A batched version of `get_rooms_for_user_with_stream_ordering`.
-
-        Returns:
-            Map from user_id to set of rooms that is currently in.
-        """
-        return await self.db_pool.runInteraction(
-            "get_rooms_for_users_with_stream_ordering",
-            self._get_rooms_for_users_with_stream_ordering_txn,
-            user_ids,
-        )
-
-    def _get_rooms_for_users_with_stream_ordering_txn(
-        self, txn: LoggingTransaction, user_ids: Collection[str]
-    ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
-
-        clause, args = make_in_list_sql_clause(
-            self.database_engine,
-            "c.state_key",
-            user_ids,
-        )
-
-        sql = f"""
-            SELECT c.state_key, room_id, e.instance_name, e.stream_ordering
-            FROM current_state_events AS c
-            INNER JOIN events AS e USING (room_id, event_id)
-            WHERE
-                c.type = 'm.room.member'
-                AND c.membership = ?
-                AND {clause}
-        """
-
-        txn.execute(sql, [Membership.JOIN] + args)
-
-        result: Dict[str, Set[GetRoomsForUserWithStreamOrdering]] = {
-            user_id: set() for user_id in user_ids
-        }
-        for user_id, room_id, instance, stream_id in txn:
-            result[user_id].add(
-                GetRoomsForUserWithStreamOrdering(
-                    room_id, PersistedEventPosition(instance, stream_id)
-                )
-            )
-
-        return {user_id: frozenset(v) for user_id, v in result.items()}
-
     async def get_users_server_still_shares_room_with(
         self, user_ids: Collection[str]
     ) -> Set[str]:
@@ -693,19 +639,68 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return {row[0] for row in txn}
 
-    @cancellable
-    async def get_rooms_for_user(
-        self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None
-    ) -> FrozenSet[str]:
+    @cached(max_entries=500000, iterable=True)
+    async def get_rooms_for_user(self, user_id: str) -> FrozenSet[str]:
         """Returns a set of room_ids the user is currently joined to.
 
         If a remote user only returns rooms this server is currently
         participating in.
         """
-        rooms = await self.get_rooms_for_user_with_stream_ordering(
-            user_id, on_invalidate=on_invalidate
+        rooms = self.get_rooms_for_user_with_stream_ordering.cache.get_immediate(
+            (user_id,),
+            None,
+            update_metrics=False,
+        )
+        if rooms:
+            return frozenset(r.room_id for r in rooms)
+
+        room_ids = await self.db_pool.simple_select_onecol(
+            table="current_state_events",
+            keyvalues={
+                "type": EventTypes.Member,
+                "membership": Membership.JOIN,
+                "state_key": user_id,
+            },
+            retcol="room_id",
+            desc="get_rooms_for_user",
         )
-        return frozenset(r.room_id for r in rooms)
+
+        return frozenset(room_ids)
+
+    @cachedList(
+        cached_method_name="get_rooms_for_user",
+        list_name="user_ids",
+    )
+    async def get_rooms_for_users(
+        self, user_ids: Collection[str]
+    ) -> Dict[str, FrozenSet[str]]:
+        """A batched version of `get_rooms_for_user`.
+
+        Returns:
+            Map from user_id to set of rooms that is currently in.
+        """
+
+        rows = await self.db_pool.simple_select_many_batch(
+            table="current_state_events",
+            column="state_key",
+            iterable=user_ids,
+            retcols=(
+                "state_key",
+                "room_id",
+            ),
+            keyvalues={
+                "type": EventTypes.Member,
+                "membership": Membership.JOIN,
+            },
+            desc="get_rooms_for_users",
+        )
+
+        user_rooms: Dict[str, Set[str]] = {user_id: set() for user_id in user_ids}
+
+        for row in rows:
+            user_rooms[row["state_key"]].add(row["room_id"])
+
+        return {key: frozenset(rooms) for key, rooms in user_rooms.items()}
 
     @cached(max_entries=10000)
     async def does_pair_of_users_share_a_room(