diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 075566417f..3109e30414 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -23,6 +23,7 @@ from twisted.internet import defer
import collections
import logging
+import itertools
logger = logging.getLogger(__name__)
@@ -672,35 +673,10 @@ class SyncHandler(BaseHandler):
account_data_by_room,
all_ephemeral_by_room,
batch, full_state=False):
- if full_state:
- state = yield self.get_state_at(room_id, now_token)
-
- elif batch.limited:
- current_state = yield self.get_state_at(room_id, now_token)
-
- state_at_previous_sync = yield self.get_state_at(
- room_id, stream_position=since_token
- )
-
- state = yield self.compute_state_delta(
- since_token=since_token,
- previous_state=state_at_previous_sync,
- current_state=current_state,
- )
- else:
- state = {
- (event.type, event.state_key): event
- for event in batch.events if event.is_state()
- }
-
- just_joined = yield self.check_joined_room(sync_config, state)
- if just_joined:
- state = yield self.get_state_at(room_id, now_token)
-
- state = {
- (e.type, e.state_key): e
- for e in sync_config.filter_collection.filter_room_state(state.values())
- }
+ state = yield self.compute_state_delta(
+ room_id, batch, sync_config, since_token, now_token,
+ full_state=full_state
+ )
account_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
@@ -766,30 +742,11 @@ class SyncHandler(BaseHandler):
logger.debug("Recents %r", batch)
- state_events_at_leave = yield self.store.get_state_for_event(
- leave_event_id
+ state_events_delta = yield self.compute_state_delta(
+ room_id, batch, sync_config, since_token, leave_token,
+ full_state=full_state
)
- if not full_state:
- state_at_previous_sync = yield self.get_state_at(
- room_id, stream_position=since_token
- )
-
- state_events_delta = yield self.compute_state_delta(
- since_token=since_token,
- previous_state=state_at_previous_sync,
- current_state=state_events_at_leave,
- )
- else:
- state_events_delta = state_events_at_leave
-
- state_events_delta = {
- (e.type, e.state_key): e
- for e in sync_config.filter_collection.filter_room_state(
- state_events_delta.values()
- )
- }
-
account_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room
)
@@ -843,15 +800,18 @@ class SyncHandler(BaseHandler):
state = {}
defer.returnValue(state)
- def compute_state_delta(self, since_token, previous_state, current_state):
- """ Works out the differnce in state between the current state and the
- state the client got when it last performed a sync.
-
- :param str since_token: the point we are comparing against
- :param dict[(str,str), synapse.events.FrozenEvent] previous_state: the
- state to compare to
- :param dict[(str,str), synapse.events.FrozenEvent] current_state: the
- new state
+ @defer.inlineCallbacks
+ def compute_state_delta(self, room_id, batch, sync_config, since_token, now_token,
+ full_state):
+ """ Works out the differnce in state between the start of the timeline
+ and the previous sync.
+
+ :param str room_id
+ :param TimelineBatch batch
+ :param sync_config
+ :param str since_token
+ :param str now_token
+ :param bool full_state
:returns A new event dictionary
"""
@@ -860,12 +820,50 @@ class SyncHandler(BaseHandler):
# updates even if they occured logically before the previous event.
# TODO(mjark) Check for new redactions in the state events.
- state_delta = {}
- for key, event in current_state.iteritems():
- if (key not in previous_state or
- previous_state[key].event_id != event.event_id):
- state_delta[key] = event
- return state_delta
+ if full_state:
+ if batch:
+ state = yield self.store.get_state_for_event(batch.events[0].event_id)
+ else:
+ state = yield self.get_state_at(
+ room_id, stream_position=now_token
+ )
+
+ timeline_state = {
+ (event.type, event.state_key): event
+ for event in batch.events if event.is_state()
+ }
+
+ state = _calculate_state(
+ timeline_contains=timeline_state,
+ timeline_start=state,
+ previous={},
+ )
+ elif batch.limited:
+ state_at_previous_sync = yield self.get_state_at(
+ room_id, stream_position=since_token
+ )
+
+ state_at_timeline_start = yield self.store.get_state_for_event(
+ batch.events[0].event_id
+ )
+
+ timeline_state = {
+ (event.type, event.state_key): event
+ for event in batch.events if event.is_state()
+ }
+
+ state = _calculate_state(
+ timeline_contains=timeline_state,
+ timeline_start=state_at_timeline_start,
+ previous=state_at_previous_sync,
+ )
+ else:
+ state = {}
+
+ defer.returnValue({
+ (e.type, e.state_key): e
+ for e in sync_config.filter_collection.filter_room_state(state.values())
+ })
def check_joined_room(self, sync_config, state_delta):
"""
@@ -912,3 +910,37 @@ def _action_has_highlight(actions):
pass
return False
+
+
+def _calculate_state(timeline_contains, timeline_start, previous):
+ """Works out what state to include in a sync response.
+
+ Args:
+ timeline_contains (dict): state in the timeline
+ timeline_start (dict): state at the start of the timeline
+ previous (dict): state at the end of the previous sync (or empty dict
+ if there is an initial sync)
+
+ Returns:
+ dict
+ """
+ event_id_to_state = {
+ e.event_id: e
+ for e in itertools.chain(
+ timeline_contains.values(),
+ previous.values(),
+ timeline_start.values(),
+ )
+ }
+
+ tc_ids = set(e.event_id for e in timeline_contains.values())
+ p_ids = set(e.event_id for e in previous.values())
+ ts_ids = set(e.event_id for e in timeline_start.values())
+
+ state_ids = (ts_ids - p_ids) - tc_ids
+
+ evs = (event_id_to_state[e] for e in state_ids)
+ return {
+ (e.type, e.state_key): e
+ for e in evs
+ }
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 07b5b5dfd5..140ce2704b 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -20,7 +20,6 @@ from synapse.http.servlet import (
)
from synapse.handlers.sync import SyncConfig
from synapse.types import StreamToken
-from synapse.events import FrozenEvent
from synapse.events.utils import (
serialize_event, format_event_for_client_v2_without_room_id,
)
@@ -287,9 +286,6 @@ class SyncRestServlet(RestServlet):
state_dict = room.state
timeline_events = room.timeline.events
- state_dict = SyncRestServlet._rollback_state_for_timeline(
- state_dict, timeline_events)
-
state_events = state_dict.values()
serialized_state = [serialize(e) for e in state_events]
@@ -314,77 +310,6 @@ class SyncRestServlet(RestServlet):
return result
- @staticmethod
- def _rollback_state_for_timeline(state, timeline):
- """
- Wind the state dictionary backwards, so that it represents the
- state at the start of the timeline, rather than at the end.
-
- :param dict[(str, str), synapse.events.EventBase] state: the
- state dictionary. Will be updated to the state before the timeline.
- :param list[synapse.events.EventBase] timeline: the event timeline
- :return: updated state dictionary
- """
-
- result = state.copy()
-
- for timeline_event in reversed(timeline):
- if not timeline_event.is_state():
- continue
-
- event_key = (timeline_event.type, timeline_event.state_key)
-
- logger.debug("Considering %s for removal", event_key)
-
- state_event = result.get(event_key)
- if (state_event is None or
- state_event.event_id != timeline_event.event_id):
- # the event in the timeline isn't present in the state
- # dictionary.
- #
- # the most likely cause for this is that there was a fork in
- # the event graph, and the state is no longer valid. Really,
- # the event shouldn't be in the timeline. We're going to ignore
- # it for now, however.
- logger.debug("Found state event %r in timeline which doesn't "
- "match state dictionary", timeline_event)
- continue
-
- prev_event_id = timeline_event.unsigned.get("replaces_state", None)
-
- prev_content = timeline_event.unsigned.get('prev_content')
- prev_sender = timeline_event.unsigned.get('prev_sender')
- # Empircally it seems possible for the event to have a
- # "replaces_state" key but not a prev_content or prev_sender
- # markjh conjectures that it could be due to the server not
- # having a copy of that event.
- # If this is the case the we ignore the previous event. This will
- # cause the displayname calculations on the client to be incorrect
- if prev_event_id is None or not prev_content or not prev_sender:
- logger.debug(
- "Removing %r from the state dict, as it is missing"
- " prev_content (prev_event_id=%r)",
- timeline_event.event_id, prev_event_id
- )
- del result[event_key]
- else:
- logger.debug(
- "Replacing %r with %r in state dict",
- timeline_event.event_id, prev_event_id
- )
- result[event_key] = FrozenEvent({
- "type": timeline_event.type,
- "state_key": timeline_event.state_key,
- "content": prev_content,
- "sender": prev_sender,
- "event_id": prev_event_id,
- "room_id": timeline_event.room_id,
- })
-
- logger.debug("New value: %r", result.get(event_key))
-
- return result
-
def register_servlets(hs, http_server):
SyncRestServlet(hs).register(http_server)
|