diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index e59e65529b..65ffb4627f 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -17,9 +17,10 @@ from twisted.internet import defer
from collections import namedtuple
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, cached
from synapse.api.constants import Membership
+from synapse.types import UserID
import logging
@@ -39,7 +40,7 @@ class RoomMemberStore(SQLBaseStore):
"""
try:
target_user_id = event.state_key
- domain = self.hs.parse_userid(target_user_id).domain
+ domain = UserID.from_string(target_user_id).domain
except:
logger.exception(
"Failed to parse target_user_id=%s", target_user_id
@@ -84,7 +85,7 @@ class RoomMemberStore(SQLBaseStore):
for e in member_events:
try:
joined_domains.add(
- self.hs.parse_userid(e.state_key).domain
+ UserID.from_string(e.state_key).domain
)
except:
# FIXME: How do we deal with invalid user ids in the db?
@@ -97,6 +98,8 @@ class RoomMemberStore(SQLBaseStore):
txn.execute(sql, (event.room_id, domain))
+ self.get_rooms_for_user.invalidate(target_user_id)
+
@defer.inlineCallbacks
def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member.
@@ -177,6 +180,14 @@ class RoomMemberStore(SQLBaseStore):
if not membership_list:
return defer.succeed(None)
+ return self.runInteraction(
+ "get_rooms_for_user_where_membership_is",
+ self._get_rooms_for_user_where_membership_is_txn,
+ user_id, membership_list
+ )
+
+ def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id,
+ membership_list):
where_clause = "user_id = ? AND (%s)" % (
" OR ".join(["membership = ?" for _ in membership_list]),
)
@@ -184,24 +195,18 @@ class RoomMemberStore(SQLBaseStore):
args = [user_id]
args.extend(membership_list)
- def f(txn):
- sql = (
- "SELECT m.room_id, m.sender, m.membership"
- " FROM room_memberships as m"
- " INNER JOIN current_state_events as c"
- " ON m.event_id = c.event_id"
- " WHERE %s"
- ) % (where_clause,)
-
- txn.execute(sql, args)
- return [
- RoomsForUser(**r) for r in self.cursor_to_dict(txn)
- ]
+ sql = (
+ "SELECT m.room_id, m.sender, m.membership"
+ " FROM room_memberships as m"
+ " INNER JOIN current_state_events as c"
+ " ON m.event_id = c.event_id"
+ " WHERE %s"
+ ) % (where_clause,)
- return self.runInteraction(
- "get_rooms_for_user_where_membership_is",
- f
- )
+ txn.execute(sql, args)
+ return [
+ RoomsForUser(**r) for r in self.cursor_to_dict(txn)
+ ]
def get_joined_hosts_for_room(self, room_id):
return self._simple_select_onecol(
@@ -239,28 +244,32 @@ class RoomMemberStore(SQLBaseStore):
results = self._parse_events_txn(txn, rows)
return results
+ @cached()
+ def get_rooms_for_user(self, user_id):
+ return self.get_rooms_for_user_where_membership_is(
+ user_id, membership_list=[Membership.JOIN],
+ )
+
+ @defer.inlineCallbacks
def user_rooms_intersect(self, user_id_list):
""" Checks whether all the users whose IDs are given in a list share a
room.
+
+ This is a "hot path" function that's called a lot, e.g. by presence for
+ generating the event stream. As such, it is implemented locally by
+ wrapping logic around heavily-cached database queries.
"""
- def interaction(txn):
- user_list_clause = " OR ".join(["m.user_id = ?"] * len(user_id_list))
- sql = (
- "SELECT m.room_id FROM room_memberships as m "
- "INNER JOIN current_state_events as c "
- "ON m.event_id = c.event_id "
- "WHERE m.membership = 'join' "
- "AND (%(clause)s) "
- # TODO(paul): We've got duplicate rows in the database somewhere
- # so we have to DISTINCT m.user_id here
- "GROUP BY m.room_id HAVING COUNT(DISTINCT m.user_id) = ?"
- ) % {"clause": user_list_clause}
+ if len(user_id_list) < 2:
+ defer.returnValue(True)
+
+ deferreds = [self.get_rooms_for_user(u) for u in user_id_list]
- args = list(user_id_list)
- args.append(len(user_id_list))
+ results = yield defer.DeferredList(deferreds, consumeErrors=True)
- txn.execute(sql, args)
+ # A list of sets of strings giving room IDs for each user
+ room_id_lists = [set([r.room_id for r in result[1]]) for result in results]
- return len(txn.fetchall()) > 0
+ # There isn't a setintersection(*list_of_sets)
+ ret = len(room_id_lists.pop(0).intersection(*room_id_lists)) > 0
- return self.runInteraction("user_rooms_intersect", interaction)
+ defer.returnValue(ret)
|