From 5016b162fcf0372fe35404c64f80aeaf21461f31 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 6 Apr 2020 09:58:42 +0100 Subject: Move client command handling out of TCP protocol (#7185) The aim here is to move the command handling out of the TCP protocol classes and to also merge the client and server command handling (so that we can reuse them for redis protocol). This PR simply moves the client paths to the new `ReplicationCommandHandler`, a future PR will move the server paths too. --- synapse/replication/tcp/handler.py | 252 +++++++++++++++++++++++++++++++++++++ 1 file changed, 252 insertions(+) create mode 100644 synapse/replication/tcp/handler.py (limited to 'synapse/replication/tcp/handler.py') diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py new file mode 100644 index 0000000000..12a1cfd6d1 --- /dev/null +++ b/synapse/replication/tcp/handler.py @@ -0,0 +1,252 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, Callable, Dict, List, Optional, Set + +from prometheus_client import Counter + +from synapse.replication.tcp.client import ReplicationClientFactory +from synapse.replication.tcp.commands import ( + Command, + FederationAckCommand, + InvalidateCacheCommand, + PositionCommand, + RdataCommand, + RemoteServerUpCommand, + RemovePusherCommand, + SyncCommand, + UserIpCommand, + UserSyncCommand, +) +from synapse.replication.tcp.streams import STREAMS_MAP, Stream +from synapse.util.async_helpers import Linearizer + +logger = logging.getLogger(__name__) + + +# number of updates received for each RDATA stream +inbound_rdata_count = Counter( + "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"] +) + + +class ReplicationCommandHandler: + """Handles incoming commands from replication as well as sending commands + back out to connections. + """ + + def __init__(self, hs): + self._replication_data_handler = hs.get_replication_data_handler() + self._presence_handler = hs.get_presence_handler() + + # Set of streams that we've caught up with. + self._streams_connected = set() # type: Set[str] + + self._streams = { + stream.NAME: stream(hs) for stream in STREAMS_MAP.values() + } # type: Dict[str, Stream] + + self._position_linearizer = Linearizer("replication_position") + + # Map of stream to batched updates. See RdataCommand for info on how + # batching works. + self._pending_batches = {} # type: Dict[str, List[Any]] + + # The factory used to create connections. + self._factory = None # type: Optional[ReplicationClientFactory] + + # The current connection. None if we are currently (re)connecting + self._connection = None + + def start_replication(self, hs): + """Helper method to start a replication connection to the remote server + using TCP. + """ + client_name = hs.config.worker_name + self._factory = ReplicationClientFactory(hs, client_name, self) + host = hs.config.worker_replication_host + port = hs.config.worker_replication_port + hs.get_reactor().connectTCP(host, port, self._factory) + + async def on_RDATA(self, cmd: RdataCommand): + stream_name = cmd.stream_name + inbound_rdata_count.labels(stream_name).inc() + + try: + row = STREAMS_MAP[stream_name].parse_row(cmd.row) + except Exception: + logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row) + raise + + if cmd.token is None or stream_name not in self._streams_connected: + # I.e. either this is part of a batch of updates for this stream (in + # which case batch until we get an update for the stream with a non + # None token) or we're currently connecting so we queue up rows. + self._pending_batches.setdefault(stream_name, []).append(row) + else: + # Check if this is the last of a batch of updates + rows = self._pending_batches.pop(stream_name, []) + rows.append(row) + await self.on_rdata(stream_name, cmd.token, rows) + + async def on_rdata(self, stream_name: str, token: int, rows: list): + """Called to handle a batch of replication data with a given stream token. + + Args: + stream_name: name of the replication stream for this batch of rows + token: stream token for this batch of rows + rows: a list of Stream.ROW_TYPE objects as returned by + Stream.parse_row. + """ + logger.debug("Received rdata %s -> %s", stream_name, token) + await self._replication_data_handler.on_rdata(stream_name, token, rows) + + async def on_POSITION(self, cmd: PositionCommand): + stream = self._streams.get(cmd.stream_name) + if not stream: + logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) + return + + # We protect catching up with a linearizer in case the replication + # connection reconnects under us. + with await self._position_linearizer.queue(cmd.stream_name): + # We're about to go and catch up with the stream, so mark as connecting + # to stop RDATA being handled at the same time by removing stream from + # list of connected streams. We also clear any batched up RDATA from + # before we got the POSITION. + self._streams_connected.discard(cmd.stream_name) + self._pending_batches.clear() + + # Find where we previously streamed up to. + current_token = self._replication_data_handler.get_streams_to_replicate().get( + cmd.stream_name + ) + if current_token is None: + logger.warning( + "Got POSITION for stream we're not subscribed to: %s", + cmd.stream_name, + ) + return + + # Fetch all updates between then and now. + limited = True + while limited: + updates, current_token, limited = await stream.get_updates_since( + current_token, cmd.token + ) + if updates: + await self.on_rdata( + cmd.stream_name, + current_token, + [stream.parse_row(update[1]) for update in updates], + ) + + # We've now caught up to position sent to us, notify handler. + await self._replication_data_handler.on_position(cmd.stream_name, cmd.token) + + # Handle any RDATA that came in while we were catching up. + rows = self._pending_batches.pop(cmd.stream_name, []) + if rows: + await self._replication_data_handler.on_rdata( + cmd.stream_name, rows[-1].token, rows + ) + + self._streams_connected.add(cmd.stream_name) + + async def on_SYNC(self, cmd: SyncCommand): + pass + + async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): + """"Called when get a new REMOTE_SERVER_UP command.""" + self._replication_data_handler.on_remote_server_up(cmd.data) + + def get_currently_syncing_users(self): + """Get the list of currently syncing users (if any). This is called + when a connection has been established and we need to send the + currently syncing users. + """ + return self._presence_handler.get_currently_syncing_users() + + def update_connection(self, connection): + """Called when a connection has been established (or lost with None). + """ + self._connection = connection + + def finished_connecting(self): + """Called when we have successfully subscribed and caught up to all + streams we're interested in. + """ + logger.info("Finished connecting to server") + + # We don't reset the delay any earlier as otherwise if there is a + # problem during start up we'll end up tight looping connecting to the + # server. + if self._factory: + self._factory.resetDelay() + + def send_command(self, cmd: Command): + """Send a command to master (when we get establish a connection if we + don't have one already.) + """ + if self._connection: + self._connection.send_command(cmd) + else: + logger.warning("Dropping command as not connected: %r", cmd.NAME) + + def send_federation_ack(self, token: int): + """Ack data for the federation stream. This allows the master to drop + data stored purely in memory. + """ + self.send_command(FederationAckCommand(token)) + + def send_user_sync( + self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int + ): + """Poke the master that a user has started/stopped syncing. + """ + self.send_command( + UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) + ) + + def send_remove_pusher(self, app_id: str, push_key: str, user_id: str): + """Poke the master to remove a pusher for a user + """ + cmd = RemovePusherCommand(app_id, push_key, user_id) + self.send_command(cmd) + + def send_invalidate_cache(self, cache_func: Callable, keys: tuple): + """Poke the master to invalidate a cache. + """ + cmd = InvalidateCacheCommand(cache_func.__name__, keys) + self.send_command(cmd) + + def send_user_ip( + self, + user_id: str, + access_token: str, + ip: str, + user_agent: str, + device_id: str, + last_seen: int, + ): + """Tell the master that the user made a request. + """ + cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen) + self.send_command(cmd) + + def send_remote_server_up(self, server: str): + self.send_command(RemoteServerUpCommand(server)) -- cgit 1.5.1 From 82498ee9019747eb86ed753e08fac0990d4ac8b9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 7 Apr 2020 10:51:07 +0100 Subject: Move server command handling out of TCP protocol (#7187) This completes the merging of server and client command processing. --- changelog.d/7187.misc | 1 + synapse/replication/tcp/handler.py | 177 ++++++++++++++++++++++++++++++++---- synapse/replication/tcp/protocol.py | 165 +++++++++++---------------------- synapse/replication/tcp/resource.py | 163 ++++++--------------------------- 4 files changed, 237 insertions(+), 269 deletions(-) create mode 100644 changelog.d/7187.misc (limited to 'synapse/replication/tcp/handler.py') diff --git a/changelog.d/7187.misc b/changelog.d/7187.misc new file mode 100644 index 0000000000..60d68ae877 --- /dev/null +++ b/changelog.d/7187.misc @@ -0,0 +1 @@ +Move server command handling out of TCP protocol. diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 12a1cfd6d1..8ec0119697 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -19,8 +19,10 @@ from typing import Any, Callable, Dict, List, Optional, Set from prometheus_client import Counter +from synapse.metrics import LaterGauge from synapse.replication.tcp.client import ReplicationClientFactory from synapse.replication.tcp.commands import ( + ClearUserSyncsCommand, Command, FederationAckCommand, InvalidateCacheCommand, @@ -28,10 +30,12 @@ from synapse.replication.tcp.commands import ( RdataCommand, RemoteServerUpCommand, RemovePusherCommand, + ReplicateCommand, SyncCommand, UserIpCommand, UserSyncCommand, ) +from synapse.replication.tcp.protocol import AbstractConnection from synapse.replication.tcp.streams import STREAMS_MAP, Stream from synapse.util.async_helpers import Linearizer @@ -42,6 +46,13 @@ logger = logging.getLogger(__name__) inbound_rdata_count = Counter( "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"] ) +user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "") +federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "") +remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "") +invalidate_cache_counter = Counter( + "synapse_replication_tcp_resource_invalidate_cache", "" +) +user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") class ReplicationCommandHandler: @@ -52,6 +63,10 @@ class ReplicationCommandHandler: def __init__(self, hs): self._replication_data_handler = hs.get_replication_data_handler() self._presence_handler = hs.get_presence_handler() + self._store = hs.get_datastore() + self._notifier = hs.get_notifier() + self._clock = hs.get_clock() + self._instance_id = hs.get_instance_id() # Set of streams that we've caught up with. self._streams_connected = set() # type: Set[str] @@ -69,8 +84,26 @@ class ReplicationCommandHandler: # The factory used to create connections. self._factory = None # type: Optional[ReplicationClientFactory] - # The current connection. None if we are currently (re)connecting - self._connection = None + # The currently connected connections. + self._connections = [] # type: List[AbstractConnection] + + LaterGauge( + "synapse_replication_tcp_resource_total_connections", + "", + [], + lambda: len(self._connections), + ) + + self._is_master = hs.config.worker_app is None + + self._federation_sender = None + if self._is_master and not hs.config.send_federation: + self._federation_sender = hs.get_federation_sender() + + self._server_notices_sender = None + if self._is_master: + self._server_notices_sender = hs.get_server_notices_sender() + self._notifier.add_remote_server_up_callback(self.send_remote_server_up) def start_replication(self, hs): """Helper method to start a replication connection to the remote server @@ -82,6 +115,70 @@ class ReplicationCommandHandler: port = hs.config.worker_replication_port hs.get_reactor().connectTCP(host, port, self._factory) + async def on_REPLICATE(self, cmd: ReplicateCommand): + # We only want to announce positions by the writer of the streams. + # Currently this is just the master process. + if not self._is_master: + return + + for stream_name, stream in self._streams.items(): + current_token = stream.current_token() + self.send_command(PositionCommand(stream_name, current_token)) + + async def on_USER_SYNC(self, cmd: UserSyncCommand): + user_sync_counter.inc() + + if self._is_master: + await self._presence_handler.update_external_syncs_row( + cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms + ) + + async def on_CLEAR_USER_SYNC(self, cmd: ClearUserSyncsCommand): + if self._is_master: + await self._presence_handler.update_external_syncs_clear(cmd.instance_id) + + async def on_FEDERATION_ACK(self, cmd: FederationAckCommand): + federation_ack_counter.inc() + + if self._federation_sender: + self._federation_sender.federation_ack(cmd.token) + + async def on_REMOVE_PUSHER(self, cmd: RemovePusherCommand): + remove_pusher_counter.inc() + + if self._is_master: + await self._store.delete_pusher_by_app_id_pushkey_user_id( + app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id + ) + + self._notifier.on_new_replication_data() + + async def on_INVALIDATE_CACHE(self, cmd: InvalidateCacheCommand): + invalidate_cache_counter.inc() + + if self._is_master: + # We invalidate the cache locally, but then also stream that to other + # workers. + await self._store.invalidate_cache_and_stream( + cmd.cache_func, tuple(cmd.keys) + ) + + async def on_USER_IP(self, cmd: UserIpCommand): + user_ip_cache_counter.inc() + + if self._is_master: + await self._store.insert_client_ip( + cmd.user_id, + cmd.access_token, + cmd.ip, + cmd.user_agent, + cmd.device_id, + cmd.last_seen, + ) + + if self._server_notices_sender: + await self._server_notices_sender.on_user_ip(cmd.user_id) + async def on_RDATA(self, cmd: RdataCommand): stream_name = cmd.stream_name inbound_rdata_count.labels(stream_name).inc() @@ -174,6 +271,9 @@ class ReplicationCommandHandler: """"Called when get a new REMOTE_SERVER_UP command.""" self._replication_data_handler.on_remote_server_up(cmd.data) + if self._is_master: + self._notifier.notify_remote_server_up(cmd.data) + def get_currently_syncing_users(self): """Get the list of currently syncing users (if any). This is called when a connection has been established and we need to send the @@ -181,29 +281,63 @@ class ReplicationCommandHandler: """ return self._presence_handler.get_currently_syncing_users() - def update_connection(self, connection): - """Called when a connection has been established (or lost with None). + def new_connection(self, connection: AbstractConnection): + """Called when we have a new connection. """ - self._connection = connection + self._connections.append(connection) + + # If we are connected to replication as a client (rather than a server) + # we need to reset the reconnection delay on the client factory (which + # is used to do exponential back off when the connection drops). + # + # Ideally we would reset the delay when we've "fully established" the + # connection (for some definition thereof) to stop us from tightlooping + # on reconnection if something fails after this point and we drop the + # connection. Unfortunately, we don't really have a better definition of + # "fully established" than the connection being established. + if self._factory: + self._factory.resetDelay() + + # Tell the server if we have any users currently syncing (should only + # happen on synchrotrons) + currently_syncing = self.get_currently_syncing_users() + now = self._clock.time_msec() + for user_id in currently_syncing: + connection.send_command( + UserSyncCommand(self._instance_id, user_id, True, now) + ) - def finished_connecting(self): - """Called when we have successfully subscribed and caught up to all - streams we're interested in. + def lost_connection(self, connection: AbstractConnection): + """Called when a connection is closed/lost. """ - logger.info("Finished connecting to server") + try: + self._connections.remove(connection) + except ValueError: + pass - # We don't reset the delay any earlier as otherwise if there is a - # problem during start up we'll end up tight looping connecting to the - # server. - if self._factory: - self._factory.resetDelay() + def connected(self) -> bool: + """Do we have any replication connections open? + + Is used by e.g. `ReplicationStreamer` to no-op if nothing is connected. + """ + return bool(self._connections) def send_command(self, cmd: Command): - """Send a command to master (when we get establish a connection if we - don't have one already.) + """Send a command to all connected connections. """ - if self._connection: - self._connection.send_command(cmd) + if self._connections: + for connection in self._connections: + try: + connection.send_command(cmd) + except Exception: + # We probably want to catch some types of exceptions here + # and log them as warnings (e.g. connection gone), but I + # can't find what those exception types they would be. + logger.exception( + "Failed to write command %s to connection %s", + cmd.NAME, + connection, + ) else: logger.warning("Dropping command as not connected: %r", cmd.NAME) @@ -250,3 +384,10 @@ class ReplicationCommandHandler: def send_remote_server_up(self, server: str): self.send_command(RemoteServerUpCommand(server)) + + def stream_update(self, stream_name: str, token: str, data: Any): + """Called when a new update is available to stream to clients. + + We need to check if the client is interested in the stream or not + """ + self.send_command(RdataCommand(stream_name, token, data)) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index f2a37f568e..9aabb9c586 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -46,6 +46,7 @@ indicate which side is sending, these are *not* included on the wire:: > ERROR server stopping * connection closed by server * """ +import abc import fcntl import logging import struct @@ -69,13 +70,8 @@ from synapse.replication.tcp.commands import ( ErrorCommand, NameCommand, PingCommand, - PositionCommand, - RdataCommand, - RemoteServerUpCommand, ReplicateCommand, ServerCommand, - SyncCommand, - UserSyncCommand, ) from synapse.types import Collection from synapse.util import Clock @@ -118,7 +114,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): are only sent by the server. On receiving a new command it calls `on_` with the parsed - command. + command before delegating to `ReplicationCommandHandler.on_`. It also sends `PING` periodically, and correctly times out remote connections (if they send a `PING` command) @@ -134,8 +130,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): max_line_buffer = 10000 - def __init__(self, clock): + def __init__(self, clock: Clock, handler: "ReplicationCommandHandler"): self.clock = clock + self.command_handler = handler self.last_received_command = self.clock.time_msec() self.last_sent_command = 0 @@ -175,6 +172,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # can time us out. self.send_command(PingCommand(self.clock.time_msec())) + self.command_handler.new_connection(self) + def send_ping(self): """Periodically sends a ping and checks if we should close the connection due to the other side timing out. @@ -243,13 +242,31 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): async def handle_command(self, cmd: Command): """Handle a command we have received over the replication stream. - By default delegates to on_, which should return an awaitable. + First calls `self.on_` if it exists, then calls + `self.command_handler.on_` if it exists. This allows for + protocol level handling of commands (e.g. PINGs), before delegating to + the handler. Args: cmd: received command """ - handler = getattr(self, "on_%s" % (cmd.NAME,)) - await handler(cmd) + handled = False + + # First call any command handlers on this instance. These are for TCP + # specific handling. + cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True + + # Then call out to the handler. + cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True + + if not handled: + logger.warning("Unhandled command: %r", cmd) def close(self): logger.warning("[%s] Closing connection", self.id()) @@ -378,6 +395,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.state = ConnectionStates.CLOSED self.pending_commands = [] + self.command_handler.lost_connection(self) + if self.transport: self.transport.unregisterProducer() @@ -404,74 +423,21 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS - def __init__(self, server_name, clock, streamer): - BaseReplicationStreamProtocol.__init__(self, clock) # Old style class + def __init__( + self, server_name: str, clock: Clock, handler: "ReplicationCommandHandler" + ): + super().__init__(clock, handler) self.server_name = server_name - self.streamer = streamer def connectionMade(self): self.send_command(ServerCommand(self.server_name)) - BaseReplicationStreamProtocol.connectionMade(self) - self.streamer.new_connection(self) + super().connectionMade() async def on_NAME(self, cmd): logger.info("[%s] Renamed to %r", self.id(), cmd.data) self.name = cmd.data - async def on_USER_SYNC(self, cmd): - await self.streamer.on_user_sync( - cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms - ) - - async def on_CLEAR_USER_SYNC(self, cmd): - await self.streamer.on_clear_user_syncs(cmd.instance_id) - - async def on_REPLICATE(self, cmd): - # Subscribe to all streams we're publishing to. - for stream_name in self.streamer.streams_by_name: - current_token = self.streamer.get_stream_token(stream_name) - self.send_command(PositionCommand(stream_name, current_token)) - - async def on_FEDERATION_ACK(self, cmd): - self.streamer.federation_ack(cmd.token) - - async def on_REMOVE_PUSHER(self, cmd): - await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) - - async def on_INVALIDATE_CACHE(self, cmd): - await self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) - - async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): - self.streamer.on_remote_server_up(cmd.data) - - async def on_USER_IP(self, cmd): - self.streamer.on_user_ip( - cmd.user_id, - cmd.access_token, - cmd.ip, - cmd.user_agent, - cmd.device_id, - cmd.last_seen, - ) - - def stream_update(self, stream_name, token, data): - """Called when a new update is available to stream to clients. - - We need to check if the client is interested in the stream or not - """ - self.send_command(RdataCommand(stream_name, token, data)) - - def send_sync(self, data): - self.send_command(SyncCommand(data)) - - def send_remote_server_up(self, server: str): - self.send_command(RemoteServerUpCommand(server)) - - def on_connection_closed(self): - BaseReplicationStreamProtocol.on_connection_closed(self) - self.streamer.lost_connection(self) - class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS @@ -485,59 +451,18 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): clock: Clock, command_handler: "ReplicationCommandHandler", ): - BaseReplicationStreamProtocol.__init__(self, clock) - - self.instance_id = hs.get_instance_id() + super().__init__(clock, command_handler) self.client_name = client_name self.server_name = server_name - self.handler = command_handler def connectionMade(self): self.send_command(NameCommand(self.client_name)) - BaseReplicationStreamProtocol.connectionMade(self) + super().connectionMade() # Once we've connected subscribe to the necessary streams self.replicate() - # Tell the server if we have any users currently syncing (should only - # happen on synchrotrons) - currently_syncing = self.handler.get_currently_syncing_users() - now = self.clock.time_msec() - for user_id in currently_syncing: - self.send_command(UserSyncCommand(self.instance_id, user_id, True, now)) - - # We've now finished connecting to so inform the client handler - self.handler.update_connection(self) - self.handler.finished_connecting() - - async def handle_command(self, cmd: Command): - """Handle a command we have received over the replication stream. - - Delegates to `command_handler.on_`, which must return an - awaitable. - - Args: - cmd: received command - """ - handled = False - - # First call any command handlers on this instance. These are for TCP - # specific handling. - cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) - if cmd_func: - await cmd_func(cmd) - handled = True - - # Then call out to the handler. - cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) - if cmd_func: - await cmd_func(cmd) - handled = True - - if not handled: - logger.warning("Unhandled command: %r", cmd) - async def on_SERVER(self, cmd): if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) @@ -550,9 +475,21 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): self.send_command(ReplicateCommand()) - def on_connection_closed(self): - BaseReplicationStreamProtocol.on_connection_closed(self) - self.handler.update_connection(None) + +class AbstractConnection(abc.ABC): + """An interface for replication connections. + """ + + @abc.abstractmethod + def send_command(self, cmd: Command): + """Send the command down the connection + """ + pass + + +# This tells python that `BaseReplicationStreamProtocol` implements the +# interface. +AbstractConnection.register(BaseReplicationStreamProtocol) # The following simply registers metrics for the replication connections diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 30021ee309..b2d6baa2a2 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -17,7 +17,7 @@ import logging import random -from typing import Any, Dict, List +from typing import Dict from six import itervalues @@ -25,24 +25,14 @@ from prometheus_client import Counter from twisted.internet.protocol import Factory -from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util.metrics import Measure, measure_func - -from .protocol import ServerReplicationStreamProtocol -from .streams import STREAMS_MAP, Stream -from .streams.federation import FederationStream +from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol +from synapse.replication.tcp.streams import STREAMS_MAP, FederationStream, Stream +from synapse.util.metrics import Measure stream_updates_counter = Counter( "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"] ) -user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "") -federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "") -remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "") -invalidate_cache_counter = Counter( - "synapse_replication_tcp_resource_invalidate_cache", "" -) -user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") logger = logging.getLogger(__name__) @@ -52,13 +42,23 @@ class ReplicationStreamProtocolFactory(Factory): """ def __init__(self, hs): - self.streamer = hs.get_replication_streamer() + self.command_handler = hs.get_tcp_replication() self.clock = hs.get_clock() self.server_name = hs.config.server_name + # If we've created a `ReplicationStreamProtocolFactory` then we're + # almost certainly registering a replication listener, so let's ensure + # that we've started a `ReplicationStreamer` instance to actually push + # data. + # + # (This is a bit of a weird place to do this, but the alternatives such + # as putting this in `HomeServer.setup()`, requires either passing the + # listener config again or always starting a `ReplicationStreamer`.) + hs.get_replication_streamer() + def buildProtocol(self, addr): return ServerReplicationStreamProtocol( - self.server_name, self.clock, self.streamer + self.server_name, self.clock, self.command_handler ) @@ -78,16 +78,6 @@ class ReplicationStreamer(object): self._replication_torture_level = hs.config.replication_torture_level - # Current connections. - self.connections = [] # type: List[ServerReplicationStreamProtocol] - - LaterGauge( - "synapse_replication_tcp_resource_total_connections", - "", - [], - lambda: len(self.connections), - ) - # List of streams that clients can subscribe to. # We only support federation stream if federation sending hase been # disabled on the master. @@ -104,18 +94,12 @@ class ReplicationStreamer(object): self.federation_sender = hs.get_federation_sender() self.notifier.add_replication_callback(self.on_notifier_poke) - self.notifier.add_remote_server_up_callback(self.send_remote_server_up) # Keeps track of whether we are currently checking for updates self.is_looping = False self.pending_updates = False - hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.on_shutdown) - - def on_shutdown(self): - # close all connections on shutdown - for conn in self.connections: - conn.send_error("server shutting down") + self.command_handler = hs.get_tcp_replication() def get_streams(self) -> Dict[str, Stream]: """Get a mapp from stream name to stream instance. @@ -129,7 +113,7 @@ class ReplicationStreamer(object): This should get called each time new data is available, even if it is currently being executed, so that nothing gets missed """ - if not self.connections: + if not self.command_handler.connected(): # Don't bother if nothing is listening. We still need to advance # the stream tokens otherwise they'll fall beihind forever for stream in self.streams: @@ -186,9 +170,7 @@ class ReplicationStreamer(object): raise logger.debug( - "Sending %d updates to %d connections", - len(updates), - len(self.connections), + "Sending %d updates", len(updates), ) if updates: @@ -204,112 +186,19 @@ class ReplicationStreamer(object): # token. See RdataCommand for more details. batched_updates = _batch_updates(updates) - for conn in self.connections: - for token, row in batched_updates: - try: - conn.stream_update(stream.NAME, token, row) - except Exception: - logger.exception("Failed to replicate") + for token, row in batched_updates: + try: + self.command_handler.stream_update( + stream.NAME, token, row + ) + except Exception: + logger.exception("Failed to replicate") logger.debug("No more pending updates, breaking poke loop") finally: self.pending_updates = False self.is_looping = False - def get_stream_token(self, stream_name): - """For a given stream get all updates since token. This is called when - a client first subscribes to a stream. - """ - stream = self.streams_by_name.get(stream_name, None) - if not stream: - raise Exception("unknown stream %s", stream_name) - - return stream.current_token() - - @measure_func("repl.federation_ack") - def federation_ack(self, token): - """We've received an ack for federation stream from a client. - """ - federation_ack_counter.inc() - if self.federation_sender: - self.federation_sender.federation_ack(token) - - @measure_func("repl.on_user_sync") - async def on_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms): - """A client has started/stopped syncing on a worker. - """ - user_sync_counter.inc() - await self.presence_handler.update_external_syncs_row( - instance_id, user_id, is_syncing, last_sync_ms - ) - - async def on_clear_user_syncs(self, instance_id): - """A replication client wants us to drop all their UserSync data. - """ - await self.presence_handler.update_external_syncs_clear(instance_id) - - @measure_func("repl.on_remove_pusher") - async def on_remove_pusher(self, app_id, push_key, user_id): - """A client has asked us to remove a pusher - """ - remove_pusher_counter.inc() - await self.store.delete_pusher_by_app_id_pushkey_user_id( - app_id=app_id, pushkey=push_key, user_id=user_id - ) - - self.notifier.on_new_replication_data() - - @measure_func("repl.on_invalidate_cache") - async def on_invalidate_cache(self, cache_func: str, keys: List[Any]): - """The client has asked us to invalidate a cache - """ - invalidate_cache_counter.inc() - - # We invalidate the cache locally, but then also stream that to other - # workers. - await self.store.invalidate_cache_and_stream(cache_func, tuple(keys)) - - @measure_func("repl.on_user_ip") - async def on_user_ip( - self, user_id, access_token, ip, user_agent, device_id, last_seen - ): - """The client saw a user request - """ - user_ip_cache_counter.inc() - await self.store.insert_client_ip( - user_id, access_token, ip, user_agent, device_id, last_seen - ) - await self._server_notices_sender.on_user_ip(user_id) - - @measure_func("repl.on_remote_server_up") - def on_remote_server_up(self, server: str): - self.notifier.notify_remote_server_up(server) - - def send_remote_server_up(self, server: str): - for conn in self.connections: - conn.send_remote_server_up(server) - - def send_sync_to_all_connections(self, data): - """Sends a SYNC command to all clients. - - Used in tests. - """ - for conn in self.connections: - conn.send_sync(data) - - def new_connection(self, connection): - """A new client connection has been established - """ - self.connections.append(connection) - - def lost_connection(self, connection): - """A client connection has been lost - """ - try: - self.connections.remove(connection) - except ValueError: - pass - def _batch_updates(updates): """Takes a list of updates of form [(token, row)] and sets the token to -- cgit 1.5.1 From ce72355d7f67a986d60a7d86489b1b40f93fb152 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 7 Apr 2020 11:01:04 +0100 Subject: Fix race in replication (#7226) Fixes a race between handling `POSITION` and `RDATA` commands. We do this by simply linearizing handling of them. --- changelog.d/7226.misc | 1 + synapse/replication/tcp/handler.py | 73 +++++++++++++++++---------- synapse/replication/tcp/streams/_base.py | 3 +- synapse/storage/data_stores/main/push_rule.py | 40 +++++++-------- 4 files changed, 68 insertions(+), 49 deletions(-) create mode 100644 changelog.d/7226.misc (limited to 'synapse/replication/tcp/handler.py') diff --git a/changelog.d/7226.misc b/changelog.d/7226.misc new file mode 100644 index 0000000000..676f285377 --- /dev/null +++ b/changelog.d/7226.misc @@ -0,0 +1 @@ +Move catchup of replication streams logic to worker. diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 8ec0119697..dd71d1bc34 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -189,16 +189,34 @@ class ReplicationCommandHandler: logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row) raise - if cmd.token is None or stream_name not in self._streams_connected: - # I.e. either this is part of a batch of updates for this stream (in - # which case batch until we get an update for the stream with a non - # None token) or we're currently connecting so we queue up rows. - self._pending_batches.setdefault(stream_name, []).append(row) - else: - # Check if this is the last of a batch of updates - rows = self._pending_batches.pop(stream_name, []) - rows.append(row) - await self.on_rdata(stream_name, cmd.token, rows) + # We linearize here for two reasons: + # 1. so we don't try and concurrently handle multiple rows for the + # same stream, and + # 2. so we don't race with getting a POSITION command and fetching + # missing RDATA. + with await self._position_linearizer.queue(cmd.stream_name): + if stream_name not in self._streams_connected: + # If the stream isn't marked as connected then we haven't seen a + # `POSITION` command yet, and so we may have missed some rows. + # Let's drop the row for now, on the assumption we'll receive a + # `POSITION` soon and we'll catch up correctly then. + logger.warning( + "Discarding RDATA for unconnected stream %s -> %s", + stream_name, + cmd.token, + ) + return + + if cmd.token is None: + # I.e. this is part of a batch of updates for this stream (in + # which case batch until we get an update for the stream with a non + # None token). + self._pending_batches.setdefault(stream_name, []).append(row) + else: + # Check if this is the last of a batch of updates + rows = self._pending_batches.pop(stream_name, []) + rows.append(row) + await self.on_rdata(stream_name, cmd.token, rows) async def on_rdata(self, stream_name: str, token: int, rows: list): """Called to handle a batch of replication data with a given stream token. @@ -221,12 +239,13 @@ class ReplicationCommandHandler: # We protect catching up with a linearizer in case the replication # connection reconnects under us. with await self._position_linearizer.queue(cmd.stream_name): - # We're about to go and catch up with the stream, so mark as connecting - # to stop RDATA being handled at the same time by removing stream from - # list of connected streams. We also clear any batched up RDATA from - # before we got the POSITION. + # We're about to go and catch up with the stream, so remove from set + # of connected streams. self._streams_connected.discard(cmd.stream_name) - self._pending_batches.clear() + + # We clear the pending batches for the stream as the fetching of the + # missing updates below will fetch all rows in the batch. + self._pending_batches.pop(cmd.stream_name, []) # Find where we previously streamed up to. current_token = self._replication_data_handler.get_streams_to_replicate().get( @@ -239,12 +258,17 @@ class ReplicationCommandHandler: ) return - # Fetch all updates between then and now. - limited = True - while limited: - updates, current_token, limited = await stream.get_updates_since( - current_token, cmd.token - ) + # If the position token matches our current token then we're up to + # date and there's nothing to do. Otherwise, fetch all updates + # between then and now. + missing_updates = cmd.token != current_token + while missing_updates: + ( + updates, + current_token, + missing_updates, + ) = await stream.get_updates_since(current_token, cmd.token) + if updates: await self.on_rdata( cmd.stream_name, @@ -255,13 +279,6 @@ class ReplicationCommandHandler: # We've now caught up to position sent to us, notify handler. await self._replication_data_handler.on_position(cmd.stream_name, cmd.token) - # Handle any RDATA that came in while we were catching up. - rows = self._pending_batches.pop(cmd.stream_name, []) - if rows: - await self._replication_data_handler.on_rdata( - cmd.stream_name, rows[-1].token, rows - ) - self._streams_connected.add(cmd.stream_name) async def on_SYNC(self, cmd: SyncCommand): diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index c14dff6c64..f56a0fd4b5 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -168,12 +168,13 @@ def make_http_update_function( async def update_function( from_token: int, upto_token: int, limit: int ) -> Tuple[List[Tuple[int, tuple]], int, bool]: - return await client( + result = await client( stream_name=stream_name, from_token=from_token, upto_token=upto_token, limit=limit, ) + return result["updates"], result["upto_token"], result["limited"] return update_function diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index 46f9bda773..b3faafa0a4 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -334,6 +334,26 @@ class PushRulesWorkerStore( results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled return results + def get_all_push_rule_updates(self, last_id, current_id, limit): + """Get all the push rules changes that have happend on the server""" + if last_id == current_id: + return defer.succeed([]) + + def get_all_push_rule_updates_txn(txn): + sql = ( + "SELECT stream_id, event_stream_ordering, user_id, rule_id," + " op, priority_class, priority, conditions, actions" + " FROM push_rules_stream" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() + + return self.db.runInteraction( + "get_all_push_rule_updates", get_all_push_rule_updates_txn + ) + class PushRuleStore(PushRulesWorkerStore): @defer.inlineCallbacks @@ -685,26 +705,6 @@ class PushRuleStore(PushRulesWorkerStore): self.push_rules_stream_cache.entity_has_changed, user_id, stream_id ) - def get_all_push_rule_updates(self, last_id, current_id, limit): - """Get all the push rules changes that have happend on the server""" - if last_id == current_id: - return defer.succeed([]) - - def get_all_push_rule_updates_txn(txn): - sql = ( - "SELECT stream_id, event_stream_ordering, user_id, rule_id," - " op, priority_class, priority, conditions, actions" - " FROM push_rules_stream" - " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, current_id, limit)) - return txn.fetchall() - - return self.db.runInteraction( - "get_all_push_rule_updates", get_all_push_rule_updates_txn - ) - def get_push_rules_stream_token(self): """Get the position of the push rules stream. Returns a pair of a stream id for the push_rules stream and the -- cgit 1.5.1 From 6a519a0ca06f42c35d5c88f3fa5f62cfd1553905 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 7 Apr 2020 16:59:40 +0100 Subject: Remove vestigal references to SYNC replication command We've ripped pretty much all of this out: let's remove the remains. --- synapse/replication/tcp/commands.py | 10 ---------- synapse/replication/tcp/handler.py | 4 ---- 2 files changed, 14 deletions(-) (limited to 'synapse/replication/tcp/handler.py') diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index e4eec643f7..a07a01278a 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -289,14 +289,6 @@ class FederationAckCommand(Command): return str(self.token) -class SyncCommand(Command): - """Used for testing. The client protocol implementation allows waiting - on a SYNC command with a specified data. - """ - - NAME = "SYNC" - - class RemovePusherCommand(Command): """Sent by the client to request the master remove the given pusher. @@ -419,7 +411,6 @@ _COMMANDS = ( ReplicateCommand, UserSyncCommand, FederationAckCommand, - SyncCommand, RemovePusherCommand, InvalidateCacheCommand, UserIpCommand, @@ -437,7 +428,6 @@ VALID_SERVER_COMMANDS = ( PositionCommand.NAME, ErrorCommand.NAME, PingCommand.NAME, - SyncCommand.NAME, RemoteServerUpCommand.NAME, ) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index dd71d1bc34..2f5a299141 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -31,7 +31,6 @@ from synapse.replication.tcp.commands import ( RemoteServerUpCommand, RemovePusherCommand, ReplicateCommand, - SyncCommand, UserIpCommand, UserSyncCommand, ) @@ -281,9 +280,6 @@ class ReplicationCommandHandler: self._streams_connected.add(cmd.stream_name) - async def on_SYNC(self, cmd: SyncCommand): - pass - async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): """"Called when get a new REMOTE_SERVER_UP command.""" self._replication_data_handler.on_remote_server_up(cmd.data) -- cgit 1.5.1 From 0f8f02bc39cf1879780f82bdb8dd581588d14cca Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 20 Apr 2020 11:43:29 +0100 Subject: On catchup, process each row with its own stream id (#7286) Other parts of the code (such as the StreamChangeCache) assume that there will not be multiple changes with the same stream id. This code was introduced in #7024, and I hope this fixes #7206. --- changelog.d/7286.misc | 1 + synapse/replication/tcp/handler.py | 73 ++++++++++++++++++++++++++++-- synapse/util/caches/stream_change_cache.py | 3 ++ 3 files changed, 72 insertions(+), 5 deletions(-) create mode 100644 changelog.d/7286.misc (limited to 'synapse/replication/tcp/handler.py') diff --git a/changelog.d/7286.misc b/changelog.d/7286.misc new file mode 100644 index 0000000000..676f285377 --- /dev/null +++ b/changelog.d/7286.misc @@ -0,0 +1 @@ +Move catchup of replication streams logic to worker. diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 2f5a299141..e32e68e8c4 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -15,7 +15,18 @@ # limitations under the License. import logging -from typing import Any, Callable, Dict, List, Optional, Set +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Optional, + Set, + Tuple, + TypeVar, +) from prometheus_client import Counter @@ -268,11 +279,14 @@ class ReplicationCommandHandler: missing_updates, ) = await stream.get_updates_since(current_token, cmd.token) - if updates: + # TODO: add some tests for this + + # Some streams return multiple rows with the same stream IDs, + # which need to be processed in batches. + + for token, rows in _batch_updates(updates): await self.on_rdata( - cmd.stream_name, - current_token, - [stream.parse_row(update[1]) for update in updates], + cmd.stream_name, token, [stream.parse_row(row) for row in rows], ) # We've now caught up to position sent to us, notify handler. @@ -404,3 +418,52 @@ class ReplicationCommandHandler: We need to check if the client is interested in the stream or not """ self.send_command(RdataCommand(stream_name, token, data)) + + +UpdateToken = TypeVar("UpdateToken") +UpdateRow = TypeVar("UpdateRow") + + +def _batch_updates( + updates: Iterable[Tuple[UpdateToken, UpdateRow]] +) -> Iterator[Tuple[UpdateToken, List[UpdateRow]]]: + """Collect stream updates with the same token together + + Given a series of updates returned by Stream.get_updates_since(), collects + the updates which share the same stream_id together. + + For example: + + [(1, a), (1, b), (2, c), (3, d), (3, e)] + + becomes: + + [ + (1, [a, b]), + (2, [c]), + (3, [d, e]), + ] + """ + + update_iter = iter(updates) + + first_update = next(update_iter, None) + if first_update is None: + # empty input + return + + current_batch_token = first_update[0] + current_batch = [first_update[1]] + + for token, row in update_iter: + if token != current_batch_token: + # different token to the previous row: flush the previous + # batch and start anew + yield current_batch_token, current_batch + current_batch_token = token + current_batch = [] + + current_batch.append(row) + + # flush the final batch + yield current_batch_token, current_batch diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 235f64049c..c61d36a82e 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -126,6 +126,9 @@ class StreamChangeCache(object): """ assert type(stream_pos) is int + # FIXME: add a sanity check here that we are not overwriting existing + # data in self._cache + if stream_pos > self._earliest_known_stream_pos: old_pos = self._entity_to_key.get(entity, None) if old_pos is not None: -- cgit 1.5.1 From 51f7eaf908a84fcaf231899e2bf1beae14ae72c0 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 22 Apr 2020 13:07:41 +0100 Subject: Add ability to run replication protocol over redis. (#7040) This is configured via the `redis` config options. --- changelog.d/7040.feature | 1 + stubs/txredisapi.pyi | 40 +++++++ synapse/app/homeserver.py | 6 + synapse/config/homeserver.py | 2 + synapse/config/redis.py | 35 ++++++ synapse/python_dependencies.py | 1 + synapse/replication/tcp/client.py | 2 +- synapse/replication/tcp/commands.py | 18 +++ synapse/replication/tcp/handler.py | 50 +++++++-- synapse/replication/tcp/protocol.py | 38 ++----- synapse/replication/tcp/redis.py | 181 +++++++++++++++++++++++++++++++ tests/replication/slave/storage/_base.py | 4 +- 12 files changed, 342 insertions(+), 36 deletions(-) create mode 100644 changelog.d/7040.feature create mode 100644 stubs/txredisapi.pyi create mode 100644 synapse/config/redis.py create mode 100644 synapse/replication/tcp/redis.py (limited to 'synapse/replication/tcp/handler.py') diff --git a/changelog.d/7040.feature b/changelog.d/7040.feature new file mode 100644 index 0000000000..ce6140fdd1 --- /dev/null +++ b/changelog.d/7040.feature @@ -0,0 +1 @@ +Add support for running replication over Redis when using workers. diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi new file mode 100644 index 0000000000..763d3fb404 --- /dev/null +++ b/stubs/txredisapi.pyi @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains *incomplete* type hints for txredisapi. +""" + +from typing import List, Optional, Union + +class RedisProtocol: + def publish(self, channel: str, message: bytes): ... + +class SubscriberProtocol: + def subscribe(self, channels: Union[str, List[str]]): ... + +def lazyConnection( + host: str = ..., + port: int = ..., + dbid: Optional[int] = ..., + reconnect: bool = ..., + charset: str = ..., + password: Optional[str] = ..., + connectTimeout: Optional[int] = ..., + replyTimeout: Optional[int] = ..., + convertNumbers: bool = ..., +) -> RedisProtocol: ... + +class SubscriberFactory: + def buildProtocol(self, addr): ... diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 49df63acd0..cbd1ea475a 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -273,6 +273,12 @@ class SynapseHomeServer(HomeServer): def start_listening(self, listeners): config = self.get_config() + if config.redis_enabled: + # If redis is enabled we connect via the replication command handler + # in the same way as the workers (since we're effectively a client + # rather than a server). + self.get_tcp_replication().start_replication(self) + for listener in listeners: if listener["type"] == "http": self._listening_services.extend(self._listener_http(config, listener)) diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index b4bca08b20..be6c6afa74 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -31,6 +31,7 @@ from .password import PasswordConfig from .password_auth_providers import PasswordAuthProviderConfig from .push import PushConfig from .ratelimiting import RatelimitConfig +from .redis import RedisConfig from .registration import RegistrationConfig from .repository import ContentRepositoryConfig from .room_directory import RoomDirectoryConfig @@ -82,4 +83,5 @@ class HomeServerConfig(RootConfig): RoomDirectoryConfig, ThirdPartyRulesConfig, TracerConfig, + RedisConfig, ] diff --git a/synapse/config/redis.py b/synapse/config/redis.py new file mode 100644 index 0000000000..81a27619ec --- /dev/null +++ b/synapse/config/redis.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.config._base import Config +from synapse.python_dependencies import check_requirements + + +class RedisConfig(Config): + section = "redis" + + def read_config(self, config, **kwargs): + redis_config = config.get("redis", {}) + self.redis_enabled = redis_config.get("enabled", False) + + if not self.redis_enabled: + return + + check_requirements("redis") + + self.redis_host = redis_config.get("host", "localhost") + self.redis_port = redis_config.get("port", 6379) + self.redis_dbid = redis_config.get("dbid") + self.redis_password = redis_config.get("password") diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 8de8cb2c12..733c51b758 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -98,6 +98,7 @@ CONDITIONAL_REQUIREMENTS = { "sentry": ["sentry-sdk>=0.7.2"], "opentracing": ["jaeger-client>=4.0.0", "opentracing>=2.2.0"], "jwt": ["pyjwt>=1.6.4"], + "redis": ["txredisapi>=1.4.7"], } ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str] diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 700ae79158..2d07b8b2d0 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -30,7 +30,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class ReplicationClientFactory(ReconnectingClientFactory): +class DirectTcpReplicationClientFactory(ReconnectingClientFactory): """Factory for building connections to the master. Will reconnect if the connection is lost. diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 5ec89d0fb8..5002efe6a0 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -454,3 +454,21 @@ VALID_CLIENT_COMMANDS = ( ErrorCommand.NAME, RemoteServerUpCommand.NAME, ) + + +def parse_command_from_line(line: str) -> Command: + """Parses a command from a received line. + + Line should already be stripped of whitespace and be checked if blank. + """ + + idx = line.index(" ") + if idx >= 0: + cmd_name = line[:idx] + rest_of_line = line[idx + 1 :] + else: + cmd_name = line + rest_of_line = "" + + cmd_cls = COMMAND_MAP[cmd_name] + return cmd_cls.from_line(rest_of_line) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index e32e68e8c4..5b5ee2c13e 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -30,8 +30,10 @@ from typing import ( from prometheus_client import Counter +from twisted.internet.protocol import ReconnectingClientFactory + from synapse.metrics import LaterGauge -from synapse.replication.tcp.client import ReplicationClientFactory +from synapse.replication.tcp.client import DirectTcpReplicationClientFactory from synapse.replication.tcp.commands import ( ClearUserSyncsCommand, Command, @@ -92,7 +94,7 @@ class ReplicationCommandHandler: self._pending_batches = {} # type: Dict[str, List[Any]] # The factory used to create connections. - self._factory = None # type: Optional[ReplicationClientFactory] + self._factory = None # type: Optional[ReconnectingClientFactory] # The currently connected connections. self._connections = [] # type: List[AbstractConnection] @@ -119,11 +121,45 @@ class ReplicationCommandHandler: """Helper method to start a replication connection to the remote server using TCP. """ - client_name = hs.config.worker_name - self._factory = ReplicationClientFactory(hs, client_name, self) - host = hs.config.worker_replication_host - port = hs.config.worker_replication_port - hs.get_reactor().connectTCP(host, port, self._factory) + if hs.config.redis.redis_enabled: + from synapse.replication.tcp.redis import ( + RedisDirectTcpReplicationClientFactory, + ) + import txredisapi + + logger.info( + "Connecting to redis (host=%r port=%r DBID=%r)", + hs.config.redis_host, + hs.config.redis_port, + hs.config.redis_dbid, + ) + + # We need two connections to redis, one for the subscription stream and + # one to send commands to (as you can't send further redis commands to a + # connection after SUBSCRIBE is called). + + # First create the connection for sending commands. + outbound_redis_connection = txredisapi.lazyConnection( + host=hs.config.redis_host, + port=hs.config.redis_port, + dbid=hs.config.redis_dbid, + password=hs.config.redis.redis_password, + reconnect=True, + ) + + # Now create the factory/connection for the subscription stream. + self._factory = RedisDirectTcpReplicationClientFactory( + hs, outbound_redis_connection + ) + hs.get_reactor().connectTCP( + hs.config.redis.redis_host, hs.config.redis.redis_port, self._factory, + ) + else: + client_name = hs.config.worker_name + self._factory = DirectTcpReplicationClientFactory(hs, client_name, self) + host = hs.config.worker_replication_host + port = hs.config.worker_replication_port + hs.get_reactor().connectTCP(host, port, self._factory) async def on_REPLICATE(self, cmd: ReplicateCommand): # We only want to announce positions by the writer of the streams. diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 9276ed2965..7240acb0a2 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -63,7 +63,6 @@ from twisted.python.failure import Failure from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.commands import ( - COMMAND_MAP, VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS, Command, @@ -72,6 +71,7 @@ from synapse.replication.tcp.commands import ( PingCommand, ReplicateCommand, ServerCommand, + parse_command_from_line, ) from synapse.types import Collection from synapse.util import Clock @@ -210,38 +210,24 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): linestr = line.decode("utf-8") - # split at the first " ", handling one-word commands - idx = linestr.index(" ") - if idx >= 0: - cmd_name = linestr[:idx] - rest_of_line = linestr[idx + 1 :] - else: - cmd_name = linestr - rest_of_line = "" + try: + cmd = parse_command_from_line(linestr) + except Exception as e: + logger.exception("[%s] failed to parse line: %r", self.id(), linestr) + self.send_error("failed to parse line: %r (%r):" % (e, linestr)) + return - if cmd_name not in self.VALID_INBOUND_COMMANDS: - logger.error("[%s] invalid command %s", self.id(), cmd_name) - self.send_error("invalid command: %s", cmd_name) + if cmd.NAME not in self.VALID_INBOUND_COMMANDS: + logger.error("[%s] invalid command %s", self.id(), cmd.NAME) + self.send_error("invalid command: %s", cmd.NAME) return self.last_received_command = self.clock.time_msec() - self.inbound_commands_counter[cmd_name] = ( - self.inbound_commands_counter[cmd_name] + 1 + self.inbound_commands_counter[cmd.NAME] = ( + self.inbound_commands_counter[cmd.NAME] + 1 ) - cmd_cls = COMMAND_MAP[cmd_name] - try: - cmd = cmd_cls.from_line(rest_of_line) - except Exception as e: - logger.exception( - "[%s] failed to parse line %r: %r", self.id(), cmd_name, rest_of_line - ) - self.send_error( - "failed to parse line for %r: %r (%r):" % (cmd_name, e, rest_of_line) - ) - return - # Now lets try and call on_ function run_as_background_process( "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py new file mode 100644 index 0000000000..4c08425735 --- /dev/null +++ b/synapse/replication/tcp/redis.py @@ -0,0 +1,181 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING + +import txredisapi + +from synapse.logging.context import PreserveLoggingContext +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.replication.tcp.commands import ( + Command, + ReplicateCommand, + parse_command_from_line, +) +from synapse.replication.tcp.protocol import AbstractConnection + +if TYPE_CHECKING: + from synapse.replication.tcp.handler import ReplicationCommandHandler + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): + """Connection to redis subscribed to replication stream. + + Parses incoming messages from redis into replication commands, and passes + them to `ReplicationCommandHandler` + + Due to the vagaries of `txredisapi` we don't want to have a custom + constructor, so instead we expect the defined attributes below to be set + immediately after initialisation. + + Attributes: + handler: The command handler to handle incoming commands. + stream_name: The *redis* stream name to subscribe to (not anything to + do with Synapse replication streams). + outbound_redis_connection: The connection to redis to use to send + commands. + """ + + handler = None # type: ReplicationCommandHandler + stream_name = None # type: str + outbound_redis_connection = None # type: txredisapi.RedisProtocol + + def connectionMade(self): + logger.info("Connected to redis instance") + self.subscribe(self.stream_name) + self.send_command(ReplicateCommand()) + + self.handler.new_connection(self) + + def messageReceived(self, pattern: str, channel: str, message: str): + """Received a message from redis. + """ + + if message.strip() == "": + # Ignore blank lines + return + + try: + cmd = parse_command_from_line(message) + except Exception: + logger.exception( + "[%s] failed to parse line: %r", message, + ) + return + + # Now lets try and call on_ function + run_as_background_process( + "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd + ) + + async def handle_command(self, cmd: Command): + """Handle a command we have received over the replication stream. + + By default delegates to on_, which should return an awaitable. + + Args: + cmd: received command + """ + handled = False + + # First call any command handlers on this instance. These are for redis + # specific handling. + cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True + + # Then call out to the handler. + cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True + + if not handled: + logger.warning("Unhandled command: %r", cmd) + + def connectionLost(self, reason): + logger.info("Lost connection to redis instance") + self.handler.lost_connection(self) + + def send_command(self, cmd: Command): + """Send a command if connection has been established. + + Args: + cmd (Command) + """ + string = "%s %s" % (cmd.NAME, cmd.to_line()) + if "\n" in string: + raise Exception("Unexpected newline in command: %r", string) + + encoded_string = string.encode("utf-8") + + async def _send(): + with PreserveLoggingContext(): + # Note that we use the other connection as we can't send + # commands using the subscription connection. + await self.outbound_redis_connection.publish( + self.stream_name, encoded_string + ) + + run_as_background_process("send-cmd", _send) + + +class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory): + """This is a reconnecting factory that connects to redis and immediately + subscribes to a stream. + + Args: + hs + outbound_redis_connection: A connection to redis that will be used to + send outbound commands (this is seperate to the redis connection + used to subscribe). + """ + + maxDelay = 5 + continueTrying = True + protocol = RedisSubscriber + + def __init__( + self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol + ): + + super().__init__() + + # This sets the password on the RedisFactory base class (as + # SubscriberFactory constructor doesn't pass it through). + self.password = hs.config.redis.redis_password + + self.handler = hs.get_tcp_replication() + self.stream_name = hs.hostname + + self.outbound_redis_connection = outbound_redis_connection + + def buildProtocol(self, addr): + p = super().buildProtocol(addr) # type: RedisSubscriber + + # We do this here rather than add to the constructor of `RedisSubcriber` + # as to do so would involve overriding `buildProtocol` entirely, however + # the base method does some other things than just instantiating the + # protocol. + p.handler = self.handler + p.outbound_redis_connection = self.outbound_redis_connection + p.stream_name = self.stream_name + + return p diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 8902a5ab69..395c7d0306 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -16,7 +16,7 @@ from mock import Mock, NonCallableMock from synapse.replication.tcp.client import ( - ReplicationClientFactory, + DirectTcpReplicationClientFactory, ReplicationDataHandler, ) from synapse.replication.tcp.handler import ReplicationCommandHandler @@ -61,7 +61,7 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): self.slaved_store ) - client_factory = ReplicationClientFactory( + client_factory = DirectTcpReplicationClientFactory( self.hs, "client_name", self.replication_handler ) client_factory.handler = self.replication_handler -- cgit 1.5.1 From 71a1abb8a116372556fd577ff1b85c7cfbe3c2b3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 22 Apr 2020 22:39:04 +0100 Subject: Stop the master relaying USER_SYNC for other workers (#7318) Long story short: if we're handling presence on the current worker, we shouldn't be sending USER_SYNC commands over replication. In an attempt to figure out what is going on here, I ended up refactoring some bits of the presencehandler code, so the first 4 commits here are non-functional refactors to move this code slightly closer to sanity. (There's still plenty to do here :/). Suggest reviewing individual commits. Fixes (I hope) #7257. --- changelog.d/7318.misc | 1 + docs/tcp_replication.md | 6 +- synapse/api/constants.py | 2 + synapse/app/generic_worker.py | 85 ++++++++------- synapse/handlers/events.py | 20 ++-- synapse/handlers/initial_sync.py | 10 +- synapse/handlers/presence.py | 210 ++++++++++++++++++++---------------- synapse/replication/tcp/commands.py | 7 +- synapse/replication/tcp/handler.py | 15 +-- synapse/server.pyi | 2 +- 10 files changed, 199 insertions(+), 159 deletions(-) create mode 100644 changelog.d/7318.misc (limited to 'synapse/replication/tcp/handler.py') diff --git a/changelog.d/7318.misc b/changelog.d/7318.misc new file mode 100644 index 0000000000..676f285377 --- /dev/null +++ b/changelog.d/7318.misc @@ -0,0 +1 @@ +Move catchup of replication streams logic to worker. diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md index 3be8e50c4c..b922d9cf7e 100644 --- a/docs/tcp_replication.md +++ b/docs/tcp_replication.md @@ -196,7 +196,7 @@ Asks the server for the current position of all streams. #### USER_SYNC (C) - A user has started or stopped syncing + A user has started or stopped syncing on this process. #### CLEAR_USER_SYNC (C) @@ -216,10 +216,6 @@ Asks the server for the current position of all streams. Inform the server a cache should be invalidated -#### SYNC (S, C) - - Used exclusively in tests - ### REMOTE_SERVER_UP (S, C) Inform other processes that a remote server may have come back online. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index fda2c2e5bb..bcaf2c3600 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -97,6 +97,8 @@ class EventTypes(object): Retention = "m.room.retention" + Presence = "m.presence" + class RejectedReason(object): AUTH_ERROR = "auth_error" diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 37afd2f810..2a56fe0bd5 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -17,6 +17,9 @@ import contextlib import logging import sys +from typing import Dict, Iterable + +from typing_extensions import ContextManager from twisted.internet import defer, reactor from twisted.web.resource import NoResource @@ -38,14 +41,14 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging from synapse.federation import send_queue from synapse.federation.transport.server import TransportLayerServer -from synapse.handlers.presence import PresenceHandler, get_interested_parties +from synapse.handlers.presence import BasePresenceHandler, get_interested_parties from synapse.http.server import JsonResource from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.replication.slave.storage._base import BaseSlavedStore, __func__ +from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.client_ips import SlavedClientIpStore @@ -225,23 +228,32 @@ class KeyUploadServlet(RestServlet): return 200, {"one_time_key_counts": result} +class _NullContextManager(ContextManager[None]): + """A context manager which does nothing.""" + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + UPDATE_SYNCING_USERS_MS = 10 * 1000 -class GenericWorkerPresence(object): +class GenericWorkerPresence(BasePresenceHandler): def __init__(self, hs): + super().__init__(hs) self.hs = hs self.is_mine_id = hs.is_mine_id self.http_client = hs.get_simple_http_client() - self.store = hs.get_datastore() - self.user_to_num_current_syncs = {} - self.clock = hs.get_clock() + + self._presence_enabled = hs.config.use_presence + + # The number of ongoing syncs on this process, by user id. + # Empty if _presence_enabled is false. + self._user_to_num_current_syncs = {} # type: Dict[str, int] + self.notifier = hs.get_notifier() self.instance_id = hs.get_instance_id() - active_presence = self.store.take_presence_startup_info() - self.user_to_current_state = {state.user_id: state for state in active_presence} - # user_id -> last_sync_ms. Lists the users that have stopped syncing # but we haven't notified the master of that yet self.users_going_offline = {} @@ -259,13 +271,13 @@ class GenericWorkerPresence(object): ) def _on_shutdown(self): - if self.hs.config.use_presence: + if self._presence_enabled: self.hs.get_tcp_replication().send_command( ClearUserSyncsCommand(self.instance_id) ) def send_user_sync(self, user_id, is_syncing, last_sync_ms): - if self.hs.config.use_presence: + if self._presence_enabled: self.hs.get_tcp_replication().send_user_sync( self.instance_id, user_id, is_syncing, last_sync_ms ) @@ -307,28 +319,33 @@ class GenericWorkerPresence(object): # TODO Hows this supposed to work? return defer.succeed(None) - get_states = __func__(PresenceHandler.get_states) - get_state = __func__(PresenceHandler.get_state) - current_state_for_users = __func__(PresenceHandler.current_state_for_users) + async def user_syncing( + self, user_id: str, affect_presence: bool + ) -> ContextManager[None]: + """Record that a user is syncing. + + Called by the sync and events servlets to record that a user has connected to + this worker and is waiting for some events. + """ + if not affect_presence or not self._presence_enabled: + return _NullContextManager() - def user_syncing(self, user_id, affect_presence): - if affect_presence: - curr_sync = self.user_to_num_current_syncs.get(user_id, 0) - self.user_to_num_current_syncs[user_id] = curr_sync + 1 + curr_sync = self._user_to_num_current_syncs.get(user_id, 0) + self._user_to_num_current_syncs[user_id] = curr_sync + 1 - # If we went from no in flight sync to some, notify replication - if self.user_to_num_current_syncs[user_id] == 1: - self.mark_as_coming_online(user_id) + # If we went from no in flight sync to some, notify replication + if self._user_to_num_current_syncs[user_id] == 1: + self.mark_as_coming_online(user_id) def _end(): # We check that the user_id is in user_to_num_current_syncs because # user_to_num_current_syncs may have been cleared if we are # shutting down. - if affect_presence and user_id in self.user_to_num_current_syncs: - self.user_to_num_current_syncs[user_id] -= 1 + if user_id in self._user_to_num_current_syncs: + self._user_to_num_current_syncs[user_id] -= 1 # If we went from one in flight sync to non, notify replication - if self.user_to_num_current_syncs[user_id] == 0: + if self._user_to_num_current_syncs[user_id] == 0: self.mark_as_going_offline(user_id) @contextlib.contextmanager @@ -338,7 +355,7 @@ class GenericWorkerPresence(object): finally: _end() - return defer.succeed(_user_syncing()) + return _user_syncing() @defer.inlineCallbacks def notify_from_replication(self, states, stream_id): @@ -373,15 +390,12 @@ class GenericWorkerPresence(object): stream_id = token yield self.notify_from_replication(states, stream_id) - def get_currently_syncing_users(self): - if self.hs.config.use_presence: - return [ - user_id - for user_id, count in self.user_to_num_current_syncs.items() - if count > 0 - ] - else: - return set() + def get_currently_syncing_users_for_replication(self) -> Iterable[str]: + return [ + user_id + for user_id, count in self._user_to_num_current_syncs.items() + if count > 0 + ] class GenericWorkerTyping(object): @@ -625,8 +639,7 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler): self.store = hs.get_datastore() self.typing_handler = hs.get_typing_handler() - # NB this is a SynchrotronPresence, not a normal PresenceHandler - self.presence_handler = hs.get_presence_handler() + self.presence_handler = hs.get_presence_handler() # type: GenericWorkerPresence self.notifier = hs.get_notifier() self.notify_pushers = hs.config.start_pushers diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index ec18a42a68..71a89f09c7 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -19,6 +19,7 @@ import random from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, SynapseError from synapse.events import EventBase +from synapse.handlers.presence import format_user_presence_state from synapse.logging.utils import log_function from synapse.types import UserID from synapse.visibility import filter_events_for_client @@ -97,6 +98,8 @@ class EventStreamHandler(BaseHandler): explicit_room_id=room_id, ) + time_now = self.clock.time_msec() + # When the user joins a new room, or another user joins a currently # joined room, we need to send down presence for those users. to_add = [] @@ -112,19 +115,20 @@ class EventStreamHandler(BaseHandler): users = await self.state.get_current_users_in_room( event.room_id ) - states = await presence_handler.get_states(users, as_event=True) - to_add.extend(states) else: + users = [event.state_key] - ev = await presence_handler.get_state( - UserID.from_string(event.state_key), as_event=True - ) - to_add.append(ev) + states = await presence_handler.get_states(users) + to_add.extend( + { + "type": EventTypes.Presence, + "content": format_user_presence_state(state, time_now), + } + for state in states + ) events.extend(to_add) - time_now = self.clock.time_msec() - chunks = await self._event_serializer.serialize_events( events, time_now, diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index b116500c7d..f88bad5f25 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -381,10 +381,16 @@ class InitialSyncHandler(BaseHandler): return [] states = await presence_handler.get_states( - [m.user_id for m in room_members], as_event=True + [m.user_id for m in room_members] ) - return states + return [ + { + "type": EventTypes.Presence, + "content": format_user_presence_state(s, time_now), + } + for s in states + ] async def get_receipts(): receipts = await self.store.get_linearized_receipts_for_room( diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 6912165622..5cbefae177 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,10 +22,10 @@ The methods that define policy are: - PresenceHandler._handle_timeouts - should_notify """ - +import abc import logging from contextlib import contextmanager -from typing import Dict, List, Set +from typing import Dict, Iterable, List, Set from six import iteritems, itervalues @@ -41,7 +42,7 @@ from synapse.logging.utils import log_function from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.presence import UserPresenceState -from synapse.types import UserID, get_domain_from_id +from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches.descriptors import cached from synapse.util.metrics import Measure @@ -99,13 +100,106 @@ EXTERNAL_PROCESS_EXPIRY = 5 * 60 * 1000 assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER -class PresenceHandler(object): +class BasePresenceHandler(abc.ABC): + """Parts of the PresenceHandler that are shared between workers and master""" + + def __init__(self, hs: "synapse.server.HomeServer"): + self.clock = hs.get_clock() + self.store = hs.get_datastore() + + active_presence = self.store.take_presence_startup_info() + self.user_to_current_state = {state.user_id: state for state in active_presence} + + @abc.abstractmethod + async def user_syncing( + self, user_id: str, affect_presence: bool + ) -> ContextManager[None]: + """Returns a context manager that should surround any stream requests + from the user. + + This allows us to keep track of who is currently streaming and who isn't + without having to have timers outside of this module to avoid flickering + when users disconnect/reconnect. + + Args: + user_id: the user that is starting a sync + affect_presence: If false this function will be a no-op. + Useful for streams that are not associated with an actual + client that is being used by a user. + """ + + @abc.abstractmethod + def get_currently_syncing_users_for_replication(self) -> Iterable[str]: + """Get an iterable of syncing users on this worker, to send to the presence handler + + This is called when a replication connection is established. It should return + a list of user ids, which are then sent as USER_SYNC commands to inform the + process handling presence about those users. + + Returns: + An iterable of user_id strings. + """ + + async def get_state(self, target_user: UserID) -> UserPresenceState: + results = await self.get_states([target_user.to_string()]) + return results[0] + + async def get_states( + self, target_user_ids: Iterable[str] + ) -> List[UserPresenceState]: + """Get the presence state for users.""" + + updates_d = await self.current_state_for_users(target_user_ids) + updates = list(updates_d.values()) + + for user_id in set(target_user_ids) - {u.user_id for u in updates}: + updates.append(UserPresenceState.default(user_id)) + + return updates + + async def current_state_for_users( + self, user_ids: Iterable[str] + ) -> Dict[str, UserPresenceState]: + """Get the current presence state for multiple users. + + Returns: + dict: `user_id` -> `UserPresenceState` + """ + states = { + user_id: self.user_to_current_state.get(user_id, None) + for user_id in user_ids + } + + missing = [user_id for user_id, state in iteritems(states) if not state] + if missing: + # There are things not in our in memory cache. Lets pull them out of + # the database. + res = await self.store.get_presence_for_users(missing) + states.update(res) + + missing = [user_id for user_id, state in iteritems(states) if not state] + if missing: + new = { + user_id: UserPresenceState.default(user_id) for user_id in missing + } + states.update(new) + self.user_to_current_state.update(new) + + return states + + @abc.abstractmethod + async def set_state( + self, target_user: UserID, state: JsonDict, ignore_status_msg: bool = False + ) -> None: + """Set the presence state of the user. """ + + +class PresenceHandler(BasePresenceHandler): def __init__(self, hs: "synapse.server.HomeServer"): + super().__init__(hs) self.hs = hs self.is_mine_id = hs.is_mine_id self.server_name = hs.hostname - self.clock = hs.get_clock() - self.store = hs.get_datastore() self.wheel_timer = WheelTimer() self.notifier = hs.get_notifier() self.federation = hs.get_federation_sender() @@ -115,13 +209,6 @@ class PresenceHandler(object): federation_registry.register_edu_handler("m.presence", self.incoming_presence) - active_presence = self.store.take_presence_startup_info() - - # A dictionary of the current state of users. This is prefilled with - # non-offline presence from the DB. We should fetch from the DB if - # we can't find a users presence in here. - self.user_to_current_state = {state.user_id: state for state in active_presence} - LaterGauge( "synapse_handlers_presence_user_to_current_state_size", "", @@ -130,7 +217,7 @@ class PresenceHandler(object): ) now = self.clock.time_msec() - for state in active_presence: + for state in self.user_to_current_state.values(): self.wheel_timer.insert( now=now, obj=state.user_id, then=state.last_active_ts + IDLE_TIMER ) @@ -361,10 +448,18 @@ class PresenceHandler(object): timers_fired_counter.inc(len(states)) + syncing_user_ids = { + user_id + for user_id, count in self.user_to_num_current_syncs.items() + if count + } + for user_ids in self.external_process_to_current_syncs.values(): + syncing_user_ids.update(user_ids) + changes = handle_timeouts( states, is_mine_fn=self.is_mine_id, - syncing_user_ids=self.get_currently_syncing_users(), + syncing_user_ids=syncing_user_ids, now=now, ) @@ -462,22 +557,9 @@ class PresenceHandler(object): return _user_syncing() - def get_currently_syncing_users(self): - """Get the set of user ids that are currently syncing on this HS. - Returns: - set(str): A set of user_id strings. - """ - if self.hs.config.use_presence: - syncing_user_ids = { - user_id - for user_id, count in self.user_to_num_current_syncs.items() - if count - } - for user_ids in self.external_process_to_current_syncs.values(): - syncing_user_ids.update(user_ids) - return syncing_user_ids - else: - return set() + def get_currently_syncing_users_for_replication(self) -> Iterable[str]: + # since we are the process handling presence, there is nothing to do here. + return [] async def update_external_syncs_row( self, process_id, user_id, is_syncing, sync_time_msec @@ -554,34 +636,6 @@ class PresenceHandler(object): res = await self.current_state_for_users([user_id]) return res[user_id] - async def current_state_for_users(self, user_ids): - """Get the current presence state for multiple users. - - Returns: - dict: `user_id` -> `UserPresenceState` - """ - states = { - user_id: self.user_to_current_state.get(user_id, None) - for user_id in user_ids - } - - missing = [user_id for user_id, state in iteritems(states) if not state] - if missing: - # There are things not in our in memory cache. Lets pull them out of - # the database. - res = await self.store.get_presence_for_users(missing) - states.update(res) - - missing = [user_id for user_id, state in iteritems(states) if not state] - if missing: - new = { - user_id: UserPresenceState.default(user_id) for user_id in missing - } - states.update(new) - self.user_to_current_state.update(new) - - return states - async def _persist_and_notify(self, states): """Persist states in the database, poke the notifier and send to interested remote servers @@ -669,40 +723,6 @@ class PresenceHandler(object): federation_presence_counter.inc(len(updates)) await self._update_states(updates) - async def get_state(self, target_user, as_event=False): - results = await self.get_states([target_user.to_string()], as_event=as_event) - - return results[0] - - async def get_states(self, target_user_ids, as_event=False): - """Get the presence state for users. - - Args: - target_user_ids (list) - as_event (bool): Whether to format it as a client event or not. - - Returns: - list - """ - - updates = await self.current_state_for_users(target_user_ids) - updates = list(updates.values()) - - for user_id in set(target_user_ids) - {u.user_id for u in updates}: - updates.append(UserPresenceState.default(user_id)) - - now = self.clock.time_msec() - if as_event: - return [ - { - "type": "m.presence", - "content": format_user_presence_state(state, now), - } - for state in updates - ] - else: - return updates - async def set_state(self, target_user, state, ignore_status_msg=False): """Set the presence state of the user. """ @@ -889,7 +909,7 @@ class PresenceHandler(object): user_ids = await self.state.get_current_users_in_room(room_id) user_ids = list(filter(self.is_mine_id, user_ids)) - states = await self.current_state_for_users(user_ids) + states_d = await self.current_state_for_users(user_ids) # Filter out old presence, i.e. offline presence states where # the user hasn't been active for a week. We can change this @@ -899,7 +919,7 @@ class PresenceHandler(object): now = self.clock.time_msec() states = [ state - for state in states.values() + for state in states_d.values() if state.state != PresenceState.OFFLINE or now - state.last_active_ts < 7 * 24 * 60 * 60 * 1000 or state.status_msg is not None diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index f26aee83cb..c7880d4b63 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -210,7 +210,10 @@ class ReplicateCommand(Command): class UserSyncCommand(Command): """Sent by the client to inform the server that a user has started or - stopped syncing. Used to calculate presence on the master. + stopped syncing on this process. + + This is used by the process handling presence (typically the master) to + calculate who is online and who is not. Includes a timestamp of when the last user sync was. @@ -218,7 +221,7 @@ class UserSyncCommand(Command): USER_SYNC - Where is either "start" or "stop" + Where is either "start" or "end" """ NAME = "USER_SYNC" diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 5b5ee2c13e..0db5a3a24d 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -337,13 +337,6 @@ class ReplicationCommandHandler: if self._is_master: self._notifier.notify_remote_server_up(cmd.data) - def get_currently_syncing_users(self): - """Get the list of currently syncing users (if any). This is called - when a connection has been established and we need to send the - currently syncing users. - """ - return self._presence_handler.get_currently_syncing_users() - def new_connection(self, connection: AbstractConnection): """Called when we have a new connection. """ @@ -361,9 +354,11 @@ class ReplicationCommandHandler: if self._factory: self._factory.resetDelay() - # Tell the server if we have any users currently syncing (should only - # happen on synchrotrons) - currently_syncing = self.get_currently_syncing_users() + # Tell the other end if we have any users currently syncing. + currently_syncing = ( + self._presence_handler.get_currently_syncing_users_for_replication() + ) + now = self._clock.time_msec() for user_id in currently_syncing: connection.send_command( diff --git a/synapse/server.pyi b/synapse/server.pyi index 9013e9bac9..f1a5717028 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -97,7 +97,7 @@ class HomeServer(object): pass def get_notifier(self) -> synapse.notifier.Notifier: pass - def get_presence_handler(self) -> synapse.handlers.presence.PresenceHandler: + def get_presence_handler(self) -> synapse.handlers.presence.BasePresenceHandler: pass def get_clock(self) -> synapse.util.Clock: pass -- cgit 1.5.1 From c2e1a2110fbe9ead26b4ecbb1afd504ed035a04d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 29 Apr 2020 12:30:36 +0100 Subject: Fix limit logic for EventsStream (#7358) * Factor out functions for injecting events into database I want to add some more flexibility to the tools for injecting events into the database, and I don't want to clutter up HomeserverTestCase with them, so let's factor them out to a new file. * Rework TestReplicationDataHandler This wasn't very easy to work with: the mock wrapping was largely superfluous, and it's useful to be able to inspect the received rows, and clear out the received list. * Fix AssertionErrors being thrown by EventsStream Part of the problem was that there was an off-by-one error in the assertion, but also the limit logic was too simple. Fix it all up and add some tests. --- changelog.d/7358.bugfix | 1 + synapse/replication/tcp/handler.py | 4 +- synapse/replication/tcp/streams/events.py | 22 +- synapse/server.pyi | 5 + synapse/storage/data_stores/main/events_worker.py | 64 +++- tests/replication/tcp/streams/_base.py | 41 ++- tests/replication/tcp/streams/test_events.py | 417 ++++++++++++++++++++++ tests/replication/tcp/streams/test_receipts.py | 10 +- tests/replication/tcp/streams/test_typing.py | 11 +- tests/rest/client/v1/utils.py | 2 +- tests/test_utils/__init__.py | 20 ++ tests/test_utils/event_injection.py | 96 +++++ tests/unittest.py | 30 +- tox.ini | 2 + 14 files changed, 658 insertions(+), 67 deletions(-) create mode 100644 changelog.d/7358.bugfix create mode 100644 tests/replication/tcp/streams/test_events.py create mode 100644 tests/test_utils/event_injection.py (limited to 'synapse/replication/tcp/handler.py') diff --git a/changelog.d/7358.bugfix b/changelog.d/7358.bugfix new file mode 100644 index 0000000000..f49c600173 --- /dev/null +++ b/changelog.d/7358.bugfix @@ -0,0 +1 @@ +Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind. diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 0db5a3a24d..3a8c7c7e2d 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -87,7 +87,9 @@ class ReplicationCommandHandler: stream.NAME: stream(hs) for stream in STREAMS_MAP.values() } # type: Dict[str, Stream] - self._position_linearizer = Linearizer("replication_position") + self._position_linearizer = Linearizer( + "replication_position", clock=self._clock + ) # Map of stream to batched updates. See RdataCommand for info on how # batching works. diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index aa50492569..52df81b1bd 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -170,22 +170,16 @@ class EventsStream(Stream): limited = False upper_limit = current_token - # next up is the state delta table - - state_rows = await self._store.get_all_updated_current_state_deltas( + # next up is the state delta table. + ( + state_rows, + upper_limit, + state_rows_limited, + ) = await self._store.get_all_updated_current_state_deltas( from_token, upper_limit, target_row_count - ) # type: List[Tuple] - - # again, if we've hit the limit there, we'll need to limit the other sources - assert len(state_rows) < target_row_count - if len(state_rows) == target_row_count: - assert state_rows[-1][0] <= upper_limit - upper_limit = state_rows[-1][0] - limited = True + ) - # FIXME: is it a given that there is only one row per stream_id in the - # state_deltas table (so that we can be sure that we have got all of the - # rows for upper_limit)? + limited = limited or state_rows_limited # finally, fetch the ex-outliers rows. We assume there are few enough of these # not to bother with the limit. diff --git a/synapse/server.pyi b/synapse/server.pyi index f1a5717028..fc5886f762 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -25,6 +25,7 @@ import synapse.server_notices.server_notices_manager import synapse.server_notices.server_notices_sender import synapse.state import synapse.storage +from synapse.events.builder import EventBuilderFactory class HomeServer(object): @property @@ -121,3 +122,7 @@ class HomeServer(object): pass def get_instance_id(self) -> str: pass + def get_event_builder_factory(self) -> EventBuilderFactory: + pass + def get_storage(self) -> synapse.storage.Storage: + pass diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index ce8be72bfe..73df6b33ba 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -19,7 +19,7 @@ import itertools import logging import threading from collections import namedtuple -from typing import List, Optional +from typing import List, Optional, Tuple from canonicaljson import json from constantly import NamedConstant, Names @@ -1084,7 +1084,28 @@ class EventsWorkerStore(SQLBaseStore): "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows ) - def get_all_updated_current_state_deltas(self, from_token, to_token, limit): + async def get_all_updated_current_state_deltas( + self, from_token: int, to_token: int, target_row_count: int + ) -> Tuple[List[Tuple], int, bool]: + """Fetch updates from current_state_delta_stream + + Args: + from_token: The previous stream token. Updates from this stream id will + be excluded. + + to_token: The current stream token (ie the upper limit). Updates up to this + stream id will be included (modulo the 'limit' param) + + target_row_count: The number of rows to try to return. If more rows are + available, we will set 'limited' in the result. In the event of a large + batch, we may return more rows than this. + Returns: + A triplet `(updates, new_last_token, limited)`, where: + * `updates` is a list of database tuples. + * `new_last_token` is the new position in stream. + * `limited` is whether there are more updates to fetch. + """ + def get_all_updated_current_state_deltas_txn(txn): sql = """ SELECT stream_id, room_id, type, state_key, event_id @@ -1092,10 +1113,45 @@ class EventsWorkerStore(SQLBaseStore): WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC LIMIT ? """ - txn.execute(sql, (from_token, to_token, limit)) + txn.execute(sql, (from_token, to_token, target_row_count)) return txn.fetchall() - return self.db.runInteraction( + def get_deltas_for_stream_id_txn(txn, stream_id): + sql = """ + SELECT stream_id, room_id, type, state_key, event_id + FROM current_state_delta_stream + WHERE stream_id = ? + """ + txn.execute(sql, [stream_id]) + return txn.fetchall() + + # we need to make sure that, for every stream id in the results, we get *all* + # the rows with that stream id. + + rows = await self.db.runInteraction( "get_all_updated_current_state_deltas", get_all_updated_current_state_deltas_txn, + ) # type: List[Tuple] + + # if we've got fewer rows than the limit, we're good + if len(rows) < target_row_count: + return rows, to_token, False + + # we hit the limit, so reduce the upper limit so that we exclude the stream id + # of the last row in the result. + assert rows[-1][0] <= to_token + to_token = rows[-1][0] - 1 + + # search backwards through the list for the point to truncate + for idx in range(len(rows) - 1, 0, -1): + if rows[idx - 1][0] <= to_token: + return rows[:idx], to_token, True + + # bother. We didn't get a full set of changes for even a single + # stream id. let's run the query again, without a row limit, but for + # just one stream id. + to_token += 1 + rows = await self.db.runInteraction( + "get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token ) + return rows, to_token, True diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index 82f15c64e0..83e16cfe3d 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -12,10 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -from typing import Optional -from mock import Mock +import logging +from typing import Any, Dict, List, Optional, Tuple import attr @@ -25,6 +24,7 @@ from twisted.web.http import HTTPChannel from synapse.app.generic_worker import GenericWorkerServer from synapse.http.site import SynapseRequest +from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.tcp.client import ReplicationDataHandler from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol @@ -65,9 +65,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): # databases objects are the same. self.worker_hs.get_datastore().db = hs.get_datastore().db - self.test_handler = Mock( - wraps=TestReplicationDataHandler(self.worker_hs.get_datastore()) - ) + self.test_handler = self._build_replication_data_handler() self.worker_hs.replication_data_handler = self.test_handler repl_handler = ReplicationCommandHandler(self.worker_hs) @@ -78,6 +76,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self._client_transport = None self._server_transport = None + def _build_replication_data_handler(self): + return TestReplicationDataHandler(self.worker_hs.get_datastore()) + def reconnect(self): if self._client_transport: self.client.close() @@ -174,22 +175,28 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): class TestReplicationDataHandler(ReplicationDataHandler): """Drop-in for ReplicationDataHandler which just collects RDATA rows""" - def __init__(self, hs): - super().__init__(hs) - self.streams = set() - self._received_rdata_rows = [] + def __init__(self, store: BaseSlavedStore): + super().__init__(store) + + # streams to subscribe to: map from stream id to position + self.stream_positions = {} # type: Dict[str, int] + + # list of received (stream_name, token, row) tuples + self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]] def get_streams_to_replicate(self): - positions = {s: 0 for s in self.streams} - for stream, token, _ in self._received_rdata_rows: - if stream in self.streams: - positions[stream] = max(token, positions.get(stream, 0)) - return positions + return self.stream_positions async def on_rdata(self, stream_name, token, rows): await super().on_rdata(stream_name, token, rows) for r in rows: - self._received_rdata_rows.append((stream_name, token, r)) + self.received_rdata_rows.append((stream_name, token, r)) + + if ( + stream_name in self.stream_positions + and token > self.stream_positions[stream_name] + ): + self.stream_positions[stream_name] = token @attr.s() @@ -221,7 +228,7 @@ class _PushHTTPChannel(HTTPChannel): super().__init__() self.reactor = reactor - self._pull_to_push_producer = None + self._pull_to_push_producer = None # type: Optional[_PullToPushProducer] def registerProducer(self, producer, streaming): # Convert pull producers to push producer. diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py new file mode 100644 index 0000000000..1fa28084f9 --- /dev/null +++ b/tests/replication/tcp/streams/test_events.py @@ -0,0 +1,417 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +from synapse.api.constants import EventTypes, Membership +from synapse.events import EventBase +from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT +from synapse.replication.tcp.streams.events import ( + EventsStreamCurrentStateRow, + EventsStreamEventRow, + EventsStreamRow, +) +from synapse.rest import admin +from synapse.rest.client.v1 import login, room + +from tests.replication.tcp.streams._base import BaseStreamTestCase +from tests.test_utils.event_injection import inject_event, inject_member_event + + +class EventsStreamTestCase(BaseStreamTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + super().prepare(reactor, clock, hs) + self.user_id = self.register_user("u1", "pass") + self.user_tok = self.login("u1", "pass") + + self.reconnect() + self.test_handler.stream_positions["events"] = 0 + + self.room_id = self.helper.create_room_as(tok=self.user_tok) + self.test_handler.received_rdata_rows.clear() + + def test_update_function_event_row_limit(self): + """Test replication with many non-state events + + Checks that all events are correctly replicated when there are lots of + event rows to be replicated. + """ + # disconnect, so that we can stack up some changes + self.disconnect() + + # generate lots of non-state events. We inject them using inject_event + # so that they are not send out over replication until we call self.replicate(). + events = [ + self._inject_test_event() + for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 1) + ] + + # also one state event + state_event = self._inject_state_event() + + # tell the notifier to catch up to avoid duplicate rows. + # workaround for https://github.com/matrix-org/synapse/issues/7360 + # FIXME remove this when the above is fixed + self.replicate() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # we should have received all the expected rows in the right order + received_rows = self.test_handler.received_rdata_rows + for event in events: + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, event.event_id) + + stream_name, token, row = received_rows.pop(0) + self.assertIsInstance(row, EventsStreamRow) + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, state_event.event_id) + + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + self.assertEqual(row.data.event_id, state_event.event_id) + + self.assertEqual([], received_rows) + + def test_update_function_huge_state_change(self): + """Test replication with many state events + + Ensures that all events are correctly replicated when there are lots of + state change rows to be replicated. + """ + + # we want to generate lots of state changes at a single stream ID. + # + # We do this by having two branches in the DAG. On one, we have a moderator + # which that generates lots of state; on the other, we de-op the moderator, + # thus invalidating all the state. + + OTHER_USER = "@other_user:localhost" + + # have the user join + inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN) + + # Update existing power levels with mod at PL50 + pls = self.helper.get_state( + self.room_id, EventTypes.PowerLevels, tok=self.user_tok + ) + pls["users"][OTHER_USER] = 50 + self.helper.send_state( + self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok, + ) + + # this is the point in the DAG where we make a fork + fork_point = self.get_success( + self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id) + ) # type: List[str] + + events = [ + self._inject_state_event(sender=OTHER_USER) + for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT) + ] + + self.replicate() + # all those events and state changes should have landed + self.assertGreaterEqual( + len(self.test_handler.received_rdata_rows), 2 * len(events) + ) + + # disconnect, so that we can stack up the changes + self.disconnect() + self.test_handler.received_rdata_rows.clear() + + # a state event which doesn't get rolled back, to check that the state + # before the huge update comes through ok + state1 = self._inject_state_event() + + # roll back all the state by de-modding the user + prev_events = fork_point + pls["users"][OTHER_USER] = 0 + pl_event = inject_event( + self.hs, + prev_event_ids=prev_events, + type=EventTypes.PowerLevels, + state_key="", + sender=self.user_id, + room_id=self.room_id, + content=pls, + ) + + # one more bit of state that doesn't get rolled back + state2 = self._inject_state_event() + + # tell the notifier to catch up to avoid duplicate rows. + # workaround for https://github.com/matrix-org/synapse/issues/7360 + # FIXME remove this when the above is fixed + self.replicate() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # now we should have received all the expected rows in the right order. + # + # we expect: + # + # - two rows for state1 + # - the PL event row, plus state rows for the PL event and each + # of the states that got reverted. + # - two rows for state2 + + received_rows = self.test_handler.received_rdata_rows + + # first check the first two rows, which should be state1 + + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, state1.event_id) + + stream_name, token, row = received_rows.pop(0) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + self.assertEqual(row.data.event_id, state1.event_id) + + # now the last two rows, which should be state2 + stream_name, token, row = received_rows.pop(-2) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, state2.event_id) + + stream_name, token, row = received_rows.pop(-1) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + self.assertEqual(row.data.event_id, state2.event_id) + + # that should leave us with the rows for the PL event + self.assertEqual(len(received_rows), len(events) + 2) + + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, pl_event.event_id) + + # the state rows are unsorted + state_rows = [] # type: List[EventsStreamCurrentStateRow] + for stream_name, token, row in received_rows: + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + state_rows.append(row.data) + + state_rows.sort(key=lambda r: r.state_key) + + sr = state_rows.pop(0) + self.assertEqual(sr.type, EventTypes.PowerLevels) + self.assertEqual(sr.event_id, pl_event.event_id) + for sr in state_rows: + self.assertEqual(sr.type, "test_state_event") + # "None" indicates the state has been deleted + self.assertIsNone(sr.event_id) + + def test_update_function_state_row_limit(self): + """Test replication with many state events over several stream ids. + """ + + # we want to generate lots of state changes, but for this test, we want to + # spread out the state changes over a few stream IDs. + # + # We do this by having two branches in the DAG. On one, we have four moderators, + # each of which that generates lots of state; on the other, we de-op the users, + # thus invalidating all the state. + + NUM_USERS = 4 + STATES_PER_USER = _STREAM_UPDATE_TARGET_ROW_COUNT // 4 + 1 + + user_ids = ["@user%i:localhost" % (i,) for i in range(NUM_USERS)] + + # have the users join + for u in user_ids: + inject_member_event(self.hs, self.room_id, u, Membership.JOIN) + + # Update existing power levels with mod at PL50 + pls = self.helper.get_state( + self.room_id, EventTypes.PowerLevels, tok=self.user_tok + ) + pls["users"].update({u: 50 for u in user_ids}) + self.helper.send_state( + self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok, + ) + + # this is the point in the DAG where we make a fork + fork_point = self.get_success( + self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id) + ) # type: List[str] + + events = [] # type: List[EventBase] + for user in user_ids: + events.extend( + self._inject_state_event(sender=user) for _ in range(STATES_PER_USER) + ) + + self.replicate() + + # all those events and state changes should have landed + self.assertGreaterEqual( + len(self.test_handler.received_rdata_rows), 2 * len(events) + ) + + # disconnect, so that we can stack up the changes + self.disconnect() + self.test_handler.received_rdata_rows.clear() + + # now roll back all that state by de-modding the users + prev_events = fork_point + pl_events = [] + for u in user_ids: + pls["users"][u] = 0 + e = inject_event( + self.hs, + prev_event_ids=prev_events, + type=EventTypes.PowerLevels, + state_key="", + sender=self.user_id, + room_id=self.room_id, + content=pls, + ) + prev_events = [e.event_id] + pl_events.append(e) + + # tell the notifier to catch up to avoid duplicate rows. + # workaround for https://github.com/matrix-org/synapse/issues/7360 + # FIXME remove this when the above is fixed + self.replicate() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # we should have received all the expected rows in the right order + + received_rows = self.test_handler.received_rdata_rows + self.assertGreaterEqual(len(received_rows), len(events)) + for i in range(NUM_USERS): + # for each user, we expect the PL event row, followed by state rows for + # the PL event and each of the states that got reverted. + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "ev") + self.assertIsInstance(row.data, EventsStreamEventRow) + self.assertEqual(row.data.event_id, pl_events[i].event_id) + + # the state rows are unsorted + state_rows = [] # type: List[EventsStreamCurrentStateRow] + for j in range(STATES_PER_USER + 1): + stream_name, token, row = received_rows.pop(0) + self.assertEqual("events", stream_name) + self.assertIsInstance(row, EventsStreamRow) + self.assertEqual(row.type, "state") + self.assertIsInstance(row.data, EventsStreamCurrentStateRow) + state_rows.append(row.data) + + state_rows.sort(key=lambda r: r.state_key) + + sr = state_rows.pop(0) + self.assertEqual(sr.type, EventTypes.PowerLevels) + self.assertEqual(sr.event_id, pl_events[i].event_id) + for sr in state_rows: + self.assertEqual(sr.type, "test_state_event") + # "None" indicates the state has been deleted + self.assertIsNone(sr.event_id) + + self.assertEqual([], received_rows) + + event_count = 0 + + def _inject_test_event( + self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs + ) -> EventBase: + if sender is None: + sender = self.user_id + + if body is None: + body = "event %i" % (self.event_count,) + self.event_count += 1 + + return inject_event( + self.hs, + room_id=self.room_id, + sender=sender, + type="test_event", + content={"body": body}, + **kwargs + ) + + def _inject_state_event( + self, + body: Optional[str] = None, + state_key: Optional[str] = None, + sender: Optional[str] = None, + ) -> EventBase: + if sender is None: + sender = self.user_id + + if state_key is None: + state_key = "state_%i" % (self.event_count,) + self.event_count += 1 + + if body is None: + body = "state event %s" % (state_key,) + + return inject_event( + self.hs, + room_id=self.room_id, + sender=sender, + type="test_state_event", + state_key=state_key, + content={"body": body}, + ) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index a0206f7363..c122b8589c 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -12,6 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# type: ignore + +from mock import Mock + from synapse.replication.tcp.streams._base import ReceiptsStream from tests.replication.tcp.streams._base import BaseStreamTestCase @@ -20,11 +25,14 @@ USER_ID = "@feeling:blue" class ReceiptsStreamTestCase(BaseStreamTestCase): + def _build_replication_data_handler(self): + return Mock(wraps=super()._build_replication_data_handler()) + def test_receipt(self): self.reconnect() # make the client subscribe to the receipts stream - self.test_handler.streams.add("receipts") + self.test_handler.stream_positions.update({"receipts": 0}) # tell the master to send a new receipt self.get_success( diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py index f0ad6402ae..4d354a9db8 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock + from synapse.handlers.typing import RoomMember from synapse.replication.http import streams from synapse.replication.tcp.streams import TypingStream @@ -26,6 +28,9 @@ class TypingStreamTestCase(BaseStreamTestCase): streams.register_servlets, ] + def _build_replication_data_handler(self): + return Mock(wraps=super()._build_replication_data_handler()) + def test_typing(self): typing = self.hs.get_typing_handler() @@ -33,8 +38,8 @@ class TypingStreamTestCase(BaseStreamTestCase): self.reconnect() - # make the client subscribe to the receipts stream - self.test_handler.streams.add("typing") + # make the client subscribe to the typing stream + self.test_handler.stream_positions.update({"typing": 0}) typing._push_update(member=RoomMember(room_id, USER_ID), typing=True) @@ -75,6 +80,6 @@ class TypingStreamTestCase(BaseStreamTestCase): stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) - row = rdata_rows[0] # type: TypingStream.TypingStreamRow + row = rdata_rows[0] self.assertEqual(room_id, row.room_id) self.assertEqual([], row.user_ids) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 371637618d..22d734e763 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -39,7 +39,7 @@ class RestHelper(object): resource = attr.ib() auth_user_id = attr.ib() - def create_room_as(self, room_creator, is_public=True, tok=None): + def create_room_as(self, room_creator=None, is_public=True, tok=None): temp_id = self.auth_user_id self.auth_user_id = room_creator path = "/_matrix/client/r0/createRoom" diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index a7310cf12a..7b345b03bb 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2019 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,3 +17,22 @@ """ Utilities for running the unit tests """ +from typing import Awaitable, TypeVar + +TV = TypeVar("TV") + + +def get_awaitable_result(awaitable: Awaitable[TV]) -> TV: + """Get the result from an Awaitable which should have completed + + Asserts that the given awaitable has a result ready, and returns its value + """ + i = awaitable.__await__() + try: + next(i) + except StopIteration as e: + # awaitable returned a result + return e.value + + # if next didn't raise, the awaitable hasn't completed. + raise Exception("awaitable has not yet completed") diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py new file mode 100644 index 0000000000..8f6872761a --- /dev/null +++ b/tests/test_utils/event_injection.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import synapse.server +from synapse.api.constants import EventTypes +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.events import EventBase +from synapse.types import Collection + +from tests.test_utils import get_awaitable_result + + +""" +Utility functions for poking events into the storage of the server under test. +""" + + +def inject_member_event( + hs: synapse.server.HomeServer, + room_id: str, + sender: str, + membership: str, + target: Optional[str] = None, + extra_content: Optional[dict] = None, + **kwargs +) -> EventBase: + """Inject a membership event into a room.""" + if target is None: + target = sender + + content = {"membership": membership} + if extra_content: + content.update(extra_content) + + return inject_event( + hs, + room_id=room_id, + type=EventTypes.Member, + sender=sender, + state_key=target, + content=content, + **kwargs + ) + + +def inject_event( + hs: synapse.server.HomeServer, + room_version: Optional[str] = None, + prev_event_ids: Optional[Collection[str]] = None, + **kwargs +) -> EventBase: + """Inject a generic event into a room + + Args: + hs: the homeserver under test + room_version: the version of the room we're inserting into. + if not specified, will be looked up + prev_event_ids: prev_events for the event. If not specified, will be looked up + kwargs: fields for the event to be created + """ + test_reactor = hs.get_reactor() + + if room_version is None: + d = hs.get_datastore().get_room_version_id(kwargs["room_id"]) + test_reactor.advance(0) + room_version = get_awaitable_result(d) + + builder = hs.get_event_builder_factory().for_room_version( + KNOWN_ROOM_VERSIONS[room_version], kwargs + ) + d = hs.get_event_creation_handler().create_new_client_event( + builder, prev_event_ids=prev_event_ids + ) + test_reactor.advance(0) + event, context = get_awaitable_result(d) + + d = hs.get_storage().persistence.persist_event(event, context) + test_reactor.advance(0) + get_awaitable_result(d) + + return event diff --git a/tests/unittest.py b/tests/unittest.py index 27af5228fe..6b6f224e9c 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -32,7 +32,6 @@ from twisted.python.threadpool import ThreadPool from twisted.trial import unittest from synapse.api.constants import EventTypes, Membership -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.config.homeserver import HomeServerConfig from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.federation.transport import server as federation_server @@ -55,6 +54,7 @@ from tests.server import ( render, setup_test_homeserver, ) +from tests.test_utils import event_injection from tests.test_utils.logging_setup import setup_logging from tests.utils import default_config, setupdb @@ -596,36 +596,14 @@ class HomeserverTestCase(TestCase): """ Inject a membership event into a room. + Deprecated: use event_injection.inject_room_member directly + Args: room: Room ID to inject the event into. user: MXID of the user to inject the membership for. membership: The membership type. """ - event_builder_factory = self.hs.get_event_builder_factory() - event_creation_handler = self.hs.get_event_creation_handler() - - room_version = self.get_success( - self.hs.get_datastore().get_room_version_id(room) - ) - - builder = event_builder_factory.for_room_version( - KNOWN_ROOM_VERSIONS[room_version], - { - "type": EventTypes.Member, - "sender": user, - "state_key": user, - "room_id": room, - "content": {"membership": membership}, - }, - ) - - event, context = self.get_success( - event_creation_handler.create_new_client_event(builder) - ) - - self.get_success( - self.hs.get_storage().persistence.persist_event(event, context) - ) + event_injection.inject_member_event(self.hs, room, user, membership) class FederatingHomeserverTestCase(HomeserverTestCase): diff --git a/tox.ini b/tox.ini index 31011d7436..2630857436 100644 --- a/tox.ini +++ b/tox.ini @@ -204,6 +204,8 @@ commands = mypy \ synapse/storage/database.py \ synapse/streams \ synapse/util/caches/stream_change_cache.py \ + tests/replication/tcp/streams \ + tests/test_utils \ tests/util/test_stream_change_cache.py # To find all folders that pass mypy you run: -- cgit 1.5.1 From 3eab76ad43e1de4074550c468d11318d497ff632 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 29 Apr 2020 14:10:59 +0100 Subject: Don't relay REMOTE_SERVER_UP cmds to same conn. (#7352) For direct TCP connections we need the master to relay REMOTE_SERVER_UP commands to the other connections so that all instances get notified about it. The old implementation just relayed to all connections, assuming that sending back to the original sender of the command was safe. This is not true for redis, where commands sent get echoed back to the sender, which was causing master to effectively infinite loop sending and then re-receiving REMOTE_SERVER_UP commands that it sent. The fix is to ensure that we only relay to *other* connections and not to the connection we received the notification from. Fixes #7334. --- changelog.d/7352.feature | 1 + synapse/notifier.py | 9 ---- synapse/replication/tcp/handler.py | 63 ++++++++++++++++++++------ synapse/replication/tcp/protocol.py | 2 +- synapse/replication/tcp/redis.py | 2 +- tests/replication/tcp/test_remote_server_up.py | 62 +++++++++++++++++++++++++ 6 files changed, 114 insertions(+), 25 deletions(-) create mode 100644 changelog.d/7352.feature create mode 100644 tests/replication/tcp/test_remote_server_up.py (limited to 'synapse/replication/tcp/handler.py') diff --git a/changelog.d/7352.feature b/changelog.d/7352.feature new file mode 100644 index 0000000000..ce6140fdd1 --- /dev/null +++ b/changelog.d/7352.feature @@ -0,0 +1 @@ +Add support for running replication over Redis when using workers. diff --git a/synapse/notifier.py b/synapse/notifier.py index 6132727cbd..88a5a97caf 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -220,12 +220,6 @@ class Notifier(object): """ self.replication_callbacks.append(cb) - def add_remote_server_up_callback(self, cb: Callable[[str], None]): - """Add a callback that will be called when synapse detects a server - has been - """ - self.remote_server_up_callbacks.append(cb) - def on_new_room_event( self, event, room_stream_id, max_room_stream_id, extra_users=[] ): @@ -544,6 +538,3 @@ class Notifier(object): # circular dependencies. if self.federation_sender: self.federation_sender.wake_destination(server) - - for cb in self.remote_server_up_callbacks: - cb(server) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 3a8c7c7e2d..b8f49a8d0f 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -117,7 +117,6 @@ class ReplicationCommandHandler: self._server_notices_sender = None if self._is_master: self._server_notices_sender = hs.get_server_notices_sender() - self._notifier.add_remote_server_up_callback(self.send_remote_server_up) def start_replication(self, hs): """Helper method to start a replication connection to the remote server @@ -163,7 +162,7 @@ class ReplicationCommandHandler: port = hs.config.worker_replication_port hs.get_reactor().connectTCP(host, port, self._factory) - async def on_REPLICATE(self, cmd: ReplicateCommand): + async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand): # We only want to announce positions by the writer of the streams. # Currently this is just the master process. if not self._is_master: @@ -173,7 +172,7 @@ class ReplicationCommandHandler: current_token = stream.current_token() self.send_command(PositionCommand(stream_name, current_token)) - async def on_USER_SYNC(self, cmd: UserSyncCommand): + async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand): user_sync_counter.inc() if self._is_master: @@ -181,17 +180,23 @@ class ReplicationCommandHandler: cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms ) - async def on_CLEAR_USER_SYNC(self, cmd: ClearUserSyncsCommand): + async def on_CLEAR_USER_SYNC( + self, conn: AbstractConnection, cmd: ClearUserSyncsCommand + ): if self._is_master: await self._presence_handler.update_external_syncs_clear(cmd.instance_id) - async def on_FEDERATION_ACK(self, cmd: FederationAckCommand): + async def on_FEDERATION_ACK( + self, conn: AbstractConnection, cmd: FederationAckCommand + ): federation_ack_counter.inc() if self._federation_sender: self._federation_sender.federation_ack(cmd.token) - async def on_REMOVE_PUSHER(self, cmd: RemovePusherCommand): + async def on_REMOVE_PUSHER( + self, conn: AbstractConnection, cmd: RemovePusherCommand + ): remove_pusher_counter.inc() if self._is_master: @@ -201,7 +206,9 @@ class ReplicationCommandHandler: self._notifier.on_new_replication_data() - async def on_INVALIDATE_CACHE(self, cmd: InvalidateCacheCommand): + async def on_INVALIDATE_CACHE( + self, conn: AbstractConnection, cmd: InvalidateCacheCommand + ): invalidate_cache_counter.inc() if self._is_master: @@ -211,7 +218,7 @@ class ReplicationCommandHandler: cmd.cache_func, tuple(cmd.keys) ) - async def on_USER_IP(self, cmd: UserIpCommand): + async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand): user_ip_cache_counter.inc() if self._is_master: @@ -227,7 +234,7 @@ class ReplicationCommandHandler: if self._server_notices_sender: await self._server_notices_sender.on_user_ip(cmd.user_id) - async def on_RDATA(self, cmd: RdataCommand): + async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand): stream_name = cmd.stream_name inbound_rdata_count.labels(stream_name).inc() @@ -278,7 +285,7 @@ class ReplicationCommandHandler: logger.debug("Received rdata %s -> %s", stream_name, token) await self._replication_data_handler.on_rdata(stream_name, token, rows) - async def on_POSITION(self, cmd: PositionCommand): + async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): stream = self._streams.get(cmd.stream_name) if not stream: logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) @@ -332,12 +339,30 @@ class ReplicationCommandHandler: self._streams_connected.add(cmd.stream_name) - async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): + async def on_REMOTE_SERVER_UP( + self, conn: AbstractConnection, cmd: RemoteServerUpCommand + ): """"Called when get a new REMOTE_SERVER_UP command.""" self._replication_data_handler.on_remote_server_up(cmd.data) - if self._is_master: - self._notifier.notify_remote_server_up(cmd.data) + self._notifier.notify_remote_server_up(cmd.data) + + # We relay to all other connections to ensure every instance gets the + # notification. + # + # When configured to use redis we'll always only have one connection and + # so this is a no-op (all instances will have already received the same + # REMOTE_SERVER_UP command). + # + # For direct TCP connections this will relay to all other connections + # connected to us. When on master this will correctly fan out to all + # other direct TCP clients and on workers there'll only be the one + # connection to master. + # + # (The logic here should also be sound if we have a mix of Redis and + # direct TCP connections so long as there is only one traffic route + # between two instances, but that is not currently supported). + self.send_command(cmd, ignore_conn=conn) def new_connection(self, connection: AbstractConnection): """Called when we have a new connection. @@ -382,11 +407,21 @@ class ReplicationCommandHandler: """ return bool(self._connections) - def send_command(self, cmd: Command): + def send_command( + self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None + ): """Send a command to all connected connections. + + Args: + cmd + ignore_conn: If set don't send command to the given connection. + Used when relaying commands from one connection to all others. """ if self._connections: for connection in self._connections: + if connection == ignore_conn: + continue + try: connection.send_command(cmd) except Exception: diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index e3f64eba8f..4198eece71 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -260,7 +260,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # Then call out to the handler. cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None) if cmd_func: - await cmd_func(cmd) + await cmd_func(self, cmd) handled = True if not handled: diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 49b3ed0c5e..617e860f95 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -112,7 +112,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): # Then call out to the handler. cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) if cmd_func: - await cmd_func(cmd) + await cmd_func(self, cmd) handled = True if not handled: diff --git a/tests/replication/tcp/test_remote_server_up.py b/tests/replication/tcp/test_remote_server_up.py new file mode 100644 index 0000000000..d1c15caeb0 --- /dev/null +++ b/tests/replication/tcp/test_remote_server_up.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +from twisted.internet.interfaces import IProtocol +from twisted.test.proto_helpers import StringTransport + +from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory + +from tests.unittest import HomeserverTestCase + + +class RemoteServerUpTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, hs): + self.factory = ReplicationStreamProtocolFactory(hs) + + def _make_client(self) -> Tuple[IProtocol, StringTransport]: + """Create a new direct TCP replication connection + """ + + proto = self.factory.buildProtocol(("127.0.0.1", 0)) + transport = StringTransport() + proto.makeConnection(transport) + + # We can safely ignore the commands received during connection. + self.pump() + transport.clear() + + return proto, transport + + def test_relay(self): + """Test that Synapse will relay REMOTE_SERVER_UP commands to all + other connections, but not the one that sent it. + """ + + proto1, transport1 = self._make_client() + + # We shouldn't receive an echo. + proto1.dataReceived(b"REMOTE_SERVER_UP example.com\n") + self.pump() + self.assertEqual(transport1.value(), b"") + + # But we should see an echo if we connect another client + proto2, transport2 = self._make_client() + proto1.dataReceived(b"REMOTE_SERVER_UP example.com\n") + + self.pump() + self.assertEqual(transport1.value(), b"") + self.assertEqual(transport2.value(), b"REMOTE_SERVER_UP example.com\n") -- cgit 1.5.1 From 37f6823f5b91f27b9dd8de8fc0e52d5ea889647c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 29 Apr 2020 16:23:08 +0100 Subject: Add instance name to RDATA/POSITION commands (#7364) This is primarily for allowing us to send those commands from workers, but for now simply allows us to ignore echoed RDATA/POSITION commands that we sent (we get echoes of sent commands when using redis). Currently we log a WARNING on the master process every time we receive an echoed RDATA. --- changelog.d/7364.misc | 1 + docs/tcp_replication.md | 41 +++++++++++++++++++------------- synapse/app/_base.py | 4 ++-- synapse/logging/opentracing.py | 23 ++++++++---------- synapse/replication/tcp/commands.py | 37 +++++++++++++++++++--------- synapse/replication/tcp/handler.py | 17 ++++++++++--- synapse/server.py | 13 ++++++++-- synapse/server.pyi | 2 ++ tests/replication/slave/storage/_base.py | 1 + tests/replication/tcp/test_commands.py | 6 +++-- 10 files changed, 95 insertions(+), 50 deletions(-) create mode 100644 changelog.d/7364.misc (limited to 'synapse/replication/tcp/handler.py') diff --git a/changelog.d/7364.misc b/changelog.d/7364.misc new file mode 100644 index 0000000000..bb5d727cf4 --- /dev/null +++ b/changelog.d/7364.misc @@ -0,0 +1 @@ +Add an `instance_name` to `RDATA` and `POSITION` replication commands. diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md index b922d9cf7e..ab2fffbfe4 100644 --- a/docs/tcp_replication.md +++ b/docs/tcp_replication.md @@ -15,15 +15,17 @@ example flow would be (where '>' indicates master to worker and > SERVER example.com < REPLICATE - > POSITION events 53 - > RDATA events 54 ["$foo1:bar.com", ...] - > RDATA events 55 ["$foo4:bar.com", ...] + > POSITION events master 53 + > RDATA events master 54 ["$foo1:bar.com", ...] + > RDATA events master 55 ["$foo4:bar.com", ...] The example shows the server accepting a new connection and sending its identity with the `SERVER` command, followed by the client server to respond with the position of all streams. The server then periodically sends `RDATA` commands -which have the format `RDATA `, where the format of -`` is defined by the individual streams. +which have the format `RDATA `, where +the format of `` is defined by the individual streams. The +`` is the name of the Synapse process that generated the data +(usually "master"). Error reporting happens by either the client or server sending an ERROR command, and usually the connection will be closed. @@ -52,7 +54,7 @@ The basic structure of the protocol is line based, where the initial word of each line specifies the command. The rest of the line is parsed based on the command. For example, the RDATA command is defined as: - RDATA + RDATA (Note that may contains spaces, but cannot contain newlines.) @@ -136,11 +138,11 @@ the wire: < NAME synapse.app.appservice < PING 1490197665618 < REPLICATE - > POSITION events 1 - > POSITION backfill 1 - > POSITION caches 1 - > RDATA caches 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513] - > RDATA events 14 ["$149019767112vOHxz:localhost:8823", + > POSITION events master 1 + > POSITION backfill master 1 + > POSITION caches master 1 + > RDATA caches master 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513] + > RDATA events master 14 ["$149019767112vOHxz:localhost:8823", "!AFDCvgApUmpdfVjIXm:localhost:8823","m.room.guest_access","",null] < PING 1490197675618 > ERROR server stopping @@ -151,10 +153,10 @@ position without needing to send data with the `RDATA` command. An example of a batched set of `RDATA` is: - > RDATA caches batch ["get_user_by_id",["@test:localhost:8823"],1490197670513] - > RDATA caches batch ["get_user_by_id",["@test2:localhost:8823"],1490197670513] - > RDATA caches batch ["get_user_by_id",["@test3:localhost:8823"],1490197670513] - > RDATA caches 54 ["get_user_by_id",["@test4:localhost:8823"],1490197670513] + > RDATA caches master batch ["get_user_by_id",["@test:localhost:8823"],1490197670513] + > RDATA caches master batch ["get_user_by_id",["@test2:localhost:8823"],1490197670513] + > RDATA caches master batch ["get_user_by_id",["@test3:localhost:8823"],1490197670513] + > RDATA caches master 54 ["get_user_by_id",["@test4:localhost:8823"],1490197670513] In this case the client shouldn't advance their caches token until it sees the the last `RDATA`. @@ -178,6 +180,11 @@ client (C): updates, and if so then fetch them out of band. Sent in response to a REPLICATE command (but can happen at any time). + The POSITION command includes the source of the stream. Currently all streams + are written by a single process (usually "master"). If fetching missing + updates via HTTP API, rather than via the DB, then processes should make the + request to the appropriate process. + #### ERROR (S, C) There was an error @@ -234,12 +241,12 @@ Each individual cache invalidation results in a row being sent down replication, which includes the cache name (the name of the function) and they key to invalidate. For example: - > RDATA caches 550953771 ["get_user_by_id", ["@bob:example.com"], 1550574873251] + > RDATA caches master 550953771 ["get_user_by_id", ["@bob:example.com"], 1550574873251] Alternatively, an entire cache can be invalidated by sending down a `null` instead of the key. For example: - > RDATA caches 550953772 ["get_user_by_id", null, 1550574873252] + > RDATA caches master 550953772 ["get_user_by_id", null, 1550574873252] However, there are times when a number of caches need to be invalidated at the same time with the same key. To reduce traffic we batch those diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 4d84f4595a..628292b890 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -270,7 +270,7 @@ def start(hs, listeners=None): # Start the tracer synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa - hs.config + hs ) # It is now safe to start your Synapse. @@ -316,7 +316,7 @@ def setup_sentry(hs): scope.set_tag("matrix_server_name", hs.config.server_name) app = hs.config.worker_app if hs.config.worker_app else "synapse.app.homeserver" - name = hs.config.worker_name if hs.config.worker_name else "master" + name = hs.get_instance_name() scope.set_tag("worker_app", app) scope.set_tag("worker_name", name) diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 0638cec429..5dddf57008 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -171,7 +171,7 @@ import logging import re import types from functools import wraps -from typing import Dict +from typing import TYPE_CHECKING, Dict from canonicaljson import json @@ -179,6 +179,9 @@ from twisted.internet import defer from synapse.config import ConfigError +if TYPE_CHECKING: + from synapse.server import HomeServer + # Helper class @@ -297,14 +300,11 @@ def _noop_context_manager(*args, **kwargs): # Setup -def init_tracer(config): +def init_tracer(hs: "HomeServer"): """Set the whitelists and initialise the JaegerClient tracer - - Args: - config (HomeserverConfig): The config used by the homeserver """ global opentracing - if not config.opentracer_enabled: + if not hs.config.opentracer_enabled: # We don't have a tracer opentracing = None return @@ -315,18 +315,15 @@ def init_tracer(config): "installed." ) - # Include the worker name - name = config.worker_name if config.worker_name else "master" - # Pull out the jaeger config if it was given. Otherwise set it to something sensible. # See https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/config.py - set_homeserver_whitelist(config.opentracer_whitelist) + set_homeserver_whitelist(hs.config.opentracer_whitelist) JaegerConfig( - config=config.jaeger_config, - service_name="{} {}".format(config.server_name, name), - scope_manager=LogContextScopeManager(config), + config=hs.config.jaeger_config, + service_name="{} {}".format(hs.config.server_name, hs.get_instance_name()), + scope_manager=LogContextScopeManager(hs.config), ).initialize_tracer() diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index c7880d4b63..f58e384d17 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -95,7 +95,7 @@ class RdataCommand(Command): Format:: - RDATA + RDATA The `` may either be a numeric stream id OR "batch". The latter case is used to support sending multiple updates with the same stream ID. This @@ -105,33 +105,40 @@ class RdataCommand(Command): The client should batch all incoming RDATA with a token of "batch" (per stream_name) until it sees an RDATA with a numeric stream ID. + The `` is the source of the new data (usually "master"). + `` of "batch" maps to the instance variable `token` being None. An example of a batched series of RDATA:: - RDATA presence batch ["@foo:example.com", "online", ...] - RDATA presence batch ["@bar:example.com", "online", ...] - RDATA presence 59 ["@baz:example.com", "online", ...] + RDATA presence master batch ["@foo:example.com", "online", ...] + RDATA presence master batch ["@bar:example.com", "online", ...] + RDATA presence master 59 ["@baz:example.com", "online", ...] """ NAME = "RDATA" - def __init__(self, stream_name, token, row): + def __init__(self, stream_name, instance_name, token, row): self.stream_name = stream_name + self.instance_name = instance_name self.token = token self.row = row @classmethod def from_line(cls, line): - stream_name, token, row_json = line.split(" ", 2) + stream_name, instance_name, token, row_json = line.split(" ", 3) return cls( - stream_name, None if token == "batch" else int(token), json.loads(row_json) + stream_name, + instance_name, + None if token == "batch" else int(token), + json.loads(row_json), ) def to_line(self): return " ".join( ( self.stream_name, + self.instance_name, str(self.token) if self.token is not None else "batch", _json_encoder.encode(self.row), ) @@ -145,23 +152,31 @@ class PositionCommand(Command): """Sent by the server to tell the client the stream postition without needing to send an RDATA. + Format:: + + POSITION + On receipt of a POSITION command clients should check if they have missed any updates, and if so then fetch them out of band. + + The `` is the process that sent the command and is the source + of the stream. """ NAME = "POSITION" - def __init__(self, stream_name, token): + def __init__(self, stream_name, instance_name, token): self.stream_name = stream_name + self.instance_name = instance_name self.token = token @classmethod def from_line(cls, line): - stream_name, token = line.split(" ", 1) - return cls(stream_name, int(token)) + stream_name, instance_name, token = line.split(" ", 2) + return cls(stream_name, instance_name, int(token)) def to_line(self): - return " ".join((self.stream_name, str(self.token))) + return " ".join((self.stream_name, self.instance_name, str(self.token))) class ErrorCommand(_SimpleCommand): diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index b8f49a8d0f..6f7054d5af 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -79,6 +79,7 @@ class ReplicationCommandHandler: self._notifier = hs.get_notifier() self._clock = hs.get_clock() self._instance_id = hs.get_instance_id() + self._instance_name = hs.get_instance_name() # Set of streams that we've caught up with. self._streams_connected = set() # type: Set[str] @@ -156,7 +157,7 @@ class ReplicationCommandHandler: hs.config.redis.redis_host, hs.config.redis.redis_port, self._factory, ) else: - client_name = hs.config.worker_name + client_name = hs.get_instance_name() self._factory = DirectTcpReplicationClientFactory(hs, client_name, self) host = hs.config.worker_replication_host port = hs.config.worker_replication_port @@ -170,7 +171,9 @@ class ReplicationCommandHandler: for stream_name, stream in self._streams.items(): current_token = stream.current_token() - self.send_command(PositionCommand(stream_name, current_token)) + self.send_command( + PositionCommand(stream_name, self._instance_name, current_token) + ) async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand): user_sync_counter.inc() @@ -235,6 +238,10 @@ class ReplicationCommandHandler: await self._server_notices_sender.on_user_ip(cmd.user_id) async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand): + if cmd.instance_name == self._instance_name: + # Ignore RDATA that are just our own echoes + return + stream_name = cmd.stream_name inbound_rdata_count.labels(stream_name).inc() @@ -286,6 +293,10 @@ class ReplicationCommandHandler: await self._replication_data_handler.on_rdata(stream_name, token, rows) async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): + if cmd.instance_name == self._instance_name: + # Ignore POSITION that are just our own echoes + return + stream = self._streams.get(cmd.stream_name) if not stream: logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) @@ -485,7 +496,7 @@ class ReplicationCommandHandler: We need to check if the client is interested in the stream or not """ - self.send_command(RdataCommand(stream_name, token, data)) + self.send_command(RdataCommand(stream_name, self._instance_name, token, data)) UpdateToken = TypeVar("UpdateToken") diff --git a/synapse/server.py b/synapse/server.py index 9d273c980c..bf97a16c09 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -234,7 +234,8 @@ class HomeServer(object): self._listening_services = [] self.start_time = None - self.instance_id = random_string(5) + self._instance_id = random_string(5) + self._instance_name = config.worker_name or "master" self.clock = Clock(reactor) self.distributor = Distributor() @@ -254,7 +255,15 @@ class HomeServer(object): This is used to distinguish running instances in worker-based deployments. """ - return self.instance_id + return self._instance_id + + def get_instance_name(self) -> str: + """A unique name for this synapse process. + + Used to identify the process over replication and in config. Does not + change over restarts. + """ + return self._instance_name def setup(self): logger.info("Setting up.") diff --git a/synapse/server.pyi b/synapse/server.pyi index fc5886f762..18043a2593 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -122,6 +122,8 @@ class HomeServer(object): pass def get_instance_id(self) -> str: pass + def get_instance_name(self) -> str: + pass def get_event_builder_factory(self) -> EventBuilderFactory: pass def get_storage(self) -> synapse.storage.Storage: diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 395c7d0306..1615dfab5e 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -57,6 +57,7 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): # We now do some gut wrenching so that we have a client that is based # off of the slave store rather than the main store. self.replication_handler = ReplicationCommandHandler(self.hs) + self.replication_handler._instance_name = "worker" self.replication_handler._replication_data_handler = ReplicationDataHandler( self.slaved_store ) diff --git a/tests/replication/tcp/test_commands.py b/tests/replication/tcp/test_commands.py index 3cbcb513cc..7ddfd0a733 100644 --- a/tests/replication/tcp/test_commands.py +++ b/tests/replication/tcp/test_commands.py @@ -28,15 +28,17 @@ class ParseCommandTestCase(TestCase): self.assertIsInstance(cmd, ReplicateCommand) def test_parse_rdata(self): - line = 'RDATA events 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]' + line = 'RDATA events master 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]' cmd = parse_command_from_line(line) self.assertIsInstance(cmd, RdataCommand) self.assertEqual(cmd.stream_name, "events") + self.assertEqual(cmd.instance_name, "master") self.assertEqual(cmd.token, 6287863) def test_parse_rdata_batch(self): - line = 'RDATA presence batch ["@foo:example.com", "online"]' + line = 'RDATA presence master batch ["@foo:example.com", "online"]' cmd = parse_command_from_line(line) self.assertIsInstance(cmd, RdataCommand) self.assertEqual(cmd.stream_name, "presence") + self.assertEqual(cmd.instance_name, "master") self.assertIsNone(cmd.token) -- cgit 1.5.1 From 3085cde577216519d789c8160262831cb2029972 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 1 May 2020 15:21:35 +0100 Subject: Use `stream.current_token()` and remove `stream_positions()` (#7172) We move the processing of typing and federation replication traffic into their handlers so that `Stream.current_token()` points to a valid token. This allows us to remove `get_streams_to_replicate()` and `stream_positions()`. --- changelog.d/7172.misc | 1 + synapse/app/generic_worker.py | 16 ------------ synapse/replication/slave/storage/_base.py | 15 +----------- synapse/replication/slave/storage/account_data.py | 8 ------ synapse/replication/slave/storage/deviceinbox.py | 5 ---- synapse/replication/slave/storage/devices.py | 10 -------- synapse/replication/slave/storage/events.py | 6 ----- synapse/replication/slave/storage/groups.py | 5 ---- synapse/replication/slave/storage/presence.py | 9 ------- synapse/replication/slave/storage/push_rule.py | 5 ---- synapse/replication/slave/storage/pushers.py | 5 ---- synapse/replication/slave/storage/receipts.py | 5 ---- synapse/replication/slave/storage/room.py | 5 ---- synapse/replication/tcp/client.py | 19 +------------- synapse/replication/tcp/handler.py | 10 +------- tests/replication/tcp/streams/_base.py | 30 ++++++++--------------- tests/replication/tcp/streams/test_events.py | 24 ++++++++++++------ tests/replication/tcp/streams/test_receipts.py | 3 --- tests/replication/tcp/streams/test_typing.py | 3 --- 19 files changed, 30 insertions(+), 154 deletions(-) create mode 100644 changelog.d/7172.misc (limited to 'synapse/replication/tcp/handler.py') diff --git a/changelog.d/7172.misc b/changelog.d/7172.misc new file mode 100644 index 0000000000..ffecdf97fe --- /dev/null +++ b/changelog.d/7172.misc @@ -0,0 +1 @@ +Use `stream.current_token()` and remove `stream_positions()`. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 0ace7b787d..97b9b81237 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -413,12 +413,6 @@ class GenericWorkerTyping(object): # map room IDs to sets of users currently typing self._room_typing = {} - def stream_positions(self): - # We must update this typing token from the response of the previous - # sync. In particular, the stream id may "reset" back to zero/a low - # value which we *must* use for the next replication request. - return {"typing": self._latest_room_serial} - def process_replication_rows(self, token, rows): if self._latest_room_serial > token: # The master has gone backwards. To prevent inconsistent data, just @@ -658,13 +652,6 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler): ) await self.process_and_notify(stream_name, token, rows) - def get_streams_to_replicate(self): - args = super(GenericWorkerReplicationHandler, self).get_streams_to_replicate() - args.update(self.typing_handler.stream_positions()) - if self.send_handler: - args.update(self.send_handler.stream_positions()) - return args - async def process_and_notify(self, stream_name, token, rows): try: if self.send_handler: @@ -799,9 +786,6 @@ class FederationSenderHandler(object): def wake_destination(self, server: str): self.federation_sender.wake_destination(server) - def stream_positions(self): - return {"federation": self.federation_position} - async def process_replication_rows(self, stream_name, token, rows): # The federation stream contains things that we want to send out, e.g. # presence, typing, etc. diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 751c799d94..5d7c8871a4 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Dict, Optional +from typing import Optional import six @@ -49,19 +49,6 @@ class BaseSlavedStore(CacheInvalidationWorkerStore): self.hs = hs - def stream_positions(self) -> Dict[str, int]: - """ - Get the current positions of all the streams this store wants to subscribe to - - Returns: - map from stream name to the most recent update we have for - that stream (ie, the point we want to start replicating from) - """ - pos = {} - if self._cache_id_gen: - pos["caches"] = self._cache_id_gen.get_current_token() - return pos - def get_cache_stream_token(self): if self._cache_id_gen: return self._cache_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index ebe94909cb..65e54b1c71 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -32,14 +32,6 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved def get_max_account_data_stream_id(self): return self._account_data_id_gen.get_current_token() - def stream_positions(self): - result = super(SlavedAccountDataStore, self).stream_positions() - position = self._account_data_id_gen.get_current_token() - result["user_account_data"] = position - result["room_account_data"] = position - result["tag_account_data"] = position - return result - def process_replication_rows(self, stream_name, token, rows): if stream_name == "tag_account_data": self._account_data_id_gen.advance(token) diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 0c237c6e0f..c923751e50 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -43,11 +43,6 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): expiry_ms=30 * 60 * 1000, ) - def stream_positions(self): - result = super(SlavedDeviceInboxStore, self).stream_positions() - result["to_device"] = self._device_inbox_id_gen.get_current_token() - return result - def process_replication_rows(self, stream_name, token, rows): if stream_name == "to_device": self._device_inbox_id_gen.advance(token) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 23b1650e41..58fb0eaae3 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -48,16 +48,6 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto "DeviceListFederationStreamChangeCache", device_list_max ) - def stream_positions(self): - result = super(SlavedDeviceStore, self).stream_positions() - # The user signature stream uses the same stream ID generator as the - # device list stream, so set them both to the device list ID - # generator's current token. - current_token = self._device_list_id_gen.get_current_token() - result[DeviceListsStream.NAME] = current_token - result[UserSignatureStream.NAME] = current_token - return result - def process_replication_rows(self, stream_name, token, rows): if stream_name == DeviceListsStream.NAME: self._device_list_id_gen.advance(token) diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index e73342c657..15011259df 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -93,12 +93,6 @@ class SlavedEventStore( def get_room_min_stream_ordering(self): return self._backfill_id_gen.get_current_token() - def stream_positions(self): - result = super(SlavedEventStore, self).stream_positions() - result["events"] = self._stream_id_gen.get_current_token() - result["backfill"] = -self._backfill_id_gen.get_current_token() - return result - def process_replication_rows(self, stream_name, token, rows): if stream_name == "events": self._stream_id_gen.advance(token) diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 2d4fd08cf5..01bcf0e882 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -37,11 +37,6 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): def get_group_stream_token(self): return self._group_updates_id_gen.get_current_token() - def stream_positions(self): - result = super(SlavedGroupServerStore, self).stream_positions() - result["groups"] = self._group_updates_id_gen.get_current_token() - return result - def process_replication_rows(self, stream_name, token, rows): if stream_name == "groups": self._group_updates_id_gen.advance(token) diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index ad8f0c15a9..fae3125072 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -41,15 +41,6 @@ class SlavedPresenceStore(BaseSlavedStore): def get_current_presence_token(self): return self._presence_id_gen.get_current_token() - def stream_positions(self): - result = super(SlavedPresenceStore, self).stream_positions() - - if self.hs.config.use_presence: - position = self._presence_id_gen.get_current_token() - result["presence"] = position - - return result - def process_replication_rows(self, stream_name, token, rows): if stream_name == "presence": self._presence_id_gen.advance(token) diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index eebd5a1fb6..6138796da4 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -37,11 +37,6 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): def get_max_push_rules_stream_id(self): return self._push_rules_stream_id_gen.get_current_token() - def stream_positions(self): - result = super(SlavedPushRuleStore, self).stream_positions() - result["push_rules"] = self._push_rules_stream_id_gen.get_current_token() - return result - def process_replication_rows(self, stream_name, token, rows): if stream_name == "push_rules": self._push_rules_stream_id_gen.advance(token) diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index bce8a3d115..67be337945 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -28,11 +28,6 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] ) - def stream_positions(self): - result = super(SlavedPusherStore, self).stream_positions() - result["pushers"] = self._pushers_id_gen.get_current_token() - return result - def get_pushers_stream_token(self): return self._pushers_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index d40dc6e1f5..993432edcb 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -42,11 +42,6 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() - def stream_positions(self): - result = super(SlavedReceiptsStore, self).stream_positions() - result["receipts"] = self._receipts_id_gen.get_current_token() - return result - def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): self.get_receipts_for_user.invalidate((user_id, receipt_type)) self._get_linearized_receipts_for_room.invalidate_many((room_id,)) diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index 3a20f45316..10dda8708f 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -30,11 +30,6 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore): def get_current_public_room_stream_id(self): return self._public_room_id_gen.get_current_token() - def stream_positions(self): - result = super(RoomStore, self).stream_positions() - result["public_rooms"] = self._public_room_id_gen.get_current_token() - return result - def process_replication_rows(self, stream_name, token, rows): if stream_name == "public_rooms": self._public_room_id_gen.advance(token) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 2d07b8b2d0..5c28fd4ac3 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -16,7 +16,7 @@ """ import logging -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING from twisted.internet.protocol import ReconnectingClientFactory @@ -100,23 +100,6 @@ class ReplicationDataHandler: """ self.store.process_replication_rows(stream_name, token, rows) - def get_streams_to_replicate(self) -> Dict[str, int]: - """Called when a new connection has been established and we need to - subscribe to streams. - - Returns: - map from stream name to the most recent update we have for - that stream (ie, the point we want to start replicating from) - """ - args = self.store.stream_positions() - user_account_data = args.pop("user_account_data", None) - room_account_data = args.pop("room_account_data", None) - if user_account_data: - args["account_data"] = user_account_data - elif room_account_data: - args["account_data"] = room_account_data - return args - async def on_position(self, stream_name: str, token: int): self.store.process_replication_rows(stream_name, token, []) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 6f7054d5af..d72f3d0cf9 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -314,15 +314,7 @@ class ReplicationCommandHandler: self._pending_batches.pop(cmd.stream_name, []) # Find where we previously streamed up to. - current_token = self._replication_data_handler.get_streams_to_replicate().get( - cmd.stream_name - ) - if current_token is None: - logger.warning( - "Got POSITION for stream we're not subscribed to: %s", - cmd.stream_name, - ) - return + current_token = stream.current_token() # If the position token matches our current token then we're up to # date and there's nothing to do. Otherwise, fetch all updates diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index 83e16cfe3d..8c104f8d1d 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, List, Optional, Tuple import attr @@ -22,13 +22,15 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime from twisted.internet.task import LoopingCall from twisted.web.http import HTTPChannel -from synapse.app.generic_worker import GenericWorkerServer +from synapse.app.generic_worker import ( + GenericWorkerReplicationHandler, + GenericWorkerServer, +) from synapse.http.site import SynapseRequest -from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.tcp.client import ReplicationDataHandler from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory +from synapse.server import HomeServer from synapse.util import Clock from tests import unittest @@ -77,7 +79,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self._server_transport = None def _build_replication_data_handler(self): - return TestReplicationDataHandler(self.worker_hs.get_datastore()) + return TestReplicationDataHandler(self.worker_hs) def reconnect(self): if self._client_transport: @@ -172,32 +174,20 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self.assertEqual(request.method, b"GET") -class TestReplicationDataHandler(ReplicationDataHandler): +class TestReplicationDataHandler(GenericWorkerReplicationHandler): """Drop-in for ReplicationDataHandler which just collects RDATA rows""" - def __init__(self, store: BaseSlavedStore): - super().__init__(store) - - # streams to subscribe to: map from stream id to position - self.stream_positions = {} # type: Dict[str, int] + def __init__(self, hs: HomeServer): + super().__init__(hs) # list of received (stream_name, token, row) tuples self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]] - def get_streams_to_replicate(self): - return self.stream_positions - async def on_rdata(self, stream_name, token, rows): await super().on_rdata(stream_name, token, rows) for r in rows: self.received_rdata_rows.append((stream_name, token, r)) - if ( - stream_name in self.stream_positions - and token > self.stream_positions[stream_name] - ): - self.stream_positions[stream_name] = token - @attr.s() class OneShotRequestFactory: diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index 1fa28084f9..8bd67bb9f1 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -43,7 +43,6 @@ class EventsStreamTestCase(BaseStreamTestCase): self.user_tok = self.login("u1", "pass") self.reconnect() - self.test_handler.stream_positions["events"] = 0 self.room_id = self.helper.create_room_as(tok=self.user_tok) self.test_handler.received_rdata_rows.clear() @@ -80,8 +79,12 @@ class EventsStreamTestCase(BaseStreamTestCase): self.reconnect() self.replicate() - # we should have received all the expected rows in the right order - received_rows = self.test_handler.received_rdata_rows + # we should have received all the expected rows in the right order (as + # well as various cache invalidation updates which we ignore) + received_rows = [ + row for row in self.test_handler.received_rdata_rows if row[0] == "events" + ] + for event in events: stream_name, token, row = received_rows.pop(0) self.assertEqual("events", stream_name) @@ -184,7 +187,8 @@ class EventsStreamTestCase(BaseStreamTestCase): self.reconnect() self.replicate() - # now we should have received all the expected rows in the right order. + # we should have received all the expected rows in the right order (as + # well as various cache invalidation updates which we ignore) # # we expect: # @@ -193,7 +197,9 @@ class EventsStreamTestCase(BaseStreamTestCase): # of the states that got reverted. # - two rows for state2 - received_rows = self.test_handler.received_rdata_rows + received_rows = [ + row for row in self.test_handler.received_rdata_rows if row[0] == "events" + ] # first check the first two rows, which should be state1 @@ -334,9 +340,11 @@ class EventsStreamTestCase(BaseStreamTestCase): self.reconnect() self.replicate() - # we should have received all the expected rows in the right order - - received_rows = self.test_handler.received_rdata_rows + # we should have received all the expected rows in the right order (as + # well as various cache invalidation updates which we ignore) + received_rows = [ + row for row in self.test_handler.received_rdata_rows if row[0] == "events" + ] self.assertGreaterEqual(len(received_rows), len(events)) for i in range(NUM_USERS): # for each user, we expect the PL event row, followed by state rows for diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index c122b8589c..df332ee679 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -31,9 +31,6 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): def test_receipt(self): self.reconnect() - # make the client subscribe to the receipts stream - self.test_handler.stream_positions.update({"receipts": 0}) - # tell the master to send a new receipt self.get_success( self.hs.get_datastore().insert_receipt( diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py index 4d354a9db8..e8d17ca68a 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py @@ -38,9 +38,6 @@ class TypingStreamTestCase(BaseStreamTestCase): self.reconnect() - # make the client subscribe to the typing stream - self.test_handler.stream_positions.update({"typing": 0}) - typing._push_update(member=RoomMember(room_id, USER_ID), typing=True) self.reactor.advance(0) -- cgit 1.5.1 From 0e719f23981b8294df66ba7f38b8c7cc99fad228 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 1 May 2020 17:19:56 +0100 Subject: Thread through instance name to replication client. (#7369) For in memory streams when fetching updates on workers we need to query the source of the stream, which currently is hard coded to be master. This PR threads through the source instance we received via `POSITION` through to the update function in each stream, which can then be passed to the replication client for in memory streams. --- changelog.d/7369.misc | 1 + synapse/app/generic_worker.py | 10 +++--- synapse/replication/http/_base.py | 19 +++++++++- synapse/replication/http/streams.py | 4 ++- synapse/replication/tcp/client.py | 12 ++++--- synapse/replication/tcp/handler.py | 20 ++++++++--- synapse/replication/tcp/streams/_base.py | 50 +++++++++++++++++++------- synapse/replication/tcp/streams/events.py | 10 ++++-- synapse/replication/tcp/streams/federation.py | 4 +-- tests/replication/tcp/streams/_base.py | 4 +-- tests/replication/tcp/streams/test_receipts.py | 4 +-- tests/replication/tcp/streams/test_typing.py | 4 +-- 12 files changed, 101 insertions(+), 41 deletions(-) create mode 100644 changelog.d/7369.misc (limited to 'synapse/replication/tcp/handler.py') diff --git a/changelog.d/7369.misc b/changelog.d/7369.misc new file mode 100644 index 0000000000..060b09c888 --- /dev/null +++ b/changelog.d/7369.misc @@ -0,0 +1 @@ +Thread through instance name to replication client. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 97b9b81237..667ad20428 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -646,13 +646,11 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler): else: self.send_handler = None - async def on_rdata(self, stream_name, token, rows): - await super(GenericWorkerReplicationHandler, self).on_rdata( - stream_name, token, rows - ) - await self.process_and_notify(stream_name, token, rows) + async def on_rdata(self, stream_name, instance_name, token, rows): + await super().on_rdata(stream_name, instance_name, token, rows) + await self._process_and_notify(stream_name, instance_name, token, rows) - async def process_and_notify(self, stream_name, token, rows): + async def _process_and_notify(self, stream_name, instance_name, token, rows): try: if self.send_handler: await self.send_handler.process_replication_rows( diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 1be1ccbdf3..f88c80ae84 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -16,6 +16,7 @@ import abc import logging import re +from inspect import signature from typing import Dict, List, Tuple from six import raise_from @@ -60,6 +61,8 @@ class ReplicationEndpoint(object): must call `register` to register the path with the HTTP server. Requests can be sent by calling the client returned by `make_client`. + Requests are sent to master process by default, but can be sent to other + named processes by specifying an `instance_name` keyword argument. Attributes: NAME (str): A name for the endpoint, added to the path as well as used @@ -91,6 +94,16 @@ class ReplicationEndpoint(object): hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000 ) + # We reserve `instance_name` as a parameter to sending requests, so we + # assert here that sub classes don't try and use the name. + assert ( + "instance_name" not in self.PATH_ARGS + ), "`instance_name` is a reserved paramater name" + assert ( + "instance_name" + not in signature(self.__class__._serialize_payload).parameters + ), "`instance_name` is a reserved paramater name" + assert self.METHOD in ("PUT", "POST", "GET") @abc.abstractmethod @@ -135,7 +148,11 @@ class ReplicationEndpoint(object): @trace(opname="outgoing_replication_request") @defer.inlineCallbacks - def send_request(**kwargs): + def send_request(instance_name="master", **kwargs): + # Currently we only support sending requests to master process. + if instance_name != "master": + raise Exception("Unknown instance") + data = yield cls._serialize_payload(**kwargs) url_args = [ diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py index f35cebc710..0459f582bf 100644 --- a/synapse/replication/http/streams.py +++ b/synapse/replication/http/streams.py @@ -50,6 +50,8 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): def __init__(self, hs): super().__init__(hs) + self._instance_name = hs.get_instance_name() + # We pull the streams from the replication steamer (if we try and make # them ourselves we end up in an import loop). self.streams = hs.get_replication_streamer().get_streams() @@ -67,7 +69,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): upto_token = parse_integer(request, "upto_token", required=True) updates, upto_token, limited = await stream.get_updates_since( - from_token, upto_token + self._instance_name, from_token, upto_token ) return ( diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 5c28fd4ac3..3bbf3c3569 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -86,17 +86,19 @@ class ReplicationDataHandler: def __init__(self, store: BaseSlavedStore): self.store = store - async def on_rdata(self, stream_name: str, token: int, rows: list): + async def on_rdata( + self, stream_name: str, instance_name: str, token: int, rows: list + ): """Called to handle a batch of replication data with a given stream token. By default this just pokes the slave store. Can be overridden in subclasses to handle more. Args: - stream_name (str): name of the replication stream for this batch of rows - token (int): stream token for this batch of rows - rows (list): a list of Stream.ROW_TYPE objects as returned by - Stream.parse_row. + stream_name: name of the replication stream for this batch of rows + instance_name: the instance that wrote the rows. + token: stream token for this batch of rows + rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ self.store.process_replication_rows(stream_name, token, rows) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index d72f3d0cf9..2d1d119c7c 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -278,19 +278,24 @@ class ReplicationCommandHandler: # Check if this is the last of a batch of updates rows = self._pending_batches.pop(stream_name, []) rows.append(row) - await self.on_rdata(stream_name, cmd.token, rows) + await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows) - async def on_rdata(self, stream_name: str, token: int, rows: list): + async def on_rdata( + self, stream_name: str, instance_name: str, token: int, rows: list + ): """Called to handle a batch of replication data with a given stream token. Args: stream_name: name of the replication stream for this batch of rows + instance_name: the instance that wrote the rows. token: stream token for this batch of rows rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ logger.debug("Received rdata %s -> %s", stream_name, token) - await self._replication_data_handler.on_rdata(stream_name, token, rows) + await self._replication_data_handler.on_rdata( + stream_name, instance_name, token, rows + ) async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): if cmd.instance_name == self._instance_name: @@ -325,7 +330,9 @@ class ReplicationCommandHandler: updates, current_token, missing_updates, - ) = await stream.get_updates_since(current_token, cmd.token) + ) = await stream.get_updates_since( + cmd.instance_name, current_token, cmd.token + ) # TODO: add some tests for this @@ -334,7 +341,10 @@ class ReplicationCommandHandler: for token, rows in _batch_updates(updates): await self.on_rdata( - cmd.stream_name, token, [stream.parse_row(row) for row in rows], + cmd.stream_name, + cmd.instance_name, + token, + [stream.parse_row(row) for row in rows], ) # We've now caught up to position sent to us, notify handler. diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 4af1afd119..b0f87c365b 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -16,7 +16,7 @@ import logging from collections import namedtuple -from typing import Any, Awaitable, Callable, Iterable, List, Optional, Tuple +from typing import Any, Awaitable, Callable, List, Optional, Tuple import attr @@ -53,6 +53,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool] # # The arguments are: # +# * instance_name: the writer of the stream # * from_token: the previous stream token: the starting point for fetching the # updates # * to_token: the new stream token: the point to get updates up to @@ -62,7 +63,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool] # If there are more updates available, it should set `limited` in the result, and # it will be called again to get the next batch. # -UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]] +UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]] class Stream(object): @@ -93,6 +94,7 @@ class Stream(object): def __init__( self, + local_instance_name: str, current_token_function: Callable[[], Token], update_function: UpdateFunction, ): @@ -108,9 +110,11 @@ class Stream(object): stream tokens. See the UpdateFunction type definition for more info. Args: + local_instance_name: The instance name of the current process current_token_function: callback to get the current token, as above update_function: callback go get stream updates, as above """ + self.local_instance_name = local_instance_name self.current_token = current_token_function self.update_function = update_function @@ -135,14 +139,14 @@ class Stream(object): """ current_token = self.current_token() updates, current_token, limited = await self.get_updates_since( - self.last_token, current_token + self.local_instance_name, self.last_token, current_token ) self.last_token = current_token return updates, current_token, limited async def get_updates_since( - self, from_token: Token, upto_token: Token + self, instance_name: str, from_token: Token, upto_token: Token ) -> StreamUpdateResult: """Like get_updates except allows specifying from when we should stream updates @@ -160,19 +164,19 @@ class Stream(object): return [], upto_token, False updates, upto_token, limited = await self.update_function( - from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT, + instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT, ) return updates, upto_token, limited def db_query_to_update_function( - query_function: Callable[[Token, Token, int], Awaitable[Iterable[tuple]]] + query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]] ) -> UpdateFunction: """Wraps a db query function which returns a list of rows to make it suitable for use as an `update_function` for the Stream class """ - async def update_function(from_token, upto_token, limit): + async def update_function(instance_name, from_token, upto_token, limit): rows = await query_function(from_token, upto_token, limit) updates = [(row[0], row[1:]) for row in rows] limited = False @@ -193,10 +197,13 @@ def make_http_update_function(hs, stream_name: str) -> UpdateFunction: client = ReplicationGetStreamUpdates.make_client(hs) async def update_function( - from_token: int, upto_token: int, limit: int + instance_name: str, from_token: int, upto_token: int, limit: int ) -> StreamUpdateResult: result = await client( - stream_name=stream_name, from_token=from_token, upto_token=upto_token, + instance_name=instance_name, + stream_name=stream_name, + from_token=from_token, + upto_token=upto_token, ) return result["updates"], result["upto_token"], result["limited"] @@ -226,6 +233,7 @@ class BackfillStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_current_backfill_token, db_query_to_update_function(store.get_all_new_backfill_event_rows), ) @@ -261,7 +269,9 @@ class PresenceStream(Stream): # Query master process update_function = make_http_update_function(hs, self.NAME) - super().__init__(store.get_current_presence_token, update_function) + super().__init__( + hs.get_instance_name(), store.get_current_presence_token, update_function + ) class TypingStream(Stream): @@ -284,7 +294,9 @@ class TypingStream(Stream): # Query master process update_function = make_http_update_function(hs, self.NAME) - super().__init__(typing_handler.get_current_token, update_function) + super().__init__( + hs.get_instance_name(), typing_handler.get_current_token, update_function + ) class ReceiptsStream(Stream): @@ -305,6 +317,7 @@ class ReceiptsStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_max_receipt_stream_id, db_query_to_update_function(store.get_all_updated_receipts), ) @@ -322,14 +335,16 @@ class PushRulesStream(Stream): def __init__(self, hs): self.store = hs.get_datastore() super(PushRulesStream, self).__init__( - self._current_token, self._update_function + hs.get_instance_name(), self._current_token, self._update_function ) def _current_token(self) -> int: push_rules_token, _ = self.store.get_push_rules_stream_token() return push_rules_token - async def _update_function(self, from_token: Token, to_token: Token, limit: int): + async def _update_function( + self, instance_name: str, from_token: Token, to_token: Token, limit: int + ): rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) limited = False @@ -356,6 +371,7 @@ class PushersStream(Stream): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_pushers_stream_token, db_query_to_update_function(store.get_all_updated_pushers_rows), ) @@ -387,6 +403,7 @@ class CachesStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_cache_stream_token, db_query_to_update_function(store.get_all_updated_caches), ) @@ -412,6 +429,7 @@ class PublicRoomsStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_current_public_room_stream_id, db_query_to_update_function(store.get_all_new_public_rooms), ) @@ -432,6 +450,7 @@ class DeviceListsStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_device_stream_token, db_query_to_update_function(store.get_all_device_list_changes_for_remotes), ) @@ -449,6 +468,7 @@ class ToDeviceStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_to_device_stream_token, db_query_to_update_function(store.get_all_new_device_messages), ) @@ -468,6 +488,7 @@ class TagAccountDataStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_max_account_data_stream_id, db_query_to_update_function(store.get_all_updated_tags), ) @@ -487,6 +508,7 @@ class AccountDataStream(Stream): def __init__(self, hs): self.store = hs.get_datastore() super().__init__( + hs.get_instance_name(), self.store.get_max_account_data_stream_id, db_query_to_update_function(self._update_function), ) @@ -517,6 +539,7 @@ class GroupServerStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_group_stream_token, db_query_to_update_function(store.get_all_groups_changes), ) @@ -534,6 +557,7 @@ class UserSignatureStream(Stream): def __init__(self, hs): store = hs.get_datastore() super().__init__( + hs.get_instance_name(), store.get_device_stream_token, db_query_to_update_function( store.get_all_user_signature_changes_for_remotes diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 52df81b1bd..890e75d827 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -118,11 +118,17 @@ class EventsStream(Stream): def __init__(self, hs): self._store = hs.get_datastore() super().__init__( - self._store.get_current_events_token, self._update_function, + hs.get_instance_name(), + self._store.get_current_events_token, + self._update_function, ) async def _update_function( - self, from_token: Token, current_token: Token, target_row_count: int + self, + instance_name: str, + from_token: Token, + current_token: Token, + target_row_count: int, ) -> StreamUpdateResult: # the events stream merges together three separate sources: diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index 75133d7e40..e8bd52e389 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -48,8 +48,8 @@ class FederationStream(Stream): current_token = lambda: 0 update_function = self._stub_update_function - super().__init__(current_token, update_function) + super().__init__(hs.get_instance_name(), current_token, update_function) @staticmethod - async def _stub_update_function(from_token, upto_token, limit): + async def _stub_update_function(instance_name, from_token, upto_token, limit): return [], upto_token, False diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index 8c104f8d1d..7b56d2028d 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -183,8 +183,8 @@ class TestReplicationDataHandler(GenericWorkerReplicationHandler): # list of received (stream_name, token, row) tuples self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]] - async def on_rdata(self, stream_name, token, rows): - await super().on_rdata(stream_name, token, rows) + async def on_rdata(self, stream_name, instance_name, token, rows): + await super().on_rdata(stream_name, instance_name, token, rows) for r in rows: self.received_rdata_rows.append((stream_name, token, r)) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index df332ee679..5853314fd4 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -41,7 +41,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): # there should be one RDATA command self.test_handler.on_rdata.assert_called_once() - stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "receipts") self.assertEqual(1, len(rdata_rows)) row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow @@ -71,7 +71,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): # We should now have caught up and get the missing data self.test_handler.on_rdata.assert_called_once() - stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "receipts") self.assertEqual(token, 3) self.assertEqual(1, len(rdata_rows)) diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py index e8d17ca68a..d25a7b194e 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py @@ -47,7 +47,7 @@ class TypingStreamTestCase(BaseStreamTestCase): self.assert_request_is_get_repl_stream_updates(request, "typing") self.test_handler.on_rdata.assert_called_once() - stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) row = rdata_rows[0] # type: TypingStream.TypingStreamRow @@ -74,7 +74,7 @@ class TypingStreamTestCase(BaseStreamTestCase): self.assertEqual(int(request.args[b"from_token"][0]), token) self.test_handler.on_rdata.assert_called_once() - stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) row = rdata_rows[0] -- cgit 1.5.1 From d78265af0c77b7d75bc5fd013515b53b01a2ac23 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 5 May 2020 18:53:38 +0100 Subject: Wait to subscribe before sending REPLICATE --- synapse/replication/tcp/handler.py | 3 ++- synapse/replication/tcp/redis.py | 52 ++++++++++++++++++++++++-------------- 2 files changed, 35 insertions(+), 20 deletions(-) (limited to 'synapse/replication/tcp/handler.py') diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 6f7054d5af..0e8a38fd80 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -99,7 +99,8 @@ class ReplicationCommandHandler: # The factory used to create connections. self._factory = None # type: Optional[ReconnectingClientFactory] - # The currently connected connections. + # The currently connected connections. (The list of places we need to send + # outgoing replication commands to.) self._connections = [] # type: List[AbstractConnection] LaterGauge( diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 617e860f95..61dbf4ddbe 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING import txredisapi -from synapse.logging.context import PreserveLoggingContext +from synapse.logging.context import make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.commands import ( Command, @@ -41,8 +41,14 @@ logger = logging.getLogger(__name__) class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): """Connection to redis subscribed to replication stream. - Parses incoming messages from redis into replication commands, and passes - them to `ReplicationCommandHandler` + This class fulfils two functions: + + (a) it implements the twisted Protocol API, where it handles the SUBSCRIBEd redis + connection, parsing *incoming* messages into replication commands, and passing them + to `ReplicationCommandHandler` + + (b) it implements the AbstractConnection API, where it sends *outgoing* commands + onto outbound_redis_connection. Due to the vagaries of `txredisapi` we don't want to have a custom constructor, so instead we expect the defined attributes below to be set @@ -50,8 +56,8 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): Attributes: handler: The command handler to handle incoming commands. - stream_name: The *redis* stream name to subscribe to (not anything to - do with Synapse replication streams). + stream_name: The *redis* stream name to subscribe to and publish from + (not anything to do with Synapse replication streams). outbound_redis_connection: The connection to redis to use to send commands. """ @@ -61,12 +67,22 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): outbound_redis_connection = None # type: txredisapi.RedisProtocol def connectionMade(self): - logger.info("Connected to redis instance") - self.subscribe(self.stream_name) - self.send_command(ReplicateCommand()) - + logger.info("Connected to redis") + run_as_background_process("subscribe-replication", self._send_subscribe) self.handler.new_connection(self) + async def _send_subscribe(self): + # it's important to make sure that we only send the REPLICATE command once we + # have successfully subscribed to the stream - otherwise we might miss the + # POSITION response sent back by the other end. + logger.info("Sending redis SUBSCRIBE for %s", self.stream_name) + await make_deferred_yieldable(self.subscribe(self.stream_name)) + logger.info( + "Successfully subscribed to redis stream, sending REPLICATE command" + ) + await self._async_send_command(ReplicateCommand()) + logger.info("REPLICATE successfully sent") + def messageReceived(self, pattern: str, channel: str, message: str): """Received a message from redis. """ @@ -119,7 +135,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): logger.warning("Unhandled command: %r", cmd) def connectionLost(self, reason): - logger.info("Lost connection to redis instance") + logger.info("Lost connection to redis") self.handler.lost_connection(self) def send_command(self, cmd: Command): @@ -128,6 +144,10 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): Args: cmd (Command) """ + run_as_background_process("send-cmd", self._send_command, cmd) + + async def _async_send_command(self, cmd: Command): + """Encode a replication command and send it over our outbound connection""" string = "%s %s" % (cmd.NAME, cmd.to_line()) if "\n" in string: raise Exception("Unexpected newline in command: %r", string) @@ -138,15 +158,9 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): # remote instances. tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc() - async def _send(): - with PreserveLoggingContext(): - # Note that we use the other connection as we can't send - # commands using the subscription connection. - await self.outbound_redis_connection.publish( - self.stream_name, encoded_string - ) - - run_as_background_process("send-cmd", _send) + await make_deferred_yieldable( + self.outbound_redis_connection.publish(self.stream_name, encoded_string) + ) class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory): -- cgit 1.5.1 From 5b8023dc7f31dbd682e93bd4d82655f70f5f19f5 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 5 May 2020 21:07:33 +0200 Subject: Move logs about discarded RDATA to debug (#7421) --- changelog.d/7421.misc | 1 + synapse/replication/tcp/handler.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/7421.misc (limited to 'synapse/replication/tcp/handler.py') diff --git a/changelog.d/7421.misc b/changelog.d/7421.misc new file mode 100644 index 0000000000..676f285377 --- /dev/null +++ b/changelog.d/7421.misc @@ -0,0 +1 @@ +Move catchup of replication streams logic to worker. diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 2d1d119c7c..bf4f1a5949 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -262,7 +262,7 @@ class ReplicationCommandHandler: # `POSITION` command yet, and so we may have missed some rows. # Let's drop the row for now, on the assumption we'll receive a # `POSITION` soon and we'll catch up correctly then. - logger.warning( + logger.debug( "Discarding RDATA for unconnected stream %s -> %s", stream_name, cmd.token, -- cgit 1.5.1 From 7f7eedbebba561771f39c22c78d1dfa1412283a5 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 5 May 2020 19:32:35 +0100 Subject: Wait for a POSITION on the right connection before accepting RDATA ... otherwise we can believe we're up to date when we're not. --- synapse/replication/tcp/handler.py | 55 +++++++++++++++++++++++++------------- synapse/replication/tcp/redis.py | 2 +- 2 files changed, 38 insertions(+), 19 deletions(-) (limited to 'synapse/replication/tcp/handler.py') diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 0e8a38fd80..ffce3db629 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -81,9 +81,6 @@ class ReplicationCommandHandler: self._instance_id = hs.get_instance_id() self._instance_name = hs.get_instance_name() - # Set of streams that we've caught up with. - self._streams_connected = set() # type: Set[str] - self._streams = { stream.NAME: stream(hs) for stream in STREAMS_MAP.values() } # type: Dict[str, Stream] @@ -103,6 +100,9 @@ class ReplicationCommandHandler: # outgoing replication commands to.) self._connections = [] # type: List[AbstractConnection] + # For each connection, the incoming streams that are coming from that connection + self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]] + LaterGauge( "synapse_replication_tcp_resource_total_connections", "", @@ -258,12 +258,14 @@ class ReplicationCommandHandler: # 2. so we don't race with getting a POSITION command and fetching # missing RDATA. with await self._position_linearizer.queue(cmd.stream_name): - if stream_name not in self._streams_connected: - # If the stream isn't marked as connected then we haven't seen a - # `POSITION` command yet, and so we may have missed some rows. + # make sure that we've processed a POSITION for this stream *on this + # connection*. (A POSITION on another connection is no good, as there + # is no guarantee that we have seen all the intermediate updates.) + sbc = self._streams_by_connection.get(conn) + if not sbc or stream_name not in sbc: # Let's drop the row for now, on the assumption we'll receive a # `POSITION` soon and we'll catch up correctly then. - logger.warning( + logger.info( "Discarding RDATA for unconnected stream %s -> %s", stream_name, cmd.token, @@ -298,30 +300,33 @@ class ReplicationCommandHandler: # Ignore POSITION that are just our own echoes return - stream = self._streams.get(cmd.stream_name) + logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line()) + + stream_name = cmd.stream_name + stream = self._streams.get(stream_name) if not stream: - logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) + logger.error("Got POSITION for unknown stream: %s", stream_name) return # We protect catching up with a linearizer in case the replication # connection reconnects under us. - with await self._position_linearizer.queue(cmd.stream_name): + with await self._position_linearizer.queue(stream_name): # We're about to go and catch up with the stream, so remove from set # of connected streams. - self._streams_connected.discard(cmd.stream_name) + for streams in self._streams_by_connection.values(): + streams.discard(stream_name) # We clear the pending batches for the stream as the fetching of the # missing updates below will fetch all rows in the batch. - self._pending_batches.pop(cmd.stream_name, []) + self._pending_batches.pop(stream_name, []) # Find where we previously streamed up to. current_token = self._replication_data_handler.get_streams_to_replicate().get( - cmd.stream_name + stream_name ) if current_token is None: logger.warning( - "Got POSITION for stream we're not subscribed to: %s", - cmd.stream_name, + "Got POSITION for stream we're not subscribed to: %s", stream_name, ) return @@ -330,6 +335,12 @@ class ReplicationCommandHandler: # between then and now. missing_updates = cmd.token != current_token while missing_updates: + logger.info( + "Fetching replication rows for '%s' between %i and %i", + stream_name, + current_token, + cmd.token, + ) ( updates, current_token, @@ -343,13 +354,15 @@ class ReplicationCommandHandler: for token, rows in _batch_updates(updates): await self.on_rdata( - cmd.stream_name, token, [stream.parse_row(row) for row in rows], + stream_name, token, [stream.parse_row(row) for row in rows], ) + logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token) + # We've now caught up to position sent to us, notify handler. - await self._replication_data_handler.on_position(cmd.stream_name, cmd.token) + await self._replication_data_handler.on_position(stream_name, cmd.token) - self._streams_connected.add(cmd.stream_name) + self._streams_by_connection.setdefault(conn, set()).add(stream_name) async def on_REMOTE_SERVER_UP( self, conn: AbstractConnection, cmd: RemoteServerUpCommand @@ -407,6 +420,12 @@ class ReplicationCommandHandler: def lost_connection(self, connection: AbstractConnection): """Called when a connection is closed/lost. """ + # we no longer need _streams_by_connection for this connection. + streams = self._streams_by_connection.pop(connection, None) + if streams: + logger.info( + "Lost replication connection; streams now disconnected: %s", streams + ) try: self._connections.remove(connection) except ValueError: diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 61dbf4ddbe..b3ae9f401c 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -144,7 +144,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): Args: cmd (Command) """ - run_as_background_process("send-cmd", self._send_command, cmd) + run_as_background_process("send-cmd", self._async_send_command, cmd) async def _async_send_command(self, cmd: Command): """Encode a replication command and send it over our outbound connection""" -- cgit 1.5.1 From da9b2db3af55f8508689f5c5d3aab67c01a3edbd Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 7 May 2020 16:46:15 +0100 Subject: Drop support for redis.dbid (#7450) Since we only use pubsub, the dbid is irrelevant. --- changelog.d/7450.feature | 1 + synapse/config/redis.py | 1 - synapse/replication/tcp/handler.py | 4 +--- 3 files changed, 2 insertions(+), 4 deletions(-) create mode 100644 changelog.d/7450.feature (limited to 'synapse/replication/tcp/handler.py') diff --git a/changelog.d/7450.feature b/changelog.d/7450.feature new file mode 100644 index 0000000000..ce6140fdd1 --- /dev/null +++ b/changelog.d/7450.feature @@ -0,0 +1 @@ +Add support for running replication over Redis when using workers. diff --git a/synapse/config/redis.py b/synapse/config/redis.py index 81a27619ec..d5d3ca1c9e 100644 --- a/synapse/config/redis.py +++ b/synapse/config/redis.py @@ -31,5 +31,4 @@ class RedisConfig(Config): self.redis_host = redis_config.get("host", "localhost") self.redis_port = redis_config.get("port", 6379) - self.redis_dbid = redis_config.get("dbid") self.redis_password = redis_config.get("password") diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index b14a3d9fca..4328b38e9d 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -131,10 +131,9 @@ class ReplicationCommandHandler: import txredisapi logger.info( - "Connecting to redis (host=%r port=%r DBID=%r)", + "Connecting to redis (host=%r port=%r)", hs.config.redis_host, hs.config.redis_port, - hs.config.redis_dbid, ) # We need two connections to redis, one for the subscription stream and @@ -145,7 +144,6 @@ class ReplicationCommandHandler: outbound_redis_connection = txredisapi.lazyConnection( host=hs.config.redis_host, port=hs.config.redis_port, - dbid=hs.config.redis_dbid, password=hs.config.redis.redis_password, reconnect=True, ) -- cgit 1.5.1