summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2019-06-24 10:00:13 +0100
committerRichard van der Hoff <richard@matrix.org>2019-06-24 10:00:13 +0100
commit5097aee740b542407e5bb13d19a3e3e6c2227316 (patch)
tree09a03650256e09cd0b5df59dbf2d7bb2ba14df6c /synapse/replication
parentchangelog (diff)
parentImprove help and cmdline option names for --generate-config options (#5512) (diff)
downloadsynapse-5097aee740b542407e5bb13d19a3e3e6c2227316.tar.xz
Merge branch 'develop' into rav/cleanup_metrics
Diffstat (limited to 'synapse/replication')
-rw-r--r--synapse/replication/http/_base.py22
-rw-r--r--synapse/replication/http/federation.py51
-rw-r--r--synapse/replication/http/login.py7
-rw-r--r--synapse/replication/http/membership.py34
-rw-r--r--synapse/replication/http/register.py17
-rw-r--r--synapse/replication/http/send_event.py13
-rw-r--r--synapse/replication/slave/storage/_base.py2
-rw-r--r--synapse/replication/slave/storage/account_data.py17
-rw-r--r--synapse/replication/slave/storage/appservice.py5
-rw-r--r--synapse/replication/slave/storage/client_ips.py4
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py6
-rw-r--r--synapse/replication/slave/storage/devices.py14
-rw-r--r--synapse/replication/slave/storage/events.py77
-rw-r--r--synapse/replication/slave/storage/groups.py9
-rw-r--r--synapse/replication/slave/storage/presence.py8
-rw-r--r--synapse/replication/slave/storage/push_rule.py6
-rw-r--r--synapse/replication/slave/storage/pushers.py4
-rw-r--r--synapse/replication/slave/storage/receipts.py1
-rw-r--r--synapse/replication/slave/storage/room.py4
-rw-r--r--synapse/replication/tcp/client.py6
-rw-r--r--synapse/replication/tcp/commands.py71
-rw-r--r--synapse/replication/tcp/protocol.py76
-rw-r--r--synapse/replication/tcp/resource.py54
-rw-r--r--synapse/replication/tcp/streams/_base.py160
-rw-r--r--synapse/replication/tcp/streams/events.py32
-rw-r--r--synapse/replication/tcp/streams/federation.py12
26 files changed, 357 insertions, 355 deletions
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 0a432a16fa..fe482e279f 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -83,8 +83,7 @@ class ReplicationEndpoint(object):
     def __init__(self, hs):
         if self.CACHE:
             self.response_cache = ResponseCache(
-                hs, "repl." + self.NAME,
-                timeout_ms=30 * 60 * 1000,
+                hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
             )
 
         assert self.METHOD in ("PUT", "POST", "GET")
@@ -134,8 +133,7 @@ class ReplicationEndpoint(object):
             data = yield cls._serialize_payload(**kwargs)
 
             url_args = [
-                urllib.parse.quote(kwargs[name], safe='')
-                for name in cls.PATH_ARGS
+                urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
             ]
 
             if cls.CACHE:
@@ -156,7 +154,10 @@ class ReplicationEndpoint(object):
                 )
 
             uri = "http://%s:%s/_synapse/replication/%s/%s" % (
-                host, port, cls.NAME, "/".join(url_args)
+                host,
+                port,
+                cls.NAME,
+                "/".join(url_args),
             )
 
             try:
@@ -202,10 +203,7 @@ class ReplicationEndpoint(object):
             url_args.append("txn_id")
 
         args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
-        pattern = re.compile("^/_synapse/replication/%s/%s$" % (
-            self.NAME,
-            args
-        ))
+        pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
 
         http_server.register_paths(method, [pattern], handler)
 
@@ -219,8 +217,4 @@ class ReplicationEndpoint(object):
 
         assert self.CACHE
 
-        return self.response_cache.wrap(
-            txn_id,
-            self._handle_request,
-            request, **kwargs
-        )
+        return self.response_cache.wrap(txn_id, self._handle_request, request, **kwargs)
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 0f0a07c422..61eafbe708 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -68,18 +68,17 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
         for event, context in event_and_contexts:
             serialized_context = yield context.serialize(event, store)
 
-            event_payloads.append({
-                "event": event.get_pdu_json(),
-                "event_format_version": event.format_version,
-                "internal_metadata": event.internal_metadata.get_dict(),
-                "rejected_reason": event.rejected_reason,
-                "context": serialized_context,
-            })
-
-        payload = {
-            "events": event_payloads,
-            "backfilled": backfilled,
-        }
+            event_payloads.append(
+                {
+                    "event": event.get_pdu_json(),
+                    "event_format_version": event.format_version,
+                    "internal_metadata": event.internal_metadata.get_dict(),
+                    "rejected_reason": event.rejected_reason,
+                    "context": serialized_context,
+                }
+            )
+
+        payload = {"events": event_payloads, "backfilled": backfilled}
 
         defer.returnValue(payload)
 
@@ -103,18 +102,15 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
                 event = EventType(event_dict, internal_metadata, rejected_reason)
 
                 context = yield EventContext.deserialize(
-                    self.store, event_payload["context"],
+                    self.store, event_payload["context"]
                 )
 
                 event_and_contexts.append((event, context))
 
-        logger.info(
-            "Got %d events from federation",
-            len(event_and_contexts),
-        )
+        logger.info("Got %d events from federation", len(event_and_contexts))
 
         yield self.federation_handler.persist_events_and_notify(
-            event_and_contexts, backfilled,
+            event_and_contexts, backfilled
         )
 
         defer.returnValue((200, {}))
@@ -146,10 +142,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
 
     @staticmethod
     def _serialize_payload(edu_type, origin, content):
-        return {
-            "origin": origin,
-            "content": content,
-        }
+        return {"origin": origin, "content": content}
 
     @defer.inlineCallbacks
     def _handle_request(self, request, edu_type):
@@ -159,10 +152,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
             origin = content["origin"]
             edu_content = content["content"]
 
-        logger.info(
-            "Got %r edu from %s",
-            edu_type, origin,
-        )
+        logger.info("Got %r edu from %s", edu_type, origin)
 
         result = yield self.registry.on_edu(edu_type, origin, edu_content)
 
@@ -201,9 +191,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
             query_type (str)
             args (dict): The arguments received for the given query type
         """
-        return {
-            "args": args,
-        }
+        return {"args": args}
 
     @defer.inlineCallbacks
     def _handle_request(self, request, query_type):
@@ -212,10 +200,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
 
             args = content["args"]
 
-        logger.info(
-            "Got %r query",
-            query_type,
-        )
+        logger.info("Got %r query", query_type)
 
         result = yield self.registry.on_query(query_type, args)
 
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index 63bc0405ea..7c1197e5dd 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -61,13 +61,10 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
         is_guest = content["is_guest"]
 
         device_id, access_token = yield self.registration_handler.register_device(
-            user_id, device_id, initial_display_name, is_guest,
+            user_id, device_id, initial_display_name, is_guest
         )
 
-        defer.returnValue((200, {
-            "device_id": device_id,
-            "access_token": access_token,
-        }))
+        defer.returnValue((200, {"device_id": device_id, "access_token": access_token}))
 
 
 def register_servlets(hs, http_server):
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 81a2b204c7..0a76a3762f 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -40,7 +40,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
     """
 
     NAME = "remote_join"
