summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/handlers/_base.py7
-rw-r--r--synapse/handlers/federation.py25
-rw-r--r--synapse/handlers/presence.py4
-rw-r--r--synapse/handlers/typing.py4
-rw-r--r--synapse/notifier.py75
-rw-r--r--synapse/storage/events.py3
-rw-r--r--synapse/types.py19
-rw-r--r--tests/handlers/test_federation.py4
-rw-r--r--tests/handlers/test_room.py8
-rw-r--r--tests/handlers/test_typing.py12
-rw-r--r--tests/rest/client/v1/test_presence.py15
-rw-r--r--tests/utils.py2
12 files changed, 123 insertions, 55 deletions
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index ddc5c21e7d..833ff41377 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -105,7 +105,9 @@ class BaseHandler(object):
         if not suppress_auth:
             self.auth.check(event, auth_events=context.current_state)
 
-        yield self.store.persist_event(event, context=context)
+        (event_stream_id, max_stream_id) = yield self.store.persist_event(
+            event, context=context
+        )
 
         federation_handler = self.hs.get_handlers().federation_handler
 
@@ -142,7 +144,8 @@ class BaseHandler(object):
         with PreserveLoggingContext():
             # Don't block waiting on waking up all the listeners.
             notify_d = self.notifier.on_new_room_event(
-                event, extra_users=extra_users
+                event, event_stream_id, max_stream_id,
+                extra_users=extra_users
             )
 
         def log_failure(f):
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 7d9906039e..bc0f7b0ee7 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -160,7 +160,7 @@ class FederationHandler(BaseHandler):
                     )
 
         try:
