summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/federation_base.py77
-rw-r--r--synapse/federation/federation_client.py217
-rw-r--r--synapse/federation/federation_server.py234
-rw-r--r--synapse/federation/persistence.py15
-rw-r--r--synapse/federation/send_queue.py145
-rw-r--r--synapse/federation/sender/__init__.py70
-rw-r--r--synapse/federation/sender/per_destination_queue.py45
-rw-r--r--synapse/federation/sender/transaction_manager.py38
-rw-r--r--synapse/federation/transport/client.py294
-rw-r--r--synapse/federation/transport/server.py245
-rw-r--r--synapse/federation/units.py33
11 files changed, 644 insertions, 769 deletions
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index fc5cfb7d83..1e925b19e7 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -44,8 +44,9 @@ class FederationBase(object):
         self._clock = hs.get_clock()
 
     @defer.inlineCallbacks
-    def _check_sigs_and_hash_and_fetch(self, origin, pdus, room_version,
-                                       outlier=False, include_none=False):
+    def _check_sigs_and_hash_and_fetch(
+        self, origin, pdus, room_version, outlier=False, include_none=False
+    ):
         """Takes a list of PDUs and checks the signatures and hashs of each
         one. If a PDU fails its signature check then we check if we have it in
         the database and if not then request if from the originating server of
@@ -79,9 +80,7 @@ class FederationBase(object):
             if not res:
                 # Check local db.
                 res = yield self.store.get_event(
-                    pdu.event_id,
-                    allow_rejected=True,
-                    allow_none=True,
+                    pdu.event_id, allow_rejected=True, allow_none=True
                 )
 
             if not res and pdu.origin != origin:
@@ -98,23 +97,16 @@ class FederationBase(object):
 
             if not res:
                 logger.warn(
-                    "Failed to find copy of %s with valid signature",
-                    pdu.event_id,
+                    "Failed to find copy of %s with valid signature", pdu.event_id
                 )
 
             defer.returnValue(res)
 
         handle = logcontext.preserve_fn(handle_check_result)
-        deferreds2 = [
-            handle(pdu, deferred)
-            for pdu, deferred in zip(pdus, deferreds)
-        ]
+        deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)]
 
         valid_pdus = yield logcontext.make_deferred_yieldable(
-            defer.gatherResults(
-                deferreds2,
-                consumeErrors=True,
-            )
+            defer.gatherResults(deferreds2, consumeErrors=True)
         ).addErrback(unwrapFirstError)
 
         if include_none:
@@ -124,7 +116,7 @@ class FederationBase(object):
 
     def _check_sigs_and_hash(self, room_version, pdu):
         return logcontext.make_deferred_yieldable(
-            self._check_sigs_and_hashes(room_version, [pdu])[0],
+            self._check_sigs_and_hashes(room_version, [pdu])[0]
         )
 
     def _check_sigs_and_hashes(self, room_version, pdus):
@@ -159,11 +151,9 @@ class FederationBase(object):
                     # received event was probably a redacted copy (but we then use our
                     # *actual* redacted copy to be on the safe side.)
                     redacted_event = prune_event(pdu)
-                    if (
-                        set(redacted_event.keys()) == set(pdu.keys()) and
-                        set(six.iterkeys(redacted_event.content))
-                            == set(six.iterkeys(pdu.content))
-                    ):
+                    if set(redacted_event.keys()) == set(pdu.keys()) and set(
+                        six.iterkeys(redacted_event.content)
+                    ) == set(six.iterkeys(pdu.content)):
                         logger.info(
                             "Event %s seems to have been redacted; using our redacted "
                             "copy",
@@ -172,14 +162,15 @@ class FederationBase(object):
                     else:
                         logger.warning(
                             "Event %s content has been tampered, redacting",
-                            pdu.event_id, pdu.get_pdu_json(),
+                            pdu.event_id,
                         )
                     return redacted_event
 
                 if self.spam_checker.check_event_for_spam(pdu):
                     logger.warn(
                         "Event contains spam, redacting %s: %s",
-                        pdu.event_id, pdu.get_pdu_json()
+                        pdu.event_id,
+                        pdu.get_pdu_json(),
                     )
                     return prune_event(pdu)
 
@@ -190,23 +181,24 @@ class FederationBase(object):
             with logcontext.PreserveLoggingContext(ctx):
                 logger.warn(
                     "Signature check failed for %s: %s",
-                    pdu.event_id, failure.getErrorMessage(),
+                    pdu.event_id,
+                    failure.getErrorMessage(),
                 )
             return failure
 
         for deferred, pdu in zip(deferreds, pdus):
             deferred.addCallbacks(
-                callback, errback,
-                callbackArgs=[pdu],
-                errbackArgs=[pdu],
+                callback, errback, callbackArgs=[pdu], errbackArgs=[pdu]
             )
 
         return deferreds
 
 
-class PduToCheckSig(namedtuple("PduToCheckSig", [
-    "pdu", "redacted_pdu_json", "sender_domain", "deferreds",
-])):
+class PduToCheckSig(
+    namedtuple(
+        "PduToCheckSig", ["pdu", "redacted_pdu_json", "sender_domain", "deferreds"]
+    )
+):
     pass
 
 
@@ -260,10 +252,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
 
     # First we check that the sender event is signed by the sender's domain
     # (except if its a 3pid invite, in which case it may be sent by any server)
-    pdus_to_check_sender = [
-        p for p in pdus_to_check
-        if not _is_invite_via_3pid(p.pdu)
-    ]
+    pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
 
     more_deferreds = keyring.verify_json_objects_for_server(
         [
@@ -297,7 +286,8 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
     # (ie, the room version uses old-style non-hash event IDs).
     if v.event_format == EventFormatVersions.V1:
         pdus_to_check_event_id = [
-            p for p in pdus_to_check
+            p
+            for p in pdus_to_check
             if p.sender_domain != get_domain_from_id(p.pdu.event_id)
         ]
 
@@ -315,10 +305,8 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
 
         def event_err(e, pdu_to_check):
             errmsg = (
-                "event id %s: unable to verify signature for event id domain: %s" % (
-                    pdu_to_check.pdu.event_id,
-                    e.getErrorMessage(),
-                )
+                "event id %s: unable to verify signature for event id domain: %s"
+                % (pdu_to_check.pdu.event_id, e.getErrorMessage())
             )
             # XX as above: not really sure if these are the right codes
             raise SynapseError(400, errmsg, Codes.UNAUTHORIZED)
@@ -368,21 +356,18 @@ def event_from_pdu_json(pdu_json, event_format_version, outlier=False):
     """
     # we could probably enforce a bunch of other fields here (room_id, sender,
     # origin, etc etc)
-    assert_params_in_dict(pdu_json, ('type', 'depth'))
+    assert_params_in_dict(pdu_json, ("type", "depth"))
 
-    depth = pdu_json['depth']
+    depth = pdu_json["depth"]
     if not isinstance(depth, six.integer_types):
-        raise SynapseError(400, "Depth %r not an intger" % (depth, ),
-                           Codes.BAD_JSON)
+        raise SynapseError(400, "Depth %r not an intger" % (depth,), Codes.BAD_JSON)
 
     if depth < 0:
         raise SynapseError(400, "Depth too small", Codes.BAD_JSON)
     elif depth > MAX_DEPTH:
         raise SynapseError(400, "Depth too large", Codes.BAD_JSON)
 
-    event = event_type_from_format_version(event_format_version)(
-        pdu_json,
-    )
+    event = event_type_from_format_version(event_format_version)(pdu_json)
 
     event.internal_metadata.outlier = outlier
 
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 70573746d6..3883eb525e 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -57,6 +57,7 @@ class InvalidResponseError(RuntimeError):
     """Helper for _try_destination_list: indicates that the server returned a response
     we couldn't parse
     """
+
     pass
 
 
@@ -65,9 +66,7 @@ class FederationClient(FederationBase):
         super(FederationClient, self).__init__(hs)
 
         self.pdu_destination_tried = {}
-        self._clock.looping_call(
-            self._clear_tried_cache, 60 * 1000,
-        )
+        self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
         self.state = hs.get_state_handler()
         self.transport_layer = hs.get_federation_transport_client()
 
@@ -99,8 +98,14 @@ class FederationClient(FederationBase):
                 self.pdu_destination_tried[event_id] = destination_dict
 
     @log_function
-    def make_query(self, destination, query_type, args,
-                   retry_on_dns_fail=False, ignore_backoff=False):
+    def make_query(
+        self,
+        destination,
+        query_type,
+        args,
+        retry_on_dns_fail=False,
+        ignore_backoff=False,
+    ):
         """Sends a federation Query to a remote homeserver of the given type
         and arguments.
 
@@ -120,7 +125,10 @@ class FederationClient(FederationBase):
         sent_queries_counter.labels(query_type).inc()
 
         return self.transport_layer.make_query(
-            destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail,
+            destination,
+            query_type,
+            args,
+            retry_on_dns_fail=retry_on_dns_fail,
             ignore_backoff=ignore_backoff,
         )
 
@@ -137,9 +145,7 @@ class FederationClient(FederationBase):
             response
         """
         sent_queries_counter.labels("client_device_keys").inc()
-        return self.transport_layer.query_client_keys(
-            destination, content, timeout
-        )
+        return self.transport_layer.query_client_keys(destination, content, timeout)
 
     @log_function
     def query_user_devices(self, destination, user_id, timeout=30000):
@@ -147,9 +153,7 @@ class FederationClient(FederationBase):
         server.
         """
         sent_queries_counter.labels("user_devices").inc()
-        return self.transport_layer.query_user_devices(
-            destination, user_id, timeout
-        )
+        return self.transport_layer.query_user_devices(destination, user_id, timeout)
 
     @log_function
     def claim_client_keys(self, destination, content, timeout):
@@ -164,9 +168,7 @@ class FederationClient(FederationBase):
             response
         """
         sent_queries_counter.labels("client_one_time_keys").inc()
-        return self.transport_layer.claim_client_keys(
-            destination, content, timeout
-        )
+        return self.transport_layer.claim_client_keys(destination, content, timeout)
 
     @defer.inlineCallbacks
     @log_function
@@ -191,7 +193,8 @@ class FederationClient(FederationBase):
             return
 
         transaction_data = yield self.transport_layer.backfill(
-            dest, room_id, extremities, limit)
+            dest, room_id, extremities, limit
+        )
 
         logger.debug("backfill transaction_data=%s", repr(transaction_data))
 
@@ -204,17 +207,19 @@ class FederationClient(FederationBase):
         ]
 
         # FIXME: We should handle signature failures more gracefully.
-        pdus[:] = yield logcontext.make_deferred_yieldable(defer.gatherResults(
-            self._check_sigs_and_hashes(room_version, pdus),
-            consumeErrors=True,
-        ).addErrback(unwrapFirstError))
+        pdus[:] = yield logcontext.make_deferred_yieldable(
+            defer.gatherResults(
+                self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True
+            ).addErrback(unwrapFirstError)
+        )
 
         defer.returnValue(pdus)
 
     @defer.inlineCallbacks
     @log_function
-    def get_pdu(self, destinations, event_id, room_version, outlier=False,
-                timeout=None):
+    def get_pdu(
+        self, destinations, event_id, room_version, outlier=False, timeout=None
+    ):
         """Requests the PDU with given origin and ID from the remote home
         servers.
 
@@ -255,7 +260,7 @@ class FederationClient(FederationBase):
 
             try:
                 transaction_data = yield self.transport_layer.get_event(
-                    destination, event_id, timeout=timeout,
+                    destination, event_id, timeout=timeout
                 )
 
                 logger.debug(
@@ -282,8 +287,7 @@ class FederationClient(FederationBase):
 
             except SynapseError as e:
                 logger.info(
-                    "Failed to get PDU %s from %s because %s",
-                    event_id, destination, e,
+                    "Failed to get PDU %s from %s because %s", event_id, destination, e
                 )
                 continue
             except NotRetryingDestination as e:
@@ -296,8 +300,7 @@ class FederationClient(FederationBase):
                 pdu_attempts[destination] = now
 
                 logger.info(
-                    "Failed to get PDU %s from %s because %s",
-                    event_id, destination, e,
+                    "Failed to get PDU %s from %s because %s", event_id, destination, e
                 )
                 continue
 
@@ -326,7 +329,7 @@ class FederationClient(FederationBase):
             # we have most of the state and auth_chain already.
             # However, this may 404 if the other side has an old synapse.
             result = yield self.transport_layer.get_room_state_ids(
-                destination, room_id, event_id=event_id,
+                destination, room_id, event_id=event_id
             )
 
             state_event_ids = result["pdu_ids"]
@@ -340,12 +343,10 @@ class FederationClient(FederationBase):
                 logger.warning(
                     "Failed to fetch missing state/auth events for %s: %s",
                     room_id,
-                    failed_to_fetch
+                    failed_to_fetch,
                 )
 
-            event_map = {
-                ev.event_id: ev for ev in fetched_events
-            }
+            event_map = {ev.event_id: ev for ev in fetched_events}
 
             pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
             auth_chain = [
@@ -362,15 +363,14 @@ class FederationClient(FederationBase):
                 raise e
 
         result = yield self.transport_layer.get_room_state(
-            destination, room_id, event_id=event_id,
+            destination, room_id, event_id=event_id
         )
 
         room_version = yield self.store.get_room_version(room_id)
         format_ver = room_version_to_event_format(room_version)
 
         pdus = [
-            event_from_pdu_json(p, format_ver, outlier=True)
-            for p in result["pdus"]
+            event_from_pdu_json(p, format_ver, outlier=True) for p in result["pdus"]
         ]
 
         auth_chain = [
@@ -378,9 +378,9 @@ class FederationClient(FederationBase):
             for p in result.get("auth_chain", [])
         ]
 
-        seen_events = yield self.store.get_events([
-            ev.event_id for ev in itertools.chain(pdus, auth_chain)
-        ])
+        seen_events = yield self.store.get_events(
+            [ev.event_id for ev in itertools.chain(pdus, auth_chain)]
+        )
 
         signed_pdus = yield self._check_sigs_and_hash_and_fetch(
             destination,
@@ -442,7 +442,7 @@ class FederationClient(FederationBase):
         batch_size = 20
         missing_events = list(missing_events)
         for i in range(0, len(missing_events), batch_size):
-            batch = set(missing_events[i:i + batch_size])
+            batch = set(missing_events[i : i + batch_size])
 
             deferreds = [
                 run_in_background(
@@ -470,21 +470,17 @@ class FederationClient(FederationBase):
     @defer.inlineCallbacks
     @log_function
     def get_event_auth(self, destination, room_id, event_id):
-        res = yield self.transport_layer.get_event_auth(
-            destination, room_id, event_id,
-        )
+        res = yield self.transport_layer.get_event_auth(destination, room_id, event_id)
 
         room_version = yield self.store.get_room_version(room_id)
         format_ver = room_version_to_event_format(room_version)
 
         auth_chain = [
-            event_from_pdu_json(p, format_ver, outlier=True)
-            for p in res["auth_chain"]
+            event_from_pdu_json(p, format_ver, outlier=True) for p in res["auth_chain"]
         ]
 
         signed_auth = yield self._check_sigs_and_hash_and_fetch(
-            destination, auth_chain,
-            outlier=True, room_version=room_version,
+            destination, auth_chain, outlier=True, room_version=room_version
         )
 
         signed_auth.sort(key=lambda e: e.depth)
@@ -527,28 +523,26 @@ class FederationClient(FederationBase):
                 res = yield callback(destination)
                 defer.returnValue(res)
             except InvalidResponseError as e:
-                logger.warn(
-                    "Failed to %s via %s: %s",
-                    description, destination, e,
-                )
+                logger.warn("Failed to %s via %s: %s", description, destination, e)
             except HttpResponseException as e:
                 if not 500 <= e.code < 600:
                     raise e.to_synapse_error()
                 else:
                     logger.warn(
                         "Failed to %s via %s: %i %s",
-                        description, destination, e.code, e.args[0],
+                        description,
+                        destination,
+                        e.code,
+                        e.args[0],
                     )
             except Exception:
-                logger.warn(
-                    "Failed to %s via %s",
-                    description, destination, exc_info=1,
-                )
+                logger.warn("Failed to %s via %s", description, destination, exc_info=1)
 
-        raise RuntimeError("Failed to %s via any server" % (description, ))
+        raise RuntimeError("Failed to %s via any server" % (description,))
 
-    def make_membership_event(self, destinations, room_id, user_id, membership,
-                              content, params):
+    def make_membership_event(
+        self, destinations, room_id, user_id, membership, content, params
+    ):
         """
         Creates an m.room.member event, with context, without participating in the room.
 
