diff options
Diffstat (limited to 'synapse/replication/tcp/handler.py')
-rw-r--r-- | synapse/replication/tcp/handler.py | 177 |
1 files changed, 159 insertions, 18 deletions
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 12a1cfd6d1..8ec0119697 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -19,8 +19,10 @@ from typing import Any, Callable, Dict, List, Optional, Set from prometheus_client import Counter +from synapse.metrics import LaterGauge from synapse.replication.tcp.client import ReplicationClientFactory from synapse.replication.tcp.commands import ( + ClearUserSyncsCommand, Command, FederationAckCommand, InvalidateCacheCommand, @@ -28,10 +30,12 @@ from synapse.replication.tcp.commands import ( RdataCommand, RemoteServerUpCommand, RemovePusherCommand, + ReplicateCommand, SyncCommand, UserIpCommand, UserSyncCommand, ) +from synapse.replication.tcp.protocol import AbstractConnection from synapse.replication.tcp.streams import STREAMS_MAP, Stream from synapse.util.async_helpers import Linearizer @@ -42,6 +46,13 @@ logger = logging.getLogger(__name__) inbound_rdata_count = Counter( "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"] ) +user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "") +federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "") +remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "") +invalidate_cache_counter = Counter( + "synapse_replication_tcp_resource_invalidate_cache", "" +) +user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") class ReplicationCommandHandler: @@ -52,6 +63,10 @@ class ReplicationCommandHandler: def __init__(self, hs): self._replication_data_handler = hs.get_replication_data_handler() self._presence_handler = hs.get_presence_handler() + self._store = hs.get_datastore() + self._notifier = hs.get_notifier() + self._clock = hs.get_clock() + self._instance_id = hs.get_instance_id() # Set of streams that we've caught up with. self._streams_connected = set() # type: Set[str] @@ -69,8 +84,26 @@ class ReplicationCommandHandler: # The factory used to create connections. self._factory = None # type: Optional[ReplicationClientFactory] - # The current connection. None if we are currently (re)connecting - self._connection = None + # The currently connected connections. + self._connections = [] # type: List[AbstractConnection] + + LaterGauge( + "synapse_replication_tcp_resource_total_connections", + "", + [], + lambda: len(self._connections), + ) + + self._is_master = hs.config.worker_app is None + + self._federation_sender = None + if self._is_master and not hs.config.send_federation: + self._federation_sender = hs.get_federation_sender() + + self._server_notices_sender = None + if self._is_master: + self._server_notices_sender = hs.get_server_notices_sender() + self._notifier.add_remote_server_up_callback(self.send_remote_server_up) def start_replication(self, hs): """Helper method to start a replication connection to the remote server @@ -82,6 +115,70 @@ class ReplicationCommandHandler: port = hs.config.worker_replication_port hs.get_reactor().connectTCP(host, port, self._factory) + async def on_REPLICATE(self, cmd: ReplicateCommand): + # We only want to announce positions by the writer of the streams. + # Currently this is just the master process. + if not self._is_master: + return + + for stream_name, stream in self._streams.items(): + current_token = stream.current_token() + self.send_command(PositionCommand(stream_name, current_token)) + + async def on_USER_SYNC(self, cmd: UserSyncCommand): + user_sync_counter.inc() + + if self._is_master: + await self._presence_handler.update_external_syncs_row( + cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms + ) + + async def on_CLEAR_USER_SYNC(self, cmd: ClearUserSyncsCommand): + if self._is_master: + await self._presence_handler.update_external_syncs_clear(cmd.instance_id) + + async def on_FEDERATION_ACK(self, cmd: FederationAckCommand): + federation_ack_counter.inc() + + if self._federation_sender: + self._federation_sender.federation_ack(cmd.token) + + async def on_REMOVE_PUSHER(self, cmd: RemovePusherCommand): + remove_pusher_counter.inc() + + if self._is_master: + await self._store.delete_pusher_by_app_id_pushkey_user_id( + app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id + ) + + self._notifier.on_new_replication_data() + + async def on_INVALIDATE_CACHE(self, cmd: InvalidateCacheCommand): + invalidate_cache_counter.inc() + + if self._is_master: + # We invalidate the cache locally, but then also stream that to other + # workers. + await self._store.invalidate_cache_and_stream( + cmd.cache_func, tuple(cmd.keys) + ) + + async def on_USER_IP(self, cmd: UserIpCommand): + user_ip_cache_counter.inc() + + if self._is_master: + await self._store.insert_client_ip( + cmd.user_id, + cmd.access_token, + cmd.ip, + cmd.user_agent, + cmd.device_id, + cmd.last_seen, + ) + + if self._server_notices_sender: + await self._server_notices_sender.on_user_ip(cmd.user_id) + async def on_RDATA(self, cmd: RdataCommand): stream_name = cmd.stream_name inbound_rdata_count.labels(stream_name).inc() @@ -174,6 +271,9 @@ class ReplicationCommandHandler: """"Called when get a new REMOTE_SERVER_UP command.""" self._replication_data_handler.on_remote_server_up(cmd.data) + if self._is_master: + self._notifier.notify_remote_server_up(cmd.data) + def get_currently_syncing_users(self): """Get the list of currently syncing users (if any). This is called when a connection has been established and we need to send the @@ -181,29 +281,63 @@ class ReplicationCommandHandler: """ return self._presence_handler.get_currently_syncing_users() - def update_connection(self, connection): - """Called when a connection has been established (or lost with None). + def new_connection(self, connection: AbstractConnection): + """Called when we have a new connection. """ - self._connection = connection + self._connections.append(connection) + + # If we are connected to replication as a client (rather than a server) + # we need to reset the reconnection delay on the client factory (which + # is used to do exponential back off when the connection drops). + # + # Ideally we would reset the delay when we've "fully established" the + # connection (for some definition thereof) to stop us from tightlooping + # on reconnection if something fails after this point and we drop the + # connection. Unfortunately, we don't really have a better definition of + # "fully established" than the connection being established. + if self._factory: + self._factory.resetDelay() + + # Tell the server if we have any users currently syncing (should only + # happen on synchrotrons) + currently_syncing = self.get_currently_syncing_users() + now = self._clock.time_msec() + for user_id in currently_syncing: + connection.send_command( + UserSyncCommand(self._instance_id, user_id, True, now) + ) - def finished_connecting(self): - """Called when we have successfully subscribed and caught up to all - streams we're interested in. + def lost_connection(self, connection: AbstractConnection): + """Called when a connection is closed/lost. """ - logger.info("Finished connecting to server") + try: + self._connections.remove(connection) + except ValueError: + pass - # We don't reset the delay any earlier as otherwise if there is a - # problem during start up we'll end up tight looping connecting to the - # server. - if self._factory: - self._factory.resetDelay() + def connected(self) -> bool: + """Do we have any replication connections open? + + Is used by e.g. `ReplicationStreamer` to no-op if nothing is connected. + """ + return bool(self._connections) def send_command(self, cmd: Command): - """Send a command to master (when we get establish a connection if we - don't have one already.) + """Send a command to all connected connections. """ - if self._connection: - self._connection.send_command(cmd) + if self._connections: + for connection in self._connections: + try: + connection.send_command(cmd) + except Exception: + # We probably want to catch some types of exceptions here + # and log them as warnings (e.g. connection gone), but I + # can't find what those exception types they would be. + logger.exception( + "Failed to write command %s to connection %s", + cmd.NAME, + connection, + ) else: logger.warning("Dropping command as not connected: %r", cmd.NAME) @@ -250,3 +384,10 @@ class ReplicationCommandHandler: def send_remote_server_up(self, server: str): self.send_command(RemoteServerUpCommand(server)) + + def stream_update(self, stream_name: str, token: str, data: Any): + """Called when a new update is available to stream to clients. + + We need to check if the client is interested in the stream or not + """ + self.send_command(RdataCommand(stream_name, token, data)) |