summary refs log tree commit diff
path: root/synapse/federation/federation_server.py
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2019-06-26 22:34:41 +0100
committerRichard van der Hoff <richard@matrix.org>2019-06-26 22:34:41 +0100
commita4daa899ec4cd195fc10936f68df5c78314b366c (patch)
tree35e88ff388b0f7652773a79930b732aa04f16bde /synapse/federation/federation_server.py
parentchangelog (diff)
parentImprove docs on choosing server_name (#5558) (diff)
downloadsynapse-a4daa899ec4cd195fc10936f68df5c78314b366c.tar.xz
Merge branch 'develop' into rav/saml2_client
Diffstat (limited to 'synapse/federation/federation_server.py')
-rw-r--r--synapse/federation/federation_server.py234
1 files changed, 93 insertions, 141 deletions
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 4c28c1dc3c..2e0cebb638 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -69,7 +69,6 @@ received_queries_counter = Counter(
 
 
 class FederationServer(FederationBase):
-
     def __init__(self, hs):
         super(FederationServer, self).__init__(hs)
 
@@ -118,11 +117,13 @@ class FederationServer(FederationBase):
 
         # use a linearizer to ensure that we don't process the same transaction
         # multiple times in parallel.
-        with (yield self._transaction_linearizer.queue(
-                (origin, transaction.transaction_id),
-        )):
+        with (
+            yield self._transaction_linearizer.queue(
+                (origin, transaction.transaction_id)
+            )
+        ):
             result = yield self._handle_incoming_transaction(
-                origin, transaction, request_time,
+                origin, transaction, request_time
             )
 
         defer.returnValue(result)
@@ -144,7 +145,7 @@ class FederationServer(FederationBase):
         if response:
             logger.debug(
                 "[%s] We've already responded to this request",
-                transaction.transaction_id
+                transaction.transaction_id,
             )
             defer.returnValue(response)
             return
@@ -152,18 +153,15 @@ class FederationServer(FederationBase):
         logger.debug("[%s] Transaction is new", transaction.transaction_id)
 
         # Reject if PDU count > 50 and EDU count > 100
-        if (len(transaction.pdus) > 50
-                or (hasattr(transaction, "edus") and len(transaction.edus) > 100)):
+        if len(transaction.pdus) > 50 or (
+            hasattr(transaction, "edus") and len(transaction.edus) > 100
+        ):
 
-            logger.info(
-                "Transaction PDU or EDU count too large. Returning 400",
-            )
+            logger.info("Transaction PDU or EDU count too large. Returning 400")
 
             response = {}
             yield self.transaction_actions.set_response(
-                origin,
-                transaction,
-                400, response
+                origin, transaction, 400, response
             )
             defer.returnValue((400, response))
 
@@ -230,9 +228,7 @@ class FederationServer(FederationBase):
             try:
                 yield self.check_server_matches_acl(origin_host, room_id)
             except AuthError as e:
-                logger.warn(
-                    "Ignoring PDUs for room %s from banned server", room_id,
-                )
+                logger.warn("Ignoring PDUs for room %s from banned server", room_id)
                 for pdu in pdus_by_room[room_id]:
                     event_id = pdu.event_id
                     pdu_results[event_id] = e.error_dict()
@@ -242,9 +238,7 @@ class FederationServer(FederationBase):
                 event_id = pdu.event_id
                 with nested_logging_context(event_id):
                     try:
-                        yield self._handle_received_pdu(
-                            origin, pdu
-                        )
+                        yield self._handle_received_pdu(origin, pdu)
                         pdu_results[event_id] = {}
                     except FederationError as e:
                         logger.warn("Error handling PDU %s: %s", event_id, e)
@@ -259,29 +253,18 @@ class FederationServer(FederationBase):
                         )
 
         yield concurrently_execute(
-            process_pdus_for_room, pdus_by_room.keys(),
-            TRANSACTION_CONCURRENCY_LIMIT,
+            process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
         )
 
         if hasattr(transaction, "edus"):
             for edu in (Edu(**x) for x in transaction.edus):
-                yield self.received_edu(
-                    origin,
-                    edu.edu_type,
-                    edu.content
-                )
+                yield self.received_edu(origin, edu.edu_type, edu.content)
 
-        response = {
-            "pdus": pdu_results,
-        }
+        response = {"pdus": pdu_results}
 
         logger.debug("Returning: %s", str(response))
 
-        yield self.transaction_actions.set_response(
-            origin,
-            transaction,
-            200, response
-        )
+        yield self.transaction_actions.set_response(origin, transaction, 200, response)
         defer.returnValue((200, response))
 
     @defer.inlineCallbacks
@@ -311,7 +294,8 @@ class FederationServer(FederationBase):
             resp = yield self._state_resp_cache.wrap(
                 (room_id, event_id),
                 self._on_context_state_request_compute,
-                room_id, event_id,
+                room_id,
+                event_id,
             )
 
         defer.returnValue((200, resp))
@@ -328,24 +312,17 @@ class FederationServer(FederationBase):
         if not in_room:
             raise AuthError(403, "Host not in room.")
 
-        state_ids = yield self.handler.get_state_ids_for_pdu(
-            room_id, event_id,
-        )
+        state_ids = yield self.handler.get_state_ids_for_pdu(room_id, event_id)
         auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids)
 