@@ -584,14 +578,14 @@ class FederationClient(FederationBase):
         valid_memberships = {Membership.JOIN, Membership.LEAVE}
         if membership not in valid_memberships:
             raise RuntimeError(
-                "make_membership_event called with membership='%s', must be one of %s" %
-                (membership, ",".join(valid_memberships))
+                "make_membership_event called with membership='%s', must be one of %s"
+                % (membership, ",".join(valid_memberships))
             )
 
         @defer.inlineCallbacks
         def send_request(destination):
             ret = yield self.transport_layer.make_membership_event(
-                destination, room_id, user_id, membership, params,
+                destination, room_id, user_id, membership, params
             )
 
             # Note: If not supplied, the room version may be either v1 or v2,
@@ -614,16 +608,17 @@ class FederationClient(FederationBase):
                 pdu_dict["prev_state"] = []
 
             ev = builder.create_local_event_from_event_dict(
-                self._clock, self.hostname, self.signing_key,
-                format_version=event_format, event_dict=pdu_dict,
+                self._clock,
+                self.hostname,
+                self.signing_key,
+                format_version=event_format,
+                event_dict=pdu_dict,
             )
 
-            defer.returnValue(
-                (destination, ev, event_format)
-            )
+            defer.returnValue((destination, ev, event_format))
 
         return self._try_destination_list(
-            "make_" + membership, destinations, send_request,
+            "make_" + membership, destinations, send_request
         )
 
     def send_join(self, destinations, pdu, event_format_version):
@@ -655,9 +650,7 @@ class FederationClient(FederationBase):
                     create_event = e
                     break
             else:
-                raise InvalidResponseError(
-                    "no %s in auth chain" % (EventTypes.Create,),
-                )
+                raise InvalidResponseError("no %s in auth chain" % (EventTypes.Create,))
 
             # the room version should be sane.
             room_version = create_event.content.get("room_version", "1")
@@ -665,9 +658,8 @@ class FederationClient(FederationBase):
                 # This shouldn't be possible, because the remote server should have
                 # rejected the join attempt during make_join.
                 raise InvalidResponseError(
-                    "room appears to have unsupported version %s" % (
-                        room_version,
-                    ))
+                    "room appears to have unsupported version %s" % (room_version,)
+                )
 
         @defer.inlineCallbacks
         def send_request(destination):
@@ -691,10 +683,7 @@ class FederationClient(FederationBase):
                 for p in content.get("auth_chain", [])
             ]
 
-            pdus = {
-                p.event_id: p
-                for p in itertools.chain(state, auth_chain)
-            }
+            pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)}
 
             room_version = None
             for e in state:
@@ -710,15 +699,13 @@ class FederationClient(FederationBase):
                 raise SynapseError(400, "No create event in state")
 
             valid_pdus = yield self._check_sigs_and_hash_and_fetch(
-                destination, list(pdus.values()),
+                destination,
+                list(pdus.values()),
                 outlier=True,
                 room_version=room_version,
             )
 
-            valid_pdus_map = {
-                p.event_id: p
-                for p in valid_pdus
-            }
+            valid_pdus_map = {p.event_id: p for p in valid_pdus}
 
             # NB: We *need* to copy to ensure that we don't have multiple
             # references being passed on, as that causes... issues.
@@ -741,11 +728,14 @@ class FederationClient(FederationBase):
 
             check_authchain_validity(signed_auth)
 
-            defer.returnValue({
-                "state": signed_state,
-                "auth_chain": signed_auth,
-                "origin": destination,
-            })
+            defer.returnValue(
+                {
+                    "state": signed_state,
+                    "auth_chain": signed_auth,
+                    "origin": destination,
+                }
+            )
+
         return self._try_destination_list("send_join", destinations, send_request)
 
     @defer.inlineCallbacks
@@ -854,6 +844,7 @@ class FederationClient(FederationBase):
 
             Fails with a ``RuntimeError`` if no servers were reachable.
         """
+
         @defer.inlineCallbacks
         def send_request(destination):
             time_now = self._clock.time_msec()
@@ -869,14 +860,23 @@ class FederationClient(FederationBase):
 
         return self._try_destination_list("send_leave", destinations, send_request)
 
-    def get_public_rooms(self, destination, limit=None, since_token=None,
-                         search_filter=None, include_all_networks=False,
-                         third_party_instance_id=None):
+    def get_public_rooms(
+        self,
+        destination,
+        limit=None,
+        since_token=None,
+        search_filter=None,
+        include_all_networks=False,
+        third_party_instance_id=None,
+    ):
         if destination == self.server_name:
             return
 
         return self.transport_layer.get_public_rooms(
-            destination, limit, since_token, search_filter,
+            destination,
+            limit,
+            since_token,
+            search_filter,
             include_all_networks=include_all_networks,
             third_party_instance_id=third_party_instance_id,
         )
@@ -891,9 +891,7 @@ class FederationClient(FederationBase):
         """
         time_now = self._clock.time_msec()
 
