From 4cff617df1ba6f241fee6957cc44859f57edcc0e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 25 Mar 2020 14:54:01 +0000 Subject: Move catchup of replication streams to worker. (#7024) This changes the replication protocol so that the server does not send down `RDATA` for rows that happened before the client connected. Instead, the server will send a `POSITION` and clients then query the database (or master out of band) to get up to date. --- synapse/replication/tcp/protocol.py | 206 +++++++++++++----------------------- 1 file changed, 71 insertions(+), 135 deletions(-) (limited to 'synapse/replication/tcp/protocol.py') diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index bc1482a9bb..f81d2e2442 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -35,9 +35,7 @@ indicate which side is sending, these are *not* included on the wire:: > PING 1490197665618 < NAME synapse.app.appservice < PING 1490197665618 - < REPLICATE events 1 - < REPLICATE backfill 1 - < REPLICATE caches 1 + < REPLICATE > POSITION events 1 > POSITION backfill 1 > POSITION caches 1 @@ -53,17 +51,15 @@ import fcntl import logging import struct from collections import defaultdict -from typing import Any, DefaultDict, Dict, List, Set, Tuple +from typing import Any, DefaultDict, Dict, List, Set -from six import iteritems, iterkeys +from six import iteritems from prometheus_client import Counter -from twisted.internet import defer from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure -from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.commands import ( @@ -82,11 +78,16 @@ from synapse.replication.tcp.commands import ( SyncCommand, UserSyncCommand, ) -from synapse.replication.tcp.streams import STREAMS_MAP +from synapse.replication.tcp.streams import STREAMS_MAP, Stream from synapse.types import Collection from synapse.util import Clock from synapse.util.stringutils import random_string +MYPY = False +if MYPY: + from synapse.server import HomeServer + + connection_close_counter = Counter( "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"] ) @@ -411,16 +412,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): self.server_name = server_name self.streamer = streamer - # The streams the client has subscribed to and is up to date with - self.replication_streams = set() # type: Set[str] - - # The streams the client is currently subscribing to. - self.connecting_streams = set() # type: Set[str] - - # Map from stream name to list of updates to send once we've finished - # subscribing the client to the stream. - self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]] - def connectionMade(self): self.send_command(ServerCommand(self.server_name)) BaseReplicationStreamProtocol.connectionMade(self) @@ -436,21 +427,10 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): ) async def on_REPLICATE(self, cmd): - stream_name = cmd.stream_name - token = cmd.token - - if stream_name == "ALL": - # Subscribe to all streams we're publishing to. - deferreds = [ - run_in_background(self.subscribe_to_stream, stream, token) - for stream in iterkeys(self.streamer.streams_by_name) - ] - - await make_deferred_yieldable( - defer.gatherResults(deferreds, consumeErrors=True) - ) - else: - await self.subscribe_to_stream(stream_name, token) + # Subscribe to all streams we're publishing to. + for stream_name in self.streamer.streams_by_name: + current_token = self.streamer.get_stream_token(stream_name) + self.send_command(PositionCommand(stream_name, current_token)) async def on_FEDERATION_ACK(self, cmd): self.streamer.federation_ack(cmd.token) @@ -474,87 +454,12 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): cmd.last_seen, ) - async def subscribe_to_stream(self, stream_name, token): - """Subscribe the remote to a stream. - - This invloves checking if they've missed anything and sending those - updates down if they have. During that time new updates for the stream - are queued and sent once we've sent down any missed updates. - """ - self.replication_streams.discard(stream_name) - self.connecting_streams.add(stream_name) - - try: - # Get missing updates - updates, current_token = await self.streamer.get_stream_updates( - stream_name, token - ) - - # Send all the missing updates - for update in updates: - token, row = update[0], update[1] - self.send_command(RdataCommand(stream_name, token, row)) - - # We send a POSITION command to ensure that they have an up to - # date token (especially useful if we didn't send any updates - # above) - self.send_command(PositionCommand(stream_name, current_token)) - - # Now we can send any updates that came in while we were subscribing - pending_rdata = self.pending_rdata.pop(stream_name, []) - updates = [] - for token, update in pending_rdata: - # If the token is null, it is part of a batch update. Batches - # are multiple updates that share a single token. To denote - # this, the token is set to None for all tokens in the batch - # except for the last. If we find a None token, we keep looking - # through tokens until we find one that is not None and then - # process all previous updates in the batch as if they had the - # final token. - if token is None: - # Store this update as part of a batch - updates.append(update) - continue - - if token <= current_token: - # This update or batch of updates is older than - # current_token, dismiss it - updates = [] - continue - - updates.append(update) - - # Send all updates that are part of this batch with the - # found token - for update in updates: - self.send_command(RdataCommand(stream_name, token, update)) - - # Clear stored updates - updates = [] - - # They're now fully subscribed - self.replication_streams.add(stream_name) - except Exception as e: - logger.exception("[%s] Failed to handle REPLICATE command", self.id()) - self.send_error("failed to handle replicate: %r", e) - finally: - self.connecting_streams.discard(stream_name) - def stream_update(self, stream_name, token, data): """Called when a new update is available to stream to clients. We need to check if the client is interested in the stream or not """ - if stream_name in self.replication_streams: - # The client is subscribed to the stream - self.send_command(RdataCommand(stream_name, token, data)) - elif stream_name in self.connecting_streams: - # The client is being subscribed to the stream - logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token) - self.pending_rdata.setdefault(stream_name, []).append((token, data)) - else: - # The client isn't subscribed - logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token) + self.send_command(RdataCommand(stream_name, token, data)) def send_sync(self, data): self.send_command(SyncCommand(data)) @@ -638,6 +543,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): def __init__( self, + hs: "HomeServer", client_name: str, server_name: str, clock: Clock, @@ -649,22 +555,25 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): self.server_name = server_name self.handler = handler + self.streams = { + stream.NAME: stream(hs) for stream in STREAMS_MAP.values() + } # type: Dict[str, Stream] + # Set of stream names that have been subscribe to, but haven't yet # caught up with. This is used to track when the client has been fully # connected to the remote. - self.streams_connecting = set() # type: Set[str] + self.streams_connecting = set(STREAMS_MAP) # type: Set[str] # Map of stream to batched updates. See RdataCommand for info on how # batching works. - self.pending_batches = {} # type: Dict[str, Any] + self.pending_batches = {} # type: Dict[str, List[Any]] def connectionMade(self): self.send_command(NameCommand(self.client_name)) BaseReplicationStreamProtocol.connectionMade(self) # Once we've connected subscribe to the necessary streams - for stream_name, token in iteritems(self.handler.get_streams_to_replicate()): - self.replicate(stream_name, token) + self.replicate() # Tell the server if we have any users currently syncing (should only # happen on synchrotrons) @@ -676,10 +585,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # We've now finished connecting to so inform the client handler self.handler.update_connection(self) - # This will happen if we don't actually subscribe to any streams - if not self.streams_connecting: - self.handler.finished_connecting() - async def on_SERVER(self, cmd): if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) @@ -697,7 +602,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): ) raise - if cmd.token is None: + if cmd.token is None or stream_name in self.streams_connecting: # I.e. this is part of a batch of updates for this stream. Batch # until we get an update for the stream with a non None token self.pending_batches.setdefault(stream_name, []).append(row) @@ -707,14 +612,55 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): rows.append(row) await self.handler.on_rdata(stream_name, cmd.token, rows) - async def on_POSITION(self, cmd): - # When we get a `POSITION` command it means we've finished getting - # missing updates for the given stream, and are now up to date. + async def on_POSITION(self, cmd: PositionCommand): + stream = self.streams.get(cmd.stream_name) + if not stream: + logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) + return + + # Find where we previously streamed up to. + current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name) + if current_token is None: + logger.warning( + "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name + ) + return + + # Fetch all updates between then and now. + limited = True + while limited: + updates, current_token, limited = await stream.get_updates_since( + current_token, cmd.token + ) + + # Check if the connection was closed underneath us, if so we bail + # rather than risk having concurrent catch ups going on. + if self.state == ConnectionStates.CLOSED: + return + + if updates: + await self.handler.on_rdata( + cmd.stream_name, + current_token, + [stream.parse_row(update[1]) for update in updates], + ) + + # We've now caught up to position sent to us, notify handler. + await self.handler.on_position(cmd.stream_name, cmd.token) + self.streams_connecting.discard(cmd.stream_name) if not self.streams_connecting: self.handler.finished_connecting() - await self.handler.on_position(cmd.stream_name, cmd.token) + # Check if the connection was closed underneath us, if so we bail + # rather than risk having concurrent catch ups going on. + if self.state == ConnectionStates.CLOSED: + return + + # Handle any RDATA that came in while we were catching up. + rows = self.pending_batches.pop(cmd.stream_name, []) + if rows: + await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows) async def on_SYNC(self, cmd): self.handler.on_sync(cmd.data) @@ -722,22 +668,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): self.handler.on_remote_server_up(cmd.data) - def replicate(self, stream_name, token): + def replicate(self): """Send the subscription request to the server """ - if stream_name not in STREAMS_MAP: - raise Exception("Invalid stream name %r" % (stream_name,)) - - logger.info( - "[%s] Subscribing to replication stream: %r from %r", - self.id(), - stream_name, - token, - ) - - self.streams_connecting.add(stream_name) + logger.info("[%s] Subscribing to replication streams", self.id()) - self.send_command(ReplicateCommand(stream_name, token)) + self.send_command(ReplicateCommand()) def on_connection_closed(self): BaseReplicationStreamProtocol.on_connection_closed(self) -- cgit 1.5.1 From 4f21c33be301b8ea6369039c3ad8baa51878e4d5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 30 Mar 2020 16:37:24 +0100 Subject: Remove usage of "conn_id" for presence. (#7128) * Remove `conn_id` usage for UserSyncCommand. Each tcp replication connection is assigned a "conn_id", which is used to give an ID to a remotely connected worker. In a redis world, there will no longer be a one to one mapping between connection and instance, so instead we need to replace such usages with an ID generated by the remote instances and included in the replicaiton commands. This really only effects UserSyncCommand. * Add CLEAR_USER_SYNCS command that is sent on shutdown. This should help with the case where a synchrotron gets restarted gracefully, rather than rely on 5 minute timeout. --- changelog.d/7128.misc | 1 + docs/tcp_replication.md | 6 ++++++ synapse/app/generic_worker.py | 20 ++++++++++++++++---- synapse/replication/tcp/client.py | 6 ++++-- synapse/replication/tcp/commands.py | 36 ++++++++++++++++++++++++++++++++---- synapse/replication/tcp/protocol.py | 9 +++++++-- synapse/replication/tcp/resource.py | 17 +++++++---------- synapse/server.py | 11 +++++++++++ synapse/server.pyi | 2 ++ 9 files changed, 86 insertions(+), 22 deletions(-) create mode 100644 changelog.d/7128.misc (limited to 'synapse/replication/tcp/protocol.py') diff --git a/changelog.d/7128.misc b/changelog.d/7128.misc new file mode 100644 index 0000000000..5703f6d2ec --- /dev/null +++ b/changelog.d/7128.misc @@ -0,0 +1 @@ +Add explicit `instance_id` for USER_SYNC commands and remove implicit `conn_id` usage. diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md index d4f7d9ec18..3be8e50c4c 100644 --- a/docs/tcp_replication.md +++ b/docs/tcp_replication.md @@ -198,6 +198,12 @@ Asks the server for the current position of all streams. A user has started or stopped syncing +#### CLEAR_USER_SYNC (C) + + The server should clear all associated user sync data from the worker. + + This is used when a worker is shutting down. + #### FEDERATION_ACK (C) Acknowledge receipt of some federation data diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index fba7ad9551..1ee266f7c5 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -65,6 +65,7 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.replication.tcp.commands import ClearUserSyncsCommand from synapse.replication.tcp.streams import ( AccountDataStream, DeviceListsStream, @@ -124,7 +125,6 @@ from synapse.types import ReadReceipt from synapse.util.async_helpers import Linearizer from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole -from synapse.util.stringutils import random_string from synapse.util.versionstring import get_version_string logger = logging.getLogger("synapse.app.generic_worker") @@ -233,6 +233,7 @@ class GenericWorkerPresence(object): self.user_to_num_current_syncs = {} self.clock = hs.get_clock() self.notifier = hs.get_notifier() + self.instance_id = hs.get_instance_id() active_presence = self.store.take_presence_startup_info() self.user_to_current_state = {state.user_id: state for state in active_presence} @@ -245,13 +246,24 @@ class GenericWorkerPresence(object): self.send_stop_syncing, UPDATE_SYNCING_USERS_MS ) - self.process_id = random_string(16) - logger.info("Presence process_id is %r", self.process_id) + hs.get_reactor().addSystemEventTrigger( + "before", + "shutdown", + run_as_background_process, + "generic_presence.on_shutdown", + self._on_shutdown, + ) + + def _on_shutdown(self): + if self.hs.config.use_presence: + self.hs.get_tcp_replication().send_command( + ClearUserSyncsCommand(self.instance_id) + ) def send_user_sync(self, user_id, is_syncing, last_sync_ms): if self.hs.config.use_presence: self.hs.get_tcp_replication().send_user_sync( - user_id, is_syncing, last_sync_ms + self.instance_id, user_id, is_syncing, last_sync_ms ) def mark_as_coming_online(self, user_id): diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 7e7ad0f798..e86d9805f1 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -189,10 +189,12 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): """ self.send_command(FederationAckCommand(token)) - def send_user_sync(self, user_id, is_syncing, last_sync_ms): + def send_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms): """Poke the master that a user has started/stopped syncing. """ - self.send_command(UserSyncCommand(user_id, is_syncing, last_sync_ms)) + self.send_command( + UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) + ) def send_remove_pusher(self, app_id, push_key, user_id): """Poke the master to remove a pusher for a user diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 5a6b734094..e4eec643f7 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -207,30 +207,32 @@ class UserSyncCommand(Command): Format:: - USER_SYNC + USER_SYNC Where is either "start" or "stop" """ NAME = "USER_SYNC" - def __init__(self, user_id, is_syncing, last_sync_ms): + def __init__(self, instance_id, user_id, is_syncing, last_sync_ms): + self.instance_id = instance_id self.user_id = user_id self.is_syncing = is_syncing self.last_sync_ms = last_sync_ms @classmethod def from_line(cls, line): - user_id, state, last_sync_ms = line.split(" ", 2) + instance_id, user_id, state, last_sync_ms = line.split(" ", 3) if state not in ("start", "end"): raise Exception("Invalid USER_SYNC state %r" % (state,)) - return cls(user_id, state == "start", int(last_sync_ms)) + return cls(instance_id, user_id, state == "start", int(last_sync_ms)) def to_line(self): return " ".join( ( + self.instance_id, self.user_id, "start" if self.is_syncing else "end", str(self.last_sync_ms), @@ -238,6 +240,30 @@ class UserSyncCommand(Command): ) +class ClearUserSyncsCommand(Command): + """Sent by the client to inform the server that it should drop all + information about syncing users sent by the client. + + Mainly used when client is about to shut down. + + Format:: + + CLEAR_USER_SYNC + """ + + NAME = "CLEAR_USER_SYNC" + + def __init__(self, instance_id): + self.instance_id = instance_id + + @classmethod + def from_line(cls, line): + return cls(line) + + def to_line(self): + return self.instance_id + + class FederationAckCommand(Command): """Sent by the client when it has processed up to a given point in the federation stream. This allows the master to drop in-memory caches of the @@ -398,6 +424,7 @@ _COMMANDS = ( InvalidateCacheCommand, UserIpCommand, RemoteServerUpCommand, + ClearUserSyncsCommand, ) # type: Tuple[Type[Command], ...] # Map of command name to command type. @@ -420,6 +447,7 @@ VALID_CLIENT_COMMANDS = ( ReplicateCommand.NAME, PingCommand.NAME, UserSyncCommand.NAME, + ClearUserSyncsCommand.NAME, FederationAckCommand.NAME, RemovePusherCommand.NAME, InvalidateCacheCommand.NAME, diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index f81d2e2442..dae246825f 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -423,9 +423,12 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): async def on_USER_SYNC(self, cmd): await self.streamer.on_user_sync( - self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms + cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms ) + async def on_CLEAR_USER_SYNC(self, cmd): + await self.streamer.on_clear_user_syncs(cmd.instance_id) + async def on_REPLICATE(self, cmd): # Subscribe to all streams we're publishing to. for stream_name in self.streamer.streams_by_name: @@ -551,6 +554,8 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): ): BaseReplicationStreamProtocol.__init__(self, clock) + self.instance_id = hs.get_instance_id() + self.client_name = client_name self.server_name = server_name self.handler = handler @@ -580,7 +585,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): currently_syncing = self.handler.get_currently_syncing_users() now = self.clock.time_msec() for user_id in currently_syncing: - self.send_command(UserSyncCommand(user_id, True, now)) + self.send_command(UserSyncCommand(self.instance_id, user_id, True, now)) # We've now finished connecting to so inform the client handler self.handler.update_connection(self) diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 4374e99e32..8b6067e20d 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -251,14 +251,19 @@ class ReplicationStreamer(object): self.federation_sender.federation_ack(token) @measure_func("repl.on_user_sync") - async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms): + async def on_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms): """A client has started/stopped syncing on a worker. """ user_sync_counter.inc() await self.presence_handler.update_external_syncs_row( - conn_id, user_id, is_syncing, last_sync_ms + instance_id, user_id, is_syncing, last_sync_ms ) + async def on_clear_user_syncs(self, instance_id): + """A replication client wants us to drop all their UserSync data. + """ + await self.presence_handler.update_external_syncs_clear(instance_id) + @measure_func("repl.on_remove_pusher") async def on_remove_pusher(self, app_id, push_key, user_id): """A client has asked us to remove a pusher @@ -321,14 +326,6 @@ class ReplicationStreamer(object): except ValueError: pass - # We need to tell the presence handler that the connection has been - # lost so that it can handle any ongoing syncs on that connection. - run_as_background_process( - "update_external_syncs_clear", - self.presence_handler.update_external_syncs_clear, - connection.conn_id, - ) - def _batch_updates(updates): """Takes a list of updates of form [(token, row)] and sets the token to diff --git a/synapse/server.py b/synapse/server.py index c7ca2bda0d..cd86475d6b 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -103,6 +103,7 @@ from synapse.storage import DataStores, Storage from synapse.streams.events import EventSources from synapse.util import Clock from synapse.util.distributor import Distributor +from synapse.util.stringutils import random_string logger = logging.getLogger(__name__) @@ -230,6 +231,8 @@ class HomeServer(object): self._listening_services = [] self.start_time = None + self.instance_id = random_string(5) + self.clock = Clock(reactor) self.distributor = Distributor() self.ratelimiter = Ratelimiter() @@ -242,6 +245,14 @@ class HomeServer(object): for depname in kwargs: setattr(self, depname, kwargs[depname]) + def get_instance_id(self): + """A unique ID for this synapse process instance. + + This is used to distinguish running instances in worker-based + deployments. + """ + return self.instance_id + def setup(self): logger.info("Setting up.") self.start_time = int(self.get_clock().time()) diff --git a/synapse/server.pyi b/synapse/server.pyi index 3844f0e12f..9d1dfa71e7 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -114,3 +114,5 @@ class HomeServer(object): pass def is_mine_id(self, domain_id: str) -> bool: pass + def get_instance_id(self) -> str: + pass -- cgit 1.5.1 From 5016b162fcf0372fe35404c64f80aeaf21461f31 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 6 Apr 2020 09:58:42 +0100 Subject: Move client command handling out of TCP protocol (#7185) The aim here is to move the command handling out of the TCP protocol classes and to also merge the client and server command handling (so that we can reuse them for redis protocol). This PR simply moves the client paths to the new `ReplicationCommandHandler`, a future PR will move the server paths too. --- changelog.d/7185.misc | 1 + synapse/app/admin_cmd.py | 12 -- synapse/app/generic_worker.py | 9 +- synapse/replication/tcp/__init__.py | 30 ++- synapse/replication/tcp/client.py | 179 +++--------------- synapse/replication/tcp/handler.py | 252 +++++++++++++++++++++++++ synapse/replication/tcp/protocol.py | 197 +++---------------- synapse/server.py | 8 +- synapse/server.pyi | 7 +- tests/replication/slave/storage/_base.py | 15 +- tests/replication/tcp/streams/_base.py | 38 ++-- tests/replication/tcp/streams/test_receipts.py | 1 - 12 files changed, 378 insertions(+), 371 deletions(-) create mode 100644 changelog.d/7185.misc create mode 100644 synapse/replication/tcp/handler.py (limited to 'synapse/replication/tcp/protocol.py') diff --git a/changelog.d/7185.misc b/changelog.d/7185.misc new file mode 100644 index 0000000000..deb9ca7021 --- /dev/null +++ b/changelog.d/7185.misc @@ -0,0 +1 @@ +Move client command handling out of TCP protocol. diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 1c7c6ec0c8..a37818fe9a 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -43,7 +43,6 @@ from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.room import RoomStore -from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer from synapse.util.logcontext import LoggingContext from synapse.util.versionstring import get_version_string @@ -79,17 +78,6 @@ class AdminCmdServer(HomeServer): def start_listening(self, listeners): pass - def build_tcp_replication(self): - return AdminCmdReplicationHandler(self) - - -class AdminCmdReplicationHandler(ReplicationClientHandler): - async def on_rdata(self, stream_name, token, rows): - pass - - def get_streams_to_replicate(self): - return {} - @defer.inlineCallbacks def export_data_command(hs, args): diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 174bef360f..dcd0709a02 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -64,7 +64,7 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore -from synapse.replication.tcp.client import ReplicationClientHandler +from synapse.replication.tcp.client import ReplicationDataHandler from synapse.replication.tcp.commands import ClearUserSyncsCommand from synapse.replication.tcp.streams import ( AccountDataStream, @@ -603,7 +603,7 @@ class GenericWorkerServer(HomeServer): def remove_pusher(self, app_id, push_key, user_id): self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id) - def build_tcp_replication(self): + def build_replication_data_handler(self): return GenericWorkerReplicationHandler(self) def build_presence_handler(self): @@ -613,7 +613,7 @@ class GenericWorkerServer(HomeServer): return GenericWorkerTyping(self) -class GenericWorkerReplicationHandler(ReplicationClientHandler): +class GenericWorkerReplicationHandler(ReplicationDataHandler): def __init__(self, hs): super(GenericWorkerReplicationHandler, self).__init__(hs.get_datastore()) @@ -644,9 +644,6 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler): args.update(self.send_handler.stream_positions()) return args - def get_currently_syncing_users(self): - return self.presence_handler.get_currently_syncing_users() - async def process_and_notify(self, stream_name, token, rows): try: if self.send_handler: diff --git a/synapse/replication/tcp/__init__.py b/synapse/replication/tcp/__init__.py index 81c2ea7ee9..523a1358d4 100644 --- a/synapse/replication/tcp/__init__.py +++ b/synapse/replication/tcp/__init__.py @@ -20,11 +20,31 @@ Further details can be found in docs/tcp_replication.rst Structure of the module: - * client.py - the client classes used for workers to connect to master + * handler.py - the classes used to handle sending/receiving commands to + replication * command.py - the definitions of all the valid commands - * protocol.py - contains bot the client and server protocol implementations, - these should not be used directly - * resource.py - the server classes that accepts and handle client connections - * streams.py - the definitons of all the valid streams + * protocol.py - the TCP protocol classes + * resource.py - handles streaming stream updates to replications + * streams/ - the definitons of all the valid streams + +The general interaction of the classes are: + + +---------------------+ + | ReplicationStreamer | + +---------------------+ + | + v + +---------------------------+ +----------------------+ + | ReplicationCommandHandler |---->|ReplicationDataHandler| + +---------------------------+ +----------------------+ + | ^ + v | + +-------------+ + | Protocols | + | (TCP/redis) | + +-------------+ + +Where the ReplicationDataHandler (or subclasses) handles incoming stream +updates. """ diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index e86d9805f1..700ae79158 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -16,26 +16,16 @@ """ import logging -from typing import Dict, List, Optional +from typing import TYPE_CHECKING, Dict -from twisted.internet import defer from twisted.internet.protocol import ReconnectingClientFactory from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.tcp.protocol import ( - AbstractReplicationClientHandler, - ClientReplicationStreamProtocol, -) - -from .commands import ( - Command, - FederationAckCommand, - InvalidateCacheCommand, - RemoteServerUpCommand, - RemovePusherCommand, - UserIpCommand, - UserSyncCommand, -) +from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol + +if TYPE_CHECKING: + from synapse.server import HomeServer + from synapse.replication.tcp.handler import ReplicationCommandHandler logger = logging.getLogger(__name__) @@ -44,16 +34,20 @@ class ReplicationClientFactory(ReconnectingClientFactory): """Factory for building connections to the master. Will reconnect if the connection is lost. - Accepts a handler that will be called when new data is available or data - is required. + Accepts a handler that is passed to `ClientReplicationStreamProtocol`. """ initialDelay = 0.1 maxDelay = 1 # Try at least once every N seconds - def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler): + def __init__( + self, + hs: "HomeServer", + client_name: str, + command_handler: "ReplicationCommandHandler", + ): self.client_name = client_name - self.handler = handler + self.command_handler = command_handler self.server_name = hs.config.server_name self.hs = hs self._clock = hs.get_clock() # As self.clock is defined in super class @@ -66,7 +60,11 @@ class ReplicationClientFactory(ReconnectingClientFactory): def buildProtocol(self, addr): logger.info("Connected to replication: %r", addr) return ClientReplicationStreamProtocol( - self.hs, self.client_name, self.server_name, self._clock, self.handler, + self.hs, + self.client_name, + self.server_name, + self._clock, + self.command_handler, ) def clientConnectionLost(self, connector, reason): @@ -78,41 +76,17 @@ class ReplicationClientFactory(ReconnectingClientFactory): ReconnectingClientFactory.clientConnectionFailed(self, connector, reason) -class ReplicationClientHandler(AbstractReplicationClientHandler): - """A base handler that can be passed to the ReplicationClientFactory. +class ReplicationDataHandler: + """Handles incoming stream updates from replication. - By default proxies incoming replication data to the SlaveStore. + This instance notifies the slave data store about updates. Can be subclassed + to handle updates in additional ways. """ def __init__(self, store: BaseSlavedStore): self.store = store - # The current connection. None if we are currently (re)connecting - self.connection = None - - # Any pending commands to be sent once a new connection has been - # established - self.pending_commands = [] # type: List[Command] - - # Map from string -> deferred, to wake up when receiveing a SYNC with - # the given string. - # Used for tests. - self.awaiting_syncs = {} # type: Dict[str, defer.Deferred] - - # The factory used to create connections. - self.factory = None # type: Optional[ReplicationClientFactory] - - def start_replication(self, hs): - """Helper method to start a replication connection to the remote server - using TCP. - """ - client_name = hs.config.worker_name - self.factory = ReplicationClientFactory(hs, client_name, self) - host = hs.config.worker_replication_host - port = hs.config.worker_replication_port - hs.get_reactor().connectTCP(host, port, self.factory) - - async def on_rdata(self, stream_name, token, rows): + async def on_rdata(self, stream_name: str, token: int, rows: list): """Called to handle a batch of replication data with a given stream token. By default this just pokes the slave store. Can be overridden in subclasses to @@ -124,30 +98,8 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): rows (list): a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ - logger.debug("Received rdata %s -> %s", stream_name, token) self.store.process_replication_rows(stream_name, token, rows) - async def on_position(self, stream_name, token): - """Called when we get new position data. By default this just pokes - the slave store. - - Can be overriden in subclasses to handle more. - """ - self.store.process_replication_rows(stream_name, token, []) - - def on_sync(self, data): - """When we received a SYNC we wake up any deferreds that were waiting - for the sync with the given data. - - Used by tests. - """ - d = self.awaiting_syncs.pop(data, None) - if d: - d.callback(data) - - def on_remote_server_up(self, server: str): - """Called when get a new REMOTE_SERVER_UP command.""" - def get_streams_to_replicate(self) -> Dict[str, int]: """Called when a new connection has been established and we need to subscribe to streams. @@ -163,85 +115,10 @@ class ReplicationClientHandler(AbstractReplicationClientHandler): args["account_data"] = user_account_data elif room_account_data: args["account_data"] = room_account_data - return args - def get_currently_syncing_users(self): - """Get the list of currently syncing users (if any). This is called - when a connection has been established and we need to send the - currently syncing users. (Overriden by the synchrotron's only) - """ - return [] - - def send_command(self, cmd): - """Send a command to master (when we get establish a connection if we - don't have one already.) - """ - if self.connection: - self.connection.send_command(cmd) - else: - logger.warning("Queuing command as not connected: %r", cmd.NAME) - self.pending_commands.append(cmd) - - def send_federation_ack(self, token): - """Ack data for the federation stream. This allows the master to drop - data stored purely in memory. - """ - self.send_command(FederationAckCommand(token)) - - def send_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms): - """Poke the master that a user has started/stopped syncing. - """ - self.send_command( - UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) - ) - - def send_remove_pusher(self, app_id, push_key, user_id): - """Poke the master to remove a pusher for a user - """ - cmd = RemovePusherCommand(app_id, push_key, user_id) - self.send_command(cmd) - - def send_invalidate_cache(self, cache_func, keys): - """Poke the master to invalidate a cache. - """ - cmd = InvalidateCacheCommand(cache_func.__name__, keys) - self.send_command(cmd) - - def send_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen): - """Tell the master that the user made a request. - """ - cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen) - self.send_command(cmd) - - def send_remote_server_up(self, server: str): - self.send_command(RemoteServerUpCommand(server)) - - def await_sync(self, data): - """Returns a deferred that is resolved when we receive a SYNC command - with given data. - - [Not currently] used by tests. - """ - return self.awaiting_syncs.setdefault(data, defer.Deferred()) - - def update_connection(self, connection): - """Called when a connection has been established (or lost with None). - """ - self.connection = connection - if connection: - for cmd in self.pending_commands: - connection.send_command(cmd) - self.pending_commands = [] - - def finished_connecting(self): - """Called when we have successfully subscribed and caught up to all - streams we're interested in. - """ - logger.info("Finished connecting to server") + async def on_position(self, stream_name: str, token: int): + self.store.process_replication_rows(stream_name, token, []) - # We don't reset the delay any earlier as otherwise if there is a - # problem during start up we'll end up tight looping connecting to the - # server. - if self.factory: - self.factory.resetDelay() + def on_remote_server_up(self, server: str): + """Called when get a new REMOTE_SERVER_UP command.""" diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py new file mode 100644 index 0000000000..12a1cfd6d1 --- /dev/null +++ b/synapse/replication/tcp/handler.py @@ -0,0 +1,252 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, Callable, Dict, List, Optional, Set + +from prometheus_client import Counter + +from synapse.replication.tcp.client import ReplicationClientFactory +from synapse.replication.tcp.commands import ( + Command, + FederationAckCommand, + InvalidateCacheCommand, + PositionCommand, + RdataCommand, + RemoteServerUpCommand, + RemovePusherCommand, + SyncCommand, + UserIpCommand, + UserSyncCommand, +) +from synapse.replication.tcp.streams import STREAMS_MAP, Stream +from synapse.util.async_helpers import Linearizer + +logger = logging.getLogger(__name__) + + +# number of updates received for each RDATA stream +inbound_rdata_count = Counter( + "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"] +) + + +class ReplicationCommandHandler: + """Handles incoming commands from replication as well as sending commands + back out to connections. + """ + + def __init__(self, hs): + self._replication_data_handler = hs.get_replication_data_handler() + self._presence_handler = hs.get_presence_handler() + + # Set of streams that we've caught up with. + self._streams_connected = set() # type: Set[str] + + self._streams = { + stream.NAME: stream(hs) for stream in STREAMS_MAP.values() + } # type: Dict[str, Stream] + + self._position_linearizer = Linearizer("replication_position") + + # Map of stream to batched updates. See RdataCommand for info on how + # batching works. + self._pending_batches = {} # type: Dict[str, List[Any]] + + # The factory used to create connections. + self._factory = None # type: Optional[ReplicationClientFactory] + + # The current connection. None if we are currently (re)connecting + self._connection = None + + def start_replication(self, hs): + """Helper method to start a replication connection to the remote server + using TCP. + """ + client_name = hs.config.worker_name + self._factory = ReplicationClientFactory(hs, client_name, self) + host = hs.config.worker_replication_host + port = hs.config.worker_replication_port + hs.get_reactor().connectTCP(host, port, self._factory) + + async def on_RDATA(self, cmd: RdataCommand): + stream_name = cmd.stream_name + inbound_rdata_count.labels(stream_name).inc() + + try: + row = STREAMS_MAP[stream_name].parse_row(cmd.row) + except Exception: + logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row) + raise + + if cmd.token is None or stream_name not in self._streams_connected: + # I.e. either this is part of a batch of updates for this stream (in + # which case batch until we get an update for the stream with a non + # None token) or we're currently connecting so we queue up rows. + self._pending_batches.setdefault(stream_name, []).append(row) + else: + # Check if this is the last of a batch of updates + rows = self._pending_batches.pop(stream_name, []) + rows.append(row) + await self.on_rdata(stream_name, cmd.token, rows) + + async def on_rdata(self, stream_name: str, token: int, rows: list): + """Called to handle a batch of replication data with a given stream token. + + Args: + stream_name: name of the replication stream for this batch of rows + token: stream token for this batch of rows + rows: a list of Stream.ROW_TYPE objects as returned by + Stream.parse_row. + """ + logger.debug("Received rdata %s -> %s", stream_name, token) + await self._replication_data_handler.on_rdata(stream_name, token, rows) + + async def on_POSITION(self, cmd: PositionCommand): + stream = self._streams.get(cmd.stream_name) + if not stream: + logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) + return + + # We protect catching up with a linearizer in case the replication + # connection reconnects under us. + with await self._position_linearizer.queue(cmd.stream_name): + # We're about to go and catch up with the stream, so mark as connecting + # to stop RDATA being handled at the same time by removing stream from + # list of connected streams. We also clear any batched up RDATA from + # before we got the POSITION. + self._streams_connected.discard(cmd.stream_name) + self._pending_batches.clear() + + # Find where we previously streamed up to. + current_token = self._replication_data_handler.get_streams_to_replicate().get( + cmd.stream_name + ) + if current_token is None: + logger.warning( + "Got POSITION for stream we're not subscribed to: %s", + cmd.stream_name, + ) + return + + # Fetch all updates between then and now. + limited = True + while limited: + updates, current_token, limited = await stream.get_updates_since( + current_token, cmd.token + ) + if updates: + await self.on_rdata( + cmd.stream_name, + current_token, + [stream.parse_row(update[1]) for update in updates], + ) + + # We've now caught up to position sent to us, notify handler. + await self._replication_data_handler.on_position(cmd.stream_name, cmd.token) + + # Handle any RDATA that came in while we were catching up. + rows = self._pending_batches.pop(cmd.stream_name, []) + if rows: + await self._replication_data_handler.on_rdata( + cmd.stream_name, rows[-1].token, rows + ) + + self._streams_connected.add(cmd.stream_name) + + async def on_SYNC(self, cmd: SyncCommand): + pass + + async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): + """"Called when get a new REMOTE_SERVER_UP command.""" + self._replication_data_handler.on_remote_server_up(cmd.data) + + def get_currently_syncing_users(self): + """Get the list of currently syncing users (if any). This is called + when a connection has been established and we need to send the + currently syncing users. + """ + return self._presence_handler.get_currently_syncing_users() + + def update_connection(self, connection): + """Called when a connection has been established (or lost with None). + """ + self._connection = connection + + def finished_connecting(self): + """Called when we have successfully subscribed and caught up to all + streams we're interested in. + """ + logger.info("Finished connecting to server") + + # We don't reset the delay any earlier as otherwise if there is a + # problem during start up we'll end up tight looping connecting to the + # server. + if self._factory: + self._factory.resetDelay() + + def send_command(self, cmd: Command): + """Send a command to master (when we get establish a connection if we + don't have one already.) + """ + if self._connection: + self._connection.send_command(cmd) + else: + logger.warning("Dropping command as not connected: %r", cmd.NAME) + + def send_federation_ack(self, token: int): + """Ack data for the federation stream. This allows the master to drop + data stored purely in memory. + """ + self.send_command(FederationAckCommand(token)) + + def send_user_sync( + self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int + ): + """Poke the master that a user has started/stopped syncing. + """ + self.send_command( + UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) + ) + + def send_remove_pusher(self, app_id: str, push_key: str, user_id: str): + """Poke the master to remove a pusher for a user + """ + cmd = RemovePusherCommand(app_id, push_key, user_id) + self.send_command(cmd) + + def send_invalidate_cache(self, cache_func: Callable, keys: tuple): + """Poke the master to invalidate a cache. + """ + cmd = InvalidateCacheCommand(cache_func.__name__, keys) + self.send_command(cmd) + + def send_user_ip( + self, + user_id: str, + access_token: str, + ip: str, + user_agent: str, + device_id: str, + last_seen: int, + ): + """Tell the master that the user made a request. + """ + cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen) + self.send_command(cmd) + + def send_remote_server_up(self, server: str): + self.send_command(RemoteServerUpCommand(server)) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index dae246825f..f2a37f568e 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -46,12 +46,11 @@ indicate which side is sending, these are *not* included on the wire:: > ERROR server stopping * connection closed by server * """ -import abc import fcntl import logging import struct from collections import defaultdict -from typing import Any, DefaultDict, Dict, List, Set +from typing import TYPE_CHECKING, DefaultDict, List from six import iteritems @@ -78,13 +77,12 @@ from synapse.replication.tcp.commands import ( SyncCommand, UserSyncCommand, ) -from synapse.replication.tcp.streams import STREAMS_MAP, Stream from synapse.types import Collection from synapse.util import Clock from synapse.util.stringutils import random_string -MYPY = False -if MYPY: +if TYPE_CHECKING: + from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.server import HomeServer @@ -475,71 +473,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): self.streamer.lost_connection(self) -class AbstractReplicationClientHandler(metaclass=abc.ABCMeta): - """ - The interface for the handler that should be passed to - ClientReplicationStreamProtocol - """ - - @abc.abstractmethod - async def on_rdata(self, stream_name, token, rows): - """Called to handle a batch of replication data with a given stream token. - - Args: - stream_name (str): name of the replication stream for this batch of rows - token (int): stream token for this batch of rows - rows (list): a list of Stream.ROW_TYPE objects as returned by - Stream.parse_row. - """ - raise NotImplementedError() - - @abc.abstractmethod - async def on_position(self, stream_name, token): - """Called when we get new position data.""" - raise NotImplementedError() - - @abc.abstractmethod - def on_sync(self, data): - """Called when get a new SYNC command.""" - raise NotImplementedError() - - @abc.abstractmethod - async def on_remote_server_up(self, server: str): - """Called when get a new REMOTE_SERVER_UP command.""" - raise NotImplementedError() - - @abc.abstractmethod - def get_streams_to_replicate(self): - """Called when a new connection has been established and we need to - subscribe to streams. - - Returns: - map from stream name to the most recent update we have for - that stream (ie, the point we want to start replicating from) - """ - raise NotImplementedError() - - @abc.abstractmethod - def get_currently_syncing_users(self): - """Get the list of currently syncing users (if any). This is called - when a connection has been established and we need to send the - currently syncing users.""" - raise NotImplementedError() - - @abc.abstractmethod - def update_connection(self, connection): - """Called when a connection has been established (or lost with None). - """ - raise NotImplementedError() - - @abc.abstractmethod - def finished_connecting(self): - """Called when we have successfully subscribed and caught up to all - streams we're interested in. - """ - raise NotImplementedError() - - class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS @@ -550,7 +483,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): client_name: str, server_name: str, clock: Clock, - handler: AbstractReplicationClientHandler, + command_handler: "ReplicationCommandHandler", ): BaseReplicationStreamProtocol.__init__(self, clock) @@ -558,20 +491,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): self.client_name = client_name self.server_name = server_name - self.handler = handler - - self.streams = { - stream.NAME: stream(hs) for stream in STREAMS_MAP.values() - } # type: Dict[str, Stream] - - # Set of stream names that have been subscribe to, but haven't yet - # caught up with. This is used to track when the client has been fully - # connected to the remote. - self.streams_connecting = set(STREAMS_MAP) # type: Set[str] - - # Map of stream to batched updates. See RdataCommand for info on how - # batching works. - self.pending_batches = {} # type: Dict[str, List[Any]] + self.handler = command_handler def connectionMade(self): self.send_command(NameCommand(self.client_name)) @@ -589,89 +509,39 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # We've now finished connecting to so inform the client handler self.handler.update_connection(self) + self.handler.finished_connecting() - async def on_SERVER(self, cmd): - if cmd.data != self.server_name: - logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) - self.send_error("Wrong remote") - - async def on_RDATA(self, cmd): - stream_name = cmd.stream_name - inbound_rdata_count.labels(stream_name).inc() - - try: - row = STREAMS_MAP[stream_name].parse_row(cmd.row) - except Exception: - logger.exception( - "[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row - ) - raise - - if cmd.token is None or stream_name in self.streams_connecting: - # I.e. this is part of a batch of updates for this stream. Batch - # until we get an update for the stream with a non None token - self.pending_batches.setdefault(stream_name, []).append(row) - else: - # Check if this is the last of a batch of updates - rows = self.pending_batches.pop(stream_name, []) - rows.append(row) - await self.handler.on_rdata(stream_name, cmd.token, rows) - - async def on_POSITION(self, cmd: PositionCommand): - stream = self.streams.get(cmd.stream_name) - if not stream: - logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) - return - - # Find where we previously streamed up to. - current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name) - if current_token is None: - logger.warning( - "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name - ) - return - - # Fetch all updates between then and now. - limited = True - while limited: - updates, current_token, limited = await stream.get_updates_since( - current_token, cmd.token - ) - - # Check if the connection was closed underneath us, if so we bail - # rather than risk having concurrent catch ups going on. - if self.state == ConnectionStates.CLOSED: - return - - if updates: - await self.handler.on_rdata( - cmd.stream_name, - current_token, - [stream.parse_row(update[1]) for update in updates], - ) + async def handle_command(self, cmd: Command): + """Handle a command we have received over the replication stream. - # We've now caught up to position sent to us, notify handler. - await self.handler.on_position(cmd.stream_name, cmd.token) + Delegates to `command_handler.on_`, which must return an + awaitable. - self.streams_connecting.discard(cmd.stream_name) - if not self.streams_connecting: - self.handler.finished_connecting() + Args: + cmd: received command + """ + handled = False - # Check if the connection was closed underneath us, if so we bail - # rather than risk having concurrent catch ups going on. - if self.state == ConnectionStates.CLOSED: - return + # First call any command handlers on this instance. These are for TCP + # specific handling. + cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True - # Handle any RDATA that came in while we were catching up. - rows = self.pending_batches.pop(cmd.stream_name, []) - if rows: - await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows) + # Then call out to the handler. + cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True - async def on_SYNC(self, cmd): - self.handler.on_sync(cmd.data) + if not handled: + logger.warning("Unhandled command: %r", cmd) - async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): - self.handler.on_remote_server_up(cmd.data) + async def on_SERVER(self, cmd): + if cmd.data != self.server_name: + logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) + self.send_error("Wrong remote") def replicate(self): """Send the subscription request to the server @@ -768,8 +638,3 @@ tcp_outbound_commands = LaterGauge( for k, count in iteritems(p.outbound_commands_counter) }, ) - -# number of updates received for each RDATA stream -inbound_rdata_count = Counter( - "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"] -) diff --git a/synapse/server.py b/synapse/server.py index 9228e1c892..9d273c980c 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -87,6 +87,8 @@ from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.notifier import Notifier from synapse.push.action_generator import ActionGenerator from synapse.push.pusherpool import PusherPool +from synapse.replication.tcp.client import ReplicationDataHandler +from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.resource import ReplicationStreamer from synapse.rest.media.v1.media_repository import ( MediaRepository, @@ -206,6 +208,7 @@ class HomeServer(object): "password_policy_handler", "storage", "replication_streamer", + "replication_data_handler", ] REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] @@ -468,7 +471,7 @@ class HomeServer(object): return ReadMarkerHandler(self) def build_tcp_replication(self): - raise NotImplementedError() + return ReplicationCommandHandler(self) def build_action_generator(self): return ActionGenerator(self) @@ -562,6 +565,9 @@ class HomeServer(object): def build_replication_streamer(self) -> ReplicationStreamer: return ReplicationStreamer(self) + def build_replication_data_handler(self): + return ReplicationDataHandler(self.get_datastore()) + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/server.pyi b/synapse/server.pyi index 9d1dfa71e7..9013e9bac9 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -19,6 +19,7 @@ import synapse.handlers.set_password import synapse.http.client import synapse.notifier import synapse.replication.tcp.client +import synapse.replication.tcp.handler import synapse.rest.media.v1.media_repository import synapse.server_notices.server_notices_manager import synapse.server_notices.server_notices_sender @@ -106,7 +107,11 @@ class HomeServer(object): pass def get_tcp_replication( self, - ) -> synapse.replication.tcp.client.ReplicationClientHandler: + ) -> synapse.replication.tcp.handler.ReplicationCommandHandler: + pass + def get_replication_data_handler( + self, + ) -> synapse.replication.tcp.client.ReplicationDataHandler: pass def get_federation_registry( self, diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 2a1e7c7166..8902a5ab69 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -17,8 +17,9 @@ from mock import Mock, NonCallableMock from synapse.replication.tcp.client import ( ReplicationClientFactory, - ReplicationClientHandler, + ReplicationDataHandler, ) +from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.storage.database import make_conn @@ -51,15 +52,19 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): self.event_id = 0 server_factory = ReplicationStreamProtocolFactory(self.hs) - self.streamer = server_factory.streamer + self.streamer = hs.get_replication_streamer() - handler_factory = Mock() - self.replication_handler = ReplicationClientHandler(self.slaved_store) - self.replication_handler.factory = handler_factory + # We now do some gut wrenching so that we have a client that is based + # off of the slave store rather than the main store. + self.replication_handler = ReplicationCommandHandler(self.hs) + self.replication_handler._replication_data_handler = ReplicationDataHandler( + self.slaved_store + ) client_factory = ReplicationClientFactory( self.hs, "client_name", self.replication_handler ) + client_factory.handler = self.replication_handler server = server_factory.buildProtocol(None) client = client_factory.buildProtocol(None) diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index a755fe2879..32238fe79a 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -15,7 +15,7 @@ from mock import Mock -from synapse.replication.tcp.commands import ReplicateCommand +from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory @@ -26,15 +26,20 @@ from tests.server import FakeTransport class BaseStreamTestCase(unittest.HomeserverTestCase): """Base class for tests of the replication streams""" + def make_homeserver(self, reactor, clock): + self.test_handler = Mock(wraps=TestReplicationDataHandler()) + return self.setup_test_homeserver(replication_data_handler=self.test_handler) + def prepare(self, reactor, clock, hs): # build a replication server - server_factory = ReplicationStreamProtocolFactory(self.hs) - self.streamer = server_factory.streamer + server_factory = ReplicationStreamProtocolFactory(hs) + self.streamer = hs.get_replication_streamer() self.server = server_factory.buildProtocol(None) - self.test_handler = Mock(wraps=TestReplicationClientHandler()) + repl_handler = ReplicationCommandHandler(hs) + repl_handler.handler = self.test_handler self.client = ClientReplicationStreamProtocol( - hs, "client", "test", clock, self.test_handler, + hs, "client", "test", clock, repl_handler, ) self._client_transport = None @@ -69,13 +74,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self.streamer.on_notifier_poke() self.pump(0.1) - def replicate_stream(self): - """Make the client end a REPLICATE command to set up a subscription to a stream""" - self.client.send_command(ReplicateCommand()) - -class TestReplicationClientHandler(object): - """Drop-in for ReplicationClientHandler which just collects RDATA rows""" +class TestReplicationDataHandler: + """Drop-in for ReplicationDataHandler which just collects RDATA rows""" def __init__(self): self.streams = set() @@ -88,18 +89,9 @@ class TestReplicationClientHandler(object): positions[stream] = max(token, positions.get(stream, 0)) return positions - def get_currently_syncing_users(self): - return [] - - def update_connection(self, connection): - pass - - def finished_connecting(self): - pass - - async def on_position(self, stream_name, token): - """Called when we get new position data.""" - async def on_rdata(self, stream_name, token, rows): for r in rows: self._received_rdata_rows.append((stream_name, token, r)) + + async def on_position(self, stream_name, token): + pass diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index 0ec0825a0e..a0206f7363 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -24,7 +24,6 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): self.reconnect() # make the client subscribe to the receipts stream - self.replicate_stream() self.test_handler.streams.add("receipts") # tell the master to send a new receipt -- cgit 1.5.1 From 82498ee9019747eb86ed753e08fac0990d4ac8b9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 7 Apr 2020 10:51:07 +0100 Subject: Move server command handling out of TCP protocol (#7187) This completes the merging of server and client command processing. --- changelog.d/7187.misc | 1 + synapse/replication/tcp/handler.py | 177 ++++++++++++++++++++++++++++++++---- synapse/replication/tcp/protocol.py | 165 +++++++++++---------------------- synapse/replication/tcp/resource.py | 163 ++++++--------------------------- 4 files changed, 237 insertions(+), 269 deletions(-) create mode 100644 changelog.d/7187.misc (limited to 'synapse/replication/tcp/protocol.py') diff --git a/changelog.d/7187.misc b/changelog.d/7187.misc new file mode 100644 index 0000000000..60d68ae877 --- /dev/null +++ b/changelog.d/7187.misc @@ -0,0 +1 @@ +Move server command handling out of TCP protocol. diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 12a1cfd6d1..8ec0119697 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -19,8 +19,10 @@ from typing import Any, Callable, Dict, List, Optional, Set from prometheus_client import Counter +from synapse.metrics import LaterGauge from synapse.replication.tcp.client import ReplicationClientFactory from synapse.replication.tcp.commands import ( + ClearUserSyncsCommand, Command, FederationAckCommand, InvalidateCacheCommand, @@ -28,10 +30,12 @@ from synapse.replication.tcp.commands import ( RdataCommand, RemoteServerUpCommand, RemovePusherCommand, + ReplicateCommand, SyncCommand, UserIpCommand, UserSyncCommand, ) +from synapse.replication.tcp.protocol import AbstractConnection from synapse.replication.tcp.streams import STREAMS_MAP, Stream from synapse.util.async_helpers import Linearizer @@ -42,6 +46,13 @@ logger = logging.getLogger(__name__) inbound_rdata_count = Counter( "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"] ) +user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "") +federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "") +remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "") +invalidate_cache_counter = Counter( + "synapse_replication_tcp_resource_invalidate_cache", "" +) +user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") class ReplicationCommandHandler: @@ -52,6 +63,10 @@ class ReplicationCommandHandler: def __init__(self, hs): self._replication_data_handler = hs.get_replication_data_handler() self._presence_handler = hs.get_presence_handler() + self._store = hs.get_datastore() + self._notifier = hs.get_notifier() + self._clock = hs.get_clock() + self._instance_id = hs.get_instance_id() # Set of streams that we've caught up with. self._streams_connected = set() # type: Set[str] @@ -69,8 +84,26 @@ class ReplicationCommandHandler: # The factory used to create connections. self._factory = None # type: Optional[ReplicationClientFactory] - # The current connection. None if we are currently (re)connecting - self._connection = None + # The currently connected connections. + self._connections = [] # type: List[AbstractConnection] + + LaterGauge( + "synapse_replication_tcp_resource_total_connections", + "", + [], + lambda: len(self._connections), + ) + + self._is_master = hs.config.worker_app is None + + self._federation_sender = None + if self._is_master and not hs.config.send_federation: + self._federation_sender = hs.get_federation_sender() + + self._server_notices_sender = None + if self._is_master: + self._server_notices_sender = hs.get_server_notices_sender() + self._notifier.add_remote_server_up_callback(self.send_remote_server_up) def start_replication(self, hs): """Helper method to start a replication connection to the remote server @@ -82,6 +115,70 @@ class ReplicationCommandHandler: port = hs.config.worker_replication_port hs.get_reactor().connectTCP(host, port, self._factory) + async def on_REPLICATE(self, cmd: ReplicateCommand): + # We only want to announce positions by the writer of the streams. + # Currently this is just the master process. + if not self._is_master: + return + + for stream_name, stream in self._streams.items(): + current_token = stream.current_token() + self.send_command(PositionCommand(stream_name, current_token)) + + async def on_USER_SYNC(self, cmd: UserSyncCommand): + user_sync_counter.inc() + + if self._is_master: + await self._presence_handler.update_external_syncs_row( + cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms + ) + + async def on_CLEAR_USER_SYNC(self, cmd: ClearUserSyncsCommand): + if self._is_master: + await self._presence_handler.update_external_syncs_clear(cmd.instance_id) + + async def on_FEDERATION_ACK(self, cmd: FederationAckCommand): + federation_ack_counter.inc() + + if self._federation_sender: + self._federation_sender.federation_ack(cmd.token) + + async def on_REMOVE_PUSHER(self, cmd: RemovePusherCommand): + remove_pusher_counter.inc() + + if self._is_master: + await self._store.delete_pusher_by_app_id_pushkey_user_id( + app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id + ) + + self._notifier.on_new_replication_data() + + async def on_INVALIDATE_CACHE(self, cmd: InvalidateCacheCommand): + invalidate_cache_counter.inc() + + if self._is_master: + # We invalidate the cache locally, but then also stream that to other + # workers. + await self._store.invalidate_cache_and_stream( + cmd.cache_func, tuple(cmd.keys) + ) + + async def on_USER_IP(self, cmd: UserIpCommand): + user_ip_cache_counter.inc() + + if self._is_master: + await self._store.insert_client_ip( + cmd.user_id, + cmd.access_token, + cmd.ip, + cmd.user_agent, + cmd.device_id, + cmd.last_seen, + ) + + if self._server_notices_sender: + await self._server_notices_sender.on_user_ip(cmd.user_id) + async def on_RDATA(self, cmd: RdataCommand): stream_name = cmd.stream_name inbound_rdata_count.labels(stream_name).inc() @@ -174,6 +271,9 @@ class ReplicationCommandHandler: """"Called when get a new REMOTE_SERVER_UP command.""" self._replication_data_handler.on_remote_server_up(cmd.data) + if self._is_master: + self._notifier.notify_remote_server_up(cmd.data) + def get_currently_syncing_users(self): """Get the list of currently syncing users (if any). This is called when a connection has been established and we need to send the @@ -181,29 +281,63 @@ class ReplicationCommandHandler: """ return self._presence_handler.get_currently_syncing_users() - def update_connection(self, connection): - """Called when a connection has been established (or lost with None). + def new_connection(self, connection: AbstractConnection): + """Called when we have a new connection. """ - self._connection = connection + self._connections.append(connection) + + # If we are connected to replication as a client (rather than a server) + # we need to reset the reconnection delay on the client factory (which + # is used to do exponential back off when the connection drops). + # + # Ideally we would reset the delay when we've "fully established" the + # connection (for some definition thereof) to stop us from tightlooping + # on reconnection if something fails after this point and we drop the + # connection. Unfortunately, we don't really have a better definition of + # "fully established" than the connection being established. + if self._factory: + self._factory.resetDelay() + + # Tell the server if we have any users currently syncing (should only + # happen on synchrotrons) + currently_syncing = self.get_currently_syncing_users() + now = self._clock.time_msec() + for user_id in currently_syncing: + connection.send_command( + UserSyncCommand(self._instance_id, user_id, True, now) + ) - def finished_connecting(self): - """Called when we have successfully subscribed and caught up to all - streams we're interested in. + def lost_connection(self, connection: AbstractConnection): + """Called when a connection is closed/lost. """ - logger.info("Finished connecting to server") + try: + self._connections.remove(connection) + except ValueError: + pass - # We don't reset the delay any earlier as otherwise if there is a - # problem during start up we'll end up tight looping connecting to the - # server. - if self._factory: - self._factory.resetDelay() + def connected(self) -> bool: + """Do we have any replication connections open? + + Is used by e.g. `ReplicationStreamer` to no-op if nothing is connected. + """ + return bool(self._connections) def send_command(self, cmd: Command): - """Send a command to master (when we get establish a connection if we - don't have one already.) + """Send a command to all connected connections. """ - if self._connection: - self._connection.send_command(cmd) + if self._connections: + for connection in self._connections: + try: + connection.send_command(cmd) + except Exception: + # We probably want to catch some types of exceptions here + # and log them as warnings (e.g. connection gone), but I + # can't find what those exception types they would be. + logger.exception( + "Failed to write command %s to connection %s", + cmd.NAME, + connection, + ) else: logger.warning("Dropping command as not connected: %r", cmd.NAME) @@ -250,3 +384,10 @@ class ReplicationCommandHandler: def send_remote_server_up(self, server: str): self.send_command(RemoteServerUpCommand(server)) + + def stream_update(self, stream_name: str, token: str, data: Any): + """Called when a new update is available to stream to clients. + + We need to check if the client is interested in the stream or not + """ + self.send_command(RdataCommand(stream_name, token, data)) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index f2a37f568e..9aabb9c586 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -46,6 +46,7 @@ indicate which side is sending, these are *not* included on the wire:: > ERROR server stopping * connection closed by server * """ +import abc import fcntl import logging import struct @@ -69,13 +70,8 @@ from synapse.replication.tcp.commands import ( ErrorCommand, NameCommand, PingCommand, - PositionCommand, - RdataCommand, - RemoteServerUpCommand, ReplicateCommand, ServerCommand, - SyncCommand, - UserSyncCommand, ) from synapse.types import Collection from synapse.util import Clock @@ -118,7 +114,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): are only sent by the server. On receiving a new command it calls `on_` with the parsed - command. + command before delegating to `ReplicationCommandHandler.on_`. It also sends `PING` periodically, and correctly times out remote connections (if they send a `PING` command) @@ -134,8 +130,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): max_line_buffer = 10000 - def __init__(self, clock): + def __init__(self, clock: Clock, handler: "ReplicationCommandHandler"): self.clock = clock + self.command_handler = handler self.last_received_command = self.clock.time_msec() self.last_sent_command = 0 @@ -175,6 +172,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # can time us out. self.send_command(PingCommand(self.clock.time_msec())) + self.command_handler.new_connection(self) + def send_ping(self): """Periodically sends a ping and checks if we should close the connection due to the other side timing out. @@ -243,13 +242,31 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): async def handle_command(self, cmd: Command): """Handle a command we have received over the replication stream. - By default delegates to on_, which should return an awaitable. + First calls `self.on_` if it exists, then calls + `self.command_handler.on_` if it exists. This allows for + protocol level handling of commands (e.g. PINGs), before delegating to + the handler. Args: cmd: received command """ - handler = getattr(self, "on_%s" % (cmd.NAME,)) - await handler(cmd) + handled = False + + # First call any command handlers on this instance. These are for TCP + # specific handling. + cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True + + # Then call out to the handler. + cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True + + if not handled: + logger.warning("Unhandled command: %r", cmd) def close(self): logger.warning("[%s] Closing connection", self.id()) @@ -378,6 +395,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.state = ConnectionStates.CLOSED self.pending_commands = [] + self.command_handler.lost_connection(self) + if self.transport: self.transport.unregisterProducer() @@ -404,74 +423,21 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS - def __init__(self, server_name, clock, streamer): - BaseReplicationStreamProtocol.__init__(self, clock) # Old style class + def __init__( + self, server_name: str, clock: Clock, handler: "ReplicationCommandHandler" + ): + super().__init__(clock, handler) self.server_name = server_name - self.streamer = streamer def connectionMade(self): self.send_command(ServerCommand(self.server_name)) - BaseReplicationStreamProtocol.connectionMade(self) - self.streamer.new_connection(self) + super().connectionMade() async def on_NAME(self, cmd): logger.info("[%s] Renamed to %r", self.id(), cmd.data) self.name = cmd.data - async def on_USER_SYNC(self, cmd): - await self.streamer.on_user_sync( - cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms - ) - - async def on_CLEAR_USER_SYNC(self, cmd): - await self.streamer.on_clear_user_syncs(cmd.instance_id) - - async def on_REPLICATE(self, cmd): - # Subscribe to all streams we're publishing to. - for stream_name in self.streamer.streams_by_name: - current_token = self.streamer.get_stream_token(stream_name) - self.send_command(PositionCommand(stream_name, current_token)) - - async def on_FEDERATION_ACK(self, cmd): - self.streamer.federation_ack(cmd.token) - - async def on_REMOVE_PUSHER(self, cmd): - await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) - - async def on_INVALIDATE_CACHE(self, cmd): - await self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) - - async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): - self.streamer.on_remote_server_up(cmd.data) - - async def on_USER_IP(self, cmd): - self.streamer.on_user_ip( - cmd.user_id, - cmd.access_token, - cmd.ip, - cmd.user_agent, - cmd.device_id, - cmd.last_seen, - ) - - def stream_update(self, stream_name, token, data): - """Called when a new update is available to stream to clients. - - We need to check if the client is interested in the stream or not - """ - self.send_command(RdataCommand(stream_name, token, data)) - - def send_sync(self, data): - self.send_command(SyncCommand(data)) - - def send_remote_server_up(self, server: str): - self.send_command(RemoteServerUpCommand(server)) - - def on_connection_closed(self): - BaseReplicationStreamProtocol.on_connection_closed(self) - self.streamer.lost_connection(self) - class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS @@ -485,59 +451,18 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): clock: Clock, command_handler: "ReplicationCommandHandler", ): - BaseReplicationStreamProtocol.__init__(self, clock) - - self.instance_id = hs.get_instance_id() + super().__init__(clock, command_handler) self.client_name = client_name self.server_name = server_name - self.handler = command_handler def connectionMade(self): self.send_command(NameCommand(self.client_name)) - BaseReplicationStreamProtocol.connectionMade(self) + super().connectionMade() # Once we've connected subscribe to the necessary streams self.replicate() - # Tell the server if we have any users currently syncing (should only - # happen on synchrotrons) - currently_syncing = self.handler.get_currently_syncing_users() - now = self.clock.time_msec() - for user_id in currently_syncing: - self.send_command(UserSyncCommand(self.instance_id, user_id, True, now)) - - # We've now finished connecting to so inform the client handler - self.handler.update_connection(self) - self.handler.finished_connecting() - - async def handle_command(self, cmd: Command): - """Handle a command we have received over the replication stream. - - Delegates to `command_handler.on_`, which must return an - awaitable. - - Args: - cmd: received command - """ - handled = False - - # First call any command handlers on this instance. These are for TCP - # specific handling. - cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) - if cmd_func: - await cmd_func(cmd) - handled = True - - # Then call out to the handler. - cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) - if cmd_func: - await cmd_func(cmd) - handled = True - - if not handled: - logger.warning("Unhandled command: %r", cmd) - async def on_SERVER(self, cmd): if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) @@ -550,9 +475,21 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): self.send_command(ReplicateCommand()) - def on_connection_closed(self): - BaseReplicationStreamProtocol.on_connection_closed(self) - self.handler.update_connection(None) + +class AbstractConnection(abc.ABC): + """An interface for replication connections. + """ + + @abc.abstractmethod + def send_command(self, cmd: Command): + """Send the command down the connection + """ + pass + + +# This tells python that `BaseReplicationStreamProtocol` implements the +# interface. +AbstractConnection.register(BaseReplicationStreamProtocol) # The following simply registers metrics for the replication connections diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 30021ee309..b2d6baa2a2 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -17,7 +17,7 @@ import logging import random -from typing import Any, Dict, List +from typing import Dict from six import itervalues @@ -25,24 +25,14 @@ from prometheus_client import Counter from twisted.internet.protocol import Factory -from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util.metrics import Measure, measure_func - -from .protocol import ServerReplicationStreamProtocol -from .streams import STREAMS_MAP, Stream -from .streams.federation import FederationStream +from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol +from synapse.replication.tcp.streams import STREAMS_MAP, FederationStream, Stream +from synapse.util.metrics import Measure stream_updates_counter = Counter( "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"] ) -user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "") -federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "") -remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "") -invalidate_cache_counter = Counter( - "synapse_replication_tcp_resource_invalidate_cache", "" -) -user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") logger = logging.getLogger(__name__) @@ -52,13 +42,23 @@ class ReplicationStreamProtocolFactory(Factory): """ def __init__(self, hs): - self.streamer = hs.get_replication_streamer() + self.command_handler = hs.get_tcp_replication() self.clock = hs.get_clock() self.server_name = hs.config.server_name + # If we've created a `ReplicationStreamProtocolFactory` then we're + # almost certainly registering a replication listener, so let's ensure + # that we've started a `ReplicationStreamer` instance to actually push + # data. + # + # (This is a bit of a weird place to do this, but the alternatives such + # as putting this in `HomeServer.setup()`, requires either passing the + # listener config again or always starting a `ReplicationStreamer`.) + hs.get_replication_streamer() + def buildProtocol(self, addr): return ServerReplicationStreamProtocol( - self.server_name, self.clock, self.streamer + self.server_name, self.clock, self.command_handler ) @@ -78,16 +78,6 @@ class ReplicationStreamer(object): self._replication_torture_level = hs.config.replication_torture_level - # Current connections. - self.connections = [] # type: List[ServerReplicationStreamProtocol] - - LaterGauge( - "synapse_replication_tcp_resource_total_connections", - "", - [], - lambda: len(self.connections), - ) - # List of streams that clients can subscribe to. # We only support federation stream if federation sending hase been # disabled on the master. @@ -104,18 +94,12 @@ class ReplicationStreamer(object): self.federation_sender = hs.get_federation_sender() self.notifier.add_replication_callback(self.on_notifier_poke) - self.notifier.add_remote_server_up_callback(self.send_remote_server_up) # Keeps track of whether we are currently checking for updates self.is_looping = False self.pending_updates = False - hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.on_shutdown) - - def on_shutdown(self): - # close all connections on shutdown - for conn in self.connections: - conn.send_error("server shutting down") + self.command_handler = hs.get_tcp_replication() def get_streams(self) -> Dict[str, Stream]: """Get a mapp from stream name to stream instance. @@ -129,7 +113,7 @@ class ReplicationStreamer(object): This should get called each time new data is available, even if it is currently being executed, so that nothing gets missed """ - if not self.connections: + if not self.command_handler.connected(): # Don't bother if nothing is listening. We still need to advance # the stream tokens otherwise they'll fall beihind forever for stream in self.streams: @@ -186,9 +170,7 @@ class ReplicationStreamer(object): raise logger.debug( - "Sending %d updates to %d connections", - len(updates), - len(self.connections), + "Sending %d updates", len(updates), ) if updates: @@ -204,112 +186,19 @@ class ReplicationStreamer(object): # token. See RdataCommand for more details. batched_updates = _batch_updates(updates) - for conn in self.connections: - for token, row in batched_updates: - try: - conn.stream_update(stream.NAME, token, row) - except Exception: - logger.exception("Failed to replicate") + for token, row in batched_updates: + try: + self.command_handler.stream_update( + stream.NAME, token, row + ) + except Exception: + logger.exception("Failed to replicate") logger.debug("No more pending updates, breaking poke loop") finally: self.pending_updates = False self.is_looping = False - def get_stream_token(self, stream_name): - """For a given stream get all updates since token. This is called when - a client first subscribes to a stream. - """ - stream = self.streams_by_name.get(stream_name, None) - if not stream: - raise Exception("unknown stream %s", stream_name) - - return stream.current_token() - - @measure_func("repl.federation_ack") - def federation_ack(self, token): - """We've received an ack for federation stream from a client. - """ - federation_ack_counter.inc() - if self.federation_sender: - self.federation_sender.federation_ack(token) - - @measure_func("repl.on_user_sync") - async def on_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms): - """A client has started/stopped syncing on a worker. - """ - user_sync_counter.inc() - await self.presence_handler.update_external_syncs_row( - instance_id, user_id, is_syncing, last_sync_ms - ) - - async def on_clear_user_syncs(self, instance_id): - """A replication client wants us to drop all their UserSync data. - """ - await self.presence_handler.update_external_syncs_clear(instance_id) - - @measure_func("repl.on_remove_pusher") - async def on_remove_pusher(self, app_id, push_key, user_id): - """A client has asked us to remove a pusher - """ - remove_pusher_counter.inc() - await self.store.delete_pusher_by_app_id_pushkey_user_id( - app_id=app_id, pushkey=push_key, user_id=user_id - ) - - self.notifier.on_new_replication_data() - - @measure_func("repl.on_invalidate_cache") - async def on_invalidate_cache(self, cache_func: str, keys: List[Any]): - """The client has asked us to invalidate a cache - """ - invalidate_cache_counter.inc() - - # We invalidate the cache locally, but then also stream that to other - # workers. - await self.store.invalidate_cache_and_stream(cache_func, tuple(keys)) - - @measure_func("repl.on_user_ip") - async def on_user_ip( - self, user_id, access_token, ip, user_agent, device_id, last_seen - ): - """The client saw a user request - """ - user_ip_cache_counter.inc() - await self.store.insert_client_ip( - user_id, access_token, ip, user_agent, device_id, last_seen - ) - await self._server_notices_sender.on_user_ip(user_id) - - @measure_func("repl.on_remote_server_up") - def on_remote_server_up(self, server: str): - self.notifier.notify_remote_server_up(server) - - def send_remote_server_up(self, server: str): - for conn in self.connections: - conn.send_remote_server_up(server) - - def send_sync_to_all_connections(self, data): - """Sends a SYNC command to all clients. - - Used in tests. - """ - for conn in self.connections: - conn.send_sync(data) - - def new_connection(self, connection): - """A new client connection has been established - """ - self.connections.append(connection) - - def lost_connection(self, connection): - """A client connection has been lost - """ - try: - self.connections.remove(connection) - except ValueError: - pass - def _batch_updates(updates): """Takes a list of updates of form [(token, row)] and sets the token to -- cgit 1.5.1 From e13c6c7a96dc71200eef9f966bee21c27ae54426 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 7 Apr 2020 17:21:13 +0100 Subject: Handle one-word replication commands correctly `REPLICATE` is now a valid command, and it's nice if you can issue it from the console without remembering to call it `REPLICATE ` with a trailing space. --- synapse/replication/tcp/protocol.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) (limited to 'synapse/replication/tcp/protocol.py') diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 9aabb9c586..9276ed2965 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -201,15 +201,23 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): ) self.send_error("ping timeout") - def lineReceived(self, line): + def lineReceived(self, line: bytes): """Called when we've received a line """ if line.strip() == "": # Ignore blank lines return - line = line.decode("utf-8") - cmd_name, rest_of_line = line.split(" ", 1) + linestr = line.decode("utf-8") + + # split at the first " ", handling one-word commands + idx = linestr.index(" ") + if idx >= 0: + cmd_name = linestr[:idx] + rest_of_line = linestr[idx + 1 :] + else: + cmd_name = linestr + rest_of_line = "" if cmd_name not in self.VALID_INBOUND_COMMANDS: logger.error("[%s] invalid command %s", self.id(), cmd_name) -- cgit 1.5.1 From 51f7eaf908a84fcaf231899e2bf1beae14ae72c0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 22 Apr 2020 13:07:41 +0100 Subject: Add ability to run replication protocol over redis. (#7040) This is configured via the `redis` config options. --- changelog.d/7040.feature | 1 + stubs/txredisapi.pyi | 40 +++++++ synapse/app/homeserver.py | 6 + synapse/config/homeserver.py | 2 + synapse/config/redis.py | 35 ++++++ synapse/python_dependencies.py | 1 + synapse/replication/tcp/client.py | 2 +- synapse/replication/tcp/commands.py | 18 +++ synapse/replication/tcp/handler.py | 50 +++++++-- synapse/replication/tcp/protocol.py | 38 ++----- synapse/replication/tcp/redis.py | 181 +++++++++++++++++++++++++++++++ tests/replication/slave/storage/_base.py | 4 +- 12 files changed, 342 insertions(+), 36 deletions(-) create mode 100644 changelog.d/7040.feature create mode 100644 stubs/txredisapi.pyi create mode 100644 synapse/config/redis.py create mode 100644 synapse/replication/tcp/redis.py (limited to 'synapse/replication/tcp/protocol.py') diff --git a/changelog.d/7040.feature b/changelog.d/7040.feature new file mode 100644 index 0000000000..ce6140fdd1 --- /dev/null +++ b/changelog.d/7040.feature @@ -0,0 +1 @@ +Add support for running replication over Redis when using workers. diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi new file mode 100644 index 0000000000..763d3fb404 --- /dev/null +++ b/stubs/txredisapi.pyi @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains *incomplete* type hints for txredisapi. +""" + +from typing import List, Optional, Union + +class RedisProtocol: + def publish(self, channel: str, message: bytes): ... + +class SubscriberProtocol: + def subscribe(self, channels: Union[str, List[str]]): ... + +def lazyConnection( + host: str = ..., + port: int = ..., + dbid: Optional[int] = ..., + reconnect: bool = ..., + charset: str = ..., + password: Optional[str] = ..., + connectTimeout: Optional[int] = ..., + replyTimeout: Optional[int] = ..., + convertNumbers: bool = ..., +) -> RedisProtocol: ... + +class SubscriberFactory: + def buildProtocol(self, addr): ... diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 49df63acd0..cbd1ea475a 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -273,6 +273,12 @@ class SynapseHomeServer(HomeServer): def start_listening(self, listeners): config = self.get_config() + if config.redis_enabled: + # If redis is enabled we connect via the replication command handler + # in the same way as the workers (since we're effectively a client + # rather than a server). + self.get_tcp_replication().start_replication(self) + for listener in listeners: if listener["type"] == "http": self._listening_services.extend(self._listener_http(config, listener)) diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index b4bca08b20..be6c6afa74 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -31,6 +31,7 @@ from .password import PasswordConfig from .password_auth_providers import PasswordAuthProviderConfig from .push import PushConfig from .ratelimiting import RatelimitConfig +from .redis import RedisConfig from .registration import RegistrationConfig from .repository import ContentRepositoryConfig from .room_directory import RoomDirectoryConfig @@ -82,4 +83,5 @@ class HomeServerConfig(RootConfig): RoomDirectoryConfig, ThirdPartyRulesConfig, TracerConfig, + RedisConfig, ] diff --git a/synapse/config/redis.py b/synapse/config/redis.py new file mode 100644 index 0000000000..81a27619ec --- /dev/null +++ b/synapse/config/redis.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.config._base import Config +from synapse.python_dependencies import check_requirements + + +class RedisConfig(Config): + section = "redis" + + def read_config(self, config, **kwargs): + redis_config = config.get("redis", {}) + self.redis_enabled = redis_config.get("enabled", False) + + if not self.redis_enabled: + return + + check_requirements("redis") + + self.redis_host = redis_config.get("host", "localhost") + self.redis_port = redis_config.get("port", 6379) + self.redis_dbid = redis_config.get("dbid") + self.redis_password = redis_config.get("password") diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 8de8cb2c12..733c51b758 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -98,6 +98,7 @@ CONDITIONAL_REQUIREMENTS = { "sentry": ["sentry-sdk>=0.7.2"], "opentracing": ["jaeger-client>=4.0.0", "opentracing>=2.2.0"], "jwt": ["pyjwt>=1.6.4"], + "redis": ["txredisapi>=1.4.7"], } ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str] diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 700ae79158..2d07b8b2d0 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -30,7 +30,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class ReplicationClientFactory(ReconnectingClientFactory): +class DirectTcpReplicationClientFactory(ReconnectingClientFactory): """Factory for building connections to the master. Will reconnect if the connection is lost. diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 5ec89d0fb8..5002efe6a0 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -454,3 +454,21 @@ VALID_CLIENT_COMMANDS = ( ErrorCommand.NAME, RemoteServerUpCommand.NAME, ) + + +def parse_command_from_line(line: str) -> Command: + """Parses a command from a received line. + + Line should already be stripped of whitespace and be checked if blank. + """ + + idx = line.index(" ") + if idx >= 0: + cmd_name = line[:idx] + rest_of_line = line[idx + 1 :] + else: + cmd_name = line + rest_of_line = "" + + cmd_cls = COMMAND_MAP[cmd_name] + return cmd_cls.from_line(rest_of_line) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index e32e68e8c4..5b5ee2c13e 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -30,8 +30,10 @@ from typing import ( from prometheus_client import Counter +from twisted.internet.protocol import ReconnectingClientFactory + from synapse.metrics import LaterGauge -from synapse.replication.tcp.client import ReplicationClientFactory +from synapse.replication.tcp.client import DirectTcpReplicationClientFactory from synapse.replication.tcp.commands import ( ClearUserSyncsCommand, Command, @@ -92,7 +94,7 @@ class ReplicationCommandHandler: self._pending_batches = {} # type: Dict[str, List[Any]] # The factory used to create connections. - self._factory = None # type: Optional[ReplicationClientFactory] + self._factory = None # type: Optional[ReconnectingClientFactory] # The currently connected connections. self._connections = [] # type: List[AbstractConnection] @@ -119,11 +121,45 @@ class ReplicationCommandHandler: """Helper method to start a replication connection to the remote server using TCP. """ - client_name = hs.config.worker_name - self._factory = ReplicationClientFactory(hs, client_name, self) - host = hs.config.worker_replication_host - port = hs.config.worker_replication_port - hs.get_reactor().connectTCP(host, port, self._factory) + if hs.config.redis.redis_enabled: + from synapse.replication.tcp.redis import ( + RedisDirectTcpReplicationClientFactory, + ) + import txredisapi + + logger.info( + "Connecting to redis (host=%r port=%r DBID=%r)", + hs.config.redis_host, + hs.config.redis_port, + hs.config.redis_dbid, + ) + + # We need two connections to redis, one for the subscription stream and + # one to send commands to (as you can't send further redis commands to a + # connection after SUBSCRIBE is called). + + # First create the connection for sending commands. + outbound_redis_connection = txredisapi.lazyConnection( + host=hs.config.redis_host, + port=hs.config.redis_port, + dbid=hs.config.redis_dbid, + password=hs.config.redis.redis_password, + reconnect=True, + ) + + # Now create the factory/connection for the subscription stream. + self._factory = RedisDirectTcpReplicationClientFactory( + hs, outbound_redis_connection + ) + hs.get_reactor().connectTCP( + hs.config.redis.redis_host, hs.config.redis.redis_port, self._factory, + ) + else: + client_name = hs.config.worker_name + self._factory = DirectTcpReplicationClientFactory(hs, client_name, self) + host = hs.config.worker_replication_host + port = hs.config.worker_replication_port + hs.get_reactor().connectTCP(host, port, self._factory) async def on_REPLICATE(self, cmd: ReplicateCommand): # We only want to announce positions by the writer of the streams. diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 9276ed2965..7240acb0a2 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -63,7 +63,6 @@ from twisted.python.failure import Failure from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.commands import ( - COMMAND_MAP, VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS, Command, @@ -72,6 +71,7 @@ from synapse.replication.tcp.commands import ( PingCommand, ReplicateCommand, ServerCommand, + parse_command_from_line, ) from synapse.types import Collection from synapse.util import Clock @@ -210,38 +210,24 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): linestr = line.decode("utf-8") - # split at the first " ", handling one-word commands - idx = linestr.index(" ") - if idx >= 0: - cmd_name = linestr[:idx] - rest_of_line = linestr[idx + 1 :] - else: - cmd_name = linestr - rest_of_line = "" + try: + cmd = parse_command_from_line(linestr) + except Exception as e: + logger.exception("[%s] failed to parse line: %r", self.id(), linestr) + self.send_error("failed to parse line: %r (%r):" % (e, linestr)) + return - if cmd_name not in self.VALID_INBOUND_COMMANDS: - logger.error("[%s] invalid command %s", self.id(), cmd_name) - self.send_error("invalid command: %s", cmd_name) + if cmd.NAME not in self.VALID_INBOUND_COMMANDS: + logger.error("[%s] invalid command %s", self.id(), cmd.NAME) + self.send_error("invalid command: %s", cmd.NAME) return self.last_received_command = self.clock.time_msec() - self.inbound_commands_counter[cmd_name] = ( - self.inbound_commands_counter[cmd_name] + 1 + self.inbound_commands_counter[cmd.NAME] = ( + self.inbound_commands_counter[cmd.NAME] + 1 ) - cmd_cls = COMMAND_MAP[cmd_name] - try: - cmd = cmd_cls.from_line(rest_of_line) - except Exception as e: - logger.exception( - "[%s] failed to parse line %r: %r", self.id(), cmd_name, rest_of_line - ) - self.send_error( - "failed to parse line for %r: %r (%r):" % (cmd_name, e, rest_of_line) - ) - return - # Now lets try and call on_ function run_as_background_process( "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py new file mode 100644 index 0000000000..4c08425735 --- /dev/null +++ b/synapse/replication/tcp/redis.py @@ -0,0 +1,181 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING + +import txredisapi + +from synapse.logging.context import PreserveLoggingContext +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.replication.tcp.commands import ( + Command, + ReplicateCommand, + parse_command_from_line, +) +from synapse.replication.tcp.protocol import AbstractConnection + +if TYPE_CHECKING: + from synapse.replication.tcp.handler import ReplicationCommandHandler + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): + """Connection to redis subscribed to replication stream. + + Parses incoming messages from redis into replication commands, and passes + them to `ReplicationCommandHandler` + + Due to the vagaries of `txredisapi` we don't want to have a custom + constructor, so instead we expect the defined attributes below to be set + immediately after initialisation. + + Attributes: + handler: The command handler to handle incoming commands. + stream_name: The *redis* stream name to subscribe to (not anything to + do with Synapse replication streams). + outbound_redis_connection: The connection to redis to use to send + commands. + """ + + handler = None # type: ReplicationCommandHandler + stream_name = None # type: str + outbound_redis_connection = None # type: txredisapi.RedisProtocol + + def connectionMade(self): + logger.info("Connected to redis instance") + self.subscribe(self.stream_name) + self.send_command(ReplicateCommand()) + + self.handler.new_connection(self) + + def messageReceived(self, pattern: str, channel: str, message: str): + """Received a message from redis. + """ + + if message.strip() == "": + # Ignore blank lines + return + + try: + cmd = parse_command_from_line(message) + except Exception: + logger.exception( + "[%s] failed to parse line: %r", message, + ) + return + + # Now lets try and call on_ function + run_as_background_process( + "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd + ) + + async def handle_command(self, cmd: Command): + """Handle a command we have received over the replication stream. + + By default delegates to on_, which should return an awaitable. + + Args: + cmd: received command + """ + handled = False + + # First call any command handlers on this instance. These are for redis + # specific handling. + cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True + + # Then call out to the handler. + cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True + + if not handled: + logger.warning("Unhandled command: %r", cmd) + + def connectionLost(self, reason): + logger.info("Lost connection to redis instance") + self.handler.lost_connection(self) + + def send_command(self, cmd: Command): + """Send a command if connection has been established. + + Args: + cmd (Command) + """ + string = "%s %s" % (cmd.NAME, cmd.to_line()) + if "\n" in string: + raise Exception("Unexpected newline in command: %r", string) + + encoded_string = string.encode("utf-8") + + async def _send(): + with PreserveLoggingContext(): + # Note that we use the other connection as we can't send + # commands using the subscription connection. + await self.outbound_redis_connection.publish( + self.stream_name, encoded_string + ) + + run_as_background_process("send-cmd", _send) + + +class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory): + """This is a reconnecting factory that connects to redis and immediately + subscribes to a stream. + + Args: + hs + outbound_redis_connection: A connection to redis that will be used to + send outbound commands (this is seperate to the redis connection + used to subscribe). + """ + + maxDelay = 5 + continueTrying = True + protocol = RedisSubscriber + + def __init__( + self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol + ): + + super().__init__() + + # This sets the password on the RedisFactory base class (as + # SubscriberFactory constructor doesn't pass it through). + self.password = hs.config.redis.redis_password + + self.handler = hs.get_tcp_replication() + self.stream_name = hs.hostname + + self.outbound_redis_connection = outbound_redis_connection + + def buildProtocol(self, addr): + p = super().buildProtocol(addr) # type: RedisSubscriber + + # We do this here rather than add to the constructor of `RedisSubcriber` + # as to do so would involve overriding `buildProtocol` entirely, however + # the base method does some other things than just instantiating the + # protocol. + p.handler = self.handler + p.outbound_redis_connection = self.outbound_redis_connection + p.stream_name = self.stream_name + + return p diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 8902a5ab69..395c7d0306 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -16,7 +16,7 @@ from mock import Mock, NonCallableMock from synapse.replication.tcp.client import ( - ReplicationClientFactory, + DirectTcpReplicationClientFactory, ReplicationDataHandler, ) from synapse.replication.tcp.handler import ReplicationCommandHandler @@ -61,7 +61,7 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): self.slaved_store ) - client_factory = ReplicationClientFactory( + client_factory = DirectTcpReplicationClientFactory( self.hs, "client_name", self.replication_handler ) client_factory.handler = self.replication_handler -- cgit 1.5.1 From 841c581c401e1435c2c7820a803b3f0c574eb8b6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 22 Apr 2020 16:26:19 +0100 Subject: Fix replication metrics when using redis (#7325) --- changelog.d/7325.feature | 1 + synapse/replication/tcp/protocol.py | 52 ++++++++++++------------------------- synapse/replication/tcp/redis.py | 14 +++++++++- 3 files changed, 30 insertions(+), 37 deletions(-) create mode 100644 changelog.d/7325.feature (limited to 'synapse/replication/tcp/protocol.py') diff --git a/changelog.d/7325.feature b/changelog.d/7325.feature new file mode 100644 index 0000000000..ce6140fdd1 --- /dev/null +++ b/changelog.d/7325.feature @@ -0,0 +1 @@ +Add support for running replication over Redis when using workers. diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 7240acb0a2..e3f64eba8f 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -50,10 +50,7 @@ import abc import fcntl import logging import struct -from collections import defaultdict -from typing import TYPE_CHECKING, DefaultDict, List - -from six import iteritems +from typing import TYPE_CHECKING, List from prometheus_client import Counter @@ -86,6 +83,18 @@ connection_close_counter = Counter( "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"] ) +tcp_inbound_commands_counter = Counter( + "synapse_replication_tcp_protocol_inbound_commands", + "Number of commands received from replication, by command and name of process connected to", + ["command", "name"], +) + +tcp_outbound_commands_counter = Counter( + "synapse_replication_tcp_protocol_outbound_commands", + "Number of commands sent to replication, by command and name of process connected to", + ["command", "name"], +) + # A list of all connected protocols. This allows us to send metrics about the # connections. connected_connections = [] @@ -151,9 +160,6 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # The LoopingCall for sending pings. self._send_ping_loop = None - self.inbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int] - self.outbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int] - def connectionMade(self): logger.info("[%s] Connection established", self.id()) @@ -224,9 +230,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.last_received_command = self.clock.time_msec() - self.inbound_commands_counter[cmd.NAME] = ( - self.inbound_commands_counter[cmd.NAME] + 1 - ) + tcp_inbound_commands_counter.labels(cmd.NAME, self.name).inc() # Now lets try and call on_ function run_as_background_process( @@ -292,9 +296,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self._queue_command(cmd) return - self.outbound_commands_counter[cmd.NAME] = ( - self.outbound_commands_counter[cmd.NAME] + 1 - ) + tcp_outbound_commands_counter.labels(cmd.NAME, self.name).inc() + string = "%s %s" % (cmd.NAME, cmd.to_line()) if "\n" in string: raise Exception("Unexpected newline in command: %r", string) @@ -546,26 +549,3 @@ tcp_transport_kernel_read_buffer = LaterGauge( for p in connected_connections }, ) - - -tcp_inbound_commands = LaterGauge( - "synapse_replication_tcp_protocol_inbound_commands", - "", - ["command", "name"], - lambda: { - (k, p.name): count - for p in connected_connections - for k, count in iteritems(p.inbound_commands_counter) - }, -) - -tcp_outbound_commands = LaterGauge( - "synapse_replication_tcp_protocol_outbound_commands", - "", - ["command", "name"], - lambda: { - (k, p.name): count - for p in connected_connections - for k, count in iteritems(p.outbound_commands_counter) - }, -) diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 4c08425735..49b3ed0c5e 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -25,7 +25,11 @@ from synapse.replication.tcp.commands import ( ReplicateCommand, parse_command_from_line, ) -from synapse.replication.tcp.protocol import AbstractConnection +from synapse.replication.tcp.protocol import ( + AbstractConnection, + tcp_inbound_commands_counter, + tcp_outbound_commands_counter, +) if TYPE_CHECKING: from synapse.replication.tcp.handler import ReplicationCommandHandler @@ -79,6 +83,10 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): ) return + # We use "redis" as the name here as we don't have 1:1 connections to + # remote instances. + tcp_inbound_commands_counter.labels(cmd.NAME, "redis").inc() + # Now lets try and call on_ function run_as_background_process( "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd @@ -126,6 +134,10 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): encoded_string = string.encode("utf-8") + # We use "redis" as the name here as we don't have 1:1 connections to + # remote instances. + tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc() + async def _send(): with PreserveLoggingContext(): # Note that we use the other connection as we can't send -- cgit 1.5.1 From 3eab76ad43e1de4074550c468d11318d497ff632 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 29 Apr 2020 14:10:59 +0100 Subject: Don't relay REMOTE_SERVER_UP cmds to same conn. (#7352) For direct TCP connections we need the master to relay REMOTE_SERVER_UP commands to the other connections so that all instances get notified about it. The old implementation just relayed to all connections, assuming that sending back to the original sender of the command was safe. This is not true for redis, where commands sent get echoed back to the sender, which was causing master to effectively infinite loop sending and then re-receiving REMOTE_SERVER_UP commands that it sent. The fix is to ensure that we only relay to *other* connections and not to the connection we received the notification from. Fixes #7334. --- changelog.d/7352.feature | 1 + synapse/notifier.py | 9 ---- synapse/replication/tcp/handler.py | 63 ++++++++++++++++++++------ synapse/replication/tcp/protocol.py | 2 +- synapse/replication/tcp/redis.py | 2 +- tests/replication/tcp/test_remote_server_up.py | 62 +++++++++++++++++++++++++ 6 files changed, 114 insertions(+), 25 deletions(-) create mode 100644 changelog.d/7352.feature create mode 100644 tests/replication/tcp/test_remote_server_up.py (limited to 'synapse/replication/tcp/protocol.py') diff --git a/changelog.d/7352.feature b/changelog.d/7352.feature new file mode 100644 index 0000000000..ce6140fdd1 --- /dev/null +++ b/changelog.d/7352.feature @@ -0,0 +1 @@ +Add support for running replication over Redis when using workers. diff --git a/synapse/notifier.py b/synapse/notifier.py index 6132727cbd..88a5a97caf 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -220,12 +220,6 @@ class Notifier(object): """ self.replication_callbacks.append(cb) - def add_remote_server_up_callback(self, cb: Callable[[str], None]): - """Add a callback that will be called when synapse detects a server - has been - """ - self.remote_server_up_callbacks.append(cb) - def on_new_room_event( self, event, room_stream_id, max_room_stream_id, extra_users=[] ): @@ -544,6 +538,3 @@ class Notifier(object): # circular dependencies. if self.federation_sender: self.federation_sender.wake_destination(server) - - for cb in self.remote_server_up_callbacks: - cb(server) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 3a8c7c7e2d..b8f49a8d0f 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -117,7 +117,6 @@ class ReplicationCommandHandler: self._server_notices_sender = None if self._is_master: self._server_notices_sender = hs.get_server_notices_sender() - self._notifier.add_remote_server_up_callback(self.send_remote_server_up) def start_replication(self, hs): """Helper method to start a replication connection to the remote server @@ -163,7 +162,7 @@ class ReplicationCommandHandler: port = hs.config.worker_replication_port hs.get_reactor().connectTCP(host, port, self._factory) - async def on_REPLICATE(self, cmd: ReplicateCommand): + async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand): # We only want to announce positions by the writer of the streams. # Currently this is just the master process. if not self._is_master: @@ -173,7 +172,7 @@ class ReplicationCommandHandler: current_token = stream.current_token() self.send_command(PositionCommand(stream_name, current_token)) - async def on_USER_SYNC(self, cmd: UserSyncCommand): + async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand): user_sync_counter.inc() if self._is_master: @@ -181,17 +180,23 @@ class ReplicationCommandHandler: cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms ) - async def on_CLEAR_USER_SYNC(self, cmd: ClearUserSyncsCommand): + async def on_CLEAR_USER_SYNC( + self, conn: AbstractConnection, cmd: ClearUserSyncsCommand + ): if self._is_master: await self._presence_handler.update_external_syncs_clear(cmd.instance_id) - async def on_FEDERATION_ACK(self, cmd: FederationAckCommand): + async def on_FEDERATION_ACK( + self, conn: AbstractConnection, cmd: FederationAckCommand + ): federation_ack_counter.inc() if self._federation_sender: self._federation_sender.federation_ack(cmd.token) - async def on_REMOVE_PUSHER(self, cmd: RemovePusherCommand): + async def on_REMOVE_PUSHER( + self, conn: AbstractConnection, cmd: RemovePusherCommand + ): remove_pusher_counter.inc() if self._is_master: @@ -201,7 +206,9 @@ class ReplicationCommandHandler: self._notifier.on_new_replication_data() - async def on_INVALIDATE_CACHE(self, cmd: InvalidateCacheCommand): + async def on_INVALIDATE_CACHE( + self, conn: AbstractConnection, cmd: InvalidateCacheCommand + ): invalidate_cache_counter.inc() if self._is_master: @@ -211,7 +218,7 @@ class ReplicationCommandHandler: cmd.cache_func, tuple(cmd.keys) ) - async def on_USER_IP(self, cmd: UserIpCommand): + async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand): user_ip_cache_counter.inc() if self._is_master: @@ -227,7 +234,7 @@ class ReplicationCommandHandler: if self._server_notices_sender: await self._server_notices_sender.on_user_ip(cmd.user_id) - async def on_RDATA(self, cmd: RdataCommand): + async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand): stream_name = cmd.stream_name inbound_rdata_count.labels(stream_name).inc() @@ -278,7 +285,7 @@ class ReplicationCommandHandler: logger.debug("Received rdata %s -> %s", stream_name, token) await self._replication_data_handler.on_rdata(stream_name, token, rows) - async def on_POSITION(self, cmd: PositionCommand): + async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): stream = self._streams.get(cmd.stream_name) if not stream: logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) @@ -332,12 +339,30 @@ class ReplicationCommandHandler: self._streams_connected.add(cmd.stream_name) - async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): + async def on_REMOTE_SERVER_UP( + self, conn: AbstractConnection, cmd: RemoteServerUpCommand + ): """"Called when get a new REMOTE_SERVER_UP command.""" self._replication_data_handler.on_remote_server_up(cmd.data) - if self._is_master: - self._notifier.notify_remote_server_up(cmd.data) + self._notifier.notify_remote_server_up(cmd.data) + + # We relay to all other connections to ensure every instance gets the + # notification. + # + # When configured to use redis we'll always only have one connection and + # so this is a no-op (all instances will have already received the same + # REMOTE_SERVER_UP command). + # + # For direct TCP connections this will relay to all other connections + # connected to us. When on master this will correctly fan out to all + # other direct TCP clients and on workers there'll only be the one + # connection to master. + # + # (The logic here should also be sound if we have a mix of Redis and + # direct TCP connections so long as there is only one traffic route + # between two instances, but that is not currently supported). + self.send_command(cmd, ignore_conn=conn) def new_connection(self, connection: AbstractConnection): """Called when we have a new connection. @@ -382,11 +407,21 @@ class ReplicationCommandHandler: """ return bool(self._connections) - def send_command(self, cmd: Command): + def send_command( + self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None + ): """Send a command to all connected connections. + + Args: + cmd + ignore_conn: If set don't send command to the given connection. + Used when relaying commands from one connection to all others. """ if self._connections: for connection in self._connections: + if connection == ignore_conn: + continue + try: connection.send_command(cmd) except Exception: diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index e3f64eba8f..4198eece71 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -260,7 +260,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # Then call out to the handler. cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None) if cmd_func: - await cmd_func(cmd) + await cmd_func(self, cmd) handled = True if not handled: diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 49b3ed0c5e..617e860f95 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -112,7 +112,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): # Then call out to the handler. cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) if cmd_func: - await cmd_func(cmd) + await cmd_func(self, cmd) handled = True if not handled: diff --git a/tests/replication/tcp/test_remote_server_up.py b/tests/replication/tcp/test_remote_server_up.py new file mode 100644 index 0000000000..d1c15caeb0 --- /dev/null +++ b/tests/replication/tcp/test_remote_server_up.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +from twisted.internet.interfaces import IProtocol +from twisted.test.proto_helpers import StringTransport + +from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory + +from tests.unittest import HomeserverTestCase + + +class RemoteServerUpTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, hs): + self.factory = ReplicationStreamProtocolFactory(hs) + + def _make_client(self) -> Tuple[IProtocol, StringTransport]: + """Create a new direct TCP replication connection + """ + + proto = self.factory.buildProtocol(("127.0.0.1", 0)) + transport = StringTransport() + proto.makeConnection(transport) + + # We can safely ignore the commands received during connection. + self.pump() + transport.clear() + + return proto, transport + + def test_relay(self): + """Test that Synapse will relay REMOTE_SERVER_UP commands to all + other connections, but not the one that sent it. + """ + + proto1, transport1 = self._make_client() + + # We shouldn't receive an echo. + proto1.dataReceived(b"REMOTE_SERVER_UP example.com\n") + self.pump() + self.assertEqual(transport1.value(), b"") + + # But we should see an echo if we connect another client + proto2, transport2 = self._make_client() + proto1.dataReceived(b"REMOTE_SERVER_UP example.com\n") + + self.pump() + self.assertEqual(transport1.value(), b"") + self.assertEqual(transport2.value(), b"REMOTE_SERVER_UP example.com\n") -- cgit 1.5.1