summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/7318.misc1
-rw-r--r--docs/tcp_replication.md6
-rw-r--r--synapse/api/constants.py2
-rw-r--r--synapse/app/generic_worker.py85
-rw-r--r--synapse/handlers/events.py20
-rw-r--r--synapse/handlers/initial_sync.py10
-rw-r--r--synapse/handlers/presence.py210
-rw-r--r--synapse/replication/tcp/commands.py7
-rw-r--r--synapse/replication/tcp/handler.py15
-rw-r--r--synapse/server.pyi2
10 files changed, 199 insertions, 159 deletions
diff --git a/changelog.d/7318.misc b/changelog.d/7318.misc
new file mode 100644
index 0000000000..676f285377
--- /dev/null
+++ b/changelog.d/7318.misc
@@ -0,0 +1 @@
+Move catchup of replication streams logic to worker.
diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md
index 3be8e50c4c..b922d9cf7e 100644
--- a/docs/tcp_replication.md
+++ b/docs/tcp_replication.md
@@ -196,7 +196,7 @@ Asks the server for the current position of all streams.
 
 #### USER_SYNC (C)
 
-   A user has started or stopped syncing
+   A user has started or stopped syncing on this process.
 
 #### CLEAR_USER_SYNC (C)
 
@@ -216,10 +216,6 @@ Asks the server for the current position of all streams.
 
    Inform the server a cache should be invalidated
 
-#### SYNC (S, C)
-
-   Used exclusively in tests
-
 ### REMOTE_SERVER_UP (S, C)
 
    Inform other processes that a remote server may have come back online.
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index fda2c2e5bb..bcaf2c3600 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -97,6 +97,8 @@ class EventTypes(object):
 
     Retention = "m.room.retention"
 
+    Presence = "m.presence"
+
 
 class RejectedReason(object):
     AUTH_ERROR = "auth_error"
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 37afd2f810..2a56fe0bd5 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -17,6 +17,9 @@
 import contextlib
 import logging
 import sys
+from typing import Dict, Iterable
+
+from typing_extensions import ContextManager
 
 from twisted.internet import defer, reactor
 from twisted.web.resource import NoResource
@@ -38,14 +41,14 @@ from synapse.config.homeserver import HomeServerConfig
 from synapse.config.logger import setup_logging
 from synapse.federation import send_queue
 from synapse.federation.transport.server import TransportLayerServer
-from synapse.handlers.presence import PresenceHandler, get_interested_parties
+from synapse.handlers.presence import BasePresenceHandler, get_interested_parties
 from synapse.http.server import JsonResource
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
 from synapse.http.site import SynapseSite
 from synapse.logging.context import LoggingContext
 from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.replication.slave.storage._base import BaseSlavedStore, __func__
+from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
 from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
 from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
@@ -225,23 +228,32 @@ class KeyUploadServlet(RestServlet):
             return 200, {"one_time_key_counts": result}
 
 
+class _NullContextManager(ContextManager[None]):
+    """A context manager which does nothing."""
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        pass
+
+
 UPDATE_SYNCING_USERS_MS = 10 * 1000
 
 
-class GenericWorkerPresence(object):
+class GenericWorkerPresence(BasePresenceHandler):
     def __init__(self, hs):
+        super().__init__(hs)
         self.hs = hs
         self.is_mine_id = hs.is_mine_id
         self.http_client = hs.get_simple_http_client()
-        self.store = hs.get_datastore()
-        self.user_to_num_current_syncs = {}
-        self.clock = hs.get_clock()
+
+        self._presence_enabled = hs.config.use_presence
+
+        # The number of ongoing syncs on this process, by user id.
+        # Empty if _presence_enabled is false.
+        self._user_to_num_current_syncs = {}  # type: Dict[str, int]
+
         self.notifier = hs.get_notifier()
         self.instance_id = hs.get_instance_id()
 
-        active_presence = self.store.take_presence_startup_info()
-        self.user_to_current_state = {state.user_id: state for state in active_presence}
-
         # user_id -> last_sync_ms. Lists the users that have stopped syncing
         # but we haven't notified the master of that yet
         self.users_going_offline = {}
@@ -259,13 +271,13 @@ class GenericWorkerPresence(object):
         )
 
     def _on_shutdown(self):
-        if self.hs.config.use_presence:
+        if self._presence_enabled:
             self.hs.get_tcp_replication().send_command(
                 ClearUserSyncsCommand(self.instance_id)
             )
 
     def send_user_sync(self, user_id, is_syncing, last_sync_ms):
