summary refs log tree commit diff
path: root/synapse/replication/tcp
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp')
-rw-r--r--synapse/replication/tcp/client.py6
-rw-r--r--synapse/replication/tcp/commands.py71
-rw-r--r--synapse/replication/tcp/protocol.py76
-rw-r--r--synapse/replication/tcp/resource.py54
-rw-r--r--synapse/replication/tcp/streams/_base.py160
-rw-r--r--synapse/replication/tcp/streams/events.py32
-rw-r--r--synapse/replication/tcp/streams/federation.py12
7 files changed, 228 insertions, 183 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 206dc3b397..a44ceb00e7 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -39,6 +39,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
     Accepts a handler that will be called when new data is available or data
     is required.
     """
+
     maxDelay = 30  # Try at least once every N seconds
 
     def __init__(self, hs, client_name, handler):
@@ -64,9 +65,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
 
     def clientConnectionFailed(self, connector, reason):
         logger.error("Failed to connect to replication: %r", reason)
-        ReconnectingClientFactory.clientConnectionFailed(
-            self, connector, reason
-        )
+        ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
 
 
 class ReplicationClientHandler(object):
@@ -74,6 +73,7 @@ class ReplicationClientHandler(object):
 
     By default proxies incoming replication data to the SlaveStore.
     """
+
     def __init__(self, store):
         self.store = store
 
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 2098c32a77..0ff2a7199f 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -23,9 +23,11 @@ import platform
 
 if platform.python_implementation() == "PyPy":
     import json
+
     _json_encoder = json.JSONEncoder()
 else:
     import simplejson as json
+
     _json_encoder = json.JSONEncoder(namedtuple_as_object=False)
 
 logger = logging.getLogger(__name__)
@@ -41,6 +43,7 @@ class Command(object):
 
     The default implementation creates a command of form `<NAME> <data>`
     """
+
     NAME = None
 
     def __init__(self, data):
@@ -73,6 +76,7 @@ class ServerCommand(Command):
 
         SERVER <server_name>
     """
+
     NAME = "SERVER"
 
 
@@ -99,6 +103,7 @@ class RdataCommand(Command):
         RDATA presence batch ["@bar:example.com", "online", ...]
         RDATA presence 59 ["@baz:example.com", "online", ...]
     """
+
     NAME = "RDATA"
 
     def __init__(self, stream_name, token, row):
@@ -110,17 +115,17 @@ class RdataCommand(Command):
     def from_line(cls, line):
         stream_name, token, row_json = line.split(" ", 2)
         return cls(
-            stream_name,
-            None if token == "batch" else int(token),
-            json.loads(row_json)
+            stream_name, None if token == "batch" else int(token), json.loads(row_json)
         )
 
     def to_line(self):
-        return " ".join((
-            self.stream_name,
-            str(self.token) if self.token is not None else "batch",
-            _json_encoder.encode(self.row),
-        ))
+        return " ".join(
+            (
+                self.stream_name,
+                str(self.token) if self.token is not None else "batch",
+                _json_encoder.encode(self.row),
+            )
+        )
 
     def get_logcontext_id(self):
         return "RDATA-" + self.stream_name
@@ -133,6 +138,7 @@ class PositionCommand(Command):
     Sent to the client after all missing updates for a stream have been sent
     to the client and they're now up to date.
     """
+
     NAME = "POSITION"
 
     def __init__(self, stream_name, token):
@@ -145,19 +151,21 @@ class PositionCommand(Command):
         return cls(stream_name, int(token))
 
     def to_line(self):
-        return " ".join((self.stream_name, str(self.token),))
+        return " ".join((self.stream_name, str(self.token)))
 
 
 class ErrorCommand(Command):
     """Sent by either side if there was an ERROR. The data is a string describing
     the error.
     """
+
     NAME = "ERROR"
 
 
 class PingCommand(Command):
     """Sent by either side as a keep alive. The data is arbitary (often timestamp)
     """
