summary refs log tree commit diff
path: root/synapse/replication/tcp
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp')
-rw-r--r--synapse/replication/tcp/client.py15
-rw-r--r--synapse/replication/tcp/commands.py28
-rw-r--r--synapse/replication/tcp/protocol.py141
-rw-r--r--synapse/replication/tcp/resource.py66
-rw-r--r--synapse/replication/tcp/streams.py25
5 files changed, 168 insertions, 107 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 6d2513c4e2..e592ab57bf 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -15,17 +15,20 @@
 """A replication client for use by synapse workers.
 """
 
-from twisted.internet import reactor, defer
+import logging
+
+from twisted.internet import defer
 from twisted.internet.protocol import ReconnectingClientFactory
 
 from .commands import (
-    FederationAckCommand, UserSyncCommand, RemovePusherCommand, InvalidateCacheCommand,
+    FederationAckCommand,
+    InvalidateCacheCommand,
+    RemovePusherCommand,
     UserIpCommand,
+    UserSyncCommand,
 )
 from .protocol import ClientReplicationStreamProtocol
 
-import logging
-
 logger = logging.getLogger(__name__)
 
 
@@ -44,7 +47,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
         self.server_name = hs.config.server_name
         self._clock = hs.get_clock()  # As self.clock is defined in super class
 
-        reactor.addSystemEventTrigger("before", "shutdown", self.stopTrying)
+        hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
 
     def startedConnecting(self, connector):
         logger.info("Connecting to replication: %r", connector.getDestination())
@@ -95,7 +98,7 @@ class ReplicationClientHandler(object):
         factory = ReplicationClientFactory(hs, client_name, self)
         host = hs.config.worker_replication_host
         port = hs.config.worker_replication_port
-        reactor.connectTCP(host, port, factory)
+        hs.get_reactor().connectTCP(host, port, factory)
 
     def on_rdata(self, stream_name, token, rows):
         """Called when we get new replication data. By default this just pokes
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index a009214e43..f3908df642 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -19,8 +19,14 @@ allowed to be sent by which side.
 """
 
 import logging
-import ujson as json
+import platform
 
+if platform.python_implementation() == "PyPy":
+    import json
+    _json_encoder = json.JSONEncoder()
+else:
+    import simplejson as json
+    _json_encoder = json.JSONEncoder(namedtuple_as_object=False)
 
 logger = logging.getLogger(__name__)
 
@@ -107,7 +113,7 @@ class RdataCommand(Command):
         return " ".join((
             self.stream_name,
             str(self.token) if self.token is not None else "batch",
-            json.dumps(self.row),
+            _json_encoder.encode(self.row),
         ))
 
 
@@ -301,7 +307,9 @@ class InvalidateCacheCommand(Command):
         return cls(cache_func, json.loads(keys_json))
 
     def to_line(self):
-        return " ".join((self.cache_func, json.dumps(self.keys)))
+        return " ".join((
+            self.cache_func, _json_encoder.encode(self.keys),
+        ))
 
 
 class UserIpCommand(Command):
@@ -323,14 +331,18 @@ class UserIpCommand(Command):
 
     @classmethod
     def from_line(cls, line):
-        user_id, access_token, ip, device_id, last_seen, user_agent = line.split(" ", 5)
+        user_id, jsn = line.split(" ", 1)
+
+        access_token, ip, user_agent, device_id, last_seen = json.loads(jsn)
 
-        return cls(user_id, access_token, ip, user_agent, device_id, int(last_seen))
+        return cls(
+            user_id, access_token, ip, user_agent, device_id, last_seen
+        )
 
     def to_line(self):
