From 1f83255de17eb2de35fc42b91ebaaaf895771aa6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 3 Mar 2020 14:19:23 +0000 Subject: Move stream catchup to workers. --- synapse/replication/tcp/client.py | 3 +- synapse/replication/tcp/protocol.py | 105 ++++++++++++---------------- synapse/replication/tcp/resource.py | 5 +- synapse/replication/tcp/streams/__init__.py | 6 +- 4 files changed, 54 insertions(+), 65 deletions(-) (limited to 'synapse') diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 02ab5b66ea..7e7ad0f798 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -55,6 +55,7 @@ class ReplicationClientFactory(ReconnectingClientFactory): self.client_name = client_name self.handler = handler self.server_name = hs.config.server_name + self.hs = hs self._clock = hs.get_clock() # As self.clock is defined in super class hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying) @@ -65,7 +66,7 @@ class ReplicationClientFactory(ReconnectingClientFactory): def buildProtocol(self, addr): logger.info("Connected to replication: %r", addr) return ClientReplicationStreamProtocol( - self.client_name, self.server_name, self._clock, self.handler + self.hs, self.client_name, self.server_name, self._clock, self.handler, ) def clientConnectionLost(self, connector, reason): diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index d7ef2398fa..649312f022 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -82,7 +82,8 @@ from synapse.replication.tcp.commands import ( SyncCommand, UserSyncCommand, ) -from synapse.replication.tcp.streams import STREAMS_MAP +from synapse.replication.tcp.streams import STREAMS_MAP, Stream +from synapse.server import HomeServer from synapse.types import Collection from synapse.util import Clock from synapse.util.stringutils import random_string @@ -414,9 +415,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): # The streams the client has subscribed to and is up to date with self.replication_streams = set() # type: Set[str] - # The streams the client is currently subscribing to. - self.connecting_streams = set() # type: Set[str] - # Map from stream name to list of updates to send once we've finished # subscribing the client to the stream. self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]] @@ -482,67 +480,21 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): are queued and sent once we've sent down any missed updates. """ self.replication_streams.discard(stream_name) - self.connecting_streams.add(stream_name) try: - limited = True - while limited: - # Get missing updates - ( - updates, - current_token, - limited, - ) = await self.streamer.get_stream_updates(stream_name, token) - - # Send all the missing updates - for update in updates: - token, row = update[0], update[1] - self.send_command(RdataCommand(stream_name, token, row)) + # Get current stream position. + current_token = self.streamer.get_stream_token(stream_name) # We send a POSITION command to ensure that they have an up to # date token (especially useful if we didn't send any updates # above) self.send_command(PositionCommand(stream_name, current_token)) - # Now we can send any updates that came in while we were subscribing - pending_rdata = self.pending_rdata.pop(stream_name, []) - updates = [] - for token, update in pending_rdata: - # If the token is null, it is part of a batch update. Batches - # are multiple updates that share a single token. To denote - # this, the token is set to None for all tokens in the batch - # except for the last. If we find a None token, we keep looking - # through tokens until we find one that is not None and then - # process all previous updates in the batch as if they had the - # final token. - if token is None: - # Store this update as part of a batch - updates.append(update) - continue - - if token <= current_token: - # This update or batch of updates is older than - # current_token, dismiss it - updates = [] - continue - - updates.append(update) - - # Send all updates that are part of this batch with the - # found token - for update in updates: - self.send_command(RdataCommand(stream_name, token, update)) - - # Clear stored updates - updates = [] - # They're now fully subscribed self.replication_streams.add(stream_name) except Exception as e: logger.exception("[%s] Failed to handle REPLICATE command", self.id()) self.send_error("failed to handle replicate: %r", e) - finally: - self.connecting_streams.discard(stream_name) def stream_update(self, stream_name, token, data): """Called when a new update is available to stream to clients. @@ -552,10 +504,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): if stream_name in self.replication_streams: # The client is subscribed to the stream self.send_command(RdataCommand(stream_name, token, data)) - elif stream_name in self.connecting_streams: - # The client is being subscribed to the stream - logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token) - self.pending_rdata.setdefault(stream_name, []).append((token, data)) else: # The client isn't subscribed logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token) @@ -642,6 +590,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): def __init__( self, + hs: HomeServer, client_name: str, server_name: str, clock: Clock, @@ -653,6 +602,10 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): self.server_name = server_name self.handler = handler + self.streams = { + stream.NAME: stream(hs) for stream in STREAMS_MAP.values() + } # type: Dict[str, Stream] + # Set of stream names that have been subscribe to, but haven't yet # caught up with. This is used to track when the client has been fully # connected to the remote. @@ -660,7 +613,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # Map of stream to batched updates. See RdataCommand for info on how # batching works. - self.pending_batches = {} # type: Dict[str, Any] + self.pending_batches = {} # type: Dict[str, List[Any]] def connectionMade(self): self.send_command(NameCommand(self.client_name)) @@ -701,7 +654,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): ) raise - if cmd.token is None: + if cmd.token is None or stream_name in self.streams_connecting: # I.e. this is part of a batch of updates for this stream. Batch # until we get an update for the stream with a non None token self.pending_batches.setdefault(stream_name, []).append(row) @@ -711,14 +664,46 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): rows.append(row) await self.handler.on_rdata(stream_name, cmd.token, rows) - async def on_POSITION(self, cmd): + async def on_POSITION(self, cmd: PositionCommand): + stream = self.streams.get(cmd.stream_name) + if not stream: + logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) + return + + # Find where we previously streamed up to. + current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name) + if current_token is None: + logger.warning( + "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name + ) + return + + # Fetch all updates between then and now. + limited = True + while limited: + updates, current_token, limited = await stream.get_updates_since( + current_token, cmd.token + ) + if updates: + await self.handler.on_rdata( + cmd.stream_name, + current_token, + [stream.parse_row(update[1]) for update in updates], + ) + + # We've now caught up to position sent to us, notify handler. + await self.handler.on_position(cmd.stream_name, cmd.token) + # When we get a `POSITION` command it means we've finished getting # missing updates for the given stream, and are now up to date. self.streams_connecting.discard(cmd.stream_name) if not self.streams_connecting: self.handler.finished_connecting() - await self.handler.on_position(cmd.stream_name, cmd.token) + # Handle any RDATA that came in while we were catching up. + rows = self.pending_batches.pop(cmd.stream_name, []) + if rows: + await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows) async def on_SYNC(self, cmd): self.handler.on_sync(cmd.data) diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 5be31024b7..757129b6d5 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -227,8 +227,7 @@ class ReplicationStreamer(object): self.pending_updates = False self.is_looping = False - @measure_func("repl.get_stream_updates") - async def get_stream_updates(self, stream_name, token): + def get_stream_token(self, stream_name): """For a given stream get all updates since token. This is called when a client first subscribes to a stream. """ @@ -236,7 +235,7 @@ class ReplicationStreamer(object): if not stream: raise Exception("unknown stream %s", stream_name) - return await stream.get_updates_since(token, stream.current_token()) + return stream.current_token() @measure_func("repl.federation_ack") def federation_ack(self, token): diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index c3b9a90ca5..6f5da99f85 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -27,7 +27,8 @@ Each stream is defined by the following information: from typing import Dict, Type -from . import _base, events, federation +from synapse.replication.tcp.streams import _base, events, federation +from synapse.replication.tcp.streams._base import Stream STREAMS_MAP = { stream.NAME: stream @@ -50,3 +51,6 @@ STREAMS_MAP = { _base.UserSignatureStream, ) } # type: Dict[str, Type[_base.Stream]] + + +__all__ = ["Stream", "STREAMS_MAP"] -- cgit 1.4.1