summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/7128.misc1
-rw-r--r--docs/tcp_replication.md6
-rw-r--r--synapse/app/generic_worker.py20
-rw-r--r--synapse/replication/tcp/client.py6
-rw-r--r--synapse/replication/tcp/commands.py36
-rw-r--r--synapse/replication/tcp/protocol.py9
-rw-r--r--synapse/replication/tcp/resource.py17
-rw-r--r--synapse/server.py11
-rw-r--r--synapse/server.pyi2
9 files changed, 86 insertions, 22 deletions
diff --git a/changelog.d/7128.misc b/changelog.d/7128.misc
new file mode 100644
index 0000000000..5703f6d2ec
--- /dev/null
+++ b/changelog.d/7128.misc
@@ -0,0 +1 @@
+Add explicit `instance_id` for USER_SYNC commands and remove implicit `conn_id` usage.
diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md
index d4f7d9ec18..3be8e50c4c 100644
--- a/docs/tcp_replication.md
+++ b/docs/tcp_replication.md
@@ -198,6 +198,12 @@ Asks the server for the current position of all streams.
 
    A user has started or stopped syncing
 
+#### CLEAR_USER_SYNC (C)
+
+   The server should clear all associated user sync data from the worker.
+
+   This is used when a worker is shutting down.
+
 #### FEDERATION_ACK (C)
 
    Acknowledge receipt of some federation data
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index fba7ad9551..1ee266f7c5 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -65,6 +65,7 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
 from synapse.replication.slave.storage.room import RoomStore
 from synapse.replication.slave.storage.transactions import SlavedTransactionStore
 from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.replication.tcp.commands import ClearUserSyncsCommand
 from synapse.replication.tcp.streams import (
     AccountDataStream,
     DeviceListsStream,
@@ -124,7 +125,6 @@ from synapse.types import ReadReceipt
 from synapse.util.async_helpers import Linearizer
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.manhole import manhole
-from synapse.util.stringutils import random_string
 from synapse.util.versionstring import get_version_string
 
 logger = logging.getLogger("synapse.app.generic_worker")
@@ -233,6 +233,7 @@ class GenericWorkerPresence(object):
         self.user_to_num_current_syncs = {}
         self.clock = hs.get_clock()
         self.notifier = hs.get_notifier()
+        self.instance_id = hs.get_instance_id()
 
         active_presence = self.store.take_presence_startup_info()
         self.user_to_current_state = {state.user_id: state for state in active_presence}
@@ -245,13 +246,24 @@ class GenericWorkerPresence(object):
             self.send_stop_syncing, UPDATE_SYNCING_USERS_MS
         )
 
-        self.process_id = random_string(16)
-        logger.info("Presence process_id is %r", self.process_id)
+        hs.get_reactor().addSystemEventTrigger(
+            "before",
+            "shutdown",
+            run_as_background_process,
+            "generic_presence.on_shutdown",
+            self._on_shutdown,
+        )
+
+    def _on_shutdown(self):
+        if self.hs.config.use_presence:
+            self.hs.get_tcp_replication().send_command(
+                ClearUserSyncsCommand(self.instance_id)
+            )
 
     def send_user_sync(self, user_id, is_syncing, last_sync_ms):
         if self.hs.config.use_presence:
             self.hs.get_tcp_replication().send_user_sync(
-                user_id, is_syncing, last_sync_ms
+                self.instance_id, user_id, is_syncing, last_sync_ms
             )
 
     def mark_as_coming_online(self, user_id):
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 7e7ad0f798..e86d9805f1 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -189,10 +189,12 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
         """
         self.send_command(FederationAckCommand(token))
 
-    def send_user_sync(self, user_id, is_syncing, last_sync_ms):
+    def send_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms):
         """Poke the master that a user has started/stopped syncing.
         """
-        self.send_command(UserSyncCommand(user_id, is_syncing, last_sync_ms))
+        self.send_command(
+            UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
+        )
 
     def send_remove_pusher(self, app_id, push_key, user_id):
         """Poke the master to remove a pusher for a user
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 5a6b734094..e4eec643f7 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -207,30 +207,32 @@ class UserSyncCommand(Command):
 
     Format::
 
-        USER_SYNC <user_id> <state> <last_sync_ms>
+        USER_SYNC <instance_id> <user_id> <state> <last_sync_ms>
 
     Where <state> is either "start" or "stop"
     """
 
     NAME = "USER_SYNC"
 
-    def __init__(self, user_id, is_syncing, last_sync_ms):
+    def __init__(self, instance_id, user_id, is_syncing, last_sync_ms):
+        self.instance_id = instance_id
         self.user_id = user_id
         self.is_syncing = is_syncing
         self.last_sync_ms = last_sync_ms
 
     @classmethod
     def from_line(cls, line):
-        user_id, state, last_sync_ms = line.split(" ", 2)
+        instance_id, user_id, state, last_sync_ms = line.split(" ", 3)
 
         if state not in ("start", "end"):
             raise Exception("Invalid USER_SYNC state %r" % (state,))
 
-        return cls(user_id, state == "start", int(last_sync_ms))
+        return cls(instance_id, user_id, state == "start", int(last_sync_ms))
 
     def to_line(self):
         return " ".join(
             (
+                self.instance_id,
                 self.user_id,
                 "start" if self.is_syncing else "end",
                 str(self.last_sync_ms),
@@ -238,6 +240,30 @@ class UserSyncCommand(Command):
         )
 
 
+class ClearUserSyncsCommand(Command):
+    """Sent by the client to inform the server that it should drop all
+    information about syncing users sent by the client.
+
+    Mainly used when client is about to shut down.
+
+    Format::
+
+        CLEAR_USER_SYNC <instance_id>
+    """
+
+    NAME = "CLEAR_USER_SYNC"
+
+    def __init__(self, instance_id):
+        self.instance_id = instance_id
+
+    @classmethod
+    def from_line(cls, line):
+        return cls(line)
+
+    def to_line(self):
+        return self.instance_id
+
+
 class FederationAckCommand(Command):
     """Sent by the client when it has processed up to a given point in the
     federation stream. This allows the master to drop in-memory caches of the
