diff options
Diffstat (limited to 'synapse/replication/tcp/client.py')
-rw-r--r-- | synapse/replication/tcp/client.py | 22 |
1 files changed, 15 insertions, 7 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index a44ceb00e7..fead78388c 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -16,10 +16,17 @@ """ import logging +from typing import Dict from twisted.internet import defer from twisted.internet.protocol import ReconnectingClientFactory +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.tcp.protocol import ( + AbstractReplicationClientHandler, + ClientReplicationStreamProtocol, +) + from .commands import ( FederationAckCommand, InvalidateCacheCommand, @@ -27,7 +34,6 @@ from .commands import ( UserIpCommand, UserSyncCommand, ) -from .protocol import ClientReplicationStreamProtocol logger = logging.getLogger(__name__) @@ -42,7 +48,7 @@ class ReplicationClientFactory(ReconnectingClientFactory): maxDelay = 30 # Try at least once every N seconds - def __init__(self, hs, client_name, handler): + def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler): self.client_name = client_name self.handler = handler self.server_name = hs.config.server_name @@ -68,13 +74,13 @@ class ReplicationClientFactory(ReconnectingClientFactory): ReconnectingClientFactory.clientConnectionFailed(self, connector, reason) -class ReplicationClientHandler(object): +class ReplicationClientHandler(AbstractReplicationClientHandler): """A base handler that can be passed to the ReplicationClientFactory. By default proxies incoming replication data to the SlaveStore. """ - def __init__(self, store): + def __init__(self, store: BaseSlavedStore): self.store = store # The current connection. None if we are currently (re)connecting @@ -138,11 +144,13 @@ class ReplicationClientHandler(object): if d: d.callback(data) - def get_streams_to_replicate(self): + def get_streams_to_replicate(self) -> Dict[str, int]: """Called when a new connection has been established and we need to subscribe to streams. - Returns a dictionary of stream name to token. + Returns: + map from stream name to the most recent update we have for + that stream (ie, the point we want to start replicating from) """ args = self.store.stream_positions() user_account_data = args.pop("user_account_data", None) @@ -168,7 +176,7 @@ class ReplicationClientHandler(object): if self.connection: self.connection.send_command(cmd) else: - logger.warn("Queuing command as not connected: %r", cmd.NAME) + logger.warning("Queuing command as not connected: %r", cmd.NAME) self.pending_commands.append(cmd) def send_federation_ack(self, token): |