diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index cab1660830..6ab10db328 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -354,7 +354,8 @@ class RoomMemberStore(SQLBaseStore):
desc="who_forgot"
)
- def get_joined_users_from_context(self, room_id, state_group, state_ids):
+ def get_joined_users_from_context(self, event, context):
+ state_group = context.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
@@ -363,12 +364,24 @@ class RoomMemberStore(SQLBaseStore):
state_group = object()
return self._get_joined_users_from_context(
- room_id, state_group, state_ids
+ event.room_id, state_group, context.current_state_ids, event=event,
+ )
+
+ def get_joined_users_from_state(self, room_id, state_group, state_ids):
+ if not state_group:
+ # If state_group is None it means it has yet to be assigned a
+ # state group, i.e. we need to make sure that calls with a state_group
+ # of None don't hit previous cached calls with a None state_group.
+ # To do this we set the state_group to a new object as object() != object()
+ state_group = object()
+
+ return self._get_joined_users_from_context(
+ room_id, state_group, state_ids,
)
@cachedInlineCallbacks(num_args=2, cache_context=True)
def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
- cache_context):
+ cache_context, event=None):
# We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different.
@@ -393,7 +406,13 @@ class RoomMemberStore(SQLBaseStore):
desc="_get_joined_users_from_context",
)
- defer.returnValue(set(row["user_id"] for row in rows))
+ users_in_room = set(row["user_id"] for row in rows)
+ if event is not None and event.type == EventTypes.Member:
+ if event.membership == Membership.JOIN:
+ if event.event_id in member_event_ids:
+ users_in_room.add(event.state_key)
+
+ defer.returnValue(users_in_room)
def is_host_joined(self, room_id, host, state_group, state_ids):
if not state_group:
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 56bfdc0b55..dce5a2f135 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -108,9 +108,6 @@ class StateStore(SQLBaseStore):
state_event_ids = dict(context.current_state_ids)
- if event.is_state():
- state_event_ids[(event.type, event.state_key)] = event.event_id
-
self._simple_insert_txn(
txn,
table="state_groups",
|