diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index b64f3f44b5..afaf002fe6 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -48,7 +48,7 @@ indicate which side is sending, these are *not* included on the wire::
> ERROR server stopping
* connection closed by server *
"""
-
+import abc
import fcntl
import logging
import struct
@@ -65,6 +65,7 @@ from twisted.python.failure import Failure
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util import Clock
from synapse.util.stringutils import random_string
from .commands import (
@@ -558,11 +559,80 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.streamer.lost_connection(self)
+class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
+ """
+ The interface for the handler that should be passed to
+ ClientReplicationStreamProtocol
+ """
+
+ @abc.abstractmethod
+ def on_rdata(self, stream_name, token, rows):
+ """Called to handle a batch of replication data with a given stream token.
+
+ Args:
+ stream_name (str): name of the replication stream for this batch of rows
+ token (int): stream token for this batch of rows
+ rows (list): a list of Stream.ROW_TYPE objects as returned by
+ Stream.parse_row.
+
+ Returns:
+ Deferred|None
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def on_position(self, stream_name, token):
+ """Called when we get new position data."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def on_sync(self, data):
+ """Called when get a new SYNC command."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_streams_to_replicate(self):
+ """Called when a new connection has been established and we need to
+ subscribe to streams.
+
+ Returns:
+ map from stream name to the most recent update we have for
+ that stream (ie, the point we want to start replicating from)
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_currently_syncing_users(self):
+ """Get the list of currently syncing users (if any). This is called
+ when a connection has been established and we need to send the
+ currently syncing users."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def update_connection(self, connection):
+ """Called when a connection has been established (or lost with None).
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def finished_connecting(self):
+ """Called when we have successfully subscribed and caught up to all
+ streams we're interested in.
+ """
+ raise NotImplementedError()
+
+
class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS
- def __init__(self, client_name, server_name, clock, handler):
+ def __init__(
+ self,
+ client_name: str,
+ server_name: str,
+ clock: Clock,
+ handler: AbstractReplicationClientHandler,
+ ):
BaseReplicationStreamProtocol.__init__(self, clock)
self.client_name = client_name
|