summary refs log tree commit diff
path: root/synapse/replication/tcp
diff options
context:
space:
mode:
authorBen Banfield-Zanin <benbz@matrix.org>2020-09-15 11:44:49 +0100
committerBen Banfield-Zanin <benbz@matrix.org>2020-09-15 11:44:49 +0100
commit1a7d96aa6ff81638f2ea696fdee2ec44e7bff75a (patch)
tree1839e80f89c53b34ff1b36974305c6cb0c94aab4 /synapse/replication/tcp
parentFix group server for older synapse (diff)
parentClarify changelog. (diff)
downloadsynapse-bbz/info-mainline-1.20.0.tar.xz
Merge remote-tracking branch 'origin/release-v1.20.0' into bbz/info-mainline-1.20.0 github/bbz/info-mainline-1.20.0 bbz/info-mainline-1.20.0
Diffstat (limited to '')
-rw-r--r--synapse/replication/tcp/__init__.py2
-rw-r--r--synapse/replication/tcp/client.py16
-rw-r--r--synapse/replication/tcp/commands.py34
-rw-r--r--synapse/replication/tcp/handler.py413
-rw-r--r--synapse/replication/tcp/protocol.py62
-rw-r--r--synapse/replication/tcp/redis.py61
-rw-r--r--synapse/replication/tcp/resource.py2
-rw-r--r--synapse/replication/tcp/streams/_base.py96
-rw-r--r--synapse/replication/tcp/streams/events.py10
9 files changed, 417 insertions, 279 deletions
diff --git a/synapse/replication/tcp/__init__.py b/synapse/replication/tcp/__init__.py

