diff --git a/synapse/notifier.py b/synapse/notifier.py
index 4e091314e6..87c120a59c 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -15,11 +15,13 @@
import logging
from collections import namedtuple
+from typing import Callable, Iterable, List, TypeVar
from prometheus_client import Counter
from twisted.internet import defer
+import synapse.server
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError
from synapse.handlers.presence import format_user_presence_state
@@ -40,12 +42,14 @@ users_woken_by_stream_counter = Counter(
"synapse_notifier_users_woken_by_stream", "", ["stream"]
)
+T = TypeVar("T")
+
# TODO(paul): Should be shared somewhere
-def count(func, l):
- """Return the number of items in l for which func returns true."""
+def count(func: Callable[[T], bool], it: Iterable[T]) -> int:
+ """Return the number of items in it for which func returns true."""
n = 0
- for x in l:
+ for x in it:
if func(x):
n += 1
return n
@@ -154,16 +158,22 @@ class Notifier(object):
UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000
- def __init__(self, hs):
+ def __init__(self, hs: "synapse.server.HomeServer"):
self.user_to_user_stream = {}
self.room_to_user_streams = {}
self.hs = hs
+ self.storage = hs.get_storage()
self.event_sources = hs.get_event_sources()
self.store = hs.get_datastore()
self.pending_new_room_events = []
- self.replication_callbacks = []
+ # Called when there are new things to stream over replication
+ self.replication_callbacks = [] # type: List[Callable[[], None]]
+
+ # Called when remote servers have come back online after having been
+ # down.
+ self.remote_server_up_callbacks = [] # type: List[Callable[[str], None]]
self.clock = hs.get_clock()
self.appservice_handler = hs.get_application_service_handler()
@@ -204,7 +214,7 @@ class Notifier(object):
"synapse_notifier_users", "", [], lambda: len(self.user_to_user_stream)
)
- def add_replication_callback(self, cb):
+ def add_replication_callback(self, cb: Callable[[], None]):
"""Add a callback that will be called when some new data is available.
Callback is not given any arguments. It should *not* return a Deferred - if
it needs to do any asynchronous work, a background thread should be started and
@@ -265,10 +275,9 @@ class Notifier(object):
"room_key", room_stream_id, users=extra_users, rooms=[event.room_id]
)
- @defer.inlineCallbacks
- def _notify_app_services(self, room_stream_id):
+ async def _notify_app_services(self, room_stream_id):
try:
- yield self.appservice_handler.notify_interested_services(room_stream_id)
+ await self.appservice_handler.notify_interested_services(room_stream_id)
except Exception:
logger.exception("Error notifying application services of event")
@@ -303,8 +312,7 @@ class Notifier(object):
without waking up any of the normal user event streams"""
self.notify_replication()
- @defer.inlineCallbacks
- def wait_for_events(
+ async def wait_for_events(
self, user_id, timeout, callback, room_ids=None, from_token=StreamToken.START
):
"""Wait until the callback returns a non empty response or the
@@ -312,9 +320,9 @@ class Notifier(object):
"""
user_stream = self.user_to_user_stream.get(user_id)
if user_stream is None:
- current_token = yield self.event_sources.get_current_token()
+ current_token = await self.event_sources.get_current_token()
if room_ids is None:
- room_ids = yield self.store.get_rooms_for_user(user_id)
+ room_ids = await self.store.get_rooms_for_user(user_id)
user_stream = _NotifierUserStream(
user_id=user_id,
rooms=room_ids,
@@ -343,11 +351,11 @@ class Notifier(object):
self.hs.get_reactor(),
)
with PreserveLoggingContext():
- yield listener.deferred
+ await listener.deferred
current_token = user_stream.current_token
- result = yield callback(prev_token, current_token)
+ result = await callback(prev_token, current_token)
if result:
break
@@ -363,12 +371,11 @@ class Notifier(object):
# This happened if there was no timeout or if the timeout had
# already expired.
current_token = user_stream.current_token
- result = yield callback(prev_token, current_token)
+ result = await callback(prev_token, current_token)
return result
- @defer.inlineCallbacks
- def get_events_for(
+ async def get_events_for(
self,
user,
pagination_config,
@@ -390,15 +397,14 @@ class Notifier(object):
"""
from_token = pagination_config.from_token
if not from_token:
- from_token = yield self.event_sources.get_current_token()
+ from_token = await self.event_sources.get_current_token()
limit = pagination_config.limit
- room_ids, is_joined = yield self._get_room_ids(user, explicit_room_id)
+ room_ids, is_joined = await self._get_room_ids(user, explicit_room_id)
is_peeking = not is_joined
- @defer.inlineCallbacks
- def check_for_updates(before_token, after_token):
+ async def check_for_updates(before_token, after_token):
if not after_token.is_after(before_token):
return EventStreamResult([], (from_token, from_token))
@@ -414,7 +420,7 @@ class Notifier(object):
if only_keys and name not in only_keys:
continue
- new_events, new_key = yield source.get_new_events(
+ new_events, new_key = await source.get_new_events(
user=user,
from_key=getattr(from_token, keyname),
limit=limit,
@@ -424,8 +430,11 @@ class Notifier(object):
)
if name == "room":
- new_events = yield filter_events_for_client(
- self.store, user.to_string(), new_events, is_peeking=is_peeking
+ new_events = await filter_events_for_client(
+ self.storage,
+ user.to_string(),
+ new_events,
+ is_peeking=is_peeking,
)
elif name == "presence":
now = self.clock.time_msec()
@@ -457,7 +466,7 @@ class Notifier(object):
user_id_for_stream,
)
- result = yield self.wait_for_events(
+ result = await self.wait_for_events(
user_id_for_stream,
timeout,
check_for_updates,
@@ -467,20 +476,18 @@ class Notifier(object):
return result
- @defer.inlineCallbacks
- def _get_room_ids(self, user, explicit_room_id):
- joined_room_ids = yield self.store.get_rooms_for_user(user.to_string())
+ async def _get_room_ids(self, user, explicit_room_id):
+ joined_room_ids = await self.store.get_rooms_for_user(user.to_string())
if explicit_room_id:
if explicit_room_id in joined_room_ids:
return [explicit_room_id], True
- if (yield self._is_world_readable(explicit_room_id)):
+ if await self._is_world_readable(explicit_room_id):
return [explicit_room_id], False
raise AuthError(403, "Non-joined access not allowed")
return joined_room_ids, True
- @defer.inlineCallbacks
- def _is_world_readable(self, room_id):
- state = yield self.state_handler.get_current_state(
+ async def _is_world_readable(self, room_id):
+ state = await self.state_handler.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
)
if state and "history_visibility" in state.content:
@@ -521,3 +528,12 @@ class Notifier(object):
"""Notify the any replication listeners that there's a new event"""
for cb in self.replication_callbacks:
cb()
+
+ def notify_remote_server_up(self, server: str):
+ """Notify any replication that a remote server has come back up
+ """
+ # We call federation_sender directly rather than registering as a
+ # callback as a) we already have a reference to it and b) it introduces
+ # circular dependencies.
+ if self.federation_sender:
+ self.federation_sender.wake_destination(server)
|