summary refs log tree commit diff
path: root/synapse/replication/tcp
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp')
-rw-r--r--synapse/replication/tcp/client.py59
-rw-r--r--synapse/replication/tcp/commands.py130
-rw-r--r--synapse/replication/tcp/protocol.py259
-rw-r--r--synapse/replication/tcp/resource.py110
-rw-r--r--synapse/replication/tcp/streams/__init__.py1
-rw-r--r--synapse/replication/tcp/streams/_base.py285
-rw-r--r--synapse/replication/tcp/streams/events.py57
-rw-r--r--synapse/replication/tcp/streams/federation.py16
8 files changed, 562 insertions, 355 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py

index 206dc3b397..02ab5b66ea 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py
@@ -16,18 +16,26 @@ """ import logging +from typing import Dict, List, Optional from twisted.internet import defer from twisted.internet.protocol import ReconnectingClientFactory +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.tcp.protocol import ( + AbstractReplicationClientHandler, + ClientReplicationStreamProtocol, +) + from .commands import ( + Command, FederationAckCommand, InvalidateCacheCommand, + RemoteServerUpCommand, RemovePusherCommand, UserIpCommand, UserSyncCommand, ) -from .protocol import ClientReplicationStreamProtocol logger = logging.getLogger(__name__) @@ -39,9 +47,11 @@ class ReplicationClientFactory(ReconnectingClientFactory): Accepts a handler that will be called when new data is available or data is required. """ - maxDelay = 30 # Try at least once every N seconds - def __init__(self, hs, client_name, handler): + initialDelay = 0.1 + maxDelay = 1 # Try at least once every N seconds + + def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler): self.client_name = client_name self.handler = handler self.server_name = hs.config.server_name @@ -64,17 +74,16 @@ class ReplicationClientFactory(ReconnectingClientFactory): def clientConnectionFailed(self, connector, reason): logger.error("Failed to connect to replication: %r", reason) - ReconnectingClientFactory.clientConnectionFailed( - self, connector, reason - ) + ReconnectingClientFactory.clientConnectionFailed(self, connector, reason) -class ReplicationClientHandler(object): +class ReplicationClientHandler(AbstractReplicationClientHandler): """A base handler that can be passed to the ReplicationClientFactory. By default proxies incoming replication data to the SlaveStore. """ - def __init__(self, store): + + def __init__(self, store: BaseSlavedStore): self.store = store # The current connection. None if we are currently (re)connecting @@ -82,15 +91,15 @@ class ReplicationClientHandler(object): # Any pending commands to be sent once a new connection has been # established - self.pending_commands = [] + self.pending_commands = [] # type: List[Command] # Map from string -> deferred, to wake up when receiveing a SYNC with # the given string. # Used for tests. - self.awaiting_syncs = {} + self.awaiting_syncs = {} # type: Dict[str, defer.Deferred] # The factory used to create connections. - self.factory = None + self.factory = None # type: Optional[ReplicationClientFactory] def start_replication(self, hs): """Helper method to start a replication connection to the remote server @@ -102,7 +111,7 @@ class ReplicationClientHandler(object): port = hs.config.worker_replication_port hs.get_reactor().connectTCP(host, port, self.factory) - def on_rdata(self, stream_name, token, rows): + async def on_rdata(self, stream_name, token, rows): """Called to handle a batch of replication data with a given stream token. By default this just pokes the slave store. Can be overridden in subclasses to @@ -113,20 +122,17 @@ class ReplicationClientHandler(object): token (int): stream token for this batch of rows rows (list): a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. - - Returns: - Deferred|None """ logger.debug("Received rdata %s -> %s", stream_name, token) - return self.store.process_replication_rows(stream_name, token, rows) + self.store.process_replication_rows(stream_name, token, rows) - def on_position(self, stream_name, token): + async def on_position(self, stream_name, token): """Called when we get new position data. By default this just pokes the slave store. Can be overriden in subclasses to handle more. """ - return self.store.process_replication_rows(stream_name, token, []) + self.store.process_replication_rows(stream_name, token, []) def on_sync(self, data): """When we received a SYNC we wake up any deferreds that were waiting @@ -138,11 +144,16 @@ class ReplicationClientHandler(object): if d: d.callback(data) - def get_streams_to_replicate(self): + def on_remote_server_up(self, server: str): + """Called when get a new REMOTE_SERVER_UP command.""" + + def get_streams_to_replicate(self) -> Dict[str, int]: """Called when a new connection has been established and we need to subscribe to streams. - Returns a dictionary of stream name to token. + Returns: + map from stream name to the most recent update we have for + that stream (ie, the point we want to start replicating from) """ args = self.store.stream_positions() user_account_data = args.pop("user_account_data", None) @@ -168,7 +179,7 @@ class ReplicationClientHandler(object): if self.connection: self.connection.send_command(cmd) else: - logger.warn("Queuing command as not connected: %r", cmd.NAME) + logger.warning("Queuing command as not connected: %r", cmd.NAME) self.pending_commands.append(cmd) def send_federation_ack(self, token): @@ -200,6 +211,9 @@ class ReplicationClientHandler(object): cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen) self.send_command(cmd) + def send_remote_server_up(self, server: str): + self.send_command(RemoteServerUpCommand(server)) + def await_sync(self, data): """Returns a deferred that is resolved when we receive a SYNC command with given data. @@ -226,4 +240,5 @@ class ReplicationClientHandler(object): # We don't reset the delay any earlier as otherwise if there is a # problem during start up we'll end up tight looping connecting to the # server. - self.factory.resetDelay() + if self.factory: + self.factory.resetDelay() diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 2098c32a77..451671412d 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py
@@ -20,13 +20,16 @@ allowed to be sent by which side. import logging import platform +from typing import Tuple, Type if platform.python_implementation() == "PyPy": import json + _json_encoder = json.JSONEncoder() else: - import simplejson as json - _json_encoder = json.JSONEncoder(namedtuple_as_object=False) + import simplejson as json # type: ignore[no-redef] # noqa: F821 + + _json_encoder = json.JSONEncoder(namedtuple_as_object=False) # type: ignore[call-arg] # noqa: F821 logger = logging.getLogger(__name__) @@ -41,7 +44,8 @@ class Command(object): The default implementation creates a command of form `<NAME> <data>` """ - NAME = None + + NAME = None # type: str def __init__(self, data): self.data = data @@ -73,6 +77,7 @@ class ServerCommand(Command): SERVER <server_name> """ + NAME = "SERVER" @@ -99,6 +104,7 @@ class RdataCommand(Command): RDATA presence batch ["@bar:example.com", "online", ...] RDATA presence 59 ["@baz:example.com", "online", ...] """ + NAME = "RDATA" def __init__(self, stream_name, token, row): @@ -110,17 +116,17 @@ class RdataCommand(Command): def from_line(cls, line): stream_name, token, row_json = line.split(" ", 2) return cls( - stream_name, - None if token == "batch" else int(token), - json.loads(row_json) + stream_name, None if token == "batch" else int(token), json.loads(row_json) ) def to_line(self): - return " ".join(( - self.stream_name, - str(self.token) if self.token is not None else "batch", - _json_encoder.encode(self.row), - )) + return " ".join( + ( + self.stream_name, + str(self.token) if self.token is not None else "batch", + _json_encoder.encode(self.row), + ) + ) def get_logcontext_id(self): return "RDATA-" + self.stream_name @@ -133,6 +139,7 @@ class PositionCommand(Command): Sent to the client after all missing updates for a stream have been sent to the client and they're now up to date. """ + NAME = "POSITION" def __init__(self, stream_name, token): @@ -145,19 +152,21 @@ class PositionCommand(Command): return cls(stream_name, int(token)) def to_line(self): - return " ".join((self.stream_name, str(self.token),)) + return " ".join((self.stream_name, str(self.token))) class ErrorCommand(Command): """Sent by either side if there was an ERROR. The data is a string describing the error. """ + NAME = "ERROR" class PingCommand(Command): """Sent by either side as a keep alive. The data is arbitary (often timestamp) """ + NAME = "PING" @@ -165,6 +174,7 @@ class NameCommand(Command): """Sent by client to inform the server of the client's identity. The data is the name """ + NAME = "NAME" @@ -184,6 +194,7 @@ class ReplicateCommand(Command): REPLICATE ALL NOW """ + NAME = "REPLICATE" def __init__(self, stream_name, token): @@ -200,7 +211,7 @@ class ReplicateCommand(Command): return cls(stream_name, token) def to_line(self): - return " ".join((self.stream_name, str(self.token),)) + return " ".join((self.stream_name, str(self.token))) def get_logcontext_id(self): return "REPLICATE-" + self.stream_name @@ -218,6 +229,7 @@ class UserSyncCommand(Command): Where <state> is either "start" or "stop" """ + NAME = "USER_SYNC" def __init__(self, user_id, is_syncing, last_sync_ms): @@ -235,9 +247,13 @@ class UserSyncCommand(Command): return cls(user_id, state == "start", int(last_sync_ms)) def to_line(self): - return " ".join(( - self.user_id, "start" if self.is_syncing else "end", str(self.last_sync_ms), - )) + return " ".join( + ( + self.user_id, + "start" if self.is_syncing else "end", + str(self.last_sync_ms), + ) + ) class FederationAckCommand(Command): @@ -251,6 +267,7 @@ class FederationAckCommand(Command): FEDERATION_ACK <token> """ + NAME = "FEDERATION_ACK" def __init__(self, token): @@ -268,6 +285,7 @@ class SyncCommand(Command): """Used for testing. The client protocol implementation allows waiting on a SYNC command with a specified data. """ + NAME = "SYNC" @@ -278,6 +296,7 @@ class RemovePusherCommand(Command): REMOVE_PUSHER <app_id> <push_key> <user_id> """ + NAME = "REMOVE_PUSHER" def __init__(self, app_id, push_key, user_id): @@ -309,6 +328,7 @@ class InvalidateCacheCommand(Command): Where <keys_json> is a json list. """ + NAME = "INVALIDATE_CACHE" def __init__(self, cache_func, keys): @@ -322,9 +342,7 @@ class InvalidateCacheCommand(Command): return cls(cache_func, json.loads(keys_json)) def to_line(self): - return " ".join(( - self.cache_func, _json_encoder.encode(self.keys), - )) + return " ".join((self.cache_func, _json_encoder.encode(self.keys))) class UserIpCommand(Command): @@ -334,6 +352,7 @@ class UserIpCommand(Command): USER_IP <user_id>, <access_token>, <ip>, <device_id>, <last_seen>, <user_agent> """ + NAME = "USER_IP" def __init__(self, user_id, access_token, ip, user_agent, device_id, last_seen): @@ -350,36 +369,57 @@ class UserIpCommand(Command): access_token, ip, user_agent, device_id, last_seen = json.loads(jsn) - return cls( - user_id, access_token, ip, user_agent, device_id, last_seen - ) + return cls(user_id, access_token, ip, user_agent, device_id, last_seen) def to_line(self): - return self.user_id + " " + _json_encoder.encode(( - self.access_token, self.ip, self.user_agent, self.device_id, - self.last_seen, - )) + return ( + self.user_id + + " " + + _json_encoder.encode( + ( + self.access_token, + self.ip, + self.user_agent, + self.device_id, + self.last_seen, + ) + ) + ) + + +class RemoteServerUpCommand(Command): + """Sent when a worker has detected that a remote server is no longer + "down" and retry timings should be reset. + + If sent from a client the server will relay to all other workers. + + Format:: + + REMOTE_SERVER_UP <server> + """ + NAME = "REMOTE_SERVER_UP" + + +_COMMANDS = ( + ServerCommand, + RdataCommand, + PositionCommand, + ErrorCommand, + PingCommand, + NameCommand, + ReplicateCommand, + UserSyncCommand, + FederationAckCommand, + SyncCommand, + RemovePusherCommand, + InvalidateCacheCommand, + UserIpCommand, + RemoteServerUpCommand, +) # type: Tuple[Type[Command], ...] # Map of command name to command type. -COMMAND_MAP = { - cmd.NAME: cmd - for cmd in ( - ServerCommand, - RdataCommand, - PositionCommand, - ErrorCommand, - PingCommand, - NameCommand, - ReplicateCommand, - UserSyncCommand, - FederationAckCommand, - SyncCommand, - RemovePusherCommand, - InvalidateCacheCommand, - UserIpCommand, - ) -} +COMMAND_MAP = {cmd.NAME: cmd for cmd in _COMMANDS} # The commands the server is allowed to send VALID_SERVER_COMMANDS = ( @@ -389,6 +429,7 @@ VALID_SERVER_COMMANDS = ( ErrorCommand.NAME, PingCommand.NAME, SyncCommand.NAME, + RemoteServerUpCommand.NAME, ) # The commands the client is allowed to send @@ -402,4 +443,5 @@ VALID_CLIENT_COMMANDS = ( InvalidateCacheCommand.NAME, UserIpCommand.NAME, ErrorCommand.NAME, + RemoteServerUpCommand.NAME, ) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index b51590cf8f..d185cc0c8f 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py
@@ -48,11 +48,12 @@ indicate which side is sending, these are *not* included on the wire:: > ERROR server stopping * connection closed by server * """ - +import abc import fcntl import logging import struct from collections import defaultdict +from typing import Any, DefaultDict, Dict, List, Set, Tuple from six import iteritems, iterkeys @@ -62,29 +63,33 @@ from twisted.internet import defer from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util.logcontext import make_deferred_yieldable, run_in_background -from synapse.util.stringutils import random_string - -from .commands import ( +from synapse.replication.tcp.commands import ( COMMAND_MAP, VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS, + Command, ErrorCommand, NameCommand, PingCommand, PositionCommand, RdataCommand, + RemoteServerUpCommand, ReplicateCommand, ServerCommand, SyncCommand, UserSyncCommand, ) -from .streams import STREAMS_MAP +from synapse.replication.tcp.streams import STREAMS_MAP +from synapse.types import Collection +from synapse.util import Clock +from synapse.util.stringutils import random_string connection_close_counter = Counter( - "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]) + "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"] +) # A list of all connected protocols. This allows us to send metrics about the # connections. @@ -119,10 +124,14 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): It also sends `PING` periodically, and correctly times out remote connections (if they send a `PING` command) """ - delimiter = b'\n' - VALID_INBOUND_COMMANDS = [] # Valid commands we expect to receive - VALID_OUTBOUND_COMMANDS = [] # Valid commans we can send + delimiter = b"\n" + + # Valid commands we expect to receive + VALID_INBOUND_COMMANDS = [] # type: Collection[str] + + # Valid commands we can send + VALID_OUTBOUND_COMMANDS = [] # type: Collection[str] max_line_buffer = 10000 @@ -141,13 +150,13 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.conn_id = random_string(5) # To dedupe in case of name clashes. # List of pending commands to send once we've established the connection - self.pending_commands = [] + self.pending_commands = [] # type: List[Command] # The LoopingCall for sending pings. self._send_ping_loop = None - self.inbound_commands_counter = defaultdict(int) - self.outbound_commands_counter = defaultdict(int) + self.inbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int] + self.outbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int] def connectionMade(self): logger.info("[%s] Connection established", self.id()) @@ -183,10 +192,14 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): if now - self.last_sent_command >= PING_TIME: self.send_command(PingCommand(now)) - if self.received_ping and now - self.last_received_command > PING_TIMEOUT_MS: + if ( + self.received_ping + and now - self.last_received_command > PING_TIMEOUT_MS + ): logger.info( "[%s] Connection hasn't received command in %r ms. Closing.", - self.id(), now - self.last_received_command + self.id(), + now - self.last_received_command, ) self.send_error("ping timeout") @@ -208,7 +221,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.last_received_command = self.clock.time_msec() self.inbound_commands_counter[cmd_name] = ( - self.inbound_commands_counter[cmd_name] + 1) + self.inbound_commands_counter[cmd_name] + 1 + ) cmd_cls = COMMAND_MAP[cmd_name] try: @@ -224,27 +238,22 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # Now lets try and call on_<CMD_NAME> function run_as_background_process( - "replication-" + cmd.get_logcontext_id(), - self.handle_command, - cmd, + "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd ) - def handle_command(self, cmd): + async def handle_command(self, cmd: Command): """Handle a command we have received over the replication stream. - By default delegates to on_<COMMAND> + By default delegates to on_<COMMAND>, which should return an awaitable. Args: - cmd (synapse.replication.tcp.commands.Command): received command - - Returns: - Deferred + cmd: received command """ handler = getattr(self, "on_%s" % (cmd.NAME,)) - return handler(cmd) + await handler(cmd) def close(self): - logger.warn("[%s] Closing connection", self.id()) + logger.warning("[%s] Closing connection", self.id()) self.time_we_closed = self.clock.time_msec() self.transport.loseConnection() self.on_connection_closed() @@ -274,8 +283,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): return self.outbound_commands_counter[cmd.NAME] = ( - self.outbound_commands_counter[cmd.NAME] + 1) - string = "%s %s" % (cmd.NAME, cmd.to_line(),) + self.outbound_commands_counter[cmd.NAME] + 1 + ) + string = "%s %s" % (cmd.NAME, cmd.to_line()) if "\n" in string: raise Exception("Unexpected newline in command: %r", string) @@ -283,10 +293,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): if len(encoded_string) > self.MAX_LENGTH: raise Exception( - "Failed to send command %s as too long (%d > %d)" % ( - cmd.NAME, - len(encoded_string), self.MAX_LENGTH, - ) + "Failed to send command %s as too long (%d > %d)" + % (cmd.NAME, len(encoded_string), self.MAX_LENGTH) ) self.sendLine(encoded_string) @@ -315,10 +323,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): for cmd in pending: self.send_command(cmd) - def on_PING(self, line): + async def on_PING(self, line): self.received_ping = True - def on_ERROR(self, cmd): + async def on_ERROR(self, cmd): logger.error("[%s] Remote reported error: %r", self.id(), cmd.data) def pauseProducing(self): @@ -379,7 +387,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): if self.transport: addr = str(self.transport.getPeer()) return "ReplicationConnection<name=%s,conn_id=%s,addr=%s>" % ( - self.name, self.conn_id, addr, + self.name, + self.conn_id, + addr, ) def id(self): @@ -402,68 +412,69 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): self.streamer = streamer # The streams the client has subscribed to and is up to date with - self.replication_streams = set() + self.replication_streams = set() # type: Set[str] # The streams the client is currently subscribing to. - self.connecting_streams = set() + self.connecting_streams = set() # type: Set[str] # Map from stream name to list of updates to send once we've finished # subscribing the client to the stream. - self.pending_rdata = {} + self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]] def connectionMade(self): self.send_command(ServerCommand(self.server_name)) BaseReplicationStreamProtocol.connectionMade(self) self.streamer.new_connection(self) - def on_NAME(self, cmd): + async def on_NAME(self, cmd): logger.info("[%s] Renamed to %r", self.id(), cmd.data) self.name = cmd.data - def on_USER_SYNC(self, cmd): - return self.streamer.on_user_sync( - self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms, + async def on_USER_SYNC(self, cmd): + await self.streamer.on_user_sync( + self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms ) - def on_REPLICATE(self, cmd): + async def on_REPLICATE(self, cmd): stream_name = cmd.stream_name token = cmd.token if stream_name == "ALL": # Subscribe to all streams we're publishing to. deferreds = [ - run_in_background( - self.subscribe_to_stream, - stream, token, - ) + run_in_background(self.subscribe_to_stream, stream, token) for stream in iterkeys(self.streamer.streams_by_name) ] - return make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults(deferreds, consumeErrors=True) ) else: - return self.subscribe_to_stream(stream_name, token) + await self.subscribe_to_stream(stream_name, token) - def on_FEDERATION_ACK(self, cmd): - return self.streamer.federation_ack(cmd.token) + async def on_FEDERATION_ACK(self, cmd): + self.streamer.federation_ack(cmd.token) - def on_REMOVE_PUSHER(self, cmd): - return self.streamer.on_remove_pusher( - cmd.app_id, cmd.push_key, cmd.user_id, - ) + async def on_REMOVE_PUSHER(self, cmd): + await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) - def on_INVALIDATE_CACHE(self, cmd): - return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) + async def on_INVALIDATE_CACHE(self, cmd): + await self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) - def on_USER_IP(self, cmd): - return self.streamer.on_user_ip( - cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id, + async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): + self.streamer.on_remote_server_up(cmd.data) + + async def on_USER_IP(self, cmd): + await self.streamer.on_user_ip( + cmd.user_id, + cmd.access_token, + cmd.ip, + cmd.user_agent, + cmd.device_id, cmd.last_seen, ) - @defer.inlineCallbacks - def subscribe_to_stream(self, stream_name, token): + async def subscribe_to_stream(self, stream_name, token): """Subscribe the remote to a stream. This invloves checking if they've missed anything and sending those @@ -475,8 +486,8 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): try: # Get missing updates - updates, current_token = yield self.streamer.get_stream_updates( - stream_name, token, + updates, current_token = await self.streamer.get_stream_updates( + stream_name, token ) # Send all the missing updates @@ -548,16 +559,90 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): def send_sync(self, data): self.send_command(SyncCommand(data)) + def send_remote_server_up(self, server: str): + self.send_command(RemoteServerUpCommand(server)) + def on_connection_closed(self): BaseReplicationStreamProtocol.on_connection_closed(self) self.streamer.lost_connection(self) +class AbstractReplicationClientHandler(metaclass=abc.ABCMeta): + """ + The interface for the handler that should be passed to + ClientReplicationStreamProtocol + """ + + @abc.abstractmethod + async def on_rdata(self, stream_name, token, rows): + """Called to handle a batch of replication data with a given stream token. + + Args: + stream_name (str): name of the replication stream for this batch of rows + token (int): stream token for this batch of rows + rows (list): a list of Stream.ROW_TYPE objects as returned by + Stream.parse_row. + """ + raise NotImplementedError() + + @abc.abstractmethod + async def on_position(self, stream_name, token): + """Called when we get new position data.""" + raise NotImplementedError() + + @abc.abstractmethod + def on_sync(self, data): + """Called when get a new SYNC command.""" + raise NotImplementedError() + + @abc.abstractmethod + async def on_remote_server_up(self, server: str): + """Called when get a new REMOTE_SERVER_UP command.""" + raise NotImplementedError() + + @abc.abstractmethod + def get_streams_to_replicate(self): + """Called when a new connection has been established and we need to + subscribe to streams. + + Returns: + map from stream name to the most recent update we have for + that stream (ie, the point we want to start replicating from) + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_currently_syncing_users(self): + """Get the list of currently syncing users (if any). This is called + when a connection has been established and we need to send the + currently syncing users.""" + raise NotImplementedError() + + @abc.abstractmethod + def update_connection(self, connection): + """Called when a connection has been established (or lost with None). + """ + raise NotImplementedError() + + @abc.abstractmethod + def finished_connecting(self): + """Called when we have successfully subscribed and caught up to all + streams we're interested in. + """ + raise NotImplementedError() + + class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS - def __init__(self, client_name, server_name, clock, handler): + def __init__( + self, + client_name: str, + server_name: str, + clock: Clock, + handler: AbstractReplicationClientHandler, + ): BaseReplicationStreamProtocol.__init__(self, clock) self.client_name = client_name @@ -567,11 +652,11 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # Set of stream names that have been subscribe to, but haven't yet # caught up with. This is used to track when the client has been fully # connected to the remote. - self.streams_connecting = set() + self.streams_connecting = set() # type: Set[str] # Map of stream to batched updates. See RdataCommand for info on how # batching works. - self.pending_batches = {} + self.pending_batches = {} # type: Dict[str, Any] def connectionMade(self): self.send_command(NameCommand(self.client_name)) @@ -595,12 +680,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): if not self.streams_connecting: self.handler.finished_connecting() - def on_SERVER(self, cmd): + async def on_SERVER(self, cmd): if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) self.send_error("Wrong remote") - def on_RDATA(self, cmd): + async def on_RDATA(self, cmd): stream_name = cmd.stream_name inbound_rdata_count.labels(stream_name).inc() @@ -608,8 +693,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): row = STREAMS_MAP[stream_name].parse_row(cmd.row) except Exception: logger.exception( - "[%s] Failed to parse RDATA: %r %r", - self.id(), stream_name, cmd.row + "[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row ) raise @@ -621,19 +705,22 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # Check if this is the last of a batch of updates rows = self.pending_batches.pop(stream_name, []) rows.append(row) - return self.handler.on_rdata(stream_name, cmd.token, rows) + await self.handler.on_rdata(stream_name, cmd.token, rows) - def on_POSITION(self, cmd): + async def on_POSITION(self, cmd): # When we get a `POSITION` command it means we've finished getting # missing updates for the given stream, and are now up to date. self.streams_connecting.discard(cmd.stream_name) if not self.streams_connecting: self.handler.finished_connecting() - return self.handler.on_position(cmd.stream_name, cmd.token) + await self.handler.on_position(cmd.stream_name, cmd.token) - def on_SYNC(self, cmd): - return self.handler.on_sync(cmd.data) + async def on_SYNC(self, cmd): + self.handler.on_sync(cmd.data) + + async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): + self.handler.on_remote_server_up(cmd.data) def replicate(self, stream_name, token): """Send the subscription request to the server @@ -643,7 +730,9 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): logger.info( "[%s] Subscribing to replication stream: %r from %r", - self.id(), stream_name, token + self.id(), + stream_name, + token, ) self.streams_connecting.add(stream_name) @@ -661,9 +750,7 @@ pending_commands = LaterGauge( "synapse_replication_tcp_protocol_pending_commands", "", ["name"], - lambda: { - (p.name,): len(p.pending_commands) for p in connected_connections - }, + lambda: {(p.name,): len(p.pending_commands) for p in connected_connections}, ) @@ -678,9 +765,7 @@ transport_send_buffer = LaterGauge( "synapse_replication_tcp_protocol_transport_send_buffer", "", ["name"], - lambda: { - (p.name,): transport_buffer_size(p) for p in connected_connections - }, + lambda: {(p.name,): transport_buffer_size(p) for p in connected_connections}, ) @@ -694,7 +779,7 @@ def transport_kernel_read_buffer_size(protocol, read=True): op = SIOCINQ else: op = SIOCOUTQ - size = struct.unpack("I", fcntl.ioctl(fileno, op, '\0\0\0\0'))[0] + size = struct.unpack("I", fcntl.ioctl(fileno, op, b"\0\0\0\0"))[0] return size return 0 @@ -726,7 +811,7 @@ tcp_inbound_commands = LaterGauge( "", ["command", "name"], lambda: { - (k, p.name,): count + (k, p.name): count for p in connected_connections for k, count in iteritems(p.inbound_commands_counter) }, @@ -737,7 +822,7 @@ tcp_outbound_commands = LaterGauge( "", ["command", "name"], lambda: { - (k, p.name,): count + (k, p.name): count for p in connected_connections for k, count in iteritems(p.outbound_commands_counter) }, diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index f6a38f5140..ce9d1fae12 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py
@@ -17,12 +17,12 @@ import logging import random +from typing import Any, List from six import itervalues from prometheus_client import Counter -from twisted.internet import defer from twisted.internet.protocol import Factory from synapse.metrics import LaterGauge @@ -33,13 +33,15 @@ from .protocol import ServerReplicationStreamProtocol from .streams import STREAMS_MAP from .streams.federation import FederationStream -stream_updates_counter = Counter("synapse_replication_tcp_resource_stream_updates", - "", ["stream_name"]) +stream_updates_counter = Counter( + "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"] +) user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "") federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "") remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "") -invalidate_cache_counter = Counter("synapse_replication_tcp_resource_invalidate_cache", - "") +invalidate_cache_counter = Counter( + "synapse_replication_tcp_resource_invalidate_cache", "" +) user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") logger = logging.getLogger(__name__) @@ -48,6 +50,7 @@ logger = logging.getLogger(__name__) class ReplicationStreamProtocolFactory(Factory): """Factory for new replication connections. """ + def __init__(self, hs): self.streamer = ReplicationStreamer(hs) self.clock = hs.get_clock() @@ -55,9 +58,7 @@ class ReplicationStreamProtocolFactory(Factory): def buildProtocol(self, addr): return ServerReplicationStreamProtocol( - self.server_name, - self.clock, - self.streamer, + self.server_name, self.clock, self.streamer ) @@ -78,37 +79,48 @@ class ReplicationStreamer(object): self._replication_torture_level = hs.config.replication_torture_level # Current connections. - self.connections = [] + self.connections = [] # type: List[ServerReplicationStreamProtocol] - LaterGauge("synapse_replication_tcp_resource_total_connections", "", [], - lambda: len(self.connections)) + LaterGauge( + "synapse_replication_tcp_resource_total_connections", + "", + [], + lambda: len(self.connections), + ) # List of streams that clients can subscribe to. # We only support federation stream if federation sending hase been # disabled on the master. self.streams = [ - stream(hs) for stream in itervalues(STREAMS_MAP) + stream(hs) + for stream in itervalues(STREAMS_MAP) if stream != FederationStream or not hs.config.send_federation ] self.streams_by_name = {stream.NAME: stream for stream in self.streams} LaterGauge( - "synapse_replication_tcp_resource_connections_per_stream", "", + "synapse_replication_tcp_resource_connections_per_stream", + "", ["stream_name"], lambda: { - (stream_name,): len([ - conn for conn in self.connections - if stream_name in conn.replication_streams - ]) + (stream_name,): len( + [ + conn + for conn in self.connections + if stream_name in conn.replication_streams + ] + ) for stream_name in self.streams_by_name - }) + }, + ) self.federation_sender = None if not hs.config.send_federation: self.federation_sender = hs.get_federation_sender() self.notifier.add_replication_callback(self.on_notifier_poke) + self.notifier.add_remote_server_up_callback(self.send_remote_server_up) # Keeps track of whether we are currently checking for updates self.is_looping = False @@ -143,8 +155,7 @@ class ReplicationStreamer(object): run_as_background_process("replication_notifier", self._run_notifier_loop) - @defer.inlineCallbacks - def _run_notifier_loop(self): + async def _run_notifier_loop(self): self.is_looping = True try: @@ -173,23 +184,26 @@ class ReplicationStreamer(object): continue if self._replication_torture_level: - yield self.clock.sleep( + await self.clock.sleep( self._replication_torture_level / 1000.0 ) logger.debug( "Getting stream: %s: %s -> %s", - stream.NAME, stream.last_token, stream.upto_token + stream.NAME, + stream.last_token, + stream.upto_token, ) try: - updates, current_token = yield stream.get_updates() + updates, current_token = await stream.get_updates() except Exception: logger.info("Failed to handle stream %s", stream.NAME) raise logger.debug( "Sending %d updates to %d connections", - len(updates), len(self.connections), + len(updates), + len(self.connections), ) if updates: @@ -218,7 +232,7 @@ class ReplicationStreamer(object): self.is_looping = False @measure_func("repl.get_stream_updates") - def get_stream_updates(self, stream_name, token): + async def get_stream_updates(self, stream_name, token): """For a given stream get all updates since token. This is called when a client first subscribes to a stream. """ @@ -226,7 +240,7 @@ class ReplicationStreamer(object): if not stream: raise Exception("unknown stream %s", stream_name) - return stream.get_updates_since(token) + return await stream.get_updates_since(token) @measure_func("repl.federation_ack") def federation_ack(self, token): @@ -237,44 +251,54 @@ class ReplicationStreamer(object): self.federation_sender.federation_ack(token) @measure_func("repl.on_user_sync") - @defer.inlineCallbacks - def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms): + async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms): """A client has started/stopped syncing on a worker. """ user_sync_counter.inc() - yield self.presence_handler.update_external_syncs_row( - conn_id, user_id, is_syncing, last_sync_ms, + await self.presence_handler.update_external_syncs_row( + conn_id, user_id, is_syncing, last_sync_ms ) @measure_func("repl.on_remove_pusher") - @defer.inlineCallbacks - def on_remove_pusher(self, app_id, push_key, user_id): + async def on_remove_pusher(self, app_id, push_key, user_id): """A client has asked us to remove a pusher """ remove_pusher_counter.inc() - yield self.store.delete_pusher_by_app_id_pushkey_user_id( + await self.store.delete_pusher_by_app_id_pushkey_user_id( app_id=app_id, pushkey=push_key, user_id=user_id ) self.notifier.on_new_replication_data() @measure_func("repl.on_invalidate_cache") - def on_invalidate_cache(self, cache_func, keys): + async def on_invalidate_cache(self, cache_func: str, keys: List[Any]): """The client has asked us to invalidate a cache """ invalidate_cache_counter.inc() - getattr(self.store, cache_func).invalidate(tuple(keys)) + + # We invalidate the cache locally, but then also stream that to other + # workers. + await self.store.invalidate_cache_and_stream(cache_func, tuple(keys)) @measure_func("repl.on_user_ip") - @defer.inlineCallbacks - def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen): + async def on_user_ip( + self, user_id, access_token, ip, user_agent, device_id, last_seen + ): """The client saw a user request """ user_ip_cache_counter.inc() - yield self.store.insert_client_ip( - user_id, access_token, ip, user_agent, device_id, last_seen, + await self.store.insert_client_ip( + user_id, access_token, ip, user_agent, device_id, last_seen ) - yield self._server_notices_sender.on_user_ip(user_id) + await self._server_notices_sender.on_user_ip(user_id) + + @measure_func("repl.on_remote_server_up") + def on_remote_server_up(self, server: str): + self.notifier.notify_remote_server_up(server) + + def send_remote_server_up(self, server: str): + for conn in self.connections: + conn.send_remote_server_up(server) def send_sync_to_all_connections(self, data): """Sends a SYNC command to all clients. @@ -299,7 +323,11 @@ class ReplicationStreamer(object): # We need to tell the presence handler that the connection has been # lost so that it can handle any ongoing syncs on that connection. - self.presence_handler.update_external_syncs_clear(connection.conn_id) + run_as_background_process( + "update_external_syncs_clear", + self.presence_handler.update_external_syncs_clear, + connection.conn_id, + ) def _batch_updates(updates): diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index 634f636dc9..5f52264e84 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py
@@ -45,5 +45,6 @@ STREAMS_MAP = { _base.TagAccountDataStream, _base.AccountDataStream, _base.GroupServerStream, + _base.UserSignatureStream, ) } diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index b6ce7a7bee..208e8a667b 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py
@@ -14,90 +14,101 @@ # See the License for the specific language governing permissions and # limitations under the License. - import itertools import logging from collections import namedtuple +from typing import Any, List, Optional -from twisted.internet import defer +import attr logger = logging.getLogger(__name__) -MAX_EVENTS_BEHIND = 10000 - -BackfillStreamRow = namedtuple("BackfillStreamRow", ( - "event_id", # str - "room_id", # str - "type", # str - "state_key", # str, optional - "redacts", # str, optional - "relates_to", # str, optional -)) -PresenceStreamRow = namedtuple("PresenceStreamRow", ( - "user_id", # str - "state", # str - "last_active_ts", # int - "last_federation_update_ts", # int - "last_user_sync_ts", # int - "status_msg", # str - "currently_active", # bool -)) -TypingStreamRow = namedtuple("TypingStreamRow", ( - "room_id", # str - "user_ids", # list(str) -)) -ReceiptsStreamRow = namedtuple("ReceiptsStreamRow", ( - "room_id", # str - "receipt_type", # str - "user_id", # str - "event_id", # str - "data", # dict -)) -PushRulesStreamRow = namedtuple("PushRulesStreamRow", ( - "user_id", # str -)) -PushersStreamRow = namedtuple("PushersStreamRow", ( - "user_id", # str - "app_id", # str - "pushkey", # str - "deleted", # bool -)) -CachesStreamRow = namedtuple("CachesStreamRow", ( - "cache_func", # str - "keys", # list(str) - "invalidation_ts", # int -)) -PublicRoomsStreamRow = namedtuple("PublicRoomsStreamRow", ( - "room_id", # str - "visibility", # str - "appservice_id", # str, optional - "network_id", # str, optional -)) -DeviceListsStreamRow = namedtuple("DeviceListsStreamRow", ( - "user_id", # str - "destination", # str -)) -ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ( - "entity", # str -)) -TagAccountDataStreamRow = namedtuple("TagAccountDataStreamRow", ( - "user_id", # str - "room_id", # str - "data", # dict -)) -AccountDataStreamRow = namedtuple("AccountDataStream", ( - "user_id", # str - "room_id", # str - "data_type", # str - "data", # dict -)) -GroupsStreamRow = namedtuple("GroupsStreamRow", ( - "group_id", # str - "user_id", # str - "type", # str - "content", # dict -)) +MAX_EVENTS_BEHIND = 500000 + +BackfillStreamRow = namedtuple( + "BackfillStreamRow", + ( + "event_id", # str + "room_id", # str + "type", # str + "state_key", # str, optional + "redacts", # str, optional + "relates_to", # str, optional + ), +) +PresenceStreamRow = namedtuple( + "PresenceStreamRow", + ( + "user_id", # str + "state", # str + "last_active_ts", # int + "last_federation_update_ts", # int + "last_user_sync_ts", # int + "status_msg", # str + "currently_active", # bool + ), +) +TypingStreamRow = namedtuple( + "TypingStreamRow", ("room_id", "user_ids") # str # list(str) +) +ReceiptsStreamRow = namedtuple( + "ReceiptsStreamRow", + ( + "room_id", # str + "receipt_type", # str + "user_id", # str + "event_id", # str + "data", # dict + ), +) +PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str +PushersStreamRow = namedtuple( + "PushersStreamRow", + ("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool +) + + +@attr.s +class CachesStreamRow: + """Stream to inform workers they should invalidate their cache. + + Attributes: + cache_func: Name of the cached function. + keys: The entry in the cache to invalidate. If None then will + invalidate all. + invalidation_ts: Timestamp of when the invalidation took place. + """ + + cache_func = attr.ib(type=str) + keys = attr.ib(type=Optional[List[Any]]) + invalidation_ts = attr.ib(type=int) + + +PublicRoomsStreamRow = namedtuple( + "PublicRoomsStreamRow", + ( + "room_id", # str + "visibility", # str + "appservice_id", # str, optional + "network_id", # str, optional + ), +) +DeviceListsStreamRow = namedtuple( + "DeviceListsStreamRow", ("user_id", "destination") # str # str +) +ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str +TagAccountDataStreamRow = namedtuple( + "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict +) +AccountDataStreamRow = namedtuple( + "AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str +) +GroupsStreamRow = namedtuple( + "GroupsStreamRow", + ("group_id", "user_id", "type", "content"), # str # str # str # dict +) +UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str class Stream(object): @@ -106,8 +117,10 @@ class Stream(object): Provides a `get_updates()` function that returns new updates since the last time it was called up until the point `advance_current_token` was called. """ - NAME = None # The name of the stream - ROW_TYPE = None # The type of the row. Used by the default impl of parse_row. + + NAME = None # type: str # The name of the stream + # The type of the row. Used by the default impl of parse_row. + ROW_TYPE = None # type: Any _LIMITED = True # Whether the update function takes a limit @classmethod @@ -145,8 +158,7 @@ class Stream(object): self.upto_token = self.current_token() self.last_token = self.upto_token - @defer.inlineCallbacks - def get_updates(self): + async def get_updates(self): """Gets all updates since the last time this function was called (or since the stream was constructed if it hadn't been called before), until the `upto_token` @@ -157,13 +169,12 @@ class Stream(object): list of ``(token, row)`` entries. ``row`` will be json-serialised and sent over the replication steam. """ - updates, current_token = yield self.get_updates_since(self.last_token) + updates, current_token = await self.get_updates_since(self.last_token) self.last_token = current_token - defer.returnValue((updates, current_token)) + return updates, current_token - @defer.inlineCallbacks - def get_updates_since(self, from_token): + async def get_updates_since(self, from_token): """Like get_updates except allows specifying from when we should stream updates @@ -174,27 +185,25 @@ class Stream(object): sent over the replication steam. """ if from_token in ("NOW", "now"): - defer.returnValue(([], self.upto_token)) + return [], self.upto_token current_token = self.upto_token from_token = int(from_token) if from_token == current_token: - defer.returnValue(([], current_token)) + return [], current_token + logger.info("get_updates_since: %s", self.__class__) if self._LIMITED: - rows = yield self.update_function( - from_token, current_token, - limit=MAX_EVENTS_BEHIND + 1, + rows = await self.update_function( + from_token, current_token, limit=MAX_EVENTS_BEHIND + 1 ) # never turn more than MAX_EVENTS_BEHIND + 1 into updates. rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1) else: - rows = yield self.update_function( - from_token, current_token, - ) + rows = await self.update_function(from_token, current_token) updates = [(row[0], row[1:]) for row in rows] @@ -203,7 +212,7 @@ class Stream(object): if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND: raise Exception("stream %s has fallen behind" % (self.NAME)) - defer.returnValue((updates, current_token)) + return updates, current_token def current_token(self): """Gets the current token of the underlying streams. Should be provided @@ -230,13 +239,14 @@ class BackfillStream(Stream): """We fetched some old events and either we had never seen that event before or it went from being an outlier to not. """ + NAME = "backfill" ROW_TYPE = BackfillStreamRow def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_current_backfill_token - self.update_function = store.get_all_new_backfill_event_rows + self.current_token = store.get_current_backfill_token # type: ignore + self.update_function = store.get_all_new_backfill_event_rows # type: ignore super(BackfillStream, self).__init__(hs) @@ -250,8 +260,8 @@ class PresenceStream(Stream): store = hs.get_datastore() presence_handler = hs.get_presence_handler() - self.current_token = store.get_current_presence_token - self.update_function = presence_handler.get_all_presence_updates + self.current_token = store.get_current_presence_token # type: ignore + self.update_function = presence_handler.get_all_presence_updates # type: ignore super(PresenceStream, self).__init__(hs) @@ -264,8 +274,8 @@ class TypingStream(Stream): def __init__(self, hs): typing_handler = hs.get_typing_handler() - self.current_token = typing_handler.get_current_token - self.update_function = typing_handler.get_all_typing_updates + self.current_token = typing_handler.get_current_token # type: ignore + self.update_function = typing_handler.get_all_typing_updates # type: ignore super(TypingStream, self).__init__(hs) @@ -277,8 +287,8 @@ class ReceiptsStream(Stream): def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_max_receipt_stream_id - self.update_function = store.get_all_updated_receipts + self.current_token = store.get_max_receipt_stream_id # type: ignore + self.update_function = store.get_all_updated_receipts # type: ignore super(ReceiptsStream, self).__init__(hs) @@ -286,6 +296,7 @@ class ReceiptsStream(Stream): class PushRulesStream(Stream): """A user has changed their push rules """ + NAME = "push_rules" ROW_TYPE = PushRulesStreamRow @@ -297,23 +308,23 @@ class PushRulesStream(Stream): push_rules_token, _ = self.store.get_push_rules_stream_token() return push_rules_token - @defer.inlineCallbacks - def update_function(self, from_token, to_token, limit): - rows = yield self.store.get_all_push_rule_updates(from_token, to_token, limit) - defer.returnValue([(row[0], row[2]) for row in rows]) + async def update_function(self, from_token, to_token, limit): + rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) + return [(row[0], row[2]) for row in rows] class PushersStream(Stream): """A user has added/changed/removed a pusher """ + NAME = "pushers" ROW_TYPE = PushersStreamRow def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_pushers_stream_token - self.update_function = store.get_all_updated_pushers_rows + self.current_token = store.get_pushers_stream_token # type: ignore + self.update_function = store.get_all_updated_pushers_rows # type: ignore super(PushersStream, self).__init__(hs) @@ -322,14 +333,15 @@ class CachesStream(Stream): """A cache was invalidated on the master and no other stream would invalidate the cache on the workers """ + NAME = "caches" ROW_TYPE = CachesStreamRow def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_cache_stream_token - self.update_function = store.get_all_updated_caches + self.current_token = store.get_cache_stream_token # type: ignore + self.update_function = store.get_all_updated_caches # type: ignore super(CachesStream, self).__init__(hs) @@ -337,14 +349,15 @@ class CachesStream(Stream): class PublicRoomsStream(Stream): """The public rooms list changed """ + NAME = "public_rooms" ROW_TYPE = PublicRoomsStreamRow def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_current_public_room_stream_id - self.update_function = store.get_all_new_public_rooms + self.current_token = store.get_current_public_room_stream_id # type: ignore + self.update_function = store.get_all_new_public_rooms # type: ignore super(PublicRoomsStream, self).__init__(hs) @@ -352,6 +365,7 @@ class PublicRoomsStream(Stream): class DeviceListsStream(Stream): """Someone added/changed/removed a device """ + NAME = "device_lists" _LIMITED = False ROW_TYPE = DeviceListsStreamRow @@ -359,8 +373,8 @@ class DeviceListsStream(Stream): def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_device_stream_token - self.update_function = store.get_all_device_list_changes_for_remotes + self.current_token = store.get_device_stream_token # type: ignore + self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore super(DeviceListsStream, self).__init__(hs) @@ -368,14 +382,15 @@ class DeviceListsStream(Stream): class ToDeviceStream(Stream): """New to_device messages for a client """ + NAME = "to_device" ROW_TYPE = ToDeviceStreamRow def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_to_device_stream_token - self.update_function = store.get_all_new_device_messages + self.current_token = store.get_to_device_stream_token # type: ignore + self.update_function = store.get_all_new_device_messages # type: ignore super(ToDeviceStream, self).__init__(hs) @@ -383,14 +398,15 @@ class ToDeviceStream(Stream): class TagAccountDataStream(Stream): """Someone added/removed a tag for a room """ + NAME = "tag_account_data" ROW_TYPE = TagAccountDataStreamRow def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_max_account_data_stream_id - self.update_function = store.get_all_updated_tags + self.current_token = store.get_max_account_data_stream_id # type: ignore + self.update_function = store.get_all_updated_tags # type: ignore super(TagAccountDataStream, self).__init__(hs) @@ -398,29 +414,29 @@ class TagAccountDataStream(Stream): class AccountDataStream(Stream): """Global or per room account data was changed """ + NAME = "account_data" ROW_TYPE = AccountDataStreamRow def __init__(self, hs): self.store = hs.get_datastore() - self.current_token = self.store.get_max_account_data_stream_id + self.current_token = self.store.get_max_account_data_stream_id # type: ignore super(AccountDataStream, self).__init__(hs) - @defer.inlineCallbacks - def update_function(self, from_token, to_token, limit): - global_results, room_results = yield self.store.get_all_updated_account_data( + async def update_function(self, from_token, to_token, limit): + global_results, room_results = await self.store.get_all_updated_account_data( from_token, from_token, to_token, limit ) results = list(room_results) results.extend( - (stream_id, user_id, None, account_data_type, content,) - for stream_id, user_id, account_data_type, content in global_results + (stream_id, user_id, None, account_data_type) + for stream_id, user_id, account_data_type in global_results ) - defer.returnValue(results) + return results class GroupServerStream(Stream): @@ -430,7 +446,24 @@ class GroupServerStream(Stream): def __init__(self, hs): store = hs.get_datastore() - self.current_token = store.get_group_stream_token - self.update_function = store.get_all_groups_changes + self.current_token = store.get_group_stream_token # type: ignore + self.update_function = store.get_all_groups_changes # type: ignore super(GroupServerStream, self).__init__(hs) + + +class UserSignatureStream(Stream): + """A user has signed their own device with their user-signing key + """ + + NAME = "user_signature" + _LIMITED = False + ROW_TYPE = UserSignatureStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + + self.current_token = store.get_device_stream_token # type: ignore + self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore + + super(UserSignatureStream, self).__init__(hs) diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index f1290d022a..b3afabb8cd 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py
@@ -13,12 +13,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import heapq +from typing import Tuple, Type import attr -from twisted.internet import defer - from ._base import Stream @@ -52,6 +52,7 @@ data part are: @attr.s(slots=True, frozen=True) class EventsStreamRow(object): """A parsed row from the events replication stream""" + type = attr.ib() # str: the TypeId of one of the *EventsStreamRows data = attr.ib() # BaseEventsStreamRow @@ -62,7 +63,8 @@ class BaseEventsStreamRow(object): Specifies how to identify, serialize and deserialize the different types. """ - TypeId = None # Unique string that ids the type. Must be overriden in sub classes. + # Unique string that ids the type. Must be overriden in sub classes. + TypeId = None # type: str @classmethod def from_data(cls, data): @@ -80,11 +82,11 @@ class BaseEventsStreamRow(object): class EventsStreamEventRow(BaseEventsStreamRow): TypeId = "ev" - event_id = attr.ib() # str - room_id = attr.ib() # str - type = attr.ib() # str - state_key = attr.ib() # str, optional - redacts = attr.ib() # str, optional + event_id = attr.ib() # str + room_id = attr.ib() # str + type = attr.ib() # str + state_key = attr.ib() # str, optional + redacts = attr.ib() # str, optional relates_to = attr.ib() # str, optional @@ -92,53 +94,50 @@ class EventsStreamEventRow(BaseEventsStreamRow): class EventsStreamCurrentStateRow(BaseEventsStreamRow): TypeId = "state" - room_id = attr.ib() # str - type = attr.ib() # str + room_id = attr.ib() # str + type = attr.ib() # str state_key = attr.ib() # str - event_id = attr.ib() # str, optional + event_id = attr.ib() # str, optional -TypeToRow = { - Row.TypeId: Row - for Row in ( - EventsStreamEventRow, - EventsStreamCurrentStateRow, - ) -} +_EventRows = ( + EventsStreamEventRow, + EventsStreamCurrentStateRow, +) # type: Tuple[Type[BaseEventsStreamRow], ...] + +TypeToRow = {Row.TypeId: Row for Row in _EventRows} class EventsStream(Stream): """We received a new event, or an event went from being an outlier to not """ + NAME = "events" def __init__(self, hs): self._store = hs.get_datastore() - self.current_token = self._store.get_current_events_token + self.current_token = self._store.get_current_events_token # type: ignore super(EventsStream, self).__init__(hs) - @defer.inlineCallbacks - def update_function(self, from_token, current_token, limit=None): - event_rows = yield self._store.get_all_new_forward_event_rows( - from_token, current_token, limit, + async def update_function(self, from_token, current_token, limit=None): + event_rows = await self._store.get_all_new_forward_event_rows( + from_token, current_token, limit ) event_updates = ( - (row[0], EventsStreamEventRow.TypeId, row[1:]) - for row in event_rows + (row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows ) - state_rows = yield self._store.get_all_updated_current_state_deltas( + state_rows = await self._store.get_all_updated_current_state_deltas( from_token, current_token, limit ) state_updates = ( - (row[0], EventsStreamCurrentStateRow.TypeId, row[1:]) - for row in state_rows + (row[0], EventsStreamCurrentStateRow.TypeId, row[1:]) for row in state_rows ) all_updates = heapq.merge(event_updates, state_updates) - defer.returnValue(all_updates) + return all_updates @classmethod def parse_row(cls, row): diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 9aa43aa8d2..615f3dc9ac 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py
@@ -17,23 +17,27 @@ from collections import namedtuple from ._base import Stream -FederationStreamRow = namedtuple("FederationStreamRow", ( - "type", # str, the type of data as defined in the BaseFederationRows - "data", # dict, serialization of a federation.send_queue.BaseFederationRow -)) +FederationStreamRow = namedtuple( + "FederationStreamRow", + ( + "type", # str, the type of data as defined in the BaseFederationRows + "data", # dict, serialization of a federation.send_queue.BaseFederationRow + ), +) class FederationStream(Stream): """Data to be sent over federation. Only available when master has federation sending disabled. """ + NAME = "federation" ROW_TYPE = FederationStreamRow def __init__(self, hs): federation_sender = hs.get_federation_sender() - self.current_token = federation_sender.get_current_token - self.update_function = federation_sender.get_replication_rows + self.current_token = federation_sender.get_current_token # type: ignore + self.update_function = federation_sender.get_replication_rows # type: ignore super(FederationStream, self).__init__(hs)