diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index f6a38f5140..ce9d1fae12 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -17,12 +17,12 @@
import logging
import random
+from typing import Any, List
from six import itervalues
from prometheus_client import Counter
-from twisted.internet import defer
from twisted.internet.protocol import Factory
from synapse.metrics import LaterGauge
@@ -33,13 +33,15 @@ from .protocol import ServerReplicationStreamProtocol
from .streams import STREAMS_MAP
from .streams.federation import FederationStream
-stream_updates_counter = Counter("synapse_replication_tcp_resource_stream_updates",
- "", ["stream_name"])
+stream_updates_counter = Counter(
+ "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"]
+)
user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
-invalidate_cache_counter = Counter("synapse_replication_tcp_resource_invalidate_cache",
- "")
+invalidate_cache_counter = Counter(
+ "synapse_replication_tcp_resource_invalidate_cache", ""
+)
user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
logger = logging.getLogger(__name__)
@@ -48,6 +50,7 @@ logger = logging.getLogger(__name__)
class ReplicationStreamProtocolFactory(Factory):
"""Factory for new replication connections.
"""
+
def __init__(self, hs):
self.streamer = ReplicationStreamer(hs)
self.clock = hs.get_clock()
@@ -55,9 +58,7 @@ class ReplicationStreamProtocolFactory(Factory):
def buildProtocol(self, addr):
return ServerReplicationStreamProtocol(
- self.server_name,
- self.clock,
- self.streamer,
+ self.server_name, self.clock, self.streamer
)
@@ -78,37 +79,48 @@ class ReplicationStreamer(object):
self._replication_torture_level = hs.config.replication_torture_level
# Current connections.
- self.connections = []
+ self.connections = [] # type: List[ServerReplicationStreamProtocol]
- LaterGauge("synapse_replication_tcp_resource_total_connections", "", [],
- lambda: len(self.connections))
+ LaterGauge(
+ "synapse_replication_tcp_resource_total_connections",
+ "",
+ [],
+ lambda: len(self.connections),
+ )
# List of streams that clients can subscribe to.
# We only support federation stream if federation sending hase been
# disabled on the master.
self.streams = [
- stream(hs) for stream in itervalues(STREAMS_MAP)
+ stream(hs)
+ for stream in itervalues(STREAMS_MAP)
if stream != FederationStream or not hs.config.send_federation
]
self.streams_by_name = {stream.NAME: stream for stream in self.streams}
LaterGauge(
- "synapse_replication_tcp_resource_connections_per_stream", "",
+ "synapse_replication_tcp_resource_connections_per_stream",
+ "",
["stream_name"],
lambda: {
- (stream_name,): len([
- conn for conn in self.connections
- if stream_name in conn.replication_streams
- ])
+ (stream_name,): len(
+ [
+ conn
+ for conn in self.connections
+ if stream_name in conn.replication_streams
+ ]
+ )
for stream_name in self.streams_by_name
- })
+ },
+ )
self.federation_sender = None
if not hs.config.send_federation:
self.federation_sender = hs.get_federation_sender()
self.notifier.add_replication_callback(self.on_notifier_poke)
+ self.notifier.add_remote_server_up_callback(self.send_remote_server_up)
# Keeps track of whether we are currently checking for updates
self.is_looping = False
@@ -143,8 +155,7 @@ class ReplicationStreamer(object):
run_as_background_process("replication_notifier", self._run_notifier_loop)
- @defer.inlineCallbacks
- def _run_notifier_loop(self):
+ async def _run_notifier_loop(self):
self.is_looping = True
try:
@@ -173,23 +184,26 @@ class ReplicationStreamer(object):
continue
if self._replication_torture_level:
- yield self.clock.sleep(
+ await self.clock.sleep(
self._replication_torture_level / 1000.0
)
logger.debug(
"Getting stream: %s: %s -> %s",
- stream.NAME, stream.last_token, stream.upto_token
+ stream.NAME,
+ stream.last_token,
+ stream.upto_token,
)
try:
- updates, current_token = yield stream.get_updates()
+ updates, current_token = await stream.get_updates()
except Exception:
logger.info("Failed to handle stream %s", stream.NAME)
raise
logger.debug(
"Sending %d updates to %d connections",
- len(updates), len(self.connections),
+ len(updates),
+ len(self.connections),
)
if updates:
@@ -218,7 +232,7 @@ class ReplicationStreamer(object):
self.is_looping = False
@measure_func("repl.get_stream_updates")
- def get_stream_updates(self, stream_name, token):
+ async def get_stream_updates(self, stream_name, token):
"""For a given stream get all updates since token. This is called when
a client first subscribes to a stream.
"""
@@ -226,7 +240,7 @@ class ReplicationStreamer(object):
if not stream:
raise Exception("unknown stream %s", stream_name)
- return stream.get_updates_since(token)
+ return await stream.get_updates_since(token)
@measure_func("repl.federation_ack")
def federation_ack(self, token):
@@ -237,44 +251,54 @@ class ReplicationStreamer(object):
self.federation_sender.federation_ack(token)
@measure_func("repl.on_user_sync")
- @defer.inlineCallbacks
- def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
+ async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
"""A client has started/stopped syncing on a worker.
"""
user_sync_counter.inc()
- yield self.presence_handler.update_external_syncs_row(
- conn_id, user_id, is_syncing, last_sync_ms,
+ await self.presence_handler.update_external_syncs_row(
+ conn_id, user_id, is_syncing, last_sync_ms
)
@measure_func("repl.on_remove_pusher")
- @defer.inlineCallbacks
- def on_remove_pusher(self, app_id, push_key, user_id):
+ async def on_remove_pusher(self, app_id, push_key, user_id):
"""A client has asked us to remove a pusher
"""
remove_pusher_counter.inc()
- yield self.store.delete_pusher_by_app_id_pushkey_user_id(
+ await self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id=app_id, pushkey=push_key, user_id=user_id
)
self.notifier.on_new_replication_data()
@measure_func("repl.on_invalidate_cache")
- def on_invalidate_cache(self, cache_func, keys):
+ async def on_invalidate_cache(self, cache_func: str, keys: List[Any]):
"""The client has asked us to invalidate a cache
"""
invalidate_cache_counter.inc()
- getattr(self.store, cache_func).invalidate(tuple(keys))
+
+ # We invalidate the cache locally, but then also stream that to other
+ # workers.
+ await self.store.invalidate_cache_and_stream(cache_func, tuple(keys))
@measure_func("repl.on_user_ip")
- @defer.inlineCallbacks
- def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
+ async def on_user_ip(
+ self, user_id, access_token, ip, user_agent, device_id, last_seen
+ ):
"""The client saw a user request
"""
user_ip_cache_counter.inc()
- yield self.store.insert_client_ip(
- user_id, access_token, ip, user_agent, device_id, last_seen,
+ await self.store.insert_client_ip(
+ user_id, access_token, ip, user_agent, device_id, last_seen
)
- yield self._server_notices_sender.on_user_ip(user_id)
+ await self._server_notices_sender.on_user_ip(user_id)
+
+ @measure_func("repl.on_remote_server_up")
+ def on_remote_server_up(self, server: str):
+ self.notifier.notify_remote_server_up(server)
+
+ def send_remote_server_up(self, server: str):
+ for conn in self.connections:
+ conn.send_remote_server_up(server)
def send_sync_to_all_connections(self, data):
"""Sends a SYNC command to all clients.
@@ -299,7 +323,11 @@ class ReplicationStreamer(object):
# We need to tell the presence handler that the connection has been
# lost so that it can handle any ongoing syncs on that connection.
- self.presence_handler.update_external_syncs_clear(connection.conn_id)
+ run_as_background_process(
+ "update_external_syncs_clear",
+ self.presence_handler.update_external_syncs_clear,
+ connection.conn_id,
+ )
def _batch_updates(updates):
|