summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/_base.py2
-rw-r--r--synapse/handlers/directory.py4
-rw-r--r--synapse/handlers/events.py2
-rw-r--r--synapse/handlers/presence.py2
-rw-r--r--synapse/handlers/register.py6
-rw-r--r--synapse/handlers/room.py11
-rw-r--r--synapse/handlers/sync.py167
7 files changed, 117 insertions, 77 deletions
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 744a9ee507..1423df6cf3 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -147,7 +147,7 @@ class BaseHandler(object):
         )
         if not allowed:
             raise LimitExceededError(
-                retry_after_ms=int(1000*(time_allowed - time_now)),
+                retry_after_ms=int(1000 * (time_allowed - time_now)),
             )
 
     @defer.inlineCallbacks
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 691564c651..4efecb1ffd 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -175,8 +175,8 @@ class DirectoryHandler(BaseHandler):
         # If this server is in the list of servers, return it first.
         if self.server_name in servers:
             servers = (
-                [self.server_name]
-                + [s for s in servers if s != self.server_name]
+                [self.server_name] +
+                [s for s in servers if s != self.server_name]
             )
         else:
             servers = list(servers)
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 254b483da6..5ad8f3779a 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -130,7 +130,7 @@ class EventStreamHandler(BaseHandler):
 
                 # Add some randomness to this value to try and mitigate against
                 # thundering herds on restart.
-                timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
+                timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1))
 
             events, tokens = yield self.notifier.get_events_for(
                 auth_user, pagin_config, timeout,
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index d36eb3b8d7..d0c21ff5c9 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -34,7 +34,7 @@ metrics = synapse.metrics.get_metrics_for(__name__)
 
 
 # Don't bother bumping "last active" time if it differs by less than 60 seconds
-LAST_ACTIVE_GRANULARITY = 60*1000
+LAST_ACTIVE_GRANULARITY = 60 * 1000
 
 # Keep no more than this number of offline serial revisions
 MAX_OFFLINE_SERIALS = 1000
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index c11b98d0b7..b8fbcf9233 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -139,7 +139,9 @@ class RegistrationHandler(BaseHandler):
                     yield self.store.register(
                         user_id=user_id,
                         token=token,
-                        password_hash=password_hash)
+                        password_hash=password_hash,
+                        make_guest=make_guest
+                    )
 
                     yield registered_user(self.distributor, user)
                 except SynapseError:
@@ -211,7 +213,7 @@ class RegistrationHandler(BaseHandler):
                 400,
                 "User ID must only contain characters which do not"
                 " require URL encoding."
-                )
+            )
         user = UserID(localpart, self.hs.hostname)
         user_id = user.to_string()
 
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 1b3f624c67..0daf2ce28e 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -18,7 +18,7 @@ from twisted.internet import defer
 
 from ._base import BaseHandler
 
-from synapse.types import UserID, RoomAlias, RoomID
+from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken
 from synapse.api.constants import (
     EventTypes, Membership, JoinRules, RoomCreationPreset,
 )
