diff options
Diffstat (limited to 'synapse/replication/tcp')
-rw-r--r-- | synapse/replication/tcp/client.py | 20 | ||||
-rw-r--r-- | synapse/replication/tcp/protocol.py | 74 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/_base.py | 7 |
3 files changed, 89 insertions, 12 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 563ce0fc53..fead78388c 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -16,10 +16,17 @@ """ import logging +from typing import Dict from twisted.internet import defer from twisted.internet.protocol import ReconnectingClientFactory +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.tcp.protocol import ( + AbstractReplicationClientHandler, + ClientReplicationStreamProtocol, +) + from .commands import ( FederationAckCommand, InvalidateCacheCommand, @@ -27,7 +34,6 @@ from .commands import ( UserIpCommand, UserSyncCommand, ) -from .protocol import ClientReplicationStreamProtocol logger = logging.getLogger(__name__) @@ -42,7 +48,7 @@ class ReplicationClientFactory(ReconnectingClientFactory): maxDelay = 30 # Try at least once every N seconds - def __init__(self, hs, client_name, handler): + def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler): self.client_name = client_name self.handler = handler self.server_name = hs.config.server_name @@ -68,13 +74,13 @@ class ReplicationClientFactory(ReconnectingClientFactory): ReconnectingClientFactory.clientConnectionFailed(self, connector, reason) -class ReplicationClientHandler(object): +class ReplicationClientHandler(AbstractReplicationClientHandler): """A base handler that can be passed to the ReplicationClientFactory. By default proxies incoming replication data to the SlaveStore. """ - def __init__(self, store): + def __init__(self, store: BaseSlavedStore): self.store = store # The current connection. None if we are currently (re)connecting @@ -138,11 +144,13 @@ class ReplicationClientHandler(object): if d: d.callback(data) - def get_streams_to_replicate(self): + def get_streams_to_replicate(self) -> Dict[str, int]: """Called when a new connection has been established and we need to subscribe to streams. - Returns a dictionary of stream name to token. + Returns: + map from stream name to the most recent update we have for + that stream (ie, the point we want to start replicating from) """ args = self.store.stream_positions() user_account_data = args.pop("user_account_data", None) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index b64f3f44b5..afaf002fe6 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -48,7 +48,7 @@ indicate which side is sending, these are *not* included on the wire:: > ERROR server stopping * connection closed by server * """ - +import abc import fcntl import logging import struct @@ -65,6 +65,7 @@ from twisted.python.failure import Failure from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util import Clock from synapse.util.stringutils import random_string from .commands import ( @@ -558,11 +559,80 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): self.streamer.lost_connection(self) +class AbstractReplicationClientHandler(metaclass=abc.ABCMeta): + """ + The interface for the handler that should be passed to + ClientReplicationStreamProtocol + """ + + @abc.abstractmethod + def on_rdata(self, stream_name, token, rows): + """Called to handle a batch of replication data with a given stream token. + + Args: + stream_name (str): name of the replication stream for this batch of rows + token (int): stream token for this batch of rows + rows (list): a list of Stream.ROW_TYPE objects as returned by + Stream.parse_row. + + Returns: + Deferred|None + """ + raise NotImplementedError() + + @abc.abstractmethod + def on_position(self, stream_name, token): + """Called when we get new position data.""" + raise NotImplementedError() + + @abc.abstractmethod + def on_sync(self, data): + """Called when get a new SYNC command.""" + raise NotImplementedError() + + @abc.abstractmethod + def get_streams_to_replicate(self): + """Called when a new connection has been established and we need to + subscribe to streams. + + Returns: + map from stream name to the most recent update we have for + that stream (ie, the point we want to start replicating from) + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_currently_syncing_users(self): + """Get the list of currently syncing users (if any). This is called + when a connection has been established and we need to send the + currently syncing users.""" + raise NotImplementedError() + + @abc.abstractmethod + def update_connection(self, connection): + """Called when a connection has been established (or lost with None). + """ + raise NotImplementedError() + + @abc.abstractmethod + def finished_connecting(self): + """Called when we have successfully subscribed and caught up to all + streams we're interested in. + """ + raise NotImplementedError() + + class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS - def __init__(self, client_name, server_name, clock, handler): + def __init__( + self, + client_name: str, + server_name: str, + clock: Clock, + handler: AbstractReplicationClientHandler, + ): BaseReplicationStreamProtocol.__init__(self, clock) self.client_name = client_name diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 9e45429d49..8512923eae 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -88,8 +88,7 @@ TagAccountDataStreamRow = namedtuple( "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict ) AccountDataStreamRow = namedtuple( - "AccountDataStream", - ("user_id", "room_id", "data_type", "data"), # str # str # str # dict + "AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str ) GroupsStreamRow = namedtuple( "GroupsStreamRow", @@ -421,8 +420,8 @@ class AccountDataStream(Stream): results = list(room_results) results.extend( - (stream_id, user_id, None, account_data_type, content) - for stream_id, user_id, account_data_type, content in global_results + (stream_id, user_id, None, account_data_type) + for stream_id, user_id, account_data_type in global_results ) return results |