-        send_content = {
-            "auth_chain": [e.get_pdu_json(time_now) for e in local_auth],
-        }
+        send_content = {"auth_chain": [e.get_pdu_json(time_now) for e in local_auth]}
 
         code, content = yield self.transport_layer.send_query_auth(
             destination=destination,
@@ -905,13 +903,10 @@ class FederationClient(FederationBase):
         room_version = yield self.store.get_room_version(room_id)
         format_ver = room_version_to_event_format(room_version)
 
-        auth_chain = [
-            event_from_pdu_json(e, format_ver)
-            for e in content["auth_chain"]
-        ]
+        auth_chain = [event_from_pdu_json(e, format_ver) for e in content["auth_chain"]]
 
         signed_auth = yield self._check_sigs_and_hash_and_fetch(
-            destination, auth_chain, outlier=True, room_version=room_version,
+            destination, auth_chain, outlier=True, room_version=room_version
         )
 
         signed_auth.sort(key=lambda e: e.depth)
@@ -925,8 +920,16 @@ class FederationClient(FederationBase):
         defer.returnValue(ret)
 
     @defer.inlineCallbacks
-    def get_missing_events(self, destination, room_id, earliest_events_ids,
-                           latest_events, limit, min_depth, timeout):
+    def get_missing_events(
+        self,
+        destination,
+        room_id,
+        earliest_events_ids,
+        latest_events,
+        limit,
+        min_depth,
+        timeout,
+    ):
         """Tries to fetch events we are missing. This is called when we receive
         an event without having received all of its ancestors.
 
@@ -957,12 +960,11 @@ class FederationClient(FederationBase):
             format_ver = room_version_to_event_format(room_version)
 
             events = [
-                event_from_pdu_json(e, format_ver)
-                for e in content.get("events", [])
+                event_from_pdu_json(e, format_ver) for e in content.get("events", [])
             ]
 
             signed_events = yield self._check_sigs_and_hash_and_fetch(
-                destination, events, outlier=False, room_version=room_version,
+                destination, events, outlier=False, room_version=room_version
             )
         except HttpResponseException as e:
             if not e.code == 400:
@@ -982,17 +984,14 @@ class FederationClient(FederationBase):
 
             try:
                 yield self.transport_layer.exchange_third_party_invite(
-                    destination=destination,
-                    room_id=room_id,
-                    event_dict=event_dict,
+                    destination=destination, room_id=room_id, event_dict=event_dict
                 )
                 defer.returnValue(None)
             except CodeMessageException:
                 raise
             except Exception as e:
                 logger.exception(
-                    "Failed to send_third_party_invite via %s: %s",
-                    destination, str(e)
+                    "Failed to send_third_party_invite via %s: %s", destination, str(e)
                 )
 
         raise RuntimeError("Failed to send to any server.")
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 4c28c1dc3c..2e0cebb638 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -69,7 +69,6 @@ received_queries_counter = Counter(
 
 
 class FederationServer(FederationBase):
-
     def __init__(self, hs):
         super(FederationServer, self).__init__(hs)
 
@@ -118,11 +117,13 @@ class FederationServer(FederationBase):
 
         # use a linearizer to ensure that we don't process the same transaction
         # multiple times in parallel.
-        with (yield self._transaction_linearizer.queue(
-                (origin, transaction.transaction_id),
-        )):
+        with (
+            yield self._transaction_linearizer.queue(
+                (origin, transaction.transaction_id)
+            )
+        ):
             result = yield self._handle_incoming_transaction(
-                origin, transaction, request_time,
+                origin, transaction, request_time
             )
 
         defer.returnValue(result)
@@ -144,7 +145,7 @@ class FederationServer(FederationBase):
         if response:
             logger.debug(
                 "[%s] We've already responded to this request",
-                transaction.transaction_id
+                transaction.transaction_id,
             )
             defer.returnValue(response)
             return
@@ -152,18 +153,15 @@ class FederationServer(FederationBase):
         logger.debug("[%s] Transaction is new", transaction.transaction_id)
 
         # Reject if PDU count > 50 and EDU count > 100
-        if (len(transaction.pdus) > 50
-                or (hasattr(transaction, "edus") and len(transaction.edus) > 100)):
+        if len(transaction.pdus) > 50 or (
+            hasattr(transaction, "edus") and len(transaction.edus) > 100
+        ):
 
-            logger.info(
-                "Transaction PDU or EDU count too large. Returning 400",
-            )
+            logger.info("Transaction PDU or EDU count too large. Returning 400")
 
             response = {}
             yield self.transaction_actions.set_response(
-                origin,
-                transaction,
-                400, response
+                origin, transaction, 400, response
             )
             defer.returnValue((400, response))
 
@@ -230,9 +228,7 @@ class FederationServer(FederationBase):
             try:
                 yield self.check_server_matches_acl(origin_host, room_id)
             except AuthError as e:
-                logger.warn(
-                    "Ignoring PDUs for room %s from banned server", room_id,
-                )
+                logger.warn("Ignoring PDUs for room %s from banned server", room_id)
                 for pdu in pdus_by_room[room_id]:
                     event_id = pdu.event_id
                     pdu_results[event_id] = e.error_dict()
@@ -242,9 +238,7 @@ class FederationServer(FederationBase):
                 event_id = pdu.event_id
                 with nested_logging_context(event_id):
                     try:
-                        yield self._handle_received_pdu(
-                            origin, pdu
-                        )
+                        yield self._handle_received_pdu(origin, pdu)
                         pdu_results[event_id] = {}
                     except FederationError as e:
                         logger.warn("Error handling PDU %s: %s", event_id, e)
@@ -259,29 +253,18 @@ class FederationServer(FederationBase):
                         )
 
         yield concurrently_execute(
-            process_pdus_for_room, pdus_by_room.keys(),
-            TRANSACTION_CONCURRENCY_LIMIT,
+            process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
         )
 
         if hasattr(transaction, "edus"):
             for edu in (Edu(**x) for x in transaction.edus):
-                yield self.received_edu(
-                    origin,
-                    edu.edu_type,
-                    edu.content
-                )
+                yield self.received_edu(origin, edu.edu_type, edu.content)
 
-        response = {
-            "pdus": pdu_results,
-        }
+        response = {"pdus": pdu_results}
 
         logger.debug("Returning: %s", str(response))
 
-        yield self.transaction_actions.set_response(
-            origin,
-            transaction,
-            200, response
-        )
+        yield self.transaction_actions.set_response(origin, transaction, 200, response)
         defer.returnValue((200, response))
 
     @defer.inlineCallbacks
@@ -311,7 +294,8 @@ class FederationServer(FederationBase):
             resp = yield self._state_resp_cache.wrap(
                 (room_id, event_id),
                 self._on_context_state_request_compute,
-                room_id, event_id,
+                room_id,
+                event_id,
             )
 
         defer.returnValue((200, resp))
@@ -328,24 +312,17 @@ class FederationServer(FederationBase):
         if not in_room:
             raise AuthError(403, "Host not in room.")
 
-        state_ids = yield self.handler.get_state_ids_for_pdu(
-            room_id, event_id,
-        )
+        state_ids = yield self.handler.get_state_ids_for_pdu(room_id, event_id)
         auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids)
 
-        defer.returnValue((200, {
-            "pdu_ids": state_ids,
-            "auth_chain_ids": auth_chain_ids,
-        }))
+        defer.returnValue(
+            (200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids})
+        )
 
     @defer.inlineCallbacks
     def _on_context_state_request_compute(self, room_id, event_id):
-        pdus = yield self.handler.get_state_for_pdu(
-            room_id, event_id,
-        )
-        auth_chain = yield self.store.get_auth_chain(
-            [pdu.event_id for pdu in pdus]
-        )
+        pdus = yield self.handler.get_state_for_pdu(room_id, event_id)
+        auth_chain = yield self.store.get_auth_chain([pdu.event_id for pdu in pdus])
 
         for event in auth_chain:
             # We sign these again because there was a bug where we
@@ -355,14 +332,16 @@ class FederationServer(FederationBase):
                     compute_event_signature(
                         event.get_pdu_json(),
                         self.hs.hostname,
-                        self.hs.config.signing_key[0]
+                        self.hs.config.signing_key[0],
                     )
                 )
 
-        defer.returnValue({
-            "pdus": [pdu.get_pdu_json() for pdu in pdus],
-            "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
-        })
+        defer.returnValue(
+            {
+                "pdus": [pdu.get_pdu_json() for pdu in pdus],
+                "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
+            }
+        )
 
     @defer.inlineCallbacks
     @log_function
@@ -370,9 +349,7 @@ class FederationServer(FederationBase):
         pdu = yield self.handler.get_persisted_pdu(origin, event_id)
 
         if pdu:
-            defer.returnValue(
-                (200, self._transaction_from_pdus([pdu]).get_dict())
-            )
+            defer.returnValue((200, self._transaction_from_pdus([pdu]).get_dict()))
         else:
             defer.returnValue((404, ""))
 
@@ -394,10 +371,9 @@ class FederationServer(FederationBase):
 
         pdu = yield self.handler.on_make_join_request(room_id, user_id)
         time_now = self._clock.time_msec()
-        defer.returnValue({
-            "event": pdu.get_pdu_json(time_now),
-            "room_version": room_version,
-        })
+        defer.returnValue(
+            {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
+        )
 
     @defer.inlineCallbacks
     def on_invite_request(self, origin, content, room_version):
@@ -431,12 +407,17 @@ class FederationServer(FederationBase):
         logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
         res_pdus = yield self.handler.on_send_join_request(origin, pdu)
         time_now = self._clock.time_msec()
-        defer.returnValue((200, {
-            "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
-            "auth_chain": [
-                p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
-            ],
-        }))
+        defer.returnValue(
+            (
+                200,
+                {
+                    "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
+                    "auth_chain": [
+                        p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
+                    ],
+                },
+            )
+        )
 
     @defer.inlineCallbacks
     def on_make_leave_request(self, origin, room_id, user_id):
@@ -447,10 +428,9 @@ class FederationServer(FederationBase):
         room_version = yield self.store.get_room_version(room_id)
 
         time_now = self._clock.time_msec()
-        defer.returnValue({
-            "event": pdu.get_pdu_json(time_now),
-            "room_version": room_version,
-        })
+        defer.returnValue(
+            {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
+        )
 
     @defer.inlineCallbacks
     def on_send_leave_request(self, origin, content, room_id):
@@ -475,9 +455,7 @@ class FederationServer(FederationBase):
 
             time_now = self._clock.time_msec()
             auth_pdus = yield self.handler.on_event_auth(event_id)
-            res = {
-                "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
-            }
+            res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]}
         defer.returnValue((200, res))
 
     @defer.inlineCallbacks
@@ -508,12 +486,11 @@ class FederationServer(FederationBase):
             format_ver = room_version_to_event_format(room_version)
 
             auth_chain = [
-                event_from_pdu_json(e, format_ver)
-                for e in content["auth_chain"]
+                event_from_pdu_json(e, format_ver) for e in content["auth_chain"]
             ]
 
             signed_auth = yield self._check_sigs_and_hash_and_fetch(
-                origin, auth_chain, outlier=True, room_version=room_version,
+                origin, auth_chain, outlier=True, room_version=room_version
             )
 
             ret = yield self.handler.on_query_auth(
@@ -527,17 +504,12 @@ class FederationServer(FederationBase):
 
             time_now = self._clock.time_msec()
             send_content = {
-                "auth_chain": [
-                    e.get_pdu_json(time_now)
-                    for e in ret["auth_chain"]
-                ],
+                "auth_chain": [e.get_pdu_json(time_now) for e in ret["auth_chain"]],
                 "rejects": ret.get("rejects", []),
                 "missing": ret.get("missing", []),
             }
 
-        defer.returnValue(
-            (200, send_content)
-        )
+        defer.returnValue((200, send_content))
 
     @log_function
     def on_query_client_keys(self, origin, content):
@@ -566,20 +538,23 @@ class FederationServer(FederationBase):
 
         logger.info(
             "Claimed one-time-keys: %s",
-            ",".join((
-                "%s for %s:%s" % (key_id, user_id, device_id)
-                for user_id, user_keys in iteritems(json_result)
-                for device_id, device_keys in iteritems(user_keys)
-                for key_id, _ in iteritems(device_keys)
-            )),
+            ",".join(
+                (
+                    "%s for %s:%s" % (key_id, user_id, device_id)
+                    for user_id, user_keys in iteritems(json_result)
+                    for device_id, device_keys in iteritems(user_keys)
+                    for key_id, _ in iteritems(device_keys)
+                )
+            ),
         )
 
         defer.returnValue({"one_time_keys": json_result})
 
     @defer.inlineCallbacks
     @log_function
-    def on_get_missing_events(self, origin, room_id, earliest_events,
-                              latest_events, limit):
+    def on_get_missing_events(
+        self, origin, room_id, earliest_events, latest_events, limit
+    ):
         with (yield self._server_linearizer.queue((origin, room_id))):
             origin_host, _ = parse_server_name(origin)
             yield self.check_server_matches_acl(origin_host, room_id)
@@ -587,11 +562,13 @@ class FederationServer(FederationBase):
             logger.info(
                 "on_get_missing_events: earliest_events: %r, latest_events: %r,"
                 " limit: %d",
-                earliest_events, latest_events, limit,
+                earliest_events,
+                latest_events,
+                limit,
             )
 
             missing_events = yield self.handler.on_get_missing_events(
-                origin, room_id, earliest_events, latest_events, limit,
+                origin, room_id, earliest_events, latest_events, limit
             )
 
             if len(missing_events) < 5:
@@ -603,9 +580,9 @@ class FederationServer(FederationBase):
 
             time_now = self._clock.time_msec()
 
-        defer.returnValue({
-            "events": [ev.get_pdu_json(time_now) for ev in missing_events],
-        })
+        defer.returnValue(
+            {"events": [ev.get_pdu_json(time_now) for ev in missing_events]}
+        )
 
     @log_function
     def on_openid_userinfo(self, token):
@@ -666,22 +643,17 @@ class FederationServer(FederationBase):
             # origin. See bug #1893. This is also true for some third party
             # invites).
             if not (
-                pdu.type == 'm.room.member' and
-                pdu.content and
-                pdu.content.get("membership", None) in (
-                    Membership.JOIN, Membership.INVITE,
-                )
+                pdu.type == "m.room.member"
+                and pdu.content
+                and pdu.content.get("membership", None)
+                in (Membership.JOIN, Membership.INVITE)
             ):
                 logger.info(
-                    "Discarding PDU %s from invalid origin %s",
-                    pdu.event_id, origin
+                    "Discarding PDU %s from invalid origin %s", pdu.event_id, origin
                 )
                 return
             else:
-                logger.info(
-                    "Accepting join PDU %s from %s",
-                    pdu.event_id, origin
-                )
+                logger.info("Accepting join PDU %s from %s", pdu.event_id, origin)
 
         # We've already checked that we know the room version by this point
         room_version = yield self.store.get_room_version(pdu.room_id)
@@ -690,33 +662,19 @@ class FederationServer(FederationBase):
         try:
             pdu = yield self._check_sigs_and_hash(room_version, pdu)
         except SynapseError as e:
-            raise FederationError(
-                "ERROR",
-                e.code,
-                e.msg,
-                affected=pdu.event_id,
-            )
+            raise FederationError("ERROR", e.code, e.msg, affected=pdu.event_id)
 
-        yield self.handler.on_receive_pdu(
-            origin, pdu, sent_to_us_directly=True,
-        )
+        yield self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True)
 
     def __str__(self):
         return "<ReplicationLayer(%s)>" % self.server_name
 
     @defer.inlineCallbacks
     def exchange_third_party_invite(
-            self,
-            sender_user_id,
-            target_user_id,
-            room_id,
-            signed,
+        self, sender_user_id, target_user_id, room_id, signed
     ):
         ret = yield self.handler.exchange_third_party_invite(
-            sender_user_id,
-            target_user_id,
-            room_id,
-            signed,
+            sender_user_id, target_user_id, room_id, signed
         )
         defer.returnValue(ret)
 
@@ -771,7 +729,7 @@ def server_matches_acl_event(server_name, acl_event):
         allow_ip_literals = True
     if not allow_ip_literals:
         # check for ipv6 literals. These start with '['.
-        if server_name[0] == '[':
+        if server_name[0] == "[":
             return False
 
         # check for ipv4 literals. We can just lift the routine from twisted.
@@ -805,7 +763,9 @@ def server_matches_acl_event(server_name, acl_event):
 
 def _acl_entry_matches(server_name, acl_entry):
     if not isinstance(acl_entry, six.string_types):
-        logger.warn("Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry))
+        logger.warn(
+            "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
+        )
         return False
     regex = glob_to_regex(acl_entry)
     return regex.match(server_name)
@@ -815,6 +775,7 @@ class FederationHandlerRegistry(object):
     """Allows classes to register themselves as handlers for a given EDU or
     query type for incoming federation traffic.
     """
+
     def __init__(self):
         self.edu_handlers = {}
         self.query_handlers = {}
@@ -848,9 +809,7 @@ class FederationHandlerRegistry(object):
                 on and the result used as the response to the query request.
         """
         if query_type in self.query_handlers:
-            raise KeyError(
-                "Already have a Query handler for %s" % (query_type,)
-            )
+            raise KeyError("Already have a Query handler for %s" % (query_type,))
 
         logger.info("Registering federation query handler for %r", query_type)
 
@@ -905,14 +864,10 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
         handler = self.edu_handlers.get(edu_type)
         if handler:
             return super(ReplicationFederationHandlerRegistry, self).on_edu(
-                edu_type, origin, content,
+                edu_type, origin, content
             )
 
-        return self._send_edu(
-            edu_type=edu_type,
-            origin=origin,
-            content=content,
-        )
+        return self._send_edu(edu_type=edu_type, origin=origin, content=content)
 
     def on_query(self, query_type, args):
         """Overrides FederationHandlerRegistry
@@ -921,7 +876,4 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
         if handler:
             return handler(args)
 
-        return self._get_query_client(
-            query_type=query_type,
-            args=args,
-        )
+        return self._get_query_client(query_type=query_type, args=args)
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index 74ffd13b4f..7535f79203 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -46,12 +46,9 @@ class TransactionActions(object):
             response code and response body.
         """
         if not transaction.transaction_id:
-            raise RuntimeError("Cannot persist a transaction with no "
-                               "transaction_id")
+            raise RuntimeError("Cannot persist a transaction with no " "transaction_id")
 
-        return self.store.get_received_txn_response(
-            transaction.transaction_id, origin
-        )
+        return self.store.get_received_txn_response(transaction.transaction_id, origin)
 
     @log_function
     def set_response(self, origin, transaction, code, response):
@@ -61,14 +58,10 @@ class TransactionActions(object):
             Deferred
         """
         if not transaction.transaction_id:
-            raise RuntimeError("Cannot persist a transaction with no "
-                               "transaction_id")
+            raise RuntimeError("Cannot persist a transaction with no " "transaction_id")
 
         return self.store.set_received_txn_response(
-            transaction.transaction_id,
-            origin,
-            code,
-            response,
+            transaction.transaction_id, origin, code, response
         )
 
     @defer.inlineCallbacks
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 0240b339b0..454456a52d 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -77,12 +77,22 @@ class FederationRemoteSendQueue(object):
         # lambda binds to the queue rather than to the name of the queue which
         # changes. ARGH.
         def register(name, queue):
-            LaterGauge("synapse_federation_send_queue_%s_size" % (queue_name,),
-                       "", [], lambda: len(queue))
+            LaterGauge(
+                "synapse_federation_send_queue_%s_size" % (queue_name,),
+                "",
+                [],
+                lambda: len(queue),
+            )
 
         for queue_name in [
-            "presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
-            "edus", "device_messages", "pos_time", "presence_destinations",
+            "presence_map",
+            "presence_changed",
+            "keyed_edu",
+            "keyed_edu_changed",
+            "edus",
+            "device_messages",
+            "pos_time",
+            "presence_destinations",
         ]:
             register(queue_name, getattr(self, queue_name))
 
@@ -121,9 +131,7 @@ class FederationRemoteSendQueue(object):
                 del self.presence_changed[key]
 
             user_ids = set(
-                user_id
-                for uids in self.presence_changed.values()
-                for user_id in uids
+                user_id for uids in self.presence_changed.values() for user_id in uids
             )
 
             keys = self.presence_destinations.keys()
@@ -285,19 +293,21 @@ class FederationRemoteSendQueue(object):
         ]
 
         for (key, user_id) in dest_user_ids:
-            rows.append((key, PresenceRow(
-                state=self.presence_map[user_id],
-            )))
+            rows.append((key, PresenceRow(state=self.presence_map[user_id])))
 
         # Fetch presence to send to destinations
         i = self.presence_destinations.bisect_right(from_token)
         j = self.presence_destinations.bisect_right(to_token) + 1
 
         for pos, (user_id, dests) in self.presence_destinations.items()[i:j]:
-            rows.append((pos, PresenceDestinationsRow(
-                state=self.presence_map[user_id],
-                destinations=list(dests),
-            )))
+            rows.append(
+                (
+                    pos,
+                    PresenceDestinationsRow(
+                        state=self.presence_map[user_id], destinations=list(dests)
+                    ),
+                )
+            )
 
         # Fetch changes keyed edus
         i = self.keyed_edu_changed.bisect_right(from_token)
@@ -308,10 +318,14 @@ class FederationRemoteSendQueue(object):
         keyed_edus = {v: k for k, v in self.keyed_edu_changed.items()[i:j]}
 
         for ((destination, edu_key), pos) in iteritems(keyed_edus):
-            rows.append((pos, KeyedEduRow(
-                key=edu_key,
-                edu=self.keyed_edu[(destination, edu_key)],
-            )))
+            rows.append(
+                (
+                    pos,
+                    KeyedEduRow(
+                        key=edu_key, edu=self.keyed_edu[(destination, edu_key)]
+                    ),
+                )
+            )
 
         # Fetch changed edus
         i = self.edus.bisect_right(from_token)
@@ -327,9 +341,7 @@ class FederationRemoteSendQueue(object):
         device_messages = {v: k for k, v in self.device_messages.items()[i:j]}
 
         for (destination, pos) in iteritems(device_messages):
-            rows.append((pos, DeviceRow(
-                destination=destination,
-            )))
+            rows.append((pos, DeviceRow(destination=destination)))
 
         # Sort rows based on pos
         rows.sort()
@@ -377,16 +389,14 @@ class BaseFederationRow(object):
         raise NotImplementedError()
 
 
-class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", (
-    "state",  # UserPresenceState
-))):
+class PresenceRow(
+    BaseFederationRow, namedtuple("PresenceRow", ("state",))  # UserPresenceState
+):
     TypeId = "p"
 
     @staticmethod
     def from_data(data):
-        return PresenceRow(
-            state=UserPresenceState.from_dict(data)
-        )
+        return PresenceRow(state=UserPresenceState.from_dict(data))
 
     def to_data(self):
         return self.state.as_dict()
@@ -395,33 +405,35 @@ class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", (
         buff.presence.append(self.state)
 
 
-class PresenceDestinationsRow(BaseFederationRow, namedtuple("PresenceDestinationsRow", (
-    "state",  # UserPresenceState
-    "destinations",  # list[str]
-))):
+class PresenceDestinationsRow(
+    BaseFederationRow,
+    namedtuple(
+        "PresenceDestinationsRow",
+        ("state", "destinations"),  # UserPresenceState  # list[str]
+    ),
+):
     TypeId = "pd"
 
     @staticmethod
     def from_data(data):
         return PresenceDestinationsRow(
-            state=UserPresenceState.from_dict(data["state"]),
-            destinations=data["dests"],
+            state=UserPresenceState.from_dict(data["state"]), destinations=data["dests"]
         )
 
     def to_data(self):
-        return {
-            "state": self.state.as_dict(),
-            "dests": self.destinations,
-        }
+        return {"state": self.state.as_dict(), "dests": self.destinations}
 
     def add_to_buffer(self, buff):
         buff.presence_destinations.append((self.state, self.destinations))
 
 
-class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", (
-    "key",  # tuple(str) - the edu key passed to send_edu
-    "edu",  # Edu
-))):
+class KeyedEduRow(
+    BaseFederationRow,
+    namedtuple(
+        "KeyedEduRow",
+        ("key", "edu"),  # tuple(str) - the edu key passed to send_edu  # Edu
+    ),
+):
     """Streams EDUs that have an associated key that is ued to clobber. For example,
     typing EDUs clobber based on room_id.
     """
@@ -430,28 +442,19 @@ class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", (
 
     @staticmethod
     def from_data(data):
-        return KeyedEduRow(
-            key=tuple(data["key"]),
-            edu=Edu(**data["edu"]),
-        )
+        return KeyedEduRow(key=tuple(data["key"]), edu=Edu(**data["edu"]))
 
     def to_data(self):
-        return {
-            "key": self.key,
-            "edu": self.edu.get_internal_dict(),
-        }
+        return {"key": self.key, "edu": self.edu.get_internal_dict()}
 
     def add_to_buffer(self, buff):
-        buff.keyed_edus.setdefault(
-            self.edu.destination, {}
-        )[self.key] = self.edu
+        buff.keyed_edus.setdefault(self.edu.destination, {})[self.key] = self.edu
 
 
-class EduRow(BaseFederationRow, namedtuple("EduRow", (
-    "edu",  # Edu
-))):
+class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))):  # Edu
     """Streams EDUs that don't have keys. See KeyedEduRow
     """
+
     TypeId = "e"
 
     @staticmethod
@@ -465,13 +468,12 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", (
         buff.edus.setdefault(self.edu.destination, []).append(self.edu)
 
 
-class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", (
-    "destination",  # str
-))):
+class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", ("destination",))):  # str
     """Streams the fact that either a) there is pending to device messages for
     users on the remote, or b) a local users device has changed and needs to
     be sent to the remote.
     """
+
     TypeId = "d"
 
     @staticmethod
@@ -487,23 +489,20 @@ class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", (
 
 TypeToRow = {
     Row.TypeId: Row
-    for Row in (
-        PresenceRow,
-        PresenceDestinationsRow,
-        KeyedEduRow,
-        EduRow,
-        DeviceRow,
-    )
+    for Row in (PresenceRow, PresenceDestinationsRow, KeyedEduRow, EduRow, DeviceRow)
 }
 
 
-ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", (
-    "presence",  # list(UserPresenceState)
-    "presence_destinations",  # list of tuples of UserPresenceState and destinations
-    "keyed_edus",  # dict of destination -> { key -> Edu }
-    "edus",  # dict of destination -> [Edu]
-    "device_destinations",  # set of destinations
-))
+ParsedFederationStreamData = namedtuple(
+    "ParsedFederationStreamData",
+    (
+        "presence",  # list(UserPresenceState)
+        "presence_destinations",  # list of tuples of UserPresenceState and destinations
+        "keyed_edus",  # dict of destination -> { key -> Edu }
+        "edus",  # dict of destination -> [Edu]
+        "device_destinations",  # set of destinations
+    ),
+)
 
 
 def process_rows_for_federation(transaction_queue, rows):
@@ -542,7 +541,7 @@ def process_rows_for_federation(transaction_queue, rows):
 
     for state, destinations in buff.presence_destinations:
         transaction_queue.send_presence_to_destinations(
-            states=[state], destinations=destinations,
+            states=[state], destinations=destinations
         )
 
     for destination, edu_map in iteritems(buff.keyed_edus):
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 4f0f939102..766c5a37cd 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -44,8 +44,8 @@ sent_pdus_destination_dist_count = Counter(
 )
 
 sent_pdus_destination_dist_total = Counter(
-    "synapse_federation_client_sent_pdu_destinations:total", ""
-    "Total number of PDUs queued for sending across all destinations",
+    "synapse_federation_client_sent_pdu_destinations:total",
+    "" "Total number of PDUs queued for sending across all destinations",
 )
 
 
@@ -63,14 +63,15 @@ class FederationSender(object):
         self._transaction_manager = TransactionManager(hs)
 
         # map from destination to PerDestinationQueue
-        self._per_destination_queues = {}   # type: dict[str, PerDestinationQueue]
+        self._per_destination_queues = {}  # type: dict[str, PerDestinationQueue]
 
         LaterGauge(
             "synapse_federation_transaction_queue_pending_destinations",
             "",
             [],
             lambda: sum(
-                1 for d in self._per_destination_queues.values()
+                1
+                for d in self._per_destination_queues.values()
                 if d.transmission_loop_running
             ),
         )
@@ -108,8 +109,9 @@ class FederationSender(object):
         # awaiting a call to flush_read_receipts_for_room. The presence of an entry
         # here for a given room means that we are rate-limiting RR flushes to that room,
         # and that there is a pending call to _flush_rrs_for_room in the system.
-        self._queues_awaiting_rr_flush_by_room = {
-        }   # type: dict[str, set[PerDestinationQueue]]
+        self._queues_awaiting_rr_flush_by_room = (
+            {}
+        )  # type: dict[str, set[PerDestinationQueue]]
 
         self._rr_txn_interval_per_room_ms = (
             1000.0 / hs.get_config().federation_rr_transactions_per_room_per_second
@@ -141,8 +143,7 @@ class FederationSender(object):
 
         # fire off a processing loop in the background
         run_as_background_process(
-            "process_event_queue_for_federation",
-            self._process_event_queue_loop,
+            "process_event_queue_for_federation", self._process_event_queue_loop
         )
 
     @defer.inlineCallbacks
@@ -152,7 +153,7 @@ class FederationSender(object):
             while True:
                 last_token = yield self.store.get_federation_out_pos("events")
                 next_token, events = yield self.store.get_all_new_events_stream(
-                    last_token, self._last_poked_id, limit=100,
+                    last_token, self._last_poked_id, limit=100
                 )
 
                 logger.debug("Handling %s -> %s", last_token, next_token)
@@ -168,6 +169,9 @@ class FederationSender(object):
                     if not is_mine and send_on_behalf_of is None:
                         return
 
+                    if not event.internal_metadata.should_proactively_send():
+                        return
+
                     try:
                         # Get the state from before the event.
                         # We need to make sure that this is the state from before
@@ -176,7 +180,7 @@ class FederationSender(object):
                         # banned then it won't receive the event because it won't
                         # be in the room after the ban.
                         destinations = yield self.state.get_current_hosts_in_room(
-                            event.room_id, latest_event_ids=event.prev_event_ids(),
+                            event.room_id, latest_event_ids=event.prev_event_ids()
                         )
                     except Exception:
                         logger.exception(
@@ -206,37 +210,40 @@ class FederationSender(object):
                 for event in events:
                     events_by_room.setdefault(event.room_id, []).append(event)
 
-                yield logcontext.make_deferred_yieldable(defer.gatherResults(
-                    [
-                        logcontext.run_in_background(handle_room_events, evs)
-                        for evs in itervalues(events_by_room)
-                    ],
-                    consumeErrors=True
-                ))
-
-                yield self.store.update_federation_out_pos(
-                    "events", next_token
+                yield logcontext.make_deferred_yieldable(
+                    defer.gatherResults(
+                        [
+                            logcontext.run_in_background(handle_room_events, evs)
+                            for evs in itervalues(events_by_room)
+                        ],
+                        consumeErrors=True,
+                    )
                 )
 
+                yield self.store.update_federation_out_pos("events", next_token)
+
                 if events:
                     now = self.clock.time_msec()
                     ts = yield self.store.get_received_ts(events[-1].event_id)
 
                     synapse.metrics.event_processing_lag.labels(
-                        "federation_sender").set(now - ts)
+                        "federation_sender"
+                    ).set(now - ts)
                     synapse.metrics.event_processing_last_ts.labels(
-                        "federation_sender").set(ts)
+                        "federation_sender"
+                    ).set(ts)
 
                     events_processed_counter.inc(len(events))
 
-                    event_processing_loop_room_count.labels(
-                        "federation_sender"
-                    ).inc(len(events_by_room))
+                    event_processing_loop_room_count.labels("federation_sender").inc(
+                        len(events_by_room)
+                    )
 
                 event_processing_loop_counter.labels("federation_sender").inc()
 
                 synapse.metrics.event_processing_positions.labels(
-                    "federation_sender").set(next_token)
+                    "federation_sender"
+                ).set(next_token)
 
         finally:
             self._is_processing = False
@@ -309,9 +316,7 @@ class FederationSender(object):
         if not domains:
             return
 
-        queues_pending_flush = self._queues_awaiting_rr_flush_by_room.get(
-            room_id
-        )
+        queues_pending_flush = self._queues_awaiting_rr_flush_by_room.get(room_id)
 
         # if there is no flush yet scheduled, we will send out these receipts with
         # immediate flushes, and schedule the next flush for this room.
@@ -374,10 +379,9 @@ class FederationSender(object):
         # updates in quick succession are correctly handled.
         # We only want to send presence for our own users, so lets always just
         # filter here just in case.
-        self.pending_presence.update({
-            state.user_id: state for state in states
-            if self.is_mine_id(state.user_id)
-        })
+        self.pending_presence.update(
+            {state.user_id: state for state in states if self.is_mine_id(state.user_id)}
+        )
 
         # We then handle the new pending presence in batches, first figuring
         # out the destinations we need to send each state to and then poking it
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 564c57203d..9aab12c0d3 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -189,11 +189,21 @@ class PerDestinationQueue(object):
 
             pending_pdus = []
             while True:
-                device_message_edus, device_stream_id, dev_list_id = (
-                    # We have to keep 2 free slots for presence and rr_edus
-                    yield self._get_new_device_messages(MAX_EDUS_PER_TRANSACTION - 2)
+                # We have to keep 2 free slots for presence and rr_edus
+                limit = MAX_EDUS_PER_TRANSACTION - 2
+
+                device_update_edus, dev_list_id = (
+                    yield self._get_device_update_edus(limit)
+                )
+
+                limit -= len(device_update_edus)
+
+                to_device_edus, device_stream_id = (
+                    yield self._get_to_device_message_edus(limit)
                 )
 
+                pending_edus = device_update_edus + to_device_edus
+
                 # BEGIN CRITICAL SECTION
                 #
                 # In order to avoid a race condition, we need to make sure that
@@ -208,10 +218,6 @@ class PerDestinationQueue(object):
                 # We can only include at most 50 PDUs per transactions
                 pending_pdus, self._pending_pdus = pending_pdus[:50], pending_pdus[50:]
 
-                pending_edus = []
-
-                # We can only include at most 100 EDUs per transactions
-                # rr_edus and pending_presence take at most one slot each
                 pending_edus.extend(self._get_rr_edus(force_flush=False))
                 pending_presence = self._pending_presence
                 self._pending_presence = {}
@@ -232,7 +238,6 @@ class PerDestinationQueue(object):
                         )
                     )
 
-                pending_edus.extend(device_message_edus)
                 pending_edus.extend(
                     self._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus))
                 )
@@ -272,10 +277,13 @@ class PerDestinationQueue(object):
                         sent_edus_by_type.labels(edu.edu_type).inc()
                     # Remove the acknowledged device messages from the database
                     # Only bother if we actually sent some device messages
-                    if device_message_edus:
+                    if to_device_edus:
                         yield self._store.delete_device_msgs_for_remote(
                             self._destination, device_stream_id
                         )
+
+                    # also mark the device updates as sent
+                    if device_update_edus:
                         logger.info(
                             "Marking as sent %r %r", self._destination, dev_list_id
                         )
@@ -347,12 +355,12 @@ class PerDestinationQueue(object):
         return pending_edus
 
     @defer.inlineCallbacks
-    def _get_new_device_messages(self, limit):
+    def _get_device_update_edus(self, limit):
         last_device_list = self._last_device_list_stream_id
 
         # Retrieve list of new device updates to send to the destination
         now_stream_id, results = yield self._store.get_devices_by_remote(
-            self._destination, last_device_list, limit=limit,
+            self._destination, last_device_list, limit=limit
         )
         edus = [
             Edu(
@@ -366,15 +374,16 @@ class PerDestinationQueue(object):
 
         assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs"
 
+        defer.returnValue((edus, now_stream_id))
+
+    @defer.inlineCallbacks
+    def _get_to_device_message_edus(self, limit):
         last_device_stream_id = self._last_device_stream_id
         to_device_stream_id = self._store.get_to_device_stream_token()
         contents, stream_id = yield self._store.get_new_device_msgs_for_remote(
-            self._destination,
-            last_device_stream_id,
-            to_device_stream_id,
-            limit - len(edus),
+            self._destination, last_device_stream_id, to_device_stream_id, limit
         )
-        edus.extend(
+        edus = [
             Edu(
                 origin=self._server_name,
                 destination=self._destination,
@@ -382,6 +391,6 @@ class PerDestinationQueue(object):
                 content=content,
             )
             for content in contents
-        )
+        ]
 
-        defer.returnValue((edus, stream_id, now_stream_id))
+        defer.returnValue((edus, stream_id))
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 35e6b8ff5b..c987bb9a0d 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -29,9 +29,10 @@ class TransactionManager(object):
 
     shared between PerDestinationQueue objects
     """
