diff options
Diffstat (limited to 'synapse/replication/tcp/handler.py')
-rw-r--r-- | synapse/replication/tcp/handler.py | 30 |
1 files changed, 16 insertions, 14 deletions
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 6f7054d5af..2d1d119c7c 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -278,19 +278,24 @@ class ReplicationCommandHandler: # Check if this is the last of a batch of updates rows = self._pending_batches.pop(stream_name, []) rows.append(row) - await self.on_rdata(stream_name, cmd.token, rows) + await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows) - async def on_rdata(self, stream_name: str, token: int, rows: list): + async def on_rdata( + self, stream_name: str, instance_name: str, token: int, rows: list + ): """Called to handle a batch of replication data with a given stream token. Args: stream_name: name of the replication stream for this batch of rows + instance_name: the instance that wrote the rows. token: stream token for this batch of rows rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ logger.debug("Received rdata %s -> %s", stream_name, token) - await self._replication_data_handler.on_rdata(stream_name, token, rows) + await self._replication_data_handler.on_rdata( + stream_name, instance_name, token, rows + ) async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): if cmd.instance_name == self._instance_name: @@ -314,15 +319,7 @@ class ReplicationCommandHandler: self._pending_batches.pop(cmd.stream_name, []) # Find where we previously streamed up to. - current_token = self._replication_data_handler.get_streams_to_replicate().get( - cmd.stream_name - ) - if current_token is None: - logger.warning( - "Got POSITION for stream we're not subscribed to: %s", - cmd.stream_name, - ) - return + current_token = stream.current_token() # If the position token matches our current token then we're up to # date and there's nothing to do. Otherwise, fetch all updates @@ -333,7 +330,9 @@ class ReplicationCommandHandler: updates, current_token, missing_updates, - ) = await stream.get_updates_since(current_token, cmd.token) + ) = await stream.get_updates_since( + cmd.instance_name, current_token, cmd.token + ) # TODO: add some tests for this @@ -342,7 +341,10 @@ class ReplicationCommandHandler: for token, rows in _batch_updates(updates): await self.on_rdata( - cmd.stream_name, token, [stream.parse_row(row) for row in rows], + cmd.stream_name, + cmd.instance_name, + token, + [stream.parse_row(row) for row in rows], ) # We've now caught up to position sent to us, notify handler. |