summary refs log tree commit diff
path: root/synapse/storage/roommember.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/roommember.py')
-rw-r--r--synapse/storage/roommember.py93
1 files changed, 64 insertions, 29 deletions
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 05b275663e..e59e65529b 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2014 OpenMarket Ltd
+# Copyright 2014, 2015 OpenMarket Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,6 +15,8 @@
 
 from twisted.internet import defer
 
+from collections import namedtuple
+
 from ._base import SQLBaseStore
 
 from synapse.api.constants import Membership
@@ -24,6 +26,12 @@ import logging
 logger = logging.getLogger(__name__)
 
 
+RoomsForUser = namedtuple(
+    "RoomsForUser",
+    ("room_id", "sender", "membership")
+)
+
+
 class RoomMemberStore(SQLBaseStore):
 
     def _store_room_member_txn(self, txn, event):
@@ -123,6 +131,19 @@ class RoomMemberStore(SQLBaseStore):
         else:
             return None
 
+    def get_users_in_room(self, room_id):
+        def f(txn):
+            sql = (
+                "SELECT m.user_id FROM room_memberships as m"
+                " INNER JOIN current_state_events as c"
+                " ON m.event_id = c.event_id"
+                " WHERE m.membership = ? AND m.room_id = ?"
+            )
+
+            txn.execute(sql, (Membership.JOIN, room_id))
+            return [r[0] for r in txn.fetchall()]
+        return self.runInteraction("get_users_in_room", f)
+
     def get_room_members(self, room_id, membership=None):
         """Retrieve the current room member list for a room.
 
@@ -150,19 +171,37 @@ class RoomMemberStore(SQLBaseStore):
             membership_list (list): A list of synapse.api.constants.Membership
             values which the user must be in.
         Returns:
-            A list of RoomMemberEvent objects
+            A list of dictionary objects, with room_id, membership and sender
+            defined.
         """
         if not membership_list:
             return defer.succeed(None)
 
-        args = [user_id]
-        args.extend(membership_list)
-
         where_clause = "user_id = ? AND (%s)" % (
             " OR ".join(["membership = ?" for _ in membership_list]),
         )
 
-        return self._get_members_query(where_clause, args)
+        args = [user_id]
+        args.extend(membership_list)
+
+        def f(txn):
+            sql = (
+                "SELECT m.room_id, m.sender, m.membership"
+                " FROM room_memberships as m"
+                " INNER JOIN current_state_events as c"
+                " ON m.event_id = c.event_id"
+                " WHERE %s"
+            ) % (where_clause,)
+
+            txn.execute(sql, args)
+            return [
+                RoomsForUser(**r) for r in self.cursor_to_dict(txn)
+            ]
+
+        return self.runInteraction(
+            "get_rooms_for_user_where_membership_is",
+            f
+        )
 
     def get_joined_hosts_for_room(self, room_id):
         return self._simple_select_onecol(
@@ -183,20 +222,14 @@ class RoomMemberStore(SQLBaseStore):
         )
 
     def _get_members_query_txn(self, txn, where_clause, where_values):
-        del_sql = (
-            "SELECT event_id FROM redactions WHERE redacts = e.event_id "
-            "LIMIT 1"
-        )
-
         sql = (
-            "SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
+            "SELECT e.* FROM events as e "
             "INNER JOIN room_memberships as m "
             "ON e.event_id = m.event_id "
             "INNER JOIN current_state_events as c "
             "ON m.event_id = c.event_id "
             "WHERE %(where)s "
         ) % {
-            "redacted": del_sql,
             "where": where_clause,
         }
 
@@ -206,26 +239,28 @@ class RoomMemberStore(SQLBaseStore):
         results = self._parse_events_txn(txn, rows)
         return results
 
-    @defer.inlineCallbacks
     def user_rooms_intersect(self, user_id_list):
         """ Checks whether all the users whose IDs are given in a list share a
         room.
         """
-        user_list_clause = " OR ".join(["m.user_id = ?"] * len(user_id_list))
-        sql = (
-            "SELECT m.room_id FROM room_memberships as m "
-            "INNER JOIN current_state_events as c "
-            "ON m.event_id = c.event_id "
-            "WHERE m.membership = 'join' "
-            "AND (%(clause)s) "
-            # TODO(paul): We've got duplicate rows in the database somewhere
-            #   so we have to DISTINCT m.user_id here
-            "GROUP BY m.room_id HAVING COUNT(DISTINCT m.user_id) = ?"
-        ) % {"clause": user_list_clause}
+        def interaction(txn):
+            user_list_clause = " OR ".join(["m.user_id = ?"] * len(user_id_list))
+            sql = (
+                "SELECT m.room_id FROM room_memberships as m "
+                "INNER JOIN current_state_events as c "
+                "ON m.event_id = c.event_id "
+                "WHERE m.membership = 'join' "
+                "AND (%(clause)s) "
+                # TODO(paul): We've got duplicate rows in the database somewhere
+                #   so we have to DISTINCT m.user_id here
+                "GROUP BY m.room_id HAVING COUNT(DISTINCT m.user_id) = ?"
+            ) % {"clause": user_list_clause}
+
+            args = list(user_id_list)
+            args.append(len(user_id_list))
 
-        args = list(user_id_list)
-        args.append(len(user_id_list))
+            txn.execute(sql, args)
 
-        rows = yield self._execute(None, sql, *args)
+            return len(txn.fetchall()) > 0
 
-        defer.returnValue(len(rows) > 0)
+        return self.runInteraction("user_rooms_intersect", interaction)