-        defer.returnValue((200, {
-            "pdu_ids": state_ids,
-            "auth_chain_ids": auth_chain_ids,
-        }))
+        defer.returnValue(
+            (200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids})
+        )
 
     @defer.inlineCallbacks
     def _on_context_state_request_compute(self, room_id, event_id):
-        pdus = yield self.handler.get_state_for_pdu(
-            room_id, event_id,
-        )
-        auth_chain = yield self.store.get_auth_chain(
-            [pdu.event_id for pdu in pdus]
-        )
+        pdus = yield self.handler.get_state_for_pdu(room_id, event_id)
+        auth_chain = yield self.store.get_auth_chain([pdu.event_id for pdu in pdus])
 
         for event in auth_chain:
             # We sign these again because there was a bug where we
@@ -355,14 +332,16 @@ class FederationServer(FederationBase):
                     compute_event_signature(
                         event.get_pdu_json(),
                         self.hs.hostname,
-                        self.hs.config.signing_key[0]
+                        self.hs.config.signing_key[0],
                     )
                 )
 
-        defer.returnValue({
-            "pdus": [pdu.get_pdu_json() for pdu in pdus],
-            "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
-        })
+        defer.returnValue(
+            {
+                "pdus": [pdu.get_pdu_json() for pdu in pdus],
+                "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
+            }
+        )
 
     @defer.inlineCallbacks
     @log_function
@@ -370,9 +349,7 @@ class FederationServer(FederationBase):
         pdu = yield self.handler.get_persisted_pdu(origin, event_id)
 
         if pdu:
-            defer.returnValue(
-                (200, self._transaction_from_pdus([pdu]).get_dict())
-            )
+            defer.returnValue((200, self._transaction_from_pdus([pdu]).get_dict()))
         else:
             defer.returnValue((404, ""))
 
@@ -394,10 +371,9 @@ class FederationServer(FederationBase):
 
         pdu = yield self.handler.on_make_join_request(room_id, user_id)
         time_now = self._clock.time_msec()
-        defer.returnValue({
-            "event": pdu.get_pdu_json(time_now),
-            "room_version": room_version,
-        })
+        defer.returnValue(
+            {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
+        )
 
     @defer.inlineCallbacks
     def on_invite_request(self, origin, content, room_version):
@@ -431,12 +407,17 @@ class FederationServer(FederationBase):
         logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
         res_pdus = yield self.handler.on_send_join_request(origin, pdu)
         time_now = self._clock.time_msec()
-        defer.returnValue((200, {
-            "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
-            "auth_chain": [
-                p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
-            ],
-        }))
+        defer.returnValue(
+            (
+                200,
+                {
+                    "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
+                    "auth_chain": [
+                        p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
+                    ],
+                },
+            )
+        )
 
     @defer.inlineCallbacks
     def on_make_leave_request(self, origin, room_id, user_id):
@@ -447,10 +428,9 @@ class FederationServer(FederationBase):
         room_version = yield self.store.get_room_version(room_id)
 
         time_now = self._clock.time_msec()
-        defer.returnValue({
-            "event": pdu.get_pdu_json(time_now),
-            "room_version": room_version,
-        })
+        defer.returnValue(
+            {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
+        )
 
     @defer.inlineCallbacks
     def on_send_leave_request(self, origin, content, room_id):
@@ -475,9 +455,7 @@ class FederationServer(FederationBase):
 
             time_now = self._clock.time_msec()
             auth_pdus = yield self.handler.on_event_auth(event_id)
-            res = {
-                "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
-            }
+            res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]}
         defer.returnValue((200, res))
 
     @defer.inlineCallbacks
