diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index e92da7b263..95e5502bf2 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -101,8 +101,9 @@ class ReplicationCommandHandler:
self._streams_to_replicate = [] # type: List[Stream]
for stream in self._streams.values():
- if stream.NAME == CachesStream.NAME:
- # All workers can write to the cache invalidation stream.
+ if hs.config.redis.redis_enabled and stream.NAME == CachesStream.NAME:
+ # All workers can write to the cache invalidation stream when
+ # using redis.
self._streams_to_replicate.append(stream)
continue
@@ -313,11 +314,14 @@ class ReplicationCommandHandler:
# We respond with current position of all streams this instance
# replicates.
for stream in self.get_streams_to_replicate():
+ # Note that we use the current token as the prev token here (rather
+ # than stream.last_token), as we can't be sure that there have been
+ # no rows written between last token and the current token (since we
+ # might be racing with the replication sending bg process).
+ current_token = stream.current_token(self._instance_name)
self.send_command(
PositionCommand(
- stream.NAME,
- self._instance_name,
- stream.current_token(self._instance_name),
+ stream.NAME, self._instance_name, current_token, current_token,
)
)
@@ -511,16 +515,16 @@ class ReplicationCommandHandler:
# If the position token matches our current token then we're up to
# date and there's nothing to do. Otherwise, fetch all updates
# between then and now.
- missing_updates = cmd.token != current_token
+ missing_updates = cmd.prev_token != current_token
while missing_updates:
logger.info(
"Fetching replication rows for '%s' between %i and %i",
stream_name,
current_token,
- cmd.token,
+ cmd.new_token,
)
(updates, current_token, missing_updates) = await stream.get_updates_since(
- cmd.instance_name, current_token, cmd.token
+ cmd.instance_name, current_token, cmd.new_token
)
# TODO: add some tests for this
@@ -536,11 +540,11 @@ class ReplicationCommandHandler:
[stream.parse_row(row) for row in rows],
)
- logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
+ logger.info("Caught up with stream '%s' to %i", stream_name, cmd.new_token)
# We've now caught up to position sent to us, notify handler.
await self._replication_data_handler.on_position(
- cmd.stream_name, cmd.instance_name, cmd.token
+ cmd.stream_name, cmd.instance_name, cmd.new_token
)
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
|