+
     NAME = "PING"
 
 
@@ -165,6 +173,7 @@ class NameCommand(Command):
     """Sent by client to inform the server of the client's identity. The data
     is the name
     """
+
     NAME = "NAME"
 
 
@@ -184,6 +193,7 @@ class ReplicateCommand(Command):
 
         REPLICATE ALL NOW
     """
+
     NAME = "REPLICATE"
 
     def __init__(self, stream_name, token):
@@ -200,7 +210,7 @@ class ReplicateCommand(Command):
         return cls(stream_name, token)
 
     def to_line(self):
-        return " ".join((self.stream_name, str(self.token),))
+        return " ".join((self.stream_name, str(self.token)))
 
     def get_logcontext_id(self):
         return "REPLICATE-" + self.stream_name
@@ -218,6 +228,7 @@ class UserSyncCommand(Command):
 
     Where <state> is either "start" or "stop"
     """
+
     NAME = "USER_SYNC"
 
     def __init__(self, user_id, is_syncing, last_sync_ms):
@@ -235,9 +246,13 @@ class UserSyncCommand(Command):
         return cls(user_id, state == "start", int(last_sync_ms))
 
     def to_line(self):
-        return " ".join((
-            self.user_id, "start" if self.is_syncing else "end", str(self.last_sync_ms),
-        ))
+        return " ".join(
+            (
+                self.user_id,
+                "start" if self.is_syncing else "end",
+                str(self.last_sync_ms),
+            )
+        )
 
 
 class FederationAckCommand(Command):
@@ -251,6 +266,7 @@ class FederationAckCommand(Command):
 
         FEDERATION_ACK <token>
     """
+
     NAME = "FEDERATION_ACK"
 
     def __init__(self, token):
@@ -268,6 +284,7 @@ class SyncCommand(Command):
     """Used for testing. The client protocol implementation allows waiting
     on a SYNC command with a specified data.
     """
+
     NAME = "SYNC"
 
 
@@ -278,6 +295,7 @@ class RemovePusherCommand(Command):
 
         REMOVE_PUSHER <app_id> <push_key> <user_id>
     """
+
     NAME = "REMOVE_PUSHER"
 
     def __init__(self, app_id, push_key, user_id):
@@ -309,6 +327,7 @@ class InvalidateCacheCommand(Command):
 
     Where <keys_json> is a json list.
     """
+
     NAME = "INVALIDATE_CACHE"
 
     def __init__(self, cache_func, keys):
@@ -322,9 +341,7 @@ class InvalidateCacheCommand(Command):
         return cls(cache_func, json.loads(keys_json))
 
     def to_line(self):
-        return " ".join((
-            self.cache_func, _json_encoder.encode(self.keys),
-        ))
+        return " ".join((self.cache_func, _json_encoder.encode(self.keys)))
 
 
 class UserIpCommand(Command):
@@ -334,6 +351,7 @@ class UserIpCommand(Command):
 
         USER_IP <user_id>, <access_token>, <ip>, <device_id>, <last_seen>, <user_agent>
     """
+
     NAME = "USER_IP"
 
     def __init__(self, user_id, access_token, ip, user_agent, device_id, last_seen):
@@ -350,15 +368,22 @@ class UserIpCommand(Command):
 
         access_token, ip, user_agent, device_id, last_seen = json.loads(jsn)
 
-        return cls(
-            user_id, access_token, ip, user_agent, device_id, last_seen
-        )
+        return cls(user_id, access_token, ip, user_agent, device_id, last_seen)
 
     def to_line(self):
