diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index c8056b0c0c..444eb7b7f4 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -16,6 +16,7 @@
import abc
import logging
import re
+from typing import Dict, List, Tuple
from six import raise_from
from six.moves import urllib
@@ -78,9 +79,8 @@ class ReplicationEndpoint(object):
__metaclass__ = abc.ABCMeta
- NAME = abc.abstractproperty()
- PATH_ARGS = abc.abstractproperty()
-
+ NAME = abc.abstractproperty() # type: str # type: ignore
+ PATH_ARGS = abc.abstractproperty() # type: Tuple[str, ...] # type: ignore
METHOD = "POST"
CACHE = True
RETRY_ON_TIMEOUT = True
@@ -171,7 +171,7 @@ class ReplicationEndpoint(object):
# have a good idea that the request has either succeeded or failed on
# the master, and so whether we should clean up or not.
while True:
- headers = {}
+ headers = {} # type: Dict[bytes, List[bytes]]
inject_active_span_byte_dict(headers, None, check_destination=False)
try:
result = yield request_func(uri, data, headers=headers)
@@ -207,7 +207,7 @@ class ReplicationEndpoint(object):
method = self.METHOD
if self.CACHE:
- handler = self._cached_handler
+ handler = self._cached_handler # type: ignore
url_args.append("txn_id")
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index b91a528245..f45cbd37a0 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Dict
+from typing import Dict, Optional
import six
@@ -41,7 +41,7 @@ class BaseSlavedStore(SQLBaseStore):
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = SlavedIdTracker(
db_conn, "cache_invalidation_stream", "stream_id"
- )
+ ) # type: Optional[SlavedIdTracker]
else:
self._cache_id_gen = None
@@ -62,14 +62,20 @@ class BaseSlavedStore(SQLBaseStore):
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "caches":
- self._cache_id_gen.advance(token)
+ if self._cache_id_gen:
+ self._cache_id_gen.advance(token)
for row in rows:
if row.cache_func == CURRENT_STATE_CACHE_NAME:
+ if row.keys is None:
+ raise Exception(
+ "Can't send an 'invalidate all' for current state cache"
+ )
+
room_id = row.keys[0]
members_changed = set(row.keys[1:])
self._invalidate_state_caches(room_id, members_changed)
else:
- self._attempt_to_invalidate_cache(row.cache_func, tuple(row.keys))
+ self._attempt_to_invalidate_cache(row.cache_func, row.keys)
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
txn.call_after(cache_func.invalidate, keys)
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 29f35b9915..3aa6cb8b96 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -152,7 +152,7 @@ class SlavedEventStore(
if etype == EventTypes.Member:
self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
- self.get_invited_rooms_for_user.invalidate((state_key,))
+ self.get_invited_rooms_for_local_user.invalidate((state_key,))
if relates_to:
self.get_relations_for_event.invalidate_many((relates_to,))
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index f552e7c972..ad8f0c15a9 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -29,7 +29,7 @@ class SlavedPresenceStore(BaseSlavedStore):
self._presence_on_startup = self._get_active_presence(db_conn)
- self.presence_stream_cache = self.presence_stream_cache = StreamChangeCache(
+ self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache", self._presence_id_gen.get_current_token()
)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index bbcb84646c..fc06a7b053 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -16,7 +16,7 @@
"""
import logging
-from typing import Dict
+from typing import Dict, List, Optional
from twisted.internet import defer
from twisted.internet.protocol import ReconnectingClientFactory
@@ -28,6 +28,7 @@ from synapse.replication.tcp.protocol import (
)
from .commands import (
+ Command,
FederationAckCommand,
InvalidateCacheCommand,
RemovePusherCommand,
@@ -89,15 +90,15 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
# Any pending commands to be sent once a new connection has been
# established
- self.pending_commands = []
+ self.pending_commands = [] # type: List[Command]
# Map from string -> deferred, to wake up when receiveing a SYNC with
# the given string.
# Used for tests.
- self.awaiting_syncs = {}
+ self.awaiting_syncs = {} # type: Dict[str, defer.Deferred]
# The factory used to create connections.
- self.factory = None
+ self.factory = None # type: Optional[ReplicationClientFactory]
def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server
@@ -109,7 +110,7 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self.factory)
- def on_rdata(self, stream_name, token, rows):
+ async def on_rdata(self, stream_name, token, rows):
"""Called to handle a batch of replication data with a given stream token.
By default this just pokes the slave store. Can be overridden in subclasses to
@@ -120,20 +121,17 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
token (int): stream token for this batch of rows
rows (list): a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
-
- Returns:
- Deferred|None
"""
logger.debug("Received rdata %s -> %s", stream_name, token)
- return self.store.process_replication_rows(stream_name, token, rows)
+ self.store.process_replication_rows(stream_name, token, rows)
- def on_position(self, stream_name, token):
+ async def on_position(self, stream_name, token):
"""Called when we get new position data. By default this just pokes
the slave store.
Can be overriden in subclasses to handle more.
"""
- return self.store.process_replication_rows(stream_name, token, [])
+ self.store.process_replication_rows(stream_name, token, [])
def on_sync(self, data):
"""When we received a SYNC we wake up any deferreds that were waiting
@@ -145,6 +143,9 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
if d:
d.callback(data)
+ def on_remote_server_up(self, server: str):
+ """Called when get a new REMOTE_SERVER_UP command."""
+
def get_streams_to_replicate(self) -> Dict[str, int]:
"""Called when a new connection has been established and we need to
subscribe to streams.
@@ -235,4 +236,5 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
# We don't reset the delay any earlier as otherwise if there is a
# problem during start up we'll end up tight looping connecting to the
# server.
- self.factory.resetDelay()
+ if self.factory:
+ self.factory.resetDelay()
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 0ff2a7199f..451671412d 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -20,15 +20,16 @@ allowed to be sent by which side.
import logging
import platform
+from typing import Tuple, Type
if platform.python_implementation() == "PyPy":
import json
_json_encoder = json.JSONEncoder()
else:
- import simplejson as json
+ import simplejson as json # type: ignore[no-redef] # noqa: F821
- _json_encoder = json.JSONEncoder(namedtuple_as_object=False)
+ _json_encoder = json.JSONEncoder(namedtuple_as_object=False) # type: ignore[call-arg] # noqa: F821
logger = logging.getLogger(__name__)
@@ -44,7 +45,7 @@ class Command(object):
The default implementation creates a command of form `<NAME> <data>`
"""
- NAME = None
+ NAME = None # type: str
def __init__(self, data):
self.data = data
@@ -386,25 +387,39 @@ class UserIpCommand(Command):
)
+class RemoteServerUpCommand(Command):
+ """Sent when a worker has detected that a remote server is no longer
+ "down" and retry timings should be reset.
+
+ If sent from a client the server will relay to all other workers.
+
+ Format::
+
+ REMOTE_SERVER_UP <server>
+ """
+
+ NAME = "REMOTE_SERVER_UP"
+
+
+_COMMANDS = (
+ ServerCommand,
+ RdataCommand,
+ PositionCommand,
+ ErrorCommand,
+ PingCommand,
+ NameCommand,
+ ReplicateCommand,
+ UserSyncCommand,
+ FederationAckCommand,
+ SyncCommand,
+ RemovePusherCommand,
+ InvalidateCacheCommand,
+ UserIpCommand,
+ RemoteServerUpCommand,
+) # type: Tuple[Type[Command], ...]
+
# Map of command name to command type.
-COMMAND_MAP = {
- cmd.NAME: cmd
- for cmd in (
- ServerCommand,
- RdataCommand,
- PositionCommand,
- ErrorCommand,
- PingCommand,
- NameCommand,
- ReplicateCommand,
- UserSyncCommand,
- FederationAckCommand,
- SyncCommand,
- RemovePusherCommand,
- InvalidateCacheCommand,
- UserIpCommand,
- )
-}
+COMMAND_MAP = {cmd.NAME: cmd for cmd in _COMMANDS}
# The commands the server is allowed to send
VALID_SERVER_COMMANDS = (
@@ -414,6 +429,7 @@ VALID_SERVER_COMMANDS = (
ErrorCommand.NAME,
PingCommand.NAME,
SyncCommand.NAME,
+ RemoteServerUpCommand.NAME,
)
# The commands the client is allowed to send
@@ -427,4 +443,5 @@ VALID_CLIENT_COMMANDS = (
InvalidateCacheCommand.NAME,
UserIpCommand.NAME,
ErrorCommand.NAME,
+ RemoteServerUpCommand.NAME,
)
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index afaf002fe6..131e5acb09 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -53,6 +53,7 @@ import fcntl
import logging
import struct
from collections import defaultdict
+from typing import Any, DefaultDict, Dict, List, Set, Tuple
from six import iteritems, iterkeys
@@ -65,24 +66,26 @@ from twisted.python.failure import Failure
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.util import Clock
-from synapse.util.stringutils import random_string
-
-from .commands import (
+from synapse.replication.tcp.commands import (
COMMAND_MAP,
VALID_CLIENT_COMMANDS,
VALID_SERVER_COMMANDS,
+ Command,
ErrorCommand,
NameCommand,
PingCommand,
PositionCommand,
RdataCommand,
+ RemoteServerUpCommand,
ReplicateCommand,
ServerCommand,
SyncCommand,
UserSyncCommand,
)
-from .streams import STREAMS_MAP
+from synapse.replication.tcp.streams import STREAMS_MAP
+from synapse.types import Collection
+from synapse.util import Clock
+from synapse.util.stringutils import random_string
connection_close_counter = Counter(
"synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
@@ -124,8 +127,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
delimiter = b"\n"
- VALID_INBOUND_COMMANDS = [] # Valid commands we expect to receive
- VALID_OUTBOUND_COMMANDS = [] # Valid commans we can send
+ # Valid commands we expect to receive
+ VALID_INBOUND_COMMANDS = [] # type: Collection[str]
+
+ # Valid commands we can send
+ VALID_OUTBOUND_COMMANDS = [] # type: Collection[str]
max_line_buffer = 10000
@@ -144,13 +150,13 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.conn_id = random_string(5) # To dedupe in case of name clashes.
# List of pending commands to send once we've established the connection
- self.pending_commands = []
+ self.pending_commands = [] # type: List[Command]
# The LoopingCall for sending pings.
self._send_ping_loop = None
- self.inbound_commands_counter = defaultdict(int)
- self.outbound_commands_counter = defaultdict(int)
+ self.inbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int]
+ self.outbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int]
def connectionMade(self):
logger.info("[%s] Connection established", self.id())
@@ -235,19 +241,16 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
"replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
)
- def handle_command(self, cmd):
+ async def handle_command(self, cmd: Command):
"""Handle a command we have received over the replication stream.
- By default delegates to on_<COMMAND>
+ By default delegates to on_<COMMAND>, which should return an awaitable.
Args:
- cmd (synapse.replication.tcp.commands.Command): received command
-
- Returns:
- Deferred
+ cmd: received command
"""
handler = getattr(self, "on_%s" % (cmd.NAME,))
- return handler(cmd)
+ await handler(cmd)
def close(self):
logger.warning("[%s] Closing connection", self.id())
@@ -320,10 +323,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
for cmd in pending:
self.send_command(cmd)
- def on_PING(self, line):
+ async def on_PING(self, line):
self.received_ping = True
- def on_ERROR(self, cmd):
+ async def on_ERROR(self, cmd):
logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
def pauseProducing(self):
@@ -409,30 +412,30 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.streamer = streamer
# The streams the client has subscribed to and is up to date with
- self.replication_streams = set()
+ self.replication_streams = set() # type: Set[str]
# The streams the client is currently subscribing to.
- self.connecting_streams = set()
+ self.connecting_streams = set() # type: Set[str]
# Map from stream name to list of updates to send once we've finished
# subscribing the client to the stream.
- self.pending_rdata = {}
+ self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]]
def connectionMade(self):
self.send_command(ServerCommand(self.server_name))
BaseReplicationStreamProtocol.connectionMade(self)
self.streamer.new_connection(self)
- def on_NAME(self, cmd):
+ async def on_NAME(self, cmd):
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
self.name = cmd.data
- def on_USER_SYNC(self, cmd):
- return self.streamer.on_user_sync(
+ async def on_USER_SYNC(self, cmd):
+ await self.streamer.on_user_sync(
self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)
- def on_REPLICATE(self, cmd):
+ async def on_REPLICATE(self, cmd):
stream_name = cmd.stream_name
token = cmd.token
@@ -443,23 +446,26 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
for stream in iterkeys(self.streamer.streams_by_name)
]
- return make_deferred_yieldable(
+ await make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True)
)
else:
- return self.subscribe_to_stream(stream_name, token)
+ await self.subscribe_to_stream(stream_name, token)
- def on_FEDERATION_ACK(self, cmd):
- return self.streamer.federation_ack(cmd.token)
+ async def on_FEDERATION_ACK(self, cmd):
+ self.streamer.federation_ack(cmd.token)
- def on_REMOVE_PUSHER(self, cmd):
- return self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
+ async def on_REMOVE_PUSHER(self, cmd):
+ await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
- def on_INVALIDATE_CACHE(self, cmd):
- return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
+ async def on_INVALIDATE_CACHE(self, cmd):
+ self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
- def on_USER_IP(self, cmd):
- return self.streamer.on_user_ip(
+ async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
+ self.streamer.on_remote_server_up(cmd.data)
+
+ async def on_USER_IP(self, cmd):
+ self.streamer.on_user_ip(
cmd.user_id,
cmd.access_token,
cmd.ip,
@@ -468,8 +474,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
cmd.last_seen,
)
- @defer.inlineCallbacks
- def subscribe_to_stream(self, stream_name, token):
+ async def subscribe_to_stream(self, stream_name, token):
"""Subscribe the remote to a stream.
This invloves checking if they've missed anything and sending those
@@ -481,7 +486,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
try:
# Get missing updates
- updates, current_token = yield self.streamer.get_stream_updates(
+ updates, current_token = await self.streamer.get_stream_updates(
stream_name, token
)
@@ -554,6 +559,9 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
def send_sync(self, data):
self.send_command(SyncCommand(data))
+ def send_remote_server_up(self, server: str):
+ self.send_command(RemoteServerUpCommand(server))
+
def on_connection_closed(self):
BaseReplicationStreamProtocol.on_connection_closed(self)
self.streamer.lost_connection(self)
@@ -566,7 +574,7 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
"""
@abc.abstractmethod
- def on_rdata(self, stream_name, token, rows):
+ async def on_rdata(self, stream_name, token, rows):
"""Called to handle a batch of replication data with a given stream token.
Args:
@@ -574,14 +582,11 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
token (int): stream token for this batch of rows
rows (list): a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
-
- Returns:
- Deferred|None
"""
raise NotImplementedError()
@abc.abstractmethod
- def on_position(self, stream_name, token):
+ async def on_position(self, stream_name, token):
"""Called when we get new position data."""
raise NotImplementedError()
@@ -591,6 +596,11 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
raise NotImplementedError()
@abc.abstractmethod
+ async def on_remote_server_up(self, server: str):
+ """Called when get a new REMOTE_SERVER_UP command."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
def get_streams_to_replicate(self):
"""Called when a new connection has been established and we need to
subscribe to streams.
@@ -642,11 +652,11 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# Set of stream names that have been subscribe to, but haven't yet
# caught up with. This is used to track when the client has been fully
# connected to the remote.
- self.streams_connecting = set()
+ self.streams_connecting = set() # type: Set[str]
# Map of stream to batched updates. See RdataCommand for info on how
# batching works.
- self.pending_batches = {}
+ self.pending_batches = {} # type: Dict[str, Any]
def connectionMade(self):
self.send_command(NameCommand(self.client_name))
@@ -670,12 +680,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
if not self.streams_connecting:
self.handler.finished_connecting()
- def on_SERVER(self, cmd):
+ async def on_SERVER(self, cmd):
if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
self.send_error("Wrong remote")
- def on_RDATA(self, cmd):
+ async def on_RDATA(self, cmd):
stream_name = cmd.stream_name
inbound_rdata_count.labels(stream_name).inc()
@@ -695,19 +705,22 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# Check if this is the last of a batch of updates
rows = self.pending_batches.pop(stream_name, [])
rows.append(row)
- return self.handler.on_rdata(stream_name, cmd.token, rows)
+ await self.handler.on_rdata(stream_name, cmd.token, rows)
- def on_POSITION(self, cmd):
+ async def on_POSITION(self, cmd):
# When we get a `POSITION` command it means we've finished getting
# missing updates for the given stream, and are now up to date.
self.streams_connecting.discard(cmd.stream_name)
if not self.streams_connecting:
self.handler.finished_connecting()
- return self.handler.on_position(cmd.stream_name, cmd.token)
+ await self.handler.on_position(cmd.stream_name, cmd.token)
+
+ async def on_SYNC(self, cmd):
+ self.handler.on_sync(cmd.data)
- def on_SYNC(self, cmd):
- return self.handler.on_sync(cmd.data)
+ async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
+ self.handler.on_remote_server_up(cmd.data)
def replicate(self, stream_name, token):
"""Send the subscription request to the server
@@ -766,7 +779,7 @@ def transport_kernel_read_buffer_size(protocol, read=True):
op = SIOCINQ
else:
op = SIOCOUTQ
- size = struct.unpack("I", fcntl.ioctl(fileno, op, "\0\0\0\0"))[0]
+ size = struct.unpack("I", fcntl.ioctl(fileno, op, b"\0\0\0\0"))[0]
return size
return 0
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index d1e98428bc..6ebf944f66 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -17,12 +17,12 @@
import logging
import random
+from typing import List
from six import itervalues
from prometheus_client import Counter
-from twisted.internet import defer
from twisted.internet.protocol import Factory
from synapse.metrics import LaterGauge
@@ -79,7 +79,7 @@ class ReplicationStreamer(object):
self._replication_torture_level = hs.config.replication_torture_level
# Current connections.
- self.connections = []
+ self.connections = [] # type: List[ServerReplicationStreamProtocol]
LaterGauge(
"synapse_replication_tcp_resource_total_connections",
@@ -120,6 +120,7 @@ class ReplicationStreamer(object):
self.federation_sender = hs.get_federation_sender()
self.notifier.add_replication_callback(self.on_notifier_poke)
+ self.notifier.add_remote_server_up_callback(self.send_remote_server_up)
# Keeps track of whether we are currently checking for updates
self.is_looping = False
@@ -154,8 +155,7 @@ class ReplicationStreamer(object):
run_as_background_process("replication_notifier", self._run_notifier_loop)
- @defer.inlineCallbacks
- def _run_notifier_loop(self):
+ async def _run_notifier_loop(self):
self.is_looping = True
try:
@@ -184,7 +184,7 @@ class ReplicationStreamer(object):
continue
if self._replication_torture_level:
- yield self.clock.sleep(
+ await self.clock.sleep(
self._replication_torture_level / 1000.0
)
@@ -195,7 +195,7 @@ class ReplicationStreamer(object):
stream.upto_token,
)
try:
- updates, current_token = yield stream.get_updates()
+ updates, current_token = await stream.get_updates()
except Exception:
logger.info("Failed to handle stream %s", stream.NAME)
raise
@@ -232,7 +232,7 @@ class ReplicationStreamer(object):
self.is_looping = False
@measure_func("repl.get_stream_updates")
- def get_stream_updates(self, stream_name, token):
+ async def get_stream_updates(self, stream_name, token):
"""For a given stream get all updates since token. This is called when
a client first subscribes to a stream.
"""
@@ -240,7 +240,7 @@ class ReplicationStreamer(object):
if not stream:
raise Exception("unknown stream %s", stream_name)
- return stream.get_updates_since(token)
+ return await stream.get_updates_since(token)
@measure_func("repl.federation_ack")
def federation_ack(self, token):
@@ -251,22 +251,20 @@ class ReplicationStreamer(object):
self.federation_sender.federation_ack(token)
@measure_func("repl.on_user_sync")
- @defer.inlineCallbacks
- def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
+ async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
"""A client has started/stopped syncing on a worker.
"""
user_sync_counter.inc()
- yield self.presence_handler.update_external_syncs_row(
+ await self.presence_handler.update_external_syncs_row(
conn_id, user_id, is_syncing, last_sync_ms
)
@measure_func("repl.on_remove_pusher")
- @defer.inlineCallbacks
- def on_remove_pusher(self, app_id, push_key, user_id):
+ async def on_remove_pusher(self, app_id, push_key, user_id):
"""A client has asked us to remove a pusher
"""
remove_pusher_counter.inc()
- yield self.store.delete_pusher_by_app_id_pushkey_user_id(
+ await self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id=app_id, pushkey=push_key, user_id=user_id
)
@@ -280,15 +278,24 @@ class ReplicationStreamer(object):
getattr(self.store, cache_func).invalidate(tuple(keys))
@measure_func("repl.on_user_ip")
- @defer.inlineCallbacks
- def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
+ async def on_user_ip(
+ self, user_id, access_token, ip, user_agent, device_id, last_seen
+ ):
"""The client saw a user request
"""
user_ip_cache_counter.inc()
- yield self.store.insert_client_ip(
+ await self.store.insert_client_ip(
user_id, access_token, ip, user_agent, device_id, last_seen
)
- yield self._server_notices_sender.on_user_ip(user_id)
+ await self._server_notices_sender.on_user_ip(user_id)
+
+ @measure_func("repl.on_remote_server_up")
+ def on_remote_server_up(self, server: str):
+ self.notifier.notify_remote_server_up(server)
+
+ def send_remote_server_up(self, server: str):
+ for conn in self.connections:
+ conn.send_remote_server_up(server)
def send_sync_to_all_connections(self, data):
"""Sends a SYNC command to all clients.
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 8512923eae..a8d568b14a 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -14,12 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import itertools
import logging
from collections import namedtuple
+from typing import Any, List, Optional
-from twisted.internet import defer
+import attr
logger = logging.getLogger(__name__)
@@ -67,10 +67,24 @@ PushersStreamRow = namedtuple(
"PushersStreamRow",
("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool
)
-CachesStreamRow = namedtuple(
- "CachesStreamRow",
- ("cache_func", "keys", "invalidation_ts"), # str # list(str) # int
-)
+
+
+@attr.s
+class CachesStreamRow:
+ """Stream to inform workers they should invalidate their cache.
+
+ Attributes:
+ cache_func: Name of the cached function.
+ keys: The entry in the cache to invalidate. If None then will
+ invalidate all.
+ invalidation_ts: Timestamp of when the invalidation took place.
+ """
+
+ cache_func = attr.ib(type=str)
+ keys = attr.ib(type=Optional[List[Any]])
+ invalidation_ts = attr.ib(type=int)
+
+
PublicRoomsStreamRow = namedtuple(
"PublicRoomsStreamRow",
(
@@ -104,8 +118,9 @@ class Stream(object):
time it was called up until the point `advance_current_token` was called.
"""
- NAME = None # The name of the stream
- ROW_TYPE = None # The type of the row. Used by the default impl of parse_row.
+ NAME = None # type: str # The name of the stream
+ # The type of the row. Used by the default impl of parse_row.
+ ROW_TYPE = None # type: Any
_LIMITED = True # Whether the update function takes a limit
@classmethod
@@ -143,8 +158,7 @@ class Stream(object):
self.upto_token = self.current_token()
self.last_token = self.upto_token
- @defer.inlineCallbacks
- def get_updates(self):
+ async def get_updates(self):
"""Gets all updates since the last time this function was called (or
since the stream was constructed if it hadn't been called before),
until the `upto_token`
@@ -155,13 +169,12 @@ class Stream(object):
list of ``(token, row)`` entries. ``row`` will be json-serialised and
sent over the replication steam.
"""
- updates, current_token = yield self.get_updates_since(self.last_token)
+ updates, current_token = await self.get_updates_since(self.last_token)
self.last_token = current_token
return updates, current_token
- @defer.inlineCallbacks
- def get_updates_since(self, from_token):
+ async def get_updates_since(self, from_token):
"""Like get_updates except allows specifying from when we should
stream updates
@@ -181,15 +194,16 @@ class Stream(object):
if from_token == current_token:
return [], current_token
+ logger.info("get_updates_since: %s", self.__class__)
if self._LIMITED:
- rows = yield self.update_function(
+ rows = await self.update_function(
from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
)
# never turn more than MAX_EVENTS_BEHIND + 1 into updates.
rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
else:
- rows = yield self.update_function(from_token, current_token)
+ rows = await self.update_function(from_token, current_token)
updates = [(row[0], row[1:]) for row in rows]
@@ -231,8 +245,8 @@ class BackfillStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_current_backfill_token
- self.update_function = store.get_all_new_backfill_event_rows
+ self.current_token = store.get_current_backfill_token # type: ignore
+ self.update_function = store.get_all_new_backfill_event_rows # type: ignore
super(BackfillStream, self).__init__(hs)
@@ -246,8 +260,8 @@ class PresenceStream(Stream):
store = hs.get_datastore()
presence_handler = hs.get_presence_handler()
- self.current_token = store.get_current_presence_token
- self.update_function = presence_handler.get_all_presence_updates
+ self.current_token = store.get_current_presence_token # type: ignore
+ self.update_function = presence_handler.get_all_presence_updates # type: ignore
super(PresenceStream, self).__init__(hs)
@@ -260,8 +274,8 @@ class TypingStream(Stream):
def __init__(self, hs):
typing_handler = hs.get_typing_handler()
- self.current_token = typing_handler.get_current_token
- self.update_function = typing_handler.get_all_typing_updates
+ self.current_token = typing_handler.get_current_token # type: ignore
+ self.update_function = typing_handler.get_all_typing_updates # type: ignore
super(TypingStream, self).__init__(hs)
@@ -273,8 +287,8 @@ class ReceiptsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_max_receipt_stream_id
- self.update_function = store.get_all_updated_receipts
+ self.current_token = store.get_max_receipt_stream_id # type: ignore
+ self.update_function = store.get_all_updated_receipts # type: ignore
super(ReceiptsStream, self).__init__(hs)
@@ -294,9 +308,8 @@ class PushRulesStream(Stream):
push_rules_token, _ = self.store.get_push_rules_stream_token()
return push_rules_token
- @defer.inlineCallbacks
- def update_function(self, from_token, to_token, limit):
- rows = yield self.store.get_all_push_rule_updates(from_token, to_token, limit)
+ async def update_function(self, from_token, to_token, limit):
+ rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
return [(row[0], row[2]) for row in rows]
@@ -310,8 +323,8 @@ class PushersStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_pushers_stream_token
- self.update_function = store.get_all_updated_pushers_rows
+ self.current_token = store.get_pushers_stream_token # type: ignore
+ self.update_function = store.get_all_updated_pushers_rows # type: ignore
super(PushersStream, self).__init__(hs)
@@ -327,8 +340,8 @@ class CachesStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_cache_stream_token
- self.update_function = store.get_all_updated_caches
+ self.current_token = store.get_cache_stream_token # type: ignore
+ self.update_function = store.get_all_updated_caches # type: ignore
super(CachesStream, self).__init__(hs)
@@ -343,8 +356,8 @@ class PublicRoomsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_current_public_room_stream_id
- self.update_function = store.get_all_new_public_rooms
+ self.current_token = store.get_current_public_room_stream_id # type: ignore
+ self.update_function = store.get_all_new_public_rooms # type: ignore
super(PublicRoomsStream, self).__init__(hs)
@@ -360,8 +373,8 @@ class DeviceListsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_device_stream_token
- self.update_function = store.get_all_device_list_changes_for_remotes
+ self.current_token = store.get_device_stream_token # type: ignore
+ self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore
super(DeviceListsStream, self).__init__(hs)
@@ -376,8 +389,8 @@ class ToDeviceStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_to_device_stream_token
- self.update_function = store.get_all_new_device_messages
+ self.current_token = store.get_to_device_stream_token # type: ignore
+ self.update_function = store.get_all_new_device_messages # type: ignore
super(ToDeviceStream, self).__init__(hs)
@@ -392,8 +405,8 @@ class TagAccountDataStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_max_account_data_stream_id
- self.update_function = store.get_all_updated_tags
+ self.current_token = store.get_max_account_data_stream_id # type: ignore
+ self.update_function = store.get_all_updated_tags # type: ignore
super(TagAccountDataStream, self).__init__(hs)
@@ -408,13 +421,12 @@ class AccountDataStream(Stream):
def __init__(self, hs):
self.store = hs.get_datastore()
- self.current_token = self.store.get_max_account_data_stream_id
+ self.current_token = self.store.get_max_account_data_stream_id # type: ignore
super(AccountDataStream, self).__init__(hs)
- @defer.inlineCallbacks
- def update_function(self, from_token, to_token, limit):
- global_results, room_results = yield self.store.get_all_updated_account_data(
+ async def update_function(self, from_token, to_token, limit):
+ global_results, room_results = await self.store.get_all_updated_account_data(
from_token, from_token, to_token, limit
)
@@ -434,8 +446,8 @@ class GroupServerStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_group_stream_token
- self.update_function = store.get_all_groups_changes
+ self.current_token = store.get_group_stream_token # type: ignore
+ self.update_function = store.get_all_groups_changes # type: ignore
super(GroupServerStream, self).__init__(hs)
@@ -451,7 +463,7 @@ class UserSignatureStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_device_stream_token
- self.update_function = store.get_all_user_signature_changes_for_remotes
+ self.current_token = store.get_device_stream_token # type: ignore
+ self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore
super(UserSignatureStream, self).__init__(hs)
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index d97669c886..b3afabb8cd 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -13,12 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import heapq
+from typing import Tuple, Type
import attr
-from twisted.internet import defer
-
from ._base import Stream
@@ -63,7 +63,8 @@ class BaseEventsStreamRow(object):
Specifies how to identify, serialize and deserialize the different types.
"""
- TypeId = None # Unique string that ids the type. Must be overriden in sub classes.
+ # Unique string that ids the type. Must be overriden in sub classes.
+ TypeId = None # type: str
@classmethod
def from_data(cls, data):
@@ -99,9 +100,12 @@ class EventsStreamCurrentStateRow(BaseEventsStreamRow):
event_id = attr.ib() # str, optional
-TypeToRow = {
- Row.TypeId: Row for Row in (EventsStreamEventRow, EventsStreamCurrentStateRow)
-}
+_EventRows = (
+ EventsStreamEventRow,
+ EventsStreamCurrentStateRow,
+) # type: Tuple[Type[BaseEventsStreamRow], ...]
+
+TypeToRow = {Row.TypeId: Row for Row in _EventRows}
class EventsStream(Stream):
@@ -112,20 +116,19 @@ class EventsStream(Stream):
def __init__(self, hs):
self._store = hs.get_datastore()
- self.current_token = self._store.get_current_events_token
+ self.current_token = self._store.get_current_events_token # type: ignore
super(EventsStream, self).__init__(hs)
- @defer.inlineCallbacks
- def update_function(self, from_token, current_token, limit=None):
- event_rows = yield self._store.get_all_new_forward_event_rows(
+ async def update_function(self, from_token, current_token, limit=None):
+ event_rows = await self._store.get_all_new_forward_event_rows(
from_token, current_token, limit
)
event_updates = (
(row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows
)
- state_rows = yield self._store.get_all_updated_current_state_deltas(
+ state_rows = await self._store.get_all_updated_current_state_deltas(
from_token, current_token, limit
)
state_updates = (
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index dc2484109d..615f3dc9ac 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -37,7 +37,7 @@ class FederationStream(Stream):
def __init__(self, hs):
federation_sender = hs.get_federation_sender()
- self.current_token = federation_sender.get_current_token
- self.update_function = federation_sender.get_replication_rows
+ self.current_token = federation_sender.get_current_token # type: ignore
+ self.update_function = federation_sender.get_replication_rows # type: ignore
super(FederationStream, self).__init__(hs)
|