summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/handlers/sync.py36
1 files changed, 32 insertions, 4 deletions
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index fa730ca760..c754cfdeeb 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -468,6 +468,8 @@ class SyncHandler(object):
         with Measure(self.clock, "compute_state_delta"):
 
             types = None
+            member_state_ids = {}
+
             if filter_members:
                 # We only request state for the members needed to display the
                 # timeline:
@@ -492,6 +494,13 @@ class SyncHandler(object):
                     state_ids = yield self.store.get_state_ids_for_event(
                         batch.events[0].event_id, types=types
                     )
+
+                    if filter_members:
+                        member_state_ids = {
+                            t: state_ids[t]
+                            for t in state_ids if t[0] == EventTypes.member
+                        }
+
                 else:
                     current_state_ids = yield self.get_state_at(
                         room_id, stream_position=now_token, types=types
@@ -499,6 +508,12 @@ class SyncHandler(object):
 
                     state_ids = current_state_ids
 
+                    if filter_members:
+                        member_state_ids = {
+                            t: state_ids[t]
+                            for t in state_ids if t[0] == EventTypes.member
+                        }
+
                 timeline_state = {
                     (event.type, event.state_key): event.event_id
                     for event in batch.events if event.is_state()
@@ -507,6 +522,7 @@ class SyncHandler(object):
                 state_ids = _calculate_state(
                     timeline_contains=timeline_state,
                     timeline_start=state_ids,
+                    timeline_start_members=member_state_ids,
                     previous={},
                     current=current_state_ids,
                 )
@@ -523,6 +539,12 @@ class SyncHandler(object):
                     batch.events[0].event_id, types=types
                 )
 
+                if filter_members:
+                    member_state_ids = {
+                        t: state_at_timeline_start[t]
+                        for t in state_ids if t[0] == EventTypes.member
+                    }
+
                 timeline_state = {
                     (event.type, event.state_key): event.event_id
                     for event in batch.events if event.is_state()
@@ -531,6 +553,7 @@ class SyncHandler(object):
                 state_ids = _calculate_state(
                     timeline_contains=timeline_state,
                     timeline_start=state_at_timeline_start,
+                    timeline_start_members=member_state_ids,
                     previous=state_at_previous_sync,
                     current=current_state_ids,
                 )
@@ -1440,12 +1463,16 @@ def _action_has_highlight(actions):
     return False
 
 
-def _calculate_state(timeline_contains, timeline_start, previous, current):
+def _calculate_state(timeline_contains, timeline_start, timeline_start_members,
+                     previous, current):
     """Works out what state to include in a sync response.
 
     Args:
         timeline_contains (dict): state in the timeline
         timeline_start (dict): state at the start of the timeline
+        timeline_start_members (dict): state at the start of the timeline
+            for room members who participate in this chunk of timeline.
+            Should always be a subset of timeline_start.
         previous (dict): state at the end of the previous sync (or empty dict
             if this is an initial sync)
         current (dict): state at the end of the timeline
@@ -1464,11 +1491,12 @@ def _calculate_state(timeline_contains, timeline_start, previous, current):
     }
 
     c_ids = set(e for e in current.values())
-    tc_ids = set(e for e in timeline_contains.values())
-    p_ids = set(e for e in previous.values())
     ts_ids = set(e for e in timeline_start.values())
+    tsm_ids = set(e for e in timeline_start_members.values())
+    p_ids = set(e for e in previous.values())
+    tc_ids = set(e for e in timeline_contains.values())
 
-    state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
+    state_ids = (((c_ids | ts_ids) - p_ids) - tc_ids) | tsm_ids
 
     return {
         event_id_to_key[e]: e for e in state_ids