diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 74a81a0a51..0dc87aee7f 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -19,7 +19,7 @@ from synapse.storage import DataStore
from synapse.storage.event_federation import EventFederationStore
from synapse.storage.event_push_actions import EventPushActionsStore
from synapse.storage.events import EventsWorkerStore
-from synapse.storage.roommember import RoomMemberStore
+from synapse.storage.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateGroupWorkerStore
from synapse.storage.stream import StreamStore
from synapse.storage.signatures import SignatureStore
@@ -39,7 +39,8 @@ logger = logging.getLogger(__name__)
# the method descriptor on the DataStore and chuck them into our class.
-class SlavedEventStore(EventsWorkerStore, StateGroupWorkerStore, BaseSlavedStore):
+class SlavedEventStore(RoomMemberWorkerStore, EventsWorkerStore,
+ StateGroupWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedEventStore, self).__init__(db_conn, hs)
@@ -69,18 +70,9 @@ class SlavedEventStore(EventsWorkerStore, StateGroupWorkerStore, BaseSlavedStore
# Cached functions can't be accessed through a class instance so we need
# to reach inside the __dict__ to extract them.
- get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"]
- get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"]
- get_hosts_in_room = RoomMemberStore.__dict__["get_hosts_in_room"]
- get_users_who_share_room_with_user = (
- RoomMemberStore.__dict__["get_users_who_share_room_with_user"]
- )
get_latest_event_ids_in_room = EventFederationStore.__dict__[
"get_latest_event_ids_in_room"
]
- get_invited_rooms_for_user = RoomMemberStore.__dict__[
- "get_invited_rooms_for_user"
- ]
get_unread_event_push_actions_by_room_for_user = (
EventPushActionsStore.__dict__["get_unread_event_push_actions_by_room_for_user"]
)
@@ -93,7 +85,6 @@ class SlavedEventStore(EventsWorkerStore, StateGroupWorkerStore, BaseSlavedStore
get_recent_event_ids_for_room = (
StreamStore.__dict__["get_recent_event_ids_for_room"]
)
- _get_joined_hosts_cache = RoomMemberStore.__dict__["_get_joined_hosts_cache"]
has_room_changed_since = DataStore.has_room_changed_since.__func__
get_unread_push_actions_for_user_in_range_for_http = (
@@ -105,9 +96,6 @@ class SlavedEventStore(EventsWorkerStore, StateGroupWorkerStore, BaseSlavedStore
get_push_action_users_in_range = (
DataStore.get_push_action_users_in_range.__func__
)
- get_rooms_for_user_where_membership_is = (
- DataStore.get_rooms_for_user_where_membership_is.__func__
- )
get_membership_changes_for_user = (
DataStore.get_membership_changes_for_user.__func__
)
@@ -116,27 +104,15 @@ class SlavedEventStore(EventsWorkerStore, StateGroupWorkerStore, BaseSlavedStore
DataStore.get_room_events_stream_for_room.__func__
)
get_events_around = DataStore.get_events_around.__func__
- get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__
- get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__
- _get_joined_users_from_context = (
- RoomMemberStore.__dict__["_get_joined_users_from_context"]
- )
-
- get_joined_hosts = DataStore.get_joined_hosts.__func__
- _get_joined_hosts = RoomMemberStore.__dict__["_get_joined_hosts"]
get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__
get_room_events_stream_for_rooms = (
DataStore.get_room_events_stream_for_rooms.__func__
)
- is_host_joined = RoomMemberStore.__dict__["is_host_joined"]
get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__
_set_before_and_after = staticmethod(DataStore._set_before_and_after)
- _get_rooms_for_user_where_membership_is_txn = (
- DataStore._get_rooms_for_user_where_membership_is_txn.__func__
- )
_get_events_around_txn = DataStore._get_events_around_txn.__func__
get_backfill_events = DataStore.get_backfill_events.__func__
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 3e77fd3901..6574fe74b4 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -17,7 +17,7 @@ from twisted.internet import defer
from collections import namedtuple
-from ._base import SQLBaseStore
+from synapse.storage.events import EventsWorkerStore
from synapse.util.async import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
@@ -48,97 +48,7 @@ ProfileInfo = namedtuple(
_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
-class RoomMemberStore(SQLBaseStore):
- def __init__(self, db_conn, hs):
- super(RoomMemberStore, self).__init__(db_conn, hs)
- self.register_background_update_handler(
- _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
- )
-
- def _store_room_members_txn(self, txn, events, backfilled):
- """Store a room member in the database.
- """
- self._simple_insert_many_txn(
- txn,
- table="room_memberships",
- values=[
- {
- "event_id": event.event_id,
- "user_id": event.state_key,
- "sender": event.user_id,
- "room_id": event.room_id,
- "membership": event.membership,
- "display_name": event.content.get("displayname", None),
- "avatar_url": event.content.get("avatar_url", None),
- }
- for event in events
- ]
- )
-
- for event in events:
- txn.call_after(
- self._membership_stream_cache.entity_has_changed,
- event.state_key, event.internal_metadata.stream_ordering
- )
- txn.call_after(
- self.get_invited_rooms_for_user.invalidate, (event.state_key,)
- )
-
- # We update the local_invites table only if the event is "current",
- # i.e., its something that has just happened.
- # The only current event that can also be an outlier is if its an
- # invite that has come in across federation.
- is_new_state = not backfilled and (
- not event.internal_metadata.is_outlier()
- or event.internal_metadata.is_invite_from_remote()
- )
- is_mine = self.hs.is_mine_id(event.state_key)
- if is_new_state and is_mine:
- if event.membership == Membership.INVITE:
- self._simple_insert_txn(
- txn,
- table="local_invites",
- values={
- "event_id": event.event_id,
- "invitee": event.state_key,
- "inviter": event.sender,
- "room_id": event.room_id,
- "stream_id": event.internal_metadata.stream_ordering,
- }
- )
- else:
- sql = (
- "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
- " room_id = ? AND invitee = ? AND locally_rejected is NULL"
- " AND replaced_by is NULL"
- )
-
- txn.execute(sql, (
- event.internal_metadata.stream_ordering,
- event.event_id,
- event.room_id,
- event.state_key,
- ))
-
- @defer.inlineCallbacks
- def locally_reject_invite(self, user_id, room_id):
- sql = (
- "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
- " room_id = ? AND invitee = ? AND locally_rejected is NULL"
- " AND replaced_by is NULL"
- )
-
- def f(txn, stream_ordering):
- txn.execute(sql, (
- stream_ordering,
- True,
- room_id,
- user_id,
- ))
-
- with self._stream_id_gen.get_next() as stream_ordering:
- yield self.runInteraction("locally_reject_invite", f, stream_ordering)
-
+class RoomMemberWorkerStore(EventsWorkerStore):
@cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
def get_hosts_in_room(self, room_id, cache_context):
"""Returns the set of all hosts currently in the room
@@ -295,89 +205,6 @@ class RoomMemberStore(SQLBaseStore):
defer.returnValue(user_who_share_room)
- def forget(self, user_id, room_id):
- """Indicate that user_id wishes to discard history for room_id."""
- def f(txn):
- sql = (
- "UPDATE"
- " room_memberships"
- " SET"
- " forgotten = 1"
- " WHERE"
- " user_id = ?"
- " AND"
- " room_id = ?"
- )
- txn.execute(sql, (user_id, room_id))
-
- txn.call_after(self.was_forgotten_at.invalidate_all)
- txn.call_after(self.did_forget.invalidate, (user_id, room_id))
- self._invalidate_cache_and_stream(
- txn, self.who_forgot_in_room, (room_id,)
- )
- return self.runInteraction("forget_membership", f)
-
- @cachedInlineCallbacks(num_args=2)
- def did_forget(self, user_id, room_id):
- """Returns whether user_id has elected to discard history for room_id.
-
- Returns False if they have since re-joined."""
- def f(txn):
- sql = (
- "SELECT"
- " COUNT(*)"
- " FROM"
- " room_memberships"
- " WHERE"
- " user_id = ?"
- " AND"
- " room_id = ?"
- " AND"
- " forgotten = 0"
- )
- txn.execute(sql, (user_id, room_id))
- rows = txn.fetchall()
- return rows[0][0]
- count = yield self.runInteraction("did_forget_membership", f)
- defer.returnValue(count == 0)
-
- @cachedInlineCallbacks(num_args=3)
- def was_forgotten_at(self, user_id, room_id, event_id):
- """Returns whether user_id has elected to discard history for room_id at
- event_id.
-
- event_id must be a membership event."""
- def f(txn):
- sql = (
- "SELECT"
- " forgotten"
- " FROM"
- " room_memberships"
- " WHERE"
- " user_id = ?"
- " AND"
- " room_id = ?"
- " AND"
- " event_id = ?"
- )
- txn.execute(sql, (user_id, room_id, event_id))
- rows = txn.fetchall()
- return rows[0][0]
- forgot = yield self.runInteraction("did_forget_membership_at", f)
- defer.returnValue(forgot == 1)
-
- @cached()
- def who_forgot_in_room(self, room_id):
- return self._simple_select_list(
- table="room_memberships",
- retcols=("user_id", "event_id"),
- keyvalues={
- "room_id": room_id,
- "forgotten": 1,
- },
- desc="who_forgot"
- )
-
def get_joined_users_from_context(self, event, context):
state_group = context.state_group
if not state_group:
@@ -600,6 +427,185 @@ class RoomMemberStore(SQLBaseStore):
defer.returnValue(joined_hosts)
+ @cached(max_entries=10000, iterable=True)
+ def _get_joined_hosts_cache(self, room_id):
+ return _JoinedHostsCache(self, room_id)
+
+
+class RoomMemberStore(RoomMemberWorkerStore):
+ def __init__(self, db_conn, hs):
+ super(RoomMemberStore, self).__init__(db_conn, hs)
+ self.register_background_update_handler(
+ _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
+ )
+
+ def _store_room_members_txn(self, txn, events, backfilled):
+ """Store a room member in the database.
+ """
+ self._simple_insert_many_txn(
+ txn,
+ table="room_memberships",
+ values=[
+ {
+ "event_id": event.event_id,
+ "user_id": event.state_key,
+ "sender": event.user_id,
+ "room_id": event.room_id,
+ "membership": event.membership,
+ "display_name": event.content.get("displayname", None),
+ "avatar_url": event.content.get("avatar_url", None),
+ }
+ for event in events
+ ]
+ )
+
+ for event in events:
+ txn.call_after(
+ self._membership_stream_cache.entity_has_changed,
+ event.state_key, event.internal_metadata.stream_ordering
+ )
+ txn.call_after(
+ self.get_invited_rooms_for_user.invalidate, (event.state_key,)
+ )
+
+ # We update the local_invites table only if the event is "current",
+ # i.e., its something that has just happened.
+ # The only current event that can also be an outlier is if its an
+ # invite that has come in across federation.
+ is_new_state = not backfilled and (
+ not event.internal_metadata.is_outlier()
+ or event.internal_metadata.is_invite_from_remote()
+ )
+ is_mine = self.hs.is_mine_id(event.state_key)
+ if is_new_state and is_mine:
+ if event.membership == Membership.INVITE:
+ self._simple_insert_txn(
+ txn,
+ table="local_invites",
+ values={
+ "event_id": event.event_id,
+ "invitee": event.state_key,
+ "inviter": event.sender,
+ "room_id": event.room_id,
+ "stream_id": event.internal_metadata.stream_ordering,
+ }
+ )
+ else:
+ sql = (
+ "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
+ " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+ " AND replaced_by is NULL"
+ )
+
+ txn.execute(sql, (
+ event.internal_metadata.stream_ordering,
+ event.event_id,
+ event.room_id,
+ event.state_key,
+ ))
+
+ @defer.inlineCallbacks
+ def locally_reject_invite(self, user_id, room_id):
+ sql = (
+ "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
+ " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+ " AND replaced_by is NULL"
+ )
+
+ def f(txn, stream_ordering):
+ txn.execute(sql, (
+ stream_ordering,
+ True,
+ room_id,
+ user_id,
+ ))
+
+ with self._stream_id_gen.get_next() as stream_ordering:
+ yield self.runInteraction("locally_reject_invite", f, stream_ordering)
+
+ def forget(self, user_id, room_id):
+ """Indicate that user_id wishes to discard history for room_id."""
+ def f(txn):
+ sql = (
+ "UPDATE"
+ " room_memberships"
+ " SET"
+ " forgotten = 1"
+ " WHERE"
+ " user_id = ?"
+ " AND"
+ " room_id = ?"
+ )
+ txn.execute(sql, (user_id, room_id))
+
+ txn.call_after(self.was_forgotten_at.invalidate_all)
+ txn.call_after(self.did_forget.invalidate, (user_id, room_id))
+ self._invalidate_cache_and_stream(
+ txn, self.who_forgot_in_room, (room_id,)
+ )
+ return self.runInteraction("forget_membership", f)
+
+ @cachedInlineCallbacks(num_args=2)
+ def did_forget(self, user_id, room_id):
+ """Returns whether user_id has elected to discard history for room_id.
+
+ Returns False if they have since re-joined."""
+ def f(txn):
+ sql = (
+ "SELECT"
+ " COUNT(*)"
+ " FROM"
+ " room_memberships"
+ " WHERE"
+ " user_id = ?"
+ " AND"
+ " room_id = ?"
+ " AND"
+ " forgotten = 0"
+ )
+ txn.execute(sql, (user_id, room_id))
+ rows = txn.fetchall()
+ return rows[0][0]
+ count = yield self.runInteraction("did_forget_membership", f)
+ defer.returnValue(count == 0)
+
+ @cachedInlineCallbacks(num_args=3)
+ def was_forgotten_at(self, user_id, room_id, event_id):
+ """Returns whether user_id has elected to discard history for room_id at
+ event_id.
+
+ event_id must be a membership event."""
+ def f(txn):
+ sql = (
+ "SELECT"
+ " forgotten"
+ " FROM"
+ " room_memberships"
+ " WHERE"
+ " user_id = ?"
+ " AND"
+ " room_id = ?"
+ " AND"
+ " event_id = ?"
+ )
+ txn.execute(sql, (user_id, room_id, event_id))
+ rows = txn.fetchall()
+ return rows[0][0]
+ forgot = yield self.runInteraction("did_forget_membership_at", f)
+ defer.returnValue(forgot == 1)
+
+ @cached()
+ def who_forgot_in_room(self, room_id):
+ return self._simple_select_list(
+ table="room_memberships",
+ retcols=("user_id", "event_id"),
+ keyvalues={
+ "room_id": room_id,
+ "forgotten": 1,
+ },
+ desc="who_forgot"
+ )
+
@defer.inlineCallbacks
def _background_add_membership_profile(self, progress, batch_size):
target_min_stream_id = progress.get(
@@ -675,10 +681,6 @@ class RoomMemberStore(SQLBaseStore):
defer.returnValue(result)
- @cached(max_entries=10000, iterable=True)
- def _get_joined_hosts_cache(self, room_id):
- return _JoinedHostsCache(self, room_id)
-
class _JoinedHostsCache(object):
"""Cache for joined hosts in a room that is optimised to handle updates
|