diff options
Diffstat (limited to '')
-rw-r--r-- | synapse/handlers/events.py | 7 | ||||
-rw-r--r-- | synapse/handlers/presence.py | 7 | ||||
-rw-r--r-- | synapse/handlers/room.py | 15 | ||||
-rw-r--r-- | synapse/notifier.py | 85 | ||||
-rw-r--r-- | synapse/streams/events.py | 73 |
5 files changed, 107 insertions, 80 deletions
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 8c34776245..aabec37fc0 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -50,7 +50,12 @@ class EventStreamHandler(BaseHandler): if pagin_config.from_token is None: pagin_config.from_token = None - events, tokens = yield self.notifier.get_events_for(auth_user, pagin_config, timeout) + rm_handler = self.hs.get_handlers().room_member_handler + room_ids = yield rm_handler.get_rooms_for_user(auth_user) + + events, tokens = yield self.notifier.get_events_for( + auth_user, room_ids, pagin_config, timeout + ) chunks = [ e.get_dict() if isinstance(e, SynapseEvent) else e diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 8408266da0..9a690258de 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -676,12 +676,7 @@ class PresenceHandler(BaseHandler): statuscache.make_event(user=observed_user, clock=self.clock) self.notifier.on_new_user_event( - observer_user.to_string(), - event_data=statuscache.make_event( - user=observed_user, - clock=self.clock - ), - store_id=statuscache.serial + [observer_user], ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 6fbe84ea40..19ade10a91 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -120,8 +120,11 @@ class MessageHandler(BaseHandler): else: from_token = yield self.hs.get_event_sources().get_current_token() + user = self.hs.parse_userid(user_id) + events, next_token = yield data_source.get_pagination_rows( - from_token, pagin_config.to_token, pagin_config.limit, room_id + user, from_token, pagin_config.to_token, pagin_config.limit, + room_id ) chunk = { @@ -265,6 +268,8 @@ class MessageHandler(BaseHandler): membership_list=[Membership.INVITE, Membership.JOIN] ) + user = self.hs.parse_userid(user_id) + rooms_ret = [] # FIXME (erikj): We need to not generate this token, @@ -272,8 +277,8 @@ class MessageHandler(BaseHandler): # FIXME (erikj): Fix this. presence_stream = self.hs.get_event_sources().sources[1] - presence = yield presence_stream.get_new_events_for_user( - user_id, now_token, None, None + presence, _ = yield presence_stream.get_pagination_rows( + user, now_token, None, None, None ) limit = pagin_config.limit @@ -297,7 +302,7 @@ class MessageHandler(BaseHandler): messages, token = yield self.store.get_recent_events_for_room( event.room_id, limit=limit, - end_token=now_token.events_key.to_string(), + end_token=now_token.events_key, ) d["messages"] = { @@ -311,7 +316,7 @@ class MessageHandler(BaseHandler): except: logger.exception("Failed to get snapshot") - ret = {"rooms": rooms_ret, "presence": presence[0], "end": now_token.to_string()} + ret = {"rooms": rooms_ret, "presence": presence, "end": now_token.to_string()} defer.returnValue(ret) diff --git a/synapse/notifier.py b/synapse/notifier.py index df9be29f3d..a69d5343cb 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -24,14 +24,14 @@ logger = logging.getLogger(__name__) class _NotificationListener(object): - def __init__(self, user, from_token, limit, timeout, deferred): + def __init__(self, user, rooms, from_token, limit, timeout, deferred): self.user = user self.from_token = from_token self.limit = limit self.timeout = timeout self.deferred = deferred - self.signal_key_list = [] + self.rooms = rooms self.pending_notifications = [] @@ -43,36 +43,39 @@ class _NotificationListener(object): except defer.AlreadyCalledError: pass - for signal, key in self.signal_key_list: - lst = notifier.signal_keys_to_users.get((signal, key), []) + for room in self.rooms: + lst = notifier.rooms_to_listeners.get(room, set()) + lst.discard(self) + + notifier.user_to_listeners.get(self.user, set()).discard(self) - try: - lst.remove(self) - except: - pass class Notifier(object): def __init__(self, hs): self.hs = hs - self.signal_keys_to_users = {} + self.rooms_to_listeners = {} + self.user_to_listeners = {} self.event_sources = hs.get_event_sources() + hs.get_distributor().observe( + "user_joined_room", self._user_joined_room + ) + @log_function @defer.inlineCallbacks - def on_new_room_event(self, event, store_id): + def on_new_room_event(self, event, extra_users=[]): room_id = event.room_id source = self.event_sources.sources[0] - listeners = self.signal_keys_to_users.get( - (source.SIGNAL_NAME, room_id), - [] - ) + listeners = self.rooms_to_listeners.get(room_id, set()).copy() + + for user in extra_users: + listeners |= self.user_to_listeners.get(user, set()).copy() - logger.debug("on_new_room_event self.signal_keys_to_users %s", listeners) logger.debug("on_new_room_event listeners %s", listeners) # TODO (erikj): Can we make this more efficient by hitting the @@ -82,7 +85,6 @@ class Notifier(object): listener.user, listener.from_token, listener.limit, - key=room_id, ) if events: @@ -90,20 +92,23 @@ class Notifier(object): self, events, listener.from_token, end_token ) - def on_new_user_event(self, *args, **kwargs): + @defer.inlineCallbacks + def on_new_user_event(self, users=[], rooms=[]): source = self.event_sources.sources[1] - listeners = self.signal_keys_to_users.get( - (source.SIGNAL_NAME, "moose"), - [] - ) + listeners = set() + + for user in users: + listeners |= self.user_to_listeners.get(user, set()).copy() + + for room in rooms: + listeners |= self.rooms_to_listeners.get(room, set()).copy() for listener in listeners: events, end_token = yield source.get_new_events_for_user( listener.user, listener.from_token, listener.limit, - key="moose", ) if events: @@ -111,23 +116,24 @@ class Notifier(object): self, events, listener.from_token, end_token ) - def get_events_for(self, user, pagination_config, timeout): + def get_events_for(self, user, rooms, pagination_config, timeout): deferred = defer.Deferred() self._get_events( - deferred, user, pagination_config.from_token, + deferred, user, rooms, pagination_config.from_token, pagination_config.limit, timeout ).addErrback(deferred.errback) return deferred @defer.inlineCallbacks - def _get_events(self, deferred, user, from_token, limit, timeout): + def _get_events(self, deferred, user, rooms, from_token, limit, timeout): if not from_token: from_token = yield self.event_sources.get_current_token() listener = _NotificationListener( user, + rooms, from_token, limit, timeout, @@ -137,7 +143,7 @@ class Notifier(object): if timeout: reactor.callLater(timeout/1000, self._timeout_listener, listener) - yield self._register_with_keys(listener) + self._register_with_keys(listener) yield self._check_for_updates(listener) return @@ -152,25 +158,13 @@ class Notifier(object): listener.from_token, ) - @defer.inlineCallbacks @log_function def _register_with_keys(self, listener): - signals_keys = {} - - # TODO (erikj): This can probably be replaced by a DeferredList - for source in self.event_sources.sources: - keys = yield source.get_keys_for_user(listener.user) - signals_keys.setdefault(source.SIGNAL_NAME, []).extend(keys) + for room in listener.rooms: + s = self.rooms_to_listeners.setdefault(room, set()) + s.add(listener) - for signal, keys in signals_keys.items(): - for key in keys: - s = self.signal_keys_to_users.setdefault((signal, key), []) - s.append(listener) - listener.signal_key_list.append((signal, key)) - - logger.debug("New signal_keys_to_users: %s", self.signal_keys_to_users) - - defer.returnValue(listener) + self.user_to_listeners.setdefault(listener.user, set()).add(listener) @defer.inlineCallbacks @log_function @@ -195,8 +189,13 @@ class Notifier(object): end_token = from_token - if events: listener.notify(self, events, listener.from_token, end_token) defer.returnValue(listener) + + def _user_joined_room(self, user, room_id): + new_listeners = self.user_to_listeners.get(user, set()) + + listeners = self.rooms_to_listeners.setdefault(room_id, set()) + listeners |= new_listeners diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 27c7734b36..36174a811b 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -26,16 +26,7 @@ class RoomEventSource(object): self.store = hs.get_datastore() @defer.inlineCallbacks - def get_keys_for_user(self, user): - events = yield self.store.get_rooms_for_user_where_membership_is( - user.to_string(), - (Membership.JOIN,), - ) - - defer.returnValue(set([e.room_id for e in events])) - - @defer.inlineCallbacks - def get_new_events_for_user(self, user, from_token, limit, key=None): + def get_new_events_for_user(self, user, from_token, limit): # We just ignore the key for now. to_key = yield self.get_current_token_part() @@ -56,7 +47,7 @@ class RoomEventSource(object): return self.store.get_room_events_max_id() @defer.inlineCallbacks - def get_pagination_rows(self, from_token, to_token, limit, key): + def get_pagination_rows(self, user, from_token, to_token, limit, key): to_key = to_token.events_key if to_token else None events, next_key = yield self.store.paginate_room_events( @@ -73,14 +64,14 @@ class RoomEventSource(object): defer.returnValue((events, next_token)) -class PresenceStreamSource(object): - SIGNAL_NAME = "PresenceStreamSource" +class PresenceSource(object): + SIGNAL_NAME = "PresenceSource" def __init__(self, hs): self.hs = hs self.clock = hs.get_clock() - def get_new_events_for_user(self, user, from_token, limit, key=None): + def get_new_events_for_user(self, user, from_token, limit): from_key = int(from_token.presence_key) presence = self.hs.get_handlers().presence_handler @@ -97,7 +88,7 @@ class PresenceStreamSource(object): data = [x[1].make_event(user=x[0], clock=clock) for x in updates] end_token = from_token.copy_and_replace( - "presence_key", latest_serial + 1 + "presence_key", latest_serial ) return ((data, end_token)) else: @@ -106,18 +97,52 @@ class PresenceStreamSource(object): ) return (([], end_token)) - def get_keys_for_user(self, user): - return defer.succeed(["moose"]) - def get_current_token_part(self): presence = self.hs.get_handlers().presence_handler return presence._user_cachemap_latest_serial + def get_pagination_rows(self, user, from_token, to_token, limit, key): + from_key = int(from_token.presence_key) + + if to_token: + to_key = int(to_token.presence_key) + else: + to_key = -1 + + presence = self.hs.get_handlers().presence_handler + cachemap = presence._user_cachemap + + # TODO(paul): limit, and filter by visibility + updates = [(k, cachemap[k]) for k in cachemap + if to_key < cachemap[k].serial < from_key] + + if updates: + clock = self.clock + + earliest_serial = max([x[1].serial for x in updates]) + data = [x[1].make_event(user=x[0], clock=clock) for x in updates] + + if to_token: + next_token = to_token + else: + next_token = from_token + + next_token = next_token.copy_and_replace( + "presence_key", earliest_serial + ) + return ((data, next_token)) + else: + if not to_token: + to_token = from_token.copy_and_replace( + "presence_key", 0 + ) + return (([], to_token)) + class EventSources(object): SOURCE_TYPES = [ RoomEventSource, - PresenceStreamSource, + PresenceSource, ] def __init__(self, hs): @@ -130,15 +155,13 @@ class EventSources(object): @defer.inlineCallbacks def get_current_token(self): events_key = yield self.sources[0].get_current_token_part() - token = EventSources.create_token(events_key, "0") + presence_key = yield self.sources[1].get_current_token_part() + token = EventSources.create_token(events_key, presence_key) defer.returnValue(token) class StreamSource(object): - def get_keys_for_user(self, user): - raise NotImplementedError("get_keys_for_user") - - def get_new_events_for_user(self, user, from_token, limit, key=None): + def get_new_events_for_user(self, user, from_token, limit): raise NotImplementedError("get_new_events_for_user") def get_current_token_part(self): @@ -146,6 +169,6 @@ class StreamSource(object): class PaginationSource(object): - def get_pagination_rows(self, from_token, to_token, limit, key): + def get_pagination_rows(self, user, from_token, to_token, limit, key): raise NotImplementedError("get_rows") |