summary refs log tree commit diff
path: root/synapse/http/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/federation')
-rw-r--r--synapse/http/federation/matrix_federation_agent.py98
-rw-r--r--synapse/http/federation/srv_resolver.py37
2 files changed, 75 insertions, 60 deletions
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index b4cbe97b41..414cde0777 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -48,7 +48,7 @@ WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600
 WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600
 
 logger = logging.getLogger(__name__)
-well_known_cache = TTLCache('well-known')
+well_known_cache = TTLCache("well-known")
 
 
 @implementer(IAgent)
@@ -78,7 +78,9 @@ class MatrixFederationAgent(object):
     """
 
     def __init__(
-        self, reactor, tls_client_options_factory,
+        self,
+        reactor,
+        tls_client_options_factory,
         _well_known_tls_policy=None,
         _srv_resolver=None,
         _well_known_cache=well_known_cache,
@@ -100,9 +102,9 @@ class MatrixFederationAgent(object):
         if _well_known_tls_policy is not None:
             # the param is called 'contextFactory', but actually passing a
             # contextfactory is deprecated, and it expects an IPolicyForHTTPS.
-            agent_args['contextFactory'] = _well_known_tls_policy
+            agent_args["contextFactory"] = _well_known_tls_policy
         _well_known_agent = RedirectAgent(
-            Agent(self._reactor, pool=self._pool, **agent_args),
+            Agent(self._reactor, pool=self._pool, **agent_args)
         )
         self._well_known_agent = _well_known_agent
 
@@ -149,7 +151,7 @@ class MatrixFederationAgent(object):
             tls_options = None
         else:
             tls_options = self._tls_client_options_factory.get_options(
-                res.tls_server_name.decode("ascii"),
+                res.tls_server_name.decode("ascii")
             )
 
         # make sure that the Host header is set correctly
@@ -158,14 +160,14 @@ class MatrixFederationAgent(object):
         else:
             headers = headers.copy()
 
-        if not headers.hasHeader(b'host'):
-            headers.addRawHeader(b'host', res.host_header)
+        if not headers.hasHeader(b"host"):
+            headers.addRawHeader(b"host", res.host_header)
 
         class EndpointFactory(object):
             @staticmethod
             def endpointForURI(_uri):
                 ep = LoggingHostnameEndpoint(
-                    self._reactor, res.target_host, res.target_port,
+                    self._reactor, res.target_host, res.target_port
                 )
                 if tls_options is not None:
                     ep = wrapClientTLS(tls_options, ep)
@@ -203,21 +205,25 @@ class MatrixFederationAgent(object):
             port = parsed_uri.port
             if port == -1:
                 port = 8448
-            defer.returnValue(_RoutingResult(
-                host_header=parsed_uri.netloc,
-                tls_server_name=parsed_uri.host,
-                target_host=parsed_uri.host,
-                target_port=port,
-            ))
+            defer.returnValue(
+                _RoutingResult(
+                    host_header=parsed_uri.netloc,
+                    tls_server_name=parsed_uri.host,
+                    target_host=parsed_uri.host,
+                    target_port=port,
+                )
+            )
 
         if parsed_uri.port != -1:
             # there is an explicit port
-            defer.returnValue(_RoutingResult(
-                host_header=parsed_uri.netloc,
-                tls_server_name=parsed_uri.host,
-                target_host=parsed_uri.host,
-                target_port=parsed_uri.port,
-            ))
+            defer.returnValue(
+                _RoutingResult(
+                    host_header=parsed_uri.netloc,
+                    tls_server_name=parsed_uri.host,
+                    target_host=parsed_uri.host,
+                    target_port=parsed_uri.port,
+                )
+            )
 
         if lookup_well_known:
             # try a .well-known lookup
@@ -229,8 +235,8 @@ class MatrixFederationAgent(object):
 
                 # parse the server name in the .well-known response into host/port.
                 # (This code is lifted from twisted.web.client.URI.fromBytes).
-                if b':' in well_known_server:
-                    well_known_host, well_known_port = well_known_server.rsplit(b':', 1)
+                if b":" in well_known_server:
+                    well_known_host, well_known_port = well_known_server.rsplit(b":", 1)
                     try:
                         well_known_port = int(well_known_port)
                     except ValueError:
@@ -264,21 +270,27 @@ class MatrixFederationAgent(object):
             port = 8448
             logger.debug(
                 "No SRV record for %s, using %s:%i",
-                parsed_uri.host.decode("ascii"), target_host.decode("ascii"), port,
+                parsed_uri.host.decode("ascii"),
+                target_host.decode("ascii"),
+                port,
             )
         else:
             target_host, port = pick_server_from_list(server_list)
             logger.debug(
                 "Picked %s:%i from SRV records for %s",
-                target_host.decode("ascii"), port, parsed_uri.host.decode("ascii"),
+                target_host.decode("ascii"),
+                port,
+                parsed_uri.host.decode("ascii"),
             )
 
-        defer.returnValue(_RoutingResult(
-            host_header=parsed_uri.netloc,
-            tls_server_name=parsed_uri.host,
-            target_host=target_host,
-            target_port=port,
-        ))
+        defer.returnValue(
+            _RoutingResult(
+                host_header=parsed_uri.netloc,
+                tls_server_name=parsed_uri.host,
+                target_host=target_host,
+                target_port=port,
+            )
+        )
 
     @defer.inlineCallbacks
     def _get_well_known(self, server_name):
@@ -318,18 +330,18 @@ class MatrixFederationAgent(object):
                  - None if there was no .well-known file.
                  - INVALID_WELL_KNOWN if the .well-known was invalid
         """
