summary refs log tree commit diff
path: root/synapse/handlers/sync.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/sync.py')
-rw-r--r--synapse/handlers/sync.py126
1 files changed, 90 insertions, 36 deletions
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index c8dfd02e7b..b5962f4f5a 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -35,6 +35,7 @@ SyncConfig = collections.namedtuple("SyncConfig", [
     "filter_collection",
     "is_guest",
     "request_key",
+    "device_id",
 ])
 
 
@@ -113,6 +114,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
     "joined",  # JoinedSyncResult for each joined room.
     "invited",  # InvitedSyncResult for each invited room.
     "archived",  # ArchivedSyncResult for each archived room.
+    "to_device",  # List of direct messages for the device.
 ])):
     __slots__ = []
 
@@ -126,7 +128,8 @@ class SyncResult(collections.namedtuple("SyncResult", [
             self.joined or
             self.invited or
             self.archived or
-            self.account_data
+            self.account_data or
+            self.to_device
         )
 
 
@@ -139,6 +142,7 @@ class SyncHandler(object):
         self.event_sources = hs.get_event_sources()
         self.clock = hs.get_clock()
         self.response_cache = ResponseCache(hs)
+        self.state = hs.get_state_handler()
 
     def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
                                full_state=False):
@@ -355,11 +359,11 @@ class SyncHandler(object):
         Returns:
             A Deferred map from ((type, state_key)->Event)
         """
-        state = yield self.store.get_state_for_event(event.event_id)
+        state_ids = yield self.store.get_state_ids_for_event(event.event_id)
         if event.is_state():
-            state = state.copy()
-            state[(event.type, event.state_key)] = event
-        defer.returnValue(state)
+            state_ids = state_ids.copy()
+            state_ids[(event.type, event.state_key)] = event.event_id
+        defer.returnValue(state_ids)
 
     @defer.inlineCallbacks
     def get_state_at(self, room_id, stream_position):
@@ -412,57 +416,61 @@ class SyncHandler(object):
         with Measure(self.clock, "compute_state_delta"):
             if full_state:
                 if batch:
-                    current_state = yield self.store.get_state_for_event(
+                    current_state_ids = yield self.store.get_state_ids_for_event(
                         batch.events[-1].event_id
                     )
 
-                    state = yield self.store.get_state_for_event(
+                    state_ids = yield self.store.get_state_ids_for_event(
                         batch.events[0].event_id
                     )
                 else:
-                    current_state = yield self.get_state_at(
+                    current_state_ids = yield self.get_state_at(
                         room_id, stream_position=now_token
                     )
 
-                    state = current_state
+                    state_ids = current_state_ids
 
                 timeline_state = {
-                    (event.type, event.state_key): event
+                    (event.type, event.state_key): event.event_id
                     for event in batch.events if event.is_state()
                 }
 
-                state = _calculate_state(
+                state_ids = _calculate_state(
                     timeline_contains=timeline_state,
-                    timeline_start=state,
+                    timeline_start=state_ids,
                     previous={},
-                    current=current_state,
+                    current=current_state_ids,
                 )
             elif batch.limited:
                 state_at_previous_sync = yield self.get_state_at(
                     room_id, stream_position=since_token
                 )
 
-                current_state = yield self.store.get_state_for_event(
+                current_state_ids = yield self.store.get_state_ids_for_event(
                     batch.events[-1].event_id
                 )
 
-                state_at_timeline_start = yield self.store.get_state_for_event(
+                state_at_timeline_start = yield self.store.get_state_ids_for_event(
                     batch.events[0].event_id
                 )
 
                 timeline_state = {
-                    (event.type, event.state_key): event
+                    (event.type, event.state_key): event.event_id
                     for event in batch.events if event.is_state()
                 }
 
-                state = _calculate_state(
+                state_ids = _calculate_state(
                     timeline_contains=timeline_state,
                     timeline_start=state_at_timeline_start,
                     previous=state_at_previous_sync,
-                    current=current_state,
+                    current=current_state_ids,
                 )
             else:
-                state = {}
+                state_ids = {}
+
+        state = {}
+        if state_ids:
+            state = yield self.store.get_events(state_ids.values())
 
         defer.returnValue({
             (e.type, e.state_key): e
@@ -527,16 +535,58 @@ class SyncHandler(object):
             sync_result_builder, newly_joined_rooms, newly_joined_users
         )
 
+        yield self._generate_sync_entry_for_to_device(sync_result_builder)
+
         defer.returnValue(SyncResult(
             presence=sync_result_builder.presence,
             account_data=sync_result_builder.account_data,
             joined=sync_result_builder.joined,
             invited=sync_result_builder.invited,
             archived=sync_result_builder.archived,
+            to_device=sync_result_builder.to_device,
             next_batch=sync_result_builder.now_token,
         ))
 
     @defer.inlineCallbacks
+    def _generate_sync_entry_for_to_device(self, sync_result_builder):
+        """Generates the portion of the sync response. Populates
+        `sync_result_builder` with the result.
+
+        Args:
+            sync_result_builder(SyncResultBuilder)
+
+        Returns:
+            Deferred(dict): A dictionary containing the per room account data.
+        """
+        user_id = sync_result_builder.sync_config.user.to_string()
+        device_id = sync_result_builder.sync_config.device_id
+        now_token = sync_result_builder.now_token
+        since_stream_id = 0
+        if sync_result_builder.since_token is not None:
+            since_stream_id = int(sync_result_builder.since_token.to_device_key)
+
+        if since_stream_id != int(now_token.to_device_key):
+            # We only delete messages when a new message comes in, but that's
+            # fine so long as we delete them at some point.
+
+            logger.debug("Deleting messages up to %d", since_stream_id)
+            yield self.store.delete_messages_for_device(
+                user_id, device_id, since_stream_id
+            )
+
+            logger.debug("Getting messages up to %d", now_token.to_device_key)
+            messages, stream_id = yield self.store.get_new_messages_for_device(
+                user_id, device_id, since_stream_id, now_token.to_device_key
+            )
+            logger.debug("Got messages up to %d: %r", stream_id, messages)
+            sync_result_builder.now_token = now_token.copy_and_replace(
+                "to_device_key", stream_id
+            )
+            sync_result_builder.to_device = messages
+        else:
+            sync_result_builder.to_device = []
+
+    @defer.inlineCallbacks
     def _generate_sync_entry_for_account_data(self, sync_result_builder):
         """Generates the account data portion of the sync response. Populates
         `sync_result_builder` with the result.
@@ -626,7 +676,7 @@ class SyncHandler(object):
 
         extra_users_ids = set(newly_joined_users)
         for room_id in newly_joined_rooms:
-            users = yield self.store.get_users_in_room(room_id)
+            users = yield self.state.get_current_user_in_room(room_id)
             extra_users_ids.update(users)
         extra_users_ids.discard(user.to_string())
 
@@ -766,8 +816,13 @@ class SyncHandler(object):
             # the last sync (even if we have since left). This is to make sure
             # we do send down the room, and with full state, where necessary
             if room_id in joined_room_ids or has_join:
-                old_state = yield self.get_state_at(room_id, since_token)
-                old_mem_ev = old_state.get((EventTypes.Member, user_id), None)
+                old_state_ids = yield self.get_state_at(room_id, since_token)
+                old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
+                old_mem_ev = None
+                if old_mem_ev_id:
+                    old_mem_ev = yield self.store.get_event(
+                        old_mem_ev_id, allow_none=True
+                    )
                 if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
                     newly_joined_rooms.append(room_id)
 
@@ -1059,27 +1114,25 @@ def _calculate_state(timeline_contains, timeline_start, previous, current):
     Returns:
         dict
     """
-    event_id_to_state = {
-        e.event_id: e
-        for e in itertools.chain(
-            timeline_contains.values(),
-            previous.values(),
-            timeline_start.values(),
-            current.values(),
+    event_id_to_key = {
+        e: key
+        for key, e in itertools.chain(
+            timeline_contains.items(),
+            previous.items(),
+            timeline_start.items(),
+            current.items(),
         )
     }
 
-    c_ids = set(e.event_id for e in current.values())
-    tc_ids = set(e.event_id for e in timeline_contains.values())
-    p_ids = set(e.event_id for e in previous.values())
-    ts_ids = set(e.event_id for e in timeline_start.values())
+    c_ids = set(e for e in current.values())
+    tc_ids = set(e for e in timeline_contains.values())
+    p_ids = set(e for e in previous.values())
+    ts_ids = set(e for e in timeline_start.values())
 
     state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
 
-    evs = (event_id_to_state[e] for e in state_ids)
     return {
-        (e.type, e.state_key): e
-        for e in evs
+        event_id_to_key[e]: e for e in state_ids
     }
 
 
@@ -1103,6 +1156,7 @@ class SyncResultBuilder(object):
         self.joined = []
         self.invited = []
         self.archived = []
+        self.device = []
 
 
 class RoomSyncResultBuilder(object):