diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 6d2513c4e2..cbe9645817 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -15,17 +15,20 @@
"""A replication client for use by synapse workers.
"""
-from twisted.internet import reactor, defer
+import logging
+
+from twisted.internet import defer
from twisted.internet.protocol import ReconnectingClientFactory
from .commands import (
- FederationAckCommand, UserSyncCommand, RemovePusherCommand, InvalidateCacheCommand,
+ FederationAckCommand,
+ InvalidateCacheCommand,
+ RemovePusherCommand,
UserIpCommand,
+ UserSyncCommand,
)
from .protocol import ClientReplicationStreamProtocol
-import logging
-
logger = logging.getLogger(__name__)
@@ -44,7 +47,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
self.server_name = hs.config.server_name
self._clock = hs.get_clock() # As self.clock is defined in super class
- reactor.addSystemEventTrigger("before", "shutdown", self.stopTrying)
+ hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
def startedConnecting(self, connector):
logger.info("Connecting to replication: %r", connector.getDestination())
@@ -95,7 +98,7 @@ class ReplicationClientHandler(object):
factory = ReplicationClientFactory(hs, client_name, self)
host = hs.config.worker_replication_host
port = hs.config.worker_replication_port
- reactor.connectTCP(host, port, factory)
+ hs.get_reactor().connectTCP(host, port, factory)
def on_rdata(self, stream_name, token, rows):
"""Called when we get new replication data. By default this just pokes
@@ -104,7 +107,7 @@ class ReplicationClientHandler(object):
Can be overriden in subclasses to handle more.
"""
logger.info("Received rdata %s -> %s", stream_name, token)
- self.store.process_replication_rows(stream_name, token, rows)
+ return self.store.process_replication_rows(stream_name, token, rows)
def on_position(self, stream_name, token):
"""Called when we get new position data. By default this just pokes
@@ -112,7 +115,7 @@ class ReplicationClientHandler(object):
Can be overriden in subclasses to handle more.
"""
- self.store.process_replication_rows(stream_name, token, [])
+ return self.store.process_replication_rows(stream_name, token, [])
def on_sync(self, data):
"""When we received a SYNC we wake up any deferreds that were waiting
@@ -189,7 +192,7 @@ class ReplicationClientHandler(object):
"""Returns a deferred that is resolved when we receive a SYNC command
with given data.
- Used by tests.
+ [Not currently] used by tests.
"""
return self.awaiting_syncs.setdefault(data, defer.Deferred())
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 12aac3cc6b..327556f6a1 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -19,13 +19,17 @@ allowed to be sent by which side.
"""
import logging
-import simplejson
+import platform
+if platform.python_implementation() == "PyPy":
+ import json
+ _json_encoder = json.JSONEncoder()
+else:
+ import simplejson as json
+ _json_encoder = json.JSONEncoder(namedtuple_as_object=False)
logger = logging.getLogger(__name__)
-_json_encoder = simplejson.JSONEncoder(namedtuple_as_object=False)
-
class Command(object):
"""The base command class.
@@ -55,6 +59,12 @@ class Command(object):
"""
return self.data
+ def get_logcontext_id(self):
+ """Get a suitable string for the logcontext when processing this command"""
+
+ # by default, we just use the command name.
+ return self.NAME
+
class ServerCommand(Command):
"""Sent by the server on new connection and includes the server_name.
@@ -102,7 +112,7 @@ class RdataCommand(Command):
return cls(
stream_name,
None if token == "batch" else int(token),
- simplejson.loads(row_json)
+ json.loads(row_json)
)
def to_line(self):
@@ -112,6 +122,9 @@ class RdataCommand(Command):
_json_encoder.encode(self.row),
))
+ def get_logcontext_id(self):
+ return "RDATA-" + self.stream_name
+
class PositionCommand(Command):
"""Sent by the client to tell the client the stream postition without
@@ -186,6 +199,9 @@ class ReplicateCommand(Command):
def to_line(self):
return " ".join((self.stream_name, str(self.token),))
+ def get_logcontext_id(self):
+ return "REPLICATE-" + self.stream_name
+
class UserSyncCommand(Command):
"""Sent by the client to inform the server that a user has started or
@@ -300,7 +316,7 @@ class InvalidateCacheCommand(Command):
def from_line(cls, line):
cache_func, keys_json = line.split(" ", 1)
- return cls(cache_func, simplejson.loads(keys_json))
+ return cls(cache_func, json.loads(keys_json))
def to_line(self):
return " ".join((
@@ -329,7 +345,7 @@ class UserIpCommand(Command):
def from_line(cls, line):
user_id, jsn = line.split(" ", 1)
- access_token, ip, user_agent, device_id, last_seen = simplejson.loads(jsn)
+ access_token, ip, user_agent, device_id, last_seen = json.loads(jsn)
return cls(
user_id, access_token, ip, user_agent, device_id, last_seen
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index c870475cd1..74e892c104 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -49,29 +49,39 @@ indicate which side is sending, these are *not* included on the wire::
* connection closed by server *
"""
+import fcntl
+import logging
+import struct
+from collections import defaultdict
+
+from six import iteritems, iterkeys
+
+from prometheus_client import Counter
+
from twisted.internet import defer
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
-from .commands import (
- COMMAND_MAP, VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS,
- ErrorCommand, ServerCommand, RdataCommand, PositionCommand, PingCommand,
- NameCommand, ReplicateCommand, UserSyncCommand, SyncCommand,
-)
-from .streams import STREAMS_MAP
-
from synapse.metrics import LaterGauge
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.stringutils import random_string
-from prometheus_client import Counter
-
-from collections import defaultdict
-
-from six import iterkeys, iteritems
-
-import logging
-import struct
-import fcntl
+from .commands import (
+ COMMAND_MAP,
+ VALID_CLIENT_COMMANDS,
+ VALID_SERVER_COMMANDS,
+ ErrorCommand,
+ NameCommand,
+ PingCommand,
+ PositionCommand,
+ RdataCommand,
+ ReplicateCommand,
+ ServerCommand,
+ SyncCommand,
+ UserSyncCommand,
+)
+from .streams import STREAMS_MAP
connection_close_counter = Counter(
"synapse_replication_tcp_protocol_close_reason", "", ["reason_type"])
@@ -214,7 +224,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# Now lets try and call on_<CMD_NAME> function
try:
- getattr(self, "on_%s" % (cmd_name,))(cmd)
+ run_as_background_process(
+ "replication-" + cmd.get_logcontext_id(),
+ getattr(self, "on_%s" % (cmd_name,)),
+ cmd,
+ )
except Exception:
logger.exception("[%s] Failed to handle line: %r", self.id(), line)
@@ -379,7 +393,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.name = cmd.data
def on_USER_SYNC(self, cmd):
- self.streamer.on_user_sync(
+ return self.streamer.on_user_sync(
self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms,
)
@@ -389,22 +403,33 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
if stream_name == "ALL":
# Subscribe to all streams we're publishing to.
- for stream in iterkeys(self.streamer.streams_by_name):
- self.subscribe_to_stream(stream, token)
+ deferreds = [
+ run_in_background(
+ self.subscribe_to_stream,
+ stream, token,
+ )
+ for stream in iterkeys(self.streamer.streams_by_name)
+ ]
+
+ return make_deferred_yieldable(
+ defer.gatherResults(deferreds, consumeErrors=True)
+ )
else:
- self.subscribe_to_stream(stream_name, token)
+ return self.subscribe_to_stream(stream_name, token)
def on_FEDERATION_ACK(self, cmd):
- self.streamer.federation_ack(cmd.token)
+ return self.streamer.federation_ack(cmd.token)
def on_REMOVE_PUSHER(self, cmd):
- self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
+ return self.streamer.on_remove_pusher(
+ cmd.app_id, cmd.push_key, cmd.user_id,
+ )
def on_INVALIDATE_CACHE(self, cmd):
- self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
+ return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
def on_USER_IP(self, cmd):
- self.streamer.on_user_ip(
+ return self.streamer.on_user_ip(
cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id,
cmd.last_seen,
)
@@ -534,14 +559,13 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# Check if this is the last of a batch of updates
rows = self.pending_batches.pop(stream_name, [])
rows.append(row)
-
- self.handler.on_rdata(stream_name, cmd.token, rows)
+ return self.handler.on_rdata(stream_name, cmd.token, rows)
def on_POSITION(self, cmd):
- self.handler.on_position(cmd.stream_name, cmd.token)
+ return self.handler.on_position(cmd.stream_name, cmd.token)
def on_SYNC(self, cmd):
- self.handler.on_sync(cmd.data)
+ return self.handler.on_sync(cmd.data)
def replicate(self, stream_name, token):
"""Send the subscription request to the server
@@ -564,11 +588,13 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# The following simply registers metrics for the replication connections
pending_commands = LaterGauge(
- "pending_commands", "", ["name", "conn_id"],
+ "synapse_replication_tcp_protocol_pending_commands",
+ "",
+ ["name", "conn_id"],
lambda: {
- (p.name, p.conn_id): len(p.pending_commands)
- for p in connected_connections
- })
+ (p.name, p.conn_id): len(p.pending_commands) for p in connected_connections
+ },
+)
def transport_buffer_size(protocol):
@@ -579,11 +605,13 @@ def transport_buffer_size(protocol):
transport_send_buffer = LaterGauge(
- "synapse_replication_tcp_transport_send_buffer", "", ["name", "conn_id"],
+ "synapse_replication_tcp_protocol_transport_send_buffer",
+ "",
+ ["name", "conn_id"],
lambda: {
- (p.name, p.conn_id): transport_buffer_size(p)
- for p in connected_connections
- })
+ (p.name, p.conn_id): transport_buffer_size(p) for p in connected_connections
+ },
+)
def transport_kernel_read_buffer_size(protocol, read=True):
@@ -602,37 +630,50 @@ def transport_kernel_read_buffer_size(protocol, read=True):
tcp_transport_kernel_send_buffer = LaterGauge(
- "synapse_replication_tcp_transport_kernel_send_buffer", "", ["name", "conn_id"],
+ "synapse_replication_tcp_protocol_transport_kernel_send_buffer",
+ "",
+ ["name", "conn_id"],
lambda: {
(p.name, p.conn_id): transport_kernel_read_buffer_size(p, False)
for p in connected_connections
- })
+ },
+)
tcp_transport_kernel_read_buffer = LaterGauge(
- "synapse_replication_tcp_transport_kernel_read_buffer", "", ["name", "conn_id"],
+ "synapse_replication_tcp_protocol_transport_kernel_read_buffer",
+ "",
+ ["name", "conn_id"],
lambda: {
(p.name, p.conn_id): transport_kernel_read_buffer_size(p, True)
for p in connected_connections
- })
+ },
+)
tcp_inbound_commands = LaterGauge(
- "synapse_replication_tcp_inbound_commands", "", ["command", "name", "conn_id"],
+ "synapse_replication_tcp_protocol_inbound_commands",
+ "",
+ ["command", "name", "conn_id"],
lambda: {
(k[0], p.name, p.conn_id): count
for p in connected_connections
for k, count in iteritems(p.inbound_commands_counter)
- })
+ },
+)
tcp_outbound_commands = LaterGauge(
- "synapse_replication_tcp_outbound_commands", "", ["command", "name", "conn_id"],
+ "synapse_replication_tcp_protocol_outbound_commands",
+ "",
+ ["command", "name", "conn_id"],
lambda: {
(k[0], p.name, p.conn_id): count
for p in connected_connections
for k, count in iteritems(p.outbound_commands_counter)
- })
+ },
+)
# number of updates received for each RDATA stream
-inbound_rdata_count = Counter("synapse_replication_tcp_inbound_rdata_count", "",
- ["stream_name"])
+inbound_rdata_count = Counter(
+ "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"]
+)
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 63bd6d2652..fd59f1595f 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -15,19 +15,21 @@
"""The server side of the replication stream.
"""
-from twisted.internet import defer, reactor
-from twisted.internet.protocol import Factory
+import logging
-from .streams import STREAMS_MAP, FederationStream
-from .protocol import ServerReplicationStreamProtocol
+from six import itervalues
-from synapse.util.metrics import Measure, measure_func
-from synapse.metrics import LaterGauge
+from prometheus_client import Counter
-import logging
+from twisted.internet import defer
+from twisted.internet.protocol import Factory
-from prometheus_client import Counter
-from six import itervalues
+from synapse.metrics import LaterGauge
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util.metrics import Measure, measure_func
+
+from .protocol import ServerReplicationStreamProtocol
+from .streams import STREAMS_MAP, FederationStream
stream_updates_counter = Counter("synapse_replication_tcp_resource_stream_updates",
"", ["stream_name"])
@@ -109,14 +111,13 @@ class ReplicationStreamer(object):
self.is_looping = False
self.pending_updates = False
- reactor.addSystemEventTrigger("before", "shutdown", self.on_shutdown)
+ hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.on_shutdown)
def on_shutdown(self):
# close all connections on shutdown
for conn in self.connections:
conn.send_error("server shutting down")
- @defer.inlineCallbacks
def on_notifier_poke(self):
"""Checks if there is actually any new data and sends it to the
connections if there are.
@@ -131,14 +132,16 @@ class ReplicationStreamer(object):
stream.discard_updates_and_advance()
return
- # If we're in the process of checking for new updates, mark that fact
- # and return
+ self.pending_updates = True
+
if self.is_looping:
- logger.debug("Noitifier poke loop already running")
- self.pending_updates = True
+ logger.debug("Notifier poke loop already running")
return
- self.pending_updates = True
+ run_as_background_process("replication_notifier", self._run_notifier_loop)
+
+ @defer.inlineCallbacks
+ def _run_notifier_loop(self):
self.is_looping = True
try:
diff --git a/synapse/replication/tcp/streams.py b/synapse/replication/tcp/streams.py
index 4c60bf79f9..55fe701c5c 100644
--- a/synapse/replication/tcp/streams.py
+++ b/synapse/replication/tcp/streams.py
@@ -24,11 +24,10 @@ Each stream is defined by the following information:
update_function: The function that returns a list of updates between two tokens
"""
-from twisted.internet import defer
-from collections import namedtuple
-
import logging
+from collections import namedtuple
+from twisted.internet import defer
logger = logging.getLogger(__name__)
|