+
     def __init__(self, hs):
         self._server_name = hs.hostname
-        self.clock = hs.get_clock()   # nb must be called this for @measure_func
+        self.clock = hs.get_clock()  # nb must be called this for @measure_func
         self._store = hs.get_datastore()
         self._transaction_actions = TransactionActions(self._store)
         self._transport_layer = hs.get_federation_transport_client()
@@ -55,9 +56,9 @@ class TransactionManager(object):
         txn_id = str(self._next_txn_id)
 
         logger.debug(
-            "TX [%s] {%s} Attempting new transaction"
-            " (pdus: %d, edus: %d)",
-            destination, txn_id,
+            "TX [%s] {%s} Attempting new transaction" " (pdus: %d, edus: %d)",
+            destination,
+            txn_id,
             len(pdus),
             len(edus),
         )
@@ -79,9 +80,9 @@ class TransactionManager(object):
 
         logger.debug("TX [%s] Persisted transaction", destination)
         logger.info(
-            "TX [%s] {%s} Sending transaction [%s],"
-            " (PDUs: %d, EDUs: %d)",
-            destination, txn_id,
+            "TX [%s] {%s} Sending transaction [%s]," " (PDUs: %d, EDUs: %d)",
+            destination,
+            txn_id,
             transaction.transaction_id,
             len(pdus),
             len(edus),
@@ -112,20 +113,12 @@ class TransactionManager(object):
             response = e.response
 
             if e.code in (401, 404, 429) or 500 <= e.code:
-                logger.info(
-                    "TX [%s] {%s} got %d response",
-                    destination, txn_id, code
-                )
+                logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
                 raise e
 
-        logger.info(
-            "TX [%s] {%s} got %d response",
-            destination, txn_id, code
-        )
+        logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
 
-        yield self._transaction_actions.delivered(
-            transaction, code, response
-        )
+        yield self._transaction_actions.delivered(transaction, code, response)
 
         logger.debug("TX [%s] {%s} Marked as delivered", destination, txn_id)
 
@@ -134,13 +127,18 @@ class TransactionManager(object):
                 if "error" in r:
                     logger.warn(
                         "TX [%s] {%s} Remote returned error for %s: %s",
-                        destination, txn_id, e_id, r,
+                        destination,
+                        txn_id,
+                        e_id,
+                        r,
                     )
         else:
             for p in pdus:
                 logger.warn(
                     "TX [%s] {%s} Failed to send event %s",
-                    destination, txn_id, p.event_id,
+                    destination,
+                    txn_id,
+                    p.event_id,
                 )
             success = False
 
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index e424c40fdf..aecd142309 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -48,12 +48,13 @@ class TransportLayerClient(object):
         Returns:
             Deferred: Results in a dict received from the remote homeserver.
         """
-        logger.debug("get_room_state dest=%s, room=%s",
-                     destination, room_id)
+        logger.debug("get_room_state dest=%s, room=%s", destination, room_id)
 
         path = _create_v1_path("/state/%s", room_id)
         return self.client.get_json(
-            destination, path=path, args={"event_id": event_id},
+            destination,
+            path=path,
+            args={"event_id": event_id},
             try_trailing_slash_on_400=True,
         )
 
@@ -71,12 +72,13 @@ class TransportLayerClient(object):
         Returns:
             Deferred: Results in a dict received from the remote homeserver.
         """
-        logger.debug("get_room_state_ids dest=%s, room=%s",
-                     destination, room_id)
+        logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id)
 
         path = _create_v1_path("/state_ids/%s", room_id)
         return self.client.get_json(
-            destination, path=path, args={"event_id": event_id},
+            destination,
+            path=path,
+            args={"event_id": event_id},
             try_trailing_slash_on_400=True,
         )
 
