diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 32f52e54d8..10f5c98ff8 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -422,6 +422,36 @@ class RemoteServerUpCommand(_SimpleCommand):
NAME = "REMOTE_SERVER_UP"
+class LockReleasedCommand(Command):
+ """Sent to inform other instances that a given lock has been dropped.
+
+ Format::
+
+ LOCK_RELEASED ["<instance_name>", "<lock_name>", "<lock_key>"]
+ """
+
+ NAME = "LOCK_RELEASED"
+
+ def __init__(
+ self,
+ instance_name: str,
+ lock_name: str,
+ lock_key: str,
+ ):
+ self.instance_name = instance_name
+ self.lock_name = lock_name
+ self.lock_key = lock_key
+
+ @classmethod
+ def from_line(cls: Type["LockReleasedCommand"], line: str) -> "LockReleasedCommand":
+ instance_name, lock_name, lock_key = json_decoder.decode(line)
+
+ return cls(instance_name, lock_name, lock_key)
+
+ def to_line(self) -> str:
+ return json_encoder.encode([self.instance_name, self.lock_name, self.lock_key])
+
+
_COMMANDS: Tuple[Type[Command], ...] = (
ServerCommand,
RdataCommand,
@@ -435,6 +465,7 @@ _COMMANDS: Tuple[Type[Command], ...] = (
UserIpCommand,
RemoteServerUpCommand,
ClearUserSyncsCommand,
+ LockReleasedCommand,
)
# Map of command name to command type.
@@ -448,6 +479,7 @@ VALID_SERVER_COMMANDS = (
ErrorCommand.NAME,
PingCommand.NAME,
RemoteServerUpCommand.NAME,
+ LockReleasedCommand.NAME,
)
# The commands the client is allowed to send
@@ -461,6 +493,7 @@ VALID_CLIENT_COMMANDS = (
UserIpCommand.NAME,
ErrorCommand.NAME,
RemoteServerUpCommand.NAME,
+ LockReleasedCommand.NAME,
)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 5d108fe11b..a2cabba7b1 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -39,6 +39,7 @@ from synapse.replication.tcp.commands import (
ClearUserSyncsCommand,
Command,
FederationAckCommand,
+ LockReleasedCommand,
PositionCommand,
RdataCommand,
RemoteServerUpCommand,
@@ -248,6 +249,9 @@ class ReplicationCommandHandler:
if self._is_master or self._should_insert_client_ips:
self.subscribe_to_channel("USER_IP")
+ if hs.config.redis.redis_enabled:
+ self._notifier.add_lock_released_callback(self.on_lock_released)
+
def subscribe_to_channel(self, channel_name: str) -> None:
"""
Indicates that we wish to subscribe to a Redis channel by name.
@@ -648,6 +652,17 @@ class ReplicationCommandHandler:
self._notifier.notify_remote_server_up(cmd.data)
+ def on_LOCK_RELEASED(
+ self, conn: IReplicationConnection, cmd: LockReleasedCommand
+ ) -> None:
+ """Called when we get a new LOCK_RELEASED command."""
+ if cmd.instance_name == self._instance_name:
+ return
+
+ self._notifier.notify_lock_released(
+ cmd.instance_name, cmd.lock_name, cmd.lock_key
+ )
+
def new_connection(self, connection: IReplicationConnection) -> None:
"""Called when we have a new connection."""
self._connections.append(connection)
@@ -754,6 +769,13 @@ class ReplicationCommandHandler:
"""
self.send_command(RdataCommand(stream_name, self._instance_name, token, data))
+ def on_lock_released(
+ self, instance_name: str, lock_name: str, lock_key: str
+ ) -> None:
+ """Called when we released a lock and should notify other instances."""
+ if instance_name == self._instance_name:
+ self.send_command(LockReleasedCommand(instance_name, lock_name, lock_key))
+
UpdateToken = TypeVar("UpdateToken")
UpdateRow = TypeVar("UpdateRow")
|