diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 075566417f..3109e30414 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -23,6 +23,7 @@ from twisted.internet import defer
import collections
import logging
+import itertools
logger = logging.getLogger(__name__)
@@ -672,35 +673,10 @@ class SyncHandler(BaseHandler):
account_data_by_room,
all_ephemeral_by_room,
batch, full_state=False):
- if full_state:
- state = yield self.get_state_at(room_id, now_token)
-
- elif batch.limited:
- current_state = yield self.get_state_at(room_id, now_token)
-
- state_at_previous_sync = yield self.get_state_at(
- room_id, stream_position=since_token
- )
-
- state = yield self.compute_state_delta(
- since_token=since_token,
- previous_state=state_at_previous_sync,
- current_state=current_state,
- )
- else:
- state = {
- (event.type, event.state_key): event
- for event in batch.events if event.is_state()
- }
-
- just_joined = yield self.check_joined_room(sync_config, state)
- if just_joined:
- state = yield self.get_state_at(room_id, now_token)
-
- state = {
- (e.type, e.state_key): e
- for e in sync_config.filter_collection.filter_room_state(state.values())
- }
+ state = yield self.compute_state_delta(
+ room_id, batch, sync_config, since_token, now_token,
+ full_state=full_state
+ )
account_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
@@ -766,30 +742,11 @@ class SyncHandler(BaseHandler):
logger.debug("Recents %r", batch)
- state_events_at_leave = yield self.store.get_state_for_event(
- leave_event_id
+ state_events_delta = yield self.compute_state_delta(
+ room_id, batch, sync_config, since_token, leave_token,
+ full_state=full_state
)
- if not full_state:
- state_at_previous_sync = yield self.get_state_at(
- room_id, stream_position=since_token
- )
-
- state_events_delta = yield self.compute_state_delta(
- since_token=since_token,
- previous_state=state_at_previous_sync,
- current_state=state_events_at_leave,
- )
- else:
- state_events_delta = state_events_at_leave
-
- state_events_delta = {
- (e.type, e.state_key): e
- for e in sync_config.filter_collection.filter_room_state(
- state_events_delta.values()
- )
- }
-
account_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
)
@@ -843,15 +800,18 @@ class SyncHandler(BaseHandler):
state = {}
defer.returnValue(state)
- def compute_state_delta(self, since_token, previous_state, current_state):
- """ Works out the differnce in state between the current state and the
- state the client got when it last performed a sync.
-
- :param str since_token: the point we are comparing against
- :param dict[(str,str), synapse.events.FrozenEvent] previous_state: the
- state to compare to
- :param dict[(str,str), synapse.events.FrozenEvent] current_state: the
- new state
+ @defer.inlineCallbacks
+ def compute_state_delta(self, room_id, batch, sync_config, since_token, now_token,
+ full_state):
+ """ Works out the differnce in state between the start of the timeline
+ and the previous sync.
+
+ :param str room_id
+ :param TimelineBatch batch
+ :param sync_config
+ :param str since_token
+ :param str now_token
+ :param bool full_state
:returns A new event dictionary
"""
@@ -860,12 +820,50 @@ class SyncHandler(BaseHandler):
# updates even if they occured logically before the previous event.
# TODO(mjark) Check for new redactions in the state events.
- state_delta = {}
- for key, event in current_state.iteritems():
- if (key not in previous_state or
- previous_state[key].event_id != event.event_id):
- state_delta[key] = event
- return state_delta
+ if full_state:
+ if batch:
+ state = yield self.store.get_state_for_event(batch.events[0].event_id)
+ else:
+ state = yield self.get_state_at(
+ room_id, stream_position=now_token
+ )
+
+ timeline_state = {
+ (event.type, event.state_key): event
+ for event in batch.events if event.is_state()
+ }
+
+ state = _calculate_state(
+ timeline_contains=timeline_state,
+ timeline_start=state,
+ previous={},
+ )
+ elif batch.limited:
+ state_at_previous_sync = yield self.get_state_at(
+ room_id, stream_position=since_token
+ )
+
+ state_at_timeline_start = yield self.store.get_state_for_event(
+ batch.events[0].event_id
+ )
+
+ timeline_state = {
+ (event.type, event.state_key): event
+ for event in batch.events if event.is_state()
+ }
+
+ state = _calculate_state(
+ timeline_contains=timeline_state,
+ timeline_start=state_at_timeline_start,
+ previous=state_at_previous_sync,
+ )
+ else:
+ state = {}
+
+ defer.returnValue({
+ (e.type, e.state_key): e
+ for e in sync_config.filter_collection.filter_room_state(state.values())
+ })
def check_joined_room(self, sync_config, state_delta):
"""
@@ -912,3 +910,37 @@ def _action_has_highlight(actions):
pass
return False
+
+
+def _calculate_state(timeline_contains, timeline_start, previous):
+ """Works out what state to include in a sync response.
+
+ Args:
+ timeline_contains (dict): state in the timeline
+ timeline_start (dict): state at the start of the timeline
+ previous (dict): state at the end of the previous sync (or empty dict
+ if there is an initial sync)
+
+ Returns:
+ dict
+ """
+ event_id_to_state = {
+ e.event_id: e
+ for e in itertools.chain(
+ timeline_contains.values(),
+ previous.values(),
+ timeline_start.values(),
+ )
+ }
+
+ tc_ids = set(e.event_id for e in timeline_contains.values())
+ p_ids = set(e.event_id for e in previous.values())
+ ts_ids = set(e.event_id for e in timeline_start.values())
+
+ state_ids = (ts_ids - p_ids) - tc_ids
+
+ evs = (event_id_to_state[e] for e in state_ids)
+ return {
+ (e.type, e.state_key): e
+ for e in evs
+ }
|