@@ -508,12 +486,11 @@ class FederationServer(FederationBase):
             format_ver = room_version_to_event_format(room_version)
 
             auth_chain = [
-                event_from_pdu_json(e, format_ver)
-                for e in content["auth_chain"]
+                event_from_pdu_json(e, format_ver) for e in content["auth_chain"]
             ]
 
             signed_auth = yield self._check_sigs_and_hash_and_fetch(
-                origin, auth_chain, outlier=True, room_version=room_version,
+                origin, auth_chain, outlier=True, room_version=room_version
             )
 
             ret = yield self.handler.on_query_auth(
@@ -527,17 +504,12 @@ class FederationServer(FederationBase):
 
             time_now = self._clock.time_msec()
             send_content = {
-                "auth_chain": [
-                    e.get_pdu_json(time_now)
-                    for e in ret["auth_chain"]
-                ],
+                "auth_chain": [e.get_pdu_json(time_now) for e in ret["auth_chain"]],
                 "rejects": ret.get("rejects", []),
                 "missing": ret.get("missing", []),
             }
 
-        defer.returnValue(
-            (200, send_content)
-        )
+        defer.returnValue((200, send_content))
 
     @log_function
     def on_query_client_keys(self, origin, content):
@@ -566,20 +538,23 @@ class FederationServer(FederationBase):
 
         logger.info(
             "Claimed one-time-keys: %s",
-            ",".join((
-                "%s for %s:%s" % (key_id, user_id, device_id)
-                for user_id, user_keys in iteritems(json_result)
-                for device_id, device_keys in iteritems(user_keys)
-                for key_id, _ in iteritems(device_keys)
-            )),
+            ",".join(
+                (
+                    "%s for %s:%s" % (key_id, user_id, device_id)
+                    for user_id, user_keys in iteritems(json_result)
+                    for device_id, device_keys in iteritems(user_keys)
+                    for key_id, _ in iteritems(device_keys)
+                )
+            ),
         )
 
         defer.returnValue({"one_time_keys": json_result})
 
     @defer.inlineCallbacks
     @log_function
-    def on_get_missing_events(self, origin, room_id, earliest_events,
-                              latest_events, limit):
+    def on_get_missing_events(
+        self, origin, room_id, earliest_events, latest_events, limit
+    ):
         with (yield self._server_linearizer.queue((origin, room_id))):
             origin_host, _ = parse_server_name(origin)
             yield self.check_server_matches_acl(origin_host, room_id)
@@ -587,11 +562,13 @@ class FederationServer(FederationBase):
             logger.info(
                 "on_get_missing_events: earliest_events: %r, latest_events: %r,"
                 " limit: %d",
-                earliest_events, latest_events, limit,
+                earliest_events,
+                latest_events,
+                limit,
             )
 
             missing_events = yield self.handler.on_get_missing_events(
-                origin, room_id, earliest_events, latest_events, limit,
+                origin, room_id, earliest_events, latest_events, limit
             )
 
             if len(missing_events) < 5:
@@ -603,9 +580,9 @@ class FederationServer(FederationBase):
 
             time_now = self._clock.time_msec()
 
-        defer.returnValue({
-            "events": [ev.get_pdu_json(time_now) for ev in missing_events],
-        })
+        defer.returnValue(
+            {"events": [ev.get_pdu_json(time_now) for ev in missing_events]}
+        )
 
     @log_function
     def on_openid_userinfo(self, token):
@@ -666,22 +643,17 @@ class FederationServer(FederationBase):
             # origin. See bug #1893. This is also true for some third party
             # invites).
             if not (
-                pdu.type == 'm.room.member' and
-                pdu.content and
-                pdu.content.get("membership", None) in (
-                    Membership.JOIN, Membership.INVITE,
-                )
+                pdu.type == "m.room.member"
+                and pdu.content
+                and pdu.content.get("membership", None)
+                in (Membership.JOIN, Membership.INVITE)
             ):
                 logger.info(
-                    "Discarding PDU %s from invalid origin %s",
-                    pdu.event_id, origin
+                    "Discarding PDU %s from invalid origin %s", pdu.event_id, origin
                 )
                 return
             else:
-                logger.info(
-                    "Accepting join PDU %s from %s",
-                    pdu.event_id, origin
-                )
+                logger.info("Accepting join PDU %s from %s", pdu.event_id, origin)
 
         # We've already checked that we know the room version by this point
         room_version = yield self.store.get_room_version(pdu.room_id)
@@ -690,33 +662,19 @@ class FederationServer(FederationBase):
         try:
             pdu = yield self._check_sigs_and_hash(room_version, pdu)
         except SynapseError as e:
-            raise FederationError(
-                "ERROR",
-                e.code,
-                e.msg,
-                affected=pdu.event_id,
-            )
+            raise FederationError("ERROR", e.code, e.msg, affected=pdu.event_id)
 
-        yield self.handler.on_receive_pdu(
-            origin, pdu, sent_to_us_directly=True,
-        )
+        yield self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True)
 
     def __str__(self):
         return "<ReplicationLayer(%s)>" % self.server_name
 
     @defer.inlineCallbacks
     def exchange_third_party_invite(
-            self,
-            sender_user_id,
-            target_user_id,
-            room_id,
-            signed,
+        self, sender_user_id, target_user_id, room_id, signed
     ):
         ret = yield self.handler.exchange_third_party_invite(
-            sender_user_id,
-            target_user_id,
-            room_id,
-            signed,
+            sender_user_id, target_user_id, room_id, signed
         )
         defer.returnValue(ret)
 
