diff options
Diffstat (limited to 'synapse/http/federation')
-rw-r--r-- | synapse/http/federation/matrix_federation_agent.py | 98 | ||||
-rw-r--r-- | synapse/http/federation/srv_resolver.py | 37 |
2 files changed, 75 insertions, 60 deletions
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index b4cbe97b41..414cde0777 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -48,7 +48,7 @@ WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600 WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600 logger = logging.getLogger(__name__) -well_known_cache = TTLCache('well-known') +well_known_cache = TTLCache("well-known") @implementer(IAgent) @@ -78,7 +78,9 @@ class MatrixFederationAgent(object): """ def __init__( - self, reactor, tls_client_options_factory, + self, + reactor, + tls_client_options_factory, _well_known_tls_policy=None, _srv_resolver=None, _well_known_cache=well_known_cache, @@ -100,9 +102,9 @@ class MatrixFederationAgent(object): if _well_known_tls_policy is not None: # the param is called 'contextFactory', but actually passing a # contextfactory is deprecated, and it expects an IPolicyForHTTPS. - agent_args['contextFactory'] = _well_known_tls_policy + agent_args["contextFactory"] = _well_known_tls_policy _well_known_agent = RedirectAgent( - Agent(self._reactor, pool=self._pool, **agent_args), + Agent(self._reactor, pool=self._pool, **agent_args) ) self._well_known_agent = _well_known_agent @@ -149,7 +151,7 @@ class MatrixFederationAgent(object): tls_options = None else: tls_options = self._tls_client_options_factory.get_options( - res.tls_server_name.decode("ascii"), + res.tls_server_name.decode("ascii") ) # make sure that the Host header is set correctly @@ -158,14 +160,14 @@ class MatrixFederationAgent(object): else: headers = headers.copy() - if not headers.hasHeader(b'host'): - headers.addRawHeader(b'host', res.host_header) + if not headers.hasHeader(b"host"): + headers.addRawHeader(b"host", res.host_header) class EndpointFactory(object): @staticmethod def endpointForURI(_uri): ep = LoggingHostnameEndpoint( - self._reactor, res.target_host, res.target_port, + self._reactor, res.target_host, res.target_port ) if tls_options is not None: ep = wrapClientTLS(tls_options, ep) @@ -203,21 +205,25 @@ class MatrixFederationAgent(object): port = parsed_uri.port if port == -1: port = 8448 - defer.returnValue(_RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=parsed_uri.host, - target_port=port, - )) + defer.returnValue( + _RoutingResult( + host_header=parsed_uri.netloc, + tls_server_name=parsed_uri.host, + target_host=parsed_uri.host, + target_port=port, + ) + ) if parsed_uri.port != -1: # there is an explicit port - defer.returnValue(_RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=parsed_uri.host, - target_port=parsed_uri.port, - )) + defer.returnValue( + _RoutingResult( + host_header=parsed_uri.netloc, + tls_server_name=parsed_uri.host, + target_host=parsed_uri.host, + target_port=parsed_uri.port, + ) + ) if lookup_well_known: # try a .well-known lookup @@ -229,8 +235,8 @@ class MatrixFederationAgent(object): # parse the server name in the .well-known response into host/port. # (This code is lifted from twisted.web.client.URI.fromBytes). - if b':' in well_known_server: - well_known_host, well_known_port = well_known_server.rsplit(b':', 1) + if b":" in well_known_server: + well_known_host, well_known_port = well_known_server.rsplit(b":", 1) try: well_known_port = int(well_known_port) except ValueError: @@ -264,21 +270,27 @@ class MatrixFederationAgent(object): port = 8448 logger.debug( "No SRV record for %s, using %s:%i", - parsed_uri.host.decode("ascii"), target_host.decode("ascii"), port, + parsed_uri.host.decode("ascii"), + target_host.decode("ascii"), + port, ) else: target_host, port = pick_server_from_list(server_list) logger.debug( "Picked %s:%i from SRV records for %s", - target_host.decode("ascii"), port, parsed_uri.host.decode("ascii"), + target_host.decode("ascii"), + port, + parsed_uri.host.decode("ascii"), ) - defer.returnValue(_RoutingResult( - host_header=parsed_uri.netloc, - tls_server_name=parsed_uri.host, - target_host=target_host, - target_port=port, - )) + defer.returnValue( + _RoutingResult( + host_header=parsed_uri.netloc, + tls_server_name=parsed_uri.host, + target_host=target_host, + target_port=port, + ) + ) @defer.inlineCallbacks def _get_well_known(self, server_name): @@ -318,18 +330,18 @@ class MatrixFederationAgent(object): - None if there was no .well-known file. - INVALID_WELL_KNOWN if the .well-known was invalid """ - uri = b"https://%s/.well-known/matrix/server" % (server_name, ) + uri = b"https://%s/.well-known/matrix/server" % (server_name,) uri_str = uri.decode("ascii") logger.info("Fetching %s", uri_str) try: response = yield make_deferred_yieldable( - self._well_known_agent.request(b"GET", uri), + self._well_known_agent.request(b"GET", uri) ) body = yield make_deferred_yieldable(readBody(response)) if response.code != 200: - raise Exception("Non-200 response %s" % (response.code, )) + raise Exception("Non-200 response %s" % (response.code,)) - parsed_body = json.loads(body.decode('utf-8')) + parsed_body = json.loads(body.decode("utf-8")) logger.info("Response from .well-known: %s", parsed_body) if not isinstance(parsed_body, dict): raise Exception("not a dict") @@ -347,8 +359,7 @@ class MatrixFederationAgent(object): result = parsed_body["m.server"].encode("ascii") cache_period = _cache_period_from_headers( - response.headers, - time_now=self._reactor.seconds, + response.headers, time_now=self._reactor.seconds ) if cache_period is None: cache_period = WELL_KNOWN_DEFAULT_CACHE_PERIOD @@ -364,6 +375,7 @@ class MatrixFederationAgent(object): @implementer(IStreamClientEndpoint) class LoggingHostnameEndpoint(object): """A wrapper for HostnameEndpint which logs when it connects""" + def __init__(self, reactor, host, port, *args, **kwargs): self.host = host self.port = port @@ -377,17 +389,17 @@ class LoggingHostnameEndpoint(object): def _cache_period_from_headers(headers, time_now=time.time): cache_controls = _parse_cache_control(headers) - if b'no-store' in cache_controls: + if b"no-store" in cache_controls: return 0 - if b'max-age' in cache_controls: + if b"max-age" in cache_controls: try: - max_age = int(cache_controls[b'max-age']) + max_age = int(cache_controls[b"max-age"]) return max_age except ValueError: pass - expires = headers.getRawHeaders(b'expires') + expires = headers.getRawHeaders(b"expires") if expires is not None: try: expires_date = stringToDatetime(expires[-1]) @@ -403,9 +415,9 @@ def _cache_period_from_headers(headers, time_now=time.time): def _parse_cache_control(headers): cache_controls = {} - for hdr in headers.getRawHeaders(b'cache-control', []): - for directive in hdr.split(b','): - splits = [x.strip() for x in directive.split(b'=', 1)] + for hdr in headers.getRawHeaders(b"cache-control", []): + for directive in hdr.split(b","): + splits = [x.strip() for x in directive.split(b"=", 1)] k = splits[0].lower() v = splits[1] if len(splits) > 1 else None cache_controls[k] = v diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py index 71830c549d..1f22f78a75 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py @@ -45,6 +45,7 @@ class Server(object): expires (int): when the cache should expire this record - in *seconds* since the epoch """ + host = attr.ib() port = attr.ib() priority = attr.ib(default=0) @@ -79,9 +80,7 @@ def pick_server_from_list(server_list): return s.host, s.port # this should be impossible. - raise RuntimeError( - "pick_server_from_list got to end of eligible server list.", - ) + raise RuntimeError("pick_server_from_list got to end of eligible server list.") class SrvResolver(object): @@ -95,6 +94,7 @@ class SrvResolver(object): cache (dict): cache object get_time (callable): clock implementation. Should return seconds since the epoch """ + def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time): self._dns_client = dns_client self._cache = cache @@ -124,7 +124,7 @@ class SrvResolver(object): try: answers, _, _ = yield make_deferred_yieldable( - self._dns_client.lookupService(service_name), + self._dns_client.lookupService(service_name) ) except DNSNameError: # TODO: cache this. We can get the SOA out of the exception, and use @@ -136,17 +136,18 @@ class SrvResolver(object): cache_entry = self._cache.get(service_name, None) if cache_entry: logger.warn( - "Failed to resolve %r, falling back to cache. %r", - service_name, e + "Failed to resolve %r, falling back to cache. %r", service_name, e ) defer.returnValue(list(cache_entry)) else: raise e - if (len(answers) == 1 - and answers[0].type == dns.SRV - and answers[0].payload - and answers[0].payload.target == dns.Name(b'.')): + if ( + len(answers) == 1 + and answers[0].type == dns.SRV + and answers[0].payload + and answers[0].payload.target == dns.Name(b".") + ): raise ConnectError("Service %s unavailable" % service_name) servers = [] @@ -157,13 +158,15 @@ class SrvResolver(object): payload = answer.payload - servers.append(Server( - host=payload.target.name, - port=payload.port, - priority=payload.priority, - weight=payload.weight, - expires=now + answer.ttl, - )) + servers.append( + Server( + host=payload.target.name, + port=payload.port, + priority=payload.priority, + weight=payload.weight, + expires=now + answer.ttl, + ) + ) self._cache[service_name] = list(servers) defer.returnValue(servers) |