diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 0db5a3a24d..3cdf87e140 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -161,7 +161,7 @@ class ReplicationCommandHandler:
port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self._factory)
- async def on_REPLICATE(self, cmd: ReplicateCommand):
+ async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
# We only want to announce positions by the writer of the streams.
# Currently this is just the master process.
if not self._is_master:
@@ -171,7 +171,7 @@ class ReplicationCommandHandler:
current_token = stream.current_token()
self.send_command(PositionCommand(stream_name, current_token))
- async def on_USER_SYNC(self, cmd: UserSyncCommand):
+ async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
user_sync_counter.inc()
if self._is_master:
@@ -179,17 +179,23 @@ class ReplicationCommandHandler:
cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)
- async def on_CLEAR_USER_SYNC(self, cmd: ClearUserSyncsCommand):
+ async def on_CLEAR_USER_SYNC(
+ self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
+ ):
if self._is_master:
await self._presence_handler.update_external_syncs_clear(cmd.instance_id)
- async def on_FEDERATION_ACK(self, cmd: FederationAckCommand):
+ async def on_FEDERATION_ACK(
+ self, conn: AbstractConnection, cmd: FederationAckCommand
+ ):
federation_ack_counter.inc()
if self._federation_sender:
self._federation_sender.federation_ack(cmd.token)
- async def on_REMOVE_PUSHER(self, cmd: RemovePusherCommand):
+ async def on_REMOVE_PUSHER(
+ self, conn: AbstractConnection, cmd: RemovePusherCommand
+ ):
remove_pusher_counter.inc()
if self._is_master:
@@ -199,7 +205,9 @@ class ReplicationCommandHandler:
self._notifier.on_new_replication_data()
- async def on_INVALIDATE_CACHE(self, cmd: InvalidateCacheCommand):
+ async def on_INVALIDATE_CACHE(
+ self, conn: AbstractConnection, cmd: InvalidateCacheCommand
+ ):
invalidate_cache_counter.inc()
if self._is_master:
@@ -209,7 +217,7 @@ class ReplicationCommandHandler:
cmd.cache_func, tuple(cmd.keys)
)
- async def on_USER_IP(self, cmd: UserIpCommand):
+ async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
user_ip_cache_counter.inc()
if self._is_master:
@@ -225,7 +233,7 @@ class ReplicationCommandHandler:
if self._server_notices_sender:
await self._server_notices_sender.on_user_ip(cmd.user_id)
- async def on_RDATA(self, cmd: RdataCommand):
+ async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
stream_name = cmd.stream_name
inbound_rdata_count.labels(stream_name).inc()
@@ -276,7 +284,7 @@ class ReplicationCommandHandler:
logger.debug("Received rdata %s -> %s", stream_name, token)
await self._replication_data_handler.on_rdata(stream_name, token, rows)
- async def on_POSITION(self, cmd: PositionCommand):
+ async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
stream = self._streams.get(cmd.stream_name)
if not stream:
logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
@@ -330,7 +338,9 @@ class ReplicationCommandHandler:
self._streams_connected.add(cmd.stream_name)
- async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
+ async def on_REMOTE_SERVER_UP(
+ self, conn: AbstractConnection, cmd: RemoteServerUpCommand
+ ):
""""Called when get a new REMOTE_SERVER_UP command."""
self._replication_data_handler.on_remote_server_up(cmd.data)
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index e3f64eba8f..4198eece71 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -260,7 +260,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# Then call out to the handler.
cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
- await cmd_func(cmd)
+ await cmd_func(self, cmd)
handled = True
if not handled:
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 49b3ed0c5e..617e860f95 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -112,7 +112,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
# Then call out to the handler.
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
- await cmd_func(cmd)
+ await cmd_func(self, cmd)
handled = True
if not handled:
|