@@ -771,7 +729,7 @@ def server_matches_acl_event(server_name, acl_event):
         allow_ip_literals = True
     if not allow_ip_literals:
         # check for ipv6 literals. These start with '['.
-        if server_name[0] == '[':
+        if server_name[0] == "[":
             return False
 
         # check for ipv4 literals. We can just lift the routine from twisted.
@@ -805,7 +763,9 @@ def server_matches_acl_event(server_name, acl_event):
 
 def _acl_entry_matches(server_name, acl_entry):
     if not isinstance(acl_entry, six.string_types):
-        logger.warn("Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry))
+        logger.warn(
+            "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
+        )
         return False
     regex = glob_to_regex(acl_entry)
     return regex.match(server_name)
@@ -815,6 +775,7 @@ class FederationHandlerRegistry(object):
     """Allows classes to register themselves as handlers for a given EDU or
     query type for incoming federation traffic.
     """
+
     def __init__(self):
         self.edu_handlers = {}
         self.query_handlers = {}
@@ -848,9 +809,7 @@ class FederationHandlerRegistry(object):
                 on and the result used as the response to the query request.
         """
         if query_type in self.query_handlers:
-            raise KeyError(
-                "Already have a Query handler for %s" % (query_type,)
-            )
+            raise KeyError("Already have a Query handler for %s" % (query_type,))
 
         logger.info("Registering federation query handler for %r", query_type)
 
@@ -905,14 +864,10 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
         handler = self.edu_handlers.get(edu_type)
         if handler:
             return super(ReplicationFederationHandlerRegistry, self).on_edu(
-                edu_type, origin, content,
+                edu_type, origin, content
             )
 
-        return self._send_edu(
-            edu_type=edu_type,
-            origin=origin,
-            content=content,
-        )
+        return self._send_edu(edu_type=edu_type, origin=origin, content=content)
 
     def on_query(self, query_type, args):
         """Overrides FederationHandlerRegistry
@@ -921,7 +876,4 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
         if handler:
             return handler(args)
 
-        return self._get_query_client(
-            query_type=query_type,
-            args=args,
-        )
+        return self._get_query_client(query_type=query_type, args=args)