-        return self.user_id + " " + _json_encoder.encode((
-            self.access_token, self.ip, self.user_agent, self.device_id,
-            self.last_seen,
-        ))
+        return (
+            self.user_id
+            + " "
+            + _json_encoder.encode(
+                (
+                    self.access_token,
+                    self.ip,
+                    self.user_agent,
+                    self.device_id,
+                    self.last_seen,
+                )
+            )
+        )
 
 
 # Map of command name to command type.
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index b51590cf8f..97efb835ad 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -84,7 +84,8 @@ from .commands import (
 from .streams import STREAMS_MAP
 
 connection_close_counter = Counter(
-    "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"])
+    "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
+)
 
 # A list of all connected protocols. This allows us to send metrics about the
 # connections.
@@ -119,7 +120,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
     It also sends `PING` periodically, and correctly times out remote connections
     (if they send a `PING` command)
     """
-    delimiter = b'\n'
+
+    delimiter = b"\n"
 
     VALID_INBOUND_COMMANDS = []  # Valid commands we expect to receive
     VALID_OUTBOUND_COMMANDS = []  # Valid commans we can send
@@ -183,10 +185,14 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
             if now - self.last_sent_command >= PING_TIME:
                 self.send_command(PingCommand(now))
 
-            if self.received_ping and now - self.last_received_command > PING_TIMEOUT_MS:
+            if (
+                self.received_ping
+                and now - self.last_received_command > PING_TIMEOUT_MS
+            ):
                 logger.info(
                     "[%s] Connection hasn't received command in %r ms. Closing.",
-                    self.id(), now - self.last_received_command
+                    self.id(),
+                    now - self.last_received_command,
                 )
                 self.send_error("ping timeout")
 
@@ -208,7 +214,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         self.last_received_command = self.clock.time_msec()
 
         self.inbound_commands_counter[cmd_name] = (
-            self.inbound_commands_counter[cmd_name] + 1)
+            self.inbound_commands_counter[cmd_name] + 1
+        )
 
         cmd_cls = COMMAND_MAP[cmd_name]
         try:
@@ -224,9 +231,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
         # Now lets try and call on_<CMD_NAME> function
         run_as_background_process(
-            "replication-" + cmd.get_logcontext_id(),
-            self.handle_command,
-            cmd,
+            "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
         )
 
     def handle_command(self, cmd):
@@ -274,8 +279,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
             return
 
         self.outbound_commands_counter[cmd.NAME] = (
-            self.outbound_commands_counter[cmd.NAME] + 1)
-        string = "%s %s" % (cmd.NAME, cmd.to_line(),)
+            self.outbound_commands_counter[cmd.NAME] + 1
+        )
+        string = "%s %s" % (cmd.NAME, cmd.to_line())
         if "\n" in string:
             raise Exception("Unexpected newline in command: %r", string)
 
@@ -283,10 +289,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
         if len(encoded_string) > self.MAX_LENGTH:
             raise Exception(
-                "Failed to send command %s as too long (%d > %d)" % (
-                    cmd.NAME,
-                    len(encoded_string), self.MAX_LENGTH,
-                )
+                "Failed to send command %s as too long (%d > %d)"
+                % (cmd.NAME, len(encoded_string), self.MAX_LENGTH)
             )
 
         self.sendLine(encoded_string)
@@ -379,7 +383,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         if self.transport:
             addr = str(self.transport.getPeer())
         return "ReplicationConnection<name=%s,conn_id=%s,addr=%s>" % (
-            self.name, self.conn_id, addr,
+            self.name,
+            self.conn_id,
+            addr,
         )
 
     def id(self):
@@ -422,7 +428,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
 
     def on_USER_SYNC(self, cmd):
         return self.streamer.on_user_sync(
-            self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms,
+            self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
         )
 
     def on_REPLICATE(self, cmd):
@@ -432,10 +438,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
         if stream_name == "ALL":
             # Subscribe to all streams we're publishing to.
             deferreds = [
-                run_in_background(
-                    self.subscribe_to_stream,
-                    stream, token,
-                )
+                run_in_background(self.subscribe_to_stream, stream, token)
                 for stream in iterkeys(self.streamer.streams_by_name)
             ]
 
@@ -449,16 +452,18 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
         return self.streamer.federation_ack(cmd.token)
 
     def on_REMOVE_PUSHER(self, cmd):
-        return self.streamer.on_remove_pusher(
-            cmd.app_id, cmd.push_key, cmd.user_id,
-        )
+        return self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
 
     def on_INVALIDATE_CACHE(self, cmd):
         return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
 
     def on_USER_IP(self, cmd):
         return self.streamer.on_user_ip(
-            cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id,
+            cmd.user_id,
+            cmd.access_token,
+            cmd.ip,
+            cmd.user_agent,
+            cmd.device_id,
             cmd.last_seen,
         )
 
@@ -476,7 +481,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
         try:
             # Get missing updates
             updates, current_token = yield self.streamer.get_stream_updates(
-                stream_name, token,
+                stream_name, token
             )
 
             # Send all the missing updates
@@ -608,8 +613,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
             row = STREAMS_MAP[stream_name].parse_row(cmd.row)
         except Exception:
             logger.exception(
-                "[%s] Failed to parse RDATA: %r %r",
-                self.id(), stream_name, cmd.row
+                "[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row
             )
             raise
 
@@ -643,7 +647,9 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
 
         logger.info(
             "[%s] Subscribing to replication stream: %r from %r",
-            self.id(), stream_name, token
+            self.id(),
+            stream_name,
+            token,
         )
 
         self.streams_connecting.add(stream_name)
@@ -661,9 +667,7 @@ pending_commands = LaterGauge(
     "synapse_replication_tcp_protocol_pending_commands",
     "",
     ["name"],
-    lambda: {
-        (p.name,): len(p.pending_commands) for p in connected_connections
-    },
+    lambda: {(p.name,): len(p.pending_commands) for p in connected_connections},
 )
 
 
@@ -678,9 +682,7 @@ transport_send_buffer = LaterGauge(
     "synapse_replication_tcp_protocol_transport_send_buffer",
     "",
     ["name"],
-    lambda: {
-        (p.name,): transport_buffer_size(p) for p in connected_connections
-    },
+    lambda: {(p.name,): transport_buffer_size(p) for p in connected_connections},
 )
 
 
@@ -694,7 +696,7 @@ def transport_kernel_read_buffer_size(protocol, read=True):
             op = SIOCINQ
         else:
             op = SIOCOUTQ
-        size = struct.unpack("I", fcntl.ioctl(fileno, op, '\0\0\0\0'))[0]
+        size = struct.unpack("I", fcntl.ioctl(fileno, op, "\0\0\0\0"))[0]
         return size
     return 0
 
@@ -726,7 +728,7 @@ tcp_inbound_commands = LaterGauge(
     "",
     ["command", "name"],
     lambda: {
-        (k, p.name,): count
+        (k, p.name): count
         for p in connected_connections
         for k, count in iteritems(p.inbound_commands_counter)
     },
@@ -737,7 +739,7 @@ tcp_outbound_commands = LaterGauge(
     "",
     ["command", "name"],
     lambda: {
-        (k, p.name,): count
+        (k, p.name): count
         for p in connected_connections
         for k, count in iteritems(p.outbound_commands_counter)
     },
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index f6a38f5140..d1e98428bc 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -33,13 +33,15 @@ from .protocol import ServerReplicationStreamProtocol
 from .streams import STREAMS_MAP
 from .streams.federation import FederationStream
 
-stream_updates_counter = Counter("synapse_replication_tcp_resource_stream_updates",
-                                 "", ["stream_name"])
+stream_updates_counter = Counter(
+    "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"]
+)
 user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
 federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
 remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
-invalidate_cache_counter = Counter("synapse_replication_tcp_resource_invalidate_cache",
-                                   "")
+invalidate_cache_counter = Counter(
+    "synapse_replication_tcp_resource_invalidate_cache", ""
+)
 user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
 
 logger = logging.getLogger(__name__)
@@ -48,6 +50,7 @@ logger = logging.getLogger(__name__)
 class ReplicationStreamProtocolFactory(Factory):
     """Factory for new replication connections.
     """
+
     def __init__(self, hs):
         self.streamer = ReplicationStreamer(hs)
         self.clock = hs.get_clock()
@@ -55,9 +58,7 @@ class ReplicationStreamProtocolFactory(Factory):
 
     def buildProtocol(self, addr):
         return ServerReplicationStreamProtocol(
-            self.server_name,
-            self.clock,
-            self.streamer,
+            self.server_name, self.clock, self.streamer
         )
 
 
@@ -80,29 +81,39 @@ class ReplicationStreamer(object):
         # Current connections.
         self.connections = []
 
-        LaterGauge("synapse_replication_tcp_resource_total_connections", "", [],
-                   lambda: len(self.connections))
+        LaterGauge(
+            "synapse_replication_tcp_resource_total_connections",
+            "",
+            [],
+            lambda: len(self.connections),
+        )
 
         # List of streams that clients can subscribe to.
         # We only support federation stream if federation sending hase been
         # disabled on the master.
         self.streams = [
-            stream(hs) for stream in itervalues(STREAMS_MAP)
+            stream(hs)
+            for stream in itervalues(STREAMS_MAP)
             if stream != FederationStream or not hs.config.send_federation
         ]
 
         self.streams_by_name = {stream.NAME: stream for stream in self.streams}
 
         LaterGauge(
-            "synapse_replication_tcp_resource_connections_per_stream", "",
+            "synapse_replication_tcp_resource_connections_per_stream",
+            "",
             ["stream_name"],
             lambda: {
-                (stream_name,): len([
-                    conn for conn in self.connections
-                    if stream_name in conn.replication_streams
-                ])
+                (stream_name,): len(
+                    [
+                        conn
+                        for conn in self.connections
+                        if stream_name in conn.replication_streams
+                    ]
+                )
                 for stream_name in self.streams_by_name
-            })
+            },
+        )
 
         self.federation_sender = None
         if not hs.config.send_federation:
@@ -179,7 +190,9 @@ class ReplicationStreamer(object):
 
                         logger.debug(
                             "Getting stream: %s: %s -> %s",
-                            stream.NAME, stream.last_token, stream.upto_token
+                            stream.NAME,
+                            stream.last_token,
+                            stream.upto_token,
                         )
                         try:
                             updates, current_token = yield stream.get_updates()
@@ -189,7 +202,8 @@ class ReplicationStreamer(object):
 
                         logger.debug(
                             "Sending %d updates to %d connections",
-                            len(updates), len(self.connections),
+                            len(updates),
+                            len(self.connections),
                         )
 
                         if updates:
@@ -243,7 +257,7 @@ class ReplicationStreamer(object):
         """
         user_sync_counter.inc()
         yield self.presence_handler.update_external_syncs_row(
-            conn_id, user_id, is_syncing, last_sync_ms,
+            conn_id, user_id, is_syncing, last_sync_ms
         )
 
     @measure_func("repl.on_remove_pusher")
@@ -272,7 +286,7 @@ class ReplicationStreamer(object):
         """
         user_ip_cache_counter.inc()
         yield self.store.insert_client_ip(
-            user_id, access_token, ip, user_agent, device_id, last_seen,
+            user_id, access_token, ip, user_agent, device_id, last_seen
         )
         yield self._server_notices_sender.on_user_ip(user_id)
 
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index b6ce7a7bee..7ef67a5a73 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -26,78 +26,75 @@ logger = logging.getLogger(__name__)
 
 MAX_EVENTS_BEHIND = 10000
 
-BackfillStreamRow = namedtuple("BackfillStreamRow", (
-    "event_id",  # str
-    "room_id",  # str
-    "type",  # str
-    "state_key",  # str, optional
-    "redacts",  # str, optional
-    "relates_to",  # str, optional
-))
-PresenceStreamRow = namedtuple("PresenceStreamRow", (
-    "user_id",  # str
-    "state",  # str
-    "last_active_ts",  # int
-    "last_federation_update_ts",  # int
-    "last_user_sync_ts",  # int
-    "status_msg",   # str
-    "currently_active",  # bool
-))
-TypingStreamRow = namedtuple("TypingStreamRow", (
-    "room_id",  # str
-    "user_ids",  # list(str)
-))
-ReceiptsStreamRow = namedtuple("ReceiptsStreamRow", (
-    "room_id",  # str
-    "receipt_type",  # str
-    "user_id",  # str
-    "event_id",  # str
-    "data",  # dict
-))
-PushRulesStreamRow = namedtuple("PushRulesStreamRow", (
-    "user_id",  # str
-))
-PushersStreamRow = namedtuple("PushersStreamRow", (
-    "user_id",  # str
-    "app_id",  # str
-    "pushkey",  # str
-    "deleted",  # bool
-))
-CachesStreamRow = namedtuple("CachesStreamRow", (
-    "cache_func",  # str
-    "keys",  # list(str)
-    "invalidation_ts",  # int
-))
-PublicRoomsStreamRow = namedtuple("PublicRoomsStreamRow", (
-    "room_id",  # str
-    "visibility",  # str
-    "appservice_id",  # str, optional
-    "network_id",  # str, optional
-))
-DeviceListsStreamRow = namedtuple("DeviceListsStreamRow", (
-    "user_id",  # str
-    "destination",  # str
-))
-ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", (
-    "entity",  # str
-))
-TagAccountDataStreamRow = namedtuple("TagAccountDataStreamRow", (
-    "user_id",  # str
-    "room_id",  # str
-    "data",  # dict
-))
-AccountDataStreamRow = namedtuple("AccountDataStream", (
-    "user_id",  # str
-    "room_id",  # str
-    "data_type",  # str
-    "data",  # dict
-))
-GroupsStreamRow = namedtuple("GroupsStreamRow", (
-    "group_id",  # str
-    "user_id",  # str
-    "type",  # str
-    "content",  # dict
-))
+BackfillStreamRow = namedtuple(
+    "BackfillStreamRow",
+    (
+        "event_id",  # str
+        "room_id",  # str
+        "type",  # str
+        "state_key",  # str, optional
+        "redacts",  # str, optional
+        "relates_to",  # str, optional
+    ),
+)
+PresenceStreamRow = namedtuple(
+    "PresenceStreamRow",
+    (
+        "user_id",  # str
+        "state",  # str
+        "last_active_ts",  # int
+        "last_federation_update_ts",  # int
+        "last_user_sync_ts",  # int
+        "status_msg",  # str
+        "currently_active",  # bool
+    ),
+)
+TypingStreamRow = namedtuple(
+    "TypingStreamRow", ("room_id", "user_ids")  # str  # list(str)
+)
+ReceiptsStreamRow = namedtuple(
+    "ReceiptsStreamRow",
+    (
+        "room_id",  # str
+        "receipt_type",  # str
+        "user_id",  # str
+        "event_id",  # str
+        "data",  # dict
+    ),
+)
+PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",))  # str
+PushersStreamRow = namedtuple(
+    "PushersStreamRow",
+    ("user_id", "app_id", "pushkey", "deleted"),  # str  # str  # str  # bool
+)
+CachesStreamRow = namedtuple(
+    "CachesStreamRow",
+    ("cache_func", "keys", "invalidation_ts"),  # str  # list(str)  # int
+)
+PublicRoomsStreamRow = namedtuple(
+    "PublicRoomsStreamRow",
+    (
+        "room_id",  # str
+        "visibility",  # str
+        "appservice_id",  # str, optional
+        "network_id",  # str, optional
+    ),
+)
+DeviceListsStreamRow = namedtuple(
+    "DeviceListsStreamRow", ("user_id", "destination")  # str  # str
+)
+ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",))  # str
+TagAccountDataStreamRow = namedtuple(
+    "TagAccountDataStreamRow", ("user_id", "room_id", "data")  # str  # str  # dict
+)
+AccountDataStreamRow = namedtuple(
+    "AccountDataStream",
+    ("user_id", "room_id", "data_type", "data"),  # str  # str  # str  # dict
+)
+GroupsStreamRow = namedtuple(
+    "GroupsStreamRow",
+    ("group_id", "user_id", "type", "content"),  # str  # str  # str  # dict
+)
 
 
 class Stream(object):
@@ -106,6 +103,7 @@ class Stream(object):
     Provides a `get_updates()` function that returns new updates since the last
     time it was called up until the point `advance_current_token` was called.
     """
+
     NAME = None  # The name of the stream
     ROW_TYPE = None  # The type of the row. Used by the default impl of parse_row.
     _LIMITED = True  # Whether the update function takes a limit
@@ -185,16 +183,13 @@ class Stream(object):
 
         if self._LIMITED:
             rows = yield self.update_function(
-                from_token, current_token,
-                limit=MAX_EVENTS_BEHIND + 1,
+                from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
             )
 
             # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
             rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
         else:
-            rows = yield self.update_function(
-                from_token, current_token,
-            )
+            rows = yield self.update_function(from_token, current_token)
 
         updates = [(row[0], row[1:]) for row in rows]
 
@@ -230,6 +225,7 @@ class BackfillStream(Stream):
     """We fetched some old events and either we had never seen that event before
     or it went from being an outlier to not.
     """
+
     NAME = "backfill"
     ROW_TYPE = BackfillStreamRow
 
@@ -286,6 +282,7 @@ class ReceiptsStream(Stream):
 class PushRulesStream(Stream):
     """A user has changed their push rules
     """
+
     NAME = "push_rules"
     ROW_TYPE = PushRulesStreamRow
 
@@ -306,6 +303,7 @@ class PushRulesStream(Stream):
 class PushersStream(Stream):
     """A user has added/changed/removed a pusher
     """