-        return " ".join((
-            self.user_id, self.access_token, self.ip, self.device_id,
-            str(self.last_seen), self.user_agent,
+        return self.user_id + " " + _json_encoder.encode((
+            self.access_token, self.ip, self.user_agent, self.device_id,
+            self.last_seen,
         ))
 
 
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 062272f8dd..dec5ac0913 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -49,32 +49,40 @@ indicate which side is sending, these are *not* included on the wire::
     * connection closed by server *
 """
 
+import fcntl
+import logging
+import struct
+from collections import defaultdict
+
+from six import iteritems, iterkeys
+
+from prometheus_client import Counter
+
 from twisted.internet import defer
 from twisted.protocols.basic import LineOnlyReceiver
 from twisted.python.failure import Failure
 
-from commands import (
-    COMMAND_MAP, VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS,
-    ErrorCommand, ServerCommand, RdataCommand, PositionCommand, PingCommand,
-    NameCommand, ReplicateCommand, UserSyncCommand, SyncCommand,
-)
-from streams import STREAMS_MAP
-
+from synapse.metrics import LaterGauge
 from synapse.util.stringutils import random_string
-from synapse.metrics.metric import CounterMetric
-
-import logging
-import synapse.metrics
-import struct
-import fcntl
 
-
-metrics = synapse.metrics.get_metrics_for(__name__)
-
-connection_close_counter = metrics.register_counter(
-    "close_reason", labels=["reason_type"],
+from .commands import (
+    COMMAND_MAP,
+    VALID_CLIENT_COMMANDS,
+    VALID_SERVER_COMMANDS,
+    ErrorCommand,
+    NameCommand,
+    PingCommand,
+    PositionCommand,
+    RdataCommand,
+    ReplicateCommand,
+    ServerCommand,
+    SyncCommand,
+    UserSyncCommand,
 )
+from .streams import STREAMS_MAP
 
+connection_close_counter = Counter(
+    "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"])
 
 # A list of all connected protocols. This allows us to send metrics about the
 # connections.
@@ -136,12 +144,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         # The LoopingCall for sending pings.
         self._send_ping_loop = None
 
-        self.inbound_commands_counter = CounterMetric(
-            "inbound_commands", labels=["command"],
-        )
-        self.outbound_commands_counter = CounterMetric(
-            "outbound_commands", labels=["command"],
-        )
+        self.inbound_commands_counter = defaultdict(int)
+        self.outbound_commands_counter = defaultdict(int)
 
     def connectionMade(self):
         logger.info("[%s] Connection established", self.id())
@@ -201,7 +205,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
         self.last_received_command = self.clock.time_msec()
 
-        self.inbound_commands_counter.inc(cmd_name)
+        self.inbound_commands_counter[cmd_name] = (
+            self.inbound_commands_counter[cmd_name] + 1)
 
         cmd_cls = COMMAND_MAP[cmd_name]
         try:
@@ -244,15 +249,15 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
                 becoming full.
         """
         if self.state == ConnectionStates.CLOSED:
-            logger.info("[%s] Not sending, connection closed", self.id())
+            logger.debug("[%s] Not sending, connection closed", self.id())
             return
 
         if do_buffer and self.state != ConnectionStates.ESTABLISHED:
             self._queue_command(cmd)
             return
 
-        self.outbound_commands_counter.inc(cmd.NAME)
-
+        self.outbound_commands_counter[cmd.NAME] = (
+            self.outbound_commands_counter[cmd.NAME] + 1)
         string = "%s %s" % (cmd.NAME, cmd.to_line(),)
         if "\n" in string:
             raise Exception("Unexpected newline in command: %r", string)
@@ -264,7 +269,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
     def _queue_command(self, cmd):
         """Queue the command until the connection is ready to write to again.
         """
-        logger.info("[%s] Queing as conn %r, cmd: %r", self.id(), self.state, cmd)
+        logger.debug("[%s] Queing as conn %r, cmd: %r", self.id(), self.state, cmd)
         self.pending_commands.append(cmd)
 
         if len(self.pending_commands) > self.max_line_buffer:
@@ -317,9 +322,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
     def connectionLost(self, reason):
         logger.info("[%s] Replication connection closed: %r", self.id(), reason)
         if isinstance(reason, Failure):
-            connection_close_counter.inc(reason.type.__name__)
+            connection_close_counter.labels(reason.type.__name__).inc()
         else:
-            connection_close_counter.inc(reason.__class__.__name__)
+            connection_close_counter.labels(reason.__class__.__name__).inc()
 
         try:
             # Remove us from list of connections to be monitored
@@ -392,7 +397,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
 
         if stream_name == "ALL":
             # Subscribe to all streams we're publishing to.
-            for stream in self.streamer.streams_by_name.iterkeys():
+            for stream in iterkeys(self.streamer.streams_by_name):
                 self.subscribe_to_stream(stream, token)
         else:
             self.subscribe_to_stream(stream_name, token)
@@ -498,7 +503,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
         BaseReplicationStreamProtocol.connectionMade(self)
 
         # Once we've connected subscribe to the necessary streams
-        for stream_name, token in self.handler.get_streams_to_replicate().iteritems():
+        for stream_name, token in iteritems(self.handler.get_streams_to_replicate()):
             self.replicate(stream_name, token)
 
         # Tell the server if we have any users currently syncing (should only
@@ -517,25 +522,28 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
             self.send_error("Wrong remote")
 
     def on_RDATA(self, cmd):
+        stream_name = cmd.stream_name
+        inbound_rdata_count.labels(stream_name).inc()
+
         try:
-            row = STREAMS_MAP[cmd.stream_name].ROW_TYPE(*cmd.row)
+            row = STREAMS_MAP[stream_name].ROW_TYPE(*cmd.row)
         except Exception:
             logger.exception(
                 "[%s] Failed to parse RDATA: %r %r",
-                self.id(), cmd.stream_name, cmd.row
+                self.id(), stream_name, cmd.row
             )
             raise
 
         if cmd.token is None:
             # I.e. this is part of a batch of updates for this stream. Batch
             # until we get an update for the stream with a non None token
-            self.pending_batches.setdefault(cmd.stream_name, []).append(row)
+            self.pending_batches.setdefault(stream_name, []).append(row)
         else:
             # Check if this is the last of a batch of updates
-            rows = self.pending_batches.pop(cmd.stream_name, [])
+            rows = self.pending_batches.pop(stream_name, [])
             rows.append(row)
 
-            self.handler.on_rdata(cmd.stream_name, cmd.token, rows)
+            self.handler.on_rdata(stream_name, cmd.token, rows)
 
     def on_POSITION(self, cmd):
         self.handler.on_position(cmd.stream_name, cmd.token)
@@ -563,13 +571,13 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
 
 # The following simply registers metrics for the replication connections
 
-metrics.register_callback(
-    "pending_commands",
+pending_commands = LaterGauge(
+    "synapse_replication_tcp_protocol_pending_commands",
+    "",
+    ["name", "conn_id"],
     lambda: {
-        (p.name, p.conn_id): len(p.pending_commands)
-        for p in connected_connections
+        (p.name, p.conn_id): len(p.pending_commands) for p in connected_connections
     },
-    labels=["name", "conn_id"],
 )
 
 
@@ -580,13 +588,13 @@ def transport_buffer_size(protocol):
     return 0
 
 
-metrics.register_callback(
-    "transport_send_buffer",
+transport_send_buffer = LaterGauge(
+    "synapse_replication_tcp_protocol_transport_send_buffer",
+    "",
+    ["name", "conn_id"],
     lambda: {
-        (p.name, p.conn_id): transport_buffer_size(p)
-        for p in connected_connections
+        (p.name, p.conn_id): transport_buffer_size(p) for p in connected_connections
     },
-    labels=["name", "conn_id"],
 )
 
 
@@ -605,42 +613,51 @@ def transport_kernel_read_buffer_size(protocol, read=True):
     return 0
 
 
-metrics.register_callback(
-    "transport_kernel_send_buffer",
+tcp_transport_kernel_send_buffer = LaterGauge(
+    "synapse_replication_tcp_protocol_transport_kernel_send_buffer",
+    "",
+    ["name", "conn_id"],
     lambda: {
         (p.name, p.conn_id): transport_kernel_read_buffer_size(p, False)
         for p in connected_connections
     },
-    labels=["name", "conn_id"],
 )
 
 
-metrics.register_callback(
-    "transport_kernel_read_buffer",
+tcp_transport_kernel_read_buffer = LaterGauge(
+    "synapse_replication_tcp_protocol_transport_kernel_read_buffer",
+    "",
+    ["name", "conn_id"],
     lambda: {
         (p.name, p.conn_id): transport_kernel_read_buffer_size(p, True)
         for p in connected_connections
     },
-    labels=["name", "conn_id"],
 )
 
 
-metrics.register_callback(
-    "inbound_commands",
+tcp_inbound_commands = LaterGauge(
+    "synapse_replication_tcp_protocol_inbound_commands",
+    "",
+    ["command", "name", "conn_id"],
     lambda: {
         (k[0], p.name, p.conn_id): count
         for p in connected_connections
-        for k, count in p.inbound_commands_counter.counts.iteritems()
+        for k, count in iteritems(p.inbound_commands_counter)
     },
-    labels=["command", "name", "conn_id"],
 )
 
-metrics.register_callback(
-    "outbound_commands",
+tcp_outbound_commands = LaterGauge(
+    "synapse_replication_tcp_protocol_outbound_commands",
+    "",
+    ["command", "name", "conn_id"],
     lambda: {
         (k[0], p.name, p.conn_id): count
         for p in connected_connections
-        for k, count in p.outbound_commands_counter.counts.iteritems()
+        for k, count in iteritems(p.outbound_commands_counter)
     },
-    labels=["command", "name", "conn_id"],
+)
+
+# number of updates received for each RDATA stream
+inbound_rdata_count = Counter(
+    "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"]
 )
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 3ea3ca5a6f..611fb66e1d 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -15,27 +15,29 @@
 """The server side of the replication stream.
 """
 
-from twisted.internet import defer, reactor
-from twisted.internet.protocol import Factory
+import logging
 
-from streams import STREAMS_MAP, FederationStream
-from protocol import ServerReplicationStreamProtocol
+from six import itervalues
 
-from synapse.util.metrics import Measure, measure_func
+from prometheus_client import Counter
 
-import logging
-import synapse.metrics
+from twisted.internet import defer
+from twisted.internet.protocol import Factory
 
+from synapse.metrics import LaterGauge
+from synapse.util.metrics import Measure, measure_func
+
+from .protocol import ServerReplicationStreamProtocol
+from .streams import STREAMS_MAP, FederationStream
 
-metrics = synapse.metrics.get_metrics_for(__name__)
-stream_updates_counter = metrics.register_counter(
-    "stream_updates", labels=["stream_name"]
-)
-user_sync_counter = metrics.register_counter("user_sync")
-federation_ack_counter = metrics.register_counter("federation_ack")
-remove_pusher_counter = metrics.register_counter("remove_pusher")
-invalidate_cache_counter = metrics.register_counter("invalidate_cache")
-user_ip_cache_counter = metrics.register_counter("user_ip_cache")
+stream_updates_counter = Counter("synapse_replication_tcp_resource_stream_updates",
+                                 "", ["stream_name"])
+user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
+federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
+remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
+invalidate_cache_counter = Counter("synapse_replication_tcp_resource_invalidate_cache",
+                                   "")
+user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
 
 logger = logging.getLogger(__name__)
 
@@ -69,33 +71,34 @@ class ReplicationStreamer(object):
         self.presence_handler = hs.get_presence_handler()
         self.clock = hs.get_clock()
         self.notifier = hs.get_notifier()
+        self._server_notices_sender = hs.get_server_notices_sender()
 
         # Current connections.
         self.connections = []
 
-        metrics.register_callback("total_connections", lambda: len(self.connections))
+        LaterGauge("synapse_replication_tcp_resource_total_connections", "", [],
+                   lambda: len(self.connections))
 
         # List of streams that clients can subscribe to.
         # We only support federation stream if federation sending hase been
         # disabled on the master.
         self.streams = [
-            stream(hs) for stream in STREAMS_MAP.itervalues()
+            stream(hs) for stream in itervalues(STREAMS_MAP)
             if stream != FederationStream or not hs.config.send_federation
         ]
 
         self.streams_by_name = {stream.NAME: stream for stream in self.streams}
 
-        metrics.register_callback(
-            "connections_per_stream",
+        LaterGauge(
+            "synapse_replication_tcp_resource_connections_per_stream", "",
+            ["stream_name"],
             lambda: {
                 (stream_name,): len([
                     conn for conn in self.connections
                     if stream_name in conn.replication_streams
                 ])
                 for stream_name in self.streams_by_name
-            },
-            labels=["stream_name"],
-        )
+            })
 
         self.federation_sender = None
         if not hs.config.send_federation:
@@ -107,7 +110,7 @@ class ReplicationStreamer(object):
         self.is_looping = False
         self.pending_updates = False
 
-        reactor.addSystemEventTrigger("before", "shutdown", self.on_shutdown)
+        hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.on_shutdown)
 
     def on_shutdown(self):
         # close all connections on shutdown
@@ -160,7 +163,11 @@ class ReplicationStreamer(object):
                             "Getting stream: %s: %s -> %s",
                             stream.NAME, stream.last_token, stream.upto_token
                         )
-                        updates, current_token = yield stream.get_updates()
+                        try:
+                            updates, current_token = yield stream.get_updates()
+                        except Exception:
+                            logger.info("Failed to handle stream %s", stream.NAME)
+                            raise
 
                         logger.debug(
                             "Sending %d updates to %d connections",
@@ -171,7 +178,7 @@ class ReplicationStreamer(object):
                             logger.info(
                                 "Streaming: %s -> %s", stream.NAME, updates[-1][0]
                             )
-                            stream_updates_counter.inc_by(len(updates), stream.NAME)
+                            stream_updates_counter.labels(stream.NAME).inc(len(updates))
 
                         # Some streams return multiple rows with the same stream IDs,
                         # we need to make sure they get sent out in batches. We do
@@ -212,11 +219,12 @@ class ReplicationStreamer(object):
             self.federation_sender.federation_ack(token)
 
     @measure_func("repl.on_user_sync")
+    @defer.inlineCallbacks
     def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
         """A client has started/stopped syncing on a worker.
         """
         user_sync_counter.inc()
-        self.presence_handler.update_external_syncs_row(
+        yield self.presence_handler.update_external_syncs_row(
             conn_id, user_id, is_syncing, last_sync_ms,
         )
 
@@ -240,13 +248,15 @@ class ReplicationStreamer(object):
         getattr(self.store, cache_func).invalidate(tuple(keys))
 
     @measure_func("repl.on_user_ip")
+    @defer.inlineCallbacks
     def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
         """The client saw a user request
         """
         user_ip_cache_counter.inc()
-        self.store.insert_client_ip(
+        yield self.store.insert_client_ip(
             user_id, access_token, ip, user_agent, device_id, last_seen,
         )
+        yield self._server_notices_sender.on_user_ip(user_id)
 
     def send_sync_to_all_connections(self, data):
         """Sends a SYNC command to all clients.
diff --git a/synapse/replication/tcp/streams.py b/synapse/replication/tcp/streams.py
index fbafe12cc2..55fe701c5c 100644
--- a/synapse/replication/tcp/streams.py
+++ b/synapse/replication/tcp/streams.py
@@ -24,11 +24,10 @@ Each stream is defined by the following information:
     update_function:    The function that returns a list of updates between two tokens
 """
 
-from twisted.internet import defer
-from collections import namedtuple
-
 import logging
+from collections import namedtuple
 
+from twisted.internet import defer
 
 logger = logging.getLogger(__name__)
 
@@ -118,6 +117,12 @@ CurrentStateDeltaStreamRow = namedtuple("CurrentStateDeltaStream", (
     "state_key",  # str
     "event_id",  # str, optional
 ))
+GroupsStreamRow = namedtuple("GroupsStreamRow", (
+    "group_id",  # str
+    "user_id",  # str
+    "type",  # str
+    "content",  # dict
+))
 
 
 class Stream(object):
@@ -464,6 +469,19 @@ class CurrentStateDeltaStream(Stream):
         super(CurrentStateDeltaStream, self).__init__(hs)
 
 
+class GroupServerStream(Stream):
+    NAME = "groups"
+    ROW_TYPE = GroupsStreamRow
+
+    def __init__(self, hs):
+        store = hs.get_datastore()
+
+        self.current_token = store.get_group_stream_token
+        self.update_function = store.get_all_groups_changes
+
+        super(GroupServerStream, self).__init__(hs)
+
+
 STREAMS_MAP = {
     stream.NAME: stream
     for stream in (
@@ -482,5 +500,6 @@ STREAMS_MAP = {
         TagAccountDataStream,
         AccountDataStream,
         CurrentStateDeltaStream,
+        GroupServerStream,
     )
 }