diff options
Diffstat (limited to 'synapse/storage/roommember.py')
-rw-r--r-- | synapse/storage/roommember.py | 37 |
1 files changed, 27 insertions, 10 deletions
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 1df043cd36..5038aeea03 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -15,15 +15,10 @@ from twisted.internet import defer -from synapse.types import UserID -from synapse.api.constants import Membership -from synapse.api.events.room import RoomMemberEvent - -from ._base import SQLBaseStore, Table +from ._base import SQLBaseStore +from synapse.api.constants import Membership -import collections -import json import logging logger = logging.getLogger(__name__) @@ -34,14 +29,15 @@ class RoomMemberStore(SQLBaseStore): def _store_room_member_txn(self, txn, event): """Store a room member in the database. """ - domain = self.hs.parse_userid(event.target_user_id).domain + target_user_id = event.state_key + domain = self.hs.parse_userid(target_user_id).domain self._simple_insert_txn( txn, "room_memberships", { "event_id": event.event_id, - "user_id": event.target_user_id, + "user_id": target_user_id, "sender": event.user_id, "room_id": event.room_id, "membership": event.membership, @@ -145,7 +141,28 @@ class RoomMemberStore(SQLBaseStore): rows = yield self._execute_and_decode(sql, *where_values) - logger.debug("_get_members_query Got rows %s", rows) + # logger.debug("_get_members_query Got rows %s", rows) results = [self._parse_event_from_row(r) for r in rows] defer.returnValue(results) + + @defer.inlineCallbacks + def do_users_share_a_room(self, user_list): + """ Checks whether a list of users share a room. + """ + user_list_clause = " OR ".join(["m.user_id = ?"] * len(user_list)) + sql = ( + "SELECT m.room_id FROM room_memberships as m " + "INNER JOIN current_state_events as c " + "ON m.event_id = c.event_id " + "WHERE m.membership = 'join' " + "AND (%(clause)s) " + "GROUP BY m.room_id HAVING COUNT(m.room_id) = ?" + ) % {"clause": user_list_clause} + + args = user_list + args.append(len(user_list)) + + rows = yield self._execute(None, sql, *args) + + defer.returnValue(len(rows) > 0) |