+
     NAME = "pushers"
     ROW_TYPE = PushersStreamRow
 
@@ -322,6 +320,7 @@ class CachesStream(Stream):
     """A cache was invalidated on the master and no other stream would invalidate
     the cache on the workers
     """
+
     NAME = "caches"
     ROW_TYPE = CachesStreamRow
 
@@ -337,6 +336,7 @@ class CachesStream(Stream):
 class PublicRoomsStream(Stream):
     """The public rooms list changed
     """
+
     NAME = "public_rooms"
     ROW_TYPE = PublicRoomsStreamRow
 
@@ -352,6 +352,7 @@ class PublicRoomsStream(Stream):
 class DeviceListsStream(Stream):
     """Someone added/changed/removed a device
     """
+
     NAME = "device_lists"
     _LIMITED = False
     ROW_TYPE = DeviceListsStreamRow
@@ -368,6 +369,7 @@ class DeviceListsStream(Stream):
 class ToDeviceStream(Stream):
     """New to_device messages for a client
     """
+
     NAME = "to_device"
     ROW_TYPE = ToDeviceStreamRow
 
@@ -383,6 +385,7 @@ class ToDeviceStream(Stream):
 class TagAccountDataStream(Stream):
     """Someone added/removed a tag for a room
     """
+
     NAME = "tag_account_data"
     ROW_TYPE = TagAccountDataStreamRow
 
