summary refs log tree commit diff
path: root/synapse/storage/roommember.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/storage/roommember.py27
1 files changed, 23 insertions, 4 deletions
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index cab1660830..6ab10db328 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -354,7 +354,8 @@ class RoomMemberStore(SQLBaseStore):
             desc="who_forgot"
         )
 
-    def get_joined_users_from_context(self, room_id, state_group, state_ids):
+    def get_joined_users_from_context(self, event, context):
+        state_group = context.state_group
         if not state_group:
             # If state_group is None it means it has yet to be assigned a
             # state group, i.e. we need to make sure that calls with a state_group
@@ -363,12 +364,24 @@ class RoomMemberStore(SQLBaseStore):
             state_group = object()
 
         return self._get_joined_users_from_context(
-            room_id, state_group, state_ids
+            event.room_id, state_group, context.current_state_ids, event=event,
+        )
+
+    def get_joined_users_from_state(self, room_id, state_group, state_ids):
+        if not state_group:
+            # If state_group is None it means it has yet to be assigned a
+            # state group, i.e. we need to make sure that calls with a state_group
+            # of None don't hit previous cached calls with a None state_group.
+            # To do this we set the state_group to a new object as object() != object()
+            state_group = object()
+
+        return self._get_joined_users_from_context(
+            room_id, state_group, state_ids,
         )
 
     @cachedInlineCallbacks(num_args=2, cache_context=True)
     def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
-                                       cache_context):
+                                       cache_context, event=None):
         # We don't use `state_group`, its there so that we can cache based
         # on it. However, its important that its never None, since two current_state's
         # with a state_group of None are likely to be different.
@@ -393,7 +406,13 @@ class RoomMemberStore(SQLBaseStore):
             desc="_get_joined_users_from_context",
         )
 
-        defer.returnValue(set(row["user_id"] for row in rows))
+        users_in_room = set(row["user_id"] for row in rows)
+        if event is not None and event.type == EventTypes.Member:
+            if event.membership == Membership.JOIN:
+                if event.event_id in member_event_ids:
+                    users_in_room.add(event.state_key)
+
+        defer.returnValue(users_in_room)
 
     def is_host_joined(self, room_id, host, state_group, state_ids):
         if not state_group: