diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 6d2513c4e2..e592ab57bf 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -15,17 +15,20 @@
"""A replication client for use by synapse workers.
"""
-from twisted.internet import reactor, defer
+import logging
+
+from twisted.internet import defer
from twisted.internet.protocol import ReconnectingClientFactory
from .commands import (
- FederationAckCommand, UserSyncCommand, RemovePusherCommand, InvalidateCacheCommand,
+ FederationAckCommand,
+ InvalidateCacheCommand,
+ RemovePusherCommand,
UserIpCommand,
+ UserSyncCommand,
)
from .protocol import ClientReplicationStreamProtocol
-import logging
-
logger = logging.getLogger(__name__)
@@ -44,7 +47,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
self.server_name = hs.config.server_name
self._clock = hs.get_clock() # As self.clock is defined in super class
- reactor.addSystemEventTrigger("before", "shutdown", self.stopTrying)
+ hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
def startedConnecting(self, connector):
logger.info("Connecting to replication: %r", connector.getDestination())
@@ -95,7 +98,7 @@ class ReplicationClientHandler(object):
factory = ReplicationClientFactory(hs, client_name, self)
host = hs.config.worker_replication_host
port = hs.config.worker_replication_port
- reactor.connectTCP(host, port, factory)
+ hs.get_reactor().connectTCP(host, port, factory)
def on_rdata(self, stream_name, token, rows):
"""Called when we get new replication data. By default this just pokes
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index a009214e43..f3908df642 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -19,8 +19,14 @@ allowed to be sent by which side.
"""
import logging
-import ujson as json
+import platform
+if platform.python_implementation() == "PyPy":
+ import json
+ _json_encoder = json.JSONEncoder()
+else:
+ import simplejson as json
+ _json_encoder = json.JSONEncoder(namedtuple_as_object=False)
logger = logging.getLogger(__name__)
@@ -107,7 +113,7 @@ class RdataCommand(Command):
return " ".join((
self.stream_name,
str(self.token) if self.token is not None else "batch",
- json.dumps(self.row),
+ _json_encoder.encode(self.row),
))
@@ -301,7 +307,9 @@ class InvalidateCacheCommand(Command):
return cls(cache_func, json.loads(keys_json))
def to_line(self):
- return " ".join((self.cache_func, json.dumps(self.keys)))
+ return " ".join((
+ self.cache_func, _json_encoder.encode(self.keys),
+ ))
class UserIpCommand(Command):
@@ -323,14 +331,18 @@ class UserIpCommand(Command):
@classmethod
def from_line(cls, line):
- user_id, access_token, ip, device_id, last_seen, user_agent = line.split(" ", 5)
+ user_id, jsn = line.split(" ", 1)
+
+ access_token, ip, user_agent, device_id, last_seen = json.loads(jsn)
- return cls(user_id, access_token, ip, user_agent, device_id, int(last_seen))
+ return cls(
+ user_id, access_token, ip, user_agent, device_id, last_seen
+ )
def to_line(self):
- return " ".join((
- self.user_id, self.access_token, self.ip, self.device_id,
- str(self.last_seen), self.user_agent,
+ return self.user_id + " " + _json_encoder.encode((
+ self.access_token, self.ip, self.user_agent, self.device_id,
+ self.last_seen,
))
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 062272f8dd..dec5ac0913 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -49,32 +49,40 @@ indicate which side is sending, these are *not* included on the wire::
* connection closed by server *
"""
+import fcntl
+import logging
+import struct
+from collections import defaultdict
+
+from six import iteritems, iterkeys
+
+from prometheus_client import Counter
+
from twisted.internet import defer
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
-from commands import (
- COMMAND_MAP, VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS,
- ErrorCommand, ServerCommand, RdataCommand, PositionCommand, PingCommand,
- NameCommand, ReplicateCommand, UserSyncCommand, SyncCommand,
-)
-from streams import STREAMS_MAP
-
+from synapse.metrics import LaterGauge
from synapse.util.stringutils import random_string
-from synapse.metrics.metric import CounterMetric
-
-import logging
-import synapse.metrics
-import struct
-import fcntl
-
-metrics = synapse.metrics.get_metrics_for(__name__)
-
-connection_close_counter = metrics.register_counter(
- "close_reason", labels=["reason_type"],
+from .commands import (
+ COMMAND_MAP,
+ VALID_CLIENT_COMMANDS,
+ VALID_SERVER_COMMANDS,
+ ErrorCommand,
+ NameCommand,
+ PingCommand,
+ PositionCommand,
+ RdataCommand,
+ ReplicateCommand,
+ ServerCommand,
+ SyncCommand,
+ UserSyncCommand,
)
+from .streams import STREAMS_MAP
+connection_close_counter = Counter(
+ "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"])
# A list of all connected protocols. This allows us to send metrics about the
# connections.
@@ -136,12 +144,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# The LoopingCall for sending pings.
self._send_ping_loop = None
- self.inbound_commands_counter = CounterMetric(
- "inbound_commands", labels=["command"],
- )
- self.outbound_commands_counter = CounterMetric(
- "outbound_commands", labels=["command"],
- )
+ self.inbound_commands_counter = defaultdict(int)
+ self.outbound_commands_counter = defaultdict(int)
def connectionMade(self):
logger.info("[%s] Connection established", self.id())
@@ -201,7 +205,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.last_received_command = self.clock.time_msec()
- self.inbound_commands_counter.inc(cmd_name)
+ self.inbound_commands_counter[cmd_name] = (
+ self.inbound_commands_counter[cmd_name] + 1)
cmd_cls = COMMAND_MAP[cmd_name]
try:
@@ -244,15 +249,15 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
becoming full.
"""
if self.state == ConnectionStates.CLOSED:
- logger.info("[%s] Not sending, connection closed", self.id())
+ logger.debug("[%s] Not sending, connection closed", self.id())
return
if do_buffer and self.state != ConnectionStates.ESTABLISHED:
self._queue_command(cmd)
return
- self.outbound_commands_counter.inc(cmd.NAME)
-
+ self.outbound_commands_counter[cmd.NAME] = (
+ self.outbound_commands_counter[cmd.NAME] + 1)
string = "%s %s" % (cmd.NAME, cmd.to_line(),)
if "\n" in string:
raise Exception("Unexpected newline in command: %r", string)
@@ -264,7 +269,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def _queue_command(self, cmd):
"""Queue the command until the connection is ready to write to again.
"""
- logger.info("[%s] Queing as conn %r, cmd: %r", self.id(), self.state, cmd)
+ logger.debug("[%s] Queing as conn %r, cmd: %r", self.id(), self.state, cmd)
self.pending_commands.append(cmd)
if len(self.pending_commands) > self.max_line_buffer:
@@ -317,9 +322,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def connectionLost(self, reason):
logger.info("[%s] Replication connection closed: %r", self.id(), reason)
if isinstance(reason, Failure):
- connection_close_counter.inc(reason.type.__name__)
+ connection_close_counter.labels(reason.type.__name__).inc()
else:
- connection_close_counter.inc(reason.__class__.__name__)
+ connection_close_counter.labels(reason.__class__.__name__).inc()
try:
# Remove us from list of connections to be monitored
@@ -392,7 +397,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
if stream_name == "ALL":
# Subscribe to all streams we're publishing to.
- for stream in self.streamer.streams_by_name.iterkeys():
+ for stream in iterkeys(self.streamer.streams_by_name):
self.subscribe_to_stream(stream, token)
else:
self.subscribe_to_stream(stream_name, token)
@@ -498,7 +503,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
BaseReplicationStreamProtocol.connectionMade(self)
# Once we've connected subscribe to the necessary streams
- for stream_name, token in self.handler.get_streams_to_replicate().iteritems():
+ for stream_name, token in iteritems(self.handler.get_streams_to_replicate()):
self.replicate(stream_name, token)
# Tell the server if we have any users currently syncing (should only
@@ -517,25 +522,28 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.send_error("Wrong remote")
def on_RDATA(self, cmd):
+ stream_name = cmd.stream_name
+ inbound_rdata_count.labels(stream_name).inc()
+
try:
- row = STREAMS_MAP[cmd.stream_name].ROW_TYPE(*cmd.row)
+ row = STREAMS_MAP[stream_name].ROW_TYPE(*cmd.row)
except Exception:
logger.exception(
"[%s] Failed to parse RDATA: %r %r",
- self.id(), cmd.stream_name, cmd.row
+ self.id(), stream_name, cmd.row
)
raise
if cmd.token is None:
# I.e. this is part of a batch of updates for this stream. Batch
# until we get an update for the stream with a non None token
- self.pending_batches.setdefault(cmd.stream_name, []).append(row)
+ self.pending_batches.setdefault(stream_name, []).append(row)
else:
# Check if this is the last of a batch of updates
- rows = self.pending_batches.pop(cmd.stream_name, [])
+ rows = self.pending_batches.pop(stream_name, [])
rows.append(row)
- self.handler.on_rdata(cmd.stream_name, cmd.token, rows)
+ self.handler.on_rdata(stream_name, cmd.token, rows)
def on_POSITION(self, cmd):
self.handler.on_position(cmd.stream_name, cmd.token)
@@ -563,13 +571,13 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# The following simply registers metrics for the replication connections
-metrics.register_callback(
- "pending_commands",
+pending_commands = LaterGauge(
+ "synapse_replication_tcp_protocol_pending_commands",
+ "",
+ ["name", "conn_id"],
lambda: {
- (p.name, p.conn_id): len(p.pending_commands)
- for p in connected_connections
+ (p.name, p.conn_id): len(p.pending_commands) for p in connected_connections
},
- labels=["name", "conn_id"],
)
@@ -580,13 +588,13 @@ def transport_buffer_size(protocol):
return 0
-metrics.register_callback(
- "transport_send_buffer",
+transport_send_buffer = LaterGauge(
+ "synapse_replication_tcp_protocol_transport_send_buffer",
+ "",
+ ["name", "conn_id"],
lambda: {
- (p.name, p.conn_id): transport_buffer_size(p)
- for p in connected_connections
+ (p.name, p.conn_id): transport_buffer_size(p) for p in connected_connections
},
- labels=["name", "conn_id"],
)
@@ -605,42 +613,51 @@ def transport_kernel_read_buffer_size(protocol, read=True):
return 0
-metrics.register_callback(
- "transport_kernel_send_buffer",
+tcp_transport_kernel_send_buffer = LaterGauge(
+ "synapse_replication_tcp_protocol_transport_kernel_send_buffer",
+ "",
+ ["name", "conn_id"],
lambda: {
(p.name, p.conn_id): transport_kernel_read_buffer_size(p, False)
for p in connected_connections
},
- labels=["name", "conn_id"],
)
-metrics.register_callback(
- "transport_kernel_read_buffer",
+tcp_transport_kernel_read_buffer = LaterGauge(
+ "synapse_replication_tcp_protocol_transport_kernel_read_buffer",
+ "",
+ ["name", "conn_id"],
lambda: {
(p.name, p.conn_id): transport_kernel_read_buffer_size(p, True)
for p in connected_connections
},
- labels=["name", "conn_id"],
)
-metrics.register_callback(
- "inbound_commands",
+tcp_inbound_commands = LaterGauge(
+ "synapse_replication_tcp_protocol_inbound_commands",
+ "",
+ ["command", "name", "conn_id"],
lambda: {
(k[0], p.name, p.conn_id): count
for p in connected_connections
- for k, count in p.inbound_commands_counter.counts.iteritems()
+ for k, count in iteritems(p.inbound_commands_counter)
},
- labels=["command", "name", "conn_id"],
)
-metrics.register_callback(
- "outbound_commands",
+tcp_outbound_commands = LaterGauge(
+ "synapse_replication_tcp_protocol_outbound_commands",
+ "",
+ ["command", "name", "conn_id"],
lambda: {
(k[0], p.name, p.conn_id): count
for p in connected_connections
- for k, count in p.outbound_commands_counter.counts.iteritems()
+ for k, count in iteritems(p.outbound_commands_counter)
},
- labels=["command", "name", "conn_id"],
+)
+
+# number of updates received for each RDATA stream
+inbound_rdata_count = Counter(
+ "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"]
)
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 3ea3ca5a6f..611fb66e1d 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -15,27 +15,29 @@
"""The server side of the replication stream.
"""
-from twisted.internet import defer, reactor
-from twisted.internet.protocol import Factory
+import logging
-from streams import STREAMS_MAP, FederationStream
-from protocol import ServerReplicationStreamProtocol
+from six import itervalues
-from synapse.util.metrics import Measure, measure_func
+from prometheus_client import Counter
-import logging
-import synapse.metrics
+from twisted.internet import defer
+from twisted.internet.protocol import Factory
+from synapse.metrics import LaterGauge
+from synapse.util.metrics import Measure, measure_func
+
+from .protocol import ServerReplicationStreamProtocol
+from .streams import STREAMS_MAP, FederationStream
-metrics = synapse.metrics.get_metrics_for(__name__)
-stream_updates_counter = metrics.register_counter(
- "stream_updates", labels=["stream_name"]
-)
-user_sync_counter = metrics.register_counter("user_sync")
-federation_ack_counter = metrics.register_counter("federation_ack")
-remove_pusher_counter = metrics.register_counter("remove_pusher")
-invalidate_cache_counter = metrics.register_counter("invalidate_cache")
-user_ip_cache_counter = metrics.register_counter("user_ip_cache")
+stream_updates_counter = Counter("synapse_replication_tcp_resource_stream_updates",
+ "", ["stream_name"])
+user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
+federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
+remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
+invalidate_cache_counter = Counter("synapse_replication_tcp_resource_invalidate_cache",
+ "")
+user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
logger = logging.getLogger(__name__)
@@ -69,33 +71,34 @@ class ReplicationStreamer(object):
self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
+ self._server_notices_sender = hs.get_server_notices_sender()
# Current connections.
self.connections = []
- metrics.register_callback("total_connections", lambda: len(self.connections))
+ LaterGauge("synapse_replication_tcp_resource_total_connections", "", [],
+ lambda: len(self.connections))
# List of streams that clients can subscribe to.
# We only support federation stream if federation sending hase been
# disabled on the master.
self.streams = [
- stream(hs) for stream in STREAMS_MAP.itervalues()
+ stream(hs) for stream in itervalues(STREAMS_MAP)
if stream != FederationStream or not hs.config.send_federation
]
self.streams_by_name = {stream.NAME: stream for stream in self.streams}
- metrics.register_callback(
- "connections_per_stream",
+ LaterGauge(
+ "synapse_replication_tcp_resource_connections_per_stream", "",
+ ["stream_name"],
lambda: {
(stream_name,): len([
conn for conn in self.connections
if stream_name in conn.replication_streams
])
for stream_name in self.streams_by_name
- },
- labels=["stream_name"],
- )
+ })
self.federation_sender = None
if not hs.config.send_federation:
@@ -107,7 +110,7 @@ class ReplicationStreamer(object):
self.is_looping = False
self.pending_updates = False
- reactor.addSystemEventTrigger("before", "shutdown", self.on_shutdown)
+ hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.on_shutdown)
def on_shutdown(self):
# close all connections on shutdown
@@ -160,7 +163,11 @@ class ReplicationStreamer(object):
"Getting stream: %s: %s -> %s",
stream.NAME, stream.last_token, stream.upto_token
)
- updates, current_token = yield stream.get_updates()
+ try:
+ updates, current_token = yield stream.get_updates()
+ except Exception:
+ logger.info("Failed to handle stream %s", stream.NAME)
+ raise
logger.debug(
"Sending %d updates to %d connections",
@@ -171,7 +178,7 @@ class ReplicationStreamer(object):
logger.info(
"Streaming: %s -> %s", stream.NAME, updates[-1][0]
)
- stream_updates_counter.inc_by(len(updates), stream.NAME)
+ stream_updates_counter.labels(stream.NAME).inc(len(updates))
# Some streams return multiple rows with the same stream IDs,
# we need to make sure they get sent out in batches. We do
@@ -212,11 +219,12 @@ class ReplicationStreamer(object):
self.federation_sender.federation_ack(token)
@measure_func("repl.on_user_sync")
+ @defer.inlineCallbacks
def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
"""A client has started/stopped syncing on a worker.
"""
user_sync_counter.inc()
- self.presence_handler.update_external_syncs_row(
+ yield self.presence_handler.update_external_syncs_row(
conn_id, user_id, is_syncing, last_sync_ms,
)
@@ -240,13 +248,15 @@ class ReplicationStreamer(object):
getattr(self.store, cache_func).invalidate(tuple(keys))
@measure_func("repl.on_user_ip")
+ @defer.inlineCallbacks
def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
"""The client saw a user request
"""
user_ip_cache_counter.inc()
- self.store.insert_client_ip(
+ yield self.store.insert_client_ip(
user_id, access_token, ip, user_agent, device_id, last_seen,
)
+ yield self._server_notices_sender.on_user_ip(user_id)
def send_sync_to_all_connections(self, data):
"""Sends a SYNC command to all clients.
diff --git a/synapse/replication/tcp/streams.py b/synapse/replication/tcp/streams.py
index fbafe12cc2..55fe701c5c 100644
--- a/synapse/replication/tcp/streams.py
+++ b/synapse/replication/tcp/streams.py
@@ -24,11 +24,10 @@ Each stream is defined by the following information:
update_function: The function that returns a list of updates between two tokens
"""
-from twisted.internet import defer
-from collections import namedtuple
-
import logging
+from collections import namedtuple
+from twisted.internet import defer
logger = logging.getLogger(__name__)
@@ -118,6 +117,12 @@ CurrentStateDeltaStreamRow = namedtuple("CurrentStateDeltaStream", (
"state_key", # str
"event_id", # str, optional
))
+GroupsStreamRow = namedtuple("GroupsStreamRow", (
+ "group_id", # str
+ "user_id", # str
+ "type", # str
+ "content", # dict
+))
class Stream(object):
@@ -464,6 +469,19 @@ class CurrentStateDeltaStream(Stream):
super(CurrentStateDeltaStream, self).__init__(hs)
+class GroupServerStream(Stream):
+ NAME = "groups"
+ ROW_TYPE = GroupsStreamRow
+
+ def __init__(self, hs):
+ store = hs.get_datastore()
+
+ self.current_token = store.get_group_stream_token
+ self.update_function = store.get_all_groups_changes
+
+ super(GroupServerStream, self).__init__(hs)
+
+
STREAMS_MAP = {
stream.NAME: stream
for stream in (
@@ -482,5 +500,6 @@ STREAMS_MAP = {
TagAccountDataStream,
AccountDataStream,
CurrentStateDeltaStream,
+ GroupServerStream,
)
}
|