-        if self.hs.config.use_presence:
+        if self._presence_enabled:
             self.hs.get_tcp_replication().send_user_sync(
                 self.instance_id, user_id, is_syncing, last_sync_ms
             )
@@ -307,28 +319,33 @@ class GenericWorkerPresence(object):
         # TODO Hows this supposed to work?
         return defer.succeed(None)
 
-    get_states = __func__(PresenceHandler.get_states)
-    get_state = __func__(PresenceHandler.get_state)
-    current_state_for_users = __func__(PresenceHandler.current_state_for_users)
+    async def user_syncing(
+        self, user_id: str, affect_presence: bool
+    ) -> ContextManager[None]:
+        """Record that a user is syncing.
+
+        Called by the sync and events servlets to record that a user has connected to
+        this worker and is waiting for some events.
+        """
+        if not affect_presence or not self._presence_enabled:
+            return _NullContextManager()
 
-    def user_syncing(self, user_id, affect_presence):
-        if affect_presence:
-            curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
-            self.user_to_num_current_syncs[user_id] = curr_sync + 1
+        curr_sync = self._user_to_num_current_syncs.get(user_id, 0)
+        self._user_to_num_current_syncs[user_id] = curr_sync + 1
 
-            # If we went from no in flight sync to some, notify replication
-            if self.user_to_num_current_syncs[user_id] == 1:
-                self.mark_as_coming_online(user_id)
+        # If we went from no in flight sync to some, notify replication
+        if self._user_to_num_current_syncs[user_id] == 1:
+            self.mark_as_coming_online(user_id)
 
         def _end():
             # We check that the user_id is in user_to_num_current_syncs because
             # user_to_num_current_syncs may have been cleared if we are
             # shutting down.
-            if affect_presence and user_id in self.user_to_num_current_syncs:
-                self.user_to_num_current_syncs[user_id] -= 1
+            if user_id in self._user_to_num_current_syncs:
+                self._user_to_num_current_syncs[user_id] -= 1
 
                 # If we went from one in flight sync to non, notify replication
-                if self.user_to_num_current_syncs[user_id] == 0:
+                if self._user_to_num_current_syncs[user_id] == 0:
                     self.mark_as_going_offline(user_id)
 
         @contextlib.contextmanager
@@ -338,7 +355,7 @@ class GenericWorkerPresence(object):
             finally:
                 _end()
 
-        return defer.succeed(_user_syncing())
+        return _user_syncing()
 
     @defer.inlineCallbacks
     def notify_from_replication(self, states, stream_id):
@@ -373,15 +390,12 @@ class GenericWorkerPresence(object):
         stream_id = token
         yield self.notify_from_replication(states, stream_id)
 
-    def get_currently_syncing_users(self):
-        if self.hs.config.use_presence:
-            return [
-                user_id
-                for user_id, count in self.user_to_num_current_syncs.items()
-                if count > 0
-            ]
-        else:
-            return set()
+    def get_currently_syncing_users_for_replication(self) -> Iterable[str]:
+        return [
+            user_id
+            for user_id, count in self._user_to_num_current_syncs.items()
+            if count > 0
+        ]
 
 
 class GenericWorkerTyping(object):
@@ -625,8 +639,7 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
 
         self.store = hs.get_datastore()
         self.typing_handler = hs.get_typing_handler()
-        # NB this is a SynchrotronPresence, not a normal PresenceHandler
-        self.presence_handler = hs.get_presence_handler()
+        self.presence_handler = hs.get_presence_handler()  # type: GenericWorkerPresence
         self.notifier = hs.get_notifier()
 
         self.notify_pushers = hs.config.start_pushers
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index ec18a42a68..71a89f09c7 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -19,6 +19,7 @@ import random
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import AuthError, SynapseError
 from synapse.events import EventBase
+from synapse.handlers.presence import format_user_presence_state
 from synapse.logging.utils import log_function
 from synapse.types import UserID
 from synapse.visibility import filter_events_for_client
@@ -97,6 +98,8 @@ class EventStreamHandler(BaseHandler):
                 explicit_room_id=room_id,
             )
 
+            time_now = self.clock.time_msec()
+
             # When the user joins a new room, or another user joins a currently
             # joined room, we need to send down presence for those users.
             to_add = []
@@ -112,19 +115,20 @@ class EventStreamHandler(BaseHandler):
                         users = await self.state.get_current_users_in_room(
                             event.room_id
                         )
