summary refs log tree commit diff
path: root/synapse/app/generic_worker.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/app/generic_worker.py')
-rw-r--r--synapse/app/generic_worker.py85
1 files changed, 49 insertions, 36 deletions
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 37afd2f810..2a56fe0bd5 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -17,6 +17,9 @@
 import contextlib
 import logging
 import sys
+from typing import Dict, Iterable
+
+from typing_extensions import ContextManager
 
 from twisted.internet import defer, reactor
 from twisted.web.resource import NoResource
@@ -38,14 +41,14 @@ from synapse.config.homeserver import HomeServerConfig
 from synapse.config.logger import setup_logging
 from synapse.federation import send_queue
 from synapse.federation.transport.server import TransportLayerServer
-from synapse.handlers.presence import PresenceHandler, get_interested_parties
+from synapse.handlers.presence import BasePresenceHandler, get_interested_parties
 from synapse.http.server import JsonResource
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
 from synapse.http.site import SynapseSite
 from synapse.logging.context import LoggingContext
 from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.replication.slave.storage._base import BaseSlavedStore, __func__
+from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
 from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
 from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
@@ -225,23 +228,32 @@ class KeyUploadServlet(RestServlet):
             return 200, {"one_time_key_counts": result}
 
 
+class _NullContextManager(ContextManager[None]):
+    """A context manager which does nothing."""
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        pass
+
+
 UPDATE_SYNCING_USERS_MS = 10 * 1000
 
 
-class GenericWorkerPresence(object):
+class GenericWorkerPresence(BasePresenceHandler):
     def __init__(self, hs):
+        super().__init__(hs)
         self.hs = hs
         self.is_mine_id = hs.is_mine_id
         self.http_client = hs.get_simple_http_client()
-        self.store = hs.get_datastore()
-        self.user_to_num_current_syncs = {}
-        self.clock = hs.get_clock()
+
+        self._presence_enabled = hs.config.use_presence
+
+        # The number of ongoing syncs on this process, by user id.
+        # Empty if _presence_enabled is false.
+        self._user_to_num_current_syncs = {}  # type: Dict[str, int]
+
         self.notifier = hs.get_notifier()
         self.instance_id = hs.get_instance_id()
 
-        active_presence = self.store.take_presence_startup_info()
-        self.user_to_current_state = {state.user_id: state for state in active_presence}
-
         # user_id -> last_sync_ms. Lists the users that have stopped syncing
         # but we haven't notified the master of that yet
         self.users_going_offline = {}
@@ -259,13 +271,13 @@ class GenericWorkerPresence(object):
         )
 
     def _on_shutdown(self):
-        if self.hs.config.use_presence:
+        if self._presence_enabled:
             self.hs.get_tcp_replication().send_command(
                 ClearUserSyncsCommand(self.instance_id)
             )
 
     def send_user_sync(self, user_id, is_syncing, last_sync_ms):
-        if self.hs.config.use_presence:
+        if self._presence_enabled:
             self.hs.get_tcp_replication().send_user_sync(
                 self.instance_id, user_id, is_syncing, last_sync_ms
             )
@@ -307,28 +319,33 @@ class GenericWorkerPresence(object):
         # TODO Hows this supposed to work?
         return defer.succeed(None)
 
-    get_states = __func__(PresenceHandler.get_states)
-    get_state = __func__(PresenceHandler.get_state)
-    current_state_for_users = __func__(PresenceHandler.current_state_for_users)
+    async def user_syncing(
+        self, user_id: str, affect_presence: bool
+    ) -> ContextManager[None]:
+        """Record that a user is syncing.
+
+        Called by the sync and events servlets to record that a user has connected to
+        this worker and is waiting for some events.
+        """
+        if not affect_presence or not self._presence_enabled:
+            return _NullContextManager()
 
-    def user_syncing(self, user_id, affect_presence):
-        if affect_presence:
-            curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
-            self.user_to_num_current_syncs[user_id] = curr_sync + 1
+        curr_sync = self._user_to_num_current_syncs.get(user_id, 0)
+        self._user_to_num_current_syncs[user_id] = curr_sync + 1
 
-            # If we went from no in flight sync to some, notify replication
-            if self.user_to_num_current_syncs[user_id] == 1:
-                self.mark_as_coming_online(user_id)
+        # If we went from no in flight sync to some, notify replication
+        if self._user_to_num_current_syncs[user_id] == 1:
+            self.mark_as_coming_online(user_id)
 
         def _end():
             # We check that the user_id is in user_to_num_current_syncs because
             # user_to_num_current_syncs may have been cleared if we are
             # shutting down.
-            if affect_presence and user_id in self.user_to_num_current_syncs:
-                self.user_to_num_current_syncs[user_id] -= 1
+            if user_id in self._user_to_num_current_syncs:
+                self._user_to_num_current_syncs[user_id] -= 1
 
                 # If we went from one in flight sync to non, notify replication
-                if self.user_to_num_current_syncs[user_id] == 0:
+                if self._user_to_num_current_syncs[user_id] == 0:
                     self.mark_as_going_offline(user_id)
 
         @contextlib.contextmanager
@@ -338,7 +355,7 @@ class GenericWorkerPresence(object):
             finally:
                 _end()
 
-        return defer.succeed(_user_syncing())
+        return _user_syncing()
 
     @defer.inlineCallbacks
     def notify_from_replication(self, states, stream_id):
@@ -373,15 +390,12 @@ class GenericWorkerPresence(object):
         stream_id = token
         yield self.notify_from_replication(states, stream_id)
 
-    def get_currently_syncing_users(self):
-        if self.hs.config.use_presence:
-            return [
-                user_id
-                for user_id, count in self.user_to_num_current_syncs.items()
-                if count > 0
-            ]
-        else:
-            return set()
+    def get_currently_syncing_users_for_replication(self) -> Iterable[str]:
+        return [
+            user_id
+            for user_id, count in self._user_to_num_current_syncs.items()
+            if count > 0
+        ]
 
 
 class GenericWorkerTyping(object):
@@ -625,8 +639,7 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
 
         self.store = hs.get_datastore()
         self.typing_handler = hs.get_typing_handler()
-        # NB this is a SynchrotronPresence, not a normal PresenceHandler
-        self.presence_handler = hs.get_presence_handler()
+        self.presence_handler = hs.get_presence_handler()  # type: GenericWorkerPresence
         self.notifier = hs.get_notifier()
 
         self.notify_pushers = hs.config.start_pushers