diff options
author | Erik Johnston <erik@matrix.org> | 2020-03-03 17:05:35 +0000 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2020-03-20 15:31:52 +0000 |
commit | 32c656865a9d1a7bcd998fb872f6d08e18f39be2 (patch) | |
tree | 04673998c1337ee498efce45ee193934c1b46751 /synapse/replication/tcp | |
parent | Remove unused token param from REPLICATE cmd (diff) | |
download | synapse-32c656865a9d1a7bcd998fb872f6d08e18f39be2.tar.xz |
Always subscribe to all streams.
This already happens since the worker merge.
Diffstat (limited to 'synapse/replication/tcp')
-rw-r--r-- | synapse/replication/tcp/protocol.py | 80 |
1 files changed, 12 insertions, 68 deletions
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 817b84ad7f..b371d66ce7 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -53,17 +53,15 @@ import fcntl import logging import struct from collections import defaultdict -from typing import Any, DefaultDict, Dict, List, Set, Tuple +from typing import Any, DefaultDict, Dict, List, Set -from six import iteritems, iterkeys +from six import iteritems from prometheus_client import Counter -from twisted.internet import defer from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure -from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.commands import ( @@ -412,13 +410,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): self.server_name = server_name self.streamer = streamer - # The streams the client has subscribed to and is up to date with - self.replication_streams = set() # type: Set[str] - - # Map from stream name to list of updates to send once we've finished - # subscribing the client to the stream. - self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]] - def connectionMade(self): self.send_command(ServerCommand(self.server_name)) BaseReplicationStreamProtocol.connectionMade(self) @@ -434,20 +425,10 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): ) async def on_REPLICATE(self, cmd): - stream_name = cmd.stream_name - - if stream_name == "ALL": - # Subscribe to all streams we're publishing to. - deferreds = [ - run_in_background(self.subscribe_to_stream, stream) - for stream in iterkeys(self.streamer.streams_by_name) - ] - - await make_deferred_yieldable( - defer.gatherResults(deferreds, consumeErrors=True) - ) - else: - await self.subscribe_to_stream(stream_name) + # Subscribe to all streams we're publishing to. + for stream_name in self.streamer.streams_by_name: + current_token = self.streamer.get_stream_token(stream_name) + self.send_command(PositionCommand(stream_name, current_token)) async def on_FEDERATION_ACK(self, cmd): self.streamer.federation_ack(cmd.token) @@ -471,37 +452,12 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): cmd.last_seen, ) - async def subscribe_to_stream(self, stream_name): - """Subscribe the remote to a stream. - """ - self.replication_streams.discard(stream_name) - - try: - # Get current stream position. - current_token = self.streamer.get_stream_token(stream_name) - - # We send a POSITION command to ensure that they have an up to - # date token (especially useful if we didn't send any updates - # above) - self.send_command(PositionCommand(stream_name, current_token)) - - # They're now fully subscribed - self.replication_streams.add(stream_name) - except Exception as e: - logger.exception("[%s] Failed to handle REPLICATE command", self.id()) - self.send_error("failed to handle replicate: %r", e) - def stream_update(self, stream_name, token, data): """Called when a new update is available to stream to clients. We need to check if the client is interested in the stream or not """ - if stream_name in self.replication_streams: - # The client is subscribed to the stream - self.send_command(RdataCommand(stream_name, token, data)) - else: - # The client isn't subscribed - logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token) + self.send_command(RdataCommand(stream_name, token, data)) def send_sync(self, data): self.send_command(SyncCommand(data)) @@ -604,7 +560,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # Set of stream names that have been subscribe to, but haven't yet # caught up with. This is used to track when the client has been fully # connected to the remote. - self.streams_connecting = set() # type: Set[str] + self.streams_connecting = set(STREAMS_MAP) # type: Set[str] # Map of stream to batched updates. See RdataCommand for info on how # batching works. @@ -615,8 +571,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): BaseReplicationStreamProtocol.connectionMade(self) # Once we've connected subscribe to the necessary streams - for stream_name in self.handler.get_streams_to_replicate(): - self.replicate(stream_name) + self.replicate() # Tell the server if we have any users currently syncing (should only # happen on synchrotrons) @@ -628,10 +583,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # We've now finished connecting to so inform the client handler self.handler.update_connection(self) - # This will happen if we don't actually subscribe to any streams - if not self.streams_connecting: - self.handler.finished_connecting() - async def on_SERVER(self, cmd): if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) @@ -706,19 +657,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): self.handler.on_remote_server_up(cmd.data) - def replicate(self, stream_name): + def replicate(self): """Send the subscription request to the server """ - if stream_name not in STREAMS_MAP: - raise Exception("Invalid stream name %r" % (stream_name,)) - - logger.info( - "[%s] Subscribing to replication stream: %r", self.id(), stream_name, - ) - - self.streams_connecting.add(stream_name) + logger.info("[%s] Subscribing to replication streams", self.id()) - self.send_command(ReplicateCommand(stream_name)) + self.send_command(ReplicateCommand("ALL")) def on_connection_closed(self): BaseReplicationStreamProtocol.on_connection_closed(self) |