summary refs log tree commit diff
path: root/synapse/storage/roommember.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/roommember.py')
-rw-r--r--synapse/storage/roommember.py83
1 files changed, 46 insertions, 37 deletions
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index e59e65529b..65ffb4627f 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -17,9 +17,10 @@ from twisted.internet import defer
 
 from collections import namedtuple
 
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, cached
 
 from synapse.api.constants import Membership
+from synapse.types import UserID
 
 import logging
 
@@ -39,7 +40,7 @@ class RoomMemberStore(SQLBaseStore):
         """
         try:
             target_user_id = event.state_key
-            domain = self.hs.parse_userid(target_user_id).domain
+            domain = UserID.from_string(target_user_id).domain
         except:
             logger.exception(
                 "Failed to parse target_user_id=%s", target_user_id
@@ -84,7 +85,7 @@ class RoomMemberStore(SQLBaseStore):
             for e in member_events:
                 try:
                     joined_domains.add(
-                        self.hs.parse_userid(e.state_key).domain
+                        UserID.from_string(e.state_key).domain
                     )
                 except:
                     # FIXME: How do we deal with invalid user ids in the db?
@@ -97,6 +98,8 @@ class RoomMemberStore(SQLBaseStore):
 
                 txn.execute(sql, (event.room_id, domain))
 
+        self.get_rooms_for_user.invalidate(target_user_id)
+
     @defer.inlineCallbacks
     def get_room_member(self, user_id, room_id):
         """Retrieve the current state of a room member.
@@ -177,6 +180,14 @@ class RoomMemberStore(SQLBaseStore):
         if not membership_list:
             return defer.succeed(None)
 
+        return self.runInteraction(
+            "get_rooms_for_user_where_membership_is",
+            self._get_rooms_for_user_where_membership_is_txn,
+            user_id, membership_list
+        )
+
+    def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id,
+                                                    membership_list):
         where_clause = "user_id = ? AND (%s)" % (
             " OR ".join(["membership = ?" for _ in membership_list]),
         )
@@ -184,24 +195,18 @@ class RoomMemberStore(SQLBaseStore):
         args = [user_id]
         args.extend(membership_list)
 
-        def f(txn):
-            sql = (
-                "SELECT m.room_id, m.sender, m.membership"
-                " FROM room_memberships as m"
-                " INNER JOIN current_state_events as c"
-                " ON m.event_id = c.event_id"
-                " WHERE %s"
-            ) % (where_clause,)
-
-            txn.execute(sql, args)
-            return [
-                RoomsForUser(**r) for r in self.cursor_to_dict(txn)
-            ]
+        sql = (
+            "SELECT m.room_id, m.sender, m.membership"
+            " FROM room_memberships as m"
+            " INNER JOIN current_state_events as c"
+            " ON m.event_id = c.event_id"
+            " WHERE %s"
+        ) % (where_clause,)
 
-        return self.runInteraction(
-            "get_rooms_for_user_where_membership_is",
-            f
-        )
+        txn.execute(sql, args)
+        return [
+            RoomsForUser(**r) for r in self.cursor_to_dict(txn)
+        ]
 
     def get_joined_hosts_for_room(self, room_id):
         return self._simple_select_onecol(
@@ -239,28 +244,32 @@ class RoomMemberStore(SQLBaseStore):
         results = self._parse_events_txn(txn, rows)
         return results
 
+    @cached()
+    def get_rooms_for_user(self, user_id):
+        return self.get_rooms_for_user_where_membership_is(
+            user_id, membership_list=[Membership.JOIN],
+        )
+
+    @defer.inlineCallbacks
     def user_rooms_intersect(self, user_id_list):
         """ Checks whether all the users whose IDs are given in a list share a
         room.
+
+        This is a "hot path" function that's called a lot, e.g. by presence for
+        generating the event stream. As such, it is implemented locally by
+        wrapping logic around heavily-cached database queries.
         """
-        def interaction(txn):
-            user_list_clause = " OR ".join(["m.user_id = ?"] * len(user_id_list))
-            sql = (
-                "SELECT m.room_id FROM room_memberships as m "
-                "INNER JOIN current_state_events as c "
-                "ON m.event_id = c.event_id "
-                "WHERE m.membership = 'join' "
-                "AND (%(clause)s) "
-                # TODO(paul): We've got duplicate rows in the database somewhere
-                #   so we have to DISTINCT m.user_id here
-                "GROUP BY m.room_id HAVING COUNT(DISTINCT m.user_id) = ?"
-            ) % {"clause": user_list_clause}
+        if len(user_id_list) < 2:
+            defer.returnValue(True)
+
+        deferreds = [self.get_rooms_for_user(u) for u in user_id_list]
 
-            args = list(user_id_list)
-            args.append(len(user_id_list))
+        results = yield defer.DeferredList(deferreds, consumeErrors=True)
 
-            txn.execute(sql, args)
+        # A list of sets of strings giving room IDs for each user
+        room_id_lists = [set([r.room_id for r in result[1]]) for result in results]
 
-            return len(txn.fetchall()) > 0
+        # There isn't a setintersection(*list_of_sets)
+        ret = len(room_id_lists.pop(0).intersection(*room_id_lists)) > 0
 
-        return self.runInteraction("user_rooms_intersect", interaction)
+        defer.returnValue(ret)