summary refs log tree commit diff
path: root/synapse/replication/tcp
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp')
-rw-r--r--synapse/replication/tcp/client.py21
-rw-r--r--synapse/replication/tcp/commands.py28
-rw-r--r--synapse/replication/tcp/protocol.py135
-rw-r--r--synapse/replication/tcp/resource.py35
-rw-r--r--synapse/replication/tcp/streams.py5
5 files changed, 143 insertions, 81 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 6d2513c4e2..cbe9645817 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -15,17 +15,20 @@
 """A replication client for use by synapse workers.
 """
 
-from twisted.internet import reactor, defer
+import logging
+
+from twisted.internet import defer
 from twisted.internet.protocol import ReconnectingClientFactory
 
 from .commands import (
-    FederationAckCommand, UserSyncCommand, RemovePusherCommand, InvalidateCacheCommand,
+    FederationAckCommand,
+    InvalidateCacheCommand,
+    RemovePusherCommand,
     UserIpCommand,
+    UserSyncCommand,
 )
 from .protocol import ClientReplicationStreamProtocol
 
-import logging
-
 logger = logging.getLogger(__name__)
 
 
@@ -44,7 +47,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
         self.server_name = hs.config.server_name
         self._clock = hs.get_clock()  # As self.clock is defined in super class
 
-        reactor.addSystemEventTrigger("before", "shutdown", self.stopTrying)
+        hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
 
     def startedConnecting(self, connector):
         logger.info("Connecting to replication: %r", connector.getDestination())
@@ -95,7 +98,7 @@ class ReplicationClientHandler(object):
         factory = ReplicationClientFactory(hs, client_name, self)
         host = hs.config.worker_replication_host
         port = hs.config.worker_replication_port
-        reactor.connectTCP(host, port, factory)
+        hs.get_reactor().connectTCP(host, port, factory)
 
     def on_rdata(self, stream_name, token, rows):
         """Called when we get new replication data. By default this just pokes
@@ -104,7 +107,7 @@ class ReplicationClientHandler(object):
         Can be overriden in subclasses to handle more.
         """
         logger.info("Received rdata %s -> %s", stream_name, token)
-        self.store.process_replication_rows(stream_name, token, rows)
+        return self.store.process_replication_rows(stream_name, token, rows)
 
     def on_position(self, stream_name, token):
         """Called when we get new position data. By default this just pokes
@@ -112,7 +115,7 @@ class ReplicationClientHandler(object):
 
         Can be overriden in subclasses to handle more.
         """
-        self.store.process_replication_rows(stream_name, token, [])
+        return self.store.process_replication_rows(stream_name, token, [])
 
     def on_sync(self, data):
         """When we received a SYNC we wake up any deferreds that were waiting
@@ -189,7 +192,7 @@ class ReplicationClientHandler(object):
         """Returns a deferred that is resolved when we receive a SYNC command
         with given data.
 
-        Used by tests.
+        [Not currently] used by tests.
         """
         return self.awaiting_syncs.setdefault(data, defer.Deferred())
 
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 12aac3cc6b..327556f6a1 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -19,13 +19,17 @@ allowed to be sent by which side.
 """
 
 import logging
-import simplejson
+import platform
 
+if platform.python_implementation() == "PyPy":
+    import json
+    _json_encoder = json.JSONEncoder()
+else:
+    import simplejson as json
+    _json_encoder = json.JSONEncoder(namedtuple_as_object=False)
 
 logger = logging.getLogger(__name__)
 
-_json_encoder = simplejson.JSONEncoder(namedtuple_as_object=False)
-
 
 class Command(object):
     """The base command class.
@@ -55,6 +59,12 @@ class Command(object):
         """
         return self.data
 
