summary refs log tree commit diff
path: root/synapse/http/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/federation')
-rw-r--r--synapse/http/federation/matrix_federation_agent.py28
-rw-r--r--synapse/http/federation/srv_resolver.py10
-rw-r--r--synapse/http/federation/well_known_resolver.py22
3 files changed, 37 insertions, 23 deletions
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index f5f917f5ae..369bf9c2fc 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -15,6 +15,7 @@
 
 import logging
 import urllib
+from typing import List
 
 from netaddr import AddrFormatError, IPAddress
 from zope.interface import implementer
@@ -48,6 +49,9 @@ class MatrixFederationAgent(object):
         tls_client_options_factory (FederationPolicyForHTTPS|None):
             factory to use for fetching client tls options, or none to disable TLS.
 
+        user_agent (bytes):
+            The user agent header to use for federation requests.
+
         _srv_resolver (SrvResolver|None):
             SRVResolver impl to use for looking up SRV records. None to use a default
             implementation.
@@ -61,6 +65,7 @@ class MatrixFederationAgent(object):
         self,
         reactor,
         tls_client_options_factory,
+        user_agent,
         _srv_resolver=None,
         _well_known_resolver=None,
     ):
@@ -78,6 +83,7 @@ class MatrixFederationAgent(object):
             ),
             pool=self._pool,
         )
+        self.user_agent = user_agent
 
         if _well_known_resolver is None:
             _well_known_resolver = WellKnownResolver(
@@ -87,6 +93,7 @@ class MatrixFederationAgent(object):
                     pool=self._pool,
                     contextFactory=tls_client_options_factory,
                 ),
+                user_agent=self.user_agent,
             )
 
         self._well_known_resolver = _well_known_resolver
@@ -149,7 +156,7 @@ class MatrixFederationAgent(object):
             parsed_uri = urllib.parse.urlparse(uri)
 
         # We need to make sure the host header is set to the netloc of the
-        # server.
+        # server and that a user-agent is provided.
         if headers is None:
             headers = Headers()
         else:
@@ -157,6 +164,8 @@ class MatrixFederationAgent(object):
 
         if not headers.hasHeader(b"host"):
             headers.addRawHeader(b"host", parsed_uri.netloc)
