summary refs log tree commit diff
path: root/synapse/storage/roommember.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/roommember.py')
-rw-r--r--synapse/storage/roommember.py70
1 files changed, 23 insertions, 47 deletions
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 9bf608bc90..65ffb4627f 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -17,7 +17,7 @@ from twisted.internet import defer
 
 from collections import namedtuple
 
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, cached
 
 from synapse.api.constants import Membership
 from synapse.types import UserID
@@ -35,11 +35,6 @@ RoomsForUser = namedtuple(
 
 class RoomMemberStore(SQLBaseStore):
 
-    def __init__(self, *args, **kw):
-        super(RoomMemberStore, self).__init__(*args, **kw)
-
-        self._user_rooms_cache = {}
-
     def _store_room_member_txn(self, txn, event):
         """Store a room member in the database.
         """
@@ -103,7 +98,7 @@ class RoomMemberStore(SQLBaseStore):
 
                 txn.execute(sql, (event.room_id, domain))
 
-        self.invalidate_rooms_for_user(target_user_id)
+        self.get_rooms_for_user.invalidate(target_user_id)
 
     @defer.inlineCallbacks
     def get_room_member(self, user_id, room_id):
@@ -185,6 +180,14 @@ class RoomMemberStore(SQLBaseStore):
         if not membership_list:
             return defer.succeed(None)
 
+        return self.runInteraction(
+            "get_rooms_for_user_where_membership_is",
+            self._get_rooms_for_user_where_membership_is_txn,
+            user_id, membership_list
+        )
+
+    def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id,
+                                                    membership_list):
         where_clause = "user_id = ? AND (%s)" % (
             " OR ".join(["membership = ?" for _ in membership_list]),
         )
@@ -192,24 +195,18 @@ class RoomMemberStore(SQLBaseStore):
         args = [user_id]
         args.extend(membership_list)
 
-        def f(txn):
-            sql = (
-                "SELECT m.room_id, m.sender, m.membership"
-                " FROM room_memberships as m"
-                " INNER JOIN current_state_events as c"
-                " ON m.event_id = c.event_id"
-                " WHERE %s"
-            ) % (where_clause,)
-
-            txn.execute(sql, args)
-            return [
-                RoomsForUser(**r) for r in self.cursor_to_dict(txn)
-            ]
+        sql = (
+            "SELECT m.room_id, m.sender, m.membership"
+            " FROM room_memberships as m"
+            " INNER JOIN current_state_events as c"
+            " ON m.event_id = c.event_id"
+            " WHERE %s"
+        ) % (where_clause,)
 
-        return self.runInteraction(
-            "get_rooms_for_user_where_membership_is",
-            f
-        )
+        txn.execute(sql, args)
+        return [
+            RoomsForUser(**r) for r in self.cursor_to_dict(txn)
+        ]
 
     def get_joined_hosts_for_room(self, room_id):
         return self._simple_select_onecol(
@@ -247,33 +244,12 @@ class RoomMemberStore(SQLBaseStore):
         results = self._parse_events_txn(txn, rows)
         return results
 
-    # TODO(paul): Create a nice @cached decorator to do this
-    #    @cached
-    #    def get_foo(...)
-    #        ...
-    #    invalidate_foo = get_foo.invalidator
-
-    @defer.inlineCallbacks
+    @cached()
     def get_rooms_for_user(self, user_id):
-        # TODO(paul): put some performance counters in here so we can easily
-        #   track what impact this cache is having
-        if user_id in self._user_rooms_cache:
-            defer.returnValue(self._user_rooms_cache[user_id])
-
-        rooms = yield self.get_rooms_for_user_where_membership_is(
+        return self.get_rooms_for_user_where_membership_is(
             user_id, membership_list=[Membership.JOIN],
         )
 
-        # TODO(paul): Consider applying a maximum size; just evict things at
-        #   random, or consider LRU?
-
-        self._user_rooms_cache[user_id] = rooms
-        defer.returnValue(rooms)
-
-    def invalidate_rooms_for_user(self, user_id):
-        if user_id in self._user_rooms_cache:
-            del self._user_rooms_cache[user_id]
-
     @defer.inlineCallbacks
     def user_rooms_intersect(self, user_id_list):
         """ Checks whether all the users whose IDs are given in a list share a