@@ -967,7 +967,7 @@ class RoomContextHandler(BaseHandler):
         Returns:
             dict, or None if the event isn't found
         """
-        before_limit = math.floor(limit/2.)
+        before_limit = math.floor(limit / 2.)
         after_limit = limit - before_limit
 
         now_token = yield self.hs.get_event_sources().get_current_token()
@@ -1037,6 +1037,11 @@ class RoomEventSource(object):
 
         to_key = yield self.get_current_key()
 
+        from_token = RoomStreamToken.parse(from_key)
+        if from_token.topological:
+            logger.warn("Stream has topological part!!!! %r", from_key)
+            from_key = "s%s" % (from_token.stream,)
+
         app_service = yield self.store.get_app_service_by_user_id(
             user.to_string()
         )
@@ -1048,7 +1053,7 @@ class RoomEventSource(object):
                 limit=limit,
             )
         else:
-            room_events = yield self.store.get_room_changes_for_user(
+            room_events = yield self.store.get_membership_changes_for_user(
                 user.to_string(), from_key, to_key
             )
 
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 075566417f..dc686db541 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -23,6 +23,7 @@ from twisted.internet import defer
 
 import collections
 import logging
+import itertools
 
 logger = logging.getLogger(__name__)
 
@@ -478,7 +479,7 @@ class SyncHandler(BaseHandler):
         )
 
         # Get a list of membership change events that have happened.
-        rooms_changed = yield self.store.get_room_changes_for_user(
+        rooms_changed = yield self.store.get_membership_changes_for_user(
             user_id, since_token.room_key, now_token.room_key
         )
 
@@ -672,35 +673,10 @@ class SyncHandler(BaseHandler):
                                            account_data_by_room,
                                            all_ephemeral_by_room,
                                            batch, full_state=False):
-        if full_state:
-            state = yield self.get_state_at(room_id, now_token)
-
-        elif batch.limited:
-            current_state = yield self.get_state_at(room_id, now_token)
-
-            state_at_previous_sync = yield self.get_state_at(
-                room_id, stream_position=since_token
-            )
-
-            state = yield self.compute_state_delta(
-                since_token=since_token,
-                previous_state=state_at_previous_sync,
-                current_state=current_state,
-            )
-        else:
-            state = {
-                (event.type, event.state_key): event
-                for event in batch.events if event.is_state()
-            }
-
-        just_joined = yield self.check_joined_room(sync_config, state)
-        if just_joined:
-            state = yield self.get_state_at(room_id, now_token)
-
-        state = {
-            (e.type, e.state_key): e
-            for e in sync_config.filter_collection.filter_room_state(state.values())
-        }
+        state = yield self.compute_state_delta(
+            room_id, batch, sync_config, since_token, now_token,
+            full_state=full_state
+        )
 
         account_data = self.account_data_for_room(
             room_id, tags_by_room, account_data_by_room
@@ -766,30 +742,11 @@ class SyncHandler(BaseHandler):
 
         logger.debug("Recents %r", batch)
 
-        state_events_at_leave = yield self.store.get_state_for_event(
-            leave_event_id
+        state_events_delta = yield self.compute_state_delta(
+            room_id, batch, sync_config, since_token, leave_token,
+            full_state=full_state
         )
 
-        if not full_state:
-            state_at_previous_sync = yield self.get_state_at(
-                room_id, stream_position=since_token
-            )
-
-            state_events_delta = yield self.compute_state_delta(
-                since_token=since_token,
-                previous_state=state_at_previous_sync,
-                current_state=state_events_at_leave,
-            )
-        else:
-            state_events_delta = state_events_at_leave
-
-        state_events_delta = {
-            (e.type, e.state_key): e
-            for e in sync_config.filter_collection.filter_room_state(
-                state_events_delta.values()
-            )
-        }
-
         account_data = self.account_data_for_room(
             room_id, tags_by_room, account_data_by_room
         )
@@ -843,15 +800,19 @@ class SyncHandler(BaseHandler):
             state = {}
         defer.returnValue(state)
 
-    def compute_state_delta(self, since_token, previous_state, current_state):
-        """ Works out the differnce in state between the current state and the
-        state the client got when it last performed a sync.
-
-        :param str since_token: the point we are comparing against
-        :param dict[(str,str), synapse.events.FrozenEvent] previous_state: the
-            state to compare to
-        :param dict[(str,str), synapse.events.FrozenEvent] current_state: the
-            new state
+    @defer.inlineCallbacks
+    def compute_state_delta(self, room_id, batch, sync_config, since_token, now_token,
+                            full_state):
+        """ Works out the differnce in state between the start of the timeline
+        and the previous sync.
+
+        :param str room_id
+        :param TimelineBatch batch: The timeline batch for the room that will
+            be sent to the user.
+        :param sync_config
+        :param str since_token: Token of the end of the previous batch. May be None.
+        :param str now_token: Token of the end of the current batch.
+        :param bool full_state: Whether to force returning the full state.
 
         :returns A new event dictionary
         """
@@ -860,12 +821,50 @@ class SyncHandler(BaseHandler):
         # updates even if they occured logically before the previous event.
         # TODO(mjark) Check for new redactions in the state events.
 
-        state_delta = {}
-        for key, event in current_state.iteritems():
-            if (key not in previous_state or
-                    previous_state[key].event_id != event.event_id):
-                state_delta[key] = event
-        return state_delta
+        if full_state:
+            if batch:
+                state = yield self.store.get_state_for_event(batch.events[0].event_id)
+            else:
+                state = yield self.get_state_at(
+                    room_id, stream_position=now_token
+                )
+
+            timeline_state = {
+                (event.type, event.state_key): event
+                for event in batch.events if event.is_state()
+            }
+
+            state = _calculate_state(
+                timeline_contains=timeline_state,
+                timeline_start=state,
+                previous={},
+            )
+        elif batch.limited:
+            state_at_previous_sync = yield self.get_state_at(
+                room_id, stream_position=since_token
+            )
+
+            state_at_timeline_start = yield self.store.get_state_for_event(
+                batch.events[0].event_id
+            )
+
+            timeline_state = {
+                (event.type, event.state_key): event
+                for event in batch.events if event.is_state()
+            }
+
+            state = _calculate_state(
+                timeline_contains=timeline_state,
+                timeline_start=state_at_timeline_start,
+                previous=state_at_previous_sync,
+            )
+        else:
+            state = {}
+
+        defer.returnValue({
+            (e.type, e.state_key): e
+            for e in sync_config.filter_collection.filter_room_state(state.values())
+        })
 
     def check_joined_room(self, sync_config, state_delta):
         """
@@ -912,3 +911,37 @@ def _action_has_highlight(actions):
             pass
 
     return False
+
+
+def _calculate_state(timeline_contains, timeline_start, previous):
+    """Works out what state to include in a sync response.
+
+    Args:
+        timeline_contains (dict): state in the timeline
+        timeline_start (dict): state at the start of the timeline
+        previous (dict): state at the end of the previous sync (or empty dict
+            if this is an initial sync)
+
+    Returns:
+        dict
+    """
+    event_id_to_state = {
+        e.event_id: e
+        for e in itertools.chain(
+            timeline_contains.values(),
+            previous.values(),
+            timeline_start.values(),
+        )
+    }
+
+    tc_ids = set(e.event_id for e in timeline_contains.values())
+    p_ids = set(e.event_id for e in previous.values())
+    ts_ids = set(e.event_id for e in timeline_start.values())
+
+    state_ids = (ts_ids - p_ids) - tc_ids
+
+    evs = (event_id_to_state[e] for e in state_ids)
+    return {
+        (e.type, e.state_key): e
+        for e in evs
+    }