summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2016-02-02 15:28:43 +0000
committerErik Johnston <erik@matrix.org>2016-02-02 15:28:43 +0000
commit04ad93e6fdc65a372b114b50aec8c1201f051735 (patch)
treee715eff5335a5b889dab126a27fbbd6bea5cb94d
parentMerge pull request #548 from matrix-org/dbkr/fix_guest_db_column (diff)
parentComments (diff)
downloadsynapse-04ad93e6fdc65a372b114b50aec8c1201f051735.tar.xz
Merge pull request #545 from matrix-org/erikj/sync
Move /sync state calculations from rest to handler
-rw-r--r--synapse/handlers/sync.py165
-rw-r--r--synapse/rest/client/v2_alpha/sync.py75
2 files changed, 99 insertions, 141 deletions
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 075566417f..8d8d10da33 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -23,6 +23,7 @@ from twisted.internet import defer
 
 import collections
 import logging
+import itertools
 
 logger = logging.getLogger(__name__)
 
@@ -672,35 +673,10 @@ class SyncHandler(BaseHandler):
                                            account_data_by_room,
                                            all_ephemeral_by_room,
                                            batch, full_state=False):
-        if full_state:
-            state = yield self.get_state_at(room_id, now_token)
-
-        elif batch.limited:
-            current_state = yield self.get_state_at(room_id, now_token)
-
-            state_at_previous_sync = yield self.get_state_at(
-                room_id, stream_position=since_token
-            )
-
-            state = yield self.compute_state_delta(
-                since_token=since_token,
-                previous_state=state_at_previous_sync,
-                current_state=current_state,
-            )
-        else:
-            state = {
-                (event.type, event.state_key): event
-                for event in batch.events if event.is_state()
-            }
-
-        just_joined = yield self.check_joined_room(sync_config, state)
-        if just_joined:
-            state = yield self.get_state_at(room_id, now_token)
-
-        state = {
-            (e.type, e.state_key): e
-            for e in sync_config.filter_collection.filter_room_state(state.values())
-        }
+        state = yield self.compute_state_delta(
+            room_id, batch, sync_config, since_token, now_token,
+            full_state=full_state
+        )
 
         account_data = self.account_data_for_room(
             room_id, tags_by_room, account_data_by_room
@@ -766,30 +742,11 @@ class SyncHandler(BaseHandler):
 
         logger.debug("Recents %r", batch)
 
-        state_events_at_leave = yield self.store.get_state_for_event(
-            leave_event_id
+        state_events_delta = yield self.compute_state_delta(
+            room_id, batch, sync_config, since_token, leave_token,
+            full_state=full_state
         )
 
-        if not full_state:
-            state_at_previous_sync = yield self.get_state_at(
-                room_id, stream_position=since_token
-            )
-
-            state_events_delta = yield self.compute_state_delta(
-                since_token=since_token,
-                previous_state=state_at_previous_sync,
-                current_state=state_events_at_leave,
-            )
-        else:
-            state_events_delta = state_events_at_leave
-
-        state_events_delta = {
-            (e.type, e.state_key): e
-            for e in sync_config.filter_collection.filter_room_state(
-                state_events_delta.values()
-            )
-        }
-
         account_data = self.account_data_for_room(
             room_id, tags_by_room, account_data_by_room
         )
@@ -843,15 +800,19 @@ class SyncHandler(BaseHandler):
             state = {}
         defer.returnValue(state)
 
-    def compute_state_delta(self, since_token, previous_state, current_state):
-        """ Works out the differnce in state between the current state and the
-        state the client got when it last performed a sync.
-
-        :param str since_token: the point we are comparing against
-        :param dict[(str,str), synapse.events.FrozenEvent] previous_state: the
-            state to compare to
-        :param dict[(str,str), synapse.events.FrozenEvent] current_state: the
-            new state
+    @defer.inlineCallbacks
+    def compute_state_delta(self, room_id, batch, sync_config, since_token, now_token,
+                            full_state):
+        """ Works out the differnce in state between the start of the timeline
+        and the previous sync.
+
+        :param str room_id
+        :param TimelineBatch batch: The timeline batch for the room that will
+            be sent to the user.
+        :param sync_config
+        :param str since_token: Token of the end of the previous batch. May be None.
+        :param str now_token: Token of the end of the current batch.
+        :param bool full_state: Whether to force returning the full state.
 
         :returns A new event dictionary
         """
@@ -860,12 +821,50 @@ class SyncHandler(BaseHandler):
         # updates even if they occured logically before the previous event.
         # TODO(mjark) Check for new redactions in the state events.
 
-        state_delta = {}
-        for key, event in current_state.iteritems():
-            if (key not in previous_state or
-                    previous_state[key].event_id != event.event_id):
-                state_delta[key] = event
-        return state_delta
+        if full_state:
+            if batch:
+                state = yield self.store.get_state_for_event(batch.events[0].event_id)
+            else:
+                state = yield self.get_state_at(
+                    room_id, stream_position=now_token
+                )
+
+            timeline_state = {
+                (event.type, event.state_key): event
+                for event in batch.events if event.is_state()
+            }
+
+            state = _calculate_state(
+                timeline_contains=timeline_state,
+                timeline_start=state,
+                previous={},
+            )
+        elif batch.limited:
+            state_at_previous_sync = yield self.get_state_at(
+                room_id, stream_position=since_token
+            )
+
+            state_at_timeline_start = yield self.store.get_state_for_event(
+                batch.events[0].event_id
+            )
+
+            timeline_state = {
+                (event.type, event.state_key): event
+                for event in batch.events if event.is_state()
+            }
+
+            state = _calculate_state(
+                timeline_contains=timeline_state,
+                timeline_start=state_at_timeline_start,
+                previous=state_at_previous_sync,
+            )
+        else:
+            state = {}
+
+        defer.returnValue({
+            (e.type, e.state_key): e
+            for e in sync_config.filter_collection.filter_room_state(state.values())
+        })
 
     def check_joined_room(self, sync_config, state_delta):
         """
@@ -912,3 +911,37 @@ def _action_has_highlight(actions):
             pass
 
     return False
+
+
+def _calculate_state(timeline_contains, timeline_start, previous):
+    """Works out what state to include in a sync response.
+
+    Args:
+        timeline_contains (dict): state in the timeline
+        timeline_start (dict): state at the start of the timeline
+        previous (dict): state at the end of the previous sync (or empty dict
+            if this is an initial sync)
+
+    Returns:
+        dict
+    """
+    event_id_to_state = {
+        e.event_id: e
+        for e in itertools.chain(
+            timeline_contains.values(),
+            previous.values(),
+            timeline_start.values(),
+        )
+    }
+
+    tc_ids = set(e.event_id for e in timeline_contains.values())
+    p_ids = set(e.event_id for e in previous.values())
+    ts_ids = set(e.event_id for e in timeline_start.values())
+
+    state_ids = (ts_ids - p_ids) - tc_ids
+
+    evs = (event_id_to_state[e] for e in state_ids)
+    return {
+        (e.type, e.state_key): e
+        for e in evs
+    }
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 07b5b5dfd5..140ce2704b 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -20,7 +20,6 @@ from synapse.http.servlet import (
 )
 from synapse.handlers.sync import SyncConfig
 from synapse.types import StreamToken
-from synapse.events import FrozenEvent
 from synapse.events.utils import (
     serialize_event, format_event_for_client_v2_without_room_id,
 )
@@ -287,9 +286,6 @@ class SyncRestServlet(RestServlet):
         state_dict = room.state
         timeline_events = room.timeline.events
 
-        state_dict = SyncRestServlet._rollback_state_for_timeline(
-            state_dict, timeline_events)
-
         state_events = state_dict.values()
 
         serialized_state = [serialize(e) for e in state_events]
@@ -314,77 +310,6 @@ class SyncRestServlet(RestServlet):
 
         return result
 
-    @staticmethod
-    def _rollback_state_for_timeline(state, timeline):
-        """
-        Wind the state dictionary backwards, so that it represents the
-        state at the start of the timeline, rather than at the end.
-
-        :param dict[(str, str), synapse.events.EventBase] state: the
-            state dictionary. Will be updated to the state before the timeline.
-        :param list[synapse.events.EventBase] timeline: the event timeline
-        :return: updated state dictionary
-        """
-
-        result = state.copy()
-
-        for timeline_event in reversed(timeline):
-            if not timeline_event.is_state():
-                continue
-
-            event_key = (timeline_event.type, timeline_event.state_key)
-
-            logger.debug("Considering %s for removal", event_key)
-
-            state_event = result.get(event_key)
-            if (state_event is None or
-                    state_event.event_id != timeline_event.event_id):
-                # the event in the timeline isn't present in the state
-                # dictionary.
-                #
-                # the most likely cause for this is that there was a fork in
-                # the event graph, and the state is no longer valid. Really,
-                # the event shouldn't be in the timeline. We're going to ignore
-                # it for now, however.
-                logger.debug("Found state event %r in timeline which doesn't "
-                             "match state dictionary", timeline_event)
-                continue
-
-            prev_event_id = timeline_event.unsigned.get("replaces_state", None)
-
-            prev_content = timeline_event.unsigned.get('prev_content')
-            prev_sender = timeline_event.unsigned.get('prev_sender')
-            # Empircally it seems possible for the event to have a
-            # "replaces_state" key but not a prev_content or prev_sender
-            # markjh conjectures that it could be due to the server not
-            # having a copy of that event.
-            # If this is the case the we ignore the previous event. This will
-            # cause the displayname calculations on the client to be incorrect
-            if prev_event_id is None or not prev_content or not prev_sender:
-                logger.debug(
-                    "Removing %r from the state dict, as it is missing"
-                    " prev_content (prev_event_id=%r)",
-                    timeline_event.event_id, prev_event_id
-                )
-                del result[event_key]
-            else:
-                logger.debug(
-                    "Replacing %r with %r in state dict",
-                    timeline_event.event_id, prev_event_id
-                )
-                result[event_key] = FrozenEvent({
-                    "type": timeline_event.type,
-                    "state_key": timeline_event.state_key,
-                    "content": prev_content,
-                    "sender": prev_sender,
-                    "event_id": prev_event_id,
-                    "room_id": timeline_event.room_id,
-                })
-
-            logger.debug("New value: %r", result.get(event_key))
-
-        return result
-
 
 def register_servlets(hs, http_server):
     SyncRestServlet(hs).register(http_server)