summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/handlers/sync.py46
1 files changed, 35 insertions, 11 deletions
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 5f060241b4..f04676e3b8 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -101,6 +101,7 @@ class JoinedSyncResult:
     room_id = attr.ib(type=str)
     timeline = attr.ib(type=TimelineBatch)
     state = attr.ib(type=StateMap[EventBase])
+    state_delta = attr.ib(type=StateMap[EventBase])
     ephemeral = attr.ib(type=List[JsonDict])
     account_data = attr.ib(type=List[JsonDict])
     unread_notifications = attr.ib(type=JsonDict)
@@ -743,7 +744,7 @@ class SyncHandler(object):
         since_token: Optional[StreamToken],
         now_token: StreamToken,
         full_state: bool,
-    ) -> StateMap[EventBase]:
+    ) -> Tuple[StateMap[EventBase], StateMap[EventBase]]:
         """ Works out the difference in state between the start of the timeline
         and the previous sync.
 
@@ -754,6 +755,10 @@ class SyncHandler(object):
             since_token: Token of the end of the previous batch. May be None.
             now_token: Token of the end of the current batch.
             full_state: Whether to force returning the full state.
+
+        Returns:
+            2-tuple of state delta and extra membership events to include (for
+            lazy loading).
         """
         # TODO(mjark) Check if the state events were received by the server
         # after the previous sync, since we need to include those state
@@ -795,6 +800,7 @@ class SyncHandler(object):
                 if event.is_state()
             }
 
+            extra_memberships = {}  # type: StateMap[str]
             if full_state:
                 if batch:
                     current_state_ids = await self.state_store.get_state_ids_for_event(
@@ -885,7 +891,7 @@ class SyncHandler(object):
                         # So we fish out all the member events corresponding to the
                         # timeline here, and then dedupe any redundant ones below.
 
-                        state_ids = await self.state_store.get_state_ids_for_event(
+                        extra_memberships = await self.state_store.get_state_ids_for_event(
                             batch.events[0].event_id,
                             # we only want members!
                             state_filter=StateFilter.from_types(
@@ -908,29 +914,41 @@ class SyncHandler(object):
                     # only send members which aren't in our LruCache (either
                     # because they're new to this client or have been pushed out
                     # of the cache)
-                    logger.debug("filtering state from %r...", state_ids)
+                    logger.debug("filtering state from %r...", extra_memberships)
+                    extra_memberships = {
+                        t: event_id
+                        for t, event_id in extra_memberships.items()
+                        if cache.get(t[1]) != event_id
+                    }
                     state_ids = {
                         t: event_id
-                        for t, event_id in iteritems(state_ids)
+                        for t, event_id in state_ids.items()
                         if cache.get(t[1]) != event_id
                     }
-                    logger.debug("...to %r", state_ids)
+                    logger.debug("...to %r", extra_memberships)
 
                 # add any member IDs we are about to send into our LruCache
                 for t, event_id in itertools.chain(
-                    state_ids.items(), timeline_state.items()
+                    state_ids.items(), timeline_state.items(), extra_memberships.items()
                 ):
                     if t[0] == EventTypes.Member:
                         cache.set(t[1], event_id)
 
         state = {}  # type: Dict[str, EventBase]
-        if state_ids:
-            state = await self.store.get_events(list(state_ids.values()))
+        if state_ids or extra_memberships:
+            state = await self.store.get_events(list(itertools.chain(extra_memberships.values(), state_ids.values())))
+
+        logger.info("State: %s", state)
 
         return {
             (e.type, e.state_key): e
             for e in sync_config.filter_collection.filter_room_state(
-                list(state.values())
+                list(state[e_id] for e_id in state_ids.values() if e_id in state)
+            )
+        }, {
+            (e.type, e.state_key): e
+            for e in sync_config.filter_collection.filter_room_state(
+                list(state[e_id] for e_id in extra_memberships.values() if e_id in state)
             )
         }
 
@@ -1422,7 +1440,7 @@ class SyncHandler(object):
         if since_token:
             for joined_sync in sync_result_builder.joined:
                 it = itertools.chain(
-                    joined_sync.timeline.events, itervalues(joined_sync.state)
+                    joined_sync.timeline.events, itervalues(joined_sync.state_delta)
                 )
                 for event in it:
                     if event.type == EventTypes.Member:
@@ -1842,10 +1860,15 @@ class SyncHandler(object):
         ):
             return
 
-        state = await self.compute_state_delta(
+        state_delta, extra_memberships = await self.compute_state_delta(
             room_id, batch, sync_config, since_token, now_token, full_state=full_state
         )
 
+        logger.info("state delta: %s", state_delta)
+        logger.info("extra_memberships: %s", extra_memberships)
+
+        state = dict(itertools.chain(state_delta.items(), extra_memberships.items()))
+
         summary = {}  # type: Optional[JsonDict]
 
         # we include a summary in room responses when we're lazy loading
@@ -1875,6 +1898,7 @@ class SyncHandler(object):
                 room_id=room_id,
                 timeline=batch,
                 state=state,
+                state_delta=state_delta,
                 ephemeral=ephemeral,
                 account_data=account_data_events,
                 unread_notifications=unread_notifications,