index 523a1358d4..1b8718b11d 100644 --- a/synapse/replication/tcp/__init__.py +++ b/synapse/replication/tcp/__init__.py
@@ -25,7 +25,7 @@ Structure of the module: * command.py - the definitions of all the valid commands * protocol.py - the TCP protocol classes * resource.py - handles streaming stream updates to replications - * streams/ - the definitons of all the valid streams + * streams/ - the definitions of all the valid streams The general interaction of the classes are: diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index df29732f51..d6ecf5b327 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py
@@ -14,7 +14,6 @@ # limitations under the License. """A replication client for use by synapse workers. """ -import heapq import logging from typing import TYPE_CHECKING, Dict, List, Tuple @@ -24,6 +23,7 @@ from twisted.internet.protocol import ReconnectingClientFactory from synapse.api.constants import EventTypes from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol +from synapse.replication.tcp.streams import TypingStream from synapse.replication.tcp.streams.events import ( EventsStream, EventsStreamEventRow, @@ -33,8 +33,8 @@ from synapse.util.async_helpers import timeout_deferred from synapse.util.metrics import Measure if TYPE_CHECKING: - from synapse.server import HomeServer from synapse.replication.tcp.handler import ReplicationCommandHandler + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -104,6 +104,7 @@ class ReplicationDataHandler: self._clock = hs.get_clock() self._streams = hs.get_replication_streams() self._instance_name = hs.get_instance_name() + self._typing_handler = hs.get_typing_handler() # Map from stream to list of deferreds waiting for the stream to # arrive at a particular position. The lists are sorted by stream position. @@ -127,6 +128,12 @@ class ReplicationDataHandler: """ self.store.process_replication_rows(stream_name, instance_name, token, rows) + if stream_name == TypingStream.NAME: + self._typing_handler.process_replication_rows(token, rows) + self.notifier.on_new_event( + "typing_key", token, rooms=[row.room_id for row in rows] + ) + if stream_name == EventsStream.NAME: # We shouldn't get multiple rows per token for events stream, so # we don't need to optimise this for multiple rows. @@ -211,9 +218,8 @@ class ReplicationDataHandler: waiting_list = self._streams_to_waiters.setdefault(stream_name, []) - # We insert into the list using heapq as it is more efficient than - # pushing then resorting each time. - heapq.heappush(waiting_list, (position, deferred)) + waiting_list.append((position, deferred)) + waiting_list.sort(key=lambda t: t[0]) # We measure here to get in flight counts and average waiting time. with Measure(self._clock, "repl.wait_for_stream_position"): diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index c04f622816..8cd47770c1 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py
@@ -19,17 +19,9 @@ allowed to be sent by which side. """ import abc import logging -import platform from typing import Tuple, Type -if platform.python_implementation() == "PyPy": - import json - - _json_encoder = json.JSONEncoder() -else: - import simplejson as json # type: ignore[no-redef] # noqa: F821 - - _json_encoder = json.JSONEncoder(namedtuple_as_object=False) # type: ignore[call-arg] # noqa: F821 +from synapse.util import json_decoder, json_encoder logger = logging.getLogger(__name__) @@ -54,7 +46,7 @@ class Command(metaclass=abc.ABCMeta): @abc.abstractmethod def to_line(self) -> str: - """Serialises the comamnd for the wire. Does not include the command + """Serialises the command for the wire. Does not include the command prefix. """ @@ -131,7 +123,7 @@ class RdataCommand(Command): stream_name, instance_name, None if token == "batch" else int(token), - json.loads(row_json), + json_decoder.decode(row_json), ) def to_line(self): @@ -140,7 +132,7 @@ class RdataCommand(Command): self.stream_name, self.instance_name, str(self.token) if self.token is not None else "batch", - _json_encoder.encode(self.row), + json_encoder.encode(self.row), ) ) @@ -149,7 +141,7 @@ class RdataCommand(Command): class PositionCommand(Command): - """Sent by the server to tell the client the stream postition without + """Sent by the server to tell the client the stream position without needing to send an RDATA. Format:: @@ -188,7 +180,7 @@ class ErrorCommand(_SimpleCommand): class PingCommand(_SimpleCommand): - """Sent by either side as a keep alive. The data is arbitary (often timestamp) + """Sent by either side as a keep alive. The data is arbitrary (often timestamp) """ NAME = "PING" @@ -300,20 +292,22 @@ class FederationAckCommand(Command): Format:: - FEDERATION_ACK <token> + FEDERATION_ACK <instance_name> <token> """ NAME = "FEDERATION_ACK" - def __init__(self, token): + def __init__(self, instance_name, token): + self.instance_name = instance_name self.token = token @classmethod def from_line(cls, line): - return cls(int(line)) + instance_name, token = line.split(" ") + return cls(instance_name, int(token)) def to_line(self): - return str(self.token) + return "%s %s" % (self.instance_name, self.token) class RemovePusherCommand(Command): @@ -363,7 +357,7 @@ class UserIpCommand(Command): def from_line(cls, line): user_id, jsn = line.split(" ", 1) - access_token, ip, user_agent, device_id, last_seen = json.loads(jsn) + access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn) return cls(user_id, access_token, ip, user_agent, device_id, last_seen) @@ -371,7 +365,7 @@ class UserIpCommand(Command): return ( self.user_id + " " - + _json_encoder.encode( + + json_encoder.encode( ( self.access_token, self.ip, diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index cbcf46f3ae..1c303f3a46 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py
@@ -13,15 +13,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging -from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar +from typing import ( + Any, + Awaitable, + Dict, + Iterable, + Iterator, + List, + Optional, + Set, + Tuple, + TypeVar, + Union, +) from prometheus_client import Counter +from typing_extensions import Deque from twisted.internet.protocol import ReconnectingClientFactory from synapse.metrics import LaterGauge +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.client import DirectTcpReplicationClientFactory from synapse.replication.tcp.commands import ( ClearUserSyncsCommand, @@ -43,8 +56,8 @@ from synapse.replication.tcp.streams import ( EventsStream, FederationStream, Stream, + TypingStream, ) -from synapse.util.async_helpers import Linearizer logger = logging.getLogger(__name__) @@ -56,12 +69,16 @@ inbound_rdata_count = Counter( user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "") federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "") remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "") -invalidate_cache_counter = Counter( - "synapse_replication_tcp_resource_invalidate_cache", "" -) + user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") +# the type of the entries in _command_queues_by_stream +_StreamCommandQueue = Deque[ + Tuple[Union[RdataCommand, PositionCommand], AbstractConnection] +] + + class ReplicationCommandHandler: """Handles incoming commands from replication as well as sending commands back out to connections. @@ -97,6 +114,14 @@ class ReplicationCommandHandler: continue + if isinstance(stream, TypingStream): + # Only add TypingStream as a source on the instance in charge of + # typing. + if hs.config.worker.writers.typing == hs.get_instance_name(): + self._streams_to_replicate.append(stream) + + continue + # Only add any other streams if we're on master. if hs.config.worker_app is not None: continue @@ -108,12 +133,8 @@ class ReplicationCommandHandler: self._streams_to_replicate.append(stream) - self._position_linearizer = Linearizer( - "replication_position", clock=self._clock - ) - - # Map of stream to batched updates. See RdataCommand for info on how - # batching works. + # Map of stream name to batched updates. See RdataCommand for info on + # how batching works. self._pending_batches = {} # type: Dict[str, List[Any]] # The factory used to create connections. @@ -123,9 +144,6 @@ class ReplicationCommandHandler: # outgoing replication commands to.) self._connections = [] # type: List[AbstractConnection] - # For each connection, the incoming streams that are coming from that connection - self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]] - LaterGauge( "synapse_replication_tcp_resource_total_connections", "", @@ -133,6 +151,32 @@ class ReplicationCommandHandler: lambda: len(self._connections), ) + # When POSITION or RDATA commands arrive, we stick them in a queue and process + # them in order in a separate background process. + + # the streams which are currently being processed by _unsafe_process_queue + self._processing_streams = set() # type: Set[str] + + # for each stream, a queue of commands that are awaiting processing, and the + # connection that they arrived on. + self._command_queues_by_stream = { + stream_name: _StreamCommandQueue() for stream_name in self._streams + } + + # For each connection, the incoming stream names that have received a POSITION + # from that connection. + self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]] + + LaterGauge( + "synapse_replication_tcp_command_queue", + "Number of inbound RDATA/POSITION commands queued for processing", + ["stream_name"], + lambda: { + (stream_name,): len(queue) + for stream_name, queue in self._command_queues_by_stream.items() + }, + ) + self._is_master = hs.config.worker_app is None self._federation_sender = None @@ -143,15 +187,75 @@ class ReplicationCommandHandler: if self._is_master: self._server_notices_sender = hs.get_server_notices_sender() + def _add_command_to_stream_queue( + self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand] + ) -> None: + """Queue the given received command for processing + + Adds the given command to the per-stream queue, and processes the queue if + necessary + """ + stream_name = cmd.stream_name + queue = self._command_queues_by_stream.get(stream_name) + if queue is None: + logger.error("Got %s for unknown stream: %s", cmd.NAME, stream_name) + return + + queue.append((cmd, conn)) + + # if we're already processing this stream, there's nothing more to do: + # the new entry on the queue will get picked up in due course + if stream_name in self._processing_streams: + return + + # fire off a background process to start processing the queue. + run_as_background_process( + "process-replication-data", self._unsafe_process_queue, stream_name + ) + + async def _unsafe_process_queue(self, stream_name: str): + """Processes the command queue for the given stream, until it is empty + + Does not check if there is already a thread processing the queue, hence "unsafe" + """ + assert stream_name not in self._processing_streams + + self._processing_streams.add(stream_name) + try: + queue = self._command_queues_by_stream.get(stream_name) + while queue: + cmd, conn = queue.popleft() + try: + await self._process_command(cmd, conn, stream_name) + except Exception: + logger.exception("Failed to handle command %s", cmd) + finally: + self._processing_streams.discard(stream_name) + + async def _process_command( + self, + cmd: Union[PositionCommand, RdataCommand], + conn: AbstractConnection, + stream_name: str, + ) -> None: + if isinstance(cmd, PositionCommand): + await self._process_position(stream_name, conn, cmd) + elif isinstance(cmd, RdataCommand): + await self._process_rdata(stream_name, conn, cmd) + else: + # This shouldn't be possible + raise Exception("Unrecognised command %s in stream queue", cmd.NAME) + def start_replication(self, hs): """Helper method to start a replication connection to the remote server using TCP. """ if hs.config.redis.redis_enabled: + import txredisapi + from synapse.replication.tcp.redis import ( RedisDirectTcpReplicationClientFactory, ) - import txredisapi logger.info( "Connecting to redis (host=%r port=%r)", @@ -198,7 +302,7 @@ class ReplicationCommandHandler: """ return self._streams_to_replicate - async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand): + def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand): self.send_positions_to_connection(conn) def send_positions_to_connection(self, conn: AbstractConnection): @@ -217,57 +321,73 @@ class ReplicationCommandHandler: ) ) - async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand): + def on_USER_SYNC( + self, conn: AbstractConnection, cmd: UserSyncCommand + ) -> Optional[Awaitable[None]]: user_sync_counter.inc() if self._is_master: - await self._presence_handler.update_external_syncs_row( + return self._presence_handler.update_external_syncs_row( cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms ) + else: + return None - async def on_CLEAR_USER_SYNC( + def on_CLEAR_USER_SYNC( self, conn: AbstractConnection, cmd: ClearUserSyncsCommand - ): + ) -> Optional[Awaitable[None]]: if self._is_master: - await self._presence_handler.update_external_syncs_clear(cmd.instance_id) + return self._presence_handler.update_external_syncs_clear(cmd.instance_id) + else: + return None - async def on_FEDERATION_ACK( - self, conn: AbstractConnection, cmd: FederationAckCommand - ): + def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand): federation_ack_counter.inc() if self._federation_sender: - self._federation_sender.federation_ack(cmd.token) + self._federation_sender.federation_ack(cmd.instance_name, cmd.token) - async def on_REMOVE_PUSHER( + def on_REMOVE_PUSHER( self, conn: AbstractConnection, cmd: RemovePusherCommand - ): + ) -> Optional[Awaitable[None]]: remove_pusher_counter.inc() if self._is_master: - await self._store.delete_pusher_by_app_id_pushkey_user_id( - app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id - ) + return self._handle_remove_pusher(cmd) + else: + return None + + async def _handle_remove_pusher(self, cmd: RemovePusherCommand): + await self._store.delete_pusher_by_app_id_pushkey_user_id( + app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id + ) - self._notifier.on_new_replication_data() + self._notifier.on_new_replication_data() - async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand): + def on_USER_IP( + self, conn: AbstractConnection, cmd: UserIpCommand + ) -> Optional[Awaitable[None]]: user_ip_cache_counter.inc() if self._is_master: - await self._store.insert_client_ip( - cmd.user_id, - cmd.access_token, - cmd.ip, - cmd.user_agent, - cmd.device_id, - cmd.last_seen, - ) + return self._handle_user_ip(cmd) + else: + return None + + async def _handle_user_ip(self, cmd: UserIpCommand): + await self._store.insert_client_ip( + cmd.user_id, + cmd.access_token, + cmd.ip, + cmd.user_agent, + cmd.device_id, + cmd.last_seen, + ) - if self._server_notices_sender: - await self._server_notices_sender.on_user_ip(cmd.user_id) + assert self._server_notices_sender is not None + await self._server_notices_sender.on_user_ip(cmd.user_id) - async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand): + def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand): if cmd.instance_name == self._instance_name: # Ignore RDATA that are just our own echoes return @@ -275,42 +395,71 @@ class ReplicationCommandHandler: stream_name = cmd.stream_name inbound_rdata_count.labels(stream_name).inc() - try: - row = STREAMS_MAP[stream_name].parse_row(cmd.row) - except Exception: - logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row) - raise - - # We linearize here for two reasons: + # We put the received command into a queue here for two reasons: # 1. so we don't try and concurrently handle multiple rows for the # same stream, and # 2. so we don't race with getting a POSITION command and fetching # missing RDATA. - with await self._position_linearizer.queue(cmd.stream_name): - # make sure that we've processed a POSITION for this stream *on this - # connection*. (A POSITION on another connection is no good, as there - # is no guarantee that we have seen all the intermediate updates.) - sbc = self._streams_by_connection.get(conn) - if not sbc or stream_name not in sbc: - # Let's drop the row for now, on the assumption we'll receive a - # `POSITION` soon and we'll catch up correctly then. - logger.debug( - "Discarding RDATA for unconnected stream %s -> %s", - stream_name, - cmd.token, - ) - return - - if cmd.token is None: - # I.e. this is part of a batch of updates for this stream (in - # which case batch until we get an update for the stream with a non - # None token). - self._pending_batches.setdefault(stream_name, []).append(row) - else: - # Check if this is the last of a batch of updates - rows = self._pending_batches.pop(stream_name, []) - rows.append(row) - await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows) + + self._add_command_to_stream_queue(conn, cmd) + + async def _process_rdata( + self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand + ) -> None: + """Process an RDATA command + + Called after the command has been popped off the queue of inbound commands + """ + try: + row = STREAMS_MAP[stream_name].parse_row(cmd.row) + except Exception as e: + raise Exception( + "Failed to parse RDATA: %r %r" % (stream_name, cmd.row) + ) from e + + # make sure that we've processed a POSITION for this stream *on this + # connection*. (A POSITION on another connection is no good, as there + # is no guarantee that we have seen all the intermediate updates.) + sbc = self._streams_by_connection.get(conn) + if not sbc or stream_name not in sbc: + # Let's drop the row for now, on the assumption we'll receive a + # `POSITION` soon and we'll catch up correctly then. + logger.debug( + "Discarding RDATA for unconnected stream %s -> %s", + stream_name, + cmd.token, + ) + return + + if cmd.token is None: + # I.e. this is part of a batch of updates for this stream (in + # which case batch until we get an update for the stream with a non + # None token). + self._pending_batches.setdefault(stream_name, []).append(row) + return + + # Check if this is the last of a batch of updates + rows = self._pending_batches.pop(stream_name, []) + rows.append(row) + + stream = self._streams[stream_name] + + # Find where we previously streamed up to. + current_token = stream.current_token(cmd.instance_name) + + # Discard this data if this token is earlier than the current + # position. Note that streams can be reset (in which case you + # expect an earlier token), but that must be preceded by a + # POSITION command. + if cmd.token <= current_token: + logger.debug( + "Discarding RDATA from stream %s at position %s before previous position %s", + stream_name, + cmd.token, + current_token, + ) + else: + await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows) async def on_rdata( self, stream_name: str, instance_name: str, token: int, rows: list @@ -329,78 +478,74 @@ class ReplicationCommandHandler: stream_name, instance_name, token, rows ) - async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): + def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): if cmd.instance_name == self._instance_name: # Ignore POSITION that are just our own echoes return logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line()) - stream_name = cmd.stream_name - stream = self._streams.get(stream_name) - if not stream: - logger.error("Got POSITION for unknown stream: %s", stream_name) - return + self._add_command_to_stream_queue(conn, cmd) - # We protect catching up with a linearizer in case the replication - # connection reconnects under us. - with await self._position_linearizer.queue(stream_name): - # We're about to go and catch up with the stream, so remove from set - # of connected streams. - for streams in self._streams_by_connection.values(): - streams.discard(stream_name) - - # We clear the pending batches for the stream as the fetching of the - # missing updates below will fetch all rows in the batch. - self._pending_batches.pop(stream_name, []) - - # Find where we previously streamed up to. - current_token = stream.current_token(cmd.instance_name) - - # If the position token matches our current token then we're up to - # date and there's nothing to do. Otherwise, fetch all updates - # between then and now. - missing_updates = cmd.token != current_token - while missing_updates: - logger.info( - "Fetching replication rows for '%s' between %i and %i", - stream_name, - current_token, - cmd.token, - ) - ( - updates, - current_token, - missing_updates, - ) = await stream.get_updates_since( - cmd.instance_name, current_token, cmd.token - ) + async def _process_position( + self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand + ) -> None: + """Process a POSITION command - # TODO: add some tests for this + Called after the command has been popped off the queue of inbound commands + """ + stream = self._streams[stream_name] - # Some streams return multiple rows with the same stream IDs, - # which need to be processed in batches. + # We're about to go and catch up with the stream, so remove from set + # of connected streams. + for streams in self._streams_by_connection.values(): + streams.discard(stream_name) - for token, rows in _batch_updates(updates): - await self.on_rdata( - stream_name, - cmd.instance_name, - token, - [stream.parse_row(row) for row in rows], - ) + # We clear the pending batches for the stream as the fetching of the + # missing updates below will fetch all rows in the batch. + self._pending_batches.pop(stream_name, []) - logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token) + # Find where we previously streamed up to. + current_token = stream.current_token(cmd.instance_name) - # We've now caught up to position sent to us, notify handler. - await self._replication_data_handler.on_position( - cmd.stream_name, cmd.instance_name, cmd.token + # If the position token matches our current token then we're up to + # date and there's nothing to do. Otherwise, fetch all updates + # between then and now. + missing_updates = cmd.token != current_token + while missing_updates: + logger.info( + "Fetching replication rows for '%s' between %i and %i", + stream_name, + current_token, + cmd.token, + ) + (updates, current_token, missing_updates) = await stream.get_updates_since( + cmd.instance_name, current_token, cmd.token ) - self._streams_by_connection.setdefault(conn, set()).add(stream_name) + # TODO: add some tests for this - async def on_REMOTE_SERVER_UP( - self, conn: AbstractConnection, cmd: RemoteServerUpCommand - ): + # Some streams return multiple rows with the same stream IDs, + # which need to be processed in batches. + + for token, rows in _batch_updates(updates): + await self.on_rdata( + stream_name, + cmd.instance_name, + token, + [stream.parse_row(row) for row in rows], + ) + + logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token) + + # We've now caught up to position sent to us, notify handler. + await self._replication_data_handler.on_position( + cmd.stream_name, cmd.instance_name, cmd.token + ) + + self._streams_by_connection.setdefault(conn, set()).add(stream_name) + + def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand): """"Called when get a new REMOTE_SERVER_UP command.""" self._replication_data_handler.on_remote_server_up(cmd.data) @@ -505,7 +650,7 @@ class ReplicationCommandHandler: """Ack data for the federation stream. This allows the master to drop data stored purely in memory. """ - self.send_command(FederationAckCommand(token)) + self.send_command(FederationAckCommand(self._instance_name, token)) def send_user_sync( self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 4198eece71..0b0d204e64 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py
@@ -50,6 +50,7 @@ import abc import fcntl import logging import struct +from inspect import isawaitable from typing import TYPE_CHECKING, List from prometheus_client import Counter @@ -57,8 +58,12 @@ from prometheus_client import Counter from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure +from synapse.logging.context import PreserveLoggingContext from synapse.metrics import LaterGauge -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import ( + BackgroundProcessLoggingContext, + run_as_background_process, +) from synapse.replication.tcp.commands import ( VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS, @@ -108,7 +113,7 @@ PING_TIMEOUT_MULTIPLIER = 5 PING_TIMEOUT_MS = PING_TIME * PING_TIMEOUT_MULTIPLIER -class ConnectionStates(object): +class ConnectionStates: CONNECTING = "connecting" ESTABLISHED = "established" PAUSED = "paused" @@ -124,6 +129,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed command before delegating to `ReplicationCommandHandler.on_<COMMAND_NAME>`. + `ReplicationCommandHandler.on_<COMMAND_NAME>` can optionally return a coroutine; + if so, that will get run as a background process. It also sends `PING` periodically, and correctly times out remote connections (if they send a `PING` command) @@ -160,6 +167,12 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # The LoopingCall for sending pings. self._send_ping_loop = None + # a logcontext which we use for processing incoming commands. We declare it as a + # background process so that the CPU stats get reported to prometheus. + ctx_name = "replication-conn-%s" % self.conn_id + self._logging_context = BackgroundProcessLoggingContext(ctx_name) + self._logging_context.request = ctx_name + def connectionMade(self): logger.info("[%s] Connection established", self.id()) @@ -210,6 +223,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): def lineReceived(self, line: bytes): """Called when we've received a line """ + with PreserveLoggingContext(self._logging_context): + self._parse_and_dispatch_line(line) + + def _parse_and_dispatch_line(self, line: bytes): if line.strip() == "": # Ignore blank lines return @@ -232,18 +249,17 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): tcp_inbound_commands_counter.labels(cmd.NAME, self.name).inc() - # Now lets try and call on_<CMD_NAME> function - run_as_background_process( - "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd - ) + self.handle_command(cmd) - async def handle_command(self, cmd: Command): + def handle_command(self, cmd: Command) -> None: """Handle a command we have received over the replication stream. First calls `self.on_<COMMAND>` if it exists, then calls - `self.command_handler.on_<COMMAND>` if it exists. This allows for - protocol level handling of commands (e.g. PINGs), before delegating to - the handler. + `self.command_handler.on_<COMMAND>` if it exists (which can optionally + return an Awaitable). + + This allows for protocol level handling of commands (e.g. PINGs), before + delegating to the handler. Args: cmd: received command @@ -254,13 +270,22 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # specific handling. cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) if cmd_func: - await cmd_func(cmd) + cmd_func(cmd) handled = True # Then call out to the handler. cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None) if cmd_func: - await cmd_func(self, cmd) + res = cmd_func(self, cmd) + + # the handler might be a coroutine: fire it off as a background process + # if so. + + if isawaitable(res): + run_as_background_process( + "replication-" + cmd.get_logcontext_id(), lambda: res + ) + handled = True if not handled: @@ -317,7 +342,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): def _queue_command(self, cmd): """Queue the command until the connection is ready to write to again. """ - logger.debug("[%s] Queing as conn %r, cmd: %r", self.id(), self.state, cmd) + logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd) self.pending_commands.append(cmd) if len(self.pending_commands) > self.max_line_buffer: @@ -336,10 +361,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): for cmd in pending: self.send_command(cmd) - async def on_PING(self, line): + def on_PING(self, line): self.received_ping = True - async def on_ERROR(self, cmd): + def on_ERROR(self, cmd): logger.error("[%s] Remote reported error: %r", self.id(), cmd.data) def pauseProducing(self): @@ -397,6 +422,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): if self.transport: self.transport.unregisterProducer() + # mark the logging context as finished + self._logging_context.__exit__(None, None, None) + def __str__(self): addr = None if self.transport: @@ -431,7 +459,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): self.send_command(ServerCommand(self.server_name)) super().connectionMade() - async def on_NAME(self, cmd): + def on_NAME(self, cmd): logger.info("[%s] Renamed to %r", self.id(), cmd.data) self.name = cmd.data @@ -460,7 +488,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # Once we've connected subscribe to the necessary streams self.replicate() - async def on_SERVER(self, cmd): + def on_SERVER(self, cmd): if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) self.send_error("Wrong remote") diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index e776b63183..f225e533de 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py
@@ -14,12 +14,16 @@ # limitations under the License. import logging +from inspect import isawaitable from typing import TYPE_CHECKING import txredisapi -from synapse.logging.context import make_deferred_yieldable -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable +from synapse.metrics.background_process_metrics import ( + BackgroundProcessLoggingContext, + run_as_background_process, +) from synapse.replication.tcp.commands import ( Command, ReplicateCommand, @@ -66,6 +70,15 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): stream_name = None # type: str outbound_redis_connection = None # type: txredisapi.RedisProtocol + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # a logcontext which we use for processing incoming commands. We declare it as a + # background process so that the CPU stats get reported to prometheus. + self._logging_context = BackgroundProcessLoggingContext( + "replication_command_handler" + ) + def connectionMade(self): logger.info("Connected to redis") super().connectionMade() @@ -92,7 +105,10 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): def messageReceived(self, pattern: str, channel: str, message: str): """Received a message from redis. """ + with PreserveLoggingContext(self._logging_context): + self._parse_and_dispatch_message(message) + def _parse_and_dispatch_message(self, message: str): if message.strip() == "": # Ignore blank lines return @@ -109,42 +125,41 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): # remote instances. tcp_inbound_commands_counter.labels(cmd.NAME, "redis").inc() - # Now lets try and call on_<CMD_NAME> function - run_as_background_process( - "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd - ) + self.handle_command(cmd) - async def handle_command(self, cmd: Command): + def handle_command(self, cmd: Command) -> None: """Handle a command we have received over the replication stream. - By default delegates to on_<COMMAND>, which should return an awaitable. + Delegates to `self.handler.on_<COMMAND>` (which can optionally return an + Awaitable). Args: cmd: received command """ - handled = False - - # First call any command handlers on this instance. These are for redis - # specific handling. - cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) - if cmd_func: - await cmd_func(cmd) - handled = True - # Then call out to the handler. cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) - if cmd_func: - await cmd_func(self, cmd) - handled = True - - if not handled: + if not cmd_func: logger.warning("Unhandled command: %r", cmd) + return + + res = cmd_func(self, cmd) + + # the handler might be a coroutine: fire it off as a background process + # if so. + + if isawaitable(res): + run_as_background_process( + "replication-" + cmd.get_logcontext_id(), lambda: res + ) def connectionLost(self, reason): logger.info("Lost connection to redis") super().connectionLost(reason) self.handler.lost_connection(self) + # mark the logging context as finished + self._logging_context.__exit__(None, None, None) + def send_command(self, cmd: Command): """Send a command if connection has been established. @@ -177,7 +192,7 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory): Args: hs outbound_redis_connection: A connection to redis that will be used to - send outbound commands (this is seperate to the redis connection + send outbound commands (this is separate to the redis connection used to subscribe). """ diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 41569305df..04d894fb3d 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py
@@ -58,7 +58,7 @@ class ReplicationStreamProtocolFactory(Factory): ) -class ReplicationStreamer(object): +class ReplicationStreamer: """Handles replication connections. This needs to be poked when new replication data may be available. When new diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 4acefc8a96..682d47f402 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py
@@ -79,7 +79,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool] UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]] -class Stream(object): +class Stream: """Base class for the streams. Provides a `get_updates()` function that returns new updates since the last @@ -198,26 +198,6 @@ def current_token_without_instance( return lambda instance_name: current_token() -def db_query_to_update_function( - query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]] -) -> UpdateFunction: - """Wraps a db query function which returns a list of rows to make it - suitable for use as an `update_function` for the Stream class - """ - - async def update_function(instance_name, from_token, upto_token, limit): - rows = await query_function(from_token, upto_token, limit) - updates = [(row[0], row[1:]) for row in rows] - limited = False - if len(updates) >= limit: - upto_token = updates[-1][0] - limited = True - - return updates, upto_token, limited - - return update_function - - def make_http_update_function(hs, stream_name: str) -> UpdateFunction: """Makes a suitable function for use as an `update_function` that queries the master process for updates. @@ -264,7 +244,7 @@ class BackfillStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_current_backfill_token), - db_query_to_update_function(store.get_all_new_backfill_event_rows), + store.get_all_new_backfill_event_rows, ) @@ -291,9 +271,7 @@ class PresenceStream(Stream): if hs.config.worker_app is None: # on the master, query the presence handler presence_handler = hs.get_presence_handler() - update_function = db_query_to_update_function( - presence_handler.get_all_presence_updates - ) + update_function = presence_handler.get_all_presence_updates else: # Query master process update_function = make_http_update_function(hs, self.NAME) @@ -316,13 +294,12 @@ class TypingStream(Stream): def __init__(self, hs): typing_handler = hs.get_typing_handler() - if hs.config.worker_app is None: - # on the master, query the typing handler - update_function = db_query_to_update_function( - typing_handler.get_all_typing_updates - ) + writer_instance = hs.config.worker.writers.typing + if writer_instance == hs.get_instance_name(): + # On the writer, query the typing handler + update_function = typing_handler.get_all_typing_updates else: - # Query master process + # Query the typing writer process update_function = make_http_update_function(hs, self.NAME) super().__init__( @@ -352,7 +329,7 @@ class ReceiptsStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_max_receipt_stream_id), - db_query_to_update_function(store.get_all_updated_receipts), + store.get_all_updated_receipts, ) @@ -367,26 +344,17 @@ class PushRulesStream(Stream): def __init__(self, hs): self.store = hs.get_datastore() + super(PushRulesStream, self).__init__( - hs.get_instance_name(), self._current_token, self._update_function + hs.get_instance_name(), + self._current_token, + self.store.get_all_push_rule_updates, ) def _current_token(self, instance_name: str) -> int: - push_rules_token, _ = self.store.get_push_rules_stream_token() + push_rules_token = self.store.get_max_push_rules_stream_id() return push_rules_token - async def _update_function( - self, instance_name: str, from_token: Token, to_token: Token, limit: int - ): - rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) - - limited = False - if len(rows) == limit: - to_token = rows[-1][0] - limited = True - - return [(row[0], (row[2],)) for row in rows], to_token, limited - class PushersStream(Stream): """A user has added/changed/removed a pusher @@ -406,7 +374,7 @@ class PushersStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_pushers_stream_token), - db_query_to_update_function(store.get_all_updated_pushers_rows), + store.get_all_updated_pushers_rows, ) @@ -434,27 +402,13 @@ class CachesStream(Stream): ROW_TYPE = CachesStreamRow def __init__(self, hs): - self.store = hs.get_datastore() + store = hs.get_datastore() super().__init__( hs.get_instance_name(), - self.store.get_cache_stream_token, - self._update_function, + store.get_cache_stream_token_for_writer, + store.get_all_updated_caches, ) - async def _update_function( - self, instance_name: str, from_token: int, upto_token: int, limit: int - ): - rows = await self.store.get_all_updated_caches( - instance_name, from_token, upto_token, limit - ) - updates = [(row[0], row[1:]) for row in rows] - limited = False - if len(updates) >= limit: - upto_token = updates[-1][0] - limited = True - - return updates, upto_token, limited - class PublicRoomsStream(Stream): """The public rooms list changed @@ -478,7 +432,7 @@ class PublicRoomsStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_current_public_room_stream_id), - db_query_to_update_function(store.get_all_new_public_rooms), + store.get_all_new_public_rooms, ) @@ -499,7 +453,7 @@ class DeviceListsStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_device_stream_token), - db_query_to_update_function(store.get_all_device_list_changes_for_remotes), + store.get_all_device_list_changes_for_remotes, ) @@ -517,7 +471,7 @@ class ToDeviceStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_to_device_stream_token), - db_query_to_update_function(store.get_all_new_device_messages), + store.get_all_new_device_messages, ) @@ -537,7 +491,7 @@ class TagAccountDataStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_max_account_data_stream_id), - db_query_to_update_function(store.get_all_updated_tags), + store.get_all_updated_tags, ) @@ -625,7 +579,7 @@ class GroupServerStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_group_stream_token), - db_query_to_update_function(store.get_all_groups_changes), + store.get_all_groups_changes, ) @@ -643,7 +597,5 @@ class UserSignatureStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_device_stream_token), - db_query_to_update_function( - store.get_all_user_signature_changes_for_remotes - ), + store.get_all_user_signature_changes_for_remotes, ) diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index f370390331..f929fc3954 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py
@@ -13,16 +13,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import heapq -from collections import Iterable +from collections.abc import Iterable from typing import List, Tuple, Type import attr from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance - """Handling of the 'events' replication stream This stream contains rows of various types. Each row therefore contains a 'type' @@ -51,20 +49,20 @@ data part are: @attr.s(slots=True, frozen=True) -class EventsStreamRow(object): +class EventsStreamRow: """A parsed row from the events replication stream""" type = attr.ib() # str: the TypeId of one of the *EventsStreamRows data = attr.ib() # BaseEventsStreamRow -class BaseEventsStreamRow(object): +class BaseEventsStreamRow: """Base class for rows to be sent in the events stream. Specifies how to identify, serialize and deserialize the different types. """ - # Unique string that ids the type. Must be overriden in sub classes. + # Unique string that ids the type. Must be overridden in sub classes. TypeId = None # type: str @classmethod