diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 0cd89260f2..866d64e679 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -20,8 +20,8 @@ from collections import namedtuple
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
-from synapse.api.constants import Membership
-from synapse.types import UserID
+from synapse.api.constants import Membership, EventTypes
+from synapse.types import get_domain_from_id
import logging
@@ -36,7 +36,7 @@ RoomsForUser = namedtuple(
class RoomMemberStore(SQLBaseStore):
- def _store_room_members_txn(self, txn, events):
+ def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database.
"""
self._simple_insert_many_txn(
@@ -56,33 +56,70 @@ class RoomMemberStore(SQLBaseStore):
for event in events:
txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
- txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after(
self._membership_stream_cache.entity_has_changed,
event.state_key, event.internal_metadata.stream_ordering
)
+ txn.call_after(
+ self.get_invited_rooms_for_user.invalidate, (event.state_key,)
+ )
- def get_room_member(self, user_id, room_id):
- """Retrieve the current state of a room member.
+ # We update the local_invites table only if the event is "current",
+ # i.e., its something that has just happened.
+ # The only current event that can also be an outlier is if its an
+ # invite that has come in across federation.
+ is_new_state = not backfilled and (
+ not event.internal_metadata.is_outlier()
+ or event.internal_metadata.is_invite_from_remote()
+ )
+ is_mine = self.hs.is_mine_id(event.state_key)
+ if is_new_state and is_mine:
+ if event.membership == Membership.INVITE:
+ self._simple_insert_txn(
+ txn,
+ table="local_invites",
+ values={
+ "event_id": event.event_id,
+ "invitee": event.state_key,
+ "inviter": event.sender,
+ "room_id": event.room_id,
+ "stream_id": event.internal_metadata.stream_ordering,
+ }
+ )
+ else:
+ sql = (
+ "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
+ " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+ " AND replaced_by is NULL"
+ )
+
+ txn.execute(sql, (
+ event.internal_metadata.stream_ordering,
+ event.event_id,
+ event.room_id,
+ event.state_key,
+ ))
- Args:
- user_id (str): The member's user ID.
- room_id (str): The room the member is in.
- Returns:
- Deferred: Results in a MembershipEvent or None.
- """
- return self.runInteraction(
- "get_room_member",
- self._get_members_events_txn,
- room_id,
- user_id=user_id,
- ).addCallback(
- self._get_events
- ).addCallback(
- lambda events: events[0] if events else None
+ @defer.inlineCallbacks
+ def locally_reject_invite(self, user_id, room_id):
+ sql = (
+ "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
+ " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+ " AND replaced_by is NULL"
)
+ def f(txn, stream_ordering):
+ txn.execute(sql, (
+ stream_ordering,
+ True,
+ room_id,
+ user_id,
+ ))
+
+ with self._stream_id_gen.get_next() as stream_ordering:
+ yield self.runInteraction("locally_reject_invite", f, stream_ordering)
+
@cached(max_entries=5000)
def get_users_in_room(self, room_id):
def f(txn):
@@ -96,51 +133,36 @@ class RoomMemberStore(SQLBaseStore):
return [r["user_id"] for r in rows]
return self.runInteraction("get_users_in_room", f)
- def get_room_members(self, room_id, membership=None):
- """Retrieve the current room member list for a room.
-
- Args:
- room_id (str): The room to get the list of members.
- membership (synapse.api.constants.Membership): The filter to apply
- to this list, or None to return all members with some state
- associated with this room.
- Returns:
- list of namedtuples representing the members in this room.
- """
- return self.runInteraction(
- "get_room_members",
- self._get_members_events_txn,
- room_id,
- membership=membership,
- ).addCallback(self._get_events)
-
@cached()
- def get_invites_for_user(self, user_id):
- """ Get all the invite events for a user
+ def get_invited_rooms_for_user(self, user_id):
+ """ Get all the rooms the user is invited to
Args:
user_id (str): The user ID.
Returns:
- A deferred list of event objects.
+ A deferred list of RoomsForUser.
"""
return self.get_rooms_for_user_where_membership_is(
user_id, [Membership.INVITE]
- ).addCallback(lambda invites: self._get_events([
- invite.event_id for invite in invites
- ]))
+ )
+
+ @defer.inlineCallbacks
+ def get_invite_for_user_in_room(self, user_id, room_id):
+ """Gets the invite for the given user and room
- def get_leave_and_ban_events_for_user(self, user_id):
- """ Get all the leave events for a user
Args:
- user_id (str): The user ID.
+ user_id (str)
+ room_id (str)
+
Returns:
- A deferred list of event objects.
+ Deferred: Resolves to either a RoomsForUser or None if no invite was
+ found.
"""
- return self.get_rooms_for_user_where_membership_is(
- user_id, (Membership.LEAVE, Membership.BAN)
- ).addCallback(lambda leaves: self._get_events([
- leave.event_id for leave in leaves
- ]))
+ invites = yield self.get_invited_rooms_for_user(user_id)
+ for invite in invites:
+ if invite.room_id == room_id:
+ defer.returnValue(invite)
+ defer.returnValue(None)
def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
""" Get all the rooms for this user where the membership for this user
@@ -165,57 +187,55 @@ class RoomMemberStore(SQLBaseStore):
def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id,
membership_list):
- where_clause = "user_id = ? AND (%s) AND forgotten = 0" % (
- " OR ".join(["membership = ?" for _ in membership_list]),
- )
- args = [user_id]
- args.extend(membership_list)
+ do_invite = Membership.INVITE in membership_list
+ membership_list = [m for m in membership_list if m != Membership.INVITE]
- sql = (
- "SELECT m.room_id, m.sender, m.membership, m.event_id, e.stream_ordering"
- " FROM current_state_events as c"
- " INNER JOIN room_memberships as m"
- " ON m.event_id = c.event_id"
- " INNER JOIN events as e"
- " ON e.event_id = c.event_id"
- " AND m.room_id = c.room_id"
- " AND m.user_id = c.state_key"
- " WHERE %s"
- ) % (where_clause,)
-
- txn.execute(sql, args)
- return [
- RoomsForUser(**r) for r in self.cursor_to_dict(txn)
- ]
+ results = []
+ if membership_list:
+ where_clause = "user_id = ? AND (%s) AND forgotten = 0" % (
+ " OR ".join(["membership = ?" for _ in membership_list]),
+ )
- @cached(max_entries=5000)
- def get_joined_hosts_for_room(self, room_id):
- return self.runInteraction(
- "get_joined_hosts_for_room",
- self._get_joined_hosts_for_room_txn,
- room_id,
- )
+ args = [user_id]
+ args.extend(membership_list)
- def _get_joined_hosts_for_room_txn(self, txn, room_id):
- rows = self._get_members_rows_txn(
- txn,
- room_id, membership=Membership.JOIN
- )
+ sql = (
+ "SELECT m.room_id, m.sender, m.membership, m.event_id, e.stream_ordering"
+ " FROM current_state_events as c"
+ " INNER JOIN room_memberships as m"
+ " ON m.event_id = c.event_id"
+ " INNER JOIN events as e"
+ " ON e.event_id = c.event_id"
+ " AND m.room_id = c.room_id"
+ " AND m.user_id = c.state_key"
+ " WHERE %s"
+ ) % (where_clause,)
+
+ txn.execute(sql, args)
+ results = [
+ RoomsForUser(**r) for r in self.cursor_to_dict(txn)
+ ]
- joined_domains = set(
- UserID.from_string(r["user_id"]).domain
- for r in rows
- )
+ if do_invite:
+ sql = (
+ "SELECT i.room_id, inviter, i.event_id, e.stream_ordering"
+ " FROM local_invites as i"
+ " INNER JOIN events as e USING (event_id)"
+ " WHERE invitee = ? AND locally_rejected is NULL"
+ " AND replaced_by is NULL"
+ )
- return joined_domains
+ txn.execute(sql, (user_id,))
+ results.extend(RoomsForUser(
+ room_id=r["room_id"],
+ sender=r["inviter"],
+ event_id=r["event_id"],
+ stream_ordering=r["stream_ordering"],
+ membership=Membership.INVITE,
+ ) for r in self.cursor_to_dict(txn))
- def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None):
- rows = self._get_members_rows_txn(
- txn,
- room_id, membership, user_id,
- )
- return [r["event_id"] for r in rows]
+ return results
def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
where_clause = "c.room_id = ?"
@@ -251,7 +271,6 @@ class RoomMemberStore(SQLBaseStore):
user_id, membership_list=[Membership.JOIN],
)
- @defer.inlineCallbacks
def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn):
@@ -266,10 +285,13 @@ class RoomMemberStore(SQLBaseStore):
" room_id = ?"
)
txn.execute(sql, (user_id, room_id))
- yield self.runInteraction("forget_membership", f)
- self.was_forgotten_at.invalidate_all()
- self.who_forgot_in_room.invalidate_all()
- self.did_forget.invalidate((user_id, room_id))
+
+ txn.call_after(self.was_forgotten_at.invalidate_all)
+ txn.call_after(self.did_forget.invalidate, (user_id, room_id))
+ self._invalidate_cache_and_stream(
+ txn, self.who_forgot_in_room, (room_id,)
+ )
+ return self.runInteraction("forget_membership", f)
@cachedInlineCallbacks(num_args=2)
def did_forget(self, user_id, room_id):
@@ -297,7 +319,8 @@ class RoomMemberStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=3)
def was_forgotten_at(self, user_id, room_id, event_id):
- """Returns whether user_id has elected to discard history for room_id at event_id.
+ """Returns whether user_id has elected to discard history for room_id at
+ event_id.
event_id must be a membership event."""
def f(txn):
@@ -330,3 +353,98 @@ class RoomMemberStore(SQLBaseStore):
},
desc="who_forgot"
)
+
+ def get_joined_users_from_context(self, event, context):
+ state_group = context.state_group
+ if not state_group:
+ # If state_group is None it means it has yet to be assigned a
+ # state group, i.e. we need to make sure that calls with a state_group
+ # of None don't hit previous cached calls with a None state_group.
+ # To do this we set the state_group to a new object as object() != object()
+ state_group = object()
+
+ return self._get_joined_users_from_context(
+ event.room_id, state_group, context.current_state_ids, event=event,
+ )
+
+ def get_joined_users_from_state(self, room_id, state_group, state_ids):
+ if not state_group:
+ # If state_group is None it means it has yet to be assigned a
+ # state group, i.e. we need to make sure that calls with a state_group
+ # of None don't hit previous cached calls with a None state_group.
+ # To do this we set the state_group to a new object as object() != object()
+ state_group = object()
+
+ return self._get_joined_users_from_context(
+ room_id, state_group, state_ids,
+ )
+
+ @cachedInlineCallbacks(num_args=2, cache_context=True)
+ def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
+ cache_context, event=None):
+ # We don't use `state_group`, its there so that we can cache based
+ # on it. However, its important that its never None, since two current_state's
+ # with a state_group of None are likely to be different.
+ # See bulk_get_push_rules_for_room for how we work around this.
+ assert state_group is not None
+
+ member_event_ids = [
+ e_id
+ for key, e_id in current_state_ids.iteritems()
+ if key[0] == EventTypes.Member
+ ]
+
+ rows = yield self._simple_select_many_batch(
+ table="room_memberships",
+ column="event_id",
+ iterable=member_event_ids,
+ retcols=['user_id'],
+ keyvalues={
+ "membership": Membership.JOIN,
+ },
+ batch_size=500,
+ desc="_get_joined_users_from_context",
+ )
+
+ users_in_room = set(row["user_id"] for row in rows)
+ if event is not None and event.type == EventTypes.Member:
+ if event.membership == Membership.JOIN:
+ if event.event_id in member_event_ids:
+ users_in_room.add(event.state_key)
+
+ defer.returnValue(users_in_room)
+
+ def is_host_joined(self, room_id, host, state_group, state_ids):
+ if not state_group:
+ # If state_group is None it means it has yet to be assigned a
+ # state group, i.e. we need to make sure that calls with a state_group
+ # of None don't hit previous cached calls with a None state_group.
+ # To do this we set the state_group to a new object as object() != object()
+ state_group = object()
+
+ return self._is_host_joined(
+ room_id, host, state_group, state_ids
+ )
+
+ @cachedInlineCallbacks(num_args=3)
+ def _is_host_joined(self, room_id, host, state_group, current_state_ids):
+ # We don't use `state_group`, its there so that we can cache based
+ # on it. However, its important that its never None, since two current_state's
+ # with a state_group of None are likely to be different.
+ # See bulk_get_push_rules_for_room for how we work around this.
+ assert state_group is not None
+
+ for (etype, state_key), event_id in current_state_ids.items():
+ if etype == EventTypes.Member:
+ try:
+ if get_domain_from_id(state_key) != host:
+ continue
+ except:
+ logger.warn("state_key not user_id: %s", state_key)
+ continue
+
+ event = yield self.get_event(event_id, allow_none=True)
+ if event and event.content["membership"] == Membership.JOIN:
+ defer.returnValue(True)
+
+ defer.returnValue(False)
|