diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 7e7ad0f798..e86d9805f1 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -189,10 +189,12 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
"""
self.send_command(FederationAckCommand(token))
- def send_user_sync(self, user_id, is_syncing, last_sync_ms):
+ def send_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms):
"""Poke the master that a user has started/stopped syncing.
"""
- self.send_command(UserSyncCommand(user_id, is_syncing, last_sync_ms))
+ self.send_command(
+ UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
+ )
def send_remove_pusher(self, app_id, push_key, user_id):
"""Poke the master to remove a pusher for a user
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 5a6b734094..e4eec643f7 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -207,30 +207,32 @@ class UserSyncCommand(Command):
Format::
- USER_SYNC <user_id> <state> <last_sync_ms>
+ USER_SYNC <instance_id> <user_id> <state> <last_sync_ms>
Where <state> is either "start" or "stop"
"""
NAME = "USER_SYNC"
- def __init__(self, user_id, is_syncing, last_sync_ms):
+ def __init__(self, instance_id, user_id, is_syncing, last_sync_ms):
+ self.instance_id = instance_id
self.user_id = user_id
self.is_syncing = is_syncing
self.last_sync_ms = last_sync_ms
@classmethod
def from_line(cls, line):
- user_id, state, last_sync_ms = line.split(" ", 2)
+ instance_id, user_id, state, last_sync_ms = line.split(" ", 3)
if state not in ("start", "end"):
raise Exception("Invalid USER_SYNC state %r" % (state,))
- return cls(user_id, state == "start", int(last_sync_ms))
+ return cls(instance_id, user_id, state == "start", int(last_sync_ms))
def to_line(self):
return " ".join(
(
+ self.instance_id,
self.user_id,
"start" if self.is_syncing else "end",
str(self.last_sync_ms),
@@ -238,6 +240,30 @@ class UserSyncCommand(Command):
)
+class ClearUserSyncsCommand(Command):
+ """Sent by the client to inform the server that it should drop all
+ information about syncing users sent by the client.
+
+ Mainly used when client is about to shut down.
+
+ Format::
+
+ CLEAR_USER_SYNC <instance_id>
+ """
+
+ NAME = "CLEAR_USER_SYNC"
+
+ def __init__(self, instance_id):
+ self.instance_id = instance_id
+
+ @classmethod
+ def from_line(cls, line):
+ return cls(line)
+
+ def to_line(self):
+ return self.instance_id
+
+
class FederationAckCommand(Command):
"""Sent by the client when it has processed up to a given point in the
federation stream. This allows the master to drop in-memory caches of the
@@ -398,6 +424,7 @@ _COMMANDS = (
InvalidateCacheCommand,
UserIpCommand,
RemoteServerUpCommand,
+ ClearUserSyncsCommand,
) # type: Tuple[Type[Command], ...]
# Map of command name to command type.
@@ -420,6 +447,7 @@ VALID_CLIENT_COMMANDS = (
ReplicateCommand.NAME,
PingCommand.NAME,
UserSyncCommand.NAME,
+ ClearUserSyncsCommand.NAME,
FederationAckCommand.NAME,
RemovePusherCommand.NAME,
InvalidateCacheCommand.NAME,
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index f81d2e2442..dae246825f 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -423,9 +423,12 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
async def on_USER_SYNC(self, cmd):
await self.streamer.on_user_sync(
- self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
+ cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)
+ async def on_CLEAR_USER_SYNC(self, cmd):
+ await self.streamer.on_clear_user_syncs(cmd.instance_id)
+
async def on_REPLICATE(self, cmd):
# Subscribe to all streams we're publishing to.
for stream_name in self.streamer.streams_by_name:
@@ -551,6 +554,8 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
):
BaseReplicationStreamProtocol.__init__(self, clock)
+ self.instance_id = hs.get_instance_id()
+
self.client_name = client_name
self.server_name = server_name
self.handler = handler
@@ -580,7 +585,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
currently_syncing = self.handler.get_currently_syncing_users()
now = self.clock.time_msec()
for user_id in currently_syncing:
- self.send_command(UserSyncCommand(user_id, True, now))
+ self.send_command(UserSyncCommand(self.instance_id, user_id, True, now))
# We've now finished connecting to so inform the client handler
self.handler.update_connection(self)
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 4374e99e32..8b6067e20d 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -251,14 +251,19 @@ class ReplicationStreamer(object):
self.federation_sender.federation_ack(token)
@measure_func("repl.on_user_sync")
- async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
+ async def on_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms):
"""A client has started/stopped syncing on a worker.
"""
user_sync_counter.inc()
await self.presence_handler.update_external_syncs_row(
- conn_id, user_id, is_syncing, last_sync_ms
+ instance_id, user_id, is_syncing, last_sync_ms
)
+ async def on_clear_user_syncs(self, instance_id):
+ """A replication client wants us to drop all their UserSync data.
+ """
+ await self.presence_handler.update_external_syncs_clear(instance_id)
+
@measure_func("repl.on_remove_pusher")
async def on_remove_pusher(self, app_id, push_key, user_id):
"""A client has asked us to remove a pusher
@@ -321,14 +326,6 @@ class ReplicationStreamer(object):
except ValueError:
pass
- # We need to tell the presence handler that the connection has been
- # lost so that it can handle any ongoing syncs on that connection.
- run_as_background_process(
- "update_external_syncs_clear",
- self.presence_handler.update_external_syncs_clear,
- connection.conn_id,
- )
-
def _batch_updates(updates):
"""Takes a list of updates of form [(token, row)] and sets the token to
|