@@ -398,6 +401,7 @@ class TagAccountDataStream(Stream):
 class AccountDataStream(Stream):
     """Global or per room account data was changed
     """
+
     NAME = "account_data"
     ROW_TYPE = AccountDataStreamRow
 
@@ -416,7 +420,7 @@ class AccountDataStream(Stream):
 
         results = list(room_results)
         results.extend(
-            (stream_id, user_id, None, account_data_type, content,)
+            (stream_id, user_id, None, account_data_type, content)
             for stream_id, user_id, account_data_type, content in global_results
         )
 
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index f1290d022a..3d0694bb11 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -52,6 +52,7 @@ data part are:
 @attr.s(slots=True, frozen=True)
 class EventsStreamRow(object):
     """A parsed row from the events replication stream"""
+
     type = attr.ib()  # str: the TypeId of one of the *EventsStreamRows
     data = attr.ib()  # BaseEventsStreamRow
 
@@ -80,11 +81,11 @@ class BaseEventsStreamRow(object):
 class EventsStreamEventRow(BaseEventsStreamRow):
     TypeId = "ev"
 
-    event_id = attr.ib()    # str
-    room_id = attr.ib()     # str
-    type = attr.ib()        # str
-    state_key = attr.ib()   # str, optional
-    redacts = attr.ib()     # str, optional
+    event_id = attr.ib()  # str
+    room_id = attr.ib()  # str
+    type = attr.ib()  # str
+    state_key = attr.ib()  # str, optional
+    redacts = attr.ib()  # str, optional
     relates_to = attr.ib()  # str, optional
 
 
