summary refs log tree commit diff
path: root/synapse/storage/roommember.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/roommember.py')
-rw-r--r--synapse/storage/roommember.py172
1 files changed, 74 insertions, 98 deletions
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index c45d128f1b..89c87290cf 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -31,6 +31,38 @@ logger = logging.getLogger(__name__)
 
 class RoomMemberStore(SQLBaseStore):
 
+    @defer.inlineCallbacks
+    def _store_room_member(self, event):
+        """Store a room member in the database.
+        """
+        domain = self.hs.parse_userid(event.target_user_id).domain
+
+        yield self._simple_insert(
+            "room_memberships",
+            {
+                "event_id": event.event_id,
+                "user_id": event.target_user_id,
+                "sender": event.user_id,
+                "room_id": event.room_id,
+                "membership": event.membership,
+            }
+        )
+
+        # Update room hosts table
+        if event.membership == Membership.JOIN:
+            sql = (
+                "INSERT OR IGNORE INTO room_hosts (room_id, host) "
+                "VALUES (?, ?)"
+            )
+            yield self._execute(None, sql, event.room_id, domain)
+        else:
+            sql = (
+                "DELETE FROM room_hosts WHERE room_id = ? AND host = ?"
+            )
+
+            yield self._execute(None, sql, event.room_id, domain)
+
+    @defer.inlineCallbacks
     def get_room_member(self, user_id, room_id):
         """Retrieve the current state of a room member.
 
@@ -38,36 +70,15 @@ class RoomMemberStore(SQLBaseStore):
             user_id (str): The member's user ID.
             room_id (str): The room the member is in.
         Returns:
-            namedtuple: The room member from the database, or None if this
-            member does not exist.
+            Deferred: Results in a MembershipEvent or None.
         """
-        query = RoomMemberTable.select_statement(
-            "room_id = ? AND user_id = ? ORDER BY id DESC LIMIT 1")
-        return self._execute(
-            RoomMemberTable.decode_single_result,
-            query, room_id, user_id,
-        )
+        rows = yield self._get_members_by_dict({
+            "e.room_id": room_id,
+            "m.user_id": user_id,
+        })
 
-    def store_room_member(self, user_id, sender, room_id, membership, content):
-        """Store a room member in the database.
+        defer.returnValue(rows[0] if rows else None)
 
-        Args:
-            user_id (str): The member's user ID.
-            room_id (str): The room in relation to the member.
-            membership (synapse.api.constants.Membership): The new membership
-            state.
-            content (dict): The content of the membership (JSON).
-        """
-        content_json = json.dumps(content)
-        return self._simple_insert(RoomMemberTable.table_name, dict(
-            user_id=user_id,
-            sender=sender,
-            room_id=room_id,
-            membership=membership,
-            content=content_json,
-        ))
-
-    @defer.inlineCallbacks
     def get_room_members(self, room_id, membership=None):
         """Retrieve the current room member list for a room.
 
@@ -79,17 +90,12 @@ class RoomMemberStore(SQLBaseStore):
         Returns:
             list of namedtuples representing the members in this room.
         """
-        query = RoomMemberTable.select_statement(
-            "id IN (SELECT MAX(id) FROM " + RoomMemberTable.table_name
-            + " WHERE room_id = ? GROUP BY user_id)"
-        )
-        res = yield self._execute(
-            RoomMemberTable.decode_results, query, room_id,
-        )
-        # strip memberships which don't match
+
+        where = {"m.room_id": room_id}
         if membership:
-            res = [entry for entry in res if entry.membership == membership]
-        defer.returnValue(res)
+            where["m.membership"] = membership
+
+        return self._get_members_by_dict(where)
 
     def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
         """ Get all the rooms for this user where the membership for this user
@@ -106,70 +112,40 @@ class RoomMemberStore(SQLBaseStore):
             return defer.succeed(None)
 
         args = [user_id]
-        membership_placeholder = ["membership=?"] * len(membership_list)
-        where_membership = "(" + " OR ".join(membership_placeholder) + ")"
-        for membership in membership_list:
-            args.append(membership)
-
-        # sub-select finds the row ID for the most recent (i.e. current)
-        # state change of this user per room, then the outer select finds those
-        query = ("SELECT room_id, membership FROM room_memberships"
-                 + " WHERE id IN (SELECT MAX(id) FROM room_memberships"
-                 + "   WHERE user_id=? GROUP BY room_id)"
-                 + " AND " + where_membership)
-        return self._execute(
-            self.cursor_to_dict, query, *args
-        )
+        args.extend(membership_list)
 
-    @defer.inlineCallbacks
-    def get_joined_hosts_for_room(self, room_id):
-        query = RoomMemberTable.select_statement(
-            "id IN (SELECT MAX(id) FROM " + RoomMemberTable.table_name
-            + " WHERE room_id = ? GROUP BY user_id)"
-        )
-
-        res = yield self._execute(
-            RoomMemberTable.decode_results, query, room_id,
+        where_clause = "user_id = ? AND (%s)" % (
+            " OR ".join(["membership = ?" for _ in membership_list]),
         )
 
-        def host_from_user_id_string(user_id):
-            domain = UserID.from_string(entry.user_id, self.hs).domain
-            return domain
-
-        # strip memberships which don't match
-        hosts = [
-            host_from_user_id_string(entry.user_id)
-            for entry in res
-            if entry.membership == Membership.JOIN
-        ]
+        return self._get_members_query(where_clause, args)
 
-        logger.debug("Returning hosts: %s from results: %s", hosts, res)
-
-        defer.returnValue(hosts)
-
-    def get_max_room_member_id(self):
-        return self._simple_max_id(RoomMemberTable.table_name)
-
-
-class RoomMemberTable(Table):
-    table_name = "room_memberships"
-
-    fields = [
-        "id",
-        "user_id",
-        "sender",
-        "room_id",
-        "membership",
-        "content"
-    ]
+    def get_joined_hosts_for_room(self, room_id):
+        return self._simple_select_onecol(
+            "room_hosts",
+            {"room_id": room_id},
+            "host"
+        )
 
-    class EntryType(collections.namedtuple("RoomMemberEntry", fields)):
+    def _get_members_by_dict(self, where_dict):
+        clause = " AND ".join("%s = ?" % k for k in where_dict.keys())
+        vals = where_dict.values()
+        return self._get_members_query(clause, vals)
 
-        def as_event(self, event_factory):
-            return event_factory.create_event(
-                etype=RoomMemberEvent.TYPE,
-                room_id=self.room_id,
-                target_user_id=self.user_id,
-                user_id=self.sender,
-                content=json.loads(self.content),
-            )
+    @defer.inlineCallbacks
+    def _get_members_query(self, where_clause, where_values):
+        sql = (
+            "SELECT e.* FROM events as e "
+            "INNER JOIN room_memberships as m "
+            "ON e.event_id = m.event_id "
+            "INNER JOIN current_state_events as c "
+            "ON m.event_id = c.event_id "
+            "WHERE %s "
+        ) % (where_clause,)
+
+        rows = yield self._execute_and_decode(sql, *where_values)
+
+        logger.debug("_get_members_query Got rows %s", rows)
+
+        results = [self._parse_event_from_row(r) for r in rows]
+        defer.returnValue(results)