diff --git a/changelog.d/7128.misc b/changelog.d/7128.misc
new file mode 100644
index 0000000000..5703f6d2ec
--- /dev/null
+++ b/changelog.d/7128.misc
@@ -0,0 +1 @@
+Add explicit `instance_id` for USER_SYNC commands and remove implicit `conn_id` usage.
diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md
index d4f7d9ec18..3be8e50c4c 100644
--- a/docs/tcp_replication.md
+++ b/docs/tcp_replication.md
@@ -198,6 +198,12 @@ Asks the server for the current position of all streams.
A user has started or stopped syncing
+#### CLEAR_USER_SYNC (C)
+
+ The server should clear all associated user sync data from the worker.
+
+ This is used when a worker is shutting down.
+
#### FEDERATION_ACK (C)
Acknowledge receipt of some federation data
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index fba7ad9551..1ee266f7c5 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -65,6 +65,7 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.replication.tcp.commands import ClearUserSyncsCommand
from synapse.replication.tcp.streams import (
AccountDataStream,
DeviceListsStream,
@@ -124,7 +125,6 @@ from synapse.types import ReadReceipt
from synapse.util.async_helpers import Linearizer
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
-from synapse.util.stringutils import random_string
from synapse.util.versionstring import get_version_string
logger = logging.getLogger("synapse.app.generic_worker")
@@ -233,6 +233,7 @@ class GenericWorkerPresence(object):
self.user_to_num_current_syncs = {}
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
+ self.instance_id = hs.get_instance_id()
active_presence = self.store.take_presence_startup_info()
self.user_to_current_state = {state.user_id: state for state in active_presence}
@@ -245,13 +246,24 @@ class GenericWorkerPresence(object):
self.send_stop_syncing, UPDATE_SYNCING_USERS_MS
)
- self.process_id = random_string(16)
- logger.info("Presence process_id is %r", self.process_id)
+ hs.get_reactor().addSystemEventTrigger(
+ "before",
+ "shutdown",
+ run_as_background_process,
+ "generic_presence.on_shutdown",
+ self._on_shutdown,
+ )
+
+ def _on_shutdown(self):
+ if self.hs.config.use_presence:
+ self.hs.get_tcp_replication().send_command(
+ ClearUserSyncsCommand(self.instance_id)
+ )
def send_user_sync(self, user_id, is_syncing, last_sync_ms):
if self.hs.config.use_presence:
self.hs.get_tcp_replication().send_user_sync(
- user_id, is_syncing, last_sync_ms
+ self.instance_id, user_id, is_syncing, last_sync_ms
)
def mark_as_coming_online(self, user_id):
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 7e7ad0f798..e86d9805f1 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -189,10 +189,12 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
"""
self.send_command(FederationAckCommand(token))
- def send_user_sync(self, user_id, is_syncing, last_sync_ms):
+ def send_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms):
"""Poke the master that a user has started/stopped syncing.
"""
- self.send_command(UserSyncCommand(user_id, is_syncing, last_sync_ms))
+ self.send_command(
+ UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
+ )
def send_remove_pusher(self, app_id, push_key, user_id):
"""Poke the master to remove a pusher for a user
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 5a6b734094..e4eec643f7 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -207,30 +207,32 @@ class UserSyncCommand(Command):
Format::
- USER_SYNC <user_id> <state> <last_sync_ms>
+ USER_SYNC <instance_id> <user_id> <state> <last_sync_ms>
Where <state> is either "start" or "stop"
"""
NAME = "USER_SYNC"
- def __init__(self, user_id, is_syncing, last_sync_ms):
+ def __init__(self, instance_id, user_id, is_syncing, last_sync_ms):
+ self.instance_id = instance_id
self.user_id = user_id
self.is_syncing = is_syncing
self.last_sync_ms = last_sync_ms
@classmethod
def from_line(cls, line):
- user_id, state, last_sync_ms = line.split(" ", 2)
+ instance_id, user_id, state, last_sync_ms = line.split(" ", 3)
if state not in ("start", "end"):
raise Exception("Invalid USER_SYNC state %r" % (state,))
- return cls(user_id, state == "start", int(last_sync_ms))
+ return cls(instance_id, user_id, state == "start", int(last_sync_ms))
def to_line(self):
return " ".join(
(
+ self.instance_id,
self.user_id,
"start" if self.is_syncing else "end",
str(self.last_sync_ms),
@@ -238,6 +240,30 @@ class UserSyncCommand(Command):
)
+class ClearUserSyncsCommand(Command):
+ """Sent by the client to inform the server that it should drop all
+ information about syncing users sent by the client.
+
+ Mainly used when client is about to shut down.
+
+ Format::
+
+ CLEAR_USER_SYNC <instance_id>
+ """
+
+ NAME = "CLEAR_USER_SYNC"
+
+ def __init__(self, instance_id):
+ self.instance_id = instance_id
+
+ @classmethod
+ def from_line(cls, line):
+ return cls(line)
+
+ def to_line(self):
+ return self.instance_id
+
+
class FederationAckCommand(Command):
"""Sent by the client when it has processed up to a given point in the
federation stream. This allows the master to drop in-memory caches of the
@@ -398,6 +424,7 @@ _COMMANDS = (
InvalidateCacheCommand,
UserIpCommand,
RemoteServerUpCommand,
+ ClearUserSyncsCommand,
) # type: Tuple[Type[Command], ...]
# Map of command name to command type.
@@ -420,6 +447,7 @@ VALID_CLIENT_COMMANDS = (
ReplicateCommand.NAME,
PingCommand.NAME,
UserSyncCommand.NAME,
+ ClearUserSyncsCommand.NAME,
FederationAckCommand.NAME,
RemovePusherCommand.NAME,
InvalidateCacheCommand.NAME,
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index f81d2e2442..dae246825f 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -423,9 +423,12 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
async def on_USER_SYNC(self, cmd):
await self.streamer.on_user_sync(
- self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
+ cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)
+ async def on_CLEAR_USER_SYNC(self, cmd):
+ await self.streamer.on_clear_user_syncs(cmd.instance_id)
+
async def on_REPLICATE(self, cmd):
# Subscribe to all streams we're publishing to.
for stream_name in self.streamer.streams_by_name:
@@ -551,6 +554,8 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
):
BaseReplicationStreamProtocol.__init__(self, clock)
+ self.instance_id = hs.get_instance_id()
+
self.client_name = client_name
self.server_name = server_name
self.handler = handler
@@ -580,7 +585,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
currently_syncing = self.handler.get_currently_syncing_users()
now = self.clock.time_msec()
for user_id in currently_syncing:
- self.send_command(UserSyncCommand(user_id, True, now))
+ self.send_command(UserSyncCommand(self.instance_id, user_id, True, now))
# We've now finished connecting to so inform the client handler
self.handler.update_connection(self)
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 4374e99e32..8b6067e20d 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -251,14 +251,19 @@ class ReplicationStreamer(object):
self.federation_sender.federation_ack(token)
@measure_func("repl.on_user_sync")
- async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
+ async def on_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms):
"""A client has started/stopped syncing on a worker.
"""
user_sync_counter.inc()
await self.presence_handler.update_external_syncs_row(
- conn_id, user_id, is_syncing, last_sync_ms
+ instance_id, user_id, is_syncing, last_sync_ms
)
+ async def on_clear_user_syncs(self, instance_id):
+ """A replication client wants us to drop all their UserSync data.
+ """
+ await self.presence_handler.update_external_syncs_clear(instance_id)
+
@measure_func("repl.on_remove_pusher")
async def on_remove_pusher(self, app_id, push_key, user_id):
"""A client has asked us to remove a pusher
@@ -321,14 +326,6 @@ class ReplicationStreamer(object):
except ValueError:
pass
- # We need to tell the presence handler that the connection has been
- # lost so that it can handle any ongoing syncs on that connection.
- run_as_background_process(
- "update_external_syncs_clear",
- self.presence_handler.update_external_syncs_clear,
- connection.conn_id,
- )
-
def _batch_updates(updates):
"""Takes a list of updates of form [(token, row)] and sets the token to
diff --git a/synapse/server.py b/synapse/server.py
index c7ca2bda0d..cd86475d6b 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -103,6 +103,7 @@ from synapse.storage import DataStores, Storage
from synapse.streams.events import EventSources
from synapse.util import Clock
from synapse.util.distributor import Distributor
+from synapse.util.stringutils import random_string
logger = logging.getLogger(__name__)
@@ -230,6 +231,8 @@ class HomeServer(object):
self._listening_services = []
self.start_time = None
+ self.instance_id = random_string(5)
+
self.clock = Clock(reactor)
self.distributor = Distributor()
self.ratelimiter = Ratelimiter()
@@ -242,6 +245,14 @@ class HomeServer(object):
for depname in kwargs:
setattr(self, depname, kwargs[depname])
+ def get_instance_id(self):
+ """A unique ID for this synapse process instance.
+
+ This is used to distinguish running instances in worker-based
+ deployments.
+ """
+ return self.instance_id
+
def setup(self):
logger.info("Setting up.")
self.start_time = int(self.get_clock().time())
diff --git a/synapse/server.pyi b/synapse/server.pyi
index 3844f0e12f..9d1dfa71e7 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -114,3 +114,5 @@ class HomeServer(object):
pass
def is_mine_id(self, domain_id: str) -> bool:
pass
+ def get_instance_id(self) -> str:
+ pass
|