diff options
Diffstat (limited to 'synapse/streams/events.py')
-rw-r--r-- | synapse/streams/events.py | 73 |
1 files changed, 48 insertions, 25 deletions
diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 27c7734b36..36174a811b 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -26,16 +26,7 @@ class RoomEventSource(object): self.store = hs.get_datastore() @defer.inlineCallbacks - def get_keys_for_user(self, user): - events = yield self.store.get_rooms_for_user_where_membership_is( - user.to_string(), - (Membership.JOIN,), - ) - - defer.returnValue(set([e.room_id for e in events])) - - @defer.inlineCallbacks - def get_new_events_for_user(self, user, from_token, limit, key=None): + def get_new_events_for_user(self, user, from_token, limit): # We just ignore the key for now. to_key = yield self.get_current_token_part() @@ -56,7 +47,7 @@ class RoomEventSource(object): return self.store.get_room_events_max_id() @defer.inlineCallbacks - def get_pagination_rows(self, from_token, to_token, limit, key): + def get_pagination_rows(self, user, from_token, to_token, limit, key): to_key = to_token.events_key if to_token else None events, next_key = yield self.store.paginate_room_events( @@ -73,14 +64,14 @@ class RoomEventSource(object): defer.returnValue((events, next_token)) -class PresenceStreamSource(object): - SIGNAL_NAME = "PresenceStreamSource" +class PresenceSource(object): + SIGNAL_NAME = "PresenceSource" def __init__(self, hs): self.hs = hs self.clock = hs.get_clock() - def get_new_events_for_user(self, user, from_token, limit, key=None): + def get_new_events_for_user(self, user, from_token, limit): from_key = int(from_token.presence_key) presence = self.hs.get_handlers().presence_handler @@ -97,7 +88,7 @@ class PresenceStreamSource(object): data = [x[1].make_event(user=x[0], clock=clock) for x in updates] end_token = from_token.copy_and_replace( - "presence_key", latest_serial + 1 + "presence_key", latest_serial ) return ((data, end_token)) else: @@ -106,18 +97,52 @@ class PresenceStreamSource(object): ) return (([], end_token)) - def get_keys_for_user(self, user): - return defer.succeed(["moose"]) - def get_current_token_part(self): presence = self.hs.get_handlers().presence_handler return presence._user_cachemap_latest_serial + def get_pagination_rows(self, user, from_token, to_token, limit, key): + from_key = int(from_token.presence_key) + + if to_token: + to_key = int(to_token.presence_key) + else: + to_key = -1 + + presence = self.hs.get_handlers().presence_handler + cachemap = presence._user_cachemap + + # TODO(paul): limit, and filter by visibility + updates = [(k, cachemap[k]) for k in cachemap + if to_key < cachemap[k].serial < from_key] + + if updates: + clock = self.clock + + earliest_serial = max([x[1].serial for x in updates]) + data = [x[1].make_event(user=x[0], clock=clock) for x in updates] + + if to_token: + next_token = to_token + else: + next_token = from_token + + next_token = next_token.copy_and_replace( + "presence_key", earliest_serial + ) + return ((data, next_token)) + else: + if not to_token: + to_token = from_token.copy_and_replace( + "presence_key", 0 + ) + return (([], to_token)) + class EventSources(object): SOURCE_TYPES = [ RoomEventSource, - PresenceStreamSource, + PresenceSource, ] def __init__(self, hs): @@ -130,15 +155,13 @@ class EventSources(object): @defer.inlineCallbacks def get_current_token(self): events_key = yield self.sources[0].get_current_token_part() - token = EventSources.create_token(events_key, "0") + presence_key = yield self.sources[1].get_current_token_part() + token = EventSources.create_token(events_key, presence_key) defer.returnValue(token) class StreamSource(object): - def get_keys_for_user(self, user): - raise NotImplementedError("get_keys_for_user") - - def get_new_events_for_user(self, user, from_token, limit, key=None): + def get_new_events_for_user(self, user, from_token, limit): raise NotImplementedError("get_new_events_for_user") def get_current_token_part(self): @@ -146,6 +169,6 @@ class StreamSource(object): class PaginationSource(object): - def get_pagination_rows(self, from_token, to_token, limit, key): + def get_pagination_rows(self, user, from_token, to_token, limit, key): raise NotImplementedError("get_rows") |