diff options
Diffstat (limited to 'synapse/replication')
-rw-r--r-- | synapse/replication/http/_base.py | 2 | ||||
-rw-r--r-- | synapse/replication/http/federation.py | 2 | ||||
-rw-r--r-- | synapse/replication/http/membership.py | 2 | ||||
-rw-r--r-- | synapse/replication/tcp/client.py | 4 | ||||
-rw-r--r-- | synapse/replication/tcp/commands.py | 36 | ||||
-rw-r--r-- | synapse/replication/tcp/handler.py | 30 | ||||
-rw-r--r-- | synapse/replication/tcp/protocol.py | 10 | ||||
-rw-r--r-- | synapse/replication/tcp/redis.py | 40 | ||||
-rw-r--r-- | synapse/replication/tcp/resource.py | 47 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/_base.py | 11 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/events.py | 6 |
11 files changed, 152 insertions, 38 deletions
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 64edadb624..2b3972cb14 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -92,7 +92,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): if self.CACHE: self.response_cache = ResponseCache( hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000 - ) + ) # type: ResponseCache[str] # We reserve `instance_name` as a parameter to sending requests, so we # assert here that sub classes don't try and use the name. diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 5393b9a9e7..b4f4a68b5c 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -62,7 +62,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): self.store = hs.get_datastore() self.storage = hs.get_storage() self.clock = hs.get_clock() - self.federation_handler = hs.get_handlers().federation_handler + self.federation_handler = hs.get_federation_handler() @staticmethod async def _serialize_payload(store, room_id, event_and_contexts, backfilled): diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 30680baee8..e7cc74a5d2 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -47,7 +47,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): def __init__(self, hs): super().__init__(hs) - self.federation_handler = hs.get_handlers().federation_handler + self.federation_handler = hs.get_federation_handler() self.store = hs.get_datastore() self.clock = hs.get_clock() diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index e165429cad..e27ee216f0 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -191,6 +191,10 @@ class ReplicationDataHandler: async def on_position(self, stream_name: str, instance_name: str, token: int): self.store.process_replication_rows(stream_name, instance_name, token, []) + # We poke the generic "replication" notifier to wake anything up that + # may be streaming. + self.notifier.notify_replication() + def on_remote_server_up(self, server: str): """Called when get a new REMOTE_SERVER_UP command.""" diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 8cd47770c1..ac532ed588 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -141,15 +141,23 @@ class RdataCommand(Command): class PositionCommand(Command): - """Sent by the server to tell the client the stream position without - needing to send an RDATA. + """Sent by an instance to tell others the stream position without needing to + send an RDATA. + + Two tokens are sent, the new position and the last position sent by the + instance (in an RDATA or other POSITION). The tokens are chosen so that *no* + rows were written by the instance between the `prev_token` and `new_token`. + (If an instance hasn't sent a position before then the new position can be + used for both.) Format:: - POSITION <stream_name> <instance_name> <token> + POSITION <stream_name> <instance_name> <prev_token> <new_token> - On receipt of a POSITION command clients should check if they have missed - any updates, and if so then fetch them out of band. + On receipt of a POSITION command instances should check if they have missed + any updates, and if so then fetch them out of band. Instances can check this + by comparing their view of the current token for the sending instance with + the included `prev_token`. The `<instance_name>` is the process that sent the command and is the source of the stream. @@ -157,18 +165,26 @@ class PositionCommand(Command): NAME = "POSITION" - def __init__(self, stream_name, instance_name, token): + def __init__(self, stream_name, instance_name, prev_token, new_token): self.stream_name = stream_name self.instance_name = instance_name - self.token = token + self.prev_token = prev_token + self.new_token = new_token @classmethod def from_line(cls, line): - stream_name, instance_name, token = line.split(" ", 2) - return cls(stream_name, instance_name, int(token)) + stream_name, instance_name, prev_token, new_token = line.split(" ", 3) + return cls(stream_name, instance_name, int(prev_token), int(new_token)) def to_line(self): - return " ".join((self.stream_name, self.instance_name, str(self.token))) + return " ".join( + ( + self.stream_name, + self.instance_name, + str(self.prev_token), + str(self.new_token), + ) + ) class ErrorCommand(_SimpleCommand): diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index b323841f73..95e5502bf2 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -101,8 +101,9 @@ class ReplicationCommandHandler: self._streams_to_replicate = [] # type: List[Stream] for stream in self._streams.values(): - if stream.NAME == CachesStream.NAME: - # All workers can write to the cache invalidation stream. + if hs.config.redis.redis_enabled and stream.NAME == CachesStream.NAME: + # All workers can write to the cache invalidation stream when + # using redis. self._streams_to_replicate.append(stream) continue @@ -251,10 +252,9 @@ class ReplicationCommandHandler: using TCP. """ if hs.config.redis.redis_enabled: - import txredisapi - from synapse.replication.tcp.redis import ( RedisDirectTcpReplicationClientFactory, + lazyConnection, ) logger.info( @@ -271,7 +271,8 @@ class ReplicationCommandHandler: # connection after SUBSCRIBE is called). # First create the connection for sending commands. - outbound_redis_connection = txredisapi.lazyConnection( + outbound_redis_connection = lazyConnection( + reactor=hs.get_reactor(), host=hs.config.redis_host, port=hs.config.redis_port, password=hs.config.redis.redis_password, @@ -313,11 +314,14 @@ class ReplicationCommandHandler: # We respond with current position of all streams this instance # replicates. for stream in self.get_streams_to_replicate(): + # Note that we use the current token as the prev token here (rather + # than stream.last_token), as we can't be sure that there have been + # no rows written between last token and the current token (since we + # might be racing with the replication sending bg process). + current_token = stream.current_token(self._instance_name) self.send_command( PositionCommand( - stream.NAME, - self._instance_name, - stream.current_token(self._instance_name), + stream.NAME, self._instance_name, current_token, current_token, ) ) @@ -511,16 +515,16 @@ class ReplicationCommandHandler: # If the position token matches our current token then we're up to # date and there's nothing to do. Otherwise, fetch all updates # between then and now. - missing_updates = cmd.token != current_token + missing_updates = cmd.prev_token != current_token while missing_updates: logger.info( "Fetching replication rows for '%s' between %i and %i", stream_name, current_token, - cmd.token, + cmd.new_token, ) (updates, current_token, missing_updates) = await stream.get_updates_since( - cmd.instance_name, current_token, cmd.token + cmd.instance_name, current_token, cmd.new_token ) # TODO: add some tests for this @@ -536,11 +540,11 @@ class ReplicationCommandHandler: [stream.parse_row(row) for row in rows], ) - logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token) + logger.info("Caught up with stream '%s' to %i", stream_name, cmd.new_token) # We've now caught up to position sent to us, notify handler. await self._replication_data_handler.on_position( - cmd.stream_name, cmd.instance_name, cmd.token + cmd.stream_name, cmd.instance_name, cmd.new_token ) self._streams_by_connection.setdefault(conn, set()).add(stream_name) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 0b0d204e64..a509e599c2 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -51,10 +51,11 @@ import fcntl import logging import struct from inspect import isawaitable -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Optional from prometheus_client import Counter +from twisted.internet import task from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure @@ -152,9 +153,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.last_received_command = self.clock.time_msec() self.last_sent_command = 0 - self.time_we_closed = None # When we requested the connection be closed + # When we requested the connection be closed + self.time_we_closed = None # type: Optional[int] - self.received_ping = False # Have we reecived a ping from the other side + self.received_ping = False # Have we received a ping from the other side self.state = ConnectionStates.CONNECTING @@ -165,7 +167,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.pending_commands = [] # type: List[Command] # The LoopingCall for sending pings. - self._send_ping_loop = None + self._send_ping_loop = None # type: Optional[task.LoopingCall] # a logcontext which we use for processing incoming commands. We declare it as a # background process so that the CPU stats get reported to prometheus. diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index f225e533de..de19705c1f 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -15,7 +15,7 @@ import logging from inspect import isawaitable -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import txredisapi @@ -228,3 +228,41 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory): p.password = self.password return p + + +def lazyConnection( + reactor, + host: str = "localhost", + port: int = 6379, + dbid: Optional[int] = None, + reconnect: bool = True, + charset: str = "utf-8", + password: Optional[str] = None, + connectTimeout: Optional[int] = None, + replyTimeout: Optional[int] = None, + convertNumbers: bool = True, +) -> txredisapi.RedisProtocol: + """Equivalent to `txredisapi.lazyConnection`, except allows specifying a + reactor. + """ + + isLazy = True + poolsize = 1 + + uuid = "%s:%d" % (host, port) + factory = txredisapi.RedisFactory( + uuid, + dbid, + poolsize, + isLazy, + txredisapi.ConnectionHandler, + charset, + password, + replyTimeout, + convertNumbers, + ) + factory.continueTrying = reconnect + for x in range(poolsize): + reactor.connectTCP(host, port, factory, connectTimeout) + + return factory.handler diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 687984e7a8..666c13fdb7 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -23,7 +23,9 @@ from prometheus_client import Counter from twisted.internet.protocol import Factory from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.replication.tcp.commands import PositionCommand from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol +from synapse.replication.tcp.streams import EventsStream from synapse.util.metrics import Measure stream_updates_counter = Counter( @@ -84,6 +86,23 @@ class ReplicationStreamer: # Set of streams to replicate. self.streams = self.command_handler.get_streams_to_replicate() + # If we have streams then we must have redis enabled or on master + assert ( + not self.streams + or hs.config.redis.redis_enabled + or not hs.config.worker.worker_app + ) + + # If we are replicating an event stream we want to periodically check if + # we should send updated POSITIONs. We do this as a looping call rather + # explicitly poking when the position advances (without new data to + # replicate) to reduce replication traffic (otherwise each writer would + # likely send a POSITION for each new event received over replication). + # + # Note that if the position hasn't advanced then we won't send anything. + if any(EventsStream.NAME == s.NAME for s in self.streams): + self.clock.looping_call(self.on_notifier_poke, 1000) + def on_notifier_poke(self): """Checks if there is actually any new data and sends it to the connections if there are. @@ -91,7 +110,7 @@ class ReplicationStreamer: This should get called each time new data is available, even if it is currently being executed, so that nothing gets missed """ - if not self.command_handler.connected(): + if not self.command_handler.connected() or not self.streams: # Don't bother if nothing is listening. We still need to advance # the stream tokens otherwise they'll fall behind forever for stream in self.streams: @@ -136,6 +155,8 @@ class ReplicationStreamer: self._replication_torture_level / 1000.0 ) + last_token = stream.last_token + logger.debug( "Getting stream: %s: %s -> %s", stream.NAME, @@ -159,6 +180,30 @@ class ReplicationStreamer: ) stream_updates_counter.labels(stream.NAME).inc(len(updates)) + else: + # The token has advanced but there is no data to + # send, so we send a `POSITION` to inform other + # workers of the updated position. + if stream.NAME == EventsStream.NAME: + # XXX: We only do this for the EventStream as it + # turns out that e.g. account data streams share + # their "current token" with each other, meaning + # that it is *not* safe to send a POSITION. + logger.info( + "Sending position: %s -> %s", + stream.NAME, + current_token, + ) + self.command_handler.send_command( + PositionCommand( + stream.NAME, + self._instance_name, + last_token, + current_token, + ) + ) + continue + # Some streams return multiple rows with the same stream IDs, # we need to make sure they get sent out in batches. We do # this by setting the current token to all but the last of diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 54dccd15a6..61b282ab2d 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -240,13 +240,18 @@ class BackfillStream(Stream): ROW_TYPE = BackfillStreamRow def __init__(self, hs): - store = hs.get_datastore() + self.store = hs.get_datastore() super().__init__( hs.get_instance_name(), - current_token_without_instance(store.get_current_backfill_token), - store.get_all_new_backfill_event_rows, + self._current_token, + self.store.get_all_new_backfill_event_rows, ) + def _current_token(self, instance_name: str) -> int: + # The backfill stream over replication operates on *positive* numbers, + # which means we need to negate it. + return -self.store._backfill_id_gen.get_current_token_for_writer(instance_name) + class PresenceStream(Stream): PresenceStreamRow = namedtuple( diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index ccc7ca30d8..82e9e0d64e 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -155,7 +155,7 @@ class EventsStream(Stream): # now we fetch up to that many rows from the events table event_rows = await self._store.get_all_new_forward_event_rows( - from_token, current_token, target_row_count + instance_name, from_token, current_token, target_row_count ) # type: List[Tuple] # we rely on get_all_new_forward_event_rows strictly honouring the limit, so @@ -180,7 +180,7 @@ class EventsStream(Stream): upper_limit, state_rows_limited, ) = await self._store.get_all_updated_current_state_deltas( - from_token, upper_limit, target_row_count + instance_name, from_token, upper_limit, target_row_count ) limited = limited or state_rows_limited @@ -189,7 +189,7 @@ class EventsStream(Stream): # not to bother with the limit. ex_outliers_rows = await self._store.get_ex_outlier_stream_rows( - from_token, upper_limit + instance_name, from_token, upper_limit ) # type: List[Tuple] # we now need to turn the raw database rows returned into tuples suitable |