diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 9bf608bc90..65ffb4627f 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -17,7 +17,7 @@ from twisted.internet import defer
from collections import namedtuple
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, cached
from synapse.api.constants import Membership
from synapse.types import UserID
@@ -35,11 +35,6 @@ RoomsForUser = namedtuple(
class RoomMemberStore(SQLBaseStore):
- def __init__(self, *args, **kw):
- super(RoomMemberStore, self).__init__(*args, **kw)
-
- self._user_rooms_cache = {}
-
def _store_room_member_txn(self, txn, event):
"""Store a room member in the database.
"""
@@ -103,7 +98,7 @@ class RoomMemberStore(SQLBaseStore):
txn.execute(sql, (event.room_id, domain))
- self.invalidate_rooms_for_user(target_user_id)
+ self.get_rooms_for_user.invalidate(target_user_id)
@defer.inlineCallbacks
def get_room_member(self, user_id, room_id):
@@ -185,6 +180,14 @@ class RoomMemberStore(SQLBaseStore):
if not membership_list:
return defer.succeed(None)
+ return self.runInteraction(
+ "get_rooms_for_user_where_membership_is",
+ self._get_rooms_for_user_where_membership_is_txn,
+ user_id, membership_list
+ )
+
+ def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id,
+ membership_list):
where_clause = "user_id = ? AND (%s)" % (
" OR ".join(["membership = ?" for _ in membership_list]),
)
@@ -192,24 +195,18 @@ class RoomMemberStore(SQLBaseStore):
args = [user_id]
args.extend(membership_list)
- def f(txn):
- sql = (
- "SELECT m.room_id, m.sender, m.membership"
- " FROM room_memberships as m"
- " INNER JOIN current_state_events as c"
- " ON m.event_id = c.event_id"
- " WHERE %s"
- ) % (where_clause,)
-
- txn.execute(sql, args)
- return [
- RoomsForUser(**r) for r in self.cursor_to_dict(txn)
- ]
+ sql = (
+ "SELECT m.room_id, m.sender, m.membership"
+ " FROM room_memberships as m"
+ " INNER JOIN current_state_events as c"
+ " ON m.event_id = c.event_id"
+ " WHERE %s"
+ ) % (where_clause,)
- return self.runInteraction(
- "get_rooms_for_user_where_membership_is",
- f
- )
+ txn.execute(sql, args)
+ return [
+ RoomsForUser(**r) for r in self.cursor_to_dict(txn)
+ ]
def get_joined_hosts_for_room(self, room_id):
return self._simple_select_onecol(
@@ -247,33 +244,12 @@ class RoomMemberStore(SQLBaseStore):
results = self._parse_events_txn(txn, rows)
return results
- # TODO(paul): Create a nice @cached decorator to do this
- # @cached
- # def get_foo(...)
- # ...
- # invalidate_foo = get_foo.invalidator
-
- @defer.inlineCallbacks
+ @cached()
def get_rooms_for_user(self, user_id):
- # TODO(paul): put some performance counters in here so we can easily
- # track what impact this cache is having
- if user_id in self._user_rooms_cache:
- defer.returnValue(self._user_rooms_cache[user_id])
-
- rooms = yield self.get_rooms_for_user_where_membership_is(
+ return self.get_rooms_for_user_where_membership_is(
user_id, membership_list=[Membership.JOIN],
)
- # TODO(paul): Consider applying a maximum size; just evict things at
- # random, or consider LRU?
-
- self._user_rooms_cache[user_id] = rooms
- defer.returnValue(rooms)
-
- def invalidate_rooms_for_user(self, user_id):
- if user_id in self._user_rooms_cache:
- del self._user_rooms_cache[user_id]
-
@defer.inlineCallbacks
def user_rooms_intersect(self, user_id_list):
""" Checks whether all the users whose IDs are given in a list share a
|