-    PATH_ARGS = ("room_id", "user_id",)
+    PATH_ARGS = ("room_id", "user_id")
 
     def __init__(self, hs):
         super(ReplicationRemoteJoinRestServlet, self).__init__(hs)
@@ -50,8 +50,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
         self.clock = hs.get_clock()
 
     @staticmethod
-    def _serialize_payload(requester, room_id, user_id, remote_room_hosts,
-                           content):
+    def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content):
         """
         Args:
             requester(Requester)
@@ -78,16 +77,10 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
         if requester.user:
             request.authenticated_entity = requester.user.to_string()
 
-        logger.info(
-            "remote_join: %s into room: %s",
-            user_id, room_id,
-        )
+        logger.info("remote_join: %s into room: %s", user_id, room_id)
 
         yield self.federation_handler.do_invite_join(
-            remote_room_hosts,
-            room_id,
-            user_id,
-            event_content,
+            remote_room_hosts, room_id, user_id, event_content
         )
 
         defer.returnValue((200, {}))
@@ -107,7 +100,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
     """
 
     NAME = "remote_reject_invite"
-    PATH_ARGS = ("room_id", "user_id",)
+    PATH_ARGS = ("room_id", "user_id")
 
     def __init__(self, hs):
         super(ReplicationRemoteRejectInviteRestServlet, self).__init__(hs)
@@ -141,16 +134,11 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
         if requester.user:
             request.authenticated_entity = requester.user.to_string()
 
-        logger.info(
-            "remote_reject_invite: %s out of room: %s",
-            user_id, room_id,
-        )
+        logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id)
 
         try:
             event = yield self.federation_handler.do_remotely_reject_invite(
-                remote_room_hosts,
-                room_id,
-                user_id,
+                remote_room_hosts, room_id, user_id
             )
             ret = event.get_pdu_json()
         except Exception as e:
@@ -162,9 +150,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
             #
             logger.warn("Failed to reject invite: %s", e)
 
-            yield self.store.locally_reject_invite(
-                user_id, room_id
-            )
+            yield self.store.locally_reject_invite(user_id, room_id)
             ret = {}
 
         defer.returnValue((200, ret))
@@ -228,7 +214,7 @@ class ReplicationRegister3PIDGuestRestServlet(ReplicationEndpoint):
         logger.info("get_or_register_3pid_guest: %r", content)
 
         ret = yield self.registeration_handler.get_or_register_3pid_guest(
-            medium, address, inviter_user_id,
+            medium, address, inviter_user_id
         )
 
         defer.returnValue((200, ret))
@@ -264,7 +250,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
             user_id (str)
             change (str): Either "joined" or "left"
         """
-        assert change in ("joined", "left",)
+        assert change in ("joined", "left")
 
         return {}
 
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index 912a5ac341..f81a0f1b8f 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -37,8 +37,16 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
 
     @staticmethod
     def _serialize_payload(
-        user_id, token, password_hash, was_guest, make_guest, appservice_id,
-        create_profile_with_displayname, admin, user_type, address,
+        user_id,
+        token,
+        password_hash,
+        was_guest,
+        make_guest,
+        appservice_id,
+        create_profile_with_displayname,
+        admin,
+        user_type,
+        address,
     ):
         """
         Args:
@@ -85,7 +93,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
             create_profile_with_displayname=content["create_profile_with_displayname"],
             admin=content["admin"],
             user_type=content["user_type"],
-            address=content["address"]
+            address=content["address"],
         )
 
         defer.returnValue((200, {}))
@@ -104,8 +112,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
         self.registration_handler = hs.get_registration_handler()
 
     @staticmethod
-    def _serialize_payload(user_id, auth_result, access_token, bind_email,
-                           bind_msisdn):
+    def _serialize_payload(user_id, auth_result, access_token, bind_email, bind_msisdn):
         """
         Args:
             user_id (str): The user ID that consented
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index 3635015eda..034763fe99 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -45,6 +45,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
             "extra_users": [],
         }
     """
+
     NAME = "send_event"
     PATH_ARGS = ("event_id",)
 
@@ -57,8 +58,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
 
     @staticmethod
     @defer.inlineCallbacks
