| diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 1d0f0058a2..b2b9b928c9 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -823,15 +823,17 @@ class SyncHandler(BaseHandler):
         # TODO(mjark) Check for new redactions in the state events.
 
         with Measure(self.clock, "compute_state_delta"):
+            current_state = yield self.get_state_at(
+                room_id, stream_position=now_token
+            )
+
             if full_state:
                 if batch:
                     state = yield self.store.get_state_for_event(
                         batch.events[0].event_id
                     )
                 else:
-                    state = yield self.get_state_at(
-                        room_id, stream_position=now_token
-                    )
+                    state = current_state
 
                 timeline_state = {
                     (event.type, event.state_key): event
@@ -842,6 +844,7 @@ class SyncHandler(BaseHandler):
                     timeline_contains=timeline_state,
                     timeline_start=state,
                     previous={},
+                    current=current_state,
                 )
             elif batch.limited:
                 state_at_previous_sync = yield self.get_state_at(
@@ -861,6 +864,7 @@ class SyncHandler(BaseHandler):
                     timeline_contains=timeline_state,
                     timeline_start=state_at_timeline_start,
                     previous=state_at_previous_sync,
+                    current=current_state,
                 )
             else:
                 state = {}
@@ -920,7 +924,7 @@ def _action_has_highlight(actions):
     return False
 
 
-def _calculate_state(timeline_contains, timeline_start, previous):
+def _calculate_state(timeline_contains, timeline_start, previous, current):
     """Works out what state to include in a sync response.
 
     Args:
@@ -928,6 +932,7 @@ def _calculate_state(timeline_contains, timeline_start, previous):
         timeline_start (dict): state at the start of the timeline
         previous (dict): state at the end of the previous sync (or empty dict
             if this is an initial sync)
+        current (dict): state at the end of the timeline
 
     Returns:
         dict
@@ -938,14 +943,16 @@ def _calculate_state(timeline_contains, timeline_start, previous):
             timeline_contains.values(),
             previous.values(),
             timeline_start.values(),
+            current.values(),
         )
     }
 
+    c_ids = set(e.event_id for e in current.values())
     tc_ids = set(e.event_id for e in timeline_contains.values())
     p_ids = set(e.event_id for e in previous.values())
     ts_ids = set(e.event_id for e in timeline_start.values())
 
-    state_ids = (ts_ids - p_ids) - tc_ids
+    state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
 
     evs = (event_id_to_state[e] for e in state_ids)
     return {
 |