summary refs log tree commit diff
path: root/synapse/notifier.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/notifier.py')
-rw-r--r--synapse/notifier.py82
1 files changed, 49 insertions, 33 deletions
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 4e091314e6..87c120a59c 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -15,11 +15,13 @@
 
 import logging
 from collections import namedtuple
+from typing import Callable, Iterable, List, TypeVar
 
 from prometheus_client import Counter
 
 from twisted.internet import defer
 
+import synapse.server
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import AuthError
 from synapse.handlers.presence import format_user_presence_state
@@ -40,12 +42,14 @@ users_woken_by_stream_counter = Counter(
     "synapse_notifier_users_woken_by_stream", "", ["stream"]
 )
 
+T = TypeVar("T")
+
 
 # TODO(paul): Should be shared somewhere
-def count(func, l):
-    """Return the number of items in l for which func returns true."""
+def count(func: Callable[[T], bool], it: Iterable[T]) -> int:
+    """Return the number of items in it for which func returns true."""
     n = 0
-    for x in l:
+    for x in it:
         if func(x):
             n += 1
     return n
@@ -154,16 +158,22 @@ class Notifier(object):
 
     UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000
 
-    def __init__(self, hs):
+    def __init__(self, hs: "synapse.server.HomeServer"):
         self.user_to_user_stream = {}
         self.room_to_user_streams = {}
 
         self.hs = hs
+        self.storage = hs.get_storage()
         self.event_sources = hs.get_event_sources()
         self.store = hs.get_datastore()
         self.pending_new_room_events = []
 
-        self.replication_callbacks = []
+        # Called when there are new things to stream over replication
+        self.replication_callbacks = []  # type: List[Callable[[], None]]
+
+        # Called when remote servers have come back online after having been
+        # down.
+        self.remote_server_up_callbacks = []  # type: List[Callable[[str], None]]
 
         self.clock = hs.get_clock()
         self.appservice_handler = hs.get_application_service_handler()
@@ -204,7 +214,7 @@ class Notifier(object):
             "synapse_notifier_users", "", [], lambda: len(self.user_to_user_stream)
         )
 
-    def add_replication_callback(self, cb):
+    def add_replication_callback(self, cb: Callable[[], None]):
         """Add a callback that will be called when some new data is available.
         Callback is not given any arguments. It should *not* return a Deferred - if
         it needs to do any asynchronous work, a background thread should be started and
@@ -265,10 +275,9 @@ class Notifier(object):
             "room_key", room_stream_id, users=extra_users, rooms=[event.room_id]
         )
 
-    @defer.inlineCallbacks
-    def _notify_app_services(self, room_stream_id):
+    async def _notify_app_services(self, room_stream_id):
         try:
-            yield self.appservice_handler.notify_interested_services(room_stream_id)
+            await self.appservice_handler.notify_interested_services(room_stream_id)
         except Exception:
             logger.exception("Error notifying application services of event")
 
@@ -303,8 +312,7 @@ class Notifier(object):
         without waking up any of the normal user event streams"""
         self.notify_replication()
 