-    def _serialize_payload(event_id, store, event, context, requester,
-                           ratelimit, extra_users):
+    def _serialize_payload(
+        event_id, store, event, context, requester, ratelimit, extra_users
+    ):
         """
         Args:
             event_id (str)
@@ -108,14 +110,11 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
             request.authenticated_entity = requester.user.to_string()
 
         logger.info(
-            "Got event to send with ID: %s into room: %s",
-            event.event_id, event.room_id,
+            "Got event to send with ID: %s into room: %s", event.event_id, event.room_id
         )
 
         yield self.event_creation_handler.persist_and_notify_client_event(
-            requester, event, context,
-            ratelimit=ratelimit,
-            extra_users=extra_users,
+            requester, event, context, ratelimit=ratelimit, extra_users=extra_users
         )
 
         defer.returnValue((200, {}))
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 817d1f67f9..182cb2a1d8 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -37,7 +37,7 @@ class BaseSlavedStore(SQLBaseStore):
         super(BaseSlavedStore, self).__init__(db_conn, hs)
         if isinstance(self.database_engine, PostgresEngine):
             self._cache_id_gen = SlavedIdTracker(
-                db_conn, "cache_invalidation_stream", "stream_id",
+                db_conn, "cache_invalidation_stream", "stream_id"
             )
         else:
             self._cache_id_gen = None
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index d9ba6d69b1..3c44d1d48d 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -21,10 +21,9 @@ from synapse.storage.tags import TagsWorkerStore
 
 
 class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
-
     def __init__(self, db_conn, hs):
         self._account_data_id_gen = SlavedIdTracker(
-            db_conn, "account_data_max_stream_id", "stream_id",
+            db_conn, "account_data_max_stream_id", "stream_id"
         )
 
         super(SlavedAccountDataStore, self).__init__(db_conn, hs)
@@ -45,24 +44,20 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
             self._account_data_id_gen.advance(token)
             for row in rows:
                 self.get_tags_for_user.invalidate((row.user_id,))
-                self._account_data_stream_cache.entity_has_changed(
-                    row.user_id, token
-                )
+                self._account_data_stream_cache.entity_has_changed(row.user_id, token)
         elif stream_name == "account_data":
             self._account_data_id_gen.advance(token)
             for row in rows:
                 if not row.room_id:
                     self.get_global_account_data_by_type_for_user.invalidate(
-                        (row.data_type, row.user_id,)
+                        (row.data_type, row.user_id)
                     )
                 self.get_account_data_for_user.invalidate((row.user_id,))
-                self.get_account_data_for_room.invalidate((row.user_id, row.room_id,))
+                self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
                 self.get_account_data_for_room_and_type.invalidate(
-                    (row.user_id, row.room_id, row.data_type,),
-                )
-                self._account_data_stream_cache.entity_has_changed(
-                    row.user_id, token
+                    (row.user_id, row.room_id, row.data_type)
                 )
+                self._account_data_stream_cache.entity_has_changed(row.user_id, token)
         return super(SlavedAccountDataStore, self).process_replication_rows(
             stream_name, token, rows
         )
diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py
index b53a4c6bd1..cda12ea70d 100644
--- a/synapse/replication/slave/storage/appservice.py
+++ b/synapse/replication/slave/storage/appservice.py
@@ -20,6 +20,7 @@ from synapse.storage.appservice import (
 )
 
 
-class SlavedApplicationServiceStore(ApplicationServiceTransactionWorkerStore,
-                                    ApplicationServiceWorkerStore):
+class SlavedApplicationServiceStore(
+    ApplicationServiceTransactionWorkerStore, ApplicationServiceWorkerStore
+):
     pass
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index 5b8521c770..14ced32333 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -25,9 +25,7 @@ class SlavedClientIpStore(BaseSlavedStore):
         super(SlavedClientIpStore, self).__init__(db_conn, hs)
 
         self.client_ip_last_seen = Cache(
-            name="client_ip_last_seen",
-            keylen=4,
-            max_entries=50000 * CACHE_SIZE_FACTOR,
+            name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
         )
 
     def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 4d59778863..284fd30d89 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -24,15 +24,15 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
     def __init__(self, db_conn, hs):
         super(SlavedDeviceInboxStore, self).__init__(db_conn, hs)
         self._device_inbox_id_gen = SlavedIdTracker(
-            db_conn, "device_max_stream_id", "stream_id",
+            db_conn, "device_max_stream_id", "stream_id"
         )
         self._device_inbox_stream_cache = StreamChangeCache(
             "DeviceInboxStreamChangeCache",
-            self._device_inbox_id_gen.get_current_token()
+            self._device_inbox_id_gen.get_current_token(),
         )
         self._device_federation_outbox_stream_cache = StreamChangeCache(
             "DeviceFederationOutboxStreamChangeCache",
-            self._device_inbox_id_gen.get_current_token()
+            self._device_inbox_id_gen.get_current_token(),
         )
 
         self._last_device_delete_cache = ExpiringCache(
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 16c9a162c5..d9300fce33 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -27,14 +27,14 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
         self.hs = hs
 
         self._device_list_id_gen = SlavedIdTracker(
-            db_conn, "device_lists_stream", "stream_id",
+            db_conn, "device_lists_stream", "stream_id"
         )
         device_list_max = self._device_list_id_gen.get_current_token()
         self._device_list_stream_cache = StreamChangeCache(
-            "DeviceListStreamChangeCache", device_list_max,
+            "DeviceListStreamChangeCache", device_list_max
         )
         self._device_list_federation_stream_cache = StreamChangeCache(
-            "DeviceListFederationStreamChangeCache", device_list_max,
+            "DeviceListFederationStreamChangeCache", device_list_max
         )
 
     def stream_positions(self):
@@ -46,17 +46,13 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
         if stream_name == "device_lists":
             self._device_list_id_gen.advance(token)
             for row in rows:
-                self._invalidate_caches_for_devices(
-                    token, row.user_id, row.destination,
-                )
+                self._invalidate_caches_for_devices(token, row.user_id, row.destination)
         return super(SlavedDeviceStore, self).process_replication_rows(
             stream_name, token, rows
         )
 
     def _invalidate_caches_for_devices(self, token, user_id, destination):
-        self._device_list_stream_cache.entity_has_changed(
-            user_id, token
-        )
+        self._device_list_stream_cache.entity_has_changed(user_id, token)
 
         if destination:
             self._device_list_federation_stream_cache.entity_has_changed(
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index a3952506c1..ab5937e638 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -45,21 +45,20 @@ logger = logging.getLogger(__name__)
 # the method descriptor on the DataStore and chuck them into our class.
 
 
-class SlavedEventStore(EventFederationWorkerStore,
-                       RoomMemberWorkerStore,
-                       EventPushActionsWorkerStore,
-                       StreamWorkerStore,
-                       StateGroupWorkerStore,
-                       EventsWorkerStore,
-                       SignatureWorkerStore,
-                       UserErasureWorkerStore,
-                       RelationsWorkerStore,
-                       BaseSlavedStore):
-
+class SlavedEventStore(
+    EventFederationWorkerStore,
+    RoomMemberWorkerStore,
+    EventPushActionsWorkerStore,
+    StreamWorkerStore,
+    StateGroupWorkerStore,
+    EventsWorkerStore,
+    SignatureWorkerStore,
+    UserErasureWorkerStore,
+    RelationsWorkerStore,
+    BaseSlavedStore,
+):
     def __init__(self, db_conn, hs):
-        self._stream_id_gen = SlavedIdTracker(
-            db_conn, "events", "stream_ordering",
-        )
+        self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
         self._backfill_id_gen = SlavedIdTracker(
             db_conn, "events", "stream_ordering", step=-1
         )
@@ -90,8 +89,13 @@ class SlavedEventStore(EventFederationWorkerStore,
             self._backfill_id_gen.advance(-token)
             for row in rows:
                 self.invalidate_caches_for_event(
-                    -token, row.event_id, row.room_id, row.type, row.state_key,
-                    row.redacts, row.relates_to,
+                    -token,
+                    row.event_id,
+                    row.room_id,
+                    row.type,
+                    row.state_key,
+                    row.redacts,
+                    row.relates_to,
                     backfilled=True,
                 )
         return super(SlavedEventStore, self).process_replication_rows(
@@ -103,41 +107,48 @@ class SlavedEventStore(EventFederationWorkerStore,
 
         if row.type == EventsStreamEventRow.TypeId:
             self.invalidate_caches_for_event(
-                token, data.event_id, data.room_id, data.type, data.state_key,
-                data.redacts, data.relates_to,
+                token,
+                data.event_id,
+                data.room_id,
+                data.type,
+                data.state_key,
+                data.redacts,
+                data.relates_to,
                 backfilled=False,
             )
         elif row.type == EventsStreamCurrentStateRow.TypeId:
             if data.type == EventTypes.Member:
                 self.get_rooms_for_user_with_stream_ordering.invalidate(
-                    (data.state_key, ),
+                    (data.state_key,)
                 )
         else:
-            raise Exception("Unknown events stream row type %s" % (row.type, ))
-
-    def invalidate_caches_for_event(self, stream_ordering, event_id, room_id,
-                                    etype, state_key, redacts, relates_to,
-                                    backfilled):
+            raise Exception("Unknown events stream row type %s" % (row.type,))
+
+    def invalidate_caches_for_event(
+        self,
+        stream_ordering,
+        event_id,
+        room_id,
+        etype,
+        state_key,
+        redacts,
+        relates_to,
+        backfilled,
+    ):
         self._invalidate_get_event_cache(event_id)
 
         self.get_latest_event_ids_in_room.invalidate((room_id,))
 
-        self.get_unread_event_push_actions_by_room_for_user.invalidate_many(
-            (room_id,)
-        )
+        self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
 
         if not backfilled:
-            self._events_stream_cache.entity_has_changed(
-                room_id, stream_ordering
-            )
+            self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
 
         if redacts:
             self._invalidate_get_event_cache(redacts)
 
         if etype == EventTypes.Member:
-            self._membership_stream_cache.entity_has_changed(
-                state_key, stream_ordering
-            )
+            self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
             self.get_invited_rooms_for_user.invalidate((state_key,))
 
         if relates_to:
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index e933b170bb..28a46edd28 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -27,10 +27,11 @@ class SlavedGroupServerStore(BaseSlavedStore):
         self.hs = hs
 
         self._group_updates_id_gen = SlavedIdTracker(
-            db_conn, "local_group_updates", "stream_id",
+            db_conn, "local_group_updates", "stream_id"
         )
         self._group_updates_stream_cache = StreamChangeCache(
-            "_group_updates_stream_cache", self._group_updates_id_gen.get_current_token(),
+            "_group_updates_stream_cache",
+            self._group_updates_id_gen.get_current_token(),
         )
 
     get_groups_changes_for_user = __func__(DataStore.get_groups_changes_for_user)
@@ -46,9 +47,7 @@ class SlavedGroupServerStore(BaseSlavedStore):
         if stream_name == "groups":
             self._group_updates_id_gen.advance(token)
             for row in rows:
-                self._group_updates_stream_cache.entity_has_changed(
-                    row.user_id, token
-                )
+                self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
 
         return super(SlavedGroupServerStore, self).process_replication_rows(
             stream_name, token, rows
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index 0ec1db25ce..82d808af4c 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -24,9 +24,7 @@ from ._slaved_id_tracker import SlavedIdTracker
 class SlavedPresenceStore(BaseSlavedStore):
     def __init__(self, db_conn, hs):
         super(SlavedPresenceStore, self).__init__(db_conn, hs)
-        self._presence_id_gen = SlavedIdTracker(
-            db_conn, "presence_stream", "stream_id",
-        )
+        self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id")
 
         self._presence_on_startup = self._get_active_presence(db_conn)
 
@@ -55,9 +53,7 @@ class SlavedPresenceStore(BaseSlavedStore):
         if stream_name == "presence":
             self._presence_id_gen.advance(token)
             for row in rows:
-                self.presence_stream_cache.entity_has_changed(
-                    row.user_id, token
-                )
+                self.presence_stream_cache.entity_has_changed(row.user_id, token)
                 self._get_presence_for_user.invalidate((row.user_id,))
         return super(SlavedPresenceStore, self).process_replication_rows(
             stream_name, token, rows
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 45fc913c52..af7012702e 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -23,7 +23,7 @@ from .events import SlavedEventStore
 class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
     def __init__(self, db_conn, hs):
         self._push_rules_stream_id_gen = SlavedIdTracker(
-            db_conn, "push_rules_stream", "stream_id",
+            db_conn, "push_rules_stream", "stream_id"
         )
         super(SlavedPushRuleStore, self).__init__(db_conn, hs)
 
@@ -47,9 +47,7 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
             for row in rows:
                 self.get_push_rules_for_user.invalidate((row.user_id,))
                 self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
-                self.push_rules_stream_cache.entity_has_changed(
-                    row.user_id, token
-                )
+                self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
         return super(SlavedPushRuleStore, self).process_replication_rows(
             stream_name, token, rows
         )
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index 3b2213c0d4..8eeb267d61 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -21,12 +21,10 @@ from ._slaved_id_tracker import SlavedIdTracker
 
 
 class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
-
     def __init__(self, db_conn, hs):
         super(SlavedPusherStore, self).__init__(db_conn, hs)
         self._pushers_id_gen = SlavedIdTracker(
-            db_conn, "pushers", "id",
-            extra_tables=[("deleted_pushers", "stream_id")],
+            db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
         )
 
     def stream_positions(self):
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index ed12342f40..91afa5a72b 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -29,7 +29,6 @@ from ._slaved_id_tracker import SlavedIdTracker
 
 
 class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
-
     def __init__(self, db_conn, hs):
         # We instantiate this first as the ReceiptsWorkerStore constructor
         # needs to be able to call get_max_receipt_stream_id
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 0cb474928c..f68b3378e3 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -38,6 +38,4 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
         if stream_name == "public_rooms":
             self._public_room_id_gen.advance(token)
 
-        return super(RoomStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super(RoomStore, self).process_replication_rows(stream_name, token, rows)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 206dc3b397..a44ceb00e7 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -39,6 +39,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
     Accepts a handler that will be called when new data is available or data
     is required.
     """