@@ -398,6 +424,7 @@ _COMMANDS = (
     InvalidateCacheCommand,
     UserIpCommand,
     RemoteServerUpCommand,
+    ClearUserSyncsCommand,
 )  # type: Tuple[Type[Command], ...]
 
 # Map of command name to command type.
@@ -420,6 +447,7 @@ VALID_CLIENT_COMMANDS = (
     ReplicateCommand.NAME,
     PingCommand.NAME,
     UserSyncCommand.NAME,
+    ClearUserSyncsCommand.NAME,
     FederationAckCommand.NAME,
     RemovePusherCommand.NAME,
     InvalidateCacheCommand.NAME,
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index f81d2e2442..dae246825f 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -423,9 +423,12 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
 
     async def on_USER_SYNC(self, cmd):
         await self.streamer.on_user_sync(
-            self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
+            cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
         )
 
+    async def on_CLEAR_USER_SYNC(self, cmd):
+        await self.streamer.on_clear_user_syncs(cmd.instance_id)
+
     async def on_REPLICATE(self, cmd):
         # Subscribe to all streams we're publishing to.
         for stream_name in self.streamer.streams_by_name:
@@ -551,6 +554,8 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
     ):
         BaseReplicationStreamProtocol.__init__(self, clock)
 
+        self.instance_id = hs.get_instance_id()
+
         self.client_name = client_name
         self.server_name = server_name
         self.handler = handler
@@ -580,7 +585,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
         currently_syncing = self.handler.get_currently_syncing_users()
         now = self.clock.time_msec()
         for user_id in currently_syncing:
-            self.send_command(UserSyncCommand(user_id, True, now))
+            self.send_command(UserSyncCommand(self.instance_id, user_id, True, now))
 
         # We've now finished connecting to so inform the client handler
         self.handler.update_connection(self)
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 4374e99e32..8b6067e20d 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -251,14 +251,19 @@ class ReplicationStreamer(object):
             self.federation_sender.federation_ack(token)
 
     @measure_func("repl.on_user_sync")
-    async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
+    async def on_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms):
         """A client has started/stopped syncing on a worker.
         """
         user_sync_counter.inc()
         await self.presence_handler.update_external_syncs_row(
-            conn_id, user_id, is_syncing, last_sync_ms
+            instance_id, user_id, is_syncing, last_sync_ms
         )
 
+    async def on_clear_user_syncs(self, instance_id):
+        """A replication client wants us to drop all their UserSync data.
+        """
+        await self.presence_handler.update_external_syncs_clear(instance_id)
+
     @measure_func("repl.on_remove_pusher")
     async def on_remove_pusher(self, app_id, push_key, user_id):
         """A client has asked us to remove a pusher
@@ -321,14 +326,6 @@ class ReplicationStreamer(object):
         except ValueError:
             pass
 
-        # We need to tell the presence handler that the connection has been
-        # lost so that it can handle any ongoing syncs on that connection.
-        run_as_background_process(
-            "update_external_syncs_clear",
-            self.presence_handler.update_external_syncs_clear,
-            connection.conn_id,
-        )
-
 
 def _batch_updates(updates):
     """Takes a list of updates of form [(token, row)] and sets the token to
diff --git a/synapse/server.py b/synapse/server.py
index c7ca2bda0d..cd86475d6b 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -103,6 +103,7 @@ from synapse.storage import DataStores, Storage
 from synapse.streams.events import EventSources
 from synapse.util import Clock
 from synapse.util.distributor import Distributor
+from synapse.util.stringutils import random_string
 
 logger = logging.getLogger(__name__)
 
@@ -230,6 +231,8 @@ class HomeServer(object):
         self._listening_services = []
         self.start_time = None
 
+        self.instance_id = random_string(5)
+
         self.clock = Clock(reactor)
         self.distributor = Distributor()
         self.ratelimiter = Ratelimiter()
@@ -242,6 +245,14 @@ class HomeServer(object):
         for depname in kwargs:
             setattr(self, depname, kwargs[depname])
 
+    def get_instance_id(self):
+        """A unique ID for this synapse process instance.
+
+        This is used to distinguish running instances in worker-based
+        deployments.
+        """
+        return self.instance_id
+
     def setup(self):
         logger.info("Setting up.")
         self.start_time = int(self.get_clock().time())
diff --git a/synapse/server.pyi b/synapse/server.pyi
index 3844f0e12f..9d1dfa71e7 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -114,3 +114,5 @@ class HomeServer(object):
         pass
     def is_mine_id(self, domain_id: str) -> bool:
         pass
+    def get_instance_id(self) -> str:
+        pass