@@ -94,13 +96,11 @@ class TransportLayerClient(object):
         Returns:
             Deferred: Results in a dict received from the remote homeserver.
         """
-        logger.debug("get_pdu dest=%s, event_id=%s",
-                     destination, event_id)
+        logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id)
 
         path = _create_v1_path("/event/%s", event_id)
         return self.client.get_json(
-            destination, path=path, timeout=timeout,
-            try_trailing_slash_on_400=True,
+            destination, path=path, timeout=timeout, try_trailing_slash_on_400=True
         )
 
     @log_function
@@ -119,7 +119,10 @@ class TransportLayerClient(object):
         """
         logger.debug(
             "backfill dest=%s, room_id=%s, event_tuples=%s, limit=%s",
-            destination, room_id, repr(event_tuples), str(limit)
+            destination,
+            room_id,
+            repr(event_tuples),
+            str(limit),
         )
 
         if not event_tuples:
@@ -128,16 +131,10 @@ class TransportLayerClient(object):
 
         path = _create_v1_path("/backfill/%s", room_id)
 
-        args = {
-            "v": event_tuples,
-            "limit": [str(limit)],
-        }
+        args = {"v": event_tuples, "limit": [str(limit)]}
 
         return self.client.get_json(
-            destination,
-            path=path,
-            args=args,
-            try_trailing_slash_on_400=True,
+            destination, path=path, args=args, try_trailing_slash_on_400=True
         )
 
     @defer.inlineCallbacks
@@ -163,7 +160,8 @@ class TransportLayerClient(object):
         """
         logger.debug(
             "send_data dest=%s, txid=%s",
-            transaction.destination, transaction.transaction_id
+            transaction.destination,
+            transaction.transaction_id,
         )
 
         if transaction.destination == self.server_name:
@@ -189,8 +187,9 @@ class TransportLayerClient(object):
 
     @defer.inlineCallbacks
     @log_function
-    def make_query(self, destination, query_type, args, retry_on_dns_fail,
-                   ignore_backoff=False):
+    def make_query(
+        self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False
+    ):
         path = _create_v1_path("/query/%s", query_type)
 
         content = yield self.client.get_json(
@@ -235,8 +234,8 @@ class TransportLayerClient(object):
         valid_memberships = {Membership.JOIN, Membership.LEAVE}
         if membership not in valid_memberships:
             raise RuntimeError(
-                "make_membership_event called with membership='%s', must be one of %s" %
-                (membership, ",".join(valid_memberships))
+                "make_membership_event called with membership='%s', must be one of %s"
+                % (membership, ",".join(valid_memberships))
             )
         path = _create_v1_path("/make_%s/%s/%s", membership, room_id, user_id)
 
@@ -268,9 +267,7 @@ class TransportLayerClient(object):
         path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
 
         response = yield self.client.put_json(
-            destination=destination,
-            path=path,
-            data=content,
+            destination=destination, path=path, data=content
         )
 
         defer.returnValue(response)
@@ -284,7 +281,6 @@ class TransportLayerClient(object):
             destination=destination,
             path=path,
             data=content,
-
             # we want to do our best to send this through. The problem is
             # that if it fails, we won't retry it later, so if the remote
             # server was just having a momentary blip, the room will be out of
@@ -300,10 +296,7 @@ class TransportLayerClient(object):
         path = _create_v1_path("/invite/%s/%s", room_id, event_id)
 
         response = yield self.client.put_json(
-            destination=destination,
-            path=path,
-            data=content,
-            ignore_backoff=True,
+            destination=destination, path=path, data=content, ignore_backoff=True
         )
 
         defer.returnValue(response)
@@ -314,26 +307,27 @@ class TransportLayerClient(object):
         path = _create_v2_path("/invite/%s/%s", room_id, event_id)
 
         response = yield self.client.put_json(
-            destination=destination,
-            path=path,
-            data=content,
-            ignore_backoff=True,
+            destination=destination, path=path, data=content, ignore_backoff=True
         )
 
         defer.returnValue(response)
 
     @defer.inlineCallbacks
     @log_function
-    def get_public_rooms(self, remote_server, limit, since_token,
-                         search_filter=None, include_all_networks=False,
-                         third_party_instance_id=None):
+    def get_public_rooms(
+        self,
+        remote_server,
+        limit,
+        since_token,
+        search_filter=None,
+        include_all_networks=False,
+        third_party_instance_id=None,
+    ):
         path = _create_v1_path("/publicRooms")
 
-        args = {
-            "include_all_networks": "true" if include_all_networks else "false",
-        }
+        args = {"include_all_networks": "true" if include_all_networks else "false"}
         if third_party_instance_id:
-            args["third_party_instance_id"] = third_party_instance_id,
+            args["third_party_instance_id"] = (third_party_instance_id,)
         if limit:
             args["limit"] = [str(limit)]
         if since_token:
@@ -342,10 +336,7 @@ class TransportLayerClient(object):
         # TODO(erikj): Actually send the search_filter across federation.
 
         response = yield self.client.get_json(
-            destination=remote_server,
-            path=path,
-            args=args,
-            ignore_backoff=True,
+            destination=remote_server, path=path, args=args, ignore_backoff=True
         )
 
         defer.returnValue(response)
@@ -353,12 +344,10 @@ class TransportLayerClient(object):
     @defer.inlineCallbacks
     @log_function
     def exchange_third_party_invite(self, destination, room_id, event_dict):
-        path = _create_v1_path("/exchange_third_party_invite/%s", room_id,)
+        path = _create_v1_path("/exchange_third_party_invite/%s", room_id)
 
         response = yield self.client.put_json(
-            destination=destination,
-            path=path,
-            data=event_dict,
+            destination=destination, path=path, data=event_dict
         )
 
         defer.returnValue(response)
@@ -368,10 +357,7 @@ class TransportLayerClient(object):
     def get_event_auth(self, destination, room_id, event_id):
         path = _create_v1_path("/event_auth/%s/%s", room_id, event_id)
 
-        content = yield self.client.get_json(
-            destination=destination,
-            path=path,
-        )
+        content = yield self.client.get_json(destination=destination, path=path)
 
         defer.returnValue(content)
 
@@ -381,9 +367,7 @@ class TransportLayerClient(object):
         path = _create_v1_path("/query_auth/%s/%s", room_id, event_id)
 
         content = yield self.client.post_json(
-            destination=destination,
-            path=path,
-            data=content,
+            destination=destination, path=path, data=content
         )
 
         defer.returnValue(content)
@@ -416,10 +400,7 @@ class TransportLayerClient(object):
         path = _create_v1_path("/user/keys/query")
 
         content = yield self.client.post_json(
-            destination=destination,
-            path=path,
-            data=query_content,
-            timeout=timeout,
+            destination=destination, path=path, data=query_content, timeout=timeout
         )
         defer.returnValue(content)
 
@@ -443,9 +424,7 @@ class TransportLayerClient(object):
         path = _create_v1_path("/user/devices/%s", user_id)
 
         content = yield self.client.get_json(
-            destination=destination,
-            path=path,
-            timeout=timeout,
+            destination=destination, path=path, timeout=timeout
         )
         defer.returnValue(content)
 
@@ -479,18 +458,23 @@ class TransportLayerClient(object):
         path = _create_v1_path("/user/keys/claim")
 
         content = yield self.client.post_json(
-            destination=destination,
-            path=path,
-            data=query_content,
-            timeout=timeout,
+            destination=destination, path=path, data=query_content, timeout=timeout
         )
         defer.returnValue(content)
 
     @defer.inlineCallbacks
     @log_function
-    def get_missing_events(self, destination, room_id, earliest_events,
-                           latest_events, limit, min_depth, timeout):
-        path = _create_v1_path("/get_missing_events/%s", room_id,)
+    def get_missing_events(
+        self,
+        destination,
+        room_id,
+        earliest_events,
+        latest_events,
+        limit,
+        min_depth,
+        timeout,
+    ):
+        path = _create_v1_path("/get_missing_events/%s", room_id)
 
         content = yield self.client.post_json(
             destination=destination,
@@ -510,7 +494,7 @@ class TransportLayerClient(object):
     def get_group_profile(self, destination, group_id, requester_user_id):
         """Get a group profile
         """
-        path = _create_v1_path("/groups/%s/profile", group_id,)
+        path = _create_v1_path("/groups/%s/profile", group_id)
 
         return self.client.get_json(
             destination=destination,
@@ -529,7 +513,7 @@ class TransportLayerClient(object):
             requester_user_id (str)
             content (dict): The new profile of the group
         """