+
     maxDelay = 30  # Try at least once every N seconds
 
     def __init__(self, hs, client_name, handler):
@@ -64,9 +65,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
 
     def clientConnectionFailed(self, connector, reason):
         logger.error("Failed to connect to replication: %r", reason)
-        ReconnectingClientFactory.clientConnectionFailed(
-            self, connector, reason
-        )
+        ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
 
 
 class ReplicationClientHandler(object):
@@ -74,6 +73,7 @@ class ReplicationClientHandler(object):
 
     By default proxies incoming replication data to the SlaveStore.
     """
+
     def __init__(self, store):
         self.store = store
 
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 2098c32a77..0ff2a7199f 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -23,9 +23,11 @@ import platform
 
 if platform.python_implementation() == "PyPy":
     import json
+
     _json_encoder = json.JSONEncoder()
 else:
     import simplejson as json
+
     _json_encoder = json.JSONEncoder(namedtuple_as_object=False)
 
 logger = logging.getLogger(__name__)
@@ -41,6 +43,7 @@ class Command(object):
 
     The default implementation creates a command of form `<NAME> <data>`
     """
+
     NAME = None
 
     def __init__(self, data):
@@ -73,6 +76,7 @@ class ServerCommand(Command):
 
         SERVER <server_name>
     """
+
     NAME = "SERVER"
 
 
@@ -99,6 +103,7 @@ class RdataCommand(Command):
         RDATA presence batch ["@bar:example.com", "online", ...]
         RDATA presence 59 ["@baz:example.com", "online", ...]
     """
+
     NAME = "RDATA"
 
     def __init__(self, stream_name, token, row):
@@ -110,17 +115,17 @@ class RdataCommand(Command):
     def from_line(cls, line):
         stream_name, token, row_json = line.split(" ", 2)
         return cls(
-            stream_name,
-            None if token == "batch" else int(token),
-            json.loads(row_json)
+            stream_name, None if token == "batch" else int(token), json.loads(row_json)
         )
 
     def to_line(self):