-    @defer.inlineCallbacks
-    def wait_for_events(
+    async def wait_for_events(
         self, user_id, timeout, callback, room_ids=None, from_token=StreamToken.START
     ):
         """Wait until the callback returns a non empty response or the
@@ -312,9 +320,9 @@ class Notifier(object):
         """
         user_stream = self.user_to_user_stream.get(user_id)
         if user_stream is None:
-            current_token = yield self.event_sources.get_current_token()
+            current_token = await self.event_sources.get_current_token()
             if room_ids is None:
-                room_ids = yield self.store.get_rooms_for_user(user_id)
+                room_ids = await self.store.get_rooms_for_user(user_id)
             user_stream = _NotifierUserStream(
                 user_id=user_id,
                 rooms=room_ids,
@@ -343,11 +351,11 @@ class Notifier(object):
                         self.hs.get_reactor(),
                     )
                     with PreserveLoggingContext():
-                        yield listener.deferred
+                        await listener.deferred
 
                     current_token = user_stream.current_token
 
-                    result = yield callback(prev_token, current_token)
+                    result = await callback(prev_token, current_token)
                     if result:
                         break
 
@@ -363,12 +371,11 @@ class Notifier(object):
             # This happened if there was no timeout or if the timeout had
             # already expired.
             current_token = user_stream.current_token
-            result = yield callback(prev_token, current_token)
+            result = await callback(prev_token, current_token)
 
         return result
 
-    @defer.inlineCallbacks
-    def get_events_for(
+    async def get_events_for(
         self,
         user,
         pagination_config,
@@ -390,15 +397,14 @@ class Notifier(object):
         """
         from_token = pagination_config.from_token
         if not from_token:
-            from_token = yield self.event_sources.get_current_token()
+            from_token = await self.event_sources.get_current_token()
 
         limit = pagination_config.limit
 
-        room_ids, is_joined = yield self._get_room_ids(user, explicit_room_id)
+        room_ids, is_joined = await self._get_room_ids(user, explicit_room_id)
         is_peeking = not is_joined
 
-        @defer.inlineCallbacks
-        def check_for_updates(before_token, after_token):
+        async def check_for_updates(before_token, after_token):
             if not after_token.is_after(before_token):
                 return EventStreamResult([], (from_token, from_token))
 
@@ -414,7 +420,7 @@ class Notifier(object):
                 if only_keys and name not in only_keys:
                     continue
 
-                new_events, new_key = yield source.get_new_events(
+                new_events, new_key = await source.get_new_events(
                     user=user,
                     from_key=getattr(from_token, keyname),
                     limit=limit,
@@ -424,8 +430,11 @@ class Notifier(object):
                 )
 
                 if name == "room":
-                    new_events = yield filter_events_for_client(
-                        self.store, user.to_string(), new_events, is_peeking=is_peeking
+                    new_events = await filter_events_for_client(
+                        self.storage,
+                        user.to_string(),
+                        new_events,
+                        is_peeking=is_peeking,
                     )
                 elif name == "presence":
                     now = self.clock.time_msec()
@@ -457,7 +466,7 @@ class Notifier(object):
                 user_id_for_stream,
             )
 
-        result = yield self.wait_for_events(
+        result = await self.wait_for_events(
             user_id_for_stream,
             timeout,
             check_for_updates,
@@ -467,20 +476,18 @@ class Notifier(object):
 
         return result
 
-    @defer.inlineCallbacks
-    def _get_room_ids(self, user, explicit_room_id):
-        joined_room_ids = yield self.store.get_rooms_for_user(user.to_string())
+    async def _get_room_ids(self, user, explicit_room_id):
+        joined_room_ids = await self.store.get_rooms_for_user(user.to_string())
         if explicit_room_id:
             if explicit_room_id in joined_room_ids:
                 return [explicit_room_id], True
-            if (yield self._is_world_readable(explicit_room_id)):
+            if await self._is_world_readable(explicit_room_id):
                 return [explicit_room_id], False
             raise AuthError(403, "Non-joined access not allowed")
         return joined_room_ids, True
 
-    @defer.inlineCallbacks
-    def _is_world_readable(self, room_id):
-        state = yield self.state_handler.get_current_state(
+    async def _is_world_readable(self, room_id):
+        state = await self.state_handler.get_current_state(
             room_id, EventTypes.RoomHistoryVisibility, ""
         )
         if state and "history_visibility" in state.content:
@@ -521,3 +528,12 @@ class Notifier(object):
         """Notify the any replication listeners that there's a new event"""
         for cb in self.replication_callbacks:
             cb()
+
+    def notify_remote_server_up(self, server: str):
+        """Notify any replication that a remote server has come back up
+        """
+        # We call federation_sender directly rather than registering as a
+        # callback as a) we already have a reference to it and b) it introduces
+        # circular dependencies.
+        if self.federation_sender:
+            self.federation_sender.wake_destination(server)