diff options
Diffstat (limited to 'synapse/replication/tcp')
-rw-r--r-- | synapse/replication/tcp/client.py | 21 | ||||
-rw-r--r-- | synapse/replication/tcp/commands.py | 28 | ||||
-rw-r--r-- | synapse/replication/tcp/protocol.py | 135 | ||||
-rw-r--r-- | synapse/replication/tcp/resource.py | 35 | ||||
-rw-r--r-- | synapse/replication/tcp/streams.py | 5 |
5 files changed, 143 insertions, 81 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 6d2513c4e2..cbe9645817 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -15,17 +15,20 @@ """A replication client for use by synapse workers. """ -from twisted.internet import reactor, defer +import logging + +from twisted.internet import defer from twisted.internet.protocol import ReconnectingClientFactory from .commands import ( - FederationAckCommand, UserSyncCommand, RemovePusherCommand, InvalidateCacheCommand, + FederationAckCommand, + InvalidateCacheCommand, + RemovePusherCommand, UserIpCommand, + UserSyncCommand, ) from .protocol import ClientReplicationStreamProtocol -import logging - logger = logging.getLogger(__name__) @@ -44,7 +47,7 @@ class ReplicationClientFactory(ReconnectingClientFactory): self.server_name = hs.config.server_name self._clock = hs.get_clock() # As self.clock is defined in super class - reactor.addSystemEventTrigger("before", "shutdown", self.stopTrying) + hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying) def startedConnecting(self, connector): logger.info("Connecting to replication: %r", connector.getDestination()) @@ -95,7 +98,7 @@ class ReplicationClientHandler(object): factory = ReplicationClientFactory(hs, client_name, self) host = hs.config.worker_replication_host port = hs.config.worker_replication_port - reactor.connectTCP(host, port, factory) + hs.get_reactor().connectTCP(host, port, factory) def on_rdata(self, stream_name, token, rows): """Called when we get new replication data. By default this just pokes @@ -104,7 +107,7 @@ class ReplicationClientHandler(object): Can be overriden in subclasses to handle more. """ logger.info("Received rdata %s -> %s", stream_name, token) - self.store.process_replication_rows(stream_name, token, rows) + return self.store.process_replication_rows(stream_name, token, rows) def on_position(self, stream_name, token): """Called when we get new position data. By default this just pokes @@ -112,7 +115,7 @@ class ReplicationClientHandler(object): Can be overriden in subclasses to handle more. """ - self.store.process_replication_rows(stream_name, token, []) + return self.store.process_replication_rows(stream_name, token, []) def on_sync(self, data): """When we received a SYNC we wake up any deferreds that were waiting @@ -189,7 +192,7 @@ class ReplicationClientHandler(object): """Returns a deferred that is resolved when we receive a SYNC command with given data. - Used by tests. + [Not currently] used by tests. """ return self.awaiting_syncs.setdefault(data, defer.Deferred()) diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 12aac3cc6b..327556f6a1 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -19,13 +19,17 @@ allowed to be sent by which side. """ import logging -import simplejson +import platform +if platform.python_implementation() == "PyPy": + import json + _json_encoder = json.JSONEncoder() +else: + import simplejson as json + _json_encoder = json.JSONEncoder(namedtuple_as_object=False) logger = logging.getLogger(__name__) -_json_encoder = simplejson.JSONEncoder(namedtuple_as_object=False) - class Command(object): """The base command class. @@ -55,6 +59,12 @@ class Command(object): """ return self.data + def get_logcontext_id(self): + """Get a suitable string for the logcontext when processing this command""" + + # by default, we just use the command name. + return self.NAME + class ServerCommand(Command): """Sent by the server on new connection and includes the server_name. @@ -102,7 +112,7 @@ class RdataCommand(Command): return cls( stream_name, None if token == "batch" else int(token), - simplejson.loads(row_json) + json.loads(row_json) ) def to_line(self): @@ -112,6 +122,9 @@ class RdataCommand(Command): _json_encoder.encode(self.row), )) + def get_logcontext_id(self): + return "RDATA-" + self.stream_name + class PositionCommand(Command): """Sent by the client to tell the client the stream postition without @@ -186,6 +199,9 @@ class ReplicateCommand(Command): def to_line(self): return " ".join((self.stream_name, str(self.token),)) + def get_logcontext_id(self): + return "REPLICATE-" + self.stream_name + class UserSyncCommand(Command): """Sent by the client to inform the server that a user has started or @@ -300,7 +316,7 @@ class InvalidateCacheCommand(Command): def from_line(cls, line): cache_func, keys_json = line.split(" ", 1) - return cls(cache_func, simplejson.loads(keys_json)) + return cls(cache_func, json.loads(keys_json)) def to_line(self): return " ".join(( @@ -329,7 +345,7 @@ class UserIpCommand(Command): def from_line(cls, line): user_id, jsn = line.split(" ", 1) - access_token, ip, user_agent, device_id, last_seen = simplejson.loads(jsn) + access_token, ip, user_agent, device_id, last_seen = json.loads(jsn) return cls( user_id, access_token, ip, user_agent, device_id, last_seen diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index c870475cd1..74e892c104 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -49,29 +49,39 @@ indicate which side is sending, these are *not* included on the wire:: * connection closed by server * """ +import fcntl +import logging +import struct +from collections import defaultdict + +from six import iteritems, iterkeys + +from prometheus_client import Counter + from twisted.internet import defer from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure -from .commands import ( - COMMAND_MAP, VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS, - ErrorCommand, ServerCommand, RdataCommand, PositionCommand, PingCommand, - NameCommand, ReplicateCommand, UserSyncCommand, SyncCommand, -) -from .streams import STREAMS_MAP - from synapse.metrics import LaterGauge +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.stringutils import random_string -from prometheus_client import Counter - -from collections import defaultdict - -from six import iterkeys, iteritems - -import logging -import struct -import fcntl +from .commands import ( + COMMAND_MAP, + VALID_CLIENT_COMMANDS, + VALID_SERVER_COMMANDS, + ErrorCommand, + NameCommand, + PingCommand, + PositionCommand, + RdataCommand, + ReplicateCommand, + ServerCommand, + SyncCommand, + UserSyncCommand, +) +from .streams import STREAMS_MAP connection_close_counter = Counter( "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]) @@ -214,7 +224,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # Now lets try and call on_<CMD_NAME> function try: - getattr(self, "on_%s" % (cmd_name,))(cmd) + run_as_background_process( + "replication-" + cmd.get_logcontext_id(), + getattr(self, "on_%s" % (cmd_name,)), + cmd, + ) except Exception: logger.exception("[%s] Failed to handle line: %r", self.id(), line) @@ -379,7 +393,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): self.name = cmd.data def on_USER_SYNC(self, cmd): - self.streamer.on_user_sync( + return self.streamer.on_user_sync( self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms, ) @@ -389,22 +403,33 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): if stream_name == "ALL": # Subscribe to all streams we're publishing to. - for stream in iterkeys(self.streamer.streams_by_name): - self.subscribe_to_stream(stream, token) + deferreds = [ + run_in_background( + self.subscribe_to_stream, + stream, token, + ) + for stream in iterkeys(self.streamer.streams_by_name) + ] + + return make_deferred_yieldable( + defer.gatherResults(deferreds, consumeErrors=True) + ) else: - self.subscribe_to_stream(stream_name, token) + return self.subscribe_to_stream(stream_name, token) def on_FEDERATION_ACK(self, cmd): - self.streamer.federation_ack(cmd.token) + return self.streamer.federation_ack(cmd.token) def on_REMOVE_PUSHER(self, cmd): - self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id) + return self.streamer.on_remove_pusher( + cmd.app_id, cmd.push_key, cmd.user_id, + ) def on_INVALIDATE_CACHE(self, cmd): - self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) + return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys) def on_USER_IP(self, cmd): - self.streamer.on_user_ip( + return self.streamer.on_user_ip( cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id, cmd.last_seen, ) @@ -534,14 +559,13 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # Check if this is the last of a batch of updates rows = self.pending_batches.pop(stream_name, []) rows.append(row) - - self.handler.on_rdata(stream_name, cmd.token, rows) + return self.handler.on_rdata(stream_name, cmd.token, rows) def on_POSITION(self, cmd): - self.handler.on_position(cmd.stream_name, cmd.token) + return self.handler.on_position(cmd.stream_name, cmd.token) def on_SYNC(self, cmd): - self.handler.on_sync(cmd.data) + return self.handler.on_sync(cmd.data) def replicate(self, stream_name, token): """Send the subscription request to the server @@ -564,11 +588,13 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # The following simply registers metrics for the replication connections pending_commands = LaterGauge( - "pending_commands", "", ["name", "conn_id"], + "synapse_replication_tcp_protocol_pending_commands", + "", + ["name", "conn_id"], lambda: { - (p.name, p.conn_id): len(p.pending_commands) - for p in connected_connections - }) + (p.name, p.conn_id): len(p.pending_commands) for p in connected_connections + }, +) def transport_buffer_size(protocol): @@ -579,11 +605,13 @@ def transport_buffer_size(protocol): transport_send_buffer = LaterGauge( - "synapse_replication_tcp_transport_send_buffer", "", ["name", "conn_id"], + "synapse_replication_tcp_protocol_transport_send_buffer", + "", + ["name", "conn_id"], lambda: { - (p.name, p.conn_id): transport_buffer_size(p) - for p in connected_connections - }) + (p.name, p.conn_id): transport_buffer_size(p) for p in connected_connections + }, +) def transport_kernel_read_buffer_size(protocol, read=True): @@ -602,37 +630,50 @@ def transport_kernel_read_buffer_size(protocol, read=True): tcp_transport_kernel_send_buffer = LaterGauge( - "synapse_replication_tcp_transport_kernel_send_buffer", "", ["name", "conn_id"], + "synapse_replication_tcp_protocol_transport_kernel_send_buffer", + "", + ["name", "conn_id"], lambda: { (p.name, p.conn_id): transport_kernel_read_buffer_size(p, False) for p in connected_connections - }) + }, +) tcp_transport_kernel_read_buffer = LaterGauge( - "synapse_replication_tcp_transport_kernel_read_buffer", "", ["name", "conn_id"], + "synapse_replication_tcp_protocol_transport_kernel_read_buffer", + "", + ["name", "conn_id"], lambda: { (p.name, p.conn_id): transport_kernel_read_buffer_size(p, True) for p in connected_connections - }) + }, +) tcp_inbound_commands = LaterGauge( - "synapse_replication_tcp_inbound_commands", "", ["command", "name", "conn_id"], + "synapse_replication_tcp_protocol_inbound_commands", + "", + ["command", "name", "conn_id"], lambda: { (k[0], p.name, p.conn_id): count for p in connected_connections for k, count in iteritems(p.inbound_commands_counter) - }) + }, +) tcp_outbound_commands = LaterGauge( - "synapse_replication_tcp_outbound_commands", "", ["command", "name", "conn_id"], + "synapse_replication_tcp_protocol_outbound_commands", + "", + ["command", "name", "conn_id"], lambda: { (k[0], p.name, p.conn_id): count for p in connected_connections for k, count in iteritems(p.outbound_commands_counter) - }) + }, +) # number of updates received for each RDATA stream -inbound_rdata_count = Counter("synapse_replication_tcp_inbound_rdata_count", "", - ["stream_name"]) +inbound_rdata_count = Counter( + "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"] +) diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 63bd6d2652..fd59f1595f 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -15,19 +15,21 @@ """The server side of the replication stream. """ -from twisted.internet import defer, reactor -from twisted.internet.protocol import Factory +import logging -from .streams import STREAMS_MAP, FederationStream -from .protocol import ServerReplicationStreamProtocol +from six import itervalues -from synapse.util.metrics import Measure, measure_func -from synapse.metrics import LaterGauge +from prometheus_client import Counter -import logging +from twisted.internet import defer +from twisted.internet.protocol import Factory -from prometheus_client import Counter -from six import itervalues +from synapse.metrics import LaterGauge +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.util.metrics import Measure, measure_func + +from .protocol import ServerReplicationStreamProtocol +from .streams import STREAMS_MAP, FederationStream stream_updates_counter = Counter("synapse_replication_tcp_resource_stream_updates", "", ["stream_name"]) @@ -109,14 +111,13 @@ class ReplicationStreamer(object): self.is_looping = False self.pending_updates = False - reactor.addSystemEventTrigger("before", "shutdown", self.on_shutdown) + hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.on_shutdown) def on_shutdown(self): # close all connections on shutdown for conn in self.connections: conn.send_error("server shutting down") - @defer.inlineCallbacks def on_notifier_poke(self): """Checks if there is actually any new data and sends it to the connections if there are. @@ -131,14 +132,16 @@ class ReplicationStreamer(object): stream.discard_updates_and_advance() return - # If we're in the process of checking for new updates, mark that fact - # and return + self.pending_updates = True + if self.is_looping: - logger.debug("Noitifier poke loop already running") - self.pending_updates = True + logger.debug("Notifier poke loop already running") return - self.pending_updates = True + run_as_background_process("replication_notifier", self._run_notifier_loop) + + @defer.inlineCallbacks + def _run_notifier_loop(self): self.is_looping = True try: diff --git a/synapse/replication/tcp/streams.py b/synapse/replication/tcp/streams.py index 4c60bf79f9..55fe701c5c 100644 --- a/synapse/replication/tcp/streams.py +++ b/synapse/replication/tcp/streams.py @@ -24,11 +24,10 @@ Each stream is defined by the following information: update_function: The function that returns a list of updates between two tokens """ -from twisted.internet import defer -from collections import namedtuple - import logging +from collections import namedtuple +from twisted.internet import defer logger = logging.getLogger(__name__) |