-        return " ".join((
-            self.stream_name,
-            str(self.token) if self.token is not None else "batch",
-            _json_encoder.encode(self.row),
-        ))
+        return " ".join(
+            (
+                self.stream_name,
+                str(self.token) if self.token is not None else "batch",
+                _json_encoder.encode(self.row),
+            )
+        )
 
     def get_logcontext_id(self):
         return "RDATA-" + self.stream_name
@@ -133,6 +138,7 @@ class PositionCommand(Command):
     Sent to the client after all missing updates for a stream have been sent
     to the client and they're now up to date.
     """
+
     NAME = "POSITION"
 
     def __init__(self, stream_name, token):
@@ -145,19 +151,21 @@ class PositionCommand(Command):
         return cls(stream_name, int(token))
 
     def to_line(self):
-        return " ".join((self.stream_name, str(self.token),))
+        return " ".join((self.stream_name, str(self.token)))
 
 
 class ErrorCommand(Command):
     """Sent by either side if there was an ERROR. The data is a string describing
     the error.
     """
+
     NAME = "ERROR"
 
 
 class PingCommand(Command):
     """Sent by either side as a keep alive. The data is arbitary (often timestamp)
     """
+
     NAME = "PING"
 
 
@@ -165,6 +173,7 @@ class NameCommand(Command):
     """Sent by client to inform the server of the client's identity. The data
     is the name
     """
+
     NAME = "NAME"
 
 
@@ -184,6 +193,7 @@ class ReplicateCommand(Command):
 
         REPLICATE ALL NOW
     """
+
     NAME = "REPLICATE"
 
     def __init__(self, stream_name, token):
@@ -200,7 +210,7 @@ class ReplicateCommand(Command):
         return cls(stream_name, token)
 
     def to_line(self):
-        return " ".join((self.stream_name, str(self.token),))
+        return " ".join((self.stream_name, str(self.token)))
 
     def get_logcontext_id(self):
         return "REPLICATE-" + self.stream_name
@@ -218,6 +228,7 @@ class UserSyncCommand(Command):
 
     Where <state> is either "start" or "stop"
     """
+
     NAME = "USER_SYNC"
 
     def __init__(self, user_id, is_syncing, last_sync_ms):
@@ -235,9 +246,13 @@ class UserSyncCommand(Command):
         return cls(user_id, state == "start", int(last_sync_ms))
 
     def to_line(self):
-        return " ".join((
-            self.user_id, "start" if self.is_syncing else "end", str(self.last_sync_ms),
-        ))
+        return " ".join(
+            (
+                self.user_id,
+                "start" if self.is_syncing else "end",
+                str(self.last_sync_ms),
+            )
+        )
 
 
 class FederationAckCommand(Command):
@@ -251,6 +266,7 @@ class FederationAckCommand(Command):
 
         FEDERATION_ACK <token>
     """
+
     NAME = "FEDERATION_ACK"
 
     def __init__(self, token):
@@ -268,6 +284,7 @@ class SyncCommand(Command):
     """Used for testing. The client protocol implementation allows waiting
     on a SYNC command with a specified data.
     """
+
     NAME = "SYNC"
 
 
@@ -278,6 +295,7 @@ class RemovePusherCommand(Command):
 
         REMOVE_PUSHER <app_id> <push_key> <user_id>
     """
+
     NAME = "REMOVE_PUSHER"
 
     def __init__(self, app_id, push_key, user_id):
@@ -309,6 +327,7 @@ class InvalidateCacheCommand(Command):
 
     Where <keys_json> is a json list.
     """
+
     NAME = "INVALIDATE_CACHE"
 
     def __init__(self, cache_func, keys):
@@ -322,9 +341,7 @@ class InvalidateCacheCommand(Command):
         return cls(cache_func, json.loads(keys_json))
 
     def to_line(self):
-        return " ".join((
-            self.cache_func, _json_encoder.encode(self.keys),
-        ))
+        return " ".join((self.cache_func, _json_encoder.encode(self.keys)))
 
 
 class UserIpCommand(Command):
@@ -334,6 +351,7 @@ class UserIpCommand(Command):
 
         USER_IP <user_id>, <access_token>, <ip>, <device_id>, <last_seen>, <user_agent>
     """
+
     NAME = "USER_IP"
 
     def __init__(self, user_id, access_token, ip, user_agent, device_id, last_seen):
@@ -350,15 +368,22 @@ class UserIpCommand(Command):
 
         access_token, ip, user_agent, device_id, last_seen = json.loads(jsn)
 
-        return cls(
-            user_id, access_token, ip, user_agent, device_id, last_seen
-        )
+        return cls(user_id, access_token, ip, user_agent, device_id, last_seen)
 
     def to_line(self):
-        return self.user_id + " " + _json_encoder.encode((
-            self.access_token, self.ip, self.user_agent, self.device_id,
-            self.last_seen,
-        ))
+        return (
+            self.user_id
+            + " "
+            + _json_encoder.encode(
+                (
+                    self.access_token,
+                    self.ip,
+                    self.user_agent,
+                    self.device_id,
+                    self.last_seen,
+                )
+            )
+        )
 
 
 # Map of command name to command type.
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index b51590cf8f..97efb835ad 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -84,7 +84,8 @@ from .commands import (
 from .streams import STREAMS_MAP
 
 connection_close_counter = Counter(
-    "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"])
+    "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
+)
 
 # A list of all connected protocols. This allows us to send metrics about the
 # connections.
@@ -119,7 +120,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
     It also sends `PING` periodically, and correctly times out remote connections
     (if they send a `PING` command)
     """
-    delimiter = b'\n'
+
+    delimiter = b"\n"
 
     VALID_INBOUND_COMMANDS = []  # Valid commands we expect to receive
     VALID_OUTBOUND_COMMANDS = []  # Valid commans we can send
@@ -183,10 +185,14 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
             if now - self.last_sent_command >= PING_TIME:
                 self.send_command(PingCommand(now))
 
-            if self.received_ping and now - self.last_received_command > PING_TIMEOUT_MS:
+            if (
+                self.received_ping
+                and now - self.last_received_command > PING_TIMEOUT_MS
+            ):
                 logger.info(
                     "[%s] Connection hasn't received command in %r ms. Closing.",
-                    self.id(), now - self.last_received_command
+                    self.id(),
+                    now - self.last_received_command,
                 )
                 self.send_error("ping timeout")
 
@@ -208,7 +214,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         self.last_received_command = self.clock.time_msec()
 
         self.inbound_commands_counter[cmd_name] = (
-            self.inbound_commands_counter[cmd_name] + 1)
+            self.inbound_commands_counter[cmd_name] + 1
+        )
 
         cmd_cls = COMMAND_MAP[cmd_name]
         try:
@@ -224,9 +231,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
         # Now lets try and call on_<CMD_NAME> function
         run_as_background_process(
-            "replication-" + cmd.get_logcontext_id(),
-            self.handle_command,
-            cmd,
+            "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
         )
 
     def handle_command(self, cmd):