+    def get_logcontext_id(self):
+        """Get a suitable string for the logcontext when processing this command"""
+
+        # by default, we just use the command name.
+        return self.NAME
+
 
 class ServerCommand(Command):
     """Sent by the server on new connection and includes the server_name.
@@ -102,7 +112,7 @@ class RdataCommand(Command):
         return cls(
             stream_name,
             None if token == "batch" else int(token),
-            simplejson.loads(row_json)
+            json.loads(row_json)
         )
 
     def to_line(self):
@@ -112,6 +122,9 @@ class RdataCommand(Command):
             _json_encoder.encode(self.row),
         ))
 
+    def get_logcontext_id(self):
+        return "RDATA-" + self.stream_name
+
 
 class PositionCommand(Command):
     """Sent by the client to tell the client the stream postition without
@@ -186,6 +199,9 @@ class ReplicateCommand(Command):
     def to_line(self):
         return " ".join((self.stream_name, str(self.token),))
 
+    def get_logcontext_id(self):
+        return "REPLICATE-" + self.stream_name
+
 
 class UserSyncCommand(Command):
     """Sent by the client to inform the server that a user has started or
@@ -300,7 +316,7 @@ class InvalidateCacheCommand(Command):
     def from_line(cls, line):
         cache_func, keys_json = line.split(" ", 1)
 
-        return cls(cache_func, simplejson.loads(keys_json))
+        return cls(cache_func, json.loads(keys_json))
 
     def to_line(self):
         return " ".join((
@@ -329,7 +345,7 @@ class UserIpCommand(Command):
     def from_line(cls, line):
         user_id, jsn = line.split(" ", 1)
 
-        access_token, ip, user_agent, device_id, last_seen = simplejson.loads(jsn)
+        access_token, ip, user_agent, device_id, last_seen = json.loads(jsn)
 
         return cls(
             user_id, access_token, ip, user_agent, device_id, last_seen
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index c870475cd1..74e892c104 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -49,29 +49,39 @@ indicate which side is sending, these are *not* included on the wire::
     * connection closed by server *
 """
 
+import fcntl
+import logging
+import struct
+from collections import defaultdict
+
+from six import iteritems, iterkeys
+
+from prometheus_client import Counter
+
 from twisted.internet import defer
 from twisted.protocols.basic import LineOnlyReceiver
 from twisted.python.failure import Failure
 
-from .commands import (
-    COMMAND_MAP, VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS,
-    ErrorCommand, ServerCommand, RdataCommand, PositionCommand, PingCommand,
-    NameCommand, ReplicateCommand, UserSyncCommand, SyncCommand,
-)
-from .streams import STREAMS_MAP
-
 from synapse.metrics import LaterGauge
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
 from synapse.util.stringutils import random_string
 
-from prometheus_client import Counter
-
-from collections import defaultdict
-
-from six import iterkeys, iteritems
-
-import logging
-import struct
-import fcntl
+from .commands import (
+    COMMAND_MAP,
+    VALID_CLIENT_COMMANDS,
+    VALID_SERVER_COMMANDS,
+    ErrorCommand,
+    NameCommand,
+    PingCommand,
+    PositionCommand,
+    RdataCommand,
+    ReplicateCommand,
+    ServerCommand,
+    SyncCommand,
+    UserSyncCommand,
+)
+from .streams import STREAMS_MAP
 
 connection_close_counter = Counter(
     "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"])
@@ -214,7 +224,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
         # Now lets try and call on_<CMD_NAME> function
         try:
-            getattr(self, "on_%s" % (cmd_name,))(cmd)
+            run_as_background_process(
+                "replication-" + cmd.get_logcontext_id(),
+                getattr(self, "on_%s" % (cmd_name,)),
+                cmd,
+            )
         except Exception:
             logger.exception("[%s] Failed to handle line: %r", self.id(), line)
 
@@ -379,7 +393,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
         self.name = cmd.data
 
     def on_USER_SYNC(self, cmd):
