summary refs log tree commit diff
path: root/synapse/replication/tcp/protocol.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/protocol.py')
-rw-r--r--synapse/replication/tcp/protocol.py48
1 files changed, 45 insertions, 3 deletions
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py

index 429471c345..55630ba9a7 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py
@@ -451,7 +451,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): @defer.inlineCallbacks def subscribe_to_stream(self, stream_name, token): - """Subscribe the remote to a streams. + """Subscribe the remote to a stream. This invloves checking if they've missed anything and sending those updates down if they have. During that time new updates for the stream @@ -478,11 +478,36 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): # Now we can send any updates that came in while we were subscribing pending_rdata = self.pending_rdata.pop(stream_name, []) + updates = [] for token, update in pending_rdata: - # Only send updates newer than the current token - if token > current_token: + # If the token is null, it is part of a batch update. Batches + # are multiple updates that share a single token. To denote + # this, the token is set to None for all tokens in the batch + # except for the last. If we find a None token, we keep looking + # through tokens until we find one that is not None and then + # process all previous updates in the batch as if they had the + # final token. + if token is None: + # Store this update as part of a batch + updates.append(update) + continue + + if token <= current_token: + # This update or batch of updates is older than + # current_token, dismiss it + updates = [] + continue + + updates.append(update) + + # Send all updates that are part of this batch with the + # found token + for update in updates: self.send_command(RdataCommand(stream_name, token, update)) + # Clear stored updates + updates = [] + # They're now fully subscribed self.replication_streams.add(stream_name) except Exception as e: @@ -526,6 +551,11 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): self.server_name = server_name self.handler = handler + # Set of stream names that have been subscribe to, but haven't yet + # caught up with. This is used to track when the client has been fully + # connected to the remote. + self.streams_connecting = set() + # Map of stream to batched updates. See RdataCommand for info on how # batching works. self.pending_batches = {} @@ -548,6 +578,10 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # We've now finished connecting to so inform the client handler self.handler.update_connection(self) + # This will happen if we don't actually subscribe to any streams + if not self.streams_connecting: + self.handler.finished_connecting() + def on_SERVER(self, cmd): if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) @@ -577,6 +611,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): return self.handler.on_rdata(stream_name, cmd.token, rows) def on_POSITION(self, cmd): + # When we get a `POSITION` command it means we've finished getting + # missing updates for the given stream, and are now up to date. + self.streams_connecting.discard(cmd.stream_name) + if not self.streams_connecting: + self.handler.finished_connecting() + return self.handler.on_position(cmd.stream_name, cmd.token) def on_SYNC(self, cmd): @@ -593,6 +633,8 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): self.id(), stream_name, token ) + self.streams_connecting.add(stream_name) + self.send_command(ReplicateCommand(stream_name, token)) def on_connection_closed(self):