diff options
Diffstat (limited to 'synapse/replication/tcp')
-rw-r--r-- | synapse/replication/tcp/client.py | 15 | ||||
-rw-r--r-- | synapse/replication/tcp/commands.py | 28 | ||||
-rw-r--r-- | synapse/replication/tcp/protocol.py | 141 | ||||
-rw-r--r-- | synapse/replication/tcp/resource.py | 66 | ||||
-rw-r--r-- | synapse/replication/tcp/streams.py | 25 |
5 files changed, 168 insertions, 107 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 6d2513c4e2..e592ab57bf 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -15,17 +15,20 @@ """A replication client for use by synapse workers. """ -from twisted.internet import reactor, defer +import logging + +from twisted.internet import defer from twisted.internet.protocol import ReconnectingClientFactory from .commands import ( - FederationAckCommand, UserSyncCommand, RemovePusherCommand, InvalidateCacheCommand, + FederationAckCommand, + InvalidateCacheCommand, + RemovePusherCommand, UserIpCommand, + UserSyncCommand, ) from .protocol import ClientReplicationStreamProtocol -import logging - logger = logging.getLogger(__name__) @@ -44,7 +47,7 @@ class ReplicationClientFactory(ReconnectingClientFactory): self.server_name = hs.config.server_name self._clock = hs.get_clock() # As self.clock is defined in super class - reactor.addSystemEventTrigger("before", "shutdown", self.stopTrying) + hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying) def startedConnecting(self, connector): logger.info("Connecting to replication: %r", connector.getDestination()) @@ -95,7 +98,7 @@ class ReplicationClientHandler(object): factory = ReplicationClientFactory(hs, client_name, self) host = hs.config.worker_replication_host port = hs.config.worker_replication_port - reactor.connectTCP(host, port, factory) + hs.get_reactor().connectTCP(host, port, factory) def on_rdata(self, stream_name, token, rows): """Called when we get new replication data. By default this just pokes diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index a009214e43..f3908df642 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -19,8 +19,14 @@ allowed to be sent by which side. """ import logging -import ujson as json +import platform +if platform.python_implementation() == "PyPy": + import json + _json_encoder = json.JSONEncoder() +else: + import simplejson as json + _json_encoder = json.JSONEncoder(namedtuple_as_object=False) logger = logging.getLogger(__name__) @@ -107,7 +113,7 @@ class RdataCommand(Command): return " ".join(( self.stream_name, str(self.token) if self.token is not None else "batch", - json.dumps(self.row), + _json_encoder.encode(self.row), )) @@ -301,7 +307,9 @@ class InvalidateCacheCommand(Command): return cls(cache_func, json.loads(keys_json)) def to_line(self): - return " ".join((self.cache_func, json.dumps(self.keys))) + return " ".join(( + self.cache_func, _json_encoder.encode(self.keys), + )) class UserIpCommand(Command): @@ -323,14 +331,18 @@ class UserIpCommand(Command): @classmethod def from_line(cls, line): - user_id, access_token, ip, device_id, last_seen, user_agent = line.split(" ", 5) + user_id, jsn = line.split(" ", 1) + + access_token, ip, user_agent, device_id, last_seen = json.loads(jsn) - return cls(user_id, access_token, ip, user_agent, device_id, int(last_seen)) + return cls( + user_id, access_token, ip, user_agent, device_id, last_seen + ) def to_line(self): - return " ".join(( - self.user_id, self.access_token, self.ip, self.device_id, - str(self.last_seen), self.user_agent, + return self.user_id + " " + _json_encoder.encode(( + self.access_token, self.ip, self.user_agent, self.device_id, + self.last_seen, )) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 062272f8dd..dec5ac0913 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -49,32 +49,40 @@ indicate which side is sending, these are *not* included on the wire:: * connection closed by server * """ +import fcntl +import logging +import struct +from collections import defaultdict + +from six import iteritems, iterkeys + +from prometheus_client import Counter + from twisted.internet import defer from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure -from commands import ( - COMMAND_MAP, VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS, - ErrorCommand, ServerCommand, RdataCommand, PositionCommand, PingCommand, - NameCommand, ReplicateCommand, UserSyncCommand, SyncCommand, -) -from streams import STREAMS_MAP - +from synapse.metrics import LaterGauge from synapse.util.stringutils import random_string -from synapse.metrics.metric import CounterMetric - -import logging -import synapse.metrics -import struct -import fcntl - -metrics = synapse.metrics.get_metrics_for(__name__) - -connection_close_counter = metrics.register_counter( - "close_reason", labels=["reason_type"], +from .commands import ( + COMMAND_MAP, + VALID_CLIENT_COMMANDS, + VALID_SERVER_COMMANDS, + ErrorCommand, + NameCommand, + PingCommand, + PositionCommand, + RdataCommand, + ReplicateCommand, + ServerCommand, + SyncCommand, + UserSyncCommand, ) +from .streams import STREAMS_MAP +connection_close_counter = Counter( + "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]) # A list of all connected protocols. This allows us to send metrics about the # connections. @@ -136,12 +144,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # The LoopingCall for sending pings. self._send_ping_loop = None - self.inbound_commands_counter = CounterMetric( - "inbound_commands", labels=["command"], - ) - self.outbound_commands_counter = CounterMetric( - "outbound_commands", labels=["command"], - ) + self.inbound_commands_counter = defaultdict(int) + self.outbound_commands_counter = defaultdict(int) def connectionMade(self): logger.info("[%s] Connection established", self.id()) @@ -201,7 +205,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.last_received_command = self.clock.time_msec() - self.inbound_commands_counter.inc(cmd_name) + self.inbound_commands_counter[cmd_name] = ( + self.inbound_commands_counter[cmd_name] + 1) cmd_cls = COMMAND_MAP[cmd_name] try: @@ -244,15 +249,15 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): becoming full. """ if self.state == ConnectionStates.CLOSED: - logger.info("[%s] Not sending, connection closed", self.id()) + logger.debug("[%s] Not sending, connection closed", self.id()) return if do_buffer and self.state != ConnectionStates.ESTABLISHED: self._queue_command(cmd) return - self.outbound_commands_counter.inc(cmd.NAME) - + self.outbound_commands_counter[cmd.NAME] = ( + self.outbound_commands_counter[cmd.NAME] + 1) string = "%s %s" % (cmd.NAME, cmd.to_line(),) if "\n" in string: raise Exception("Unexpected newline in command: %r", string) @@ -264,7 +269,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): def _queue_command(self, cmd): """Queue the command until the connection is ready to write to again. """ - logger.info("[%s] Queing as conn %r, cmd: %r", self.id(), self.state, cmd) + logger.debug("[%s] Queing as conn %r, cmd: %r", self.id(), self.state, cmd) self.pending_commands.append(cmd) if len(self.pending_commands) > self.max_line_buffer: @@ -317,9 +322,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): def connectionLost(self, reason): logger.info("[%s] Replication connection closed: %r", self.id(), reason) if isinstance(reason, Failure): - connection_close_counter.inc(reason.type.__name__) + connection_close_counter.labels(reason.type.__name__).inc() else: - connection_close_counter.inc(reason.__class__.__name__) + connection_close_counter.labels(reason.__class__.__name__).inc() try: # Remove us from list of connections to be monitored @@ -392,7 +397,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): if stream_name == "ALL": # Subscribe to all streams we're publishing to. - for stream in self.streamer.streams_by_name.iterkeys(): + for stream in iterkeys(self.streamer.streams_by_name): self.subscribe_to_stream(stream, token) else: self.subscribe_to_stream(stream_name, token) @@ -498,7 +503,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): BaseReplicationStreamProtocol.connectionMade(self) # Once we've connected subscribe to the necessary streams - for stream_name, token in self.handler.get_streams_to_replicate().iteritems(): + for stream_name, token in iteritems(self.handler.get_streams_to_replicate()): self.replicate(stream_name, token) # Tell the server if we have any users currently syncing (should only @@ -517,25 +522,28 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): self.send_error("Wrong remote") def on_RDATA(self, cmd): + stream_name = cmd.stream_name + inbound_rdata_count.labels(stream_name).inc() + try: - row = STREAMS_MAP[cmd.stream_name].ROW_TYPE(*cmd.row) + row = STREAMS_MAP[stream_name].ROW_TYPE(*cmd.row) except Exception: logger.exception( "[%s] Failed to parse RDATA: %r %r", - self.id(), cmd.stream_name, cmd.row + self.id(), stream_name, cmd.row ) raise if cmd.token is None: # I.e. this is part of a batch of updates for this stream. Batch # until we get an update for the stream with a non None token - self.pending_batches.setdefault(cmd.stream_name, []).append(row) + self.pending_batches.setdefault(stream_name, []).append(row) else: # Check if this is the last of a batch of updates - rows = self.pending_batches.pop(cmd.stream_name, []) + rows = self.pending_batches.pop(stream_name, []) rows.append(row) - self.handler.on_rdata(cmd.stream_name, cmd.token, rows) + self.handler.on_rdata(stream_name, cmd.token, rows) def on_POSITION(self, cmd): self.handler.on_position(cmd.stream_name, cmd.token) @@ -563,13 +571,13 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # The following simply registers metrics for the replication connections -metrics.register_callback( - "pending_commands", +pending_commands = LaterGauge( + "synapse_replication_tcp_protocol_pending_commands", + "", + ["name", "conn_id"], lambda: { - (p.name, p.conn_id): len(p.pending_commands) - for p in connected_connections + (p.name, p.conn_id): len(p.pending_commands) for p in connected_connections }, - labels=["name", "conn_id"], ) @@ -580,13 +588,13 @@ def transport_buffer_size(protocol): return 0 -metrics.register_callback( - "transport_send_buffer", +transport_send_buffer = LaterGauge( + "synapse_replication_tcp_protocol_transport_send_buffer", + "", + ["name", "conn_id"], lambda: { - (p.name, p.conn_id): transport_buffer_size(p) - for p in connected_connections + (p.name, p.conn_id): transport_buffer_size(p) for p in connected_connections }, - labels=["name", "conn_id"], ) @@ -605,42 +613,51 @@ def transport_kernel_read_buffer_size(protocol, read=True): return 0 -metrics.register_callback( - "transport_kernel_send_buffer", +tcp_transport_kernel_send_buffer = LaterGauge( + "synapse_replication_tcp_protocol_transport_kernel_send_buffer", + "", + ["name", "conn_id"], lambda: { (p.name, p.conn_id): transport_kernel_read_buffer_size(p, False) for p in connected_connections }, - labels=["name", "conn_id"], ) -metrics.register_callback( - "transport_kernel_read_buffer", +tcp_transport_kernel_read_buffer = LaterGauge( + "synapse_replication_tcp_protocol_transport_kernel_read_buffer", + "", + ["name", "conn_id"], lambda: { (p.name, p.conn_id): transport_kernel_read_buffer_size(p, True) for p in connected_connections }, - labels=["name", "conn_id"], ) -metrics.register_callback( - "inbound_commands", +tcp_inbound_commands = LaterGauge( + "synapse_replication_tcp_protocol_inbound_commands", + "", + ["command", "name", "conn_id"], lambda: { (k[0], p.name, p.conn_id): count for p in connected_connections - for k, count in p.inbound_commands_counter.counts.iteritems() + for k, count in iteritems(p.inbound_commands_counter) }, - labels=["command", "name", "conn_id"], ) -metrics.register_callback( - "outbound_commands", +tcp_outbound_commands = LaterGauge( + "synapse_replication_tcp_protocol_outbound_commands", + "", + ["command", "name", "conn_id"], lambda: { (k[0], p.name, p.conn_id): count for p in connected_connections - for k, count in p.outbound_commands_counter.counts.iteritems() + for k, count in iteritems(p.outbound_commands_counter) }, - labels=["command", "name", "conn_id"], +) + +# number of updates received for each RDATA stream +inbound_rdata_count = Counter( + "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"] ) diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 3ea3ca5a6f..611fb66e1d 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -15,27 +15,29 @@ """The server side of the replication stream. """ -from twisted.internet import defer, reactor -from twisted.internet.protocol import Factory +import logging -from streams import STREAMS_MAP, FederationStream -from protocol import ServerReplicationStreamProtocol +from six import itervalues -from synapse.util.metrics import Measure, measure_func +from prometheus_client import Counter -import logging -import synapse.metrics +from twisted.internet import defer +from twisted.internet.protocol import Factory +from synapse.metrics import LaterGauge +from synapse.util.metrics import Measure, measure_func + +from .protocol import ServerReplicationStreamProtocol +from .streams import STREAMS_MAP, FederationStream -metrics = synapse.metrics.get_metrics_for(__name__) -stream_updates_counter = metrics.register_counter( - "stream_updates", labels=["stream_name"] -) -user_sync_counter = metrics.register_counter("user_sync") -federation_ack_counter = metrics.register_counter("federation_ack") -remove_pusher_counter = metrics.register_counter("remove_pusher") -invalidate_cache_counter = metrics.register_counter("invalidate_cache") -user_ip_cache_counter = metrics.register_counter("user_ip_cache") +stream_updates_counter = Counter("synapse_replication_tcp_resource_stream_updates", + "", ["stream_name"]) +user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "") +federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "") +remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "") +invalidate_cache_counter = Counter("synapse_replication_tcp_resource_invalidate_cache", + "") +user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") logger = logging.getLogger(__name__) @@ -69,33 +71,34 @@ class ReplicationStreamer(object): self.presence_handler = hs.get_presence_handler() self.clock = hs.get_clock() self.notifier = hs.get_notifier() + self._server_notices_sender = hs.get_server_notices_sender() # Current connections. self.connections = [] - metrics.register_callback("total_connections", lambda: len(self.connections)) + LaterGauge("synapse_replication_tcp_resource_total_connections", "", [], + lambda: len(self.connections)) # List of streams that clients can subscribe to. # We only support federation stream if federation sending hase been # disabled on the master. self.streams = [ - stream(hs) for stream in STREAMS_MAP.itervalues() + stream(hs) for stream in itervalues(STREAMS_MAP) if stream != FederationStream or not hs.config.send_federation ] self.streams_by_name = {stream.NAME: stream for stream in self.streams} - metrics.register_callback( - "connections_per_stream", + LaterGauge( + "synapse_replication_tcp_resource_connections_per_stream", "", + ["stream_name"], lambda: { (stream_name,): len([ conn for conn in self.connections if stream_name in conn.replication_streams ]) for stream_name in self.streams_by_name - }, - labels=["stream_name"], - ) + }) self.federation_sender = None if not hs.config.send_federation: @@ -107,7 +110,7 @@ class ReplicationStreamer(object): self.is_looping = False self.pending_updates = False - reactor.addSystemEventTrigger("before", "shutdown", self.on_shutdown) + hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.on_shutdown) def on_shutdown(self): # close all connections on shutdown @@ -160,7 +163,11 @@ class ReplicationStreamer(object): "Getting stream: %s: %s -> %s", stream.NAME, stream.last_token, stream.upto_token ) - updates, current_token = yield stream.get_updates() + try: + updates, current_token = yield stream.get_updates() + except Exception: + logger.info("Failed to handle stream %s", stream.NAME) + raise logger.debug( "Sending %d updates to %d connections", @@ -171,7 +178,7 @@ class ReplicationStreamer(object): logger.info( "Streaming: %s -> %s", stream.NAME, updates[-1][0] ) - stream_updates_counter.inc_by(len(updates), stream.NAME) + stream_updates_counter.labels(stream.NAME).inc(len(updates)) # Some streams return multiple rows with the same stream IDs, # we need to make sure they get sent out in batches. We do @@ -212,11 +219,12 @@ class ReplicationStreamer(object): self.federation_sender.federation_ack(token) @measure_func("repl.on_user_sync") + @defer.inlineCallbacks def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms): """A client has started/stopped syncing on a worker. """ user_sync_counter.inc() - self.presence_handler.update_external_syncs_row( + yield self.presence_handler.update_external_syncs_row( conn_id, user_id, is_syncing, last_sync_ms, ) @@ -240,13 +248,15 @@ class ReplicationStreamer(object): getattr(self.store, cache_func).invalidate(tuple(keys)) @measure_func("repl.on_user_ip") + @defer.inlineCallbacks def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen): """The client saw a user request """ user_ip_cache_counter.inc() - self.store.insert_client_ip( + yield self.store.insert_client_ip( user_id, access_token, ip, user_agent, device_id, last_seen, ) + yield self._server_notices_sender.on_user_ip(user_id) def send_sync_to_all_connections(self, data): """Sends a SYNC command to all clients. diff --git a/synapse/replication/tcp/streams.py b/synapse/replication/tcp/streams.py index fbafe12cc2..55fe701c5c 100644 --- a/synapse/replication/tcp/streams.py +++ b/synapse/replication/tcp/streams.py @@ -24,11 +24,10 @@ Each stream is defined by the following information: update_function: The function that returns a list of updates between two tokens """ -from twisted.internet import defer -from collections import namedtuple - import logging +from collections import namedtuple +from twisted.internet import defer logger = logging.getLogger(__name__) @@ -118,6 +117,12 @@ CurrentStateDeltaStreamRow = namedtuple("CurrentStateDeltaStream", ( "state_key", # str "event_id", # str, optional )) +GroupsStreamRow = namedtuple("GroupsStreamRow", ( + "group_id", # str + "user_id", # str + "type", # str + "content", # dict +)) class Stream(object): @@ -464,6 +469,19 @@ class CurrentStateDeltaStream(Stream): super(CurrentStateDeltaStream, self).__init__(hs) +class GroupServerStream(Stream): + NAME = "groups" + ROW_TYPE = GroupsStreamRow + + def __init__(self, hs): + store = hs.get_datastore() + + self.current_token = store.get_group_stream_token + self.update_function = store.get_all_groups_changes + + super(GroupServerStream, self).__init__(hs) + + STREAMS_MAP = { stream.NAME: stream for stream in ( @@ -482,5 +500,6 @@ STREAMS_MAP = { TagAccountDataStream, AccountDataStream, CurrentStateDeltaStream, + GroupServerStream, ) } |