-                        states = await presence_handler.get_states(users, as_event=True)
-                        to_add.extend(states)
                     else:
+                        users = [event.state_key]
 
-                        ev = await presence_handler.get_state(
-                            UserID.from_string(event.state_key), as_event=True
-                        )
-                        to_add.append(ev)
+                    states = await presence_handler.get_states(users)
+                    to_add.extend(
+                        {
+                            "type": EventTypes.Presence,
+                            "content": format_user_presence_state(state, time_now),
+                        }
+                        for state in states
+                    )
 
             events.extend(to_add)
 
-            time_now = self.clock.time_msec()
-
             chunks = await self._event_serializer.serialize_events(
                 events,
                 time_now,
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index b116500c7d..f88bad5f25 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -381,10 +381,16 @@ class InitialSyncHandler(BaseHandler):
                 return []
 
             states = await presence_handler.get_states(
-                [m.user_id for m in room_members], as_event=True
+                [m.user_id for m in room_members]
             )
 
-            return states
+            return [
+                {
+                    "type": EventTypes.Presence,
+                    "content": format_user_presence_state(s, time_now),
+                }
+                for s in states
+            ]
 
         async def get_receipts():
             receipts = await self.store.get_linearized_receipts_for_room(
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 6912165622..5cbefae177 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -21,10 +22,10 @@ The methods that define policy are:
     - PresenceHandler._handle_timeouts
     - should_notify
 """
-
+import abc
 import logging
 from contextlib import contextmanager
-from typing import Dict, List, Set
+from typing import Dict, Iterable, List, Set
 
 from six import iteritems, itervalues
 
@@ -41,7 +42,7 @@ from synapse.logging.utils import log_function
 from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.presence import UserPresenceState
-from synapse.types import UserID, get_domain_from_id
+from synapse.types import JsonDict, UserID, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.descriptors import cached
 from synapse.util.metrics import Measure
@@ -99,13 +100,106 @@ EXTERNAL_PROCESS_EXPIRY = 5 * 60 * 1000
 assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
 
 
-class PresenceHandler(object):
+class BasePresenceHandler(abc.ABC):
+    """Parts of the PresenceHandler that are shared between workers and master"""
+
+    def __init__(self, hs: "synapse.server.HomeServer"):
+        self.clock = hs.get_clock()
+        self.store = hs.get_datastore()
+
+        active_presence = self.store.take_presence_startup_info()
+        self.user_to_current_state = {state.user_id: state for state in active_presence}
+
+    @abc.abstractmethod
+    async def user_syncing(
+        self, user_id: str, affect_presence: bool
+    ) -> ContextManager[None]:
+        """Returns a context manager that should surround any stream requests
+        from the user.
+
+        This allows us to keep track of who is currently streaming and who isn't
+        without having to have timers outside of this module to avoid flickering
+        when users disconnect/reconnect.
+
+        Args:
+            user_id: the user that is starting a sync
+            affect_presence: If false this function will be a no-op.
+                Useful for streams that are not associated with an actual
+                client that is being used by a user.
+        """
+
+    @abc.abstractmethod
+    def get_currently_syncing_users_for_replication(self) -> Iterable[str]:
+        """Get an iterable of syncing users on this worker, to send to the presence handler
+
+        This is called when a replication connection is established. It should return
+        a list of user ids, which are then sent as USER_SYNC commands to inform the
+        process handling presence about those users.
+
+        Returns:
+            An iterable of user_id strings.
+        """
+
+    async def get_state(self, target_user: UserID) -> UserPresenceState:
+        results = await self.get_states([target_user.to_string()])
+        return results[0]
+
+    async def get_states(
+        self, target_user_ids: Iterable[str]
+    ) -> List[UserPresenceState]:
+        """Get the presence state for users."""
+
+        updates_d = await self.current_state_for_users(target_user_ids)
+        updates = list(updates_d.values())
+
+        for user_id in set(target_user_ids) - {u.user_id for u in updates}:
+            updates.append(UserPresenceState.default(user_id))
+
+        return updates
+
+    async def current_state_for_users(
+        self, user_ids: Iterable[str]
+    ) -> Dict[str, UserPresenceState]:
+        """Get the current presence state for multiple users.
+
+        Returns:
+            dict: `user_id` -> `UserPresenceState`
+        """
+        states = {
+            user_id: self.user_to_current_state.get(user_id, None)
+            for user_id in user_ids
+        }
+
+        missing = [user_id for user_id, state in iteritems(states) if not state]
+        if missing:
+            # There are things not in our in memory cache. Lets pull them out of
+            # the database.
+            res = await self.store.get_presence_for_users(missing)
+            states.update(res)
+
+            missing = [user_id for user_id, state in iteritems(states) if not state]
+            if missing:
+                new = {
+                    user_id: UserPresenceState.default(user_id) for user_id in missing
+                }
+                states.update(new)
+                self.user_to_current_state.update(new)
+
+        return states
+
+    @abc.abstractmethod
+    async def set_state(
+        self, target_user: UserID, state: JsonDict, ignore_status_msg: bool = False
+    ) -> None:
+        """Set the presence state of the user. """
+
+
+class PresenceHandler(BasePresenceHandler):
     def __init__(self, hs: "synapse.server.HomeServer"):
+        super().__init__(hs)
         self.hs = hs
         self.is_mine_id = hs.is_mine_id
         self.server_name = hs.hostname
-        self.clock = hs.get_clock()
-        self.store = hs.get_datastore()
         self.wheel_timer = WheelTimer()
         self.notifier = hs.get_notifier()
         self.federation = hs.get_federation_sender()
@@ -115,13 +209,6 @@ class PresenceHandler(object):
 
         federation_registry.register_edu_handler("m.presence", self.incoming_presence)
 
-        active_presence = self.store.take_presence_startup_info()
-
-        # A dictionary of the current state of users. This is prefilled with
-        # non-offline presence from the DB. We should fetch from the DB if
-        # we can't find a users presence in here.
-        self.user_to_current_state = {state.user_id: state for state in active_presence}
-
         LaterGauge(
             "synapse_handlers_presence_user_to_current_state_size",
             "",
@@ -130,7 +217,7 @@ class PresenceHandler(object):
         )
 
         now = self.clock.time_msec()
-        for state in active_presence:
+        for state in self.user_to_current_state.values():
             self.wheel_timer.insert(
                 now=now, obj=state.user_id, then=state.last_active_ts + IDLE_TIMER
             )
@@ -361,10 +448,18 @@ class PresenceHandler(object):
 
         timers_fired_counter.inc(len(states))
 
+        syncing_user_ids = {
+            user_id
+            for user_id, count in self.user_to_num_current_syncs.items()
+            if count
+        }
+        for user_ids in self.external_process_to_current_syncs.values():
+            syncing_user_ids.update(user_ids)
+
         changes = handle_timeouts(
             states,
             is_mine_fn=self.is_mine_id,
-            syncing_user_ids=self.get_currently_syncing_users(),
+            syncing_user_ids=syncing_user_ids,
             now=now,
         )
 
@@ -462,22 +557,9 @@ class PresenceHandler(object):
 
         return _user_syncing()
 
-    def get_currently_syncing_users(self):
-        """Get the set of user ids that are currently syncing on this HS.
-        Returns:
-            set(str): A set of user_id strings.
-        """
-        if self.hs.config.use_presence:
-            syncing_user_ids = {
-                user_id
-                for user_id, count in self.user_to_num_current_syncs.items()
-                if count
-            }
-            for user_ids in self.external_process_to_current_syncs.values():
-                syncing_user_ids.update(user_ids)
-            return syncing_user_ids
-        else:
-            return set()
+    def get_currently_syncing_users_for_replication(self) -> Iterable[str]:
+        # since we are the process handling presence, there is nothing to do here.
+        return []
 
     async def update_external_syncs_row(
         self, process_id, user_id, is_syncing, sync_time_msec
@@ -554,34 +636,6 @@ class PresenceHandler(object):
         res = await self.current_state_for_users([user_id])
         return res[user_id]
 
-    async def current_state_for_users(self, user_ids):
-        """Get the current presence state for multiple users.
-
-        Returns:
-            dict: `user_id` -> `UserPresenceState`
-        """
-        states = {
-            user_id: self.user_to_current_state.get(user_id, None)
-            for user_id in user_ids
-        }
-
-        missing = [user_id for user_id, state in iteritems(states) if not state]
-        if missing:
-            # There are things not in our in memory cache. Lets pull them out of
-            # the database.
-            res = await self.store.get_presence_for_users(missing)
-            states.update(res)
-
-            missing = [user_id for user_id, state in iteritems(states) if not state]
-            if missing:
-                new = {
-                    user_id: UserPresenceState.default(user_id) for user_id in missing
-                }
-                states.update(new)
-                self.user_to_current_state.update(new)
-
-        return states
-
     async def _persist_and_notify(self, states):
         """Persist states in the database, poke the notifier and send to
         interested remote servers
@@ -669,40 +723,6 @@ class PresenceHandler(object):
             federation_presence_counter.inc(len(updates))
             await self._update_states(updates)
 
-    async def get_state(self, target_user, as_event=False):
-        results = await self.get_states([target_user.to_string()], as_event=as_event)
-
-        return results[0]
-
-    async def get_states(self, target_user_ids, as_event=False):
-        """Get the presence state for users.
-
-        Args:
-            target_user_ids (list)
-            as_event (bool): Whether to format it as a client event or not.
-
-        Returns:
-            list
-        """
-
-        updates = await self.current_state_for_users(target_user_ids)
-        updates = list(updates.values())
-
-        for user_id in set(target_user_ids) - {u.user_id for u in updates}:
-            updates.append(UserPresenceState.default(user_id))
-
-        now = self.clock.time_msec()
-        if as_event:
-            return [
-                {
-                    "type": "m.presence",
-                    "content": format_user_presence_state(state, now),
-                }
-                for state in updates
-            ]
-        else:
-            return updates
-
     async def set_state(self, target_user, state, ignore_status_msg=False):
         """Set the presence state of the user.
         """
@@ -889,7 +909,7 @@ class PresenceHandler(object):
             user_ids = await self.state.get_current_users_in_room(room_id)
             user_ids = list(filter(self.is_mine_id, user_ids))
 
-            states = await self.current_state_for_users(user_ids)
+            states_d = await self.current_state_for_users(user_ids)
 
             # Filter out old presence, i.e. offline presence states where
             # the user hasn't been active for a week. We can change this
@@ -899,7 +919,7 @@ class PresenceHandler(object):
             now = self.clock.time_msec()
             states = [
                 state
-                for state in states.values()
+                for state in states_d.values()
                 if state.state != PresenceState.OFFLINE
                 or now - state.last_active_ts < 7 * 24 * 60 * 60 * 1000
                 or state.status_msg is not None
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index f26aee83cb..c7880d4b63 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -210,7 +210,10 @@ class ReplicateCommand(Command):
 
 class UserSyncCommand(Command):
     """Sent by the client to inform the server that a user has started or
-    stopped syncing. Used to calculate presence on the master.
+    stopped syncing on this process.
+
+    This is used by the process handling presence (typically the master) to
+    calculate who is online and who is not.
 
     Includes a timestamp of when the last user sync was.
 
@@ -218,7 +221,7 @@ class UserSyncCommand(Command):
 
         USER_SYNC <instance_id> <user_id> <state> <last_sync_ms>
 
-    Where <state> is either "start" or "stop"
+    Where <state> is either "start" or "end"
     """
 
     NAME = "USER_SYNC"
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 5b5ee2c13e..0db5a3a24d 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -337,13 +337,6 @@ class ReplicationCommandHandler:
         if self._is_master:
             self._notifier.notify_remote_server_up(cmd.data)
 
-    def get_currently_syncing_users(self):
-        """Get the list of currently syncing users (if any). This is called
-        when a connection has been established and we need to send the
-        currently syncing users.
-        """
-        return self._presence_handler.get_currently_syncing_users()
-
     def new_connection(self, connection: AbstractConnection):
         """Called when we have a new connection.
         """
@@ -361,9 +354,11 @@ class ReplicationCommandHandler:
         if self._factory:
             self._factory.resetDelay()
 
-        # Tell the server if we have any users currently syncing (should only
-        # happen on synchrotrons)
-        currently_syncing = self.get_currently_syncing_users()
+        # Tell the other end if we have any users currently syncing.
+        currently_syncing = (
+            self._presence_handler.get_currently_syncing_users_for_replication()
+        )
+
         now = self._clock.time_msec()
         for user_id in currently_syncing:
             connection.send_command(
diff --git a/synapse/server.pyi b/synapse/server.pyi
index 9013e9bac9..f1a5717028 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -97,7 +97,7 @@ class HomeServer(object):
         pass
     def get_notifier(self) -> synapse.notifier.Notifier:
         pass
-    def get_presence_handler(self) -> synapse.handlers.presence.PresenceHandler:
+    def get_presence_handler(self) -> synapse.handlers.presence.BasePresenceHandler:
         pass
     def get_clock(self) -> synapse.util.Clock:
         pass