diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index d72f3d0cf9..2d1d119c7c 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -278,19 +278,24 @@ class ReplicationCommandHandler:
# Check if this is the last of a batch of updates
rows = self._pending_batches.pop(stream_name, [])
rows.append(row)
- await self.on_rdata(stream_name, cmd.token, rows)
+ await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
- async def on_rdata(self, stream_name: str, token: int, rows: list):
+ async def on_rdata(
+ self, stream_name: str, instance_name: str, token: int, rows: list
+ ):
"""Called to handle a batch of replication data with a given stream token.
Args:
stream_name: name of the replication stream for this batch of rows
+ instance_name: the instance that wrote the rows.
token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
logger.debug("Received rdata %s -> %s", stream_name, token)
- await self._replication_data_handler.on_rdata(stream_name, token, rows)
+ await self._replication_data_handler.on_rdata(
+ stream_name, instance_name, token, rows
+ )
async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
if cmd.instance_name == self._instance_name:
@@ -325,7 +330,9 @@ class ReplicationCommandHandler:
updates,
current_token,
missing_updates,
- ) = await stream.get_updates_since(current_token, cmd.token)
+ ) = await stream.get_updates_since(
+ cmd.instance_name, current_token, cmd.token
+ )
# TODO: add some tests for this
@@ -334,7 +341,10 @@ class ReplicationCommandHandler:
for token, rows in _batch_updates(updates):
await self.on_rdata(
- cmd.stream_name, token, [stream.parse_row(row) for row in rows],
+ cmd.stream_name,
+ cmd.instance_name,
+ token,
+ [stream.parse_row(row) for row in rows],
)
# We've now caught up to position sent to us, notify handler.
|