diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 206dc3b397..02ab5b66ea 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -16,18 +16,26 @@
"""
import logging
+from typing import Dict, List, Optional
from twisted.internet import defer
from twisted.internet.protocol import ReconnectingClientFactory
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.tcp.protocol import (
+ AbstractReplicationClientHandler,
+ ClientReplicationStreamProtocol,
+)
+
from .commands import (
+ Command,
FederationAckCommand,
InvalidateCacheCommand,
+ RemoteServerUpCommand,
RemovePusherCommand,
UserIpCommand,
UserSyncCommand,
)
-from .protocol import ClientReplicationStreamProtocol
logger = logging.getLogger(__name__)
@@ -39,9 +47,11 @@ class ReplicationClientFactory(ReconnectingClientFactory):
Accepts a handler that will be called when new data is available or data
is required.
"""
- maxDelay = 30 # Try at least once every N seconds
- def __init__(self, hs, client_name, handler):
+ initialDelay = 0.1
+ maxDelay = 1 # Try at least once every N seconds
+
+ def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler):
self.client_name = client_name
self.handler = handler
self.server_name = hs.config.server_name
@@ -64,17 +74,16 @@ class ReplicationClientFactory(ReconnectingClientFactory):
def clientConnectionFailed(self, connector, reason):
logger.error("Failed to connect to replication: %r", reason)
- ReconnectingClientFactory.clientConnectionFailed(
- self, connector, reason
- )
+ ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
-class ReplicationClientHandler(object):
+class ReplicationClientHandler(AbstractReplicationClientHandler):
"""A base handler that can be passed to the ReplicationClientFactory.
By default proxies incoming replication data to the SlaveStore.
"""
- def __init__(self, store):
+
+ def __init__(self, store: BaseSlavedStore):
self.store = store
# The current connection. None if we are currently (re)connecting
@@ -82,15 +91,15 @@ class ReplicationClientHandler(object):
# Any pending commands to be sent once a new connection has been
# established
- self.pending_commands = []
+ self.pending_commands = [] # type: List[Command]
# Map from string -> deferred, to wake up when receiveing a SYNC with
# the given string.
# Used for tests.
- self.awaiting_syncs = {}
+ self.awaiting_syncs = {} # type: Dict[str, defer.Deferred]
# The factory used to create connections.
- self.factory = None
+ self.factory = None # type: Optional[ReplicationClientFactory]
def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server
@@ -102,7 +111,7 @@ class ReplicationClientHandler(object):
port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self.factory)
- def on_rdata(self, stream_name, token, rows):
+ async def on_rdata(self, stream_name, token, rows):
"""Called to handle a batch of replication data with a given stream token.
By default this just pokes the slave store. Can be overridden in subclasses to
@@ -113,20 +122,17 @@ class ReplicationClientHandler(object):
token (int): stream token for this batch of rows
rows (list): a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
-
- Returns:
- Deferred|None
"""
logger.debug("Received rdata %s -> %s", stream_name, token)
- return self.store.process_replication_rows(stream_name, token, rows)
+ self.store.process_replication_rows(stream_name, token, rows)
- def on_position(self, stream_name, token):
+ async def on_position(self, stream_name, token):
"""Called when we get new position data. By default this just pokes
the slave store.
Can be overriden in subclasses to handle more.
"""
- return self.store.process_replication_rows(stream_name, token, [])
+ self.store.process_replication_rows(stream_name, token, [])
def on_sync(self, data):
"""When we received a SYNC we wake up any deferreds that were waiting
@@ -138,11 +144,16 @@ class ReplicationClientHandler(object):
if d:
d.callback(data)
- def get_streams_to_replicate(self):
+ def on_remote_server_up(self, server: str):
+ """Called when get a new REMOTE_SERVER_UP command."""
+
+ def get_streams_to_replicate(self) -> Dict[str, int]:
"""Called when a new connection has been established and we need to
subscribe to streams.
- Returns a dictionary of stream name to token.
+ Returns:
+ map from stream name to the most recent update we have for
+ that stream (ie, the point we want to start replicating from)
"""
args = self.store.stream_positions()
user_account_data = args.pop("user_account_data", None)
@@ -168,7 +179,7 @@ class ReplicationClientHandler(object):
if self.connection:
self.connection.send_command(cmd)
else:
- logger.warn("Queuing command as not connected: %r", cmd.NAME)
+ logger.warning("Queuing command as not connected: %r", cmd.NAME)
self.pending_commands.append(cmd)
def send_federation_ack(self, token):
@@ -200,6 +211,9 @@ class ReplicationClientHandler(object):
cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
self.send_command(cmd)
+ def send_remote_server_up(self, server: str):
+ self.send_command(RemoteServerUpCommand(server))
+
def await_sync(self, data):
"""Returns a deferred that is resolved when we receive a SYNC command
with given data.
@@ -226,4 +240,5 @@ class ReplicationClientHandler(object):
# We don't reset the delay any earlier as otherwise if there is a
# problem during start up we'll end up tight looping connecting to the
# server.
- self.factory.resetDelay()
+ if self.factory:
+ self.factory.resetDelay()
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 2098c32a77..451671412d 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -20,13 +20,16 @@ allowed to be sent by which side.
import logging
import platform
+from typing import Tuple, Type
if platform.python_implementation() == "PyPy":
import json
+
_json_encoder = json.JSONEncoder()
else:
- import simplejson as json
- _json_encoder = json.JSONEncoder(namedtuple_as_object=False)
+ import simplejson as json # type: ignore[no-redef] # noqa: F821
+
+ _json_encoder = json.JSONEncoder(namedtuple_as_object=False) # type: ignore[call-arg] # noqa: F821
logger = logging.getLogger(__name__)
@@ -41,7 +44,8 @@ class Command(object):
The default implementation creates a command of form `<NAME> <data>`
"""
- NAME = None
+
+ NAME = None # type: str
def __init__(self, data):
self.data = data
@@ -73,6 +77,7 @@ class ServerCommand(Command):
SERVER <server_name>
"""
+
NAME = "SERVER"
@@ -99,6 +104,7 @@ class RdataCommand(Command):
RDATA presence batch ["@bar:example.com", "online", ...]
RDATA presence 59 ["@baz:example.com", "online", ...]
"""
+
NAME = "RDATA"
def __init__(self, stream_name, token, row):
@@ -110,17 +116,17 @@ class RdataCommand(Command):
def from_line(cls, line):
stream_name, token, row_json = line.split(" ", 2)
return cls(
- stream_name,
- None if token == "batch" else int(token),
- json.loads(row_json)
+ stream_name, None if token == "batch" else int(token), json.loads(row_json)
)
def to_line(self):
- return " ".join((
- self.stream_name,
- str(self.token) if self.token is not None else "batch",
- _json_encoder.encode(self.row),
- ))
+ return " ".join(
+ (
+ self.stream_name,
+ str(self.token) if self.token is not None else "batch",
+ _json_encoder.encode(self.row),
+ )
+ )
def get_logcontext_id(self):
return "RDATA-" + self.stream_name
@@ -133,6 +139,7 @@ class PositionCommand(Command):
Sent to the client after all missing updates for a stream have been sent
to the client and they're now up to date.
"""
+
NAME = "POSITION"
def __init__(self, stream_name, token):
@@ -145,19 +152,21 @@ class PositionCommand(Command):
return cls(stream_name, int(token))
def to_line(self):
- return " ".join((self.stream_name, str(self.token),))
+ return " ".join((self.stream_name, str(self.token)))
class ErrorCommand(Command):
"""Sent by either side if there was an ERROR. The data is a string describing
the error.
"""
+
NAME = "ERROR"
class PingCommand(Command):
"""Sent by either side as a keep alive. The data is arbitary (often timestamp)
"""
+
NAME = "PING"
@@ -165,6 +174,7 @@ class NameCommand(Command):
"""Sent by client to inform the server of the client's identity. The data
is the name
"""
+
NAME = "NAME"
@@ -184,6 +194,7 @@ class ReplicateCommand(Command):
REPLICATE ALL NOW
"""
+
NAME = "REPLICATE"
def __init__(self, stream_name, token):
@@ -200,7 +211,7 @@ class ReplicateCommand(Command):
return cls(stream_name, token)
def to_line(self):
- return " ".join((self.stream_name, str(self.token),))
+ return " ".join((self.stream_name, str(self.token)))
def get_logcontext_id(self):
return "REPLICATE-" + self.stream_name
@@ -218,6 +229,7 @@ class UserSyncCommand(Command):
Where <state> is either "start" or "stop"
"""
+
NAME = "USER_SYNC"
def __init__(self, user_id, is_syncing, last_sync_ms):
@@ -235,9 +247,13 @@ class UserSyncCommand(Command):
return cls(user_id, state == "start", int(last_sync_ms))
def to_line(self):
- return " ".join((
- self.user_id, "start" if self.is_syncing else "end", str(self.last_sync_ms),
- ))
+ return " ".join(
+ (
+ self.user_id,
+ "start" if self.is_syncing else "end",
+ str(self.last_sync_ms),
+ )
+ )
class FederationAckCommand(Command):
@@ -251,6 +267,7 @@ class FederationAckCommand(Command):
FEDERATION_ACK <token>
"""
+
NAME = "FEDERATION_ACK"
def __init__(self, token):
@@ -268,6 +285,7 @@ class SyncCommand(Command):
"""Used for testing. The client protocol implementation allows waiting
on a SYNC command with a specified data.
"""
+
NAME = "SYNC"
@@ -278,6 +296,7 @@ class RemovePusherCommand(Command):
REMOVE_PUSHER <app_id> <push_key> <user_id>
"""
+
NAME = "REMOVE_PUSHER"
def __init__(self, app_id, push_key, user_id):
@@ -309,6 +328,7 @@ class InvalidateCacheCommand(Command):
Where <keys_json> is a json list.
"""
+
NAME = "INVALIDATE_CACHE"
def __init__(self, cache_func, keys):
@@ -322,9 +342,7 @@ class InvalidateCacheCommand(Command):
return cls(cache_func, json.loads(keys_json))
def to_line(self):
- return " ".join((
- self.cache_func, _json_encoder.encode(self.keys),
- ))
+ return " ".join((self.cache_func, _json_encoder.encode(self.keys)))
class UserIpCommand(Command):
@@ -334,6 +352,7 @@ class UserIpCommand(Command):
USER_IP <user_id>, <access_token>, <ip>, <device_id>, <last_seen>, <user_agent>
"""
+
NAME = "USER_IP"
def __init__(self, user_id, access_token, ip, user_agent, device_id, last_seen):
@@ -350,36 +369,57 @@ class UserIpCommand(Command):
access_token, ip, user_agent, device_id, last_seen = json.loads(jsn)
- return cls(
- user_id, access_token, ip, user_agent, device_id, last_seen
- )
+ return cls(user_id, access_token, ip, user_agent, device_id, last_seen)
def to_line(self):
- return self.user_id + " " + _json_encoder.encode((
- self.access_token, self.ip, self.user_agent, self.device_id,
- self.last_seen,
- ))
+ return (
+ self.user_id
+ + " "
+ + _json_encoder.encode(
+ (
+ self.access_token,
+ self.ip,
+ self.user_agent,
+ self.device_id,
+ self.last_seen,
+ )
+ )
+ )
+
+
+class RemoteServerUpCommand(Command):
+ """Sent when a worker has detected that a remote server is no longer
+ "down" and retry timings should be reset.
+
+ If sent from a client the server will relay to all other workers.
+
+ Format::
+
+ REMOTE_SERVER_UP <server>
+ """
+ NAME = "REMOTE_SERVER_UP"
+
+
+_COMMANDS = (
+ ServerCommand,
+ RdataCommand,
+ PositionCommand,
+ ErrorCommand,
+ PingCommand,
+ NameCommand,
+ ReplicateCommand,
+ UserSyncCommand,
+ FederationAckCommand,
+ SyncCommand,
+ RemovePusherCommand,
+ InvalidateCacheCommand,
+ UserIpCommand,
+ RemoteServerUpCommand,
+) # type: Tuple[Type[Command], ...]
# Map of command name to command type.
-COMMAND_MAP = {
- cmd.NAME: cmd
- for cmd in (
- ServerCommand,
- RdataCommand,
- PositionCommand,
- ErrorCommand,
- PingCommand,
- NameCommand,
- ReplicateCommand,
- UserSyncCommand,
- FederationAckCommand,
- SyncCommand,
- RemovePusherCommand,
- InvalidateCacheCommand,
- UserIpCommand,
- )
-}
+COMMAND_MAP = {cmd.NAME: cmd for cmd in _COMMANDS}
# The commands the server is allowed to send
VALID_SERVER_COMMANDS = (
@@ -389,6 +429,7 @@ VALID_SERVER_COMMANDS = (
ErrorCommand.NAME,
PingCommand.NAME,
SyncCommand.NAME,
+ RemoteServerUpCommand.NAME,
)
# The commands the client is allowed to send
@@ -402,4 +443,5 @@ VALID_CLIENT_COMMANDS = (
InvalidateCacheCommand.NAME,
UserIpCommand.NAME,
ErrorCommand.NAME,
+ RemoteServerUpCommand.NAME,
)
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index b51590cf8f..d185cc0c8f 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -48,11 +48,12 @@ indicate which side is sending, these are *not* included on the wire::
> ERROR server stopping
* connection closed by server *
"""
-
+import abc
import fcntl
import logging
import struct
from collections import defaultdict
+from typing import Any, DefaultDict, Dict, List, Set, Tuple
from six import iteritems, iterkeys
@@ -62,29 +63,33 @@ from twisted.internet import defer
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.util.logcontext import make_deferred_yieldable, run_in_background
-from synapse.util.stringutils import random_string
-
-from .commands import (
+from synapse.replication.tcp.commands import (
COMMAND_MAP,
VALID_CLIENT_COMMANDS,
VALID_SERVER_COMMANDS,
+ Command,
ErrorCommand,
NameCommand,
PingCommand,
PositionCommand,
RdataCommand,
+ RemoteServerUpCommand,
ReplicateCommand,
ServerCommand,
SyncCommand,
UserSyncCommand,
)
-from .streams import STREAMS_MAP
+from synapse.replication.tcp.streams import STREAMS_MAP
+from synapse.types import Collection
+from synapse.util import Clock
+from synapse.util.stringutils import random_string
connection_close_counter = Counter(
- "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"])
+ "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
+)
# A list of all connected protocols. This allows us to send metrics about the
# connections.
@@ -119,10 +124,14 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
It also sends `PING` periodically, and correctly times out remote connections
(if they send a `PING` command)
"""
- delimiter = b'\n'
- VALID_INBOUND_COMMANDS = [] # Valid commands we expect to receive
- VALID_OUTBOUND_COMMANDS = [] # Valid commans we can send
+ delimiter = b"\n"
+
+ # Valid commands we expect to receive
+ VALID_INBOUND_COMMANDS = [] # type: Collection[str]
+
+ # Valid commands we can send
+ VALID_OUTBOUND_COMMANDS = [] # type: Collection[str]
max_line_buffer = 10000
@@ -141,13 +150,13 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.conn_id = random_string(5) # To dedupe in case of name clashes.
# List of pending commands to send once we've established the connection
- self.pending_commands = []
+ self.pending_commands = [] # type: List[Command]
# The LoopingCall for sending pings.
self._send_ping_loop = None
- self.inbound_commands_counter = defaultdict(int)
- self.outbound_commands_counter = defaultdict(int)
+ self.inbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int]
+ self.outbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int]
def connectionMade(self):
logger.info("[%s] Connection established", self.id())
@@ -183,10 +192,14 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
if now - self.last_sent_command >= PING_TIME:
self.send_command(PingCommand(now))
- if self.received_ping and now - self.last_received_command > PING_TIMEOUT_MS:
+ if (
+ self.received_ping
+ and now - self.last_received_command > PING_TIMEOUT_MS
+ ):
logger.info(
"[%s] Connection hasn't received command in %r ms. Closing.",
- self.id(), now - self.last_received_command
+ self.id(),
+ now - self.last_received_command,
)
self.send_error("ping timeout")
@@ -208,7 +221,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.last_received_command = self.clock.time_msec()
self.inbound_commands_counter[cmd_name] = (
- self.inbound_commands_counter[cmd_name] + 1)
+ self.inbound_commands_counter[cmd_name] + 1
+ )
cmd_cls = COMMAND_MAP[cmd_name]
try:
@@ -224,27 +238,22 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# Now lets try and call on_<CMD_NAME> function
run_as_background_process(
- "replication-" + cmd.get_logcontext_id(),
- self.handle_command,
- cmd,
+ "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
)
- def handle_command(self, cmd):
+ async def handle_command(self, cmd: Command):
"""Handle a command we have received over the replication stream.
- By default delegates to on_<COMMAND>
+ By default delegates to on_<COMMAND>, which should return an awaitable.
Args:
- cmd (synapse.replication.tcp.commands.Command): received command
-
- Returns:
- Deferred
+ cmd: received command
"""
handler = getattr(self, "on_%s" % (cmd.NAME,))
- return handler(cmd)
+ await handler(cmd)
def close(self):
- logger.warn("[%s] Closing connection", self.id())
+ logger.warning("[%s] Closing connection", self.id())
self.time_we_closed = self.clock.time_msec()
self.transport.loseConnection()
self.on_connection_closed()
@@ -274,8 +283,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
return
self.outbound_commands_counter[cmd.NAME] = (
- self.outbound_commands_counter[cmd.NAME] + 1)
- string = "%s %s" % (cmd.NAME, cmd.to_line(),)
+ self.outbound_commands_counter[cmd.NAME] + 1
+ )
+ string = "%s %s" % (cmd.NAME, cmd.to_line())
if "\n" in string:
raise Exception("Unexpected newline in command: %r", string)
@@ -283,10 +293,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
if len(encoded_string) > self.MAX_LENGTH:
raise Exception(
- "Failed to send command %s as too long (%d > %d)" % (
- cmd.NAME,
- len(encoded_string), self.MAX_LENGTH,
- )
+ "Failed to send command %s as too long (%d > %d)"
+ % (cmd.NAME, len(encoded_string), self.MAX_LENGTH)
)
self.sendLine(encoded_string)
@@ -315,10 +323,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
for cmd in pending:
self.send_command(cmd)
- def on_PING(self, line):
+ async def on_PING(self, line):
self.received_ping = True
- def on_ERROR(self, cmd):
+ async def on_ERROR(self, cmd):
logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
def pauseProducing(self):
@@ -379,7 +387,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
if self.transport:
addr = str(self.transport.getPeer())
return "ReplicationConnection<name=%s,conn_id=%s,addr=%s>" % (
- self.name, self.conn_id, addr,
+ self.name,
+ self.conn_id,
+ addr,
)
def id(self):
@@ -402,68 +412,69 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.streamer = streamer
# The streams the client has subscribed to and is up to date with
- self.replication_streams = set()
+ self.replication_streams = set() # type: Set[str]
# The streams the client is currently subscribing to.
- self.connecting_streams = set()
+ self.connecting_streams = set() # type: Set[str]
# Map from stream name to list of updates to send once we've finished
# subscribing the client to the stream.
- self.pending_rdata = {}
+ self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]]
def connectionMade(self):
self.send_command(ServerCommand(self.server_name))
BaseReplicationStreamProtocol.connectionMade(self)
self.streamer.new_connection(self)
- def on_NAME(self, cmd):
+ async def on_NAME(self, cmd):
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
self.name = cmd.data
- def on_USER_SYNC(self, cmd):
- return self.streamer.on_user_sync(
- self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms,
+ async def on_USER_SYNC(self, cmd):
+ await self.streamer.on_user_sync(
+ self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)
- def on_REPLICATE(self, cmd):
+ async def on_REPLICATE(self, cmd):
stream_name = cmd.stream_name
token = cmd.token
if stream_name == "ALL":
# Subscribe to all streams we're publishing to.
deferreds = [
- run_in_background(
- self.subscribe_to_stream,
- stream, token,
- )
+ run_in_background(self.subscribe_to_stream, stream, token)
for stream in iterkeys(self.streamer.streams_by_name)
]
- return make_deferred_yieldable(
+ await make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True)
)
else:
- return self.subscribe_to_stream(stream_name, token)
+ await self.subscribe_to_stream(stream_name, token)
- def on_FEDERATION_ACK(self, cmd):
- return self.streamer.federation_ack(cmd.token)
+ async def on_FEDERATION_ACK(self, cmd):
+ self.streamer.federation_ack(cmd.token)
- def on_REMOVE_PUSHER(self, cmd):
- return self.streamer.on_remove_pusher(
- cmd.app_id, cmd.push_key, cmd.user_id,
- )
+ async def on_REMOVE_PUSHER(self, cmd):
+ await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
- def on_INVALIDATE_CACHE(self, cmd):
- return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
+ async def on_INVALIDATE_CACHE(self, cmd):
+ await self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
- def on_USER_IP(self, cmd):
- return self.streamer.on_user_ip(
- cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id,
+ async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
+ self.streamer.on_remote_server_up(cmd.data)
+
+ async def on_USER_IP(self, cmd):
+ await self.streamer.on_user_ip(
+ cmd.user_id,
+ cmd.access_token,
+ cmd.ip,
+ cmd.user_agent,
+ cmd.device_id,
cmd.last_seen,
)
- @defer.inlineCallbacks
- def subscribe_to_stream(self, stream_name, token):
+ async def subscribe_to_stream(self, stream_name, token):
"""Subscribe the remote to a stream.
This invloves checking if they've missed anything and sending those
@@ -475,8 +486,8 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
try:
# Get missing updates
- updates, current_token = yield self.streamer.get_stream_updates(
- stream_name, token,
+ updates, current_token = await self.streamer.get_stream_updates(
+ stream_name, token
)
# Send all the missing updates
@@ -548,16 +559,90 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
def send_sync(self, data):
self.send_command(SyncCommand(data))
+ def send_remote_server_up(self, server: str):
+ self.send_command(RemoteServerUpCommand(server))
+
def on_connection_closed(self):
BaseReplicationStreamProtocol.on_connection_closed(self)
self.streamer.lost_connection(self)
+class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
+ """
+ The interface for the handler that should be passed to
+ ClientReplicationStreamProtocol
+ """
+
+ @abc.abstractmethod
+ async def on_rdata(self, stream_name, token, rows):
+ """Called to handle a batch of replication data with a given stream token.
+
+ Args:
+ stream_name (str): name of the replication stream for this batch of rows
+ token (int): stream token for this batch of rows
+ rows (list): a list of Stream.ROW_TYPE objects as returned by
+ Stream.parse_row.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ async def on_position(self, stream_name, token):
+ """Called when we get new position data."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def on_sync(self, data):
+ """Called when get a new SYNC command."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ async def on_remote_server_up(self, server: str):
+ """Called when get a new REMOTE_SERVER_UP command."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_streams_to_replicate(self):
+ """Called when a new connection has been established and we need to
+ subscribe to streams.
+
+ Returns:
+ map from stream name to the most recent update we have for
+ that stream (ie, the point we want to start replicating from)
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_currently_syncing_users(self):
+ """Get the list of currently syncing users (if any). This is called
+ when a connection has been established and we need to send the
+ currently syncing users."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def update_connection(self, connection):
+ """Called when a connection has been established (or lost with None).
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def finished_connecting(self):
+ """Called when we have successfully subscribed and caught up to all
+ streams we're interested in.
+ """
+ raise NotImplementedError()
+
+
class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS
- def __init__(self, client_name, server_name, clock, handler):
+ def __init__(
+ self,
+ client_name: str,
+ server_name: str,
+ clock: Clock,
+ handler: AbstractReplicationClientHandler,
+ ):
BaseReplicationStreamProtocol.__init__(self, clock)
self.client_name = client_name
@@ -567,11 +652,11 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# Set of stream names that have been subscribe to, but haven't yet
# caught up with. This is used to track when the client has been fully
# connected to the remote.
- self.streams_connecting = set()
+ self.streams_connecting = set() # type: Set[str]
# Map of stream to batched updates. See RdataCommand for info on how
# batching works.
- self.pending_batches = {}
+ self.pending_batches = {} # type: Dict[str, Any]
def connectionMade(self):
self.send_command(NameCommand(self.client_name))
@@ -595,12 +680,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
if not self.streams_connecting:
self.handler.finished_connecting()
- def on_SERVER(self, cmd):
+ async def on_SERVER(self, cmd):
if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
self.send_error("Wrong remote")
- def on_RDATA(self, cmd):
+ async def on_RDATA(self, cmd):
stream_name = cmd.stream_name
inbound_rdata_count.labels(stream_name).inc()
@@ -608,8 +693,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
row = STREAMS_MAP[stream_name].parse_row(cmd.row)
except Exception:
logger.exception(
- "[%s] Failed to parse RDATA: %r %r",
- self.id(), stream_name, cmd.row
+ "[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row
)
raise
@@ -621,19 +705,22 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# Check if this is the last of a batch of updates
rows = self.pending_batches.pop(stream_name, [])
rows.append(row)
- return self.handler.on_rdata(stream_name, cmd.token, rows)
+ await self.handler.on_rdata(stream_name, cmd.token, rows)
- def on_POSITION(self, cmd):
+ async def on_POSITION(self, cmd):
# When we get a `POSITION` command it means we've finished getting
# missing updates for the given stream, and are now up to date.
self.streams_connecting.discard(cmd.stream_name)
if not self.streams_connecting:
self.handler.finished_connecting()
- return self.handler.on_position(cmd.stream_name, cmd.token)
+ await self.handler.on_position(cmd.stream_name, cmd.token)
- def on_SYNC(self, cmd):
- return self.handler.on_sync(cmd.data)
+ async def on_SYNC(self, cmd):
+ self.handler.on_sync(cmd.data)
+
+ async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
+ self.handler.on_remote_server_up(cmd.data)
def replicate(self, stream_name, token):
"""Send the subscription request to the server
@@ -643,7 +730,9 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
logger.info(
"[%s] Subscribing to replication stream: %r from %r",
- self.id(), stream_name, token
+ self.id(),
+ stream_name,
+ token,
)
self.streams_connecting.add(stream_name)
@@ -661,9 +750,7 @@ pending_commands = LaterGauge(
"synapse_replication_tcp_protocol_pending_commands",
"",
["name"],
- lambda: {
- (p.name,): len(p.pending_commands) for p in connected_connections
- },
+ lambda: {(p.name,): len(p.pending_commands) for p in connected_connections},
)
@@ -678,9 +765,7 @@ transport_send_buffer = LaterGauge(
"synapse_replication_tcp_protocol_transport_send_buffer",
"",
["name"],
- lambda: {
- (p.name,): transport_buffer_size(p) for p in connected_connections
- },
+ lambda: {(p.name,): transport_buffer_size(p) for p in connected_connections},
)
@@ -694,7 +779,7 @@ def transport_kernel_read_buffer_size(protocol, read=True):
op = SIOCINQ
else:
op = SIOCOUTQ
- size = struct.unpack("I", fcntl.ioctl(fileno, op, '\0\0\0\0'))[0]
+ size = struct.unpack("I", fcntl.ioctl(fileno, op, b"\0\0\0\0"))[0]
return size
return 0
@@ -726,7 +811,7 @@ tcp_inbound_commands = LaterGauge(
"",
["command", "name"],
lambda: {
- (k, p.name,): count
+ (k, p.name): count
for p in connected_connections
for k, count in iteritems(p.inbound_commands_counter)
},
@@ -737,7 +822,7 @@ tcp_outbound_commands = LaterGauge(
"",
["command", "name"],
lambda: {
- (k, p.name,): count
+ (k, p.name): count
for p in connected_connections
for k, count in iteritems(p.outbound_commands_counter)
},
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index f6a38f5140..ce9d1fae12 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -17,12 +17,12 @@
import logging
import random
+from typing import Any, List
from six import itervalues
from prometheus_client import Counter
-from twisted.internet import defer
from twisted.internet.protocol import Factory
from synapse.metrics import LaterGauge
@@ -33,13 +33,15 @@ from .protocol import ServerReplicationStreamProtocol
from .streams import STREAMS_MAP
from .streams.federation import FederationStream
-stream_updates_counter = Counter("synapse_replication_tcp_resource_stream_updates",
- "", ["stream_name"])
+stream_updates_counter = Counter(
+ "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"]
+)
user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
-invalidate_cache_counter = Counter("synapse_replication_tcp_resource_invalidate_cache",
- "")
+invalidate_cache_counter = Counter(
+ "synapse_replication_tcp_resource_invalidate_cache", ""
+)
user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
logger = logging.getLogger(__name__)
@@ -48,6 +50,7 @@ logger = logging.getLogger(__name__)
class ReplicationStreamProtocolFactory(Factory):
"""Factory for new replication connections.
"""
+
def __init__(self, hs):
self.streamer = ReplicationStreamer(hs)
self.clock = hs.get_clock()
@@ -55,9 +58,7 @@ class ReplicationStreamProtocolFactory(Factory):
def buildProtocol(self, addr):
return ServerReplicationStreamProtocol(
- self.server_name,
- self.clock,
- self.streamer,
+ self.server_name, self.clock, self.streamer
)
@@ -78,37 +79,48 @@ class ReplicationStreamer(object):
self._replication_torture_level = hs.config.replication_torture_level
# Current connections.
- self.connections = []
+ self.connections = [] # type: List[ServerReplicationStreamProtocol]
- LaterGauge("synapse_replication_tcp_resource_total_connections", "", [],
- lambda: len(self.connections))
+ LaterGauge(
+ "synapse_replication_tcp_resource_total_connections",
+ "",
+ [],
+ lambda: len(self.connections),
+ )
# List of streams that clients can subscribe to.
# We only support federation stream if federation sending hase been
# disabled on the master.
self.streams = [
- stream(hs) for stream in itervalues(STREAMS_MAP)
+ stream(hs)
+ for stream in itervalues(STREAMS_MAP)
if stream != FederationStream or not hs.config.send_federation
]
self.streams_by_name = {stream.NAME: stream for stream in self.streams}
LaterGauge(
- "synapse_replication_tcp_resource_connections_per_stream", "",
+ "synapse_replication_tcp_resource_connections_per_stream",
+ "",
["stream_name"],
lambda: {
- (stream_name,): len([
- conn for conn in self.connections
- if stream_name in conn.replication_streams
- ])
+ (stream_name,): len(
+ [
+ conn
+ for conn in self.connections
+ if stream_name in conn.replication_streams
+ ]
+ )
for stream_name in self.streams_by_name
- })
+ },
+ )
self.federation_sender = None
if not hs.config.send_federation:
self.federation_sender = hs.get_federation_sender()
self.notifier.add_replication_callback(self.on_notifier_poke)
+ self.notifier.add_remote_server_up_callback(self.send_remote_server_up)
# Keeps track of whether we are currently checking for updates
self.is_looping = False
@@ -143,8 +155,7 @@ class ReplicationStreamer(object):
run_as_background_process("replication_notifier", self._run_notifier_loop)
- @defer.inlineCallbacks
- def _run_notifier_loop(self):
+ async def _run_notifier_loop(self):
self.is_looping = True
try:
@@ -173,23 +184,26 @@ class ReplicationStreamer(object):
continue
if self._replication_torture_level:
- yield self.clock.sleep(
+ await self.clock.sleep(
self._replication_torture_level / 1000.0
)
logger.debug(
"Getting stream: %s: %s -> %s",
- stream.NAME, stream.last_token, stream.upto_token
+ stream.NAME,
+ stream.last_token,
+ stream.upto_token,
)
try:
- updates, current_token = yield stream.get_updates()
+ updates, current_token = await stream.get_updates()
except Exception:
logger.info("Failed to handle stream %s", stream.NAME)
raise
logger.debug(
"Sending %d updates to %d connections",
- len(updates), len(self.connections),
+ len(updates),
+ len(self.connections),
)
if updates:
@@ -218,7 +232,7 @@ class ReplicationStreamer(object):
self.is_looping = False
@measure_func("repl.get_stream_updates")
- def get_stream_updates(self, stream_name, token):
+ async def get_stream_updates(self, stream_name, token):
"""For a given stream get all updates since token. This is called when
a client first subscribes to a stream.
"""
@@ -226,7 +240,7 @@ class ReplicationStreamer(object):
if not stream:
raise Exception("unknown stream %s", stream_name)
- return stream.get_updates_since(token)
+ return await stream.get_updates_since(token)
@measure_func("repl.federation_ack")
def federation_ack(self, token):
@@ -237,44 +251,54 @@ class ReplicationStreamer(object):
self.federation_sender.federation_ack(token)
@measure_func("repl.on_user_sync")
- @defer.inlineCallbacks
- def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
+ async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
"""A client has started/stopped syncing on a worker.
"""
user_sync_counter.inc()
- yield self.presence_handler.update_external_syncs_row(
- conn_id, user_id, is_syncing, last_sync_ms,
+ await self.presence_handler.update_external_syncs_row(
+ conn_id, user_id, is_syncing, last_sync_ms
)
@measure_func("repl.on_remove_pusher")
- @defer.inlineCallbacks
- def on_remove_pusher(self, app_id, push_key, user_id):
+ async def on_remove_pusher(self, app_id, push_key, user_id):
"""A client has asked us to remove a pusher
"""
remove_pusher_counter.inc()
- yield self.store.delete_pusher_by_app_id_pushkey_user_id(
+ await self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id=app_id, pushkey=push_key, user_id=user_id
)
self.notifier.on_new_replication_data()
@measure_func("repl.on_invalidate_cache")
- def on_invalidate_cache(self, cache_func, keys):
+ async def on_invalidate_cache(self, cache_func: str, keys: List[Any]):
"""The client has asked us to invalidate a cache
"""
invalidate_cache_counter.inc()
- getattr(self.store, cache_func).invalidate(tuple(keys))
+
+ # We invalidate the cache locally, but then also stream that to other
+ # workers.
+ await self.store.invalidate_cache_and_stream(cache_func, tuple(keys))
@measure_func("repl.on_user_ip")
- @defer.inlineCallbacks
- def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
+ async def on_user_ip(
+ self, user_id, access_token, ip, user_agent, device_id, last_seen
+ ):
"""The client saw a user request
"""
user_ip_cache_counter.inc()
- yield self.store.insert_client_ip(
- user_id, access_token, ip, user_agent, device_id, last_seen,
+ await self.store.insert_client_ip(
+ user_id, access_token, ip, user_agent, device_id, last_seen
)
- yield self._server_notices_sender.on_user_ip(user_id)
+ await self._server_notices_sender.on_user_ip(user_id)
+
+ @measure_func("repl.on_remote_server_up")
+ def on_remote_server_up(self, server: str):
+ self.notifier.notify_remote_server_up(server)
+
+ def send_remote_server_up(self, server: str):
+ for conn in self.connections:
+ conn.send_remote_server_up(server)
def send_sync_to_all_connections(self, data):
"""Sends a SYNC command to all clients.
@@ -299,7 +323,11 @@ class ReplicationStreamer(object):
# We need to tell the presence handler that the connection has been
# lost so that it can handle any ongoing syncs on that connection.
- self.presence_handler.update_external_syncs_clear(connection.conn_id)
+ run_as_background_process(
+ "update_external_syncs_clear",
+ self.presence_handler.update_external_syncs_clear,
+ connection.conn_id,
+ )
def _batch_updates(updates):
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index 634f636dc9..5f52264e84 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -45,5 +45,6 @@ STREAMS_MAP = {
_base.TagAccountDataStream,
_base.AccountDataStream,
_base.GroupServerStream,
+ _base.UserSignatureStream,
)
}
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index b6ce7a7bee..208e8a667b 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -14,90 +14,101 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import itertools
import logging
from collections import namedtuple
+from typing import Any, List, Optional
-from twisted.internet import defer
+import attr
logger = logging.getLogger(__name__)
-MAX_EVENTS_BEHIND = 10000
-
-BackfillStreamRow = namedtuple("BackfillStreamRow", (
- "event_id", # str
- "room_id", # str
- "type", # str
- "state_key", # str, optional
- "redacts", # str, optional
- "relates_to", # str, optional
-))
-PresenceStreamRow = namedtuple("PresenceStreamRow", (
- "user_id", # str
- "state", # str
- "last_active_ts", # int
- "last_federation_update_ts", # int
- "last_user_sync_ts", # int
- "status_msg", # str
- "currently_active", # bool
-))
-TypingStreamRow = namedtuple("TypingStreamRow", (
- "room_id", # str
- "user_ids", # list(str)
-))
-ReceiptsStreamRow = namedtuple("ReceiptsStreamRow", (
- "room_id", # str
- "receipt_type", # str
- "user_id", # str
- "event_id", # str
- "data", # dict
-))
-PushRulesStreamRow = namedtuple("PushRulesStreamRow", (
- "user_id", # str
-))
-PushersStreamRow = namedtuple("PushersStreamRow", (
- "user_id", # str
- "app_id", # str
- "pushkey", # str
- "deleted", # bool
-))
-CachesStreamRow = namedtuple("CachesStreamRow", (
- "cache_func", # str
- "keys", # list(str)
- "invalidation_ts", # int
-))
-PublicRoomsStreamRow = namedtuple("PublicRoomsStreamRow", (
- "room_id", # str
- "visibility", # str
- "appservice_id", # str, optional
- "network_id", # str, optional
-))
-DeviceListsStreamRow = namedtuple("DeviceListsStreamRow", (
- "user_id", # str
- "destination", # str
-))
-ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", (
- "entity", # str
-))
-TagAccountDataStreamRow = namedtuple("TagAccountDataStreamRow", (
- "user_id", # str
- "room_id", # str
- "data", # dict
-))
-AccountDataStreamRow = namedtuple("AccountDataStream", (
- "user_id", # str
- "room_id", # str
- "data_type", # str
- "data", # dict
-))
-GroupsStreamRow = namedtuple("GroupsStreamRow", (
- "group_id", # str
- "user_id", # str
- "type", # str
- "content", # dict
-))
+MAX_EVENTS_BEHIND = 500000
+
+BackfillStreamRow = namedtuple(
+ "BackfillStreamRow",
+ (
+ "event_id", # str
+ "room_id", # str
+ "type", # str
+ "state_key", # str, optional
+ "redacts", # str, optional
+ "relates_to", # str, optional
+ ),
+)
+PresenceStreamRow = namedtuple(
+ "PresenceStreamRow",
+ (
+ "user_id", # str
+ "state", # str
+ "last_active_ts", # int
+ "last_federation_update_ts", # int
+ "last_user_sync_ts", # int
+ "status_msg", # str
+ "currently_active", # bool
+ ),
+)
+TypingStreamRow = namedtuple(
+ "TypingStreamRow", ("room_id", "user_ids") # str # list(str)
+)
+ReceiptsStreamRow = namedtuple(
+ "ReceiptsStreamRow",
+ (
+ "room_id", # str
+ "receipt_type", # str
+ "user_id", # str
+ "event_id", # str
+ "data", # dict
+ ),
+)
+PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str
+PushersStreamRow = namedtuple(
+ "PushersStreamRow",
+ ("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool
+)
+
+
+@attr.s
+class CachesStreamRow:
+ """Stream to inform workers they should invalidate their cache.
+
+ Attributes:
+ cache_func: Name of the cached function.
+ keys: The entry in the cache to invalidate. If None then will
+ invalidate all.
+ invalidation_ts: Timestamp of when the invalidation took place.
+ """
+
+ cache_func = attr.ib(type=str)
+ keys = attr.ib(type=Optional[List[Any]])
+ invalidation_ts = attr.ib(type=int)
+
+
+PublicRoomsStreamRow = namedtuple(
+ "PublicRoomsStreamRow",
+ (
+ "room_id", # str
+ "visibility", # str
+ "appservice_id", # str, optional
+ "network_id", # str, optional
+ ),
+)
+DeviceListsStreamRow = namedtuple(
+ "DeviceListsStreamRow", ("user_id", "destination") # str # str
+)
+ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str
+TagAccountDataStreamRow = namedtuple(
+ "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict
+)
+AccountDataStreamRow = namedtuple(
+ "AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str
+)
+GroupsStreamRow = namedtuple(
+ "GroupsStreamRow",
+ ("group_id", "user_id", "type", "content"), # str # str # str # dict
+)
+UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str
class Stream(object):
@@ -106,8 +117,10 @@ class Stream(object):
Provides a `get_updates()` function that returns new updates since the last
time it was called up until the point `advance_current_token` was called.
"""
- NAME = None # The name of the stream
- ROW_TYPE = None # The type of the row. Used by the default impl of parse_row.
+
+ NAME = None # type: str # The name of the stream
+ # The type of the row. Used by the default impl of parse_row.
+ ROW_TYPE = None # type: Any
_LIMITED = True # Whether the update function takes a limit
@classmethod
@@ -145,8 +158,7 @@ class Stream(object):
self.upto_token = self.current_token()
self.last_token = self.upto_token
- @defer.inlineCallbacks
- def get_updates(self):
+ async def get_updates(self):
"""Gets all updates since the last time this function was called (or
since the stream was constructed if it hadn't been called before),
until the `upto_token`
@@ -157,13 +169,12 @@ class Stream(object):
list of ``(token, row)`` entries. ``row`` will be json-serialised and
sent over the replication steam.
"""
- updates, current_token = yield self.get_updates_since(self.last_token)
+ updates, current_token = await self.get_updates_since(self.last_token)
self.last_token = current_token
- defer.returnValue((updates, current_token))
+ return updates, current_token
- @defer.inlineCallbacks
- def get_updates_since(self, from_token):
+ async def get_updates_since(self, from_token):
"""Like get_updates except allows specifying from when we should
stream updates
@@ -174,27 +185,25 @@ class Stream(object):
sent over the replication steam.
"""
if from_token in ("NOW", "now"):
- defer.returnValue(([], self.upto_token))
+ return [], self.upto_token
current_token = self.upto_token
from_token = int(from_token)
if from_token == current_token:
- defer.returnValue(([], current_token))
+ return [], current_token
+ logger.info("get_updates_since: %s", self.__class__)
if self._LIMITED:
- rows = yield self.update_function(
- from_token, current_token,
- limit=MAX_EVENTS_BEHIND + 1,
+ rows = await self.update_function(
+ from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
)
# never turn more than MAX_EVENTS_BEHIND + 1 into updates.
rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
else:
- rows = yield self.update_function(
- from_token, current_token,
- )
+ rows = await self.update_function(from_token, current_token)
updates = [(row[0], row[1:]) for row in rows]
@@ -203,7 +212,7 @@ class Stream(object):
if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND:
raise Exception("stream %s has fallen behind" % (self.NAME))
- defer.returnValue((updates, current_token))
+ return updates, current_token
def current_token(self):
"""Gets the current token of the underlying streams. Should be provided
@@ -230,13 +239,14 @@ class BackfillStream(Stream):
"""We fetched some old events and either we had never seen that event before
or it went from being an outlier to not.
"""
+
NAME = "backfill"
ROW_TYPE = BackfillStreamRow
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_current_backfill_token
- self.update_function = store.get_all_new_backfill_event_rows
+ self.current_token = store.get_current_backfill_token # type: ignore
+ self.update_function = store.get_all_new_backfill_event_rows # type: ignore
super(BackfillStream, self).__init__(hs)
@@ -250,8 +260,8 @@ class PresenceStream(Stream):
store = hs.get_datastore()
presence_handler = hs.get_presence_handler()
- self.current_token = store.get_current_presence_token
- self.update_function = presence_handler.get_all_presence_updates
+ self.current_token = store.get_current_presence_token # type: ignore
+ self.update_function = presence_handler.get_all_presence_updates # type: ignore
super(PresenceStream, self).__init__(hs)
@@ -264,8 +274,8 @@ class TypingStream(Stream):
def __init__(self, hs):
typing_handler = hs.get_typing_handler()
- self.current_token = typing_handler.get_current_token
- self.update_function = typing_handler.get_all_typing_updates
+ self.current_token = typing_handler.get_current_token # type: ignore
+ self.update_function = typing_handler.get_all_typing_updates # type: ignore
super(TypingStream, self).__init__(hs)
@@ -277,8 +287,8 @@ class ReceiptsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_max_receipt_stream_id
- self.update_function = store.get_all_updated_receipts
+ self.current_token = store.get_max_receipt_stream_id # type: ignore
+ self.update_function = store.get_all_updated_receipts # type: ignore
super(ReceiptsStream, self).__init__(hs)
@@ -286,6 +296,7 @@ class ReceiptsStream(Stream):
class PushRulesStream(Stream):
"""A user has changed their push rules
"""
+
NAME = "push_rules"
ROW_TYPE = PushRulesStreamRow
@@ -297,23 +308,23 @@ class PushRulesStream(Stream):
push_rules_token, _ = self.store.get_push_rules_stream_token()
return push_rules_token
- @defer.inlineCallbacks
- def update_function(self, from_token, to_token, limit):
- rows = yield self.store.get_all_push_rule_updates(from_token, to_token, limit)
- defer.returnValue([(row[0], row[2]) for row in rows])
+ async def update_function(self, from_token, to_token, limit):
+ rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
+ return [(row[0], row[2]) for row in rows]
class PushersStream(Stream):
"""A user has added/changed/removed a pusher
"""
+
NAME = "pushers"
ROW_TYPE = PushersStreamRow
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_pushers_stream_token
- self.update_function = store.get_all_updated_pushers_rows
+ self.current_token = store.get_pushers_stream_token # type: ignore
+ self.update_function = store.get_all_updated_pushers_rows # type: ignore
super(PushersStream, self).__init__(hs)
@@ -322,14 +333,15 @@ class CachesStream(Stream):
"""A cache was invalidated on the master and no other stream would invalidate
the cache on the workers
"""
+
NAME = "caches"
ROW_TYPE = CachesStreamRow
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_cache_stream_token
- self.update_function = store.get_all_updated_caches
+ self.current_token = store.get_cache_stream_token # type: ignore
+ self.update_function = store.get_all_updated_caches # type: ignore
super(CachesStream, self).__init__(hs)
@@ -337,14 +349,15 @@ class CachesStream(Stream):
class PublicRoomsStream(Stream):
"""The public rooms list changed
"""
+
NAME = "public_rooms"
ROW_TYPE = PublicRoomsStreamRow
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_current_public_room_stream_id
- self.update_function = store.get_all_new_public_rooms
+ self.current_token = store.get_current_public_room_stream_id # type: ignore
+ self.update_function = store.get_all_new_public_rooms # type: ignore
super(PublicRoomsStream, self).__init__(hs)
@@ -352,6 +365,7 @@ class PublicRoomsStream(Stream):
class DeviceListsStream(Stream):
"""Someone added/changed/removed a device
"""
+
NAME = "device_lists"
_LIMITED = False
ROW_TYPE = DeviceListsStreamRow
@@ -359,8 +373,8 @@ class DeviceListsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_device_stream_token
- self.update_function = store.get_all_device_list_changes_for_remotes
+ self.current_token = store.get_device_stream_token # type: ignore
+ self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore
super(DeviceListsStream, self).__init__(hs)
@@ -368,14 +382,15 @@ class DeviceListsStream(Stream):
class ToDeviceStream(Stream):
"""New to_device messages for a client
"""
+
NAME = "to_device"
ROW_TYPE = ToDeviceStreamRow
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_to_device_stream_token
- self.update_function = store.get_all_new_device_messages
+ self.current_token = store.get_to_device_stream_token # type: ignore
+ self.update_function = store.get_all_new_device_messages # type: ignore
super(ToDeviceStream, self).__init__(hs)
@@ -383,14 +398,15 @@ class ToDeviceStream(Stream):
class TagAccountDataStream(Stream):
"""Someone added/removed a tag for a room
"""
+
NAME = "tag_account_data"
ROW_TYPE = TagAccountDataStreamRow
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_max_account_data_stream_id
- self.update_function = store.get_all_updated_tags
+ self.current_token = store.get_max_account_data_stream_id # type: ignore
+ self.update_function = store.get_all_updated_tags # type: ignore
super(TagAccountDataStream, self).__init__(hs)
@@ -398,29 +414,29 @@ class TagAccountDataStream(Stream):
class AccountDataStream(Stream):
"""Global or per room account data was changed
"""
+
NAME = "account_data"
ROW_TYPE = AccountDataStreamRow
def __init__(self, hs):
self.store = hs.get_datastore()
- self.current_token = self.store.get_max_account_data_stream_id
+ self.current_token = self.store.get_max_account_data_stream_id # type: ignore
super(AccountDataStream, self).__init__(hs)
- @defer.inlineCallbacks
- def update_function(self, from_token, to_token, limit):
- global_results, room_results = yield self.store.get_all_updated_account_data(
+ async def update_function(self, from_token, to_token, limit):
+ global_results, room_results = await self.store.get_all_updated_account_data(
from_token, from_token, to_token, limit
)
results = list(room_results)
results.extend(
- (stream_id, user_id, None, account_data_type, content,)
- for stream_id, user_id, account_data_type, content in global_results
+ (stream_id, user_id, None, account_data_type)
+ for stream_id, user_id, account_data_type in global_results
)
- defer.returnValue(results)
+ return results
class GroupServerStream(Stream):
@@ -430,7 +446,24 @@ class GroupServerStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_group_stream_token
- self.update_function = store.get_all_groups_changes
+ self.current_token = store.get_group_stream_token # type: ignore
+ self.update_function = store.get_all_groups_changes # type: ignore
super(GroupServerStream, self).__init__(hs)
+
+
+class UserSignatureStream(Stream):
+ """A user has signed their own device with their user-signing key
+ """
+
+ NAME = "user_signature"
+ _LIMITED = False
+ ROW_TYPE = UserSignatureStreamRow
+
+ def __init__(self, hs):
+ store = hs.get_datastore()
+
+ self.current_token = store.get_device_stream_token # type: ignore
+ self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore
+
+ super(UserSignatureStream, self).__init__(hs)
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index f1290d022a..b3afabb8cd 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -13,12 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import heapq
+from typing import Tuple, Type
import attr
-from twisted.internet import defer
-
from ._base import Stream
@@ -52,6 +52,7 @@ data part are:
@attr.s(slots=True, frozen=True)
class EventsStreamRow(object):
"""A parsed row from the events replication stream"""
+
type = attr.ib() # str: the TypeId of one of the *EventsStreamRows
data = attr.ib() # BaseEventsStreamRow
@@ -62,7 +63,8 @@ class BaseEventsStreamRow(object):
Specifies how to identify, serialize and deserialize the different types.
"""
- TypeId = None # Unique string that ids the type. Must be overriden in sub classes.
+ # Unique string that ids the type. Must be overriden in sub classes.
+ TypeId = None # type: str
@classmethod
def from_data(cls, data):
@@ -80,11 +82,11 @@ class BaseEventsStreamRow(object):
class EventsStreamEventRow(BaseEventsStreamRow):
TypeId = "ev"
- event_id = attr.ib() # str
- room_id = attr.ib() # str
- type = attr.ib() # str
- state_key = attr.ib() # str, optional
- redacts = attr.ib() # str, optional
+ event_id = attr.ib() # str
+ room_id = attr.ib() # str
+ type = attr.ib() # str
+ state_key = attr.ib() # str, optional
+ redacts = attr.ib() # str, optional
relates_to = attr.ib() # str, optional
@@ -92,53 +94,50 @@ class EventsStreamEventRow(BaseEventsStreamRow):
class EventsStreamCurrentStateRow(BaseEventsStreamRow):
TypeId = "state"
- room_id = attr.ib() # str
- type = attr.ib() # str
+ room_id = attr.ib() # str
+ type = attr.ib() # str
state_key = attr.ib() # str
- event_id = attr.ib() # str, optional
+ event_id = attr.ib() # str, optional
-TypeToRow = {
- Row.TypeId: Row
- for Row in (
- EventsStreamEventRow,
- EventsStreamCurrentStateRow,
- )
-}
+_EventRows = (
+ EventsStreamEventRow,
+ EventsStreamCurrentStateRow,
+) # type: Tuple[Type[BaseEventsStreamRow], ...]
+
+TypeToRow = {Row.TypeId: Row for Row in _EventRows}
class EventsStream(Stream):
"""We received a new event, or an event went from being an outlier to not
"""
+
NAME = "events"
def __init__(self, hs):
self._store = hs.get_datastore()
- self.current_token = self._store.get_current_events_token
+ self.current_token = self._store.get_current_events_token # type: ignore
super(EventsStream, self).__init__(hs)
- @defer.inlineCallbacks
- def update_function(self, from_token, current_token, limit=None):
- event_rows = yield self._store.get_all_new_forward_event_rows(
- from_token, current_token, limit,
+ async def update_function(self, from_token, current_token, limit=None):
+ event_rows = await self._store.get_all_new_forward_event_rows(
+ from_token, current_token, limit
)
event_updates = (
- (row[0], EventsStreamEventRow.TypeId, row[1:])
- for row in event_rows
+ (row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows
)
- state_rows = yield self._store.get_all_updated_current_state_deltas(
+ state_rows = await self._store.get_all_updated_current_state_deltas(
from_token, current_token, limit
)
state_updates = (
- (row[0], EventsStreamCurrentStateRow.TypeId, row[1:])
- for row in state_rows
+ (row[0], EventsStreamCurrentStateRow.TypeId, row[1:]) for row in state_rows
)
all_updates = heapq.merge(event_updates, state_updates)
- defer.returnValue(all_updates)
+ return all_updates
@classmethod
def parse_row(cls, row):
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 9aa43aa8d2..615f3dc9ac 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -17,23 +17,27 @@ from collections import namedtuple
from ._base import Stream
-FederationStreamRow = namedtuple("FederationStreamRow", (
- "type", # str, the type of data as defined in the BaseFederationRows
- "data", # dict, serialization of a federation.send_queue.BaseFederationRow
-))
+FederationStreamRow = namedtuple(
+ "FederationStreamRow",
+ (
+ "type", # str, the type of data as defined in the BaseFederationRows
+ "data", # dict, serialization of a federation.send_queue.BaseFederationRow
+ ),
+)
class FederationStream(Stream):
"""Data to be sent over federation. Only available when master has federation
sending disabled.
"""
+
NAME = "federation"
ROW_TYPE = FederationStreamRow
def __init__(self, hs):
federation_sender = hs.get_federation_sender()
- self.current_token = federation_sender.get_current_token
- self.update_function = federation_sender.get_replication_rows
+ self.current_token = federation_sender.get_current_token # type: ignore
+ self.update_function = federation_sender.get_replication_rows # type: ignore
super(FederationStream, self).__init__(hs)
|