@@ -274,8 +279,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
             return
 
         self.outbound_commands_counter[cmd.NAME] = (
-            self.outbound_commands_counter[cmd.NAME] + 1)
-        string = "%s %s" % (cmd.NAME, cmd.to_line(),)
+            self.outbound_commands_counter[cmd.NAME] + 1
+        )
+        string = "%s %s" % (cmd.NAME, cmd.to_line())
         if "\n" in string:
             raise Exception("Unexpected newline in command: %r", string)
 
@@ -283,10 +289,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
         if len(encoded_string) > self.MAX_LENGTH:
             raise Exception(
-                "Failed to send command %s as too long (%d > %d)" % (
-                    cmd.NAME,
-                    len(encoded_string), self.MAX_LENGTH,
-                )
+                "Failed to send command %s as too long (%d > %d)"
+                % (cmd.NAME, len(encoded_string), self.MAX_LENGTH)
             )
 
         self.sendLine(encoded_string)
@@ -379,7 +383,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         if self.transport:
             addr = str(self.transport.getPeer())
         return "ReplicationConnection<name=%s,conn_id=%s,addr=%s>" % (
-            self.name, self.conn_id, addr,
+            self.name,
+            self.conn_id,
+            addr,
         )
 
     def id(self):
@@ -422,7 +428,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
 
     def on_USER_SYNC(self, cmd):
         return self.streamer.on_user_sync(
-            self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms,
+            self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
         )
 
     def on_REPLICATE(self, cmd):
@@ -432,10 +438,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
         if stream_name == "ALL":
             # Subscribe to all streams we're publishing to.
             deferreds = [
-                run_in_background(
-                    self.subscribe_to_stream,
-                    stream, token,
-                )
+                run_in_background(self.subscribe_to_stream, stream, token)
                 for stream in iterkeys(self.streamer.streams_by_name)
             ]
 
@@ -449,16 +452,18 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
         return self.streamer.federation_ack(cmd.token)
 
     def on_REMOVE_PUSHER(self, cmd):
-        return self.streamer.on_remove_pusher(
-            cmd.app_id, cmd.push_key, cmd.user_id,
-        )
+        return self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
 
     def on_INVALIDATE_CACHE(self, cmd):
         return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
 
     def on_USER_IP(self, cmd):
         return self.streamer.on_user_ip(
-            cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id,
+            cmd.user_id,
+            cmd.access_token,
+            cmd.ip,
+            cmd.user_agent,
+            cmd.device_id,
             cmd.last_seen,
         )
 
@@ -476,7 +481,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
         try:
             # Get missing updates
             updates, current_token = yield self.streamer.get_stream_updates(
-                stream_name, token,
+                stream_name, token
             )
 
             # Send all the missing updates
@@ -608,8 +613,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
             row = STREAMS_MAP[stream_name].parse_row(cmd.row)
         except Exception:
             logger.exception(
-                "[%s] Failed to parse RDATA: %r %r",
-                self.id(), stream_name, cmd.row
+                "[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row
             )
             raise
 
@@ -643,7 +647,9 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
 
         logger.info(
             "[%s] Subscribing to replication stream: %r from %r",
-            self.id(), stream_name, token
+            self.id(),
+            stream_name,
+            token,
         )
 
         self.streams_connecting.add(stream_name)
@@ -661,9 +667,7 @@ pending_commands = LaterGauge(
     "synapse_replication_tcp_protocol_pending_commands",
     "",
     ["name"],
-    lambda: {
-        (p.name,): len(p.pending_commands) for p in connected_connections
-    },
+    lambda: {(p.name,): len(p.pending_commands) for p in connected_connections},
 )
 
 
@@ -678,9 +682,7 @@ transport_send_buffer = LaterGauge(
     "synapse_replication_tcp_protocol_transport_send_buffer",
     "",
     ["name"],
-    lambda: {
-        (p.name,): transport_buffer_size(p) for p in connected_connections
-    },
+    lambda: {(p.name,): transport_buffer_size(p) for p in connected_connections},
 )
 
 
@@ -694,7 +696,7 @@ def transport_kernel_read_buffer_size(protocol, read=True):
             op = SIOCINQ
         else:
             op = SIOCOUTQ
-        size = struct.unpack("I", fcntl.ioctl(fileno, op, '\0\0\0\0'))[0]
+        size = struct.unpack("I", fcntl.ioctl(fileno, op, "\0\0\0\0"))[0]
         return size
     return 0
 
@@ -726,7 +728,7 @@ tcp_inbound_commands = LaterGauge(
     "",
     ["command", "name"],
     lambda: {
-        (k, p.name,): count
+        (k, p.name): count
         for p in connected_connections
         for k, count in iteritems(p.inbound_commands_counter)
     },
@@ -737,7 +739,7 @@ tcp_outbound_commands = LaterGauge(
     "",
     ["command", "name"],
     lambda: {
-        (k, p.name,): count
+        (k, p.name): count
         for p in connected_connections
         for k, count in iteritems(p.outbound_commands_counter)
     },
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index f6a38f5140..d1e98428bc 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -33,13 +33,15 @@ from .protocol import ServerReplicationStreamProtocol
 from .streams import STREAMS_MAP
 from .streams.federation import FederationStream
 
-stream_updates_counter = Counter("synapse_replication_tcp_resource_stream_updates",
-                                 "", ["stream_name"])
+stream_updates_counter = Counter(
+    "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"]
+)
 user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
 federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
 remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
-invalidate_cache_counter = Counter("synapse_replication_tcp_resource_invalidate_cache",
-                                   "")
+invalidate_cache_counter = Counter(
+    "synapse_replication_tcp_resource_invalidate_cache", ""
+)
 user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
 
 logger = logging.getLogger(__name__)
@@ -48,6 +50,7 @@ logger = logging.getLogger(__name__)
 class ReplicationStreamProtocolFactory(Factory):
     """Factory for new replication connections.
     """
+
     def __init__(self, hs):
         self.streamer = ReplicationStreamer(hs)
         self.clock = hs.get_clock()
@@ -55,9 +58,7 @@ class ReplicationStreamProtocolFactory(Factory):
 
     def buildProtocol(self, addr):
         return ServerReplicationStreamProtocol(
-            self.server_name,
-            self.clock,
-            self.streamer,
+            self.server_name, self.clock, self.streamer
         )
 
 
@@ -80,29 +81,39 @@ class ReplicationStreamer(object):
         # Current connections.
         self.connections = []
 
-        LaterGauge("synapse_replication_tcp_resource_total_connections", "", [],
-                   lambda: len(self.connections))
+        LaterGauge(
+            "synapse_replication_tcp_resource_total_connections",
+            "",
+            [],
+            lambda: len(self.connections),
+        )
 
         # List of streams that clients can subscribe to.
         # We only support federation stream if federation sending hase been
         # disabled on the master.
         self.streams = [
-            stream(hs) for stream in itervalues(STREAMS_MAP)
+            stream(hs)
+            for stream in itervalues(STREAMS_MAP)
             if stream != FederationStream or not hs.config.send_federation
         ]
 
         self.streams_by_name = {stream.NAME: stream for stream in self.streams}
 
         LaterGauge(
-            "synapse_replication_tcp_resource_connections_per_stream", "",
+            "synapse_replication_tcp_resource_connections_per_stream",
+            "",
             ["stream_name"],
             lambda: {
-                (stream_name,): len([
-                    conn for conn in self.connections
-                    if stream_name in conn.replication_streams
-                ])
+                (stream_name,): len(
+                    [
+                        conn
+                        for conn in self.connections
+                        if stream_name in conn.replication_streams
+                    ]
+                )
                 for stream_name in self.streams_by_name
-            })
+            },
+        )
 
         self.federation_sender = None
         if not hs.config.send_federation:
@@ -179,7 +190,9 @@ class ReplicationStreamer(object):
 
                         logger.debug(
                             "Getting stream: %s: %s -> %s",
-                            stream.NAME, stream.last_token, stream.upto_token
+                            stream.NAME,
+                            stream.last_token,
+                            stream.upto_token,
                         )
                         try:
                             updates, current_token = yield stream.get_updates()
@@ -189,7 +202,8 @@ class ReplicationStreamer(object):
 
                         logger.debug(
                             "Sending %d updates to %d connections",
-                            len(updates), len(self.connections),
+                            len(updates),
+                            len(self.connections),
                         )
 
                         if updates:
@@ -243,7 +257,7 @@ class ReplicationStreamer(object):
         """
         user_sync_counter.inc()
         yield self.presence_handler.update_external_syncs_row(
-            conn_id, user_id, is_syncing, last_sync_ms,
+            conn_id, user_id, is_syncing, last_sync_ms
         )
 
     @measure_func("repl.on_remove_pusher")
@@ -272,7 +286,7 @@ class ReplicationStreamer(object):
         """
         user_ip_cache_counter.inc()
         yield self.store.insert_client_ip(
-            user_id, access_token, ip, user_agent, device_id, last_seen,
+            user_id, access_token, ip, user_agent, device_id, last_seen
         )
         yield self._server_notices_sender.on_user_ip(user_id)
 
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index b6ce7a7bee..7ef67a5a73 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -26,78 +26,75 @@ logger = logging.getLogger(__name__)
 
 MAX_EVENTS_BEHIND = 10000
 
-BackfillStreamRow = namedtuple("BackfillStreamRow", (
-    "event_id",  # str
-    "room_id",  # str
-    "type",  # str
-    "state_key",  # str, optional
-    "redacts",  # str, optional
-    "relates_to",  # str, optional
-))
-PresenceStreamRow = namedtuple("PresenceStreamRow", (
-    "user_id",  # str
-    "state",  # str
-    "last_active_ts",  # int
-    "last_federation_update_ts",  # int
-    "last_user_sync_ts",  # int
-    "status_msg",   # str
-    "currently_active",  # bool
-))
-TypingStreamRow = namedtuple("TypingStreamRow", (
-    "room_id",  # str
-    "user_ids",  # list(str)
-))
-ReceiptsStreamRow = namedtuple("ReceiptsStreamRow", (
-    "room_id",  # str
-    "receipt_type",  # str
-    "user_id",  # str
-    "event_id",  # str
-    "data",  # dict
-))
-PushRulesStreamRow = namedtuple("PushRulesStreamRow", (
-    "user_id",  # str
-))
-PushersStreamRow = namedtuple("PushersStreamRow", (
-    "user_id",  # str
-    "app_id",  # str
-    "pushkey",  # str
-    "deleted",  # bool
-))
-CachesStreamRow = namedtuple("CachesStreamRow", (
-    "cache_func",  # str
-    "keys",  # list(str)
-    "invalidation_ts",  # int
-))
-PublicRoomsStreamRow = namedtuple("PublicRoomsStreamRow", (
-    "room_id",  # str
-    "visibility",  # str
-    "appservice_id",  # str, optional
-    "network_id",  # str, optional
-))
-DeviceListsStreamRow = namedtuple("DeviceListsStreamRow", (
-    "user_id",  # str
-    "destination",  # str
-))
-ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", (
-    "entity",  # str
-))
-TagAccountDataStreamRow = namedtuple("TagAccountDataStreamRow", (
-    "user_id",  # str
-    "room_id",  # str
-    "data",  # dict
-))
-AccountDataStreamRow = namedtuple("AccountDataStream", (
-    "user_id",  # str
-    "room_id",  # str
-    "data_type",  # str
-    "data",  # dict
-))
-GroupsStreamRow = namedtuple("GroupsStreamRow", (
-    "group_id",  # str
-    "user_id",  # str
-    "type",  # str
-    "content",  # dict
-))
+BackfillStreamRow = namedtuple(
+    "BackfillStreamRow",
+    (
+        "event_id",  # str
+        "room_id",  # str
+        "type",  # str
+        "state_key",  # str, optional
+        "redacts",  # str, optional
+        "relates_to",  # str, optional
+    ),
+)
+PresenceStreamRow = namedtuple(
+    "PresenceStreamRow",
+    (
+        "user_id",  # str
+        "state",  # str
+        "last_active_ts",  # int
+        "last_federation_update_ts",  # int
+        "last_user_sync_ts",  # int
+        "status_msg",  # str
+        "currently_active",  # bool
+    ),
+)
+TypingStreamRow = namedtuple(
+    "TypingStreamRow", ("room_id", "user_ids")  # str  # list(str)
+)
+ReceiptsStreamRow = namedtuple(
+    "ReceiptsStreamRow",
+    (
+        "room_id",  # str
+        "receipt_type",  # str
+        "user_id",  # str
+        "event_id",  # str
+        "data",  # dict
+    ),
+)
+PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",))  # str
+PushersStreamRow = namedtuple(
+    "PushersStreamRow",
+    ("user_id", "app_id", "pushkey", "deleted"),  # str  # str  # str  # bool
+)
+CachesStreamRow = namedtuple(
+    "CachesStreamRow",
+    ("cache_func", "keys", "invalidation_ts"),  # str  # list(str)  # int
+)
+PublicRoomsStreamRow = namedtuple(
+    "PublicRoomsStreamRow",
+    (
+        "room_id",  # str
+        "visibility",  # str
+        "appservice_id",  # str, optional
+        "network_id",  # str, optional
+    ),
+)
+DeviceListsStreamRow = namedtuple(
+    "DeviceListsStreamRow", ("user_id", "destination")  # str  # str
+)
+ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",))  # str
+TagAccountDataStreamRow = namedtuple(
+    "TagAccountDataStreamRow", ("user_id", "room_id", "data")  # str  # str  # dict
+)
+AccountDataStreamRow = namedtuple(
+    "AccountDataStream",
+    ("user_id", "room_id", "data_type", "data"),  # str  # str  # str  # dict
+)
+GroupsStreamRow = namedtuple(
+    "GroupsStreamRow",
+    ("group_id", "user_id", "type", "content"),  # str  # str  # str  # dict
+)
 
 
 class Stream(object):