-        path = _create_v1_path("/groups/%s/profile", group_id,)
+        path = _create_v1_path("/groups/%s/profile", group_id)
 
         return self.client.post_json(
             destination=destination,
@@ -543,7 +527,7 @@ class TransportLayerClient(object):
     def get_group_summary(self, destination, group_id, requester_user_id):
         """Get a group summary
         """
-        path = _create_v1_path("/groups/%s/summary", group_id,)
+        path = _create_v1_path("/groups/%s/summary", group_id)
 
         return self.client.get_json(
             destination=destination,
@@ -556,7 +540,7 @@ class TransportLayerClient(object):
     def get_rooms_in_group(self, destination, group_id, requester_user_id):
         """Get all rooms in a group
         """
-        path = _create_v1_path("/groups/%s/rooms", group_id,)
+        path = _create_v1_path("/groups/%s/rooms", group_id)
 
         return self.client.get_json(
             destination=destination,
@@ -565,11 +549,12 @@ class TransportLayerClient(object):
             ignore_backoff=True,
         )
 
-    def add_room_to_group(self, destination, group_id, requester_user_id, room_id,
-                          content):
+    def add_room_to_group(
+        self, destination, group_id, requester_user_id, room_id, content
+    ):
         """Add a room to a group
         """
-        path = _create_v1_path("/groups/%s/room/%s", group_id, room_id,)
+        path = _create_v1_path("/groups/%s/room/%s", group_id, room_id)
 
         return self.client.post_json(
             destination=destination,
@@ -579,13 +564,13 @@ class TransportLayerClient(object):
             ignore_backoff=True,
         )
 
-    def update_room_in_group(self, destination, group_id, requester_user_id, room_id,
-                             config_key, content):
+    def update_room_in_group(
+        self, destination, group_id, requester_user_id, room_id, config_key, content
+    ):
         """Update room in group
         """
         path = _create_v1_path(
-            "/groups/%s/room/%s/config/%s",
-            group_id, room_id, config_key,
+            "/groups/%s/room/%s/config/%s", group_id, room_id, config_key
         )
 
         return self.client.post_json(
@@ -599,7 +584,7 @@ class TransportLayerClient(object):
     def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
         """Remove a room from a group
         """
-        path = _create_v1_path("/groups/%s/room/%s", group_id, room_id,)
+        path = _create_v1_path("/groups/%s/room/%s", group_id, room_id)
 
         return self.client.delete_json(
             destination=destination,
@@ -612,7 +597,7 @@ class TransportLayerClient(object):
     def get_users_in_group(self, destination, group_id, requester_user_id):
         """Get users in a group
         """
-        path = _create_v1_path("/groups/%s/users", group_id,)
+        path = _create_v1_path("/groups/%s/users", group_id)
 
         return self.client.get_json(
             destination=destination,
@@ -625,7 +610,7 @@ class TransportLayerClient(object):
     def get_invited_users_in_group(self, destination, group_id, requester_user_id):
         """Get users that have been invited to a group
         """
-        path = _create_v1_path("/groups/%s/invited_users", group_id,)
+        path = _create_v1_path("/groups/%s/invited_users", group_id)
 
         return self.client.get_json(
             destination=destination,
@@ -638,16 +623,10 @@ class TransportLayerClient(object):
     def accept_group_invite(self, destination, group_id, user_id, content):
         """Accept a group invite
         """
-        path = _create_v1_path(
-            "/groups/%s/users/%s/accept_invite",
-            group_id, user_id,
-        )
+        path = _create_v1_path("/groups/%s/users/%s/accept_invite", group_id, user_id)
 
         return self.client.post_json(
-            destination=destination,
-            path=path,
-            data=content,
-            ignore_backoff=True,
+            destination=destination, path=path, data=content, ignore_backoff=True
         )
 
     @log_function
@@ -657,14 +636,13 @@ class TransportLayerClient(object):
         path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id)
 
         return self.client.post_json(
-            destination=destination,
-            path=path,
-            data=content,
-            ignore_backoff=True,
+            destination=destination, path=path, data=content, ignore_backoff=True
         )
 
     @log_function
-    def invite_to_group(self, destination, group_id, user_id, requester_user_id, content):
+    def invite_to_group(
+        self, destination, group_id, user_id, requester_user_id, content
+    ):
         """Invite a user to a group
         """
         path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id)
@@ -686,15 +664,13 @@ class TransportLayerClient(object):
         path = _create_v1_path("/groups/local/%s/users/%s/invite", group_id, user_id)
 
         return self.client.post_json(
-            destination=destination,
-            path=path,
-            data=content,
-            ignore_backoff=True,
+            destination=destination, path=path, data=content, ignore_backoff=True
         )
 
     @log_function
-    def remove_user_from_group(self, destination, group_id, requester_user_id,
-                               user_id, content):
+    def remove_user_from_group(
+        self, destination, group_id, requester_user_id, user_id, content
+    ):
         """Remove a user fron a group
         """
         path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id)
@@ -708,8 +684,9 @@ class TransportLayerClient(object):
         )
 
     @log_function
-    def remove_user_from_group_notification(self, destination, group_id, user_id,
-                                            content):
+    def remove_user_from_group_notification(
+        self, destination, group_id, user_id, content
+    ):
         """Sent by group server to inform a user's server that they have been
         kicked from the group.
         """
@@ -717,10 +694,7 @@ class TransportLayerClient(object):
         path = _create_v1_path("/groups/local/%s/users/%s/remove", group_id, user_id)
 
         return self.client.post_json(
-            destination=destination,
-            path=path,
-            data=content,
-            ignore_backoff=True,
+            destination=destination, path=path, data=content, ignore_backoff=True
         )
 
     @log_function
@@ -732,24 +706,24 @@ class TransportLayerClient(object):
         path = _create_v1_path("/groups/%s/renew_attestation/%s", group_id, user_id)
 
         return self.client.post_json(
-            destination=destination,
-            path=path,
-            data=content,
-            ignore_backoff=True,
+            destination=destination, path=path, data=content, ignore_backoff=True
         )
 
     @log_function
-    def update_group_summary_room(self, destination, group_id, user_id, room_id,
-                                  category_id, content):
+    def update_group_summary_room(
+        self, destination, group_id, user_id, room_id, category_id, content
+    ):
         """Update a room entry in a group summary
         """
         if category_id:
             path = _create_v1_path(
                 "/groups/%s/summary/categories/%s/rooms/%s",
-                group_id, category_id, room_id,
+                group_id,
+                category_id,
+                room_id,
             )
         else:
-            path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id,)
+            path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id)
 
         return self.client.post_json(
             destination=destination,
@@ -760,17 +734,20 @@ class TransportLayerClient(object):
         )
 
     @log_function
-    def delete_group_summary_room(self, destination, group_id, user_id, room_id,
-                                  category_id):
+    def delete_group_summary_room(
+        self, destination, group_id, user_id, room_id, category_id
+    ):
         """Delete a room entry in a group summary
         """
         if category_id:
             path = _create_v1_path(
                 "/groups/%s/summary/categories/%s/rooms/%s",
-                group_id, category_id, room_id,
+                group_id,
+                category_id,
+                room_id,
             )
         else:
-            path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id,)
+            path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id)
 
         return self.client.delete_json(
             destination=destination,
@@ -783,7 +760,7 @@ class TransportLayerClient(object):
     def get_group_categories(self, destination, group_id, requester_user_id):
         """Get all categories in a group
         """
-        path = _create_v1_path("/groups/%s/categories", group_id,)
+        path = _create_v1_path("/groups/%s/categories", group_id)
 
         return self.client.get_json(
             destination=destination,
@@ -796,7 +773,7 @@ class TransportLayerClient(object):
     def get_group_category(self, destination, group_id, requester_user_id, category_id):
         """Get category info in a group
         """
-        path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,)
+        path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
 
         return self.client.get_json(
             destination=destination,
@@ -806,11 +783,12 @@ class TransportLayerClient(object):
         )
 
     @log_function
-    def update_group_category(self, destination, group_id, requester_user_id, category_id,
-                              content):
+    def update_group_category(
+        self, destination, group_id, requester_user_id, category_id, content
+    ):
         """Update a category in a group
         """
-        path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,)
+        path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
 
         return self.client.post_json(
             destination=destination,
@@ -821,11 +799,12 @@ class TransportLayerClient(object):
         )
 
     @log_function
-    def delete_group_category(self, destination, group_id, requester_user_id,
-                              category_id):
+    def delete_group_category(
+        self, destination, group_id, requester_user_id, category_id
+    ):
         """Delete a category in a group
         """
-        path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,)
+        path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
 
         return self.client.delete_json(
             destination=destination,
@@ -838,7 +817,7 @@ class TransportLayerClient(object):
     def get_group_roles(self, destination, group_id, requester_user_id):
         """Get all roles in a group
         """
-        path = _create_v1_path("/groups/%s/roles", group_id,)
+        path = _create_v1_path("/groups/%s/roles", group_id)
 
         return self.client.get_json(
             destination=destination,
@@ -851,7 +830,7 @@ class TransportLayerClient(object):
     def get_group_role(self, destination, group_id, requester_user_id, role_id):
         """Get a roles info
         """
-        path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,)
+        path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
 
         return self.client.get_json(
             destination=destination,
@@ -861,11 +840,12 @@ class TransportLayerClient(object):
         )
 
     @log_function
-    def update_group_role(self, destination, group_id, requester_user_id, role_id,
-                          content):
+    def update_group_role(
+        self, destination, group_id, requester_user_id, role_id, content
+    ):
         """Update a role in a group
         """
-        path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,)
+        path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
 
         return self.client.post_json(
             destination=destination,
@@ -879,7 +859,7 @@ class TransportLayerClient(object):
     def delete_group_role(self, destination, group_id, requester_user_id, role_id):
         """Delete a role in a group
         """
-        path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,)
+        path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
 
         return self.client.delete_json(
             destination=destination,
@@ -889,17 +869,17 @@ class TransportLayerClient(object):
         )
 
     @log_function
-    def update_group_summary_user(self, destination, group_id, requester_user_id,
-                                  user_id, role_id, content):
+    def update_group_summary_user(
+        self, destination, group_id, requester_user_id, user_id, role_id, content
+    ):
         """Update a users entry in a group
         """
         if role_id:
             path = _create_v1_path(
-                "/groups/%s/summary/roles/%s/users/%s",
-                group_id, role_id, user_id,
+                "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id
             )
         else:
-            path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id,)
+            path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id)
 
         return self.client.post_json(
             destination=destination,
@@ -910,11 +890,10 @@ class TransportLayerClient(object):
         )
 
     @log_function
-    def set_group_join_policy(self, destination, group_id, requester_user_id,
-                              content):
+    def set_group_join_policy(self, destination, group_id, requester_user_id, content):
         """Sets the join policy for a group
         """
-        path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id,)
+        path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id)
 
         return self.client.put_json(
             destination=destination,
@@ -925,17 +904,17 @@ class TransportLayerClient(object):
         )
 
     @log_function
-    def delete_group_summary_user(self, destination, group_id, requester_user_id,
-                                  user_id, role_id):
+    def delete_group_summary_user(
+        self, destination, group_id, requester_user_id, user_id, role_id
+    ):
         """Delete a users entry in a group
         """
         if role_id:
             path = _create_v1_path(
-                "/groups/%s/summary/roles/%s/users/%s",
-                group_id, role_id, user_id,
+                "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id
             )
         else:
-            path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id,)
+            path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id)
 
         return self.client.delete_json(
             destination=destination,
@@ -953,10 +932,7 @@ class TransportLayerClient(object):
         content = {"user_ids": user_ids}
 
         return self.client.post_json(
-            destination=destination,
-            path=path,
-            data=content,
-            ignore_backoff=True,
+            destination=destination, path=path, data=content, ignore_backoff=True
         )
 
 
@@ -975,9 +951,8 @@ def _create_v1_path(path, *args):
     Returns:
         str
     """
-    return (
-        FEDERATION_V1_PREFIX
-        + path % tuple(urllib.parse.quote(arg, "") for arg in args)
+    return FEDERATION_V1_PREFIX + path % tuple(
+        urllib.parse.quote(arg, "") for arg in args
     )
 
 
@@ -996,7 +971,6 @@ def _create_v2_path(path, *args):
     Returns:
         str
     """
-    return (
-        FEDERATION_V2_PREFIX
-        + path % tuple(urllib.parse.quote(arg, "") for arg in args)
+    return FEDERATION_V2_PREFIX + path % tuple(
+        urllib.parse.quote(arg, "") for arg in args
     )
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 949a5fb2aa..955f0f4308 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -66,8 +66,7 @@ class TransportLayerServer(JsonResource):
 
         self.authenticator = Authenticator(hs)
         self.ratelimiter = FederationRateLimiter(
-            self.clock,
-            config=hs.config.rc_federation,
+            self.clock, config=hs.config.rc_federation
         )
 
         self.register_servlets()
@@ -84,11 +83,13 @@ class TransportLayerServer(JsonResource):
 
 class AuthenticationError(SynapseError):
     """There was a problem authenticating the request"""
+
     pass
 
 
 class NoAuthenticationError(AuthenticationError):
     """The request had no authentication information"""
+
     pass
 
 
@@ -105,8 +106,8 @@ class Authenticator(object):
     def authenticate_request(self, request, content):
         now = self._clock.time_msec()
         json_request = {
-            "method": request.method.decode('ascii'),
-            "uri": request.uri.decode('ascii'),
+            "method": request.method.decode("ascii"),
+            "uri": request.uri.decode("ascii"),
             "destination": self.server_name,
             "signatures": {},
         }
@@ -120,7 +121,7 @@ class Authenticator(object):
 
         if not auth_headers:
             raise NoAuthenticationError(
-                401, "Missing Authorization headers", Codes.UNAUTHORIZED,
+                401, "Missing Authorization headers", Codes.UNAUTHORIZED
             )
 
         for auth in auth_headers:
@@ -130,14 +131,14 @@ class Authenticator(object):
                 json_request["signatures"].setdefault(origin, {})[key] = sig
 
         if (
-            self.federation_domain_whitelist is not None and
-            origin not in self.federation_domain_whitelist
+            self.federation_domain_whitelist is not None
+            and origin not in self.federation_domain_whitelist
         ):
             raise FederationDeniedError(origin)
 
         if not json_request["signatures"]:
             raise NoAuthenticationError(
-                401, "Missing Authorization headers", Codes.UNAUTHORIZED,
+                401, "Missing Authorization headers", Codes.UNAUTHORIZED
             )
 
         yield self.keyring.verify_json_for_server(
@@ -177,12 +178,12 @@ def _parse_auth_header(header_bytes):
         AuthenticationError if the header could not be parsed
     """
     try:
-        header_str = header_bytes.decode('utf-8')
+        header_str = header_bytes.decode("utf-8")
         params = header_str.split(" ")[1].split(",")
         param_dict = dict(kv.split("=") for kv in params)
 
         def strip_quotes(value):
-            if value.startswith("\""):
+            if value.startswith('"'):
                 return value[1:-1]
             else:
                 return value
@@ -198,11 +199,11 @@ def _parse_auth_header(header_bytes):
     except Exception as e:
         logger.warn(
             "Error parsing auth header '%s': %s",
-            header_bytes.decode('ascii', 'replace'),
+            header_bytes.decode("ascii", "replace"),
             e,
         )
         raise AuthenticationError(
-            400, "Malformed Authorization header", Codes.UNAUTHORIZED,
+            400, "Malformed Authorization header", Codes.UNAUTHORIZED
         )
 
 
@@ -242,6 +243,7 @@ class BaseFederationServlet(object):
             Exception: other exceptions will be caught, logged, and a 500 will be
                 returned.
     """
+
     REQUIRE_AUTH = True
 
     PREFIX = FEDERATION_V1_PREFIX  # Allows specifying the API version
@@ -293,9 +295,7 @@ class BaseFederationServlet(object):
                         origin, content, request.args, *args, **kwargs
                     )
             else:
-                response = yield func(
-                    origin, content, request.args, *args, **kwargs
-                )
+                response = yield func(origin, content, request.args, *args, **kwargs)
 
             defer.returnValue(response)
 
@@ -343,14 +343,12 @@ class FederationSendServlet(BaseFederationServlet):
         try:
             transaction_data = content
 
-            logger.debug(
-                "Decoded %s: %s",
-                transaction_id, str(transaction_data)
-            )
+            logger.debug("Decoded %s: %s", transaction_id, str(transaction_data))
 
             logger.info(
                 "Received txn %s from %s. (PDUs: %d, EDUs: %d)",
-                transaction_id, origin,
+                transaction_id,
+                origin,
                 len(transaction_data.get("pdus", [])),
                 len(transaction_data.get("edus", [])),
             )
@@ -361,8 +359,7 @@ class FederationSendServlet(BaseFederationServlet):
             # Add some extra data to the transaction dict that isn't included
             # in the request body.
             transaction_data.update(
-                transaction_id=transaction_id,
-                destination=self.server_name
+                transaction_id=transaction_id, destination=self.server_name
             )
 
         except Exception as e:
@@ -372,7 +369,7 @@ class FederationSendServlet(BaseFederationServlet):
 
         try:
             code, response = yield self.handler.on_incoming_transaction(
-                origin, transaction_data,
+                origin, transaction_data
             )
         except Exception:
             logger.exception("on_incoming_transaction failed")
@@ -416,7 +413,7 @@ class FederationBackfillServlet(BaseFederationServlet):
     PATH = "/backfill/(?P<context>[^/]*)/?"
 
     def on_GET(self, origin, content, query, context):
-        versions = [x.decode('ascii') for x in query[b"v"]]
+        versions = [x.decode("ascii") for x in query[b"v"]]
         limit = parse_integer_from_args(query, "limit", None)
 
         if not limit:
@@ -432,7 +429,7 @@ class FederationQueryServlet(BaseFederationServlet):
     def on_GET(self, origin, content, query, query_type):
         return self.handler.on_query_request(
             query_type,
-            {k.decode('utf8'): v[0].decode("utf-8") for k, v in query.items()}
+            {k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()},
         )
 
 
@@ -456,15 +453,14 @@ class FederationMakeJoinServlet(BaseFederationServlet):
             Deferred[(int, object)|None]: either (response code, response object) to
                  return a JSON response, or None if the request has already been handled.
         """
-        versions = query.get(b'ver')
+        versions = query.get(b"ver")
         if versions is not None:
             supported_versions = [v.decode("utf-8") for v in versions]
         else:
             supported_versions = ["1"]
 
         content = yield self.handler.on_make_join_request(
-            origin, context, user_id,
-            supported_versions=supported_versions,
+            origin, context, user_id, supported_versions=supported_versions
         )
         defer.returnValue((200, content))
 
@@ -474,9 +470,7 @@ class FederationMakeLeaveServlet(BaseFederationServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, origin, content, query, context, user_id):
-        content = yield self.handler.on_make_leave_request(
-            origin, context, user_id,
-        )
+        content = yield self.handler.on_make_leave_request(origin, context, user_id)
         defer.returnValue((200, content))
 
 
@@ -517,7 +511,7 @@ class FederationV1InviteServlet(BaseFederationServlet):
         # state resolution algorithm, and we don't use that for processing
         # invites
         content = yield self.handler.on_invite_request(
-            origin, content, room_version=RoomVersions.V1.identifier,
+            origin, content, room_version=RoomVersions.V1.identifier
         )
 
         # V1 federation API is defined to return a content of `[200, {...}]`
@@ -545,7 +539,7 @@ class FederationV2InviteServlet(BaseFederationServlet):
         event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state
 
         content = yield self.handler.on_invite_request(
-            origin, event, room_version=room_version,
+            origin, event, room_version=room_version
         )
         defer.returnValue((200, content))
 
@@ -629,8 +623,10 @@ class On3pidBindServlet(BaseFederationServlet):
             for invite in content["invites"]:
                 try:
                     if "signed" not in invite or "token" not in invite["signed"]:
-                        message = ("Rejecting received notification of third-"
-                                   "party invite without signed: %s" % (invite,))
+                        message = (
+                            "Rejecting received notification of third-"
+                            "party invite without signed: %s" % (invite,)
+                        )
                         logger.info(message)
                         raise SynapseError(400, message)
                     yield self.handler.exchange_third_party_invite(
@@ -671,18 +667,23 @@ class OpenIdUserInfo(BaseFederationServlet):
     def on_GET(self, origin, content, query):
         token = query.get(b"access_token", [None])[0]
         if token is None:
-            defer.returnValue((401, {
-                "errcode": "M_MISSING_TOKEN", "error": "Access Token required"
-            }))
+            defer.returnValue(
+                (401, {"errcode": "M_MISSING_TOKEN", "error": "Access Token required"})
+            )
             return
 
-        user_id = yield self.handler.on_openid_userinfo(token.decode('ascii'))
+        user_id = yield self.handler.on_openid_userinfo(token.decode("ascii"))
 
         if user_id is None:
-            defer.returnValue((401, {
-                "errcode": "M_UNKNOWN_TOKEN",
-                "error": "Access Token unknown or expired"
-            }))
+            defer.returnValue(
+                (
+                    401,
+                    {
+                        "errcode": "M_UNKNOWN_TOKEN",
+                        "error": "Access Token unknown or expired",
+                    },
+                )
+            )
 
         defer.returnValue((200, {"sub": user_id}))
 
@@ -720,15 +721,15 @@ class PublicRoomList(BaseFederationServlet):
 
     PATH = "/publicRooms"
 
-    def __init__(self, handler, authenticator, ratelimiter, server_name, deny_access):
+    def __init__(self, handler, authenticator, ratelimiter, server_name, allow_access):
         super(PublicRoomList, self).__init__(
-            handler, authenticator, ratelimiter, server_name,
+            handler, authenticator, ratelimiter, server_name
         )
-        self.deny_access = deny_access
+        self.allow_access = allow_access
 
     @defer.inlineCallbacks
     def on_GET(self, origin, content, query):
-        if self.deny_access:
+        if not self.allow_access:
             raise FederationDeniedError(origin)
 
         limit = parse_integer_from_args(query, "limit", 0)
@@ -748,9 +749,7 @@ class PublicRoomList(BaseFederationServlet):
             network_tuple = ThirdPartyInstanceID(None, None)
 
         data = yield self.handler.get_local_public_room_list(
-            limit, since_token,
-            network_tuple=network_tuple,
-            from_federation=True,
+            limit, since_token, network_tuple=network_tuple, from_federation=True
         )
         defer.returnValue((200, data))
 
@@ -761,17 +760,18 @@ class FederationVersionServlet(BaseFederationServlet):
     REQUIRE_AUTH = False
 
     def on_GET(self, origin, content, query):
-        return defer.succeed((200, {
-            "server": {
-                "name": "Synapse",
-                "version": get_version_string(synapse)
-            },
-        }))
+        return defer.succeed(
+            (
+                200,
+                {"server": {"name": "Synapse", "version": get_version_string(synapse)}},
+            )
+        )
 
 
 class FederationGroupsProfileServlet(BaseFederationServlet):
     """Get/set the basic profile of a group on behalf of a user
     """
+
     PATH = "/groups/(?P<group_id>[^/]*)/profile"
 
     @defer.inlineCallbacks
@@ -780,9 +780,7 @@ class FederationGroupsProfileServlet(BaseFederationServlet):
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
-        new_content = yield self.handler.get_group_profile(
-            group_id, requester_user_id
-        )
+        new_content = yield self.handler.get_group_profile(group_id, requester_user_id)
 
         defer.returnValue((200, new_content))
 
@@ -808,9 +806,7 @@ class FederationGroupsSummaryServlet(BaseFederationServlet):
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
-        new_content = yield self.handler.get_group_summary(
-            group_id, requester_user_id
-        )
+        new_content = yield self.handler.get_group_summary(group_id, requester_user_id)
 
         defer.returnValue((200, new_content))
 
@@ -818,6 +814,7 @@ class FederationGroupsSummaryServlet(BaseFederationServlet):
 class FederationGroupsRoomsServlet(BaseFederationServlet):
     """Get the rooms in a group on behalf of a user
     """
+
     PATH = "/groups/(?P<group_id>[^/]*)/rooms"
 
     @defer.inlineCallbacks
@@ -826,9 +823,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet):
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
-        new_content = yield self.handler.get_rooms_in_group(
-            group_id, requester_user_id
-        )
+        new_content = yield self.handler.get_rooms_in_group(group_id, requester_user_id)
 
         defer.returnValue((200, new_content))
 
@@ -836,6 +831,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet):
 class FederationGroupsAddRoomsServlet(BaseFederationServlet):
     """Add/remove room from group
     """
+
     PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
 
     @defer.inlineCallbacks
@@ -857,7 +853,7 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet):
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
         new_content = yield self.handler.remove_room_from_group(
-            group_id, requester_user_id, room_id,
+            group_id, requester_user_id, room_id
         )
 
         defer.returnValue((200, new_content))
@@ -866,6 +862,7 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet):
 class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
     """Update room config in group
     """
+
     PATH = (
         "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
         "/config/(?P<config_key>[^/]*)"
@@ -878,7 +875,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
         result = yield self.groups_handler.update_room_in_group(
-            group_id, requester_user_id, room_id, config_key, content,
+            group_id, requester_user_id, room_id, config_key, content
         )
 
         defer.returnValue((200, result))
@@ -887,6 +884,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
 class FederationGroupsUsersServlet(BaseFederationServlet):
     """Get the users in a group on behalf of a user
     """
+
     PATH = "/groups/(?P<group_id>[^/]*)/users"
 
     @defer.inlineCallbacks
@@ -895,9 +893,7 @@ class FederationGroupsUsersServlet(BaseFederationServlet):
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
-        new_content = yield self.handler.get_users_in_group(
-            group_id, requester_user_id
-        )
+        new_content = yield self.handler.get_users_in_group(group_id, requester_user_id)
 
         defer.returnValue((200, new_content))
 
@@ -905,6 +901,7 @@ class FederationGroupsUsersServlet(BaseFederationServlet):
 class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
     """Get the users that have been invited to a group
     """
+
     PATH = "/groups/(?P<group_id>[^/]*)/invited_users"
 
     @defer.inlineCallbacks
@@ -923,6 +920,7 @@ class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
 class FederationGroupsInviteServlet(BaseFederationServlet):
     """Ask a group server to invite someone to the group
     """
+
     PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
 
     @defer.inlineCallbacks
@@ -932,7 +930,7 @@ class FederationGroupsInviteServlet(BaseFederationServlet):
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
         new_content = yield self.handler.invite_to_group(
-            group_id, user_id, requester_user_id, content,
+            group_id, user_id, requester_user_id, content
         )
 
         defer.returnValue((200, new_content))
@@ -941,6 +939,7 @@ class FederationGroupsInviteServlet(BaseFederationServlet):
 class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
     """Accept an invitation from the group server
     """
+
     PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite"
 
     @defer.inlineCallbacks
@@ -948,9 +947,7 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
         if get_domain_from_id(user_id) != origin:
             raise SynapseError(403, "user_id doesn't match origin")
 
-        new_content = yield self.handler.accept_invite(
-            group_id, user_id, content,
-        )
+        new_content = yield self.handler.accept_invite(group_id, user_id, content)
 
         defer.returnValue((200, new_content))
 
@@ -958,6 +955,7 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
 class FederationGroupsJoinServlet(BaseFederationServlet):
     """Attempt to join a group
     """
+
     PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join"
 
     @defer.inlineCallbacks
@@ -965,9 +963,7 @@ class FederationGroupsJoinServlet(BaseFederationServlet):
         if get_domain_from_id(user_id) != origin:
             raise SynapseError(403, "user_id doesn't match origin")
 
-        new_content = yield self.handler.join_group(
-            group_id, user_id, content,
-        )
+        new_content = yield self.handler.join_group(group_id, user_id, content)
 
         defer.returnValue((200, new_content))
 
@@ -975,6 +971,7 @@ class FederationGroupsJoinServlet(BaseFederationServlet):
 class FederationGroupsRemoveUserServlet(BaseFederationServlet):
     """Leave or kick a user from the group
     """
+
     PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
 
     @defer.inlineCallbacks
@@ -984,7 +981,7 @@ class FederationGroupsRemoveUserServlet(BaseFederationServlet):
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
         new_content = yield self.handler.remove_user_from_group(
-            group_id, user_id, requester_user_id, content,
+            group_id, user_id, requester_user_id, content
         )
 
         defer.returnValue((200, new_content))
@@ -993,6 +990,7 @@ class FederationGroupsRemoveUserServlet(BaseFederationServlet):
 class FederationGroupsLocalInviteServlet(BaseFederationServlet):
     """A group server has invited a local user
     """
+
     PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
 
     @defer.inlineCallbacks
@@ -1000,9 +998,7 @@ class FederationGroupsLocalInviteServlet(BaseFederationServlet):
         if get_domain_from_id(group_id) != origin:
             raise SynapseError(403, "group_id doesn't match origin")
 
-        new_content = yield self.handler.on_invite(
-            group_id, user_id, content,
-        )
+        new_content = yield self.handler.on_invite(group_id, user_id, content)
 
         defer.returnValue((200, new_content))
 
@@ -1010,6 +1006,7 @@ class FederationGroupsLocalInviteServlet(BaseFederationServlet):
 class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
     """A group server has removed a local user
     """
+
     PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
 
     @defer.inlineCallbacks
@@ -1018,7 +1015,7 @@ class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
             raise SynapseError(403, "user_id doesn't match origin")
 
         new_content = yield self.handler.user_removed_from_group(
-            group_id, user_id, content,
+            group_id, user_id, content
         )
 
         defer.returnValue((200, new_content))
@@ -1027,6 +1024,7 @@ class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
 class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
     """A group or user's server renews their attestation
     """
+
     PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)"
 
     @defer.inlineCallbacks
@@ -1047,6 +1045,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
         - /groups/:group/summary/rooms/:room_id
         - /groups/:group/summary/categories/:category/rooms/:room_id
     """
+
     PATH = (
         "/groups/(?P<group_id>[^/]*)/summary"
         "(/categories/(?P<category_id>[^/]+))?"
@@ -1063,7 +1062,8 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
             raise SynapseError(400, "category_id cannot be empty string")
 
         resp = yield self.handler.update_group_summary_room(
-            group_id, requester_user_id,
+            group_id,
+            requester_user_id,
             room_id=room_id,
             category_id=category_id,
             content=content,
@@ -1081,9 +1081,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
             raise SynapseError(400, "category_id cannot be empty string")
 
         resp = yield self.handler.delete_group_summary_room(
-            group_id, requester_user_id,
-            room_id=room_id,
-            category_id=category_id,
+            group_id, requester_user_id, room_id=room_id, category_id=category_id
         )
 
         defer.returnValue((200, resp))
@@ -1092,9 +1090,8 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
 class FederationGroupsCategoriesServlet(BaseFederationServlet):
     """Get all categories for a group
     """
-    PATH = (
-        "/groups/(?P<group_id>[^/]*)/categories/?"
-    )
+
+    PATH = "/groups/(?P<group_id>[^/]*)/categories/?"
 
     @defer.inlineCallbacks
     def on_GET(self, origin, content, query, group_id):
@@ -1102,9 +1099,7 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet):
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
-        resp = yield self.handler.get_group_categories(
-            group_id, requester_user_id,
-        )
+        resp = yield self.handler.get_group_categories(group_id, requester_user_id)
 
         defer.returnValue((200, resp))
 
@@ -1112,9 +1107,8 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet):
 class FederationGroupsCategoryServlet(BaseFederationServlet):
     """Add/remove/get a category in a group
     """
-    PATH = (
-        "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)"
-    )
+
+    PATH = "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)"
 
     @defer.inlineCallbacks
     def on_GET(self, origin, content, query, group_id, category_id):
@@ -1138,7 +1132,7 @@ class FederationGroupsCategoryServlet(BaseFederationServlet):
             raise SynapseError(400, "category_id cannot be empty string")
 
         resp = yield self.handler.upsert_group_category(
-            group_id, requester_user_id, category_id, content,
+            group_id, requester_user_id, category_id, content
         )
 
         defer.returnValue((200, resp))
@@ -1153,7 +1147,7 @@ class FederationGroupsCategoryServlet(BaseFederationServlet):
             raise SynapseError(400, "category_id cannot be empty string")
 
         resp = yield self.handler.delete_group_category(
-            group_id, requester_user_id, category_id,
+            group_id, requester_user_id, category_id
         )
 
         defer.returnValue((200, resp))
@@ -1162,9 +1156,8 @@ class FederationGroupsCategoryServlet(BaseFederationServlet):
 class FederationGroupsRolesServlet(BaseFederationServlet):
     """Get roles in a group
     """
-    PATH = (
-        "/groups/(?P<group_id>[^/]*)/roles/?"
-    )
+
+    PATH = "/groups/(?P<group_id>[^/]*)/roles/?"
 
     @defer.inlineCallbacks
     def on_GET(self, origin, content, query, group_id):
@@ -1172,9 +1165,7 @@ class FederationGroupsRolesServlet(BaseFederationServlet):
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
-        resp = yield self.handler.get_group_roles(
-            group_id, requester_user_id,
-        )
+        resp = yield self.handler.get_group_roles(group_id, requester_user_id)
 
         defer.returnValue((200, resp))
 
@@ -1182,9 +1173,8 @@ class FederationGroupsRolesServlet(BaseFederationServlet):
 class FederationGroupsRoleServlet(BaseFederationServlet):
     """Add/remove/get a role in a group
     """
-    PATH = (
-        "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)"
-    )
+
+    PATH = "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)"
 
     @defer.inlineCallbacks
     def on_GET(self, origin, content, query, group_id, role_id):
@@ -1192,9 +1182,7 @@ class FederationGroupsRoleServlet(BaseFederationServlet):
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
-        resp = yield self.handler.get_group_role(
-            group_id, requester_user_id, role_id
-        )
+        resp = yield self.handler.get_group_role(group_id, requester_user_id, role_id)
 
         defer.returnValue((200, resp))
 
@@ -1208,7 +1196,7 @@ class FederationGroupsRoleServlet(BaseFederationServlet):
             raise SynapseError(400, "role_id cannot be empty string")
 
         resp = yield self.handler.update_group_role(
-            group_id, requester_user_id, role_id, content,
+            group_id, requester_user_id, role_id, content
         )
 
         defer.returnValue((200, resp))
@@ -1223,7 +1211,7 @@ class FederationGroupsRoleServlet(BaseFederationServlet):
             raise SynapseError(400, "role_id cannot be empty string")
 
         resp = yield self.handler.delete_group_role(
-            group_id, requester_user_id, role_id,
+            group_id, requester_user_id, role_id
         )
 
         defer.returnValue((200, resp))
@@ -1236,6 +1224,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
         - /groups/:group/summary/users/:user_id
         - /groups/:group/summary/roles/:role/users/:user_id
     """
+
     PATH = (
         "/groups/(?P<group_id>[^/]*)/summary"
         "(/roles/(?P<role_id>[^/]+))?"
@@ -1252,7 +1241,8 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
             raise SynapseError(400, "role_id cannot be empty string")
 
         resp = yield self.handler.update_group_summary_user(
-            group_id, requester_user_id,
+            group_id,
+            requester_user_id,
             user_id=user_id,
             role_id=role_id,
             content=content,
@@ -1270,9 +1260,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
             raise SynapseError(400, "role_id cannot be empty string")
 
         resp = yield self.handler.delete_group_summary_user(
-            group_id, requester_user_id,
-            user_id=user_id,
-            role_id=role_id,
+            group_id, requester_user_id, user_id=user_id, role_id=role_id
         )
 
         defer.returnValue((200, resp))
@@ -1281,14 +1269,13 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
 class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
     """Get roles in a group
     """
-    PATH = (
-        "/get_groups_publicised"
-    )
+
+    PATH = "/get_groups_publicised"
 
     @defer.inlineCallbacks
     def on_POST(self, origin, content, query):
         resp = yield self.handler.bulk_get_publicised_groups(
-            content["user_ids"], proxy=False,
+            content["user_ids"], proxy=False
         )
 
         defer.returnValue((200, resp))
@@ -1297,6 +1284,7 @@ class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
 class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
     """Sets whether a group is joinable without an invite or knock
     """
+
     PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy"
 
     @defer.inlineCallbacks
@@ -1317,6 +1305,7 @@ class RoomComplexityServlet(BaseFederationServlet):
     Indicates to other servers how complex (and therefore likely
     resource-intensive) a public room this server knows about is.
     """
+
     PATH = "/rooms/(?P<room_id>[^/]*)/complexity"
     PREFIX = FEDERATION_UNSTABLE_PREFIX
 
@@ -1325,9 +1314,7 @@ class RoomComplexityServlet(BaseFederationServlet):
 
         store = self.handler.hs.get_datastore()
 
-        is_public = yield store.is_room_world_readable_or_publicly_joinable(
-            room_id
-        )
+        is_public = yield store.is_room_world_readable_or_publicly_joinable(room_id)
 
         if not is_public:
             raise SynapseError(404, "Room not found", errcode=Codes.INVALID_PARAM)
@@ -1362,13 +1349,9 @@ FEDERATION_SERVLET_CLASSES = (
     RoomComplexityServlet,
 )
 
-OPENID_SERVLET_CLASSES = (
-    OpenIdUserInfo,
-)
+OPENID_SERVLET_CLASSES = (OpenIdUserInfo,)
 
-ROOM_LIST_CLASSES = (
-    PublicRoomList,
-)
+ROOM_LIST_CLASSES = (PublicRoomList,)
 
 GROUP_SERVER_SERVLET_CLASSES = (
     FederationGroupsProfileServlet,
@@ -1399,9 +1382,7 @@ GROUP_LOCAL_SERVLET_CLASSES = (
 )
 
 
-GROUP_ATTESTATION_SERVLET_CLASSES = (
-    FederationGroupsRenewAttestaionServlet,
-)
+GROUP_ATTESTATION_SERVLET_CLASSES = (FederationGroupsRenewAttestaionServlet,)
 
 DEFAULT_SERVLET_GROUPS = (
     "federation",
@@ -1455,7 +1436,7 @@ def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=N
                 authenticator=authenticator,
                 ratelimiter=ratelimiter,
                 server_name=hs.hostname,
-                deny_access=hs.config.restrict_public_rooms_to_local_users,
+                allow_access=hs.config.allow_public_rooms_over_federation,
             ).register(resource)
 
     if "group_server" in servlet_groups:
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index 025a79c022..14aad8f09d 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -32,21 +32,11 @@ class Edu(JsonEncodedObject):
     internal ID or previous references graph.
     """
 
-    valid_keys = [
-        "origin",
-        "destination",
-        "edu_type",
-        "content",
-    ]
+    valid_keys = ["origin", "destination", "edu_type", "content"]
 
-    required_keys = [
-        "edu_type",
-    ]
+    required_keys = ["edu_type"]
 
-    internal_keys = [
-        "origin",
-        "destination",
-    ]
+    internal_keys = ["origin", "destination"]
 
 
 class Transaction(JsonEncodedObject):
@@ -75,10 +65,7 @@ class Transaction(JsonEncodedObject):
         "edus",
     ]
 
-    internal_keys = [
-        "transaction_id",
-        "destination",
-    ]
+    internal_keys = ["transaction_id", "destination"]
 
     required_keys = [
         "transaction_id",
@@ -98,9 +85,7 @@ class Transaction(JsonEncodedObject):
             del kwargs["edus"]
 
         super(Transaction, self).__init__(
-            transaction_id=transaction_id,
-            pdus=pdus,
-            **kwargs
+            transaction_id=transaction_id, pdus=pdus, **kwargs
         )
 
     @staticmethod
@@ -109,13 +94,9 @@ class Transaction(JsonEncodedObject):
         transaction_id and origin_server_ts keys.
         """
         if "origin_server_ts" not in kwargs:
-            raise KeyError(
-                "Require 'origin_server_ts' to construct a Transaction"
-            )
+            raise KeyError("Require 'origin_server_ts' to construct a Transaction")
         if "transaction_id" not in kwargs:
-            raise KeyError(
-                "Require 'transaction_id' to construct a Transaction"
-            )
+            raise KeyError("Require 'transaction_id' to construct a Transaction")
 
         kwargs["pdus"] = [p.get_pdu_json() for p in pdus]