summary refs log tree commit diff
path: root/synapse/http/federation/matrix_federation_agent.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-11-25 07:07:21 -0500
committerGitHub <noreply@github.com>2020-11-25 07:07:21 -0500
commitf38676d16143e399b654504486cf8cbecad12a5d (patch)
tree91c4d01eaec5d525b5899242a6c66fa505621330 /synapse/http/federation/matrix_federation_agent.py
parentClarify documentation of the admin list media API (#8795) (diff)
downloadsynapse-f38676d16143e399b654504486cf8cbecad12a5d.tar.xz
Add type hints to matrix federation client / agent. (#8806)
Diffstat (limited to 'synapse/http/federation/matrix_federation_agent.py')
-rw-r--r--synapse/http/federation/matrix_federation_agent.py100
1 files changed, 61 insertions, 39 deletions
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 83d6196d4a..e77f9587d0 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -12,21 +12,25 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
-import urllib
-from typing import List
+import urllib.parse
+from typing import List, Optional
 
 from netaddr import AddrFormatError, IPAddress
 from zope.interface import implementer
 
 from twisted.internet import defer
 from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
-from twisted.internet.interfaces import IStreamClientEndpoint
-from twisted.web.client import Agent, HTTPConnectionPool
+from twisted.internet.interfaces import (
+    IProtocolFactory,
+    IReactorCore,
+    IStreamClientEndpoint,
+)
+from twisted.web.client import URI, Agent, HTTPConnectionPool
 from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent, IAgentEndpointFactory
+from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer
 
+from synapse.crypto.context_factory import FederationPolicyForHTTPS
 from synapse.http.federation.srv_resolver import Server, SrvResolver
 from synapse.http.federation.well_known_resolver import WellKnownResolver
 from synapse.logging.context import make_deferred_yieldable, run_in_background
@@ -44,30 +48,30 @@ class MatrixFederationAgent:
     Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.)
 
     Args:
-        reactor (IReactor): twisted reactor to use for underlying requests
+        reactor: twisted reactor to use for underlying requests
 
-        tls_client_options_factory (FederationPolicyForHTTPS|None):
+        tls_client_options_factory:
             factory to use for fetching client tls options, or none to disable TLS.
 
-        user_agent (bytes):
+        user_agent:
             The user agent header to use for federation requests.
 
-        _srv_resolver (SrvResolver|None):
-            SRVResolver impl to use for looking up SRV records. None to use a default
-            implementation.
+        _srv_resolver:
+            SrvResolver implementation to use for looking up SRV records. None
+            to use a default implementation.
 
-        _well_known_resolver (WellKnownResolver|None):
+        _well_known_resolver:
             WellKnownResolver to use to perform well-known lookups. None to use a
             default implementation.
     """
 
     def __init__(
         self,
-        reactor,
-        tls_client_options_factory,
-        user_agent,
-        _srv_resolver=None,
-        _well_known_resolver=None,
+        reactor: IReactorCore,
+        tls_client_options_factory: Optional[FederationPolicyForHTTPS],
+        user_agent: bytes,
+        _srv_resolver: Optional[SrvResolver] = None,
+        _well_known_resolver: Optional[WellKnownResolver] = None,
     ):
         self._reactor = reactor
         self._clock = Clock(reactor)
@@ -99,15 +103,20 @@ class MatrixFederationAgent:
         self._well_known_resolver = _well_known_resolver
 
     @defer.inlineCallbacks
-    def request(self, method, uri, headers=None, bodyProducer=None):
+    def request(
+        self,
+        method: bytes,
+        uri: bytes,
+        headers: Optional[Headers] = None,
+        bodyProducer: Optional[IBodyProducer] = None,
+    ) -> defer.Deferred:
         """
         Args:
-            method (bytes): HTTP method: GET/POST/etc
-            uri (bytes): Absolute URI to be retrieved
-            headers (twisted.web.http_headers.Headers|None):
-                HTTP headers to send with the request, or None to
-                send no extra headers.
-            bodyProducer (twisted.web.iweb.IBodyProducer|None):
+            method: HTTP method: GET/POST/etc
+            uri: Absolute URI to be retrieved
+            headers:
+                HTTP headers to send with the request, or None to send no extra headers.
+            bodyProducer:
                 An object which can generate bytes to make up the
                 body of this request (for example, the properly encoded contents of
                 a file for a file upload).  Or None if the request is to have
@@ -123,6 +132,9 @@ class MatrixFederationAgent:
         # explicit port.
         parsed_uri = urllib.parse.urlparse(uri)
 
+        # There must be a valid hostname.
+        assert parsed_uri.hostname
+
         # If this is a matrix:// URI check if the server has delegated matrix
         # traffic using well-known delegation.
         #
@@ -179,7 +191,12 @@ class MatrixHostnameEndpointFactory:
     """Factory for MatrixHostnameEndpoint for parsing to an Agent.
     """
 
-    def __init__(self, reactor, tls_client_options_factory, srv_resolver):
+    def __init__(
+        self,
+        reactor: IReactorCore,
+        tls_client_options_factory: Optional[FederationPolicyForHTTPS],
+        srv_resolver: Optional[SrvResolver],
+    ):
         self._reactor = reactor
         self._tls_client_options_factory = tls_client_options_factory
 
@@ -203,15 +220,20 @@ class MatrixHostnameEndpoint:
     resolution (i.e. via SRV). Does not check for well-known delegation.
 
     Args:
-        reactor (IReactor)
-        tls_client_options_factory (ClientTLSOptionsFactory|None):
+        reactor: twisted reactor to use for underlying requests
+        tls_client_options_factory:
             factory to use for fetching client tls options, or none to disable TLS.
-        srv_resolver (SrvResolver): The SRV resolver to use
-        parsed_uri (twisted.web.client.URI): The parsed URI that we're wanting
-            to connect to.
+        srv_resolver: The SRV resolver to use
+        parsed_uri: The parsed URI that we're wanting to connect to.
     """
 
-    def __init__(self, reactor, tls_client_options_factory, srv_resolver, parsed_uri):
+    def __init__(
+        self,
+        reactor: IReactorCore,
+        tls_client_options_factory: Optional[FederationPolicyForHTTPS],
+        srv_resolver: SrvResolver,
+        parsed_uri: URI,
+    ):
         self._reactor = reactor
 
         self._parsed_uri = parsed_uri
@@ -231,13 +253,13 @@ class MatrixHostnameEndpoint:
 
         self._srv_resolver = srv_resolver
 
-    def connect(self, protocol_factory):
+    def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred:
         """Implements IStreamClientEndpoint interface
         """
 
         return run_in_background(self._do_connect, protocol_factory)
 
-    async def _do_connect(self, protocol_factory):
+    async def _do_connect(self, protocol_factory: IProtocolFactory) -> None:
         first_exception = None
 
         server_list = await self._resolve_server()
@@ -303,20 +325,20 @@ class MatrixHostnameEndpoint:
         return [Server(host, 8448)]
 
 
-def _is_ip_literal(host):
+def _is_ip_literal(host: bytes) -> bool:
     """Test if the given host name is either an IPv4 or IPv6 literal.
 
     Args:
-        host (bytes)
+        host: The host name to check
 
     Returns:
-        bool
+        True if the hostname is an IP address literal.
     """
 
-    host = host.decode("ascii")
+    host_str = host.decode("ascii")
 
     try:
-        IPAddress(host)
+        IPAddress(host_str)
         return True
     except AddrFormatError:
         return False