summary refs log tree commit diff
path: root/synapse/replication/tcp/protocol.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/protocol.py')
-rw-r--r--synapse/replication/tcp/protocol.py76
1 files changed, 39 insertions, 37 deletions
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index b51590cf8f..97efb835ad 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -84,7 +84,8 @@ from .commands import (
 from .streams import STREAMS_MAP
 
 connection_close_counter = Counter(
-    "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"])
+    "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
+)
 
 # A list of all connected protocols. This allows us to send metrics about the
 # connections.
@@ -119,7 +120,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
     It also sends `PING` periodically, and correctly times out remote connections
     (if they send a `PING` command)
     """
-    delimiter = b'\n'
+
+    delimiter = b"\n"
 
     VALID_INBOUND_COMMANDS = []  # Valid commands we expect to receive
     VALID_OUTBOUND_COMMANDS = []  # Valid commans we can send
@@ -183,10 +185,14 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
             if now - self.last_sent_command >= PING_TIME:
                 self.send_command(PingCommand(now))
 
-            if self.received_ping and now - self.last_received_command > PING_TIMEOUT_MS:
+            if (
+                self.received_ping
+                and now - self.last_received_command > PING_TIMEOUT_MS
+            ):
                 logger.info(
                     "[%s] Connection hasn't received command in %r ms. Closing.",
-                    self.id(), now - self.last_received_command
+                    self.id(),
+                    now - self.last_received_command,
                 )
                 self.send_error("ping timeout")
 
@@ -208,7 +214,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         self.last_received_command = self.clock.time_msec()
 
         self.inbound_commands_counter[cmd_name] = (
-            self.inbound_commands_counter[cmd_name] + 1)
+            self.inbound_commands_counter[cmd_name] + 1
+        )
 
         cmd_cls = COMMAND_MAP[cmd_name]
         try:
@@ -224,9 +231,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
         # Now lets try and call on_<CMD_NAME> function
         run_as_background_process(
-            "replication-" + cmd.get_logcontext_id(),
-            self.handle_command,
-            cmd,
+            "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
         )
 
     def handle_command(self, cmd):
@@ -274,8 +279,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
             return
 
         self.outbound_commands_counter[cmd.NAME] = (
-            self.outbound_commands_counter[cmd.NAME] + 1)
-        string = "%s %s" % (cmd.NAME, cmd.to_line(),)
+            self.outbound_commands_counter[cmd.NAME] + 1
+        )
+        string = "%s %s" % (cmd.NAME, cmd.to_line())
         if "\n" in string:
             raise Exception("Unexpected newline in command: %r", string)
 
@@ -283,10 +289,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
         if len(encoded_string) > self.MAX_LENGTH:
             raise Exception(
-                "Failed to send command %s as too long (%d > %d)" % (
-                    cmd.NAME,
-                    len(encoded_string), self.MAX_LENGTH,
-                )
+                "Failed to send command %s as too long (%d > %d)"
+                % (cmd.NAME, len(encoded_string), self.MAX_LENGTH)
             )
 
         self.sendLine(encoded_string)
@@ -379,7 +383,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         if self.transport:
             addr = str(self.transport.getPeer())
         return "ReplicationConnection<name=%s,conn_id=%s,addr=%s>" % (
-            self.name, self.conn_id, addr,
+            self.name,
+            self.conn_id,
+            addr,
         )
 
     def id(self):
@@ -422,7 +428,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
 
     def on_USER_SYNC(self, cmd):
         return self.streamer.on_user_sync(
-            self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms,
+            self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
         )
 
     def on_REPLICATE(self, cmd):
@@ -432,10 +438,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
         if stream_name == "ALL":
             # Subscribe to all streams we're publishing to.
             deferreds = [
-                run_in_background(
-                    self.subscribe_to_stream,
-                    stream, token,
-                )
+                run_in_background(self.subscribe_to_stream, stream, token)
                 for stream in iterkeys(self.streamer.streams_by_name)
             ]
 
@@ -449,16 +452,18 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
         return self.streamer.federation_ack(cmd.token)
 
     def on_REMOVE_PUSHER(self, cmd):
-        return self.streamer.on_remove_pusher(
-            cmd.app_id, cmd.push_key, cmd.user_id,
-        )
+        return self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
 
     def on_INVALIDATE_CACHE(self, cmd):
         return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
 
     def on_USER_IP(self, cmd):
         return self.streamer.on_user_ip(
-            cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id,
+            cmd.user_id,
+            cmd.access_token,
+            cmd.ip,
+            cmd.user_agent,
+            cmd.device_id,
             cmd.last_seen,
         )
 
@@ -476,7 +481,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
         try:
             # Get missing updates
             updates, current_token = yield self.streamer.get_stream_updates(
-                stream_name, token,
+                stream_name, token
             )
 
             # Send all the missing updates
@@ -608,8 +613,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
             row = STREAMS_MAP[stream_name].parse_row(cmd.row)
         except Exception:
             logger.exception(
-                "[%s] Failed to parse RDATA: %r %r",
-                self.id(), stream_name, cmd.row
+                "[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row
             )
             raise
 
@@ -643,7 +647,9 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
 
         logger.info(
             "[%s] Subscribing to replication stream: %r from %r",
-            self.id(), stream_name, token
+            self.id(),
+            stream_name,
+            token,
         )
 
         self.streams_connecting.add(stream_name)
@@ -661,9 +667,7 @@ pending_commands = LaterGauge(
     "synapse_replication_tcp_protocol_pending_commands",
     "",
     ["name"],
-    lambda: {
-        (p.name,): len(p.pending_commands) for p in connected_connections
-    },
+    lambda: {(p.name,): len(p.pending_commands) for p in connected_connections},
 )
 
 
@@ -678,9 +682,7 @@ transport_send_buffer = LaterGauge(
     "synapse_replication_tcp_protocol_transport_send_buffer",
     "",
     ["name"],
-    lambda: {
-        (p.name,): transport_buffer_size(p) for p in connected_connections
-    },
+    lambda: {(p.name,): transport_buffer_size(p) for p in connected_connections},
 )
 
 
@@ -694,7 +696,7 @@ def transport_kernel_read_buffer_size(protocol, read=True):
             op = SIOCINQ
         else:
             op = SIOCOUTQ
-        size = struct.unpack("I", fcntl.ioctl(fileno, op, '\0\0\0\0'))[0]
+        size = struct.unpack("I", fcntl.ioctl(fileno, op, "\0\0\0\0"))[0]
         return size
     return 0
 
@@ -726,7 +728,7 @@ tcp_inbound_commands = LaterGauge(
     "",
     ["command", "name"],
     lambda: {
-        (k, p.name,): count
+        (k, p.name): count
         for p in connected_connections
         for k, count in iteritems(p.inbound_commands_counter)
     },
@@ -737,7 +739,7 @@ tcp_outbound_commands = LaterGauge(
     "",
     ["command", "name"],
     lambda: {
-        (k, p.name,): count
+        (k, p.name): count
         for p in connected_connections
         for k, count in iteritems(p.outbound_commands_counter)
     },