-        self.streamer.on_user_sync(
+        return self.streamer.on_user_sync(
             self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms,
         )
 
@@ -389,22 +403,33 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
 
         if stream_name == "ALL":
             # Subscribe to all streams we're publishing to.
-            for stream in iterkeys(self.streamer.streams_by_name):
-                self.subscribe_to_stream(stream, token)
+            deferreds = [
+                run_in_background(
+                    self.subscribe_to_stream,
+                    stream, token,
+                )
+                for stream in iterkeys(self.streamer.streams_by_name)
+            ]
+
+            return make_deferred_yieldable(
+                defer.gatherResults(deferreds, consumeErrors=True)
+            )
         else:
-            self.subscribe_to_stream(stream_name, token)
+            return self.subscribe_to_stream(stream_name, token)
 
     def on_FEDERATION_ACK(self, cmd):
-        self.streamer.federation_ack(cmd.token)
+        return self.streamer.federation_ack(cmd.token)
 
     def on_REMOVE_PUSHER(self, cmd):
-        self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
+        return self.streamer.on_remove_pusher(
+            cmd.app_id, cmd.push_key, cmd.user_id,
+        )
 
     def on_INVALIDATE_CACHE(self, cmd):
-        self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
+        return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
 
     def on_USER_IP(self, cmd):
-        self.streamer.on_user_ip(
+        return self.streamer.on_user_ip(
             cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id,
             cmd.last_seen,
         )
@@ -534,14 +559,13 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
             # Check if this is the last of a batch of updates
             rows = self.pending_batches.pop(stream_name, [])
             rows.append(row)
-
-            self.handler.on_rdata(stream_name, cmd.token, rows)
+            return self.handler.on_rdata(stream_name, cmd.token, rows)
 
     def on_POSITION(self, cmd):
-        self.handler.on_position(cmd.stream_name, cmd.token)
+        return self.handler.on_position(cmd.stream_name, cmd.token)
 
     def on_SYNC(self, cmd):
-        self.handler.on_sync(cmd.data)
+        return self.handler.on_sync(cmd.data)
 
     def replicate(self, stream_name, token):
         """Send the subscription request to the server
@@ -564,11 +588,13 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
 # The following simply registers metrics for the replication connections
 
 pending_commands = LaterGauge(
-    "pending_commands", "", ["name", "conn_id"],
+    "synapse_replication_tcp_protocol_pending_commands",
+    "",
+    ["name", "conn_id"],
     lambda: {
-        (p.name, p.conn_id): len(p.pending_commands)
-        for p in connected_connections
-    })
+        (p.name, p.conn_id): len(p.pending_commands) for p in connected_connections
+    },
+)
 
 
 def transport_buffer_size(protocol):
@@ -579,11 +605,13 @@ def transport_buffer_size(protocol):
 
 
 transport_send_buffer = LaterGauge(
-    "synapse_replication_tcp_transport_send_buffer", "", ["name", "conn_id"],
+    "synapse_replication_tcp_protocol_transport_send_buffer",
+    "",
+    ["name", "conn_id"],
     lambda: {
-        (p.name, p.conn_id): transport_buffer_size(p)
-        for p in connected_connections
-    })
+        (p.name, p.conn_id): transport_buffer_size(p) for p in connected_connections
+    },
+)
 
 
 def transport_kernel_read_buffer_size(protocol, read=True):
@@ -602,37 +630,50 @@ def transport_kernel_read_buffer_size(protocol, read=True):
 
 
 tcp_transport_kernel_send_buffer = LaterGauge(
-    "synapse_replication_tcp_transport_kernel_send_buffer", "", ["name", "conn_id"],
+    "synapse_replication_tcp_protocol_transport_kernel_send_buffer",
+    "",
+    ["name", "conn_id"],
     lambda: {
         (p.name, p.conn_id): transport_kernel_read_buffer_size(p, False)
         for p in connected_connections
-    })
+    },
+)
 
 
 tcp_transport_kernel_read_buffer = LaterGauge(
-    "synapse_replication_tcp_transport_kernel_read_buffer", "", ["name", "conn_id"],
+    "synapse_replication_tcp_protocol_transport_kernel_read_buffer",
+    "",
+    ["name", "conn_id"],
     lambda: {
         (p.name, p.conn_id): transport_kernel_read_buffer_size(p, True)
         for p in connected_connections
-    })
+    },
+)
 
 
 tcp_inbound_commands = LaterGauge(
-    "synapse_replication_tcp_inbound_commands", "", ["command", "name", "conn_id"],
+    "synapse_replication_tcp_protocol_inbound_commands",
+    "",
+    ["command", "name", "conn_id"],
     lambda: {
         (k[0], p.name, p.conn_id): count
         for p in connected_connections
         for k, count in iteritems(p.inbound_commands_counter)
-    })
+    },
+)
 
 tcp_outbound_commands = LaterGauge(
-    "synapse_replication_tcp_outbound_commands", "", ["command", "name", "conn_id"],
+    "synapse_replication_tcp_protocol_outbound_commands",
+    "",
+    ["command", "name", "conn_id"],
     lambda: {
         (k[0], p.name, p.conn_id): count
         for p in connected_connections
         for k, count in iteritems(p.outbound_commands_counter)
-    })
+    },
+)
 
 # number of updates received for each RDATA stream
-inbound_rdata_count = Counter("synapse_replication_tcp_inbound_rdata_count", "",
-                              ["stream_name"])
+inbound_rdata_count = Counter(
+    "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"]
+)
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 63bd6d2652..fd59f1595f 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -15,19 +15,21 @@
 """The server side of the replication stream.
 """
 
-from twisted.internet import defer, reactor
-from twisted.internet.protocol import Factory
+import logging
 
-from .streams import STREAMS_MAP, FederationStream
-from .protocol import ServerReplicationStreamProtocol
+from six import itervalues
 
-from synapse.util.metrics import Measure, measure_func
-from synapse.metrics import LaterGauge
+from prometheus_client import Counter
 
-import logging
+from twisted.internet import defer
+from twisted.internet.protocol import Factory
 
-from prometheus_client import Counter
-from six import itervalues
+from synapse.metrics import LaterGauge
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util.metrics import Measure, measure_func
+
+from .protocol import ServerReplicationStreamProtocol
+from .streams import STREAMS_MAP, FederationStream
 
 stream_updates_counter = Counter("synapse_replication_tcp_resource_stream_updates",
                                  "", ["stream_name"])
@@ -109,14 +111,13 @@ class ReplicationStreamer(object):
         self.is_looping = False
         self.pending_updates = False
 
-        reactor.addSystemEventTrigger("before", "shutdown", self.on_shutdown)
+        hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.on_shutdown)
 
     def on_shutdown(self):
         # close all connections on shutdown
         for conn in self.connections:
             conn.send_error("server shutting down")
 
-    @defer.inlineCallbacks
     def on_notifier_poke(self):
         """Checks if there is actually any new data and sends it to the
         connections if there are.
@@ -131,14 +132,16 @@ class ReplicationStreamer(object):
                 stream.discard_updates_and_advance()
             return
 
-        # If we're in the process of checking for new updates, mark that fact
-        # and return
+        self.pending_updates = True
+
         if self.is_looping:
-            logger.debug("Noitifier poke loop already running")
-            self.pending_updates = True
+            logger.debug("Notifier poke loop already running")
             return
 
-        self.pending_updates = True
+        run_as_background_process("replication_notifier", self._run_notifier_loop)
+
+    @defer.inlineCallbacks
+    def _run_notifier_loop(self):
         self.is_looping = True
 
         try:
diff --git a/synapse/replication/tcp/streams.py b/synapse/replication/tcp/streams.py
index 4c60bf79f9..55fe701c5c 100644
--- a/synapse/replication/tcp/streams.py
+++ b/synapse/replication/tcp/streams.py
@@ -24,11 +24,10 @@ Each stream is defined by the following information:
     update_function:    The function that returns a list of updates between two tokens
 """
 
-from twisted.internet import defer
-from collections import namedtuple
-
 import logging
+from collections import namedtuple
 
+from twisted.internet import defer
 
 logger = logging.getLogger(__name__)