diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/roommember.py | 70 |
1 files changed, 49 insertions, 21 deletions
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 27b7d8eb13..e59e65529b 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -15,6 +15,8 @@ from twisted.internet import defer +from collections import namedtuple + from ._base import SQLBaseStore from synapse.api.constants import Membership @@ -24,6 +26,12 @@ import logging logger = logging.getLogger(__name__) +RoomsForUser = namedtuple( + "RoomsForUser", + ("room_id", "sender", "membership") +) + + class RoomMemberStore(SQLBaseStore): def _store_room_member_txn(self, txn, event): @@ -163,19 +171,37 @@ class RoomMemberStore(SQLBaseStore): membership_list (list): A list of synapse.api.constants.Membership values which the user must be in. Returns: - A list of RoomMemberEvent objects + A list of dictionary objects, with room_id, membership and sender + defined. """ if not membership_list: return defer.succeed(None) - args = [user_id] - args.extend(membership_list) - where_clause = "user_id = ? AND (%s)" % ( " OR ".join(["membership = ?" for _ in membership_list]), ) - return self._get_members_query(where_clause, args) + args = [user_id] + args.extend(membership_list) + + def f(txn): + sql = ( + "SELECT m.room_id, m.sender, m.membership" + " FROM room_memberships as m" + " INNER JOIN current_state_events as c" + " ON m.event_id = c.event_id" + " WHERE %s" + ) % (where_clause,) + + txn.execute(sql, args) + return [ + RoomsForUser(**r) for r in self.cursor_to_dict(txn) + ] + + return self.runInteraction( + "get_rooms_for_user_where_membership_is", + f + ) def get_joined_hosts_for_room(self, room_id): return self._simple_select_onecol( @@ -213,26 +239,28 @@ class RoomMemberStore(SQLBaseStore): results = self._parse_events_txn(txn, rows) return results - @defer.inlineCallbacks def user_rooms_intersect(self, user_id_list): """ Checks whether all the users whose IDs are given in a list share a room. """ - user_list_clause = " OR ".join(["m.user_id = ?"] * len(user_id_list)) - sql = ( - "SELECT m.room_id FROM room_memberships as m " - "INNER JOIN current_state_events as c " - "ON m.event_id = c.event_id " - "WHERE m.membership = 'join' " - "AND (%(clause)s) " - # TODO(paul): We've got duplicate rows in the database somewhere - # so we have to DISTINCT m.user_id here - "GROUP BY m.room_id HAVING COUNT(DISTINCT m.user_id) = ?" - ) % {"clause": user_list_clause} + def interaction(txn): + user_list_clause = " OR ".join(["m.user_id = ?"] * len(user_id_list)) + sql = ( + "SELECT m.room_id FROM room_memberships as m " + "INNER JOIN current_state_events as c " + "ON m.event_id = c.event_id " + "WHERE m.membership = 'join' " + "AND (%(clause)s) " + # TODO(paul): We've got duplicate rows in the database somewhere + # so we have to DISTINCT m.user_id here + "GROUP BY m.room_id HAVING COUNT(DISTINCT m.user_id) = ?" + ) % {"clause": user_list_clause} + + args = list(user_id_list) + args.append(len(user_id_list)) - args = list(user_id_list) - args.append(len(user_id_list)) + txn.execute(sql, args) - rows = yield self._execute(None, sql, *args) + return len(txn.fetchall()) > 0 - defer.returnValue(len(rows) > 0) + return self.runInteraction("user_rooms_intersect", interaction) |