diff options
Diffstat (limited to 'synapse/notifier.py')
-rw-r--r-- | synapse/notifier.py | 82 |
1 files changed, 49 insertions, 33 deletions
diff --git a/synapse/notifier.py b/synapse/notifier.py index 4e091314e6..87c120a59c 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -15,11 +15,13 @@ import logging from collections import namedtuple +from typing import Callable, Iterable, List, TypeVar from prometheus_client import Counter from twisted.internet import defer +import synapse.server from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError from synapse.handlers.presence import format_user_presence_state @@ -40,12 +42,14 @@ users_woken_by_stream_counter = Counter( "synapse_notifier_users_woken_by_stream", "", ["stream"] ) +T = TypeVar("T") + # TODO(paul): Should be shared somewhere -def count(func, l): - """Return the number of items in l for which func returns true.""" +def count(func: Callable[[T], bool], it: Iterable[T]) -> int: + """Return the number of items in it for which func returns true.""" n = 0 - for x in l: + for x in it: if func(x): n += 1 return n @@ -154,16 +158,22 @@ class Notifier(object): UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000 - def __init__(self, hs): + def __init__(self, hs: "synapse.server.HomeServer"): self.user_to_user_stream = {} self.room_to_user_streams = {} self.hs = hs + self.storage = hs.get_storage() self.event_sources = hs.get_event_sources() self.store = hs.get_datastore() self.pending_new_room_events = [] - self.replication_callbacks = [] + # Called when there are new things to stream over replication + self.replication_callbacks = [] # type: List[Callable[[], None]] + + # Called when remote servers have come back online after having been + # down. + self.remote_server_up_callbacks = [] # type: List[Callable[[str], None]] self.clock = hs.get_clock() self.appservice_handler = hs.get_application_service_handler() @@ -204,7 +214,7 @@ class Notifier(object): "synapse_notifier_users", "", [], lambda: len(self.user_to_user_stream) ) - def add_replication_callback(self, cb): + def add_replication_callback(self, cb: Callable[[], None]): """Add a callback that will be called when some new data is available. Callback is not given any arguments. It should *not* return a Deferred - if it needs to do any asynchronous work, a background thread should be started and @@ -265,10 +275,9 @@ class Notifier(object): "room_key", room_stream_id, users=extra_users, rooms=[event.room_id] ) - @defer.inlineCallbacks - def _notify_app_services(self, room_stream_id): + async def _notify_app_services(self, room_stream_id): try: - yield self.appservice_handler.notify_interested_services(room_stream_id) + await self.appservice_handler.notify_interested_services(room_stream_id) except Exception: logger.exception("Error notifying application services of event") @@ -303,8 +312,7 @@ class Notifier(object): without waking up any of the normal user event streams""" self.notify_replication() - @defer.inlineCallbacks - def wait_for_events( + async def wait_for_events( self, user_id, timeout, callback, room_ids=None, from_token=StreamToken.START ): """Wait until the callback returns a non empty response or the @@ -312,9 +320,9 @@ class Notifier(object): """ user_stream = self.user_to_user_stream.get(user_id) if user_stream is None: - current_token = yield self.event_sources.get_current_token() + current_token = await self.event_sources.get_current_token() if room_ids is None: - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) user_stream = _NotifierUserStream( user_id=user_id, rooms=room_ids, @@ -343,11 +351,11 @@ class Notifier(object): self.hs.get_reactor(), ) with PreserveLoggingContext(): - yield listener.deferred + await listener.deferred current_token = user_stream.current_token - result = yield callback(prev_token, current_token) + result = await callback(prev_token, current_token) if result: break @@ -363,12 +371,11 @@ class Notifier(object): # This happened if there was no timeout or if the timeout had # already expired. current_token = user_stream.current_token - result = yield callback(prev_token, current_token) + result = await callback(prev_token, current_token) return result - @defer.inlineCallbacks - def get_events_for( + async def get_events_for( self, user, pagination_config, @@ -390,15 +397,14 @@ class Notifier(object): """ from_token = pagination_config.from_token if not from_token: - from_token = yield self.event_sources.get_current_token() + from_token = await self.event_sources.get_current_token() limit = pagination_config.limit - room_ids, is_joined = yield self._get_room_ids(user, explicit_room_id) + room_ids, is_joined = await self._get_room_ids(user, explicit_room_id) is_peeking = not is_joined - @defer.inlineCallbacks - def check_for_updates(before_token, after_token): + async def check_for_updates(before_token, after_token): if not after_token.is_after(before_token): return EventStreamResult([], (from_token, from_token)) @@ -414,7 +420,7 @@ class Notifier(object): if only_keys and name not in only_keys: continue - new_events, new_key = yield source.get_new_events( + new_events, new_key = await source.get_new_events( user=user, from_key=getattr(from_token, keyname), limit=limit, @@ -424,8 +430,11 @@ class Notifier(object): ) if name == "room": - new_events = yield filter_events_for_client( - self.store, user.to_string(), new_events, is_peeking=is_peeking + new_events = await filter_events_for_client( + self.storage, + user.to_string(), + new_events, + is_peeking=is_peeking, ) elif name == "presence": now = self.clock.time_msec() @@ -457,7 +466,7 @@ class Notifier(object): user_id_for_stream, ) - result = yield self.wait_for_events( + result = await self.wait_for_events( user_id_for_stream, timeout, check_for_updates, @@ -467,20 +476,18 @@ class Notifier(object): return result - @defer.inlineCallbacks - def _get_room_ids(self, user, explicit_room_id): - joined_room_ids = yield self.store.get_rooms_for_user(user.to_string()) + async def _get_room_ids(self, user, explicit_room_id): + joined_room_ids = await self.store.get_rooms_for_user(user.to_string()) if explicit_room_id: if explicit_room_id in joined_room_ids: return [explicit_room_id], True - if (yield self._is_world_readable(explicit_room_id)): + if await self._is_world_readable(explicit_room_id): return [explicit_room_id], False raise AuthError(403, "Non-joined access not allowed") return joined_room_ids, True - @defer.inlineCallbacks - def _is_world_readable(self, room_id): - state = yield self.state_handler.get_current_state( + async def _is_world_readable(self, room_id): + state = await self.state_handler.get_current_state( room_id, EventTypes.RoomHistoryVisibility, "" ) if state and "history_visibility" in state.content: @@ -521,3 +528,12 @@ class Notifier(object): """Notify the any replication listeners that there's a new event""" for cb in self.replication_callbacks: cb() + + def notify_remote_server_up(self, server: str): + """Notify any replication that a remote server has come back up + """ + # We call federation_sender directly rather than registering as a + # callback as a) we already have a reference to it and b) it introduces + # circular dependencies. + if self.federation_sender: + self.federation_sender.wake_destination(server) |