diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 206dc3b397..02ab5b66ea 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -16,18 +16,26 @@
"""
import logging
+from typing import Dict, List, Optional
from twisted.internet import defer
from twisted.internet.protocol import ReconnectingClientFactory
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.tcp.protocol import (
+ AbstractReplicationClientHandler,
+ ClientReplicationStreamProtocol,
+)
+
from .commands import (
+ Command,
FederationAckCommand,
InvalidateCacheCommand,
+ RemoteServerUpCommand,
RemovePusherCommand,
UserIpCommand,
UserSyncCommand,
)
-from .protocol import ClientReplicationStreamProtocol
logger = logging.getLogger(__name__)
@@ -39,9 +47,11 @@ class ReplicationClientFactory(ReconnectingClientFactory):
Accepts a handler that will be called when new data is available or data
is required.
"""
- maxDelay = 30 # Try at least once every N seconds
- def __init__(self, hs, client_name, handler):
+ initialDelay = 0.1
+ maxDelay = 1 # Try at least once every N seconds
+
+ def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler):
self.client_name = client_name
self.handler = handler
self.server_name = hs.config.server_name
@@ -64,17 +74,16 @@ class ReplicationClientFactory(ReconnectingClientFactory):
def clientConnectionFailed(self, connector, reason):
logger.error("Failed to connect to replication: %r", reason)
- ReconnectingClientFactory.clientConnectionFailed(
- self, connector, reason
- )
+ ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
-class ReplicationClientHandler(object):
+class ReplicationClientHandler(AbstractReplicationClientHandler):
"""A base handler that can be passed to the ReplicationClientFactory.
By default proxies incoming replication data to the SlaveStore.
"""
- def __init__(self, store):
+
+ def __init__(self, store: BaseSlavedStore):
self.store = store
# The current connection. None if we are currently (re)connecting
@@ -82,15 +91,15 @@ class ReplicationClientHandler(object):
# Any pending commands to be sent once a new connection has been
# established
- self.pending_commands = []
+ self.pending_commands = [] # type: List[Command]
# Map from string -> deferred, to wake up when receiveing a SYNC with
# the given string.
# Used for tests.
- self.awaiting_syncs = {}
+ self.awaiting_syncs = {} # type: Dict[str, defer.Deferred]
# The factory used to create connections.
- self.factory = None
+ self.factory = None # type: Optional[ReplicationClientFactory]
def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server
@@ -102,7 +111,7 @@ class ReplicationClientHandler(object):
port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self.factory)
- def on_rdata(self, stream_name, token, rows):
+ async def on_rdata(self, stream_name, token, rows):
"""Called to handle a batch of replication data with a given stream token.
By default this just pokes the slave store. Can be overridden in subclasses to
@@ -113,20 +122,17 @@ class ReplicationClientHandler(object):
token (int): stream token for this batch of rows
rows (list): a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
-
- Returns:
- Deferred|None
"""
logger.debug("Received rdata %s -> %s", stream_name, token)
- return self.store.process_replication_rows(stream_name, token, rows)
+ self.store.process_replication_rows(stream_name, token, rows)
- def on_position(self, stream_name, token):
+ async def on_position(self, stream_name, token):
"""Called when we get new position data. By default this just pokes
the slave store.
Can be overriden in subclasses to handle more.
"""
- return self.store.process_replication_rows(stream_name, token, [])
+ self.store.process_replication_rows(stream_name, token, [])
def on_sync(self, data):
"""When we received a SYNC we wake up any deferreds that were waiting
@@ -138,11 +144,16 @@ class ReplicationClientHandler(object):
if d:
d.callback(data)
- def get_streams_to_replicate(self):
+ def on_remote_server_up(self, server: str):
+ """Called when get a new REMOTE_SERVER_UP command."""
+
+ def get_streams_to_replicate(self) -> Dict[str, int]:
"""Called when a new connection has been established and we need to
subscribe to streams.
- Returns a dictionary of stream name to token.
+ Returns:
+ map from stream name to the most recent update we have for
+ that stream (ie, the point we want to start replicating from)
"""
args = self.store.stream_positions()
user_account_data = args.pop("user_account_data", None)
@@ -168,7 +179,7 @@ class ReplicationClientHandler(object):
if self.connection:
self.connection.send_command(cmd)
else:
- logger.warn("Queuing command as not connected: %r", cmd.NAME)
+ logger.warning("Queuing command as not connected: %r", cmd.NAME)
self.pending_commands.append(cmd)
def send_federation_ack(self, token):
@@ -200,6 +211,9 @@ class ReplicationClientHandler(object):
cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
self.send_command(cmd)
+ def send_remote_server_up(self, server: str):
+ self.send_command(RemoteServerUpCommand(server))
+
def await_sync(self, data):
"""Returns a deferred that is resolved when we receive a SYNC command
with given data.
@@ -226,4 +240,5 @@ class ReplicationClientHandler(object):
# We don't reset the delay any earlier as otherwise if there is a
# problem during start up we'll end up tight looping connecting to the
# server.
- self.factory.resetDelay()
+ if self.factory:
+ self.factory.resetDelay()
|