-        uri = b"https://%s/.well-known/matrix/server" % (server_name, )
+        uri = b"https://%s/.well-known/matrix/server" % (server_name,)
         uri_str = uri.decode("ascii")
         logger.info("Fetching %s", uri_str)
         try:
             response = yield make_deferred_yieldable(
-                self._well_known_agent.request(b"GET", uri),
+                self._well_known_agent.request(b"GET", uri)
             )
             body = yield make_deferred_yieldable(readBody(response))
             if response.code != 200:
-                raise Exception("Non-200 response %s" % (response.code, ))
+                raise Exception("Non-200 response %s" % (response.code,))
 
-            parsed_body = json.loads(body.decode('utf-8'))
+            parsed_body = json.loads(body.decode("utf-8"))
             logger.info("Response from .well-known: %s", parsed_body)
             if not isinstance(parsed_body, dict):
                 raise Exception("not a dict")
@@ -347,8 +359,7 @@ class MatrixFederationAgent(object):
         result = parsed_body["m.server"].encode("ascii")
 
         cache_period = _cache_period_from_headers(
-            response.headers,
-            time_now=self._reactor.seconds,
+            response.headers, time_now=self._reactor.seconds
         )
         if cache_period is None:
             cache_period = WELL_KNOWN_DEFAULT_CACHE_PERIOD
@@ -364,6 +375,7 @@ class MatrixFederationAgent(object):
 @implementer(IStreamClientEndpoint)
 class LoggingHostnameEndpoint(object):
     """A wrapper for HostnameEndpint which logs when it connects"""
+
     def __init__(self, reactor, host, port, *args, **kwargs):
         self.host = host
         self.port = port
@@ -377,17 +389,17 @@ class LoggingHostnameEndpoint(object):
 def _cache_period_from_headers(headers, time_now=time.time):
     cache_controls = _parse_cache_control(headers)
 
-    if b'no-store' in cache_controls:
+    if b"no-store" in cache_controls:
         return 0
 
-    if b'max-age' in cache_controls:
+    if b"max-age" in cache_controls:
         try:
-            max_age = int(cache_controls[b'max-age'])
+            max_age = int(cache_controls[b"max-age"])
             return max_age
         except ValueError:
             pass
 
-    expires = headers.getRawHeaders(b'expires')
+    expires = headers.getRawHeaders(b"expires")
     if expires is not None:
         try:
             expires_date = stringToDatetime(expires[-1])
@@ -403,9 +415,9 @@ def _cache_period_from_headers(headers, time_now=time.time):
 
 def _parse_cache_control(headers):
     cache_controls = {}
-    for hdr in headers.getRawHeaders(b'cache-control', []):
-        for directive in hdr.split(b','):
-            splits = [x.strip() for x in directive.split(b'=', 1)]
+    for hdr in headers.getRawHeaders(b"cache-control", []):
+        for directive in hdr.split(b","):
+            splits = [x.strip() for x in directive.split(b"=", 1)]
             k = splits[0].lower()
             v = splits[1] if len(splits) > 1 else None
             cache_controls[k] = v
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index 71830c549d..1f22f78a75 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -45,6 +45,7 @@ class Server(object):
         expires (int): when the cache should expire this record - in *seconds* since
             the epoch
     """
+
     host = attr.ib()
     port = attr.ib()
     priority = attr.ib(default=0)
@@ -79,9 +80,7 @@ def pick_server_from_list(server_list):
             return s.host, s.port
 
     # this should be impossible.
-    raise RuntimeError(
-        "pick_server_from_list got to end of eligible server list.",
-    )
+    raise RuntimeError("pick_server_from_list got to end of eligible server list.")
 
 
 class SrvResolver(object):
@@ -95,6 +94,7 @@ class SrvResolver(object):
         cache (dict): cache object
         get_time (callable): clock implementation. Should return seconds since the epoch
     """
+
     def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time):
         self._dns_client = dns_client
         self._cache = cache
@@ -124,7 +124,7 @@ class SrvResolver(object):
 
         try:
             answers, _, _ = yield make_deferred_yieldable(
-                self._dns_client.lookupService(service_name),
+                self._dns_client.lookupService(service_name)
             )
         except DNSNameError:
             # TODO: cache this. We can get the SOA out of the exception, and use
@@ -136,17 +136,18 @@ class SrvResolver(object):
             cache_entry = self._cache.get(service_name, None)
             if cache_entry:
                 logger.warn(
-                    "Failed to resolve %r, falling back to cache. %r",
-                    service_name, e
+                    "Failed to resolve %r, falling back to cache. %r", service_name, e
                 )
                 defer.returnValue(list(cache_entry))
             else:
                 raise e
 
-        if (len(answers) == 1
-                and answers[0].type == dns.SRV
-                and answers[0].payload
-                and answers[0].payload.target == dns.Name(b'.')):
+        if (
+            len(answers) == 1
+            and answers[0].type == dns.SRV
+            and answers[0].payload
+            and answers[0].payload.target == dns.Name(b".")
+        ):
             raise ConnectError("Service %s unavailable" % service_name)
 
         servers = []
@@ -157,13 +158,15 @@ class SrvResolver(object):
 
             payload = answer.payload
 
-            servers.append(Server(
-                host=payload.target.name,
-                port=payload.port,
-                priority=payload.priority,
-                weight=payload.weight,
-                expires=now + answer.ttl,
-            ))
+            servers.append(
+                Server(
+                    host=payload.target.name,
+                    port=payload.port,
+                    priority=payload.priority,
+                    weight=payload.weight,
+                    expires=now + answer.ttl,
+                )
+            )
 
         self._cache[service_name] = list(servers)
         defer.returnValue(servers)