@@ -92,24 +93,21 @@ class EventsStreamEventRow(BaseEventsStreamRow):
 class EventsStreamCurrentStateRow(BaseEventsStreamRow):
     TypeId = "state"
 
-    room_id = attr.ib()    # str
-    type = attr.ib()       # str
+    room_id = attr.ib()  # str
+    type = attr.ib()  # str
     state_key = attr.ib()  # str
-    event_id = attr.ib()   # str, optional
+    event_id = attr.ib()  # str, optional
 
 
 TypeToRow = {
-    Row.TypeId: Row
-    for Row in (
-        EventsStreamEventRow,
-        EventsStreamCurrentStateRow,
-    )
+    Row.TypeId: Row for Row in (EventsStreamEventRow, EventsStreamCurrentStateRow)
 }
 
 
 class EventsStream(Stream):
     """We received a new event, or an event went from being an outlier to not
     """
+
     NAME = "events"
 
     def __init__(self, hs):
@@ -121,19 +119,17 @@ class EventsStream(Stream):
     @defer.inlineCallbacks
     def update_function(self, from_token, current_token, limit=None):
         event_rows = yield self._store.get_all_new_forward_event_rows(
-            from_token, current_token, limit,
+            from_token, current_token, limit
         )
         event_updates = (
-            (row[0], EventsStreamEventRow.TypeId, row[1:])
-            for row in event_rows
+            (row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows
         )
 
         state_rows = yield self._store.get_all_updated_current_state_deltas(
             from_token, current_token, limit
         )
         state_updates = (
-            (row[0], EventsStreamCurrentStateRow.TypeId, row[1:])
-            for row in state_rows
+            (row[0], EventsStreamCurrentStateRow.TypeId, row[1:]) for row in state_rows
         )
 
         all_updates = heapq.merge(event_updates, state_updates)
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 9aa43aa8d2..dc2484109d 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -17,16 +17,20 @@ from collections import namedtuple
 
 from ._base import Stream
 
-FederationStreamRow = namedtuple("FederationStreamRow", (
-    "type",  # str, the type of data as defined in the BaseFederationRows
-    "data",  # dict, serialization of a federation.send_queue.BaseFederationRow
-))
+FederationStreamRow = namedtuple(
+    "FederationStreamRow",
+    (
+        "type",  # str, the type of data as defined in the BaseFederationRows
+        "data",  # dict, serialization of a federation.send_queue.BaseFederationRow
+    ),
+)
 
 
 class FederationStream(Stream):
     """Data to be sent over federation. Only available when master has federation
     sending disabled.
     """
+
     NAME = "federation"
     ROW_TYPE = FederationStreamRow