diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index 1c7c6ec0c8..a37818fe9a 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -43,7 +43,6 @@ from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.room import RoomStore
-from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer
from synapse.util.logcontext import LoggingContext
from synapse.util.versionstring import get_version_string
@@ -79,17 +78,6 @@ class AdminCmdServer(HomeServer):
def start_listening(self, listeners):
pass
- def build_tcp_replication(self):
- return AdminCmdReplicationHandler(self)
-
-
-class AdminCmdReplicationHandler(ReplicationClientHandler):
- async def on_rdata(self, stream_name, token, rows):
- pass
-
- def get_streams_to_replicate(self):
- return {}
-
@defer.inlineCallbacks
def export_data_command(hs, args):
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 174bef360f..dcd0709a02 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -64,7 +64,7 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
-from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.commands import ClearUserSyncsCommand
from synapse.replication.tcp.streams import (
AccountDataStream,
@@ -603,7 +603,7 @@ class GenericWorkerServer(HomeServer):
def remove_pusher(self, app_id, push_key, user_id):
self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id)
- def build_tcp_replication(self):
+ def build_replication_data_handler(self):
return GenericWorkerReplicationHandler(self)
def build_presence_handler(self):
@@ -613,7 +613,7 @@ class GenericWorkerServer(HomeServer):
return GenericWorkerTyping(self)
-class GenericWorkerReplicationHandler(ReplicationClientHandler):
+class GenericWorkerReplicationHandler(ReplicationDataHandler):
def __init__(self, hs):
super(GenericWorkerReplicationHandler, self).__init__(hs.get_datastore())
@@ -644,9 +644,6 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler):
args.update(self.send_handler.stream_positions())
return args
- def get_currently_syncing_users(self):
- return self.presence_handler.get_currently_syncing_users()
-
async def process_and_notify(self, stream_name, token, rows):
try:
if self.send_handler:
diff --git a/synapse/replication/tcp/__init__.py b/synapse/replication/tcp/__init__.py
index 81c2ea7ee9..523a1358d4 100644
--- a/synapse/replication/tcp/__init__.py
+++ b/synapse/replication/tcp/__init__.py
@@ -20,11 +20,31 @@ Further details can be found in docs/tcp_replication.rst
Structure of the module:
- * client.py - the client classes used for workers to connect to master
+ * handler.py - the classes used to handle sending/receiving commands to
+ replication
* command.py - the definitions of all the valid commands
- * protocol.py - contains bot the client and server protocol implementations,
- these should not be used directly
- * resource.py - the server classes that accepts and handle client connections
- * streams.py - the definitons of all the valid streams
+ * protocol.py - the TCP protocol classes
+ * resource.py - handles streaming stream updates to replications
+ * streams/ - the definitons of all the valid streams
+
+The general interaction of the classes are:
+
+ +---------------------+
+ | ReplicationStreamer |
+ +---------------------+
+ |
+ v
+ +---------------------------+ +----------------------+
+ | ReplicationCommandHandler |---->|ReplicationDataHandler|
+ +---------------------------+ +----------------------+
+ | ^
+ v |
+ +-------------+
+ | Protocols |
+ | (TCP/redis) |
+ +-------------+
+
+Where the ReplicationDataHandler (or subclasses) handles incoming stream
+updates.
"""
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index e86d9805f1..700ae79158 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -16,26 +16,16 @@
"""
import logging
-from typing import Dict, List, Optional
+from typing import TYPE_CHECKING, Dict
-from twisted.internet import defer
from twisted.internet.protocol import ReconnectingClientFactory
from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.tcp.protocol import (
- AbstractReplicationClientHandler,
- ClientReplicationStreamProtocol,
-)
-
-from .commands import (
- Command,
- FederationAckCommand,
- InvalidateCacheCommand,
- RemoteServerUpCommand,
- RemovePusherCommand,
- UserIpCommand,
- UserSyncCommand,
-)
+from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+ from synapse.replication.tcp.handler import ReplicationCommandHandler
logger = logging.getLogger(__name__)
@@ -44,16 +34,20 @@ class ReplicationClientFactory(ReconnectingClientFactory):
"""Factory for building connections to the master. Will reconnect if the
connection is lost.
- Accepts a handler that will be called when new data is available or data
- is required.
+ Accepts a handler that is passed to `ClientReplicationStreamProtocol`.
"""
initialDelay = 0.1
maxDelay = 1 # Try at least once every N seconds
- def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler):
+ def __init__(
+ self,
+ hs: "HomeServer",
+ client_name: str,
+ command_handler: "ReplicationCommandHandler",
+ ):
self.client_name = client_name
- self.handler = handler
+ self.command_handler = command_handler
self.server_name = hs.config.server_name
self.hs = hs
self._clock = hs.get_clock() # As self.clock is defined in super class
@@ -66,7 +60,11 @@ class ReplicationClientFactory(ReconnectingClientFactory):
def buildProtocol(self, addr):
logger.info("Connected to replication: %r", addr)
return ClientReplicationStreamProtocol(
- self.hs, self.client_name, self.server_name, self._clock, self.handler,
+ self.hs,
+ self.client_name,
+ self.server_name,
+ self._clock,
+ self.command_handler,
)
def clientConnectionLost(self, connector, reason):
@@ -78,41 +76,17 @@ class ReplicationClientFactory(ReconnectingClientFactory):
ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
-class ReplicationClientHandler(AbstractReplicationClientHandler):
- """A base handler that can be passed to the ReplicationClientFactory.
+class ReplicationDataHandler:
+ """Handles incoming stream updates from replication.
- By default proxies incoming replication data to the SlaveStore.
+ This instance notifies the slave data store about updates. Can be subclassed
+ to handle updates in additional ways.
"""
def __init__(self, store: BaseSlavedStore):
self.store = store
- # The current connection. None if we are currently (re)connecting
- self.connection = None
-
- # Any pending commands to be sent once a new connection has been
- # established
- self.pending_commands = [] # type: List[Command]
-
- # Map from string -> deferred, to wake up when receiveing a SYNC with
- # the given string.
- # Used for tests.
- self.awaiting_syncs = {} # type: Dict[str, defer.Deferred]
-
- # The factory used to create connections.
- self.factory = None # type: Optional[ReplicationClientFactory]
-
- def start_replication(self, hs):
- """Helper method to start a replication connection to the remote server
- using TCP.
- """
- client_name = hs.config.worker_name
- self.factory = ReplicationClientFactory(hs, client_name, self)
- host = hs.config.worker_replication_host
- port = hs.config.worker_replication_port
- hs.get_reactor().connectTCP(host, port, self.factory)
-
- async def on_rdata(self, stream_name, token, rows):
+ async def on_rdata(self, stream_name: str, token: int, rows: list):
"""Called to handle a batch of replication data with a given stream token.
By default this just pokes the slave store. Can be overridden in subclasses to
@@ -124,30 +98,8 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
rows (list): a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
- logger.debug("Received rdata %s -> %s", stream_name, token)
self.store.process_replication_rows(stream_name, token, rows)
- async def on_position(self, stream_name, token):
- """Called when we get new position data. By default this just pokes
- the slave store.
-
- Can be overriden in subclasses to handle more.
- """
- self.store.process_replication_rows(stream_name, token, [])
-
- def on_sync(self, data):
- """When we received a SYNC we wake up any deferreds that were waiting
- for the sync with the given data.
-
- Used by tests.
- """
- d = self.awaiting_syncs.pop(data, None)
- if d:
- d.callback(data)
-
- def on_remote_server_up(self, server: str):
- """Called when get a new REMOTE_SERVER_UP command."""
-
def get_streams_to_replicate(self) -> Dict[str, int]:
"""Called when a new connection has been established and we need to
subscribe to streams.
@@ -163,85 +115,10 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
args["account_data"] = user_account_data
elif room_account_data:
args["account_data"] = room_account_data
-
return args
- def get_currently_syncing_users(self):
- """Get the list of currently syncing users (if any). This is called
- when a connection has been established and we need to send the
- currently syncing users. (Overriden by the synchrotron's only)
- """
- return []
-
- def send_command(self, cmd):
- """Send a command to master (when we get establish a connection if we
- don't have one already.)
- """
- if self.connection:
- self.connection.send_command(cmd)
- else:
- logger.warning("Queuing command as not connected: %r", cmd.NAME)
- self.pending_commands.append(cmd)
-
- def send_federation_ack(self, token):
- """Ack data for the federation stream. This allows the master to drop
- data stored purely in memory.
- """
- self.send_command(FederationAckCommand(token))
-
- def send_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms):
- """Poke the master that a user has started/stopped syncing.
- """
- self.send_command(
- UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
- )
-
- def send_remove_pusher(self, app_id, push_key, user_id):
- """Poke the master to remove a pusher for a user
- """
- cmd = RemovePusherCommand(app_id, push_key, user_id)
- self.send_command(cmd)
-
- def send_invalidate_cache(self, cache_func, keys):
- """Poke the master to invalidate a cache.
- """
- cmd = InvalidateCacheCommand(cache_func.__name__, keys)
- self.send_command(cmd)
-
- def send_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
- """Tell the master that the user made a request.
- """
- cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
- self.send_command(cmd)
-
- def send_remote_server_up(self, server: str):
- self.send_command(RemoteServerUpCommand(server))
-
- def await_sync(self, data):
- """Returns a deferred that is resolved when we receive a SYNC command
- with given data.
-
- [Not currently] used by tests.
- """
- return self.awaiting_syncs.setdefault(data, defer.Deferred())
-
- def update_connection(self, connection):
- """Called when a connection has been established (or lost with None).
- """
- self.connection = connection
- if connection:
- for cmd in self.pending_commands:
- connection.send_command(cmd)
- self.pending_commands = []
-
- def finished_connecting(self):
- """Called when we have successfully subscribed and caught up to all
- streams we're interested in.
- """
- logger.info("Finished connecting to server")
+ async def on_position(self, stream_name: str, token: int):
+ self.store.process_replication_rows(stream_name, token, [])
- # We don't reset the delay any earlier as otherwise if there is a
- # problem during start up we'll end up tight looping connecting to the
- # server.
- if self.factory:
- self.factory.resetDelay()
+ def on_remote_server_up(self, server: str):
+ """Called when get a new REMOTE_SERVER_UP command."""
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
new file mode 100644
index 0000000000..12a1cfd6d1
--- /dev/null
+++ b/synapse/replication/tcp/handler.py
@@ -0,0 +1,252 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import Any, Callable, Dict, List, Optional, Set
+
+from prometheus_client import Counter
+
+from synapse.replication.tcp.client import ReplicationClientFactory
+from synapse.replication.tcp.commands import (
+ Command,
+ FederationAckCommand,
+ InvalidateCacheCommand,
+ PositionCommand,
+ RdataCommand,
+ RemoteServerUpCommand,
+ RemovePusherCommand,
+ SyncCommand,
+ UserIpCommand,
+ UserSyncCommand,
+)
+from synapse.replication.tcp.streams import STREAMS_MAP, Stream
+from synapse.util.async_helpers import Linearizer
+
+logger = logging.getLogger(__name__)
+
+
+# number of updates received for each RDATA stream
+inbound_rdata_count = Counter(
+ "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"]
+)
+
+
+class ReplicationCommandHandler:
+ """Handles incoming commands from replication as well as sending commands
+ back out to connections.
+ """
+
+ def __init__(self, hs):
+ self._replication_data_handler = hs.get_replication_data_handler()
+ self._presence_handler = hs.get_presence_handler()
+
+ # Set of streams that we've caught up with.
+ self._streams_connected = set() # type: Set[str]
+
+ self._streams = {
+ stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
+ } # type: Dict[str, Stream]
+
+ self._position_linearizer = Linearizer("replication_position")
+
+ # Map of stream to batched updates. See RdataCommand for info on how
+ # batching works.
+ self._pending_batches = {} # type: Dict[str, List[Any]]
+
+ # The factory used to create connections.
+ self._factory = None # type: Optional[ReplicationClientFactory]
+
+ # The current connection. None if we are currently (re)connecting
+ self._connection = None
+
+ def start_replication(self, hs):
+ """Helper method to start a replication connection to the remote server
+ using TCP.
+ """
+ client_name = hs.config.worker_name
+ self._factory = ReplicationClientFactory(hs, client_name, self)
+ host = hs.config.worker_replication_host
+ port = hs.config.worker_replication_port
+ hs.get_reactor().connectTCP(host, port, self._factory)
+
+ async def on_RDATA(self, cmd: RdataCommand):
+ stream_name = cmd.stream_name
+ inbound_rdata_count.labels(stream_name).inc()
+
+ try:
+ row = STREAMS_MAP[stream_name].parse_row(cmd.row)
+ except Exception:
+ logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row)
+ raise
+
+ if cmd.token is None or stream_name not in self._streams_connected:
+ # I.e. either this is part of a batch of updates for this stream (in
+ # which case batch until we get an update for the stream with a non
+ # None token) or we're currently connecting so we queue up rows.
+ self._pending_batches.setdefault(stream_name, []).append(row)
+ else:
+ # Check if this is the last of a batch of updates
+ rows = self._pending_batches.pop(stream_name, [])
+ rows.append(row)
+ await self.on_rdata(stream_name, cmd.token, rows)
+
+ async def on_rdata(self, stream_name: str, token: int, rows: list):
+ """Called to handle a batch of replication data with a given stream token.
+
+ Args:
+ stream_name: name of the replication stream for this batch of rows
+ token: stream token for this batch of rows
+ rows: a list of Stream.ROW_TYPE objects as returned by
+ Stream.parse_row.
+ """
+ logger.debug("Received rdata %s -> %s", stream_name, token)
+ await self._replication_data_handler.on_rdata(stream_name, token, rows)
+
+ async def on_POSITION(self, cmd: PositionCommand):
+ stream = self._streams.get(cmd.stream_name)
+ if not stream:
+ logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
+ return
+
+ # We protect catching up with a linearizer in case the replication
+ # connection reconnects under us.
+ with await self._position_linearizer.queue(cmd.stream_name):
+ # We're about to go and catch up with the stream, so mark as connecting
+ # to stop RDATA being handled at the same time by removing stream from
+ # list of connected streams. We also clear any batched up RDATA from
+ # before we got the POSITION.
+ self._streams_connected.discard(cmd.stream_name)
+ self._pending_batches.clear()
+
+ # Find where we previously streamed up to.
+ current_token = self._replication_data_handler.get_streams_to_replicate().get(
+ cmd.stream_name
+ )
+ if current_token is None:
+ logger.warning(
+ "Got POSITION for stream we're not subscribed to: %s",
+ cmd.stream_name,
+ )
+ return
+
+ # Fetch all updates between then and now.
+ limited = True
+ while limited:
+ updates, current_token, limited = await stream.get_updates_since(
+ current_token, cmd.token
+ )
+ if updates:
+ await self.on_rdata(
+ cmd.stream_name,
+ current_token,
+ [stream.parse_row(update[1]) for update in updates],
+ )
+
+ # We've now caught up to position sent to us, notify handler.
+ await self._replication_data_handler.on_position(cmd.stream_name, cmd.token)
+
+ # Handle any RDATA that came in while we were catching up.
+ rows = self._pending_batches.pop(cmd.stream_name, [])
+ if rows:
+ await self._replication_data_handler.on_rdata(
+ cmd.stream_name, rows[-1].token, rows
+ )
+
+ self._streams_connected.add(cmd.stream_name)
+
+ async def on_SYNC(self, cmd: SyncCommand):
+ pass
+
+ async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
+ """"Called when get a new REMOTE_SERVER_UP command."""
+ self._replication_data_handler.on_remote_server_up(cmd.data)
+
+ def get_currently_syncing_users(self):
+ """Get the list of currently syncing users (if any). This is called
+ when a connection has been established and we need to send the
+ currently syncing users.
+ """
+ return self._presence_handler.get_currently_syncing_users()
+
+ def update_connection(self, connection):
+ """Called when a connection has been established (or lost with None).
+ """
+ self._connection = connection
+
+ def finished_connecting(self):
+ """Called when we have successfully subscribed and caught up to all
+ streams we're interested in.
+ """
+ logger.info("Finished connecting to server")
+
+ # We don't reset the delay any earlier as otherwise if there is a
+ # problem during start up we'll end up tight looping connecting to the
+ # server.
+ if self._factory:
+ self._factory.resetDelay()
+
+ def send_command(self, cmd: Command):
+ """Send a command to master (when we get establish a connection if we
+ don't have one already.)
+ """
+ if self._connection:
+ self._connection.send_command(cmd)
+ else:
+ logger.warning("Dropping command as not connected: %r", cmd.NAME)
+
+ def send_federation_ack(self, token: int):
+ """Ack data for the federation stream. This allows the master to drop
+ data stored purely in memory.
+ """
+ self.send_command(FederationAckCommand(token))
+
+ def send_user_sync(
+ self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
+ ):
+ """Poke the master that a user has started/stopped syncing.
+ """
+ self.send_command(
+ UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
+ )
+
+ def send_remove_pusher(self, app_id: str, push_key: str, user_id: str):
+ """Poke the master to remove a pusher for a user
+ """
+ cmd = RemovePusherCommand(app_id, push_key, user_id)
+ self.send_command(cmd)
+
+ def send_invalidate_cache(self, cache_func: Callable, keys: tuple):
+ """Poke the master to invalidate a cache.
+ """
+ cmd = InvalidateCacheCommand(cache_func.__name__, keys)
+ self.send_command(cmd)
+
+ def send_user_ip(
+ self,
+ user_id: str,
+ access_token: str,
+ ip: str,
+ user_agent: str,
+ device_id: str,
+ last_seen: int,
+ ):
+ """Tell the master that the user made a request.
+ """
+ cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
+ self.send_command(cmd)
+
+ def send_remote_server_up(self, server: str):
+ self.send_command(RemoteServerUpCommand(server))
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index dae246825f..f2a37f568e 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -46,12 +46,11 @@ indicate which side is sending, these are *not* included on the wire::
> ERROR server stopping
* connection closed by server *
"""
-import abc
import fcntl
import logging
import struct
from collections import defaultdict
-from typing import Any, DefaultDict, Dict, List, Set
+from typing import TYPE_CHECKING, DefaultDict, List
from six import iteritems
@@ -78,13 +77,12 @@ from synapse.replication.tcp.commands import (
SyncCommand,
UserSyncCommand,
)
-from synapse.replication.tcp.streams import STREAMS_MAP, Stream
from synapse.types import Collection
from synapse.util import Clock
from synapse.util.stringutils import random_string
-MYPY = False
-if MYPY:
+if TYPE_CHECKING:
+ from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.server import HomeServer
@@ -475,71 +473,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.streamer.lost_connection(self)
-class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
- """
- The interface for the handler that should be passed to
- ClientReplicationStreamProtocol
- """
-
- @abc.abstractmethod
- async def on_rdata(self, stream_name, token, rows):
- """Called to handle a batch of replication data with a given stream token.
-
- Args:
- stream_name (str): name of the replication stream for this batch of rows
- token (int): stream token for this batch of rows
- rows (list): a list of Stream.ROW_TYPE objects as returned by
- Stream.parse_row.
- """
- raise NotImplementedError()
-
- @abc.abstractmethod
- async def on_position(self, stream_name, token):
- """Called when we get new position data."""
- raise NotImplementedError()
-
- @abc.abstractmethod
- def on_sync(self, data):
- """Called when get a new SYNC command."""
- raise NotImplementedError()
-
- @abc.abstractmethod
- async def on_remote_server_up(self, server: str):
- """Called when get a new REMOTE_SERVER_UP command."""
- raise NotImplementedError()
-
- @abc.abstractmethod
- def get_streams_to_replicate(self):
- """Called when a new connection has been established and we need to
- subscribe to streams.
-
- Returns:
- map from stream name to the most recent update we have for
- that stream (ie, the point we want to start replicating from)
- """
- raise NotImplementedError()
-
- @abc.abstractmethod
- def get_currently_syncing_users(self):
- """Get the list of currently syncing users (if any). This is called
- when a connection has been established and we need to send the
- currently syncing users."""
- raise NotImplementedError()
-
- @abc.abstractmethod
- def update_connection(self, connection):
- """Called when a connection has been established (or lost with None).
- """
- raise NotImplementedError()
-
- @abc.abstractmethod
- def finished_connecting(self):
- """Called when we have successfully subscribed and caught up to all
- streams we're interested in.
- """
- raise NotImplementedError()
-
-
class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS
@@ -550,7 +483,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
client_name: str,
server_name: str,
clock: Clock,
- handler: AbstractReplicationClientHandler,
+ command_handler: "ReplicationCommandHandler",
):
BaseReplicationStreamProtocol.__init__(self, clock)
@@ -558,20 +491,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.client_name = client_name
self.server_name = server_name
- self.handler = handler
-
- self.streams = {
- stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
- } # type: Dict[str, Stream]
-
- # Set of stream names that have been subscribe to, but haven't yet
- # caught up with. This is used to track when the client has been fully
- # connected to the remote.
- self.streams_connecting = set(STREAMS_MAP) # type: Set[str]
-
- # Map of stream to batched updates. See RdataCommand for info on how
- # batching works.
- self.pending_batches = {} # type: Dict[str, List[Any]]
+ self.handler = command_handler
def connectionMade(self):
self.send_command(NameCommand(self.client_name))
@@ -589,89 +509,39 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# We've now finished connecting to so inform the client handler
self.handler.update_connection(self)
+ self.handler.finished_connecting()
- async def on_SERVER(self, cmd):
- if cmd.data != self.server_name:
- logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
- self.send_error("Wrong remote")
-
- async def on_RDATA(self, cmd):
- stream_name = cmd.stream_name
- inbound_rdata_count.labels(stream_name).inc()
-
- try:
- row = STREAMS_MAP[stream_name].parse_row(cmd.row)
- except Exception:
- logger.exception(
- "[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row
- )
- raise
-
- if cmd.token is None or stream_name in self.streams_connecting:
- # I.e. this is part of a batch of updates for this stream. Batch
- # until we get an update for the stream with a non None token
- self.pending_batches.setdefault(stream_name, []).append(row)
- else:
- # Check if this is the last of a batch of updates
- rows = self.pending_batches.pop(stream_name, [])
- rows.append(row)
- await self.handler.on_rdata(stream_name, cmd.token, rows)
-
- async def on_POSITION(self, cmd: PositionCommand):
- stream = self.streams.get(cmd.stream_name)
- if not stream:
- logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
- return
-
- # Find where we previously streamed up to.
- current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name)
- if current_token is None:
- logger.warning(
- "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name
- )
- return
-
- # Fetch all updates between then and now.
- limited = True
- while limited:
- updates, current_token, limited = await stream.get_updates_since(
- current_token, cmd.token
- )
-
- # Check if the connection was closed underneath us, if so we bail
- # rather than risk having concurrent catch ups going on.
- if self.state == ConnectionStates.CLOSED:
- return
-
- if updates:
- await self.handler.on_rdata(
- cmd.stream_name,
- current_token,
- [stream.parse_row(update[1]) for update in updates],
- )
+ async def handle_command(self, cmd: Command):
+ """Handle a command we have received over the replication stream.
- # We've now caught up to position sent to us, notify handler.
- await self.handler.on_position(cmd.stream_name, cmd.token)
+ Delegates to `command_handler.on_<COMMAND>`, which must return an
+ awaitable.
- self.streams_connecting.discard(cmd.stream_name)
- if not self.streams_connecting:
- self.handler.finished_connecting()
+ Args:
+ cmd: received command
+ """
+ handled = False
- # Check if the connection was closed underneath us, if so we bail
- # rather than risk having concurrent catch ups going on.
- if self.state == ConnectionStates.CLOSED:
- return
+ # First call any command handlers on this instance. These are for TCP
+ # specific handling.
+ cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
+ if cmd_func:
+ await cmd_func(cmd)
+ handled = True
- # Handle any RDATA that came in while we were catching up.
- rows = self.pending_batches.pop(cmd.stream_name, [])
- if rows:
- await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows)
+ # Then call out to the handler.
+ cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
+ if cmd_func:
+ await cmd_func(cmd)
+ handled = True
- async def on_SYNC(self, cmd):
- self.handler.on_sync(cmd.data)
+ if not handled:
+ logger.warning("Unhandled command: %r", cmd)
- async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
- self.handler.on_remote_server_up(cmd.data)
+ async def on_SERVER(self, cmd):
+ if cmd.data != self.server_name:
+ logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
+ self.send_error("Wrong remote")
def replicate(self):
"""Send the subscription request to the server
@@ -768,8 +638,3 @@ tcp_outbound_commands = LaterGauge(
for k, count in iteritems(p.outbound_commands_counter)
},
)
-
-# number of updates received for each RDATA stream
-inbound_rdata_count = Counter(
- "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"]
-)
diff --git a/synapse/server.py b/synapse/server.py
index 9228e1c892..9d273c980c 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -87,6 +87,8 @@ from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.notifier import Notifier
from synapse.push.action_generator import ActionGenerator
from synapse.push.pusherpool import PusherPool
+from synapse.replication.tcp.client import ReplicationDataHandler
+from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.resource import ReplicationStreamer
from synapse.rest.media.v1.media_repository import (
MediaRepository,
@@ -206,6 +208,7 @@ class HomeServer(object):
"password_policy_handler",
"storage",
"replication_streamer",
+ "replication_data_handler",
]
REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
@@ -468,7 +471,7 @@ class HomeServer(object):
return ReadMarkerHandler(self)
def build_tcp_replication(self):
- raise NotImplementedError()
+ return ReplicationCommandHandler(self)
def build_action_generator(self):
return ActionGenerator(self)
@@ -562,6 +565,9 @@ class HomeServer(object):
def build_replication_streamer(self) -> ReplicationStreamer:
return ReplicationStreamer(self)
+ def build_replication_data_handler(self):
+ return ReplicationDataHandler(self.get_datastore())
+
def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
diff --git a/synapse/server.pyi b/synapse/server.pyi
index 9d1dfa71e7..9013e9bac9 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -19,6 +19,7 @@ import synapse.handlers.set_password
import synapse.http.client
import synapse.notifier
import synapse.replication.tcp.client
+import synapse.replication.tcp.handler
import synapse.rest.media.v1.media_repository
import synapse.server_notices.server_notices_manager
import synapse.server_notices.server_notices_sender
@@ -106,7 +107,11 @@ class HomeServer(object):
pass
def get_tcp_replication(
self,
- ) -> synapse.replication.tcp.client.ReplicationClientHandler:
+ ) -> synapse.replication.tcp.handler.ReplicationCommandHandler:
+ pass
+ def get_replication_data_handler(
+ self,
+ ) -> synapse.replication.tcp.client.ReplicationDataHandler:
pass
def get_federation_registry(
self,
|