+        if not headers.hasHeader(b"user-agent"):
+            headers.addRawHeader(b"user-agent", self.user_agent)
 
         res = yield make_deferred_yieldable(
             self._agent.request(method, uri, headers, bodyProducer)
@@ -228,22 +237,21 @@ class MatrixHostnameEndpoint(object):
 
         return run_in_background(self._do_connect, protocol_factory)
 
-    @defer.inlineCallbacks
-    def _do_connect(self, protocol_factory):
+    async def _do_connect(self, protocol_factory):
         first_exception = None
 
-        server_list = yield self._resolve_server()
+        server_list = await self._resolve_server()
 
         for server in server_list:
             host = server.host
             port = server.port
 
             try:
-                logger.info("Connecting to %s:%i", host.decode("ascii"), port)
+                logger.debug("Connecting to %s:%i", host.decode("ascii"), port)
                 endpoint = HostnameEndpoint(self._reactor, host, port)
                 if self._tls_options:
                     endpoint = wrapClientTLS(self._tls_options, endpoint)
-                result = yield make_deferred_yieldable(
+                result = await make_deferred_yieldable(
                     endpoint.connect(protocol_factory)
                 )
 
@@ -263,13 +271,9 @@ class MatrixHostnameEndpoint(object):
         # to try and if that doesn't work then we'll have an exception.
         raise Exception("Failed to resolve server %r" % (self._parsed_uri.netloc,))
 
-    @defer.inlineCallbacks
-    def _resolve_server(self):
+    async def _resolve_server(self) -> List[Server]:
         """Resolves the server name to a list of hosts and ports to attempt to
         connect to.
-
-        Returns:
-            Deferred[list[Server]]
         """
 
         if self._parsed_uri.scheme != b"matrix":
@@ -290,7 +294,7 @@ class MatrixHostnameEndpoint(object):
         if port or _is_ip_literal(host):
             return [Server(host, port or 8448)]
 
-        server_list = yield self._srv_resolver.resolve_service(b"_matrix._tcp." + host)
+        server_list = await self._srv_resolver.resolve_service(b"_matrix._tcp." + host)
 
         if server_list:
             return server_list
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index 021b233a7d..2ede90a9b1 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -17,10 +17,10 @@
 import logging
 import random
 import time
+from typing import List
 
 import attr
 
-from twisted.internet import defer
 from twisted.internet.error import ConnectError
 from twisted.names import client, dns
 from twisted.names.error import DNSNameError, DomainError
@@ -113,16 +113,14 @@ class SrvResolver(object):
         self._cache = cache
         self._get_time = get_time
 
-    @defer.inlineCallbacks
-    def resolve_service(self, service_name):
+    async def resolve_service(self, service_name: bytes) -> List[Server]:
         """Look up a SRV record
 
         Args:
             service_name (bytes): record to look up
 
         Returns:
-            Deferred[list[Server]]:
-                a list of the SRV records, or an empty list if none found
+            a list of the SRV records, or an empty list if none found
         """
         now = int(self._get_time())
 
@@ -136,7 +134,7 @@ class SrvResolver(object):
                 return _sort_server_list(servers)
 
         try:
-            answers, _, _ = yield make_deferred_yieldable(
+            answers, _, _ = await make_deferred_yieldable(
                 self._dns_client.lookupService(service_name)
             )
         except DNSNameError:
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index 7ddfad286d..f794315deb 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -13,7 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import json
 import logging
 import random
 import time
@@ -23,9 +22,10 @@ import attr
 from twisted.internet import defer
 from twisted.web.client import RedirectAgent, readBody
 from twisted.web.http import stringToDatetime
+from twisted.web.http_headers import Headers
 
 from synapse.logging.context import make_deferred_yieldable
-from synapse.util import Clock
+from synapse.util import Clock, json_decoder
 from synapse.util.caches.ttlcache import TTLCache
 from synapse.util.metrics import Measure
 
@@ -78,7 +78,12 @@ class WellKnownResolver(object):
     """
 
     def __init__(
-        self, reactor, agent, well_known_cache=None, had_well_known_cache=None
+        self,
+        reactor,
+        agent,
+        user_agent,
+        well_known_cache=None,
+        had_well_known_cache=None,
     ):
         self._reactor = reactor
         self._clock = Clock(reactor)
@@ -92,6 +97,7 @@ class WellKnownResolver(object):
         self._well_known_cache = well_known_cache
         self._had_valid_well_known_cache = had_well_known_cache
         self._well_known_agent = RedirectAgent(agent)
+        self.user_agent = user_agent
 
     @defer.inlineCallbacks
     def get_well_known(self, server_name):
@@ -174,7 +180,7 @@ class WellKnownResolver(object):
             if response.code != 200:
                 raise Exception("Non-200 response %s" % (response.code,))
 
-            parsed_body = json.loads(body.decode("utf-8"))
+            parsed_body = json_decoder.decode(body.decode("utf-8"))
             logger.info("Response from .well-known: %s", parsed_body)
 
             result = parsed_body["m.server"].encode("ascii")
@@ -227,6 +233,10 @@ class WellKnownResolver(object):
         uri = b"https://%s/.well-known/matrix/server" % (server_name,)
         uri_str = uri.decode("ascii")
 
+        headers = {
+            b"User-Agent": [self.user_agent],
+        }
+
         i = 0
         while True:
             i += 1
@@ -234,7 +244,9 @@ class WellKnownResolver(object):
             logger.info("Fetching %s", uri_str)
             try:
                 response = yield make_deferred_yieldable(
-                    self._well_known_agent.request(b"GET", uri)
+                    self._well_known_agent.request(
+                        b"GET", uri, headers=Headers(headers)
+                    )
                 )
                 body = yield make_deferred_yieldable(readBody(response))