summary refs log tree commit diff
path: root/synapse/storage/roommember.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/roommember.py')
-rw-r--r--synapse/storage/roommember.py73
1 files changed, 53 insertions, 20 deletions
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index e59e65529b..779f9ce544 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -20,6 +20,7 @@ from collections import namedtuple
 from ._base import SQLBaseStore
 
 from synapse.api.constants import Membership
+from synapse.types import UserID
 
 import logging
 
@@ -34,12 +35,17 @@ RoomsForUser = namedtuple(
 
 class RoomMemberStore(SQLBaseStore):
 
+    def __init__(self, *args, **kw):
+        super(RoomMemberStore, self).__init__(*args, **kw)
+
+        self._user_rooms_cache = {}
+
     def _store_room_member_txn(self, txn, event):
         """Store a room member in the database.
         """
         try:
             target_user_id = event.state_key
-            domain = self.hs.parse_userid(target_user_id).domain
+            domain = UserID.from_string(target_user_id).domain
         except:
             logger.exception(
                 "Failed to parse target_user_id=%s", target_user_id
@@ -84,7 +90,7 @@ class RoomMemberStore(SQLBaseStore):
             for e in member_events:
                 try:
                     joined_domains.add(
-                        self.hs.parse_userid(e.state_key).domain
+                        UserID.from_string(e.state_key).domain
                     )
                 except:
                     # FIXME: How do we deal with invalid user ids in the db?
@@ -97,6 +103,8 @@ class RoomMemberStore(SQLBaseStore):
 
                 txn.execute(sql, (event.room_id, domain))
 
+        self.invalidate_rooms_for_user(target_user_id)
+
     @defer.inlineCallbacks
     def get_room_member(self, user_id, room_id):
         """Retrieve the current state of a room member.
@@ -239,28 +247,53 @@ class RoomMemberStore(SQLBaseStore):
         results = self._parse_events_txn(txn, rows)
         return results
 
+    # TODO(paul): Create a nice @cached decorator to do this
+    #    @cached
+    #    def get_foo(...)
+    #        ...
+    #    invalidate_foo = get_foo.invalidator
+
+    @defer.inlineCallbacks
+    def get_rooms_for_user(self, user_id):
+        # TODO(paul): put some performance counters in here so we can easily
+        #   track what impact this cache is having
+        if user_id in self._user_rooms_cache:
+            defer.returnValue(self._user_rooms_cache[user_id])
+
+        rooms = yield self.get_rooms_for_user_where_membership_is(
+            user_id, membership_list=[Membership.JOIN],
+        )
+
+        # TODO(paul): Consider applying a maximum size; just evict things at
+        #   random, or consider LRU?
+
+        self._user_rooms_cache[user_id] = rooms
+        defer.returnValue(rooms)
+
+    def invalidate_rooms_for_user(self, user_id):
+        if user_id in self._user_rooms_cache:
+            del self._user_rooms_cache[user_id]
+
+    @defer.inlineCallbacks
     def user_rooms_intersect(self, user_id_list):
         """ Checks whether all the users whose IDs are given in a list share a
         room.
+
+        This is a "hot path" function that's called a lot, e.g. by presence for
+        generating the event stream. As such, it is implemented locally by
+        wrapping logic around heavily-cached database queries.
         """
-        def interaction(txn):
-            user_list_clause = " OR ".join(["m.user_id = ?"] * len(user_id_list))
-            sql = (
-                "SELECT m.room_id FROM room_memberships as m "
-                "INNER JOIN current_state_events as c "
-                "ON m.event_id = c.event_id "
-                "WHERE m.membership = 'join' "
-                "AND (%(clause)s) "
-                # TODO(paul): We've got duplicate rows in the database somewhere
-                #   so we have to DISTINCT m.user_id here
-                "GROUP BY m.room_id HAVING COUNT(DISTINCT m.user_id) = ?"
-            ) % {"clause": user_list_clause}
-
-            args = list(user_id_list)
-            args.append(len(user_id_list))
+        if len(user_id_list) < 2:
+            defer.returnValue(True)
 
-            txn.execute(sql, args)
+        deferreds = [self.get_rooms_for_user(u) for u in user_id_list]
+
+        results = yield defer.DeferredList(deferreds)
+
+        # A list of sets of strings giving room IDs for each user
+        room_id_lists = [set([r.room_id for r in result[1]]) for result in results]
 
-            return len(txn.fetchall()) > 0
+        # There isn't a setintersection(*list_of_sets)
+        ret = len(room_id_lists.pop(0).intersection(*room_id_lists)) > 0
 
-        return self.runInteraction("user_rooms_intersect", interaction)
+        defer.returnValue(ret)