summary refs log tree commit diff
path: root/synapse/streams/events.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/streams/events.py')
-rw-r--r--synapse/streams/events.py73
1 files changed, 48 insertions, 25 deletions
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 27c7734b36..36174a811b 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -26,16 +26,7 @@ class RoomEventSource(object):
         self.store = hs.get_datastore()
 
     @defer.inlineCallbacks
-    def get_keys_for_user(self, user):
-        events = yield self.store.get_rooms_for_user_where_membership_is(
-            user.to_string(),
-            (Membership.JOIN,),
-        )
-
-        defer.returnValue(set([e.room_id for e in events]))
-
-    @defer.inlineCallbacks
-    def get_new_events_for_user(self, user, from_token, limit, key=None):
+    def get_new_events_for_user(self, user, from_token, limit):
         # We just ignore the key for now.
 
         to_key = yield self.get_current_token_part()
@@ -56,7 +47,7 @@ class RoomEventSource(object):
         return self.store.get_room_events_max_id()
 
     @defer.inlineCallbacks
-    def get_pagination_rows(self, from_token, to_token, limit, key):
+    def get_pagination_rows(self, user, from_token, to_token, limit, key):
         to_key = to_token.events_key if to_token else None
 
         events, next_key = yield self.store.paginate_room_events(
@@ -73,14 +64,14 @@ class RoomEventSource(object):
         defer.returnValue((events, next_token))
 
 
-class PresenceStreamSource(object):
-    SIGNAL_NAME = "PresenceStreamSource"
+class PresenceSource(object):
+    SIGNAL_NAME = "PresenceSource"
 
     def __init__(self, hs):
         self.hs = hs
         self.clock = hs.get_clock()
 
-    def get_new_events_for_user(self, user, from_token, limit, key=None):
+    def get_new_events_for_user(self, user, from_token, limit):
         from_key = int(from_token.presence_key)
 
         presence = self.hs.get_handlers().presence_handler
@@ -97,7 +88,7 @@ class PresenceStreamSource(object):
             data = [x[1].make_event(user=x[0], clock=clock) for x in updates]
 
             end_token = from_token.copy_and_replace(
-                "presence_key", latest_serial + 1
+                "presence_key", latest_serial
             )
             return ((data, end_token))
         else:
@@ -106,18 +97,52 @@ class PresenceStreamSource(object):
             )
             return (([], end_token))
 
-    def get_keys_for_user(self, user):
-        return defer.succeed(["moose"])
-
     def get_current_token_part(self):
         presence = self.hs.get_handlers().presence_handler
         return presence._user_cachemap_latest_serial
 
+    def get_pagination_rows(self, user, from_token, to_token, limit, key):
+        from_key = int(from_token.presence_key)
+
+        if to_token:
+            to_key = int(to_token.presence_key)
+        else:
+            to_key = -1
+
+        presence = self.hs.get_handlers().presence_handler
+        cachemap = presence._user_cachemap
+
+        # TODO(paul): limit, and filter by visibility
+        updates = [(k, cachemap[k]) for k in cachemap
+                   if to_key < cachemap[k].serial < from_key]
+
+        if updates:
+            clock = self.clock
+
+            earliest_serial = max([x[1].serial for x in updates])
+            data = [x[1].make_event(user=x[0], clock=clock) for x in updates]
+
+            if to_token:
+                next_token = to_token
+            else:
+                next_token = from_token
+
+            next_token = next_token.copy_and_replace(
+                "presence_key", earliest_serial
+            )
+            return ((data, next_token))
+        else:
+            if not to_token:
+                to_token = from_token.copy_and_replace(
+                    "presence_key", 0
+                )
+            return (([], to_token))
+
 
 class EventSources(object):
     SOURCE_TYPES = [
         RoomEventSource,
-        PresenceStreamSource,
+        PresenceSource,
     ]
 
     def __init__(self, hs):
@@ -130,15 +155,13 @@ class EventSources(object):
     @defer.inlineCallbacks
     def get_current_token(self):
         events_key = yield self.sources[0].get_current_token_part()
-        token = EventSources.create_token(events_key, "0")
+        presence_key = yield self.sources[1].get_current_token_part()
+        token = EventSources.create_token(events_key, presence_key)
         defer.returnValue(token)
 
 
 class StreamSource(object):
-    def get_keys_for_user(self, user):
-        raise NotImplementedError("get_keys_for_user")
-
-    def get_new_events_for_user(self, user, from_token, limit, key=None):
+    def get_new_events_for_user(self, user, from_token, limit):
         raise NotImplementedError("get_new_events_for_user")
 
     def get_current_token_part(self):
@@ -146,6 +169,6 @@ class StreamSource(object):
 
 
 class PaginationSource(object):
-    def get_pagination_rows(self, from_token, to_token, limit, key):
+    def get_pagination_rows(self, user, from_token, to_token, limit, key):
         raise NotImplementedError("get_rows")