diff --git a/synapse/replication/tcp/__init__.py b/synapse/replication/tcp/__init__.py
index 81c2ea7ee9..523a1358d4 100644
--- a/synapse/replication/tcp/__init__.py
+++ b/synapse/replication/tcp/__init__.py
@@ -20,11 +20,31 @@ Further details can be found in docs/tcp_replication.rst
Structure of the module:
- * client.py - the client classes used for workers to connect to master
+ * handler.py - the classes used to handle sending/receiving commands to
+ replication
* command.py - the definitions of all the valid commands
- * protocol.py - contains bot the client and server protocol implementations,
- these should not be used directly
- * resource.py - the server classes that accepts and handle client connections
- * streams.py - the definitons of all the valid streams
+ * protocol.py - the TCP protocol classes
+ * resource.py - handles streaming stream updates to replications
+ * streams/ - the definitons of all the valid streams
+
+The general interaction of the classes are:
+
+ +---------------------+
+ | ReplicationStreamer |
+ +---------------------+
+ |
+ v
+ +---------------------------+ +----------------------+
+ | ReplicationCommandHandler |---->|ReplicationDataHandler|
+ +---------------------------+ +----------------------+
+ | ^
+ v |
+ +-------------+
+ | Protocols |
+ | (TCP/redis) |
+ +-------------+
+
+Where the ReplicationDataHandler (or subclasses) handles incoming stream
+updates.
"""
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 02ab5b66ea..508ad1b720 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -14,47 +14,55 @@
# limitations under the License.
"""A replication client for use by synapse workers.
"""
-
+import heapq
import logging
-from typing import Dict, List, Optional
+from typing import TYPE_CHECKING, Dict, List, Tuple
-from twisted.internet import defer
+from twisted.internet.defer import Deferred
from twisted.internet.protocol import ReconnectingClientFactory
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.tcp.protocol import (
- AbstractReplicationClientHandler,
- ClientReplicationStreamProtocol,
+from synapse.api.constants import EventTypes
+from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
+from synapse.replication.tcp.streams.events import (
+ EventsStream,
+ EventsStreamEventRow,
+ EventsStreamRow,
)
+from synapse.util.async_helpers import timeout_deferred
+from synapse.util.metrics import Measure
-from .commands import (
- Command,
- FederationAckCommand,
- InvalidateCacheCommand,
- RemoteServerUpCommand,
- RemovePusherCommand,
- UserIpCommand,
- UserSyncCommand,
-)
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+ from synapse.replication.tcp.handler import ReplicationCommandHandler
logger = logging.getLogger(__name__)
-class ReplicationClientFactory(ReconnectingClientFactory):
+# How long we allow callers to wait for replication updates before timing out.
+_WAIT_FOR_REPLICATION_TIMEOUT_SECONDS = 30
+
+
+class DirectTcpReplicationClientFactory(ReconnectingClientFactory):
"""Factory for building connections to the master. Will reconnect if the
connection is lost.
- Accepts a handler that will be called when new data is available or data
- is required.
+ Accepts a handler that is passed to `ClientReplicationStreamProtocol`.
"""
initialDelay = 0.1
maxDelay = 1 # Try at least once every N seconds
- def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler):
+ def __init__(
+ self,
+ hs: "HomeServer",
+ client_name: str,
+ command_handler: "ReplicationCommandHandler",
+ ):
self.client_name = client_name
- self.handler = handler
+ self.command_handler = command_handler
self.server_name = hs.config.server_name
+ self.hs = hs
self._clock = hs.get_clock() # As self.clock is defined in super class
hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
@@ -65,7 +73,11 @@ class ReplicationClientFactory(ReconnectingClientFactory):
def buildProtocol(self, addr):
logger.info("Connected to replication: %r", addr)
return ClientReplicationStreamProtocol(
- self.client_name, self.server_name, self._clock, self.handler
+ self.hs,
+ self.client_name,
+ self.server_name,
+ self._clock,
+ self.command_handler,
)
def clientConnectionLost(self, connector, reason):
@@ -77,168 +89,136 @@ class ReplicationClientFactory(ReconnectingClientFactory):
ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
-class ReplicationClientHandler(AbstractReplicationClientHandler):
- """A base handler that can be passed to the ReplicationClientFactory.
+class ReplicationDataHandler:
+ """Handles incoming stream updates from replication.
- By default proxies incoming replication data to the SlaveStore.
+ This instance notifies the slave data store about updates. Can be subclassed
+ to handle updates in additional ways.
"""
- def __init__(self, store: BaseSlavedStore):
- self.store = store
-
- # The current connection. None if we are currently (re)connecting
- self.connection = None
-
- # Any pending commands to be sent once a new connection has been
- # established
- self.pending_commands = [] # type: List[Command]
-
- # Map from string -> deferred, to wake up when receiveing a SYNC with
- # the given string.
- # Used for tests.
- self.awaiting_syncs = {} # type: Dict[str, defer.Deferred]
-
- # The factory used to create connections.
- self.factory = None # type: Optional[ReplicationClientFactory]
-
- def start_replication(self, hs):
- """Helper method to start a replication connection to the remote server
- using TCP.
- """
- client_name = hs.config.worker_name
- self.factory = ReplicationClientFactory(hs, client_name, self)
- host = hs.config.worker_replication_host
- port = hs.config.worker_replication_port
- hs.get_reactor().connectTCP(host, port, self.factory)
-
- async def on_rdata(self, stream_name, token, rows):
+ def __init__(self, hs: "HomeServer"):
+ self.store = hs.get_datastore()
+ self.pusher_pool = hs.get_pusherpool()
+ self.notifier = hs.get_notifier()
+ self._reactor = hs.get_reactor()
+ self._clock = hs.get_clock()
+ self._streams = hs.get_replication_streams()
+ self._instance_name = hs.get_instance_name()
+
+ # Map from stream to list of deferreds waiting for the stream to
+ # arrive at a particular position. The lists are sorted by stream position.
+ self._streams_to_waiters = (
+ {}
+ ) # type: Dict[str, List[Tuple[int, Deferred[None]]]]
+
+ async def on_rdata(
+ self, stream_name: str, instance_name: str, token: int, rows: list
+ ):
"""Called to handle a batch of replication data with a given stream token.
By default this just pokes the slave store. Can be overridden in subclasses to
handle more.
Args:
- stream_name (str): name of the replication stream for this batch of rows
- token (int): stream token for this batch of rows
- rows (list): a list of Stream.ROW_TYPE objects as returned by
- Stream.parse_row.
- """
- logger.debug("Received rdata %s -> %s", stream_name, token)
- self.store.process_replication_rows(stream_name, token, rows)
-
- async def on_position(self, stream_name, token):
- """Called when we get new position data. By default this just pokes
- the slave store.
-
- Can be overriden in subclasses to handle more.
+ stream_name: name of the replication stream for this batch of rows
+ instance_name: the instance that wrote the rows.
+ token: stream token for this batch of rows
+ rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
"""
- self.store.process_replication_rows(stream_name, token, [])
-
- def on_sync(self, data):
- """When we received a SYNC we wake up any deferreds that were waiting
- for the sync with the given data.
-
- Used by tests.
- """
- d = self.awaiting_syncs.pop(data, None)
- if d:
- d.callback(data)
+ self.store.process_replication_rows(stream_name, instance_name, token, rows)
+
+ if stream_name == EventsStream.NAME:
+ # We shouldn't get multiple rows per token for events stream, so
+ # we don't need to optimise this for multiple rows.
+ for row in rows:
+ if row.type != EventsStreamEventRow.TypeId:
+ continue
+ assert isinstance(row, EventsStreamRow)
+
+ event = await self.store.get_event(
+ row.data.event_id, allow_rejected=True
+ )
+ if event.rejected_reason:
+ continue
+
+ extra_users = () # type: Tuple[str, ...]
+ if event.type == EventTypes.Member:
+ extra_users = (event.state_key,)
+ max_token = self.store.get_room_max_stream_ordering()
+ self.notifier.on_new_room_event(event, token, max_token, extra_users)
+
+ await self.pusher_pool.on_new_notifications(token, token)
+
+ # Notify any waiting deferreds. The list is ordered by position so we
+ # just iterate through the list until we reach a position that is
+ # greater than the received row position.
+ waiting_list = self._streams_to_waiters.get(stream_name, [])
+
+ # Index of first item with a position after the current token, i.e we
+ # have called all deferreds before this index. If not overwritten by
+ # loop below means either a) no items in list so no-op or b) all items
+ # in list were called and so the list should be cleared. Setting it to
+ # `len(list)` works for both cases.
+ index_of_first_deferred_not_called = len(waiting_list)
+
+ for idx, (position, deferred) in enumerate(waiting_list):
+ if position <= token:
+ try:
+ with PreserveLoggingContext():
+ deferred.callback(None)
+ except Exception:
+ # The deferred has been cancelled or timed out.
+ pass
+ else:
+ # The list is sorted by position so we don't need to continue
+ # checking any futher entries in the list.
+ index_of_first_deferred_not_called = idx
+ break
+
+ # Drop all entries in the waiting list that were called in the above
+ # loop. (This maintains the order so no need to resort)
+ waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]
+
+ async def on_position(self, stream_name: str, instance_name: str, token: int):
+ self.store.process_replication_rows(stream_name, instance_name, token, [])
def on_remote_server_up(self, server: str):
"""Called when get a new REMOTE_SERVER_UP command."""
- def get_streams_to_replicate(self) -> Dict[str, int]:
- """Called when a new connection has been established and we need to
- subscribe to streams.
-
- Returns:
- map from stream name to the most recent update we have for
- that stream (ie, the point we want to start replicating from)
+ async def wait_for_stream_position(
+ self, instance_name: str, stream_name: str, position: int
+ ):
+ """Wait until this instance has received updates up to and including
+ the given stream position.
"""
- args = self.store.stream_positions()
- user_account_data = args.pop("user_account_data", None)
- room_account_data = args.pop("room_account_data", None)
- if user_account_data:
- args["account_data"] = user_account_data
- elif room_account_data:
- args["account_data"] = room_account_data
-
- return args
-
- def get_currently_syncing_users(self):
- """Get the list of currently syncing users (if any). This is called
- when a connection has been established and we need to send the
- currently syncing users. (Overriden by the synchrotron's only)
- """
- return []
- def send_command(self, cmd):
- """Send a command to master (when we get establish a connection if we
- don't have one already.)
- """
- if self.connection:
- self.connection.send_command(cmd)
- else:
- logger.warning("Queuing command as not connected: %r", cmd.NAME)
- self.pending_commands.append(cmd)
-
- def send_federation_ack(self, token):
- """Ack data for the federation stream. This allows the master to drop
- data stored purely in memory.
- """
- self.send_command(FederationAckCommand(token))
-
- def send_user_sync(self, user_id, is_syncing, last_sync_ms):
- """Poke the master that a user has started/stopped syncing.
- """
- self.send_command(UserSyncCommand(user_id, is_syncing, last_sync_ms))
-
- def send_remove_pusher(self, app_id, push_key, user_id):
- """Poke the master to remove a pusher for a user
- """
- cmd = RemovePusherCommand(app_id, push_key, user_id)
- self.send_command(cmd)
-
- def send_invalidate_cache(self, cache_func, keys):
- """Poke the master to invalidate a cache.
- """
- cmd = InvalidateCacheCommand(cache_func.__name__, keys)
- self.send_command(cmd)
-
- def send_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
- """Tell the master that the user made a request.
- """
- cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
- self.send_command(cmd)
-
- def send_remote_server_up(self, server: str):
- self.send_command(RemoteServerUpCommand(server))
-
- def await_sync(self, data):
- """Returns a deferred that is resolved when we receive a SYNC command
- with given data.
+ if instance_name == self._instance_name:
+ # We don't get told about updates written by this process, and
+ # anyway in that case we don't need to wait.
+ return
+
+ current_position = self._streams[stream_name].current_token(self._instance_name)
+ if position <= current_position:
+ # We're already past the position
+ return
+
+ # Create a new deferred that times out after N seconds, as we don't want
+ # to wedge here forever.
+ deferred = Deferred()
+ deferred = timeout_deferred(
+ deferred, _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, self._reactor
+ )
- [Not currently] used by tests.
- """
- return self.awaiting_syncs.setdefault(data, defer.Deferred())
+ waiting_list = self._streams_to_waiters.setdefault(stream_name, [])
- def update_connection(self, connection):
- """Called when a connection has been established (or lost with None).
- """
- self.connection = connection
- if connection:
- for cmd in self.pending_commands:
- connection.send_command(cmd)
- self.pending_commands = []
-
- def finished_connecting(self):
- """Called when we have successfully subscribed and caught up to all
- streams we're interested in.
- """
- logger.info("Finished connecting to server")
+ # We insert into the list using heapq as it is more efficient than
+ # pushing then resorting each time.
+ heapq.heappush(waiting_list, (position, deferred))
- # We don't reset the delay any earlier as otherwise if there is a
- # problem during start up we'll end up tight looping connecting to the
- # server.
- if self.factory:
- self.factory.resetDelay()
+ # We measure here to get in flight counts and average waiting time.
+ with Measure(self._clock, "repl.wait_for_stream_position"):
+ logger.info("Waiting for repl stream %r to reach %s", stream_name, position)
+ await make_deferred_yieldable(deferred)
+ logger.info(
+ "Finished waiting for repl stream %r to reach %s", stream_name, position
+ )
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 451671412d..c04f622816 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -17,7 +17,7 @@
The VALID_SERVER_COMMANDS and VALID_CLIENT_COMMANDS define which commands are
allowed to be sent by which side.
"""
-
+import abc
import logging
import platform
from typing import Tuple, Type
@@ -34,34 +34,29 @@ else:
logger = logging.getLogger(__name__)
-class Command(object):
+class Command(metaclass=abc.ABCMeta):
"""The base command class.
All subclasses must set the NAME variable which equates to the name of the
command on the wire.
A full command line on the wire is constructed from `NAME + " " + to_line()`
-
- The default implementation creates a command of form `<NAME> <data>`
"""
NAME = None # type: str
- def __init__(self, data):
- self.data = data
-
@classmethod
+ @abc.abstractmethod
def from_line(cls, line):
"""Deserialises a line from the wire into this command. `line` does not
include the command.
"""
- return cls(line)
- def to_line(self):
+ @abc.abstractmethod
+ def to_line(self) -> str:
"""Serialises the comamnd for the wire. Does not include the command
prefix.
"""
- return self.data
def get_logcontext_id(self):
"""Get a suitable string for the logcontext when processing this command"""
@@ -70,7 +65,21 @@ class Command(object):
return self.NAME
-class ServerCommand(Command):
+class _SimpleCommand(Command):
+ """An implementation of Command whose argument is just a 'data' string."""
+
+ def __init__(self, data):
+ self.data = data
+
+ @classmethod
+ def from_line(cls, line):
+ return cls(line)
+
+ def to_line(self) -> str:
+ return self.data
+
+
+class ServerCommand(_SimpleCommand):
"""Sent by the server on new connection and includes the server_name.
Format::
@@ -86,7 +95,7 @@ class RdataCommand(Command):
Format::
- RDATA <stream_name> <token> <row_json>
+ RDATA <stream_name> <instance_name> <token> <row_json>
The `<token>` may either be a numeric stream id OR "batch". The latter case
is used to support sending multiple updates with the same stream ID. This
@@ -96,33 +105,40 @@ class RdataCommand(Command):
The client should batch all incoming RDATA with a token of "batch" (per
stream_name) until it sees an RDATA with a numeric stream ID.
+ The `<instance_name>` is the source of the new data (usually "master").
+
`<token>` of "batch" maps to the instance variable `token` being None.
An example of a batched series of RDATA::
- RDATA presence batch ["@foo:example.com", "online", ...]
- RDATA presence batch ["@bar:example.com", "online", ...]
- RDATA presence 59 ["@baz:example.com", "online", ...]
+ RDATA presence master batch ["@foo:example.com", "online", ...]
+ RDATA presence master batch ["@bar:example.com", "online", ...]
+ RDATA presence master 59 ["@baz:example.com", "online", ...]
"""
NAME = "RDATA"
- def __init__(self, stream_name, token, row):
+ def __init__(self, stream_name, instance_name, token, row):
self.stream_name = stream_name
+ self.instance_name = instance_name
self.token = token
self.row = row
@classmethod
def from_line(cls, line):
- stream_name, token, row_json = line.split(" ", 2)
+ stream_name, instance_name, token, row_json = line.split(" ", 3)
return cls(
- stream_name, None if token == "batch" else int(token), json.loads(row_json)
+ stream_name,
+ instance_name,
+ None if token == "batch" else int(token),
+ json.loads(row_json),
)
def to_line(self):
return " ".join(
(
self.stream_name,
+ self.instance_name,
str(self.token) if self.token is not None else "batch",
_json_encoder.encode(self.row),
)
@@ -136,26 +152,34 @@ class PositionCommand(Command):
"""Sent by the server to tell the client the stream postition without
needing to send an RDATA.
- Sent to the client after all missing updates for a stream have been sent
- to the client and they're now up to date.
+ Format::
+
+ POSITION <stream_name> <instance_name> <token>
+
+ On receipt of a POSITION command clients should check if they have missed
+ any updates, and if so then fetch them out of band.
+
+ The `<instance_name>` is the process that sent the command and is the source
+ of the stream.
"""
NAME = "POSITION"
- def __init__(self, stream_name, token):
+ def __init__(self, stream_name, instance_name, token):
self.stream_name = stream_name
+ self.instance_name = instance_name
self.token = token
@classmethod
def from_line(cls, line):
- stream_name, token = line.split(" ", 1)
- return cls(stream_name, int(token))
+ stream_name, instance_name, token = line.split(" ", 2)
+ return cls(stream_name, instance_name, int(token))
def to_line(self):
- return " ".join((self.stream_name, str(self.token)))
+ return " ".join((self.stream_name, self.instance_name, str(self.token)))
-class ErrorCommand(Command):
+class ErrorCommand(_SimpleCommand):
"""Sent by either side if there was an ERROR. The data is a string describing
the error.
"""
@@ -163,14 +187,14 @@ class ErrorCommand(Command):
NAME = "ERROR"
-class PingCommand(Command):
+class PingCommand(_SimpleCommand):
"""Sent by either side as a keep alive. The data is arbitary (often timestamp)
"""
NAME = "PING"
-class NameCommand(Command):
+class NameCommand(_SimpleCommand):
"""Sent by client to inform the server of the client's identity. The data
is the name
"""
@@ -179,76 +203,63 @@ class NameCommand(Command):
class ReplicateCommand(Command):
- """Sent by the client to subscribe to the stream.
+ """Sent by the client to subscribe to streams.
Format::
- REPLICATE <stream_name> <token>
-
- Where <token> may be either:
- * a numeric stream_id to stream updates from
- * "NOW" to stream all subsequent updates.
-
- The <stream_name> can be "ALL" to subscribe to all known streams, in which
- case the <token> must be set to "NOW", i.e.::
-
- REPLICATE ALL NOW
+ REPLICATE
"""
NAME = "REPLICATE"
- def __init__(self, stream_name, token):
- self.stream_name = stream_name
- self.token = token
+ def __init__(self):
+ pass
@classmethod
def from_line(cls, line):
- stream_name, token = line.split(" ", 1)
- if token in ("NOW", "now"):
- token = "NOW"
- else:
- token = int(token)
- return cls(stream_name, token)
+ return cls()
def to_line(self):
- return " ".join((self.stream_name, str(self.token)))
-
- def get_logcontext_id(self):
- return "REPLICATE-" + self.stream_name
+ return ""
class UserSyncCommand(Command):
"""Sent by the client to inform the server that a user has started or
- stopped syncing. Used to calculate presence on the master.
+ stopped syncing on this process.
+
+ This is used by the process handling presence (typically the master) to
+ calculate who is online and who is not.
Includes a timestamp of when the last user sync was.
Format::
- USER_SYNC <user_id> <state> <last_sync_ms>
+ USER_SYNC <instance_id> <user_id> <state> <last_sync_ms>
- Where <state> is either "start" or "stop"
+ Where <state> is either "start" or "end"
"""
NAME = "USER_SYNC"
- def __init__(self, user_id, is_syncing, last_sync_ms):
+ def __init__(self, instance_id, user_id, is_syncing, last_sync_ms):
+ self.instance_id = instance_id
self.user_id = user_id
self.is_syncing = is_syncing
self.last_sync_ms = last_sync_ms
@classmethod
def from_line(cls, line):
- user_id, state, last_sync_ms = line.split(" ", 2)
+ instance_id, user_id, state, last_sync_ms = line.split(" ", 3)
if state not in ("start", "end"):
raise Exception("Invalid USER_SYNC state %r" % (state,))
- return cls(user_id, state == "start", int(last_sync_ms))
+ return cls(instance_id, user_id, state == "start", int(last_sync_ms))
def to_line(self):
return " ".join(
(
+ self.instance_id,
self.user_id,
"start" if self.is_syncing else "end",
str(self.last_sync_ms),
@@ -256,6 +267,30 @@ class UserSyncCommand(Command):
)
+class ClearUserSyncsCommand(Command):
+ """Sent by the client to inform the server that it should drop all
+ information about syncing users sent by the client.
+
+ Mainly used when client is about to shut down.
+
+ Format::
+
+ CLEAR_USER_SYNC <instance_id>
+ """
+
+ NAME = "CLEAR_USER_SYNC"
+
+ def __init__(self, instance_id):
+ self.instance_id = instance_id
+
+ @classmethod
+ def from_line(cls, line):
+ return cls(line)
+
+ def to_line(self):
+ return self.instance_id
+
+
class FederationAckCommand(Command):
"""Sent by the client when it has processed up to a given point in the
federation stream. This allows the master to drop in-memory caches of the
@@ -281,14 +316,6 @@ class FederationAckCommand(Command):
return str(self.token)
-class SyncCommand(Command):
- """Used for testing. The client protocol implementation allows waiting
- on a SYNC command with a specified data.
- """
-
- NAME = "SYNC"
-
-
class RemovePusherCommand(Command):
"""Sent by the client to request the master remove the given pusher.
@@ -314,37 +341,6 @@ class RemovePusherCommand(Command):
return " ".join((self.app_id, self.push_key, self.user_id))
-class InvalidateCacheCommand(Command):
- """Sent by the client to invalidate an upstream cache.
-
- THIS IS NOT RELIABLE, AND SHOULD *NOT* BE USED ACCEPT FOR THINGS THAT ARE
- NOT DISASTROUS IF WE DROP ON THE FLOOR.
-
- Mainly used to invalidate destination retry timing caches.
-
- Format::
-
- INVALIDATE_CACHE <cache_func> <keys_json>
-
- Where <keys_json> is a json list.
- """
-
- NAME = "INVALIDATE_CACHE"
-
- def __init__(self, cache_func, keys):
- self.cache_func = cache_func
- self.keys = keys
-
- @classmethod
- def from_line(cls, line):
- cache_func, keys_json = line.split(" ", 1)
-
- return cls(cache_func, json.loads(keys_json))
-
- def to_line(self):
- return " ".join((self.cache_func, _json_encoder.encode(self.keys)))
-
-
class UserIpCommand(Command):
"""Sent periodically when a worker sees activity from a client.
@@ -387,7 +383,7 @@ class UserIpCommand(Command):
)
-class RemoteServerUpCommand(Command):
+class RemoteServerUpCommand(_SimpleCommand):
"""Sent when a worker has detected that a remote server is no longer
"down" and retry timings should be reset.
@@ -411,11 +407,10 @@ _COMMANDS = (
ReplicateCommand,
UserSyncCommand,
FederationAckCommand,
- SyncCommand,
RemovePusherCommand,
- InvalidateCacheCommand,
UserIpCommand,
RemoteServerUpCommand,
+ ClearUserSyncsCommand,
) # type: Tuple[Type[Command], ...]
# Map of command name to command type.
@@ -428,7 +423,6 @@ VALID_SERVER_COMMANDS = (
PositionCommand.NAME,
ErrorCommand.NAME,
PingCommand.NAME,
- SyncCommand.NAME,
RemoteServerUpCommand.NAME,
)
@@ -438,10 +432,28 @@ VALID_CLIENT_COMMANDS = (
ReplicateCommand.NAME,
PingCommand.NAME,
UserSyncCommand.NAME,
+ ClearUserSyncsCommand.NAME,
FederationAckCommand.NAME,
RemovePusherCommand.NAME,
- InvalidateCacheCommand.NAME,
UserIpCommand.NAME,
ErrorCommand.NAME,
RemoteServerUpCommand.NAME,
)
+
+
+def parse_command_from_line(line: str) -> Command:
+ """Parses a command from a received line.
+
+ Line should already be stripped of whitespace and be checked if blank.
+ """
+
+ idx = line.find(" ")
+ if idx >= 0:
+ cmd_name = line[:idx]
+ rest_of_line = line[idx + 1 :]
+ else:
+ cmd_name = line
+ rest_of_line = ""
+
+ cmd_cls = COMMAND_MAP[cmd_name]
+ return cmd_cls.from_line(rest_of_line)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
new file mode 100644
index 0000000000..cbcf46f3ae
--- /dev/null
+++ b/synapse/replication/tcp/handler.py
@@ -0,0 +1,596 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
+
+from prometheus_client import Counter
+
+from twisted.internet.protocol import ReconnectingClientFactory
+
+from synapse.metrics import LaterGauge
+from synapse.replication.tcp.client import DirectTcpReplicationClientFactory
+from synapse.replication.tcp.commands import (
+ ClearUserSyncsCommand,
+ Command,
+ FederationAckCommand,
+ PositionCommand,
+ RdataCommand,
+ RemoteServerUpCommand,
+ RemovePusherCommand,
+ ReplicateCommand,
+ UserIpCommand,
+ UserSyncCommand,
+)
+from synapse.replication.tcp.protocol import AbstractConnection
+from synapse.replication.tcp.streams import (
+ STREAMS_MAP,
+ BackfillStream,
+ CachesStream,
+ EventsStream,
+ FederationStream,
+ Stream,
+)
+from synapse.util.async_helpers import Linearizer
+
+logger = logging.getLogger(__name__)
+
+
+# number of updates received for each RDATA stream
+inbound_rdata_count = Counter(
+ "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"]
+)
+user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
+federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
+remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
+invalidate_cache_counter = Counter(
+ "synapse_replication_tcp_resource_invalidate_cache", ""
+)
+user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
+
+
+class ReplicationCommandHandler:
+ """Handles incoming commands from replication as well as sending commands
+ back out to connections.
+ """
+
+ def __init__(self, hs):
+ self._replication_data_handler = hs.get_replication_data_handler()
+ self._presence_handler = hs.get_presence_handler()
+ self._store = hs.get_datastore()
+ self._notifier = hs.get_notifier()
+ self._clock = hs.get_clock()
+ self._instance_id = hs.get_instance_id()
+ self._instance_name = hs.get_instance_name()
+
+ self._streams = {
+ stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
+ } # type: Dict[str, Stream]
+
+ # List of streams that this instance is the source of
+ self._streams_to_replicate = [] # type: List[Stream]
+
+ for stream in self._streams.values():
+ if stream.NAME == CachesStream.NAME:
+ # All workers can write to the cache invalidation stream.
+ self._streams_to_replicate.append(stream)
+ continue
+
+ if isinstance(stream, (EventsStream, BackfillStream)):
+ # Only add EventStream and BackfillStream as a source on the
+ # instance in charge of event persistence.
+ if hs.config.worker.writers.events == hs.get_instance_name():
+ self._streams_to_replicate.append(stream)
+
+ continue
+
+ # Only add any other streams if we're on master.
+ if hs.config.worker_app is not None:
+ continue
+
+ if stream.NAME == FederationStream.NAME and hs.config.send_federation:
+ # We only support federation stream if federation sending
+ # has been disabled on the master.
+ continue
+
+ self._streams_to_replicate.append(stream)
+
+ self._position_linearizer = Linearizer(
+ "replication_position", clock=self._clock
+ )
+
+ # Map of stream to batched updates. See RdataCommand for info on how
+ # batching works.
+ self._pending_batches = {} # type: Dict[str, List[Any]]
+
+ # The factory used to create connections.
+ self._factory = None # type: Optional[ReconnectingClientFactory]
+
+ # The currently connected connections. (The list of places we need to send
+ # outgoing replication commands to.)
+ self._connections = [] # type: List[AbstractConnection]
+
+ # For each connection, the incoming streams that are coming from that connection
+ self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
+
+ LaterGauge(
+ "synapse_replication_tcp_resource_total_connections",
+ "",
+ [],
+ lambda: len(self._connections),
+ )
+
+ self._is_master = hs.config.worker_app is None
+
+ self._federation_sender = None
+ if self._is_master and not hs.config.send_federation:
+ self._federation_sender = hs.get_federation_sender()
+
+ self._server_notices_sender = None
+ if self._is_master:
+ self._server_notices_sender = hs.get_server_notices_sender()
+
+ def start_replication(self, hs):
+ """Helper method to start a replication connection to the remote server
+ using TCP.
+ """
+ if hs.config.redis.redis_enabled:
+ from synapse.replication.tcp.redis import (
+ RedisDirectTcpReplicationClientFactory,
+ )
+ import txredisapi
+
+ logger.info(
+ "Connecting to redis (host=%r port=%r)",
+ hs.config.redis_host,
+ hs.config.redis_port,
+ )
+
+ # First let's ensure that we have a ReplicationStreamer started.
+ hs.get_replication_streamer()
+
+ # We need two connections to redis, one for the subscription stream and
+ # one to send commands to (as you can't send further redis commands to a
+ # connection after SUBSCRIBE is called).
+
+ # First create the connection for sending commands.
+ outbound_redis_connection = txredisapi.lazyConnection(
+ host=hs.config.redis_host,
+ port=hs.config.redis_port,
+ password=hs.config.redis.redis_password,
+ reconnect=True,
+ )
+
+ # Now create the factory/connection for the subscription stream.
+ self._factory = RedisDirectTcpReplicationClientFactory(
+ hs, outbound_redis_connection
+ )
+ hs.get_reactor().connectTCP(
+ hs.config.redis.redis_host, hs.config.redis.redis_port, self._factory,
+ )
+ else:
+ client_name = hs.get_instance_name()
+ self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
+ host = hs.config.worker_replication_host
+ port = hs.config.worker_replication_port
+ hs.get_reactor().connectTCP(host, port, self._factory)
+
+ def get_streams(self) -> Dict[str, Stream]:
+ """Get a map from stream name to all streams.
+ """
+ return self._streams
+
+ def get_streams_to_replicate(self) -> List[Stream]:
+ """Get a list of streams that this instances replicates.
+ """
+ return self._streams_to_replicate
+
+ async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
+ self.send_positions_to_connection(conn)
+
+ def send_positions_to_connection(self, conn: AbstractConnection):
+ """Send current position of all streams this process is source of to
+ the connection.
+ """
+
+ # We respond with current position of all streams this instance
+ # replicates.
+ for stream in self.get_streams_to_replicate():
+ self.send_command(
+ PositionCommand(
+ stream.NAME,
+ self._instance_name,
+ stream.current_token(self._instance_name),
+ )
+ )
+
+ async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
+ user_sync_counter.inc()
+
+ if self._is_master:
+ await self._presence_handler.update_external_syncs_row(
+ cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
+ )
+
+ async def on_CLEAR_USER_SYNC(
+ self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
+ ):
+ if self._is_master:
+ await self._presence_handler.update_external_syncs_clear(cmd.instance_id)
+
+ async def on_FEDERATION_ACK(
+ self, conn: AbstractConnection, cmd: FederationAckCommand
+ ):
+ federation_ack_counter.inc()
+
+ if self._federation_sender:
+ self._federation_sender.federation_ack(cmd.token)
+
+ async def on_REMOVE_PUSHER(
+ self, conn: AbstractConnection, cmd: RemovePusherCommand
+ ):
+ remove_pusher_counter.inc()
+
+ if self._is_master:
+ await self._store.delete_pusher_by_app_id_pushkey_user_id(
+ app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
+ )
+
+ self._notifier.on_new_replication_data()
+
+ async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
+ user_ip_cache_counter.inc()
+
+ if self._is_master:
+ await self._store.insert_client_ip(
+ cmd.user_id,
+ cmd.access_token,
+ cmd.ip,
+ cmd.user_agent,
+ cmd.device_id,
+ cmd.last_seen,
+ )
+
+ if self._server_notices_sender:
+ await self._server_notices_sender.on_user_ip(cmd.user_id)
+
+ async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
+ if cmd.instance_name == self._instance_name:
+ # Ignore RDATA that are just our own echoes
+ return
+
+ stream_name = cmd.stream_name
+ inbound_rdata_count.labels(stream_name).inc()
+
+ try:
+ row = STREAMS_MAP[stream_name].parse_row(cmd.row)
+ except Exception:
+ logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row)
+ raise
+
+ # We linearize here for two reasons:
+ # 1. so we don't try and concurrently handle multiple rows for the
+ # same stream, and
+ # 2. so we don't race with getting a POSITION command and fetching
+ # missing RDATA.
+ with await self._position_linearizer.queue(cmd.stream_name):
+ # make sure that we've processed a POSITION for this stream *on this
+ # connection*. (A POSITION on another connection is no good, as there
+ # is no guarantee that we have seen all the intermediate updates.)
+ sbc = self._streams_by_connection.get(conn)
+ if not sbc or stream_name not in sbc:
+ # Let's drop the row for now, on the assumption we'll receive a
+ # `POSITION` soon and we'll catch up correctly then.
+ logger.debug(
+ "Discarding RDATA for unconnected stream %s -> %s",
+ stream_name,
+ cmd.token,
+ )
+ return
+
+ if cmd.token is None:
+ # I.e. this is part of a batch of updates for this stream (in
+ # which case batch until we get an update for the stream with a non
+ # None token).
+ self._pending_batches.setdefault(stream_name, []).append(row)
+ else:
+ # Check if this is the last of a batch of updates
+ rows = self._pending_batches.pop(stream_name, [])
+ rows.append(row)
+ await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
+
+ async def on_rdata(
+ self, stream_name: str, instance_name: str, token: int, rows: list
+ ):
+ """Called to handle a batch of replication data with a given stream token.
+
+ Args:
+ stream_name: name of the replication stream for this batch of rows
+ instance_name: the instance that wrote the rows.
+ token: stream token for this batch of rows
+ rows: a list of Stream.ROW_TYPE objects as returned by
+ Stream.parse_row.
+ """
+ logger.debug("Received rdata %s (%s) -> %s", stream_name, instance_name, token)
+ await self._replication_data_handler.on_rdata(
+ stream_name, instance_name, token, rows
+ )
+
+ async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
+ if cmd.instance_name == self._instance_name:
+ # Ignore POSITION that are just our own echoes
+ return
+
+ logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line())
+
+ stream_name = cmd.stream_name
+ stream = self._streams.get(stream_name)
+ if not stream:
+ logger.error("Got POSITION for unknown stream: %s", stream_name)
+ return
+
+ # We protect catching up with a linearizer in case the replication
+ # connection reconnects under us.
+ with await self._position_linearizer.queue(stream_name):
+ # We're about to go and catch up with the stream, so remove from set
+ # of connected streams.
+ for streams in self._streams_by_connection.values():
+ streams.discard(stream_name)
+
+ # We clear the pending batches for the stream as the fetching of the
+ # missing updates below will fetch all rows in the batch.
+ self._pending_batches.pop(stream_name, [])
+
+ # Find where we previously streamed up to.
+ current_token = stream.current_token(cmd.instance_name)
+
+ # If the position token matches our current token then we're up to
+ # date and there's nothing to do. Otherwise, fetch all updates
+ # between then and now.
+ missing_updates = cmd.token != current_token
+ while missing_updates:
+ logger.info(
+ "Fetching replication rows for '%s' between %i and %i",
+ stream_name,
+ current_token,
+ cmd.token,
+ )
+ (
+ updates,
+ current_token,
+ missing_updates,
+ ) = await stream.get_updates_since(
+ cmd.instance_name, current_token, cmd.token
+ )
+
+ # TODO: add some tests for this
+
+ # Some streams return multiple rows with the same stream IDs,
+ # which need to be processed in batches.
+
+ for token, rows in _batch_updates(updates):
+ await self.on_rdata(
+ stream_name,
+ cmd.instance_name,
+ token,
+ [stream.parse_row(row) for row in rows],
+ )
+
+ logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
+
+ # We've now caught up to position sent to us, notify handler.
+ await self._replication_data_handler.on_position(
+ cmd.stream_name, cmd.instance_name, cmd.token
+ )
+
+ self._streams_by_connection.setdefault(conn, set()).add(stream_name)
+
+ async def on_REMOTE_SERVER_UP(
+ self, conn: AbstractConnection, cmd: RemoteServerUpCommand
+ ):
+ """"Called when get a new REMOTE_SERVER_UP command."""
+ self._replication_data_handler.on_remote_server_up(cmd.data)
+
+ self._notifier.notify_remote_server_up(cmd.data)
+
+ # We relay to all other connections to ensure every instance gets the
+ # notification.
+ #
+ # When configured to use redis we'll always only have one connection and
+ # so this is a no-op (all instances will have already received the same
+ # REMOTE_SERVER_UP command).
+ #
+ # For direct TCP connections this will relay to all other connections
+ # connected to us. When on master this will correctly fan out to all
+ # other direct TCP clients and on workers there'll only be the one
+ # connection to master.
+ #
+ # (The logic here should also be sound if we have a mix of Redis and
+ # direct TCP connections so long as there is only one traffic route
+ # between two instances, but that is not currently supported).
+ self.send_command(cmd, ignore_conn=conn)
+
+ def new_connection(self, connection: AbstractConnection):
+ """Called when we have a new connection.
+ """
+ self._connections.append(connection)
+
+ # If we are connected to replication as a client (rather than a server)
+ # we need to reset the reconnection delay on the client factory (which
+ # is used to do exponential back off when the connection drops).
+ #
+ # Ideally we would reset the delay when we've "fully established" the
+ # connection (for some definition thereof) to stop us from tightlooping
+ # on reconnection if something fails after this point and we drop the
+ # connection. Unfortunately, we don't really have a better definition of
+ # "fully established" than the connection being established.
+ if self._factory:
+ self._factory.resetDelay()
+
+ # Tell the other end if we have any users currently syncing.
+ currently_syncing = (
+ self._presence_handler.get_currently_syncing_users_for_replication()
+ )
+
+ now = self._clock.time_msec()
+ for user_id in currently_syncing:
+ connection.send_command(
+ UserSyncCommand(self._instance_id, user_id, True, now)
+ )
+
+ def lost_connection(self, connection: AbstractConnection):
+ """Called when a connection is closed/lost.
+ """
+ # we no longer need _streams_by_connection for this connection.
+ streams = self._streams_by_connection.pop(connection, None)
+ if streams:
+ logger.info(
+ "Lost replication connection; streams now disconnected: %s", streams
+ )
+ try:
+ self._connections.remove(connection)
+ except ValueError:
+ pass
+
+ def connected(self) -> bool:
+ """Do we have any replication connections open?
+
+ Is used by e.g. `ReplicationStreamer` to no-op if nothing is connected.
+ """
+ return bool(self._connections)
+
+ def send_command(
+ self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None
+ ):
+ """Send a command to all connected connections.
+
+ Args:
+ cmd
+ ignore_conn: If set don't send command to the given connection.
+ Used when relaying commands from one connection to all others.
+ """
+ if self._connections:
+ for connection in self._connections:
+ if connection == ignore_conn:
+ continue
+
+ try:
+ connection.send_command(cmd)
+ except Exception:
+ # We probably want to catch some types of exceptions here
+ # and log them as warnings (e.g. connection gone), but I
+ # can't find what those exception types they would be.
+ logger.exception(
+ "Failed to write command %s to connection %s",
+ cmd.NAME,
+ connection,
+ )
+ else:
+ logger.warning("Dropping command as not connected: %r", cmd.NAME)
+
+ def send_federation_ack(self, token: int):
+ """Ack data for the federation stream. This allows the master to drop
+ data stored purely in memory.
+ """
+ self.send_command(FederationAckCommand(token))
+
+ def send_user_sync(
+ self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
+ ):
+ """Poke the master that a user has started/stopped syncing.
+ """
+ self.send_command(
+ UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
+ )
+
+ def send_remove_pusher(self, app_id: str, push_key: str, user_id: str):
+ """Poke the master to remove a pusher for a user
+ """
+ cmd = RemovePusherCommand(app_id, push_key, user_id)
+ self.send_command(cmd)
+
+ def send_user_ip(
+ self,
+ user_id: str,
+ access_token: str,
+ ip: str,
+ user_agent: str,
+ device_id: str,
+ last_seen: int,
+ ):
+ """Tell the master that the user made a request.
+ """
+ cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
+ self.send_command(cmd)
+
+ def send_remote_server_up(self, server: str):
+ self.send_command(RemoteServerUpCommand(server))
+
+ def stream_update(self, stream_name: str, token: str, data: Any):
+ """Called when a new update is available to stream to clients.
+
+ We need to check if the client is interested in the stream or not
+ """
+ self.send_command(RdataCommand(stream_name, self._instance_name, token, data))
+
+
+UpdateToken = TypeVar("UpdateToken")
+UpdateRow = TypeVar("UpdateRow")
+
+
+def _batch_updates(
+ updates: Iterable[Tuple[UpdateToken, UpdateRow]]
+) -> Iterator[Tuple[UpdateToken, List[UpdateRow]]]:
+ """Collect stream updates with the same token together
+
+ Given a series of updates returned by Stream.get_updates_since(), collects
+ the updates which share the same stream_id together.
+
+ For example:
+
+ [(1, a), (1, b), (2, c), (3, d), (3, e)]
+
+ becomes:
+
+ [
+ (1, [a, b]),
+ (2, [c]),
+ (3, [d, e]),
+ ]
+ """
+
+ update_iter = iter(updates)
+
+ first_update = next(update_iter, None)
+ if first_update is None:
+ # empty input
+ return
+
+ current_batch_token = first_update[0]
+ current_batch = [first_update[1]]
+
+ for token, row in update_iter:
+ if token != current_batch_token:
+ # different token to the previous row: flush the previous
+ # batch and start anew
+ yield current_batch_token, current_batch
+ current_batch_token = token
+ current_batch = []
+
+ current_batch.append(row)
+
+ # flush the final batch
+ yield current_batch_token, current_batch
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index d185cc0c8f..4198eece71 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -35,9 +35,7 @@ indicate which side is sending, these are *not* included on the wire::
> PING 1490197665618
< NAME synapse.app.appservice
< PING 1490197665618
- < REPLICATE events 1
- < REPLICATE backfill 1
- < REPLICATE caches 1
+ < REPLICATE
> POSITION events 1
> POSITION backfill 1
> POSITION caches 1
@@ -52,45 +50,51 @@ import abc
import fcntl
import logging
import struct
-from collections import defaultdict
-from typing import Any, DefaultDict, Dict, List, Set, Tuple
-
-from six import iteritems, iterkeys
+from typing import TYPE_CHECKING, List
from prometheus_client import Counter
-from twisted.internet import defer
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
-from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.commands import (
- COMMAND_MAP,
VALID_CLIENT_COMMANDS,
VALID_SERVER_COMMANDS,
Command,
ErrorCommand,
NameCommand,
PingCommand,
- PositionCommand,
- RdataCommand,
- RemoteServerUpCommand,
ReplicateCommand,
ServerCommand,
- SyncCommand,
- UserSyncCommand,
+ parse_command_from_line,
)
-from synapse.replication.tcp.streams import STREAMS_MAP
from synapse.types import Collection
from synapse.util import Clock
from synapse.util.stringutils import random_string
+if TYPE_CHECKING:
+ from synapse.replication.tcp.handler import ReplicationCommandHandler
+ from synapse.server import HomeServer
+
+
connection_close_counter = Counter(
"synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
)
+tcp_inbound_commands_counter = Counter(
+ "synapse_replication_tcp_protocol_inbound_commands",
+ "Number of commands received from replication, by command and name of process connected to",
+ ["command", "name"],
+)
+
+tcp_outbound_commands_counter = Counter(
+ "synapse_replication_tcp_protocol_outbound_commands",
+ "Number of commands sent to replication, by command and name of process connected to",
+ ["command", "name"],
+)
+
# A list of all connected protocols. This allows us to send metrics about the
# connections.
connected_connections = []
@@ -119,7 +123,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
are only sent by the server.
On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed
- command.
+ command before delegating to `ReplicationCommandHandler.on_<COMMAND_NAME>`.
It also sends `PING` periodically, and correctly times out remote connections
(if they send a `PING` command)
@@ -135,8 +139,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
max_line_buffer = 10000
- def __init__(self, clock):
+ def __init__(self, clock: Clock, handler: "ReplicationCommandHandler"):
self.clock = clock
+ self.command_handler = handler
self.last_received_command = self.clock.time_msec()
self.last_sent_command = 0
@@ -155,9 +160,6 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# The LoopingCall for sending pings.
self._send_ping_loop = None
- self.inbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int]
- self.outbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int]
-
def connectionMade(self):
logger.info("[%s] Connection established", self.id())
@@ -176,6 +178,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# can time us out.
self.send_command(PingCommand(self.clock.time_msec()))
+ self.command_handler.new_connection(self)
+
def send_ping(self):
"""Periodically sends a ping and checks if we should close the connection
due to the other side timing out.
@@ -203,38 +207,30 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
)
self.send_error("ping timeout")
- def lineReceived(self, line):
+ def lineReceived(self, line: bytes):
"""Called when we've received a line
"""
if line.strip() == "":
# Ignore blank lines
return
- line = line.decode("utf-8")
- cmd_name, rest_of_line = line.split(" ", 1)
+ linestr = line.decode("utf-8")
- if cmd_name not in self.VALID_INBOUND_COMMANDS:
- logger.error("[%s] invalid command %s", self.id(), cmd_name)
- self.send_error("invalid command: %s", cmd_name)
+ try:
+ cmd = parse_command_from_line(linestr)
+ except Exception as e:
+ logger.exception("[%s] failed to parse line: %r", self.id(), linestr)
+ self.send_error("failed to parse line: %r (%r):" % (e, linestr))
return
- self.last_received_command = self.clock.time_msec()
+ if cmd.NAME not in self.VALID_INBOUND_COMMANDS:
+ logger.error("[%s] invalid command %s", self.id(), cmd.NAME)
+ self.send_error("invalid command: %s", cmd.NAME)
+ return
- self.inbound_commands_counter[cmd_name] = (
- self.inbound_commands_counter[cmd_name] + 1
- )
+ self.last_received_command = self.clock.time_msec()
- cmd_cls = COMMAND_MAP[cmd_name]
- try:
- cmd = cmd_cls.from_line(rest_of_line)
- except Exception as e:
- logger.exception(
- "[%s] failed to parse line %r: %r", self.id(), cmd_name, rest_of_line
- )
- self.send_error(
- "failed to parse line for %r: %r (%r):" % (cmd_name, e, rest_of_line)
- )
- return
+ tcp_inbound_commands_counter.labels(cmd.NAME, self.name).inc()
# Now lets try and call on_<CMD_NAME> function
run_as_background_process(
@@ -244,13 +240,31 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
async def handle_command(self, cmd: Command):
"""Handle a command we have received over the replication stream.
- By default delegates to on_<COMMAND>, which should return an awaitable.
+ First calls `self.on_<COMMAND>` if it exists, then calls
+ `self.command_handler.on_<COMMAND>` if it exists. This allows for
+ protocol level handling of commands (e.g. PINGs), before delegating to
+ the handler.
Args:
cmd: received command
"""
- handler = getattr(self, "on_%s" % (cmd.NAME,))
- await handler(cmd)
+ handled = False
+
+ # First call any command handlers on this instance. These are for TCP
+ # specific handling.
+ cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
+ if cmd_func:
+ await cmd_func(cmd)
+ handled = True
+
+ # Then call out to the handler.
+ cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
+ if cmd_func:
+ await cmd_func(self, cmd)
+ handled = True
+
+ if not handled:
+ logger.warning("Unhandled command: %r", cmd)
def close(self):
logger.warning("[%s] Closing connection", self.id())
@@ -282,9 +296,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self._queue_command(cmd)
return
- self.outbound_commands_counter[cmd.NAME] = (
- self.outbound_commands_counter[cmd.NAME] + 1
- )
+ tcp_outbound_commands_counter.labels(cmd.NAME, self.name).inc()
+
string = "%s %s" % (cmd.NAME, cmd.to_line())
if "\n" in string:
raise Exception("Unexpected newline in command: %r", string)
@@ -379,6 +392,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.state = ConnectionStates.CLOSED
self.pending_commands = []
+ self.command_handler.lost_connection(self)
+
if self.transport:
self.transport.unregisterProducer()
@@ -405,232 +420,21 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS
VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS
- def __init__(self, server_name, clock, streamer):
- BaseReplicationStreamProtocol.__init__(self, clock) # Old style class
+ def __init__(
+ self, server_name: str, clock: Clock, handler: "ReplicationCommandHandler"
+ ):
+ super().__init__(clock, handler)
self.server_name = server_name
- self.streamer = streamer
-
- # The streams the client has subscribed to and is up to date with
- self.replication_streams = set() # type: Set[str]
-
- # The streams the client is currently subscribing to.
- self.connecting_streams = set() # type: Set[str]
-
- # Map from stream name to list of updates to send once we've finished
- # subscribing the client to the stream.
- self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]]
def connectionMade(self):
self.send_command(ServerCommand(self.server_name))
- BaseReplicationStreamProtocol.connectionMade(self)
- self.streamer.new_connection(self)
+ super().connectionMade()
async def on_NAME(self, cmd):
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
self.name = cmd.data
- async def on_USER_SYNC(self, cmd):
- await self.streamer.on_user_sync(
- self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
- )
-
- async def on_REPLICATE(self, cmd):
- stream_name = cmd.stream_name
- token = cmd.token
-
- if stream_name == "ALL":
- # Subscribe to all streams we're publishing to.
- deferreds = [
- run_in_background(self.subscribe_to_stream, stream, token)
- for stream in iterkeys(self.streamer.streams_by_name)
- ]
-
- await make_deferred_yieldable(
- defer.gatherResults(deferreds, consumeErrors=True)
- )
- else:
- await self.subscribe_to_stream(stream_name, token)
-
- async def on_FEDERATION_ACK(self, cmd):
- self.streamer.federation_ack(cmd.token)
-
- async def on_REMOVE_PUSHER(self, cmd):
- await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
-
- async def on_INVALIDATE_CACHE(self, cmd):
- await self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
-
- async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
- self.streamer.on_remote_server_up(cmd.data)
-
- async def on_USER_IP(self, cmd):
- await self.streamer.on_user_ip(
- cmd.user_id,
- cmd.access_token,
- cmd.ip,
- cmd.user_agent,
- cmd.device_id,
- cmd.last_seen,
- )
-
- async def subscribe_to_stream(self, stream_name, token):
- """Subscribe the remote to a stream.
-
- This invloves checking if they've missed anything and sending those
- updates down if they have. During that time new updates for the stream
- are queued and sent once we've sent down any missed updates.
- """
- self.replication_streams.discard(stream_name)
- self.connecting_streams.add(stream_name)
-
- try:
- # Get missing updates
- updates, current_token = await self.streamer.get_stream_updates(
- stream_name, token
- )
-
- # Send all the missing updates
- for update in updates:
- token, row = update[0], update[1]
- self.send_command(RdataCommand(stream_name, token, row))
-
- # We send a POSITION command to ensure that they have an up to
- # date token (especially useful if we didn't send any updates
- # above)
- self.send_command(PositionCommand(stream_name, current_token))
-
- # Now we can send any updates that came in while we were subscribing
- pending_rdata = self.pending_rdata.pop(stream_name, [])
- updates = []
- for token, update in pending_rdata:
- # If the token is null, it is part of a batch update. Batches
- # are multiple updates that share a single token. To denote
- # this, the token is set to None for all tokens in the batch
- # except for the last. If we find a None token, we keep looking
- # through tokens until we find one that is not None and then
- # process all previous updates in the batch as if they had the
- # final token.
- if token is None:
- # Store this update as part of a batch
- updates.append(update)
- continue
-
- if token <= current_token:
- # This update or batch of updates is older than
- # current_token, dismiss it
- updates = []
- continue
-
- updates.append(update)
-
- # Send all updates that are part of this batch with the
- # found token
- for update in updates:
- self.send_command(RdataCommand(stream_name, token, update))
-
- # Clear stored updates
- updates = []
-
- # They're now fully subscribed
- self.replication_streams.add(stream_name)
- except Exception as e:
- logger.exception("[%s] Failed to handle REPLICATE command", self.id())
- self.send_error("failed to handle replicate: %r", e)
- finally:
- self.connecting_streams.discard(stream_name)
-
- def stream_update(self, stream_name, token, data):
- """Called when a new update is available to stream to clients.
-
- We need to check if the client is interested in the stream or not
- """
- if stream_name in self.replication_streams:
- # The client is subscribed to the stream
- self.send_command(RdataCommand(stream_name, token, data))
- elif stream_name in self.connecting_streams:
- # The client is being subscribed to the stream
- logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token)
- self.pending_rdata.setdefault(stream_name, []).append((token, data))
- else:
- # The client isn't subscribed
- logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token)
-
- def send_sync(self, data):
- self.send_command(SyncCommand(data))
-
- def send_remote_server_up(self, server: str):
- self.send_command(RemoteServerUpCommand(server))
-
- def on_connection_closed(self):
- BaseReplicationStreamProtocol.on_connection_closed(self)
- self.streamer.lost_connection(self)
-
-
-class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
- """
- The interface for the handler that should be passed to
- ClientReplicationStreamProtocol
- """
-
- @abc.abstractmethod
- async def on_rdata(self, stream_name, token, rows):
- """Called to handle a batch of replication data with a given stream token.
-
- Args:
- stream_name (str): name of the replication stream for this batch of rows
- token (int): stream token for this batch of rows
- rows (list): a list of Stream.ROW_TYPE objects as returned by
- Stream.parse_row.
- """
- raise NotImplementedError()
-
- @abc.abstractmethod
- async def on_position(self, stream_name, token):
- """Called when we get new position data."""
- raise NotImplementedError()
-
- @abc.abstractmethod
- def on_sync(self, data):
- """Called when get a new SYNC command."""
- raise NotImplementedError()
-
- @abc.abstractmethod
- async def on_remote_server_up(self, server: str):
- """Called when get a new REMOTE_SERVER_UP command."""
- raise NotImplementedError()
-
- @abc.abstractmethod
- def get_streams_to_replicate(self):
- """Called when a new connection has been established and we need to
- subscribe to streams.
-
- Returns:
- map from stream name to the most recent update we have for
- that stream (ie, the point we want to start replicating from)
- """
- raise NotImplementedError()
-
- @abc.abstractmethod
- def get_currently_syncing_users(self):
- """Get the list of currently syncing users (if any). This is called
- when a connection has been established and we need to send the
- currently syncing users."""
- raise NotImplementedError()
-
- @abc.abstractmethod
- def update_connection(self, connection):
- """Called when a connection has been established (or lost with None).
- """
- raise NotImplementedError()
-
- @abc.abstractmethod
- def finished_connecting(self):
- """Called when we have successfully subscribed and caught up to all
- streams we're interested in.
- """
- raise NotImplementedError()
-
class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
@@ -638,110 +442,51 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
def __init__(
self,
+ hs: "HomeServer",
client_name: str,
server_name: str,
clock: Clock,
- handler: AbstractReplicationClientHandler,
+ command_handler: "ReplicationCommandHandler",
):
- BaseReplicationStreamProtocol.__init__(self, clock)
+ super().__init__(clock, command_handler)
self.client_name = client_name
self.server_name = server_name
- self.handler = handler
-
- # Set of stream names that have been subscribe to, but haven't yet
- # caught up with. This is used to track when the client has been fully
- # connected to the remote.
- self.streams_connecting = set() # type: Set[str]
-
- # Map of stream to batched updates. See RdataCommand for info on how
- # batching works.
- self.pending_batches = {} # type: Dict[str, Any]
def connectionMade(self):
self.send_command(NameCommand(self.client_name))
- BaseReplicationStreamProtocol.connectionMade(self)
+ super().connectionMade()
# Once we've connected subscribe to the necessary streams
- for stream_name, token in iteritems(self.handler.get_streams_to_replicate()):
- self.replicate(stream_name, token)
-
- # Tell the server if we have any users currently syncing (should only
- # happen on synchrotrons)
- currently_syncing = self.handler.get_currently_syncing_users()
- now = self.clock.time_msec()
- for user_id in currently_syncing:
- self.send_command(UserSyncCommand(user_id, True, now))
-
- # We've now finished connecting to so inform the client handler
- self.handler.update_connection(self)
-
- # This will happen if we don't actually subscribe to any streams
- if not self.streams_connecting:
- self.handler.finished_connecting()
+ self.replicate()
async def on_SERVER(self, cmd):
if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
self.send_error("Wrong remote")
- async def on_RDATA(self, cmd):
- stream_name = cmd.stream_name
- inbound_rdata_count.labels(stream_name).inc()
-
- try:
- row = STREAMS_MAP[stream_name].parse_row(cmd.row)
- except Exception:
- logger.exception(
- "[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row
- )
- raise
-
- if cmd.token is None:
- # I.e. this is part of a batch of updates for this stream. Batch
- # until we get an update for the stream with a non None token
- self.pending_batches.setdefault(stream_name, []).append(row)
- else:
- # Check if this is the last of a batch of updates
- rows = self.pending_batches.pop(stream_name, [])
- rows.append(row)
- await self.handler.on_rdata(stream_name, cmd.token, rows)
-
- async def on_POSITION(self, cmd):
- # When we get a `POSITION` command it means we've finished getting
- # missing updates for the given stream, and are now up to date.
- self.streams_connecting.discard(cmd.stream_name)
- if not self.streams_connecting:
- self.handler.finished_connecting()
+ def replicate(self):
+ """Send the subscription request to the server
+ """
+ logger.info("[%s] Subscribing to replication streams", self.id())
- await self.handler.on_position(cmd.stream_name, cmd.token)
+ self.send_command(ReplicateCommand())
- async def on_SYNC(self, cmd):
- self.handler.on_sync(cmd.data)
- async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
- self.handler.on_remote_server_up(cmd.data)
+class AbstractConnection(abc.ABC):
+ """An interface for replication connections.
+ """
- def replicate(self, stream_name, token):
- """Send the subscription request to the server
+ @abc.abstractmethod
+ def send_command(self, cmd: Command):
+ """Send the command down the connection
"""
- if stream_name not in STREAMS_MAP:
- raise Exception("Invalid stream name %r" % (stream_name,))
-
- logger.info(
- "[%s] Subscribing to replication stream: %r from %r",
- self.id(),
- stream_name,
- token,
- )
-
- self.streams_connecting.add(stream_name)
+ pass
- self.send_command(ReplicateCommand(stream_name, token))
- def on_connection_closed(self):
- BaseReplicationStreamProtocol.on_connection_closed(self)
- self.handler.update_connection(None)
+# This tells python that `BaseReplicationStreamProtocol` implements the
+# interface.
+AbstractConnection.register(BaseReplicationStreamProtocol)
# The following simply registers metrics for the replication connections
@@ -804,31 +549,3 @@ tcp_transport_kernel_read_buffer = LaterGauge(
for p in connected_connections
},
)
-
-
-tcp_inbound_commands = LaterGauge(
- "synapse_replication_tcp_protocol_inbound_commands",
- "",
- ["command", "name"],
- lambda: {
- (k, p.name): count
- for p in connected_connections
- for k, count in iteritems(p.inbound_commands_counter)
- },
-)
-
-tcp_outbound_commands = LaterGauge(
- "synapse_replication_tcp_protocol_outbound_commands",
- "",
- ["command", "name"],
- lambda: {
- (k, p.name): count
- for p in connected_connections
- for k, count in iteritems(p.outbound_commands_counter)
- },
-)
-
-# number of updates received for each RDATA stream
-inbound_rdata_count = Counter(
- "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"]
-)
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
new file mode 100644
index 0000000000..e776b63183
--- /dev/null
+++ b/synapse/replication/tcp/redis.py
@@ -0,0 +1,215 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING
+
+import txredisapi
+
+from synapse.logging.context import make_deferred_yieldable
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.replication.tcp.commands import (
+ Command,
+ ReplicateCommand,
+ parse_command_from_line,
+)
+from synapse.replication.tcp.protocol import (
+ AbstractConnection,
+ tcp_inbound_commands_counter,
+ tcp_outbound_commands_counter,
+)
+
+if TYPE_CHECKING:
+ from synapse.replication.tcp.handler import ReplicationCommandHandler
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
+ """Connection to redis subscribed to replication stream.
+
+ This class fulfils two functions:
+
+ (a) it implements the twisted Protocol API, where it handles the SUBSCRIBEd redis
+ connection, parsing *incoming* messages into replication commands, and passing them
+ to `ReplicationCommandHandler`
+
+ (b) it implements the AbstractConnection API, where it sends *outgoing* commands
+ onto outbound_redis_connection.
+
+ Due to the vagaries of `txredisapi` we don't want to have a custom
+ constructor, so instead we expect the defined attributes below to be set
+ immediately after initialisation.
+
+ Attributes:
+ handler: The command handler to handle incoming commands.
+ stream_name: The *redis* stream name to subscribe to and publish from
+ (not anything to do with Synapse replication streams).
+ outbound_redis_connection: The connection to redis to use to send
+ commands.
+ """
+
+ handler = None # type: ReplicationCommandHandler
+ stream_name = None # type: str
+ outbound_redis_connection = None # type: txredisapi.RedisProtocol
+
+ def connectionMade(self):
+ logger.info("Connected to redis")
+ super().connectionMade()
+ run_as_background_process("subscribe-replication", self._send_subscribe)
+
+ async def _send_subscribe(self):
+ # it's important to make sure that we only send the REPLICATE command once we
+ # have successfully subscribed to the stream - otherwise we might miss the
+ # POSITION response sent back by the other end.
+ logger.info("Sending redis SUBSCRIBE for %s", self.stream_name)
+ await make_deferred_yieldable(self.subscribe(self.stream_name))
+ logger.info(
+ "Successfully subscribed to redis stream, sending REPLICATE command"
+ )
+ self.handler.new_connection(self)
+ await self._async_send_command(ReplicateCommand())
+ logger.info("REPLICATE successfully sent")
+
+ # We send out our positions when there is a new connection in case the
+ # other side missed updates. We do this for Redis connections as the
+ # otherside won't know we've connected and so won't issue a REPLICATE.
+ self.handler.send_positions_to_connection(self)
+
+ def messageReceived(self, pattern: str, channel: str, message: str):
+ """Received a message from redis.
+ """
+
+ if message.strip() == "":
+ # Ignore blank lines
+ return
+
+ try:
+ cmd = parse_command_from_line(message)
+ except Exception:
+ logger.exception(
+ "Failed to parse replication line: %r", message,
+ )
+ return
+
+ # We use "redis" as the name here as we don't have 1:1 connections to
+ # remote instances.
+ tcp_inbound_commands_counter.labels(cmd.NAME, "redis").inc()
+
+ # Now lets try and call on_<CMD_NAME> function
+ run_as_background_process(
+ "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
+ )
+
+ async def handle_command(self, cmd: Command):
+ """Handle a command we have received over the replication stream.
+
+ By default delegates to on_<COMMAND>, which should return an awaitable.
+
+ Args:
+ cmd: received command
+ """
+ handled = False
+
+ # First call any command handlers on this instance. These are for redis
+ # specific handling.
+ cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
+ if cmd_func:
+ await cmd_func(cmd)
+ handled = True
+
+ # Then call out to the handler.
+ cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
+ if cmd_func:
+ await cmd_func(self, cmd)
+ handled = True
+
+ if not handled:
+ logger.warning("Unhandled command: %r", cmd)
+
+ def connectionLost(self, reason):
+ logger.info("Lost connection to redis")
+ super().connectionLost(reason)
+ self.handler.lost_connection(self)
+
+ def send_command(self, cmd: Command):
+ """Send a command if connection has been established.
+
+ Args:
+ cmd (Command)
+ """
+ run_as_background_process("send-cmd", self._async_send_command, cmd)
+
+ async def _async_send_command(self, cmd: Command):
+ """Encode a replication command and send it over our outbound connection"""
+ string = "%s %s" % (cmd.NAME, cmd.to_line())
+ if "\n" in string:
+ raise Exception("Unexpected newline in command: %r", string)
+
+ encoded_string = string.encode("utf-8")
+
+ # We use "redis" as the name here as we don't have 1:1 connections to
+ # remote instances.
+ tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
+
+ await make_deferred_yieldable(
+ self.outbound_redis_connection.publish(self.stream_name, encoded_string)
+ )
+
+
+class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
+ """This is a reconnecting factory that connects to redis and immediately
+ subscribes to a stream.
+
+ Args:
+ hs
+ outbound_redis_connection: A connection to redis that will be used to
+ send outbound commands (this is seperate to the redis connection
+ used to subscribe).
+ """
+
+ maxDelay = 5
+ continueTrying = True
+ protocol = RedisSubscriber
+
+ def __init__(
+ self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol
+ ):
+
+ super().__init__()
+
+ # This sets the password on the RedisFactory base class (as
+ # SubscriberFactory constructor doesn't pass it through).
+ self.password = hs.config.redis.redis_password
+
+ self.handler = hs.get_tcp_replication()
+ self.stream_name = hs.hostname
+
+ self.outbound_redis_connection = outbound_redis_connection
+
+ def buildProtocol(self, addr):
+ p = super().buildProtocol(addr) # type: RedisSubscriber
+
+ # We do this here rather than add to the constructor of `RedisSubcriber`
+ # as to do so would involve overriding `buildProtocol` entirely, however
+ # the base method does some other things than just instantiating the
+ # protocol.
+ p.handler = self.handler
+ p.outbound_redis_connection = self.outbound_redis_connection
+ p.stream_name = self.stream_name
+ p.password = self.password
+
+ return p
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index ce9d1fae12..41569305df 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -17,32 +17,18 @@
import logging
import random
-from typing import Any, List
-
-from six import itervalues
from prometheus_client import Counter
from twisted.internet.protocol import Factory
-from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.util.metrics import Measure, measure_func
-
-from .protocol import ServerReplicationStreamProtocol
-from .streams import STREAMS_MAP
-from .streams.federation import FederationStream
+from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
+from synapse.util.metrics import Measure
stream_updates_counter = Counter(
"synapse_replication_tcp_resource_stream_updates", "", ["stream_name"]
)
-user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
-federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
-remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
-invalidate_cache_counter = Counter(
- "synapse_replication_tcp_resource_invalidate_cache", ""
-)
-user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
logger = logging.getLogger(__name__)
@@ -52,13 +38,23 @@ class ReplicationStreamProtocolFactory(Factory):
"""
def __init__(self, hs):
- self.streamer = ReplicationStreamer(hs)
+ self.command_handler = hs.get_tcp_replication()
self.clock = hs.get_clock()
self.server_name = hs.config.server_name
+ # If we've created a `ReplicationStreamProtocolFactory` then we're
+ # almost certainly registering a replication listener, so let's ensure
+ # that we've started a `ReplicationStreamer` instance to actually push
+ # data.
+ #
+ # (This is a bit of a weird place to do this, but the alternatives such
+ # as putting this in `HomeServer.setup()`, requires either passing the
+ # listener config again or always starting a `ReplicationStreamer`.)
+ hs.get_replication_streamer()
+
def buildProtocol(self, addr):
return ServerReplicationStreamProtocol(
- self.server_name, self.clock, self.streamer
+ self.server_name, self.clock, self.command_handler
)
@@ -71,67 +67,22 @@ class ReplicationStreamer(object):
def __init__(self, hs):
self.store = hs.get_datastore()
- self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
- self._server_notices_sender = hs.get_server_notices_sender()
+ self._instance_name = hs.get_instance_name()
self._replication_torture_level = hs.config.replication_torture_level
- # Current connections.
- self.connections = [] # type: List[ServerReplicationStreamProtocol]
-
- LaterGauge(
- "synapse_replication_tcp_resource_total_connections",
- "",
- [],
- lambda: len(self.connections),
- )
-
- # List of streams that clients can subscribe to.
- # We only support federation stream if federation sending hase been
- # disabled on the master.
- self.streams = [
- stream(hs)
- for stream in itervalues(STREAMS_MAP)
- if stream != FederationStream or not hs.config.send_federation
- ]
-
- self.streams_by_name = {stream.NAME: stream for stream in self.streams}
-
- LaterGauge(
- "synapse_replication_tcp_resource_connections_per_stream",
- "",
- ["stream_name"],
- lambda: {
- (stream_name,): len(
- [
- conn
- for conn in self.connections
- if stream_name in conn.replication_streams
- ]
- )
- for stream_name in self.streams_by_name
- },
- )
-
- self.federation_sender = None
- if not hs.config.send_federation:
- self.federation_sender = hs.get_federation_sender()
-
self.notifier.add_replication_callback(self.on_notifier_poke)
- self.notifier.add_remote_server_up_callback(self.send_remote_server_up)
# Keeps track of whether we are currently checking for updates
self.is_looping = False
self.pending_updates = False
- hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.on_shutdown)
+ self.command_handler = hs.get_tcp_replication()
- def on_shutdown(self):
- # close all connections on shutdown
- for conn in self.connections:
- conn.send_error("server shutting down")
+ # Set of streams to replicate.
+ self.streams = self.command_handler.get_streams_to_replicate()
def on_notifier_poke(self):
"""Checks if there is actually any new data and sends it to the
@@ -140,7 +91,7 @@ class ReplicationStreamer(object):
This should get called each time new data is available, even if it
is currently being executed, so that nothing gets missed
"""
- if not self.connections:
+ if not self.command_handler.connected():
# Don't bother if nothing is listening. We still need to advance
# the stream tokens otherwise they'll fall beihind forever
for stream in self.streams:
@@ -166,11 +117,6 @@ class ReplicationStreamer(object):
self.pending_updates = False
with Measure(self.clock, "repl.stream.get_updates"):
- # First we tell the streams that they should update their
- # current tokens.
- for stream in self.streams:
- stream.advance_current_token()
-
all_streams = self.streams
if self._replication_torture_level is not None:
@@ -180,7 +126,9 @@ class ReplicationStreamer(object):
random.shuffle(all_streams)
for stream in all_streams:
- if stream.last_token == stream.upto_token:
+ if stream.last_token == stream.current_token(
+ self._instance_name
+ ):
continue
if self._replication_torture_level:
@@ -192,18 +140,17 @@ class ReplicationStreamer(object):
"Getting stream: %s: %s -> %s",
stream.NAME,
stream.last_token,
- stream.upto_token,
+ stream.current_token(self._instance_name),
)
try:
- updates, current_token = await stream.get_updates()
+ updates, current_token, limited = await stream.get_updates()
+ self.pending_updates |= limited
except Exception:
logger.info("Failed to handle stream %s", stream.NAME)
raise
logger.debug(
- "Sending %d updates to %d connections",
- len(updates),
- len(self.connections),
+ "Sending %d updates", len(updates),
)
if updates:
@@ -219,116 +166,19 @@ class ReplicationStreamer(object):
# token. See RdataCommand for more details.
batched_updates = _batch_updates(updates)
- for conn in self.connections:
- for token, row in batched_updates:
- try:
- conn.stream_update(stream.NAME, token, row)
- except Exception:
- logger.exception("Failed to replicate")
+ for token, row in batched_updates:
+ try:
+ self.command_handler.stream_update(
+ stream.NAME, token, row
+ )
+ except Exception:
+ logger.exception("Failed to replicate")
logger.debug("No more pending updates, breaking poke loop")
finally:
self.pending_updates = False
self.is_looping = False
- @measure_func("repl.get_stream_updates")
- async def get_stream_updates(self, stream_name, token):
- """For a given stream get all updates since token. This is called when
- a client first subscribes to a stream.
- """
- stream = self.streams_by_name.get(stream_name, None)
- if not stream:
- raise Exception("unknown stream %s", stream_name)
-
- return await stream.get_updates_since(token)
-
- @measure_func("repl.federation_ack")
- def federation_ack(self, token):
- """We've received an ack for federation stream from a client.
- """
- federation_ack_counter.inc()
- if self.federation_sender:
- self.federation_sender.federation_ack(token)
-
- @measure_func("repl.on_user_sync")
- async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
- """A client has started/stopped syncing on a worker.
- """
- user_sync_counter.inc()
- await self.presence_handler.update_external_syncs_row(
- conn_id, user_id, is_syncing, last_sync_ms
- )
-
- @measure_func("repl.on_remove_pusher")
- async def on_remove_pusher(self, app_id, push_key, user_id):
- """A client has asked us to remove a pusher
- """
- remove_pusher_counter.inc()
- await self.store.delete_pusher_by_app_id_pushkey_user_id(
- app_id=app_id, pushkey=push_key, user_id=user_id
- )
-
- self.notifier.on_new_replication_data()
-
- @measure_func("repl.on_invalidate_cache")
- async def on_invalidate_cache(self, cache_func: str, keys: List[Any]):
- """The client has asked us to invalidate a cache
- """
- invalidate_cache_counter.inc()
-
- # We invalidate the cache locally, but then also stream that to other
- # workers.
- await self.store.invalidate_cache_and_stream(cache_func, tuple(keys))
-
- @measure_func("repl.on_user_ip")
- async def on_user_ip(
- self, user_id, access_token, ip, user_agent, device_id, last_seen
- ):
- """The client saw a user request
- """
- user_ip_cache_counter.inc()
- await self.store.insert_client_ip(
- user_id, access_token, ip, user_agent, device_id, last_seen
- )
- await self._server_notices_sender.on_user_ip(user_id)
-
- @measure_func("repl.on_remote_server_up")
- def on_remote_server_up(self, server: str):
- self.notifier.notify_remote_server_up(server)
-
- def send_remote_server_up(self, server: str):
- for conn in self.connections:
- conn.send_remote_server_up(server)
-
- def send_sync_to_all_connections(self, data):
- """Sends a SYNC command to all clients.
-
- Used in tests.
- """
- for conn in self.connections:
- conn.send_sync(data)
-
- def new_connection(self, connection):
- """A new client connection has been established
- """
- self.connections.append(connection)
-
- def lost_connection(self, connection):
- """A client connection has been lost
- """
- try:
- self.connections.remove(connection)
- except ValueError:
- pass
-
- # We need to tell the presence handler that the connection has been
- # lost so that it can handle any ongoing syncs on that connection.
- run_as_background_process(
- "update_external_syncs_clear",
- self.presence_handler.update_external_syncs_clear,
- connection.conn_id,
- )
-
def _batch_updates(updates):
"""Takes a list of updates of form [(token, row)] and sets the token to
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index 5f52264e84..d1a61c3314 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -25,26 +25,63 @@ Each stream is defined by the following information:
update_function: The function that returns a list of updates between two tokens
"""
-from . import _base, events, federation
+from synapse.replication.tcp.streams._base import (
+ AccountDataStream,
+ BackfillStream,
+ CachesStream,
+ DeviceListsStream,
+ GroupServerStream,
+ PresenceStream,
+ PublicRoomsStream,
+ PushersStream,
+ PushRulesStream,
+ ReceiptsStream,
+ Stream,
+ TagAccountDataStream,
+ ToDeviceStream,
+ TypingStream,
+ UserSignatureStream,
+)
+from synapse.replication.tcp.streams.events import EventsStream
+from synapse.replication.tcp.streams.federation import FederationStream
STREAMS_MAP = {
stream.NAME: stream
for stream in (
- events.EventsStream,
- _base.BackfillStream,
- _base.PresenceStream,
- _base.TypingStream,
- _base.ReceiptsStream,
- _base.PushRulesStream,
- _base.PushersStream,
- _base.CachesStream,
- _base.PublicRoomsStream,
- _base.DeviceListsStream,
- _base.ToDeviceStream,
- federation.FederationStream,
- _base.TagAccountDataStream,
- _base.AccountDataStream,
- _base.GroupServerStream,
- _base.UserSignatureStream,
+ EventsStream,
+ BackfillStream,
+ PresenceStream,
+ TypingStream,
+ ReceiptsStream,
+ PushRulesStream,
+ PushersStream,
+ CachesStream,
+ PublicRoomsStream,
+ DeviceListsStream,
+ ToDeviceStream,
+ FederationStream,
+ TagAccountDataStream,
+ AccountDataStream,
+ GroupServerStream,
+ UserSignatureStream,
)
}
+
+__all__ = [
+ "STREAMS_MAP",
+ "Stream",
+ "BackfillStream",
+ "PresenceStream",
+ "TypingStream",
+ "ReceiptsStream",
+ "PushRulesStream",
+ "PushersStream",
+ "CachesStream",
+ "PublicRoomsStream",
+ "DeviceListsStream",
+ "ToDeviceStream",
+ "TagAccountDataStream",
+ "AccountDataStream",
+ "GroupServerStream",
+ "UserSignatureStream",
+]
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 208e8a667b..d42aaff055 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -14,117 +14,84 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import itertools
+import heapq
import logging
from collections import namedtuple
-from typing import Any, List, Optional
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ List,
+ Optional,
+ Tuple,
+ TypeVar,
+)
import attr
-logger = logging.getLogger(__name__)
+from synapse.replication.http.streams import ReplicationGetStreamUpdates
+if TYPE_CHECKING:
+ import synapse.server
-MAX_EVENTS_BEHIND = 500000
+logger = logging.getLogger(__name__)
-BackfillStreamRow = namedtuple(
- "BackfillStreamRow",
- (
- "event_id", # str
- "room_id", # str
- "type", # str
- "state_key", # str, optional
- "redacts", # str, optional
- "relates_to", # str, optional
- ),
-)
-PresenceStreamRow = namedtuple(
- "PresenceStreamRow",
- (
- "user_id", # str
- "state", # str
- "last_active_ts", # int
- "last_federation_update_ts", # int
- "last_user_sync_ts", # int
- "status_msg", # str
- "currently_active", # bool
- ),
-)
-TypingStreamRow = namedtuple(
- "TypingStreamRow", ("room_id", "user_ids") # str # list(str)
-)
-ReceiptsStreamRow = namedtuple(
- "ReceiptsStreamRow",
- (
- "room_id", # str
- "receipt_type", # str
- "user_id", # str
- "event_id", # str
- "data", # dict
- ),
-)
-PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str
-PushersStreamRow = namedtuple(
- "PushersStreamRow",
- ("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool
-)
+# the number of rows to request from an update_function.
+_STREAM_UPDATE_TARGET_ROW_COUNT = 100
-@attr.s
-class CachesStreamRow:
- """Stream to inform workers they should invalidate their cache.
+# Some type aliases to make things a bit easier.
- Attributes:
- cache_func: Name of the cached function.
- keys: The entry in the cache to invalidate. If None then will
- invalidate all.
- invalidation_ts: Timestamp of when the invalidation took place.
- """
+# A stream position token
+Token = int
- cache_func = attr.ib(type=str)
- keys = attr.ib(type=Optional[List[Any]])
- invalidation_ts = attr.ib(type=int)
+# The type of a stream update row, after JSON deserialisation, but before
+# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
+# just a row from a database query, though this is dependent on the stream in question.
+#
+StreamRow = TypeVar("StreamRow", bound=Tuple)
+# The type returned by the update_function of a stream, as well as get_updates(),
+# get_updates_since, etc.
+#
+# It consists of a triplet `(updates, new_last_token, limited)`, where:
+# * `updates` is a list of `(token, row)` entries.
+# * `new_last_token` is the new position in stream.
+# * `limited` is whether there are more updates to fetch.
+#
+StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
-PublicRoomsStreamRow = namedtuple(
- "PublicRoomsStreamRow",
- (
- "room_id", # str
- "visibility", # str
- "appservice_id", # str, optional
- "network_id", # str, optional
- ),
-)
-DeviceListsStreamRow = namedtuple(
- "DeviceListsStreamRow", ("user_id", "destination") # str # str
-)
-ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str
-TagAccountDataStreamRow = namedtuple(
- "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict
-)
-AccountDataStreamRow = namedtuple(
- "AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str
-)
-GroupsStreamRow = namedtuple(
- "GroupsStreamRow",
- ("group_id", "user_id", "type", "content"), # str # str # str # dict
-)
-UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str
+# The type of an update_function for a stream
+#
+# The arguments are:
+#
+# * instance_name: the writer of the stream
+# * from_token: the previous stream token: the starting point for fetching the
+# updates
+# * to_token: the new stream token: the point to get updates up to
+# * target_row_count: a target for the number of rows to be returned.
+#
+# The update_function is expected to return up to _approximately_ target_row_count rows.
+# If there are more updates available, it should set `limited` in the result, and
+# it will be called again to get the next batch.
+#
+UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]]
class Stream(object):
"""Base class for the streams.
Provides a `get_updates()` function that returns new updates since the last
- time it was called up until the point `advance_current_token` was called.
+ time it was called.
"""
NAME = None # type: str # The name of the stream
# The type of the row. Used by the default impl of parse_row.
ROW_TYPE = None # type: Any
- _LIMITED = True # Whether the update function takes a limit
@classmethod
- def parse_row(cls, row):
+ def parse_row(cls, row: StreamRow):
"""Parse a row received over replication
By default, assumes that the row data is an array object and passes its contents
@@ -138,101 +105,138 @@ class Stream(object):
"""
return cls.ROW_TYPE(*row)
- def __init__(self, hs):
- # The token from which we last asked for updates
- self.last_token = self.current_token()
-
- # The token that we will get updates up to
- self.upto_token = self.current_token()
+ def __init__(
+ self,
+ local_instance_name: str,
+ current_token_function: Callable[[str], Token],
+ update_function: UpdateFunction,
+ ):
+ """Instantiate a Stream
+
+ `current_token_function` and `update_function` are callbacks which
+ should be implemented by subclasses.
+
+ `current_token_function` takes an instance name, which is a writer to
+ the stream, and returns the position in the stream of the writer (as
+ viewed from the current process). On the writer process this is where
+ the writer has successfully written up to, whereas on other processes
+ this is the position which we have received updates up to over
+ replication. (Note that most streams have a single writer and so their
+ implementations ignore the instance name passed in).
+
+ `update_function` is called to get updates for this stream between a
+ pair of stream tokens. See the `UpdateFunction` type definition for more
+ info.
- def advance_current_token(self):
- """Updates `upto_token` to "now", which updates up until which point
- get_updates[_since] will fetch rows till.
+ Args:
+ local_instance_name: The instance name of the current process
+ current_token_function: callback to get the current token, as above
+ update_function: callback go get stream updates, as above
"""
- self.upto_token = self.current_token()
+ self.local_instance_name = local_instance_name
+ self.current_token = current_token_function
+ self.update_function = update_function
+
+ # The token from which we last asked for updates
+ self.last_token = self.current_token(self.local_instance_name)
def discard_updates_and_advance(self):
"""Called when the stream should advance but the updates would be discarded,
e.g. when there are no currently connected workers.
"""
- self.upto_token = self.current_token()
- self.last_token = self.upto_token
+ self.last_token = self.current_token(self.local_instance_name)
- async def get_updates(self):
+ async def get_updates(self) -> StreamUpdateResult:
"""Gets all updates since the last time this function was called (or
- since the stream was constructed if it hadn't been called before),
- until the `upto_token`
+ since the stream was constructed if it hadn't been called before).
Returns:
- Deferred[Tuple[List[Tuple[int, Any]], int]:
- Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
- list of ``(token, row)`` entries. ``row`` will be json-serialised and
- sent over the replication steam.
+ A triplet `(updates, new_last_token, limited)`, where `updates` is
+ a list of `(token, row)` entries, `new_last_token` is the new
+ position in stream, and `limited` is whether there are more updates
+ to fetch.
"""
- updates, current_token = await self.get_updates_since(self.last_token)
+ current_token = self.current_token(self.local_instance_name)
+ updates, current_token, limited = await self.get_updates_since(
+ self.local_instance_name, self.last_token, current_token
+ )
self.last_token = current_token
- return updates, current_token
+ return updates, current_token, limited
- async def get_updates_since(self, from_token):
+ async def get_updates_since(
+ self, instance_name: str, from_token: Token, upto_token: Token
+ ) -> StreamUpdateResult:
"""Like get_updates except allows specifying from when we should
stream updates
Returns:
- Deferred[Tuple[List[Tuple[int, Any]], int]:
- Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
- list of ``(token, row)`` entries. ``row`` will be json-serialised and
- sent over the replication steam.
+ A triplet `(updates, new_last_token, limited)`, where `updates` is
+ a list of `(token, row)` entries, `new_last_token` is the new
+ position in stream, and `limited` is whether there are more updates
+ to fetch.
"""
- if from_token in ("NOW", "now"):
- return [], self.upto_token
-
- current_token = self.upto_token
from_token = int(from_token)
- if from_token == current_token:
- return [], current_token
+ if from_token == upto_token:
+ return [], upto_token, False
- logger.info("get_updates_since: %s", self.__class__)
- if self._LIMITED:
- rows = await self.update_function(
- from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
- )
+ updates, upto_token, limited = await self.update_function(
+ instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
+ )
+ return updates, upto_token, limited
- # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
- rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
- else:
- rows = await self.update_function(from_token, current_token)
+def current_token_without_instance(
+ current_token: Callable[[], int]
+) -> Callable[[str], int]:
+ """Takes a current token callback function for a single writer stream
+ that doesn't take an instance name parameter and wraps it in a function that
+ does accept an instance name parameter but ignores it.
+ """
+ return lambda instance_name: current_token()
+
+
+def db_query_to_update_function(
+ query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
+) -> UpdateFunction:
+ """Wraps a db query function which returns a list of rows to make it
+ suitable for use as an `update_function` for the Stream class
+ """
+
+ async def update_function(instance_name, from_token, upto_token, limit):
+ rows = await query_function(from_token, upto_token, limit)
updates = [(row[0], row[1:]) for row in rows]
+ limited = False
+ if len(updates) >= limit:
+ upto_token = updates[-1][0]
+ limited = True
- # check we didn't get more rows than the limit.
- # doing it like this allows the update_function to be a generator.
- if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND:
- raise Exception("stream %s has fallen behind" % (self.NAME))
+ return updates, upto_token, limited
- return updates, current_token
+ return update_function
- def current_token(self):
- """Gets the current token of the underlying streams. Should be provided
- by the sub classes
- Returns:
- int
- """
- raise NotImplementedError()
+def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
+ """Makes a suitable function for use as an `update_function` that queries
+ the master process for updates.
+ """
- def update_function(self, from_token, current_token, limit=None):
- """Get updates between from_token and to_token. If Stream._LIMITED is
- True then limit is provided, otherwise it's not.
+ client = ReplicationGetStreamUpdates.make_client(hs)
- Returns:
- Deferred(list(tuple)): the first entry in the tuple is the token for
- that update, and the rest of the tuple gets used to construct
- a ``ROW_TYPE`` instance
- """
- raise NotImplementedError()
+ async def update_function(
+ instance_name: str, from_token: int, upto_token: int, limit: int
+ ) -> StreamUpdateResult:
+ result = await client(
+ instance_name=instance_name,
+ stream_name=stream_name,
+ from_token=from_token,
+ upto_token=upto_token,
+ )
+ return result["updates"], result["upto_token"], result["limited"]
+
+ return update_function
class BackfillStream(Stream):
@@ -240,93 +244,170 @@ class BackfillStream(Stream):
or it went from being an outlier to not.
"""
+ BackfillStreamRow = namedtuple(
+ "BackfillStreamRow",
+ (
+ "event_id", # str
+ "room_id", # str
+ "type", # str
+ "state_key", # str, optional
+ "redacts", # str, optional
+ "relates_to", # str, optional
+ ),
+ )
+
NAME = "backfill"
ROW_TYPE = BackfillStreamRow
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_current_backfill_token # type: ignore
- self.update_function = store.get_all_new_backfill_event_rows # type: ignore
-
- super(BackfillStream, self).__init__(hs)
+ super().__init__(
+ hs.get_instance_name(),
+ current_token_without_instance(store.get_current_backfill_token),
+ db_query_to_update_function(store.get_all_new_backfill_event_rows),
+ )
class PresenceStream(Stream):
+ PresenceStreamRow = namedtuple(
+ "PresenceStreamRow",
+ (
+ "user_id", # str
+ "state", # str
+ "last_active_ts", # int
+ "last_federation_update_ts", # int
+ "last_user_sync_ts", # int
+ "status_msg", # str
+ "currently_active", # bool
+ ),
+ )
+
NAME = "presence"
- _LIMITED = False
ROW_TYPE = PresenceStreamRow
def __init__(self, hs):
store = hs.get_datastore()
- presence_handler = hs.get_presence_handler()
- self.current_token = store.get_current_presence_token # type: ignore
- self.update_function = presence_handler.get_all_presence_updates # type: ignore
+ if hs.config.worker_app is None:
+ # on the master, query the presence handler
+ presence_handler = hs.get_presence_handler()
+ update_function = db_query_to_update_function(
+ presence_handler.get_all_presence_updates
+ )
+ else:
+ # Query master process
+ update_function = make_http_update_function(hs, self.NAME)
- super(PresenceStream, self).__init__(hs)
+ super().__init__(
+ hs.get_instance_name(),
+ current_token_without_instance(store.get_current_presence_token),
+ update_function,
+ )
class TypingStream(Stream):
+ TypingStreamRow = namedtuple(
+ "TypingStreamRow", ("room_id", "user_ids") # str # list(str)
+ )
+
NAME = "typing"
- _LIMITED = False
ROW_TYPE = TypingStreamRow
def __init__(self, hs):
typing_handler = hs.get_typing_handler()
- self.current_token = typing_handler.get_current_token # type: ignore
- self.update_function = typing_handler.get_all_typing_updates # type: ignore
+ if hs.config.worker_app is None:
+ # on the master, query the typing handler
+ update_function = db_query_to_update_function(
+ typing_handler.get_all_typing_updates
+ )
+ else:
+ # Query master process
+ update_function = make_http_update_function(hs, self.NAME)
- super(TypingStream, self).__init__(hs)
+ super().__init__(
+ hs.get_instance_name(),
+ current_token_without_instance(typing_handler.get_current_token),
+ update_function,
+ )
class ReceiptsStream(Stream):
+ ReceiptsStreamRow = namedtuple(
+ "ReceiptsStreamRow",
+ (
+ "room_id", # str
+ "receipt_type", # str
+ "user_id", # str
+ "event_id", # str
+ "data", # dict
+ ),
+ )
+
NAME = "receipts"
ROW_TYPE = ReceiptsStreamRow
def __init__(self, hs):
store = hs.get_datastore()
-
- self.current_token = store.get_max_receipt_stream_id # type: ignore
- self.update_function = store.get_all_updated_receipts # type: ignore
-
- super(ReceiptsStream, self).__init__(hs)
+ super().__init__(
+ hs.get_instance_name(),
+ current_token_without_instance(store.get_max_receipt_stream_id),
+ db_query_to_update_function(store.get_all_updated_receipts),
+ )
class PushRulesStream(Stream):
"""A user has changed their push rules
"""
+ PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str
+
NAME = "push_rules"
ROW_TYPE = PushRulesStreamRow
def __init__(self, hs):
self.store = hs.get_datastore()
- super(PushRulesStream, self).__init__(hs)
+ super(PushRulesStream, self).__init__(
+ hs.get_instance_name(), self._current_token, self._update_function
+ )
- def current_token(self):
+ def _current_token(self, instance_name: str) -> int:
push_rules_token, _ = self.store.get_push_rules_stream_token()
return push_rules_token
- async def update_function(self, from_token, to_token, limit):
+ async def _update_function(
+ self, instance_name: str, from_token: Token, to_token: Token, limit: int
+ ):
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
- return [(row[0], row[2]) for row in rows]
+
+ limited = False
+ if len(rows) == limit:
+ to_token = rows[-1][0]
+ limited = True
+
+ return [(row[0], (row[2],)) for row in rows], to_token, limited
class PushersStream(Stream):
"""A user has added/changed/removed a pusher
"""
+ PushersStreamRow = namedtuple(
+ "PushersStreamRow",
+ ("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool
+ )
+
NAME = "pushers"
ROW_TYPE = PushersStreamRow
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_pushers_stream_token # type: ignore
- self.update_function = store.get_all_updated_pushers_rows # type: ignore
-
- super(PushersStream, self).__init__(hs)
+ super().__init__(
+ hs.get_instance_name(),
+ current_token_without_instance(store.get_pushers_stream_token),
+ db_query_to_update_function(store.get_all_updated_pushers_rows),
+ )
class CachesStream(Stream):
@@ -334,136 +415,229 @@ class CachesStream(Stream):
the cache on the workers
"""
+ @attr.s
+ class CachesStreamRow:
+ """Stream to inform workers they should invalidate their cache.
+
+ Attributes:
+ cache_func: Name of the cached function.
+ keys: The entry in the cache to invalidate. If None then will
+ invalidate all.
+ invalidation_ts: Timestamp of when the invalidation took place.
+ """
+
+ cache_func = attr.ib(type=str)
+ keys = attr.ib(type=Optional[List[Any]])
+ invalidation_ts = attr.ib(type=int)
+
NAME = "caches"
ROW_TYPE = CachesStreamRow
def __init__(self, hs):
- store = hs.get_datastore()
+ self.store = hs.get_datastore()
+ super().__init__(
+ hs.get_instance_name(),
+ self.store.get_cache_stream_token,
+ self._update_function,
+ )
- self.current_token = store.get_cache_stream_token # type: ignore
- self.update_function = store.get_all_updated_caches # type: ignore
+ async def _update_function(
+ self, instance_name: str, from_token: int, upto_token: int, limit: int
+ ):
+ rows = await self.store.get_all_updated_caches(
+ instance_name, from_token, upto_token, limit
+ )
+ updates = [(row[0], row[1:]) for row in rows]
+ limited = False
+ if len(updates) >= limit:
+ upto_token = updates[-1][0]
+ limited = True
- super(CachesStream, self).__init__(hs)
+ return updates, upto_token, limited
class PublicRoomsStream(Stream):
"""The public rooms list changed
"""
+ PublicRoomsStreamRow = namedtuple(
+ "PublicRoomsStreamRow",
+ (
+ "room_id", # str
+ "visibility", # str
+ "appservice_id", # str, optional
+ "network_id", # str, optional
+ ),
+ )
+
NAME = "public_rooms"
ROW_TYPE = PublicRoomsStreamRow
def __init__(self, hs):
store = hs.get_datastore()
-
- self.current_token = store.get_current_public_room_stream_id # type: ignore
- self.update_function = store.get_all_new_public_rooms # type: ignore
-
- super(PublicRoomsStream, self).__init__(hs)
+ super().__init__(
+ hs.get_instance_name(),
+ current_token_without_instance(store.get_current_public_room_stream_id),
+ db_query_to_update_function(store.get_all_new_public_rooms),
+ )
class DeviceListsStream(Stream):
- """Someone added/changed/removed a device
+ """Either a user has updated their devices or a remote server needs to be
+ told about a device update.
"""
+ @attr.s
+ class DeviceListsStreamRow:
+ entity = attr.ib(type=str)
+
NAME = "device_lists"
- _LIMITED = False
ROW_TYPE = DeviceListsStreamRow
def __init__(self, hs):
store = hs.get_datastore()
-
- self.current_token = store.get_device_stream_token # type: ignore
- self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore
-
- super(DeviceListsStream, self).__init__(hs)
+ super().__init__(
+ hs.get_instance_name(),
+ current_token_without_instance(store.get_device_stream_token),
+ db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
+ )
class ToDeviceStream(Stream):
"""New to_device messages for a client
"""
+ ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str
+
NAME = "to_device"
ROW_TYPE = ToDeviceStreamRow
def __init__(self, hs):
store = hs.get_datastore()
-
- self.current_token = store.get_to_device_stream_token # type: ignore
- self.update_function = store.get_all_new_device_messages # type: ignore
-
- super(ToDeviceStream, self).__init__(hs)
+ super().__init__(
+ hs.get_instance_name(),
+ current_token_without_instance(store.get_to_device_stream_token),
+ db_query_to_update_function(store.get_all_new_device_messages),
+ )
class TagAccountDataStream(Stream):
"""Someone added/removed a tag for a room
"""
+ TagAccountDataStreamRow = namedtuple(
+ "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict
+ )
+
NAME = "tag_account_data"
ROW_TYPE = TagAccountDataStreamRow
def __init__(self, hs):
store = hs.get_datastore()
-
- self.current_token = store.get_max_account_data_stream_id # type: ignore
- self.update_function = store.get_all_updated_tags # type: ignore
-
- super(TagAccountDataStream, self).__init__(hs)
+ super().__init__(
+ hs.get_instance_name(),
+ current_token_without_instance(store.get_max_account_data_stream_id),
+ db_query_to_update_function(store.get_all_updated_tags),
+ )
class AccountDataStream(Stream):
"""Global or per room account data was changed
"""
+ AccountDataStreamRow = namedtuple(
+ "AccountDataStream",
+ ("user_id", "room_id", "data_type"), # str # Optional[str] # str
+ )
+
NAME = "account_data"
ROW_TYPE = AccountDataStreamRow
- def __init__(self, hs):
+ def __init__(self, hs: "synapse.server.HomeServer"):
self.store = hs.get_datastore()
+ super().__init__(
+ hs.get_instance_name(),
+ current_token_without_instance(self.store.get_max_account_data_stream_id),
+ self._update_function,
+ )
- self.current_token = self.store.get_max_account_data_stream_id # type: ignore
+ async def _update_function(
+ self, instance_name: str, from_token: int, to_token: int, limit: int
+ ) -> StreamUpdateResult:
+ limited = False
+ global_results = await self.store.get_updated_global_account_data(
+ from_token, to_token, limit
+ )
- super(AccountDataStream, self).__init__(hs)
+ # if the global results hit the limit, we'll need to limit the room results to
+ # the same stream token.
+ if len(global_results) >= limit:
+ to_token = global_results[-1][0]
+ limited = True
- async def update_function(self, from_token, to_token, limit):
- global_results, room_results = await self.store.get_all_updated_account_data(
- from_token, from_token, to_token, limit
+ room_results = await self.store.get_updated_room_account_data(
+ from_token, to_token, limit
)
- results = list(room_results)
- results.extend(
- (stream_id, user_id, None, account_data_type)
+ # likewise, if the room results hit the limit, limit the global results to
+ # the same stream token.
+ if len(room_results) >= limit:
+ to_token = room_results[-1][0]
+ limited = True
+
+ # convert the global results to the right format, and limit them to the to_token
+ # at the same time
+ global_rows = (
+ (stream_id, (user_id, None, account_data_type))
for stream_id, user_id, account_data_type in global_results
+ if stream_id <= to_token
)
- return results
+ # we know that the room_results are already limited to `to_token` so no need
+ # for a check on `stream_id` here.
+ room_rows = (
+ (stream_id, (user_id, room_id, account_data_type))
+ for stream_id, user_id, room_id, account_data_type in room_results
+ )
+
+ # we need to return a sorted list, so merge them together.
+ updates = list(heapq.merge(room_rows, global_rows))
+ return updates, to_token, limited
class GroupServerStream(Stream):
+ GroupsStreamRow = namedtuple(
+ "GroupsStreamRow",
+ ("group_id", "user_id", "type", "content"), # str # str # str # dict
+ )
+
NAME = "groups"
ROW_TYPE = GroupsStreamRow
def __init__(self, hs):
store = hs.get_datastore()
-
- self.current_token = store.get_group_stream_token # type: ignore
- self.update_function = store.get_all_groups_changes # type: ignore
-
- super(GroupServerStream, self).__init__(hs)
+ super().__init__(
+ hs.get_instance_name(),
+ current_token_without_instance(store.get_group_stream_token),
+ db_query_to_update_function(store.get_all_groups_changes),
+ )
class UserSignatureStream(Stream):
"""A user has signed their own device with their user-signing key
"""
+ UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str
+
NAME = "user_signature"
- _LIMITED = False
ROW_TYPE = UserSignatureStreamRow
def __init__(self, hs):
store = hs.get_datastore()
-
- self.current_token = store.get_device_stream_token # type: ignore
- self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore
-
- super(UserSignatureStream, self).__init__(hs)
+ super().__init__(
+ hs.get_instance_name(),
+ current_token_without_instance(store.get_device_stream_token),
+ db_query_to_update_function(
+ store.get_all_user_signature_changes_for_remotes
+ ),
+ )
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index b3afabb8cd..f370390331 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -15,11 +15,12 @@
# limitations under the License.
import heapq
-from typing import Tuple, Type
+from collections import Iterable
+from typing import List, Tuple, Type
import attr
-from ._base import Stream
+from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
"""Handling of the 'events' replication stream
@@ -116,28 +117,107 @@ class EventsStream(Stream):
def __init__(self, hs):
self._store = hs.get_datastore()
- self.current_token = self._store.get_current_events_token # type: ignore
+ super().__init__(
+ hs.get_instance_name(),
+ current_token_without_instance(self._store.get_current_events_token),
+ self._update_function,
+ )
- super(EventsStream, self).__init__(hs)
+ async def _update_function(
+ self,
+ instance_name: str,
+ from_token: Token,
+ current_token: Token,
+ target_row_count: int,
+ ) -> StreamUpdateResult:
+
+ # the events stream merges together three separate sources:
+ # * new events
+ # * current_state changes
+ # * events which were previously outliers, but have now been de-outliered.
+ #
+ # The merge operation is complicated by the fact that we only have a single
+ # "stream token" which is supposed to indicate how far we have got through
+ # all three streams. It's therefore no good to return rows 1-1000 from the
+ # "new events" table if the state_deltas are limited to rows 1-100 by the
+ # target_row_count.
+ #
+ # In other words: we must pick a new upper limit, and must return *all* rows
+ # up to that point for each of the three sources.
+ #
+ # Start by trying to split the target_row_count up. We expect to have a
+ # negligible number of ex-outliers, and a rough approximation based on recent
+ # traffic on sw1v.org shows that there are approximately the same number of
+ # event rows between a given pair of stream ids as there are state
+ # updates, so let's split our target_row_count among those two types. The target
+ # is only an approximation - it doesn't matter if we end up going a bit over it.
+
+ target_row_count //= 2
+
+ # now we fetch up to that many rows from the events table
- async def update_function(self, from_token, current_token, limit=None):
event_rows = await self._store.get_all_new_forward_event_rows(
- from_token, current_token, limit
- )
- event_updates = (
- (row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows
+ from_token, current_token, target_row_count
+ ) # type: List[Tuple]
+
+ # we rely on get_all_new_forward_event_rows strictly honouring the limit, so
+ # that we know it is safe to just take upper_limit = event_rows[-1][0].
+ assert (
+ len(event_rows) <= target_row_count
+ ), "get_all_new_forward_event_rows did not honour row limit"
+
+ # if we hit the limit on event_updates, there's no point in going beyond the
+ # last stream_id in the batch for the other sources.
+
+ if len(event_rows) == target_row_count:
+ limited = True
+ upper_limit = event_rows[-1][0] # type: int
+ else:
+ limited = False
+ upper_limit = current_token
+
+ # next up is the state delta table.
+ (
+ state_rows,
+ upper_limit,
+ state_rows_limited,
+ ) = await self._store.get_all_updated_current_state_deltas(
+ from_token, upper_limit, target_row_count
)
- state_rows = await self._store.get_all_updated_current_state_deltas(
- from_token, current_token, limit
- )
- state_updates = (
- (row[0], EventsStreamCurrentStateRow.TypeId, row[1:]) for row in state_rows
- )
+ limited = limited or state_rows_limited
- all_updates = heapq.merge(event_updates, state_updates)
+ # finally, fetch the ex-outliers rows. We assume there are few enough of these
+ # not to bother with the limit.
- return all_updates
+ ex_outliers_rows = await self._store.get_ex_outlier_stream_rows(
+ from_token, upper_limit
+ ) # type: List[Tuple]
+
+ # we now need to turn the raw database rows returned into tuples suitable
+ # for the replication protocol (basically, we add an identifier to
+ # distinguish the row type). At the same time, we can limit the event_rows
+ # to the max stream_id from state_rows.
+
+ event_updates = (
+ (stream_id, (EventsStreamEventRow.TypeId, rest))
+ for (stream_id, *rest) in event_rows
+ if stream_id <= upper_limit
+ ) # type: Iterable[Tuple[int, Tuple]]
+
+ state_updates = (
+ (stream_id, (EventsStreamCurrentStateRow.TypeId, rest))
+ for (stream_id, *rest) in state_rows
+ ) # type: Iterable[Tuple[int, Tuple]]
+
+ ex_outliers_updates = (
+ (stream_id, (EventsStreamEventRow.TypeId, rest))
+ for (stream_id, *rest) in ex_outliers_rows
+ ) # type: Iterable[Tuple[int, Tuple]]
+
+ # we need to return a sorted list, so merge them together.
+ updates = list(heapq.merge(event_updates, state_updates, ex_outliers_updates))
+ return updates, upper_limit, limited
@classmethod
def parse_row(cls, row):
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 615f3dc9ac..9bcd13b009 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -15,14 +15,10 @@
# limitations under the License.
from collections import namedtuple
-from ._base import Stream
-
-FederationStreamRow = namedtuple(
- "FederationStreamRow",
- (
- "type", # str, the type of data as defined in the BaseFederationRows
- "data", # dict, serialization of a federation.send_queue.BaseFederationRow
- ),
+from synapse.replication.tcp.streams._base import (
+ Stream,
+ current_token_without_instance,
+ make_http_update_function,
)
@@ -31,13 +27,47 @@ class FederationStream(Stream):
sending disabled.
"""
+ FederationStreamRow = namedtuple(
+ "FederationStreamRow",
+ (
+ "type", # str, the type of data as defined in the BaseFederationRows
+ "data", # dict, serialization of a federation.send_queue.BaseFederationRow
+ ),
+ )
+
NAME = "federation"
ROW_TYPE = FederationStreamRow
def __init__(self, hs):
- federation_sender = hs.get_federation_sender()
+ if hs.config.worker_app is None:
+ # master process: get updates from the FederationRemoteSendQueue.
+ # (if the master is configured to send federation itself, federation_sender
+ # will be a real FederationSender, which has stubs for current_token and
+ # get_replication_rows.)
+ federation_sender = hs.get_federation_sender()
+ current_token = current_token_without_instance(
+ federation_sender.get_current_token
+ )
+ update_function = federation_sender.get_replication_rows
+
+ elif hs.should_send_federation():
+ # federation sender: Query master process
+ update_function = make_http_update_function(hs, self.NAME)
+ current_token = self._stub_current_token
+
+ else:
+ # other worker: stub out the update function (we're not interested in
+ # any updates so when we get a POSITION we do nothing)
+ update_function = self._stub_update_function
+ current_token = self._stub_current_token
+
+ super().__init__(hs.get_instance_name(), current_token, update_function)
- self.current_token = federation_sender.get_current_token # type: ignore
- self.update_function = federation_sender.get_replication_rows # type: ignore
+ @staticmethod
+ def _stub_current_token(instance_name: str) -> int:
+ # dummy current-token method for use on workers
+ return 0
- super(FederationStream, self).__init__(hs)
+ @staticmethod
+ async def _stub_update_function(instance_name, from_token, upto_token, limit):
+ return [], upto_token, False
|