@@ -106,6 +103,7 @@ class Stream(object):
     Provides a `get_updates()` function that returns new updates since the last
     time it was called up until the point `advance_current_token` was called.
     """
+
     NAME = None  # The name of the stream
     ROW_TYPE = None  # The type of the row. Used by the default impl of parse_row.
     _LIMITED = True  # Whether the update function takes a limit
@@ -185,16 +183,13 @@ class Stream(object):
 
         if self._LIMITED:
             rows = yield self.update_function(
-                from_token, current_token,
-                limit=MAX_EVENTS_BEHIND + 1,
+                from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
             )
 
             # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
             rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
         else:
-            rows = yield self.update_function(
-                from_token, current_token,
-            )
+            rows = yield self.update_function(from_token, current_token)
 
         updates = [(row[0], row[1:]) for row in rows]
 
@@ -230,6 +225,7 @@ class BackfillStream(Stream):
     """We fetched some old events and either we had never seen that event before
     or it went from being an outlier to not.
     """
+
     NAME = "backfill"
     ROW_TYPE = BackfillStreamRow
 
@@ -286,6 +282,7 @@ class ReceiptsStream(Stream):
 class PushRulesStream(Stream):
     """A user has changed their push rules
     """
+
     NAME = "push_rules"
     ROW_TYPE = PushRulesStreamRow
 
@@ -306,6 +303,7 @@ class PushRulesStream(Stream):
 class PushersStream(Stream):
     """A user has added/changed/removed a pusher
     """
+
     NAME = "pushers"
     ROW_TYPE = PushersStreamRow
 
@@ -322,6 +320,7 @@ class CachesStream(Stream):
     """A cache was invalidated on the master and no other stream would invalidate
     the cache on the workers
     """
+
     NAME = "caches"
     ROW_TYPE = CachesStreamRow
 
@@ -337,6 +336,7 @@ class CachesStream(Stream):
 class PublicRoomsStream(Stream):
     """The public rooms list changed
     """
+
     NAME = "public_rooms"
     ROW_TYPE = PublicRoomsStreamRow
 
@@ -352,6 +352,7 @@ class PublicRoomsStream(Stream):
 class DeviceListsStream(Stream):
     """Someone added/changed/removed a device
     """
+
     NAME = "device_lists"
     _LIMITED = False
     ROW_TYPE = DeviceListsStreamRow
@@ -368,6 +369,7 @@ class DeviceListsStream(Stream):
 class ToDeviceStream(Stream):
     """New to_device messages for a client
     """
+
     NAME = "to_device"
     ROW_TYPE = ToDeviceStreamRow
 
@@ -383,6 +385,7 @@ class ToDeviceStream(Stream):
 class TagAccountDataStream(Stream):
     """Someone added/removed a tag for a room
     """
+
     NAME = "tag_account_data"
     ROW_TYPE = TagAccountDataStreamRow
 
@@ -398,6 +401,7 @@ class TagAccountDataStream(Stream):
 class AccountDataStream(Stream):
     """Global or per room account data was changed
     """
+
     NAME = "account_data"
     ROW_TYPE = AccountDataStreamRow
 
@@ -416,7 +420,7 @@ class AccountDataStream(Stream):
 
         results = list(room_results)
         results.extend(
-            (stream_id, user_id, None, account_data_type, content,)
+            (stream_id, user_id, None, account_data_type, content)
             for stream_id, user_id, account_data_type, content in global_results
         )
 
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index f1290d022a..3d0694bb11 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -52,6 +52,7 @@ data part are:
 @attr.s(slots=True, frozen=True)
 class EventsStreamRow(object):
     """A parsed row from the events replication stream"""
+
     type = attr.ib()  # str: the TypeId of one of the *EventsStreamRows
     data = attr.ib()  # BaseEventsStreamRow
 
@@ -80,11 +81,11 @@ class BaseEventsStreamRow(object):
 class EventsStreamEventRow(BaseEventsStreamRow):
     TypeId = "ev"
 
-    event_id = attr.ib()    # str
-    room_id = attr.ib()     # str
-    type = attr.ib()        # str
-    state_key = attr.ib()   # str, optional
-    redacts = attr.ib()     # str, optional
+    event_id = attr.ib()  # str
+    room_id = attr.ib()  # str
+    type = attr.ib()  # str
+    state_key = attr.ib()  # str, optional
+    redacts = attr.ib()  # str, optional
     relates_to = attr.ib()  # str, optional
 
 
@@ -92,24 +93,21 @@ class EventsStreamEventRow(BaseEventsStreamRow):
 class EventsStreamCurrentStateRow(BaseEventsStreamRow):
     TypeId = "state"
 
-    room_id = attr.ib()    # str
-    type = attr.ib()       # str
+    room_id = attr.ib()  # str
+    type = attr.ib()  # str
     state_key = attr.ib()  # str
-    event_id = attr.ib()   # str, optional
+    event_id = attr.ib()  # str, optional
 
 
 TypeToRow = {
-    Row.TypeId: Row
-    for Row in (
-        EventsStreamEventRow,
-        EventsStreamCurrentStateRow,
-    )
+    Row.TypeId: Row for Row in (EventsStreamEventRow, EventsStreamCurrentStateRow)
 }
 
 
 class EventsStream(Stream):
     """We received a new event, or an event went from being an outlier to not
     """
+
     NAME = "events"
 
     def __init__(self, hs):
@@ -121,19 +119,17 @@ class EventsStream(Stream):
     @defer.inlineCallbacks
     def update_function(self, from_token, current_token, limit=None):
         event_rows = yield self._store.get_all_new_forward_event_rows(
-            from_token, current_token, limit,
+            from_token, current_token, limit
         )
         event_updates = (
-            (row[0], EventsStreamEventRow.TypeId, row[1:])
-            for row in event_rows
+            (row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows
         )
 
         state_rows = yield self._store.get_all_updated_current_state_deltas(
             from_token, current_token, limit
         )
         state_updates = (
-            (row[0], EventsStreamCurrentStateRow.TypeId, row[1:])
-            for row in state_rows
+            (row[0], EventsStreamCurrentStateRow.TypeId, row[1:]) for row in state_rows
         )
 
         all_updates = heapq.merge(event_updates, state_updates)
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 9aa43aa8d2..dc2484109d 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -17,16 +17,20 @@ from collections import namedtuple
 
 from ._base import Stream
 
-FederationStreamRow = namedtuple("FederationStreamRow", (
-    "type",  # str, the type of data as defined in the BaseFederationRows
-    "data",  # dict, serialization of a federation.send_queue.BaseFederationRow
-))
+FederationStreamRow = namedtuple(
+    "FederationStreamRow",
+    (
+        "type",  # str, the type of data as defined in the BaseFederationRows
+        "data",  # dict, serialization of a federation.send_queue.BaseFederationRow
+    ),
+)
 
 
 class FederationStream(Stream):
     """Data to be sent over federation. Only available when master has federation
     sending disabled.
     """
+
     NAME = "federation"
     ROW_TYPE = FederationStreamRow