diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index c45d128f1b..89c87290cf 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -31,6 +31,38 @@ logger = logging.getLogger(__name__)
class RoomMemberStore(SQLBaseStore):
+ @defer.inlineCallbacks
+ def _store_room_member(self, event):
+ """Store a room member in the database.
+ """
+ domain = self.hs.parse_userid(event.target_user_id).domain
+
+ yield self._simple_insert(
+ "room_memberships",
+ {
+ "event_id": event.event_id,
+ "user_id": event.target_user_id,
+ "sender": event.user_id,
+ "room_id": event.room_id,
+ "membership": event.membership,
+ }
+ )
+
+ # Update room hosts table
+ if event.membership == Membership.JOIN:
+ sql = (
+ "INSERT OR IGNORE INTO room_hosts (room_id, host) "
+ "VALUES (?, ?)"
+ )
+ yield self._execute(None, sql, event.room_id, domain)
+ else:
+ sql = (
+ "DELETE FROM room_hosts WHERE room_id = ? AND host = ?"
+ )
+
+ yield self._execute(None, sql, event.room_id, domain)
+
+ @defer.inlineCallbacks
def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member.
@@ -38,36 +70,15 @@ class RoomMemberStore(SQLBaseStore):
user_id (str): The member's user ID.
room_id (str): The room the member is in.
Returns:
- namedtuple: The room member from the database, or None if this
- member does not exist.
+ Deferred: Results in a MembershipEvent or None.
"""
- query = RoomMemberTable.select_statement(
- "room_id = ? AND user_id = ? ORDER BY id DESC LIMIT 1")
- return self._execute(
- RoomMemberTable.decode_single_result,
- query, room_id, user_id,
- )
+ rows = yield self._get_members_by_dict({
+ "e.room_id": room_id,
+ "m.user_id": user_id,
+ })
- def store_room_member(self, user_id, sender, room_id, membership, content):
- """Store a room member in the database.
+ defer.returnValue(rows[0] if rows else None)
- Args:
- user_id (str): The member's user ID.
- room_id (str): The room in relation to the member.
- membership (synapse.api.constants.Membership): The new membership
- state.
- content (dict): The content of the membership (JSON).
- """
- content_json = json.dumps(content)
- return self._simple_insert(RoomMemberTable.table_name, dict(
- user_id=user_id,
- sender=sender,
- room_id=room_id,
- membership=membership,
- content=content_json,
- ))
-
- @defer.inlineCallbacks
def get_room_members(self, room_id, membership=None):
"""Retrieve the current room member list for a room.
@@ -79,17 +90,12 @@ class RoomMemberStore(SQLBaseStore):
Returns:
list of namedtuples representing the members in this room.
"""
- query = RoomMemberTable.select_statement(
- "id IN (SELECT MAX(id) FROM " + RoomMemberTable.table_name
- + " WHERE room_id = ? GROUP BY user_id)"
- )
- res = yield self._execute(
- RoomMemberTable.decode_results, query, room_id,
- )
- # strip memberships which don't match
+
+ where = {"m.room_id": room_id}
if membership:
- res = [entry for entry in res if entry.membership == membership]
- defer.returnValue(res)
+ where["m.membership"] = membership
+
+ return self._get_members_by_dict(where)
def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
""" Get all the rooms for this user where the membership for this user
@@ -106,70 +112,40 @@ class RoomMemberStore(SQLBaseStore):
return defer.succeed(None)
args = [user_id]
- membership_placeholder = ["membership=?"] * len(membership_list)
- where_membership = "(" + " OR ".join(membership_placeholder) + ")"
- for membership in membership_list:
- args.append(membership)
-
- # sub-select finds the row ID for the most recent (i.e. current)
- # state change of this user per room, then the outer select finds those
- query = ("SELECT room_id, membership FROM room_memberships"
- + " WHERE id IN (SELECT MAX(id) FROM room_memberships"
- + " WHERE user_id=? GROUP BY room_id)"
- + " AND " + where_membership)
- return self._execute(
- self.cursor_to_dict, query, *args
- )
+ args.extend(membership_list)
- @defer.inlineCallbacks
- def get_joined_hosts_for_room(self, room_id):
- query = RoomMemberTable.select_statement(
- "id IN (SELECT MAX(id) FROM " + RoomMemberTable.table_name
- + " WHERE room_id = ? GROUP BY user_id)"
- )
-
- res = yield self._execute(
- RoomMemberTable.decode_results, query, room_id,
+ where_clause = "user_id = ? AND (%s)" % (
+ " OR ".join(["membership = ?" for _ in membership_list]),
)
- def host_from_user_id_string(user_id):
- domain = UserID.from_string(entry.user_id, self.hs).domain
- return domain
-
- # strip memberships which don't match
- hosts = [
- host_from_user_id_string(entry.user_id)
- for entry in res
- if entry.membership == Membership.JOIN
- ]
+ return self._get_members_query(where_clause, args)
- logger.debug("Returning hosts: %s from results: %s", hosts, res)
-
- defer.returnValue(hosts)
-
- def get_max_room_member_id(self):
- return self._simple_max_id(RoomMemberTable.table_name)
-
-
-class RoomMemberTable(Table):
- table_name = "room_memberships"
-
- fields = [
- "id",
- "user_id",
- "sender",
- "room_id",
- "membership",
- "content"
- ]
+ def get_joined_hosts_for_room(self, room_id):
+ return self._simple_select_onecol(
+ "room_hosts",
+ {"room_id": room_id},
+ "host"
+ )
- class EntryType(collections.namedtuple("RoomMemberEntry", fields)):
+ def _get_members_by_dict(self, where_dict):
+ clause = " AND ".join("%s = ?" % k for k in where_dict.keys())
+ vals = where_dict.values()
+ return self._get_members_query(clause, vals)
- def as_event(self, event_factory):
- return event_factory.create_event(
- etype=RoomMemberEvent.TYPE,
- room_id=self.room_id,
- target_user_id=self.user_id,
- user_id=self.sender,
- content=json.loads(self.content),
- )
+ @defer.inlineCallbacks
+ def _get_members_query(self, where_clause, where_values):
+ sql = (
+ "SELECT e.* FROM events as e "
+ "INNER JOIN room_memberships as m "
+ "ON e.event_id = m.event_id "
+ "INNER JOIN current_state_events as c "
+ "ON m.event_id = c.event_id "
+ "WHERE %s "
+ ) % (where_clause,)
+
+ rows = yield self._execute_and_decode(sql, *where_values)
+
+ logger.debug("_get_members_query Got rows %s", rows)
+
+ results = [self._parse_event_from_row(r) for r in rows]
+ defer.returnValue(results)
|