diff options
Diffstat (limited to 'synapse/replication/tcp')
-rw-r--r-- | synapse/replication/tcp/client.py | 20 | ||||
-rw-r--r-- | synapse/replication/tcp/protocol.py | 74 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/__init__.py | 1 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/_base.py | 18 |
4 files changed, 105 insertions, 8 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 563ce0fc53..fead78388c 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -16,10 +16,17 @@ """ import logging +from typing import Dict from twisted.internet import defer from twisted.internet.protocol import ReconnectingClientFactory +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.tcp.protocol import ( + AbstractReplicationClientHandler, + ClientReplicationStreamProtocol, +) + from .commands import ( FederationAckCommand, InvalidateCacheCommand, @@ -27,7 +34,6 @@ from .commands import ( UserIpCommand, UserSyncCommand, ) -from .protocol import ClientReplicationStreamProtocol logger = logging.getLogger(__name__) @@ -42,7 +48,7 @@ class ReplicationClientFactory(ReconnectingClientFactory): maxDelay = 30 # Try at least once every N seconds - def __init__(self, hs, client_name, handler): + def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler): self.client_name = client_name self.handler = handler self.server_name = hs.config.server_name @@ -68,13 +74,13 @@ class ReplicationClientFactory(ReconnectingClientFactory): ReconnectingClientFactory.clientConnectionFailed(self, connector, reason) -class ReplicationClientHandler(object): +class ReplicationClientHandler(AbstractReplicationClientHandler): """A base handler that can be passed to the ReplicationClientFactory. By default proxies incoming replication data to the SlaveStore. """ - def __init__(self, store): + def __init__(self, store: BaseSlavedStore): self.store = store # The current connection. None if we are currently (re)connecting @@ -138,11 +144,13 @@ class ReplicationClientHandler(object): if d: d.callback(data) - def get_streams_to_replicate(self): + def get_streams_to_replicate(self) -> Dict[str, int]: """Called when a new connection has been established and we need to subscribe to streams. - Returns a dictionary of stream name to token. + Returns: + map from stream name to the most recent update we have for + that stream (ie, the point we want to start replicating from) """ args = self.store.stream_positions() user_account_data = args.pop("user_account_data", None) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index b64f3f44b5..afaf002fe6 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -48,7 +48,7 @@ indicate which side is sending, these are *not* included on the wire:: > ERROR server stopping * connection closed by server * """ - +import abc import fcntl import logging import struct @@ -65,6 +65,7 @@ from twisted.python.failure import Failure from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util import Clock from synapse.util.stringutils import random_string from .commands import ( @@ -558,11 +559,80 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): self.streamer.lost_connection(self) +class AbstractReplicationClientHandler(metaclass=abc.ABCMeta): + """ + The interface for the handler that should be passed to + ClientReplicationStreamProtocol + """ + + @abc.abstractmethod + def on_rdata(self, stream_name, token, rows): + """Called to handle a batch of replication data with a given stream token. + + Args: + stream_name (str): name of the replication stream for this batch of rows + token (int): stream token for this batch of rows + rows (list): a list of Stream.ROW_TYPE objects as returned by + Stream.parse_row. + + Returns: + Deferred|None + """ + raise NotImplementedError() + + @abc.abstractmethod + def on_position(self, stream_name, token): + """Called when we get new position data.""" + raise NotImplementedError() + + @abc.abstractmethod + def on_sync(self, data): + """Called when get a new SYNC command.""" + raise NotImplementedError() + + @abc.abstractmethod + def get_streams_to_replicate(self): + """Called when a new connection has been established and we need to + subscribe to streams. + + Returns: + map from stream name to the most recent update we have for + that stream (ie, the point we want to start replicating from) + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_currently_syncing_users(self): + """Get the list of currently syncing users (if any). This is called + when a connection has been established and we need to send the + currently syncing users.""" + raise NotImplementedError() + + @abc.abstractmethod + def update_connection(self, connection): + """Called when a connection has been established (or lost with None). + """ + raise NotImplementedError() + + @abc.abstractmethod + def finished_connecting(self): + """Called when we have successfully subscribed and caught up to all + streams we're interested in. + """ + raise NotImplementedError() + + class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS - def __init__(self, client_name, server_name, clock, handler): + def __init__( + self, + client_name: str, + server_name: str, + clock: Clock, + handler: AbstractReplicationClientHandler, + ): BaseReplicationStreamProtocol.__init__(self, clock) self.client_name = client_name diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index 634f636dc9..5f52264e84 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -45,5 +45,6 @@ STREAMS_MAP = { _base.TagAccountDataStream, _base.AccountDataStream, _base.GroupServerStream, + _base.UserSignatureStream, ) } diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index f03111c259..9e45429d49 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -95,6 +95,7 @@ GroupsStreamRow = namedtuple( "GroupsStreamRow", ("group_id", "user_id", "type", "content"), # str # str # str # dict ) +UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str class Stream(object): @@ -438,3 +439,20 @@ class GroupServerStream(Stream): self.update_function = store.get_all_groups_changes super(GroupServerStream, self).__init__(hs) + + +class UserSignatureStream(Stream): + """A user has signed their own device with their user-signing key + """ + + NAME = "user_signature" + _LIMITED = False + ROW_TYPE = UserSignatureStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + + self.current_token = store.get_device_stream_token + self.update_function = store.get_all_user_signature_changes_for_remotes + + super(UserSignatureStream, self).__init__(hs) |