diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index f5f917f5ae..83d6196d4a 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -15,6 +15,7 @@
import logging
import urllib
+from typing import List
from netaddr import AddrFormatError, IPAddress
from zope.interface import implementer
@@ -35,7 +36,7 @@ logger = logging.getLogger(__name__)
@implementer(IAgent)
-class MatrixFederationAgent(object):
+class MatrixFederationAgent:
"""An Agent-like thing which provides a `request` method which correctly
handles resolving matrix server names when using matrix://. Handles standard
https URIs as normal.
@@ -48,6 +49,9 @@ class MatrixFederationAgent(object):
tls_client_options_factory (FederationPolicyForHTTPS|None):
factory to use for fetching client tls options, or none to disable TLS.
+ user_agent (bytes):
+ The user agent header to use for federation requests.
+
_srv_resolver (SrvResolver|None):
SRVResolver impl to use for looking up SRV records. None to use a default
implementation.
@@ -61,6 +65,7 @@ class MatrixFederationAgent(object):
self,
reactor,
tls_client_options_factory,
+ user_agent,
_srv_resolver=None,
_well_known_resolver=None,
):
@@ -78,6 +83,7 @@ class MatrixFederationAgent(object):
),
pool=self._pool,
)
+ self.user_agent = user_agent
if _well_known_resolver is None:
_well_known_resolver = WellKnownResolver(
@@ -87,6 +93,7 @@ class MatrixFederationAgent(object):
pool=self._pool,
contextFactory=tls_client_options_factory,
),
+ user_agent=self.user_agent,
)
self._well_known_resolver = _well_known_resolver
@@ -127,8 +134,8 @@ class MatrixFederationAgent(object):
and not _is_ip_literal(parsed_uri.hostname)
and not parsed_uri.port
):
- well_known_result = yield self._well_known_resolver.get_well_known(
- parsed_uri.hostname
+ well_known_result = yield defer.ensureDeferred(
+ self._well_known_resolver.get_well_known(parsed_uri.hostname)
)
delegated_server = well_known_result.delegated_server
@@ -149,7 +156,7 @@ class MatrixFederationAgent(object):
parsed_uri = urllib.parse.urlparse(uri)
# We need to make sure the host header is set to the netloc of the
- # server.
+ # server and that a user-agent is provided.
if headers is None:
headers = Headers()
else:
@@ -157,6 +164,8 @@ class MatrixFederationAgent(object):
if not headers.hasHeader(b"host"):
headers.addRawHeader(b"host", parsed_uri.netloc)
+ if not headers.hasHeader(b"user-agent"):
+ headers.addRawHeader(b"user-agent", self.user_agent)
res = yield make_deferred_yieldable(
self._agent.request(method, uri, headers, bodyProducer)
@@ -166,7 +175,7 @@ class MatrixFederationAgent(object):
@implementer(IAgentEndpointFactory)
-class MatrixHostnameEndpointFactory(object):
+class MatrixHostnameEndpointFactory:
"""Factory for MatrixHostnameEndpoint for parsing to an Agent.
"""
@@ -189,7 +198,7 @@ class MatrixHostnameEndpointFactory(object):
@implementer(IStreamClientEndpoint)
-class MatrixHostnameEndpoint(object):
+class MatrixHostnameEndpoint:
"""An endpoint that resolves matrix:// URLs using Matrix server name
resolution (i.e. via SRV). Does not check for well-known delegation.
@@ -228,22 +237,21 @@ class MatrixHostnameEndpoint(object):
return run_in_background(self._do_connect, protocol_factory)
- @defer.inlineCallbacks
- def _do_connect(self, protocol_factory):
+ async def _do_connect(self, protocol_factory):
first_exception = None
- server_list = yield self._resolve_server()
+ server_list = await self._resolve_server()
for server in server_list:
host = server.host
port = server.port
try:
- logger.info("Connecting to %s:%i", host.decode("ascii"), port)
+ logger.debug("Connecting to %s:%i", host.decode("ascii"), port)
endpoint = HostnameEndpoint(self._reactor, host, port)
if self._tls_options:
endpoint = wrapClientTLS(self._tls_options, endpoint)
- result = yield make_deferred_yieldable(
+ result = await make_deferred_yieldable(
endpoint.connect(protocol_factory)
)
@@ -263,13 +271,9 @@ class MatrixHostnameEndpoint(object):
# to try and if that doesn't work then we'll have an exception.
raise Exception("Failed to resolve server %r" % (self._parsed_uri.netloc,))
- @defer.inlineCallbacks
- def _resolve_server(self):
+ async def _resolve_server(self) -> List[Server]:
"""Resolves the server name to a list of hosts and ports to attempt to
connect to.
-
- Returns:
- Deferred[list[Server]]
"""
if self._parsed_uri.scheme != b"matrix":
@@ -290,7 +294,7 @@ class MatrixHostnameEndpoint(object):
if port or _is_ip_literal(host):
return [Server(host, port or 8448)]
- server_list = yield self._srv_resolver.resolve_service(b"_matrix._tcp." + host)
+ server_list = await self._srv_resolver.resolve_service(b"_matrix._tcp." + host)
if server_list:
return server_list
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index 021b233a7d..d9620032d2 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -17,10 +17,10 @@
import logging
import random
import time
+from typing import List
import attr
-from twisted.internet import defer
from twisted.internet.error import ConnectError
from twisted.names import client, dns
from twisted.names.error import DNSNameError, DomainError
@@ -33,7 +33,7 @@ SERVER_CACHE = {}
@attr.s(slots=True, frozen=True)
-class Server(object):
+class Server:
"""
Our record of an individual server which can be tried to reach a destination.
@@ -96,7 +96,7 @@ def _sort_server_list(server_list):
return results
-class SrvResolver(object):
+class SrvResolver:
"""Interface to the dns client to do SRV lookups, with result caching.
The default resolver in twisted.names doesn't do any caching (it has a CacheResolver,
@@ -113,16 +113,14 @@ class SrvResolver(object):
self._cache = cache
self._get_time = get_time
- @defer.inlineCallbacks
- def resolve_service(self, service_name):
+ async def resolve_service(self, service_name: bytes) -> List[Server]:
"""Look up a SRV record
Args:
service_name (bytes): record to look up
Returns:
- Deferred[list[Server]]:
- a list of the SRV records, or an empty list if none found
+ a list of the SRV records, or an empty list if none found
"""
now = int(self._get_time())
@@ -136,7 +134,7 @@ class SrvResolver(object):
return _sort_server_list(servers)
try:
- answers, _, _ = yield make_deferred_yieldable(
+ answers, _, _ = await make_deferred_yieldable(
self._dns_client.lookupService(service_name)
)
except DNSNameError:
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index 7ddfad286d..e6f067ca29 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -13,19 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
import logging
import random
import time
+from typing import Callable, Dict, Optional, Tuple
import attr
from twisted.internet import defer
from twisted.web.client import RedirectAgent, readBody
from twisted.web.http import stringToDatetime
+from twisted.web.http_headers import Headers
+from twisted.web.iweb import IResponse
from synapse.logging.context import make_deferred_yieldable
-from synapse.util import Clock
+from synapse.util import Clock, json_decoder
from synapse.util.caches.ttlcache import TTLCache
from synapse.util.metrics import Measure
@@ -69,16 +71,21 @@ _had_valid_well_known_cache = TTLCache("had-valid-well-known")
@attr.s(slots=True, frozen=True)
-class WellKnownLookupResult(object):
+class WellKnownLookupResult:
delegated_server = attr.ib()
-class WellKnownResolver(object):
+class WellKnownResolver:
"""Handles well-known lookups for matrix servers.
"""
def __init__(
- self, reactor, agent, well_known_cache=None, had_well_known_cache=None
+ self,
+ reactor,
+ agent,
+ user_agent,
+ well_known_cache=None,
+ had_well_known_cache=None,
):
self._reactor = reactor
self._clock = Clock(reactor)
@@ -92,16 +99,16 @@ class WellKnownResolver(object):
self._well_known_cache = well_known_cache
self._had_valid_well_known_cache = had_well_known_cache
self._well_known_agent = RedirectAgent(agent)
+ self.user_agent = user_agent
- @defer.inlineCallbacks
- def get_well_known(self, server_name):
+ async def get_well_known(self, server_name: bytes) -> WellKnownLookupResult:
"""Attempt to fetch and parse a .well-known file for the given server
Args:
- server_name (bytes): name of the server, from the requested url
+ server_name: name of the server, from the requested url
Returns:
- Deferred[WellKnownLookupResult]: The result of the lookup
+ The result of the lookup
"""
try:
prev_result, expiry, ttl = self._well_known_cache.get_with_expiry(
@@ -118,7 +125,9 @@ class WellKnownResolver(object):
# requests for the same server in parallel?
try:
with Measure(self._clock, "get_well_known"):
- result, cache_period = yield self._fetch_well_known(server_name)
+ result, cache_period = await self._fetch_well_known(
+ server_name
+ ) # type: Tuple[Optional[bytes], float]
except _FetchWellKnownFailure as e:
if prev_result and e.temporary:
@@ -147,18 +156,17 @@ class WellKnownResolver(object):
return WellKnownLookupResult(delegated_server=result)
- @defer.inlineCallbacks
- def _fetch_well_known(self, server_name):
+ async def _fetch_well_known(self, server_name: bytes) -> Tuple[bytes, float]:
"""Actually fetch and parse a .well-known, without checking the cache
Args:
- server_name (bytes): name of the server, from the requested url
+ server_name: name of the server, from the requested url
Raises:
_FetchWellKnownFailure if we fail to lookup a result
Returns:
- Deferred[Tuple[bytes,int]]: The lookup result and cache period.
+ The lookup result and cache period.
"""
had_valid_well_known = self._had_valid_well_known_cache.get(server_name, False)
@@ -166,7 +174,7 @@ class WellKnownResolver(object):
# We do this in two steps to differentiate between possibly transient
# errors (e.g. can't connect to host, 503 response) and more permenant
# errors (such as getting a 404 response).
- response, body = yield self._make_well_known_request(
+ response, body = await self._make_well_known_request(
server_name, retry=had_valid_well_known
)
@@ -174,7 +182,7 @@ class WellKnownResolver(object):
if response.code != 200:
raise Exception("Non-200 response %s" % (response.code,))
- parsed_body = json.loads(body.decode("utf-8"))
+ parsed_body = json_decoder.decode(body.decode("utf-8"))
logger.info("Response from .well-known: %s", parsed_body)
result = parsed_body["m.server"].encode("ascii")
@@ -209,34 +217,40 @@ class WellKnownResolver(object):
return result, cache_period
- @defer.inlineCallbacks
- def _make_well_known_request(self, server_name, retry):
+ async def _make_well_known_request(
+ self, server_name: bytes, retry: bool
+ ) -> Tuple[IResponse, bytes]:
"""Make the well known request.
This will retry the request if requested and it fails (with unable
to connect or receives a 5xx error).
Args:
- server_name (bytes)
- retry (bool): Whether to retry the request if it fails.
+ server_name: name of the server, from the requested url
+ retry: Whether to retry the request if it fails.
Returns:
- Deferred[tuple[IResponse, bytes]] Returns the response object and
- body. Response may be a non-200 response.
+ Returns the response object and body. Response may be a non-200 response.
"""
uri = b"https://%s/.well-known/matrix/server" % (server_name,)
uri_str = uri.decode("ascii")
+ headers = {
+ b"User-Agent": [self.user_agent],
+ }
+
i = 0
while True:
i += 1
logger.info("Fetching %s", uri_str)
try:
- response = yield make_deferred_yieldable(
- self._well_known_agent.request(b"GET", uri)
+ response = await make_deferred_yieldable(
+ self._well_known_agent.request(
+ b"GET", uri, headers=Headers(headers)
+ )
)
- body = yield make_deferred_yieldable(readBody(response))
+ body = await make_deferred_yieldable(readBody(response))
if 500 <= response.code < 600:
raise Exception("Non-200 response %s" % (response.code,))
@@ -253,21 +267,24 @@ class WellKnownResolver(object):
logger.info("Error fetching %s: %s. Retrying", uri_str, e)
# Sleep briefly in the hopes that they come back up
- yield self._clock.sleep(0.5)
+ await self._clock.sleep(0.5)
-def _cache_period_from_headers(headers, time_now=time.time):
+def _cache_period_from_headers(
+ headers: Headers, time_now: Callable[[], float] = time.time
+) -> Optional[float]:
cache_controls = _parse_cache_control(headers)
if b"no-store" in cache_controls:
return 0
if b"max-age" in cache_controls:
- try:
- max_age = int(cache_controls[b"max-age"])
- return max_age
- except ValueError:
- pass
+ max_age = cache_controls[b"max-age"]
+ if max_age:
+ try:
+ return int(max_age)
+ except ValueError:
+ pass
expires = headers.getRawHeaders(b"expires")
if expires is not None:
@@ -283,7 +300,7 @@ def _cache_period_from_headers(headers, time_now=time.time):
return None
-def _parse_cache_control(headers):
+def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
cache_controls = {}
for hdr in headers.getRawHeaders(b"cache-control", []):
for directive in hdr.split(b","):
|