-            yield self._handle_new_event(
+            _, event_stream_id, max_stream_id = yield self._handle_new_event(
                 origin,
                 event,
                 state=state,
@@ -203,7 +203,8 @@ class FederationHandler(BaseHandler):
 
             with PreserveLoggingContext():
                 d = self.notifier.on_new_room_event(
-                    event, extra_users=extra_users
+                    event, event_stream_id, max_stream_id,
+                    extra_users=extra_users
                 )
 
             def log_failure(f):
@@ -561,7 +562,7 @@ class FederationHandler(BaseHandler):
                 if e.event_id in auth_ids
             }
 
-            yield self._handle_new_event(
+            _, event_stream_id, max_stream_id = yield self._handle_new_event(
                 origin,
                 new_event,
                 state=state,
@@ -571,7 +572,8 @@ class FederationHandler(BaseHandler):
 
             with PreserveLoggingContext():
                 d = self.notifier.on_new_room_event(
-                    new_event, extra_users=[joinee]
+                    new_event, event_stream_id, max_stream_id,
+                    extra_users=[joinee]
                 )
 
             def log_failure(f):
@@ -637,7 +639,9 @@ class FederationHandler(BaseHandler):
 
         event.internal_metadata.outlier = False
 
-        context = yield self._handle_new_event(origin, event)
+        context, event_stream_id, max_stream_id = yield self._handle_new_event(
+            origin, event
+        )
 
         logger.debug(
             "on_send_join_request: After _handle_new_event: %s, sigs: %s",
@@ -653,7 +657,7 @@ class FederationHandler(BaseHandler):
 
         with PreserveLoggingContext():
             d = self.notifier.on_new_room_event(
-                event, extra_users=extra_users
+                event, event_stream_id, max_stream_id, extra_users=extra_users
             )
 
         def log_failure(f):
@@ -727,7 +731,7 @@ class FederationHandler(BaseHandler):
 
         context = yield self.state_handler.compute_event_context(event)
 
-        yield self.store.persist_event(
+        event_stream_id, max_stream_id = yield self.store.persist_event(
             event,
             context=context,
             backfilled=False,
@@ -736,7 +740,8 @@ class FederationHandler(BaseHandler):
         target_user = UserID.from_string(event.state_key)
         with PreserveLoggingContext():
             d = self.notifier.on_new_room_event(
-                event, extra_users=[target_user],
+                event, event_stream_id, max_stream_id,
+                extra_users=[target_user],
             )
 
         def log_failure(f):
@@ -914,7 +919,7 @@ class FederationHandler(BaseHandler):
             )
             raise
 
-        yield self.store.persist_event(
+        event_stream_id, max_stream_id = yield self.store.persist_event(
             event,
             context=context,
             backfilled=backfilled,
@@ -922,7 +927,7 @@ class FederationHandler(BaseHandler):
             current_state=current_state,
         )
 
-        defer.returnValue(context)
+        defer.returnValue((context, event_stream_id, max_stream_id))
 
     @defer.inlineCallbacks
     def on_query_auth(self, origin, event_id, remote_auth_chain, rejects,
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 28688d532d..7db4b062d2 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -345,6 +345,8 @@ class PresenceHandler(BaseHandler):
         curr_users = yield rm_handler.get_room_members(room_id)
 
         for local_user in [c for c in curr_users if self.hs.is_mine(c)]:
+            statuscache = self._get_or_offline_usercache(local_user)
+            statuscache.update({}, serial=self._user_cachemap_latest_serial)
             self.push_update_to_local_and_remote(
                 observed_user=local_user,
                 users_to_push=[user],
@@ -820,6 +822,8 @@ class PresenceHandler(BaseHandler):
                                room_ids=[], statuscache=None):
         with PreserveLoggingContext():
             self.notifier.on_new_user_event(
+                "presence_key",
+                self._user_cachemap_latest_serial,
                 users_to_push,
                 room_ids,
             )
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 64fe51aa3e..a9895292c2 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -218,7 +218,9 @@ class TypingNotificationHandler(BaseHandler):
         self._room_serials[room_id] = self._latest_room_serial
 
         with PreserveLoggingContext():
-            self.notifier.on_new_user_event(rooms=[room_id])
+            self.notifier.on_new_user_event(
+                "typing_key", self._latest_room_serial, rooms=[room_id]
+            )
 
 
 class TypingNotificationEventSource(object):
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 214a2b28ca..4d10c05038 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -52,12 +52,11 @@ class _NotificationListener(object):
     def notified(self):
         return self.deferred.called
 
-    def notify(self):
+    def notify(self, token):
         """ Inform whoever is listening about the new events.
         """
-
         try:
-            self.deferred.callback(None)
+            self.deferred.callback(token)
         except defer.AlreadyCalledError:
             pass
 
@@ -73,15 +72,18 @@ class _NotifierUserStream(object):
     """
 
     def __init__(self, user, rooms, current_token, appservice=None):
-        self.user = user
+        self.user = str(user)
         self.appservice = appservice
         self.listeners = set()
-        self.rooms = rooms
+        self.rooms = set(rooms)
         self.current_token = current_token
 
-    def notify(self, new_token):
+    def notify(self, stream_key, stream_id):
+        self.current_token = self.current_token.copy_and_replace(
+            stream_key, stream_id
+        )
         for listener in self.listeners:
-            listener.notify(new_token)
+            listener.notify(self.current_token)
         self.listeners.clear()
 
     def remove(self, notifier):
@@ -117,6 +119,7 @@ class Notifier(object):
 
         self.event_sources = hs.get_event_sources()
         self.store = hs.get_datastore()
+        self.pending_new_room_events = []
 
         self.clock = hs.get_clock()
 
@@ -153,9 +156,21 @@ class Notifier(object):
             lambda: count(bool, self.appservice_to_user_streams.values()),
         )
 
+    def notify_pending_new_room_events(self, max_room_stream_id):
+        pending = sorted(self.pending_new_room_events)
+        self.pending_new_room_events = []
+        for event, room_stream_id, extra_users in pending:
+            if room_stream_id > max_room_stream_id:
+                self.pending_new_room_events.append((
+                    event, room_stream_id, extra_users
+                ))
+            else:
+                self._on_new_room_event(event, room_stream_id, extra_users)
+
     @log_function
     @defer.inlineCallbacks
-    def on_new_room_event(self, event, new_token, extra_users=[]):
+    def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
+                          extra_users=[]):
         """ Used by handlers to inform the notifier something has happened
         in the room, room event wise.
 
@@ -163,8 +178,18 @@ class Notifier(object):
         listening to the room, and any listeners for the users in the
         `extra_users` param.
         """
-        assert isinstance(new_token, StreamToken)
         yield run_on_reactor()
+
+        self.notify_pending_new_room_events(max_room_stream_id)
+
+        if room_stream_id > max_room_stream_id:
+            self.pending_new_room_events.append((
+                event, room_stream_id, extra_users
+            ))
+        else:
+            self._on_new_room_event(event, room_stream_id, extra_users)
+
+    def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
         # poke any interested application service.
         self.hs.get_handlers().appservice_handler.notify_interested_services(
             event
@@ -197,33 +222,32 @@ class Notifier(object):
 
         for user_stream in user_streams:
             try:
-                user_stream.notify(new_token)
+                user_stream.notify("room_key", "s%d" % (room_stream_id,))
             except:
                 logger.exception("Failed to notify listener")
 
     @defer.inlineCallbacks
     @log_function
-    def on_new_user_event(self, new_token, users=[], rooms=[]):
+    def on_new_user_event(self, stream_key, new_token, users=[], rooms=[]):
         """ Used to inform listeners that something has happend
         presence/user event wise.
 
         Will wake up all listeners for the given users and rooms.
         """
-        assert isinstance(new_token, StreamToken)
         yield run_on_reactor()
         user_streams = set()
 
         for user in users:
             user_stream = self.user_to_user_stream.get(user)
-            if user_stream:
-                user_stream.add(user_stream)
+            if user_stream is not None:
+                user_streams.add(user_stream)
 
         for room in rooms:
             user_streams |= self.room_to_user_streams.get(room, set())
 
         for user_stream in user_streams:
             try:
-                user_streams.notify(new_token)
+                user_stream.notify(stream_key, new_token)
             except:
                 logger.exception("Failed to notify listener")
 
@@ -236,12 +260,12 @@ class Notifier(object):
 
         deferred = defer.Deferred()
 
-        user_stream = self.user_to_user_streams.get(user)
+        user = str(user)
+        user_stream = self.user_to_user_stream.get(user)
         if user_stream is None:
-            appservice = yield self.store.get_app_service_by_user_id(
-                user.to_string()
-            )
+            appservice = yield self.store.get_app_service_by_user_id(user)
             current_token = yield self.event_sources.get_current_token()
+            rooms = yield self.store.get_rooms_for_user(user)
             user_stream = _NotifierUserStream(
                 user=user,
                 rooms=rooms,
@@ -252,8 +276,9 @@ class Notifier(object):
         else:
             current_token = user_stream.current_token
 
+        listener = [_NotificationListener(deferred)]
+
         if timeout and not current_token.is_after(from_token):
-            listener = [_NotificationListener(deferred)]
             user_stream.listeners.add(listener[0])
 
         if current_token.is_after(from_token):
@@ -334,7 +359,7 @@ class Notifier(object):
         self.user_to_user_stream[user_stream.user] = user_stream
 
         for room in user_stream.rooms:
-            s = self.room_to_user_stream.setdefault(room, set())
+            s = self.room_to_user_streams.setdefault(room, set())
             s.add(user_stream)
 
         if user_stream.appservice:
@@ -343,10 +368,12 @@ class Notifier(object):
             ).add(user_stream)
 
     def _user_joined_room(self, user, room_id):
+        user = str(user)
         new_user_stream = self.user_to_user_stream.get(user)
-        room_streams = self.room_to_user_streams.setdefault(room_id, set())
-        room_streams.add(new_user_stream)
-        new_user_stream.rooms.add(room_id)
+        if new_user_stream is not None:
+            room_streams = self.room_to_user_streams.setdefault(room_id, set())
+            room_streams.add(new_user_stream)
+            new_user_stream.rooms.add(room_id)
 
 
 def _discard_if_notified(listener_set):
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index a5a6869079..7d6df5f4c6 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -64,6 +64,9 @@ class EventsStore(SQLBaseStore):
         except _RollbackButIsFineException:
             pass
 
+        max_persisted_id = yield self._stream_id_gen.get_max_token(self)
+        defer.returnValue((stream_ordering, max_persisted_id))
+
     @defer.inlineCallbacks
     def get_event(self, event_id, check_redacted=True,
                   get_prev_content=False, allow_rejected=False,
diff --git a/synapse/types.py b/synapse/types.py
index 0f16867d75..d89a04f7c3 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -70,6 +70,8 @@ class DomainSpecificString(
         """Return a string encoding the fields of the structure object."""
         return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain)
 
+    __str__ = to_string
+
     @classmethod
     def create(cls, localpart, domain,):
         return cls(localpart=localpart, domain=domain)
@@ -107,7 +109,6 @@ class StreamToken(
     def from_string(cls, string):
         try:
             keys = string.split(cls._SEPARATOR)
-
             return cls(*keys)
         except:
             raise SynapseError(400, "Invalid Token")
@@ -115,6 +116,22 @@ class StreamToken(
     def to_string(self):
         return self._SEPARATOR.join([str(k) for k in self])
 
+    @property
+    def room_stream_id(self):
+        # TODO(markjh): Awful hack to work around hacks in the presence tests
+        if type(self.room_key) is int:
+            return self.room_key
+        else:
+            return int(self.room_key[1:].split("-")[-1])
+
+    def is_after(self, other_token):
+        """Does this token contain events that the other doesn't?"""
+        return (
+            (other_token.room_stream_id < self.room_stream_id)
+            or (int(other_token.presence_key) < int(self.presence_key))
+            or (int(other_token.typing_key) < int(self.typing_key))
+        )
+
     def copy_and_replace(self, key, new_value):
         d = self._asdict()
         d[key] = new_value
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 08d2404b6c..f3821242bc 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -83,7 +83,7 @@ class FederationTestCase(unittest.TestCase):
             "hashes": {"sha256":"AcLrgtUIqqwaGoHhrEvYG1YLDIsVPYJdSRGhkp3jJp8"},
         })
 
-        self.datastore.persist_event.return_value = defer.succeed(None)
+        self.datastore.persist_event.return_value = defer.succeed((1,1))
         self.datastore.get_room.return_value = defer.succeed(True)
         self.auth.check_host_in_room.return_value = defer.succeed(True)
 
@@ -126,5 +126,5 @@ class FederationTestCase(unittest.TestCase):
         self.auth.check.assert_called_once_with(ANY, auth_events={})
 
         self.notifier.on_new_room_event.assert_called_once_with(
-            ANY, extra_users=[]
+            ANY, 1, 1, extra_users=[]
         )
diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py
index 6417f73309..a2d7635995 100644
--- a/tests/handlers/test_room.py
+++ b/tests/handlers/test_room.py
@@ -87,6 +87,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
         self.ratelimiter = hs.get_ratelimiter()
         self.ratelimiter.send_message.return_value = (True, 0)
 
+        self.datastore.persist_event.return_value = (1,1)
+
     @defer.inlineCallbacks
     def test_invite(self):
         room_id = "!foo:red"
@@ -160,7 +162,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
             event, context=context,
         )
         self.notifier.on_new_room_event.assert_called_once_with(
-            event, extra_users=[UserID.from_string(target_user_id)]
+            event, 1, 1, extra_users=[UserID.from_string(target_user_id)]
         )
         self.assertFalse(self.datastore.get_room.called)
         self.assertFalse(self.datastore.store_room.called)
@@ -226,7 +228,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
             event, context=context
         )
         self.notifier.on_new_room_event.assert_called_once_with(
-            event, extra_users=[user]
+            event, 1, 1, extra_users=[user]
         )
 
         join_signal_observer.assert_called_with(
@@ -304,7 +306,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
             event, context=context
         )
         self.notifier.on_new_room_event.assert_called_once_with(
-            event, extra_users=[user]
+            event, 1, 1, extra_users=[user]
         )
 
         leave_signal_observer.assert_called_with(
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index b318d4944a..7ccbe2ea9c 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -183,7 +183,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
         )
 
         self.on_new_user_event.assert_has_calls([
-            call(rooms=[self.room_id]),
+            call('typing_key', 1, rooms=[self.room_id]),
         ])
 
         self.assertEquals(self.event_source.get_current_key(), 1)
@@ -246,7 +246,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
         )
 
         self.on_new_user_event.assert_has_calls([
-            call(rooms=[self.room_id]),
+            call('typing_key', 1, rooms=[self.room_id]),
         ])
 
         self.assertEquals(self.event_source.get_current_key(), 1)
@@ -300,7 +300,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
         )
 
         self.on_new_user_event.assert_has_calls([
-            call(rooms=[self.room_id]),
+            call('typing_key', 1, rooms=[self.room_id]),
         ])
 
         yield put_json.await_calls()
@@ -332,7 +332,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
         )
 
         self.on_new_user_event.assert_has_calls([
-            call(rooms=[self.room_id]),
+            call('typing_key', 1, rooms=[self.room_id]),
         ])
         self.on_new_user_event.reset_mock()
 
@@ -352,7 +352,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
         self.clock.advance_time(11)
 
         self.on_new_user_event.assert_has_calls([
-            call(rooms=[self.room_id]),
+            call('typing_key', 2, rooms=[self.room_id]),
         ])
 
         self.assertEquals(self.event_source.get_current_key(), 2)
@@ -378,7 +378,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
         )
 
         self.on_new_user_event.assert_has_calls([
-            call(rooms=[self.room_id]),
+            call('typing_key', 3, rooms=[self.room_id]),
         ])
         self.on_new_user_event.reset_mock()
 
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 8e0c5fa630..c0c52796ad 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -27,6 +27,7 @@ from synapse.handlers.presence import PresenceHandler
 from synapse.rest.client.v1 import presence
 from synapse.rest.client.v1 import events
 from synapse.types import UserID
+from synapse.util.async import run_on_reactor
 
 
 OFFLINE = PresenceState.OFFLINE
@@ -264,6 +265,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
             datastore=Mock(spec=[
                 "set_presence_state",
                 "get_presence_list",
+                "get_rooms_for_user",
             ]),
             clock=Mock(spec=[
                 "call_later",
@@ -298,6 +300,9 @@ class PresenceEventStreamTestCase(unittest.TestCase):
         self.mock_datastore.get_app_service_by_user_id = Mock(
             return_value=defer.succeed(None)
         )
+        self.mock_datastore.get_rooms_for_user = (
+            lambda u: get_rooms_for_user(UserID.from_string(u))
+        )
 
         def get_profile_displayname(user_id):
             return defer.succeed("Frank")
@@ -350,19 +355,19 @@ class PresenceEventStreamTestCase(unittest.TestCase):
         self.mock_datastore.set_presence_state.return_value = defer.succeed(
             {"state": ONLINE}
         )
-        self.mock_datastore.get_presence_list.return_value = defer.succeed(
-            []
-        )
+        self.mock_datastore.get_presence_list.return_value = defer.succeed([])
 
         yield self.presence.set_state(self.u_banana, self.u_banana,
             state={"presence": ONLINE}
         )
 
+        yield run_on_reactor()
+
         (code, response) = yield self.mock_resource.trigger("GET",
-                "/events?from=0_1_0&timeout=0", None)
+                "/events?from=s0_1_0&timeout=0", None)
 
         self.assertEquals(200, code)
-        self.assertEquals({"start": "0_1_0", "end": "0_2_0", "chunk": [
+        self.assertEquals({"start": "s0_1_0", "end": "s0_2_0", "chunk": [
             {"type": "m.presence",
              "content": {
                  "user_id": "@banana:test",
diff --git a/tests/utils.py b/tests/utils.py
index cc038fecf1..a67530bd63 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -355,7 +355,7 @@ class MemoryDataStore(object):
         return []
 
     def get_room_events_max_id(self):
-        return 0  # TODO (erikj)
+        return "s0"  # TODO (erikj)
 
     def get_send_event_level(self, room_id):
         return defer.succeed(0)