diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 5cb26ad6db..befeb55b25 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -64,7 +64,11 @@ class SQLBaseStore(object):
def interaction(txn):
cursor = txn.execute(query, args)
- return decoder(cursor)
+ if decoder:
+ return decoder(cursor)
+ else:
+ return cursor
+
return self._db_pool.runInteraction(interaction)
def _execut_query(self, query, *args):
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index ef73be4af4..60296380e6 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -31,6 +31,38 @@ logger = logging.getLogger(__name__)
class RoomMemberStore(SQLBaseStore):
+ @defer.inlineCallbacks
+ def _store_room_member(self, event):
+ """Store a room member in the database.
+ """
+ domain = self.hs.parse_userid(event.target_user_id).domain
+
+ yield self._simple_insert(
+ "room_memberships",
+ {
+ "event_id": event.event_id,
+ "user_id": event.target_user_id,
+ "sender": event.user_id,
+ "room_id": event.room_id,
+ "membership": event.membership,
+ }
+ )
+
+ # Update room hosts table
+ if event.membership == Membership.JOIN:
+ sql = (
+ "INSERT OR IGNORE INTO room_hosts (room_id, host) "
+ "VALUES (?, ?)"
+ )
+ yield self._execute(None, sql, room_id, domain)
+ else:
+ sql = (
+ "DELETE FROM room_hosts WHERE room_id = ? AND host = ?"
+ )
+
+ yield self._execute(None, sql, room_id, domain)
+
+
def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member.
@@ -38,36 +70,13 @@ class RoomMemberStore(SQLBaseStore):
user_id (str): The member's user ID.
room_id (str): The room the member is in.
Returns:
- namedtuple: The room member from the database, or None if this
- member does not exist.
+ Deferred: Results in a MembershipEvent or None.
"""
- query = RoomMemberTable.select_statement(
- "room_id = ? AND user_id = ? ORDER BY id DESC LIMIT 1")
- return self._execute(
- RoomMemberTable.decode_single_result,
- query, room_id, user_id,
- )
-
- def store_room_member(self, user_id, sender, room_id, membership, content):
- """Store a room member in the database.
-
- Args:
- user_id (str): The member's user ID.
- room_id (str): The room in relation to the member.
- membership (synapse.api.constants.Membership): The new membership
- state.
- content (dict): The content of the membership (JSON).
- """
- content_json = json.dumps(content)
- return self._simple_insert(RoomMemberTable.table_name, dict(
- user_id=user_id,
- sender=sender,
+ return self._get_members_by_dict(
room_id=room_id,
- membership=membership,
- content=content_json,
- ))
+ user_id=user_id
+ )
- @defer.inlineCallbacks
def get_room_members(self, room_id, membership=None):
"""Retrieve the current room member list for a room.
@@ -79,17 +88,12 @@ class RoomMemberStore(SQLBaseStore):
Returns:
list of namedtuples representing the members in this room.
"""
- query = RoomMemberTable.select_statement(
- "id IN (SELECT MAX(id) FROM " + RoomMemberTable.table_name
- + " WHERE room_id = ? GROUP BY user_id)"
- )
- res = yield self._execute(
- RoomMemberTable.decode_results, query, room_id,
- )
- # strip memberships which don't match
+
+ where = {"room_id": room_id}
if membership:
- res = [entry for entry in res if entry.membership == membership]
- defer.returnValue(res)
+ where["membership"] = membership
+
+ return self._get_members_by_dict(**membership)
def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
""" Get all the rooms for this user where the membership for this user
@@ -106,67 +110,37 @@ class RoomMemberStore(SQLBaseStore):
return defer.succeed(None)
args = [user_id]
- membership_placeholder = ["membership=?"] * len(membership_list)
- where_membership = "(" + " OR ".join(membership_placeholder) + ")"
- for membership in membership_list:
- args.append(membership)
-
- query = ("SELECT room_id, membership FROM room_memberships"
- + " WHERE user_id=? AND " + where_membership
- + " GROUP BY room_id ORDER BY id DESC")
- return self._execute(
- self.cursor_to_dict, query, *args
- )
+ args.extend(membership_list)
- @defer.inlineCallbacks
- def get_joined_hosts_for_room(self, room_id):
- query = RoomMemberTable.select_statement(
- "id IN (SELECT MAX(id) FROM " + RoomMemberTable.table_name
- + " WHERE room_id = ? GROUP BY user_id)"
+ where_clause "user_id = ? AND (%s)" % (
+ " OR ".join(["membership = ?" for _ in membership_list]),
)
- res = yield self._execute(
- RoomMemberTable.decode_results, query, room_id,
- )
-
- def host_from_user_id_string(user_id):
- domain = UserID.from_string(entry.user_id, self.hs).domain
- return domain
-
- # strip memberships which don't match
- hosts = [
- host_from_user_id_string(entry.user_id)
- for entry in res
- if entry.membership == Membership.JOIN
- ]
+ return self._get_members_query(where_clause, args)
- logger.debug("Returning hosts: %s from results: %s", hosts, res)
-
- defer.returnValue(hosts)
-
- def get_max_room_member_id(self):
- return self._simple_max_id(RoomMemberTable.table_name)
-
-
-class RoomMemberTable(Table):
- table_name = "room_memberships"
-
- fields = [
- "id",
- "user_id",
- "sender",
- "room_id",
- "membership",
- "content"
- ]
+ def get_joined_hosts_for_room(self, room_id):
+ return self._simple_select_onecol(
+ "room_hosts",
+ {"room_id": room_id},
+ "host"
+ )
- class EntryType(collections.namedtuple("RoomMemberEntry", fields)):
+ def _get_members_by_dict(self, where_dict):
+ clause = " AND ".join("%s = ?" % k for k in where.keys())
+ vals = where.values()
+ return self._get_members_query(clause, vals)
- def as_event(self, event_factory):
- return event_factory.create_event(
- etype=RoomMemberEvent.TYPE,
- room_id=self.room_id,
- target_user_id=self.user_id,
- user_id=self.sender,
- content=json.loads(self.content),
- )
+ @defer.inlineCallbacks
+ def _get_members_query(self, where_clause, where_values):
+ sql = (
+ "SELECT e.* FROM events as e "
+ "INNER JOIN room_memberships as m "
+ "ON e.event_id = m.event_id "
+ "INNER JOIN current_state as c "
+ "ON m.event_id = c.event_id "
+ "WHERE %s "
+ ) % (where_clause,)
+
+ rows = yield self._execute_query(sql, where_values)
+ results = [self._parse_event_from_row(r) for r in rows]
+ defer.returnValue(results)
diff --git a/synapse/storage/schema/im.sql b/synapse/storage/schema/im.sql
index 37b7c6c74f..7f564c8540 100644
--- a/synapse/storage/schema/im.sql
+++ b/synapse/storage/schema/im.sql
@@ -17,7 +17,6 @@ CREATE TABLE IF NOT EXISTS events(
ordering INTEGER PRIMARY KEY AUTOINCREMENT,
event_id TEXT NOT NULL,
type TEXT NOT NULL,
--- sender TEXT,
room_id TEXT,
content TEXT,
unrecognized_keys TEXT
@@ -57,3 +56,8 @@ CREATE TABLE IF NOT EXISTS rooms(
is_public INTEGER,
creator TEXT
);
+
+CREATE TABLE IF NOT EXISTS room_hosts(
+ room_id TEXT NOT NULL,
+ host TEXT NOT NULL
+);
|