summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/9518.misc1
-rw-r--r--synapse/http/federation/matrix_federation_agent.py18
-rw-r--r--synapse/http/matrixfederationclient.py6
-rw-r--r--synapse/http/server.py29
-rw-r--r--synapse/http/site.py35
-rw-r--r--synapse/logging/_remote.py6
-rw-r--r--synapse/metrics/__init__.py11
-rw-r--r--synapse/module_api/__init__.py4
-rw-r--r--synapse/push/httppusher.py5
-rw-r--r--synapse/replication/tcp/client.py4
-rw-r--r--synapse/server.py3
-rw-r--r--tests/rest/client/v1/test_login.py35
12 files changed, 96 insertions, 61 deletions
diff --git a/changelog.d/9518.misc b/changelog.d/9518.misc
new file mode 100644
index 0000000000..14c7b78dd9
--- /dev/null
+++ b/changelog.d/9518.misc
@@ -0,0 +1 @@
+Fix incorrect type hints.
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 2e83fa6773..b07aa59c08 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import logging
 import urllib.parse
-from typing import List, Optional
+from typing import Any, Generator, List, Optional
 
 from netaddr import AddrFormatError, IPAddress, IPSet
 from zope.interface import implementer
@@ -116,7 +116,7 @@ class MatrixFederationAgent:
         uri: bytes,
         headers: Optional[Headers] = None,
         bodyProducer: Optional[IBodyProducer] = None,
-    ) -> defer.Deferred:
+    ) -> Generator[defer.Deferred, Any, defer.Deferred]:
         """
         Args:
             method: HTTP method: GET/POST/etc
@@ -177,17 +177,17 @@ class MatrixFederationAgent:
         # We need to make sure the host header is set to the netloc of the
         # server and that a user-agent is provided.
         if headers is None:
-            headers = Headers()
+            request_headers = Headers()
         else:
-            headers = headers.copy()
+            request_headers = headers.copy()
 
-        if not headers.hasHeader(b"host"):
-            headers.addRawHeader(b"host", parsed_uri.netloc)
-        if not headers.hasHeader(b"user-agent"):
-            headers.addRawHeader(b"user-agent", self.user_agent)
+        if not request_headers.hasHeader(b"host"):
+            request_headers.addRawHeader(b"host", parsed_uri.netloc)
+        if not request_headers.hasHeader(b"user-agent"):
+            request_headers.addRawHeader(b"user-agent", self.user_agent)
 
         res = yield make_deferred_yieldable(
-            self._agent.request(method, uri, headers, bodyProducer)
+            self._agent.request(method, uri, request_headers, bodyProducer)
         )
 
         return res
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index cde42e9f5e..0f107714ea 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -1049,14 +1049,14 @@ def check_content_type_is_json(headers: Headers) -> None:
         RequestSendFailed: if the Content-Type header is missing or isn't JSON
 
     """
-    c_type = headers.getRawHeaders(b"Content-Type")
-    if c_type is None:
+    content_type_headers = headers.getRawHeaders(b"Content-Type")
+    if content_type_headers is None:
         raise RequestSendFailed(
             RuntimeError("No Content-Type header received from remote server"),
             can_retry=False,
         )
 
-    c_type = c_type[0].decode("ascii")  # only the first header
+    c_type = content_type_headers[0].decode("ascii")  # only the first header
     val, options = cgi.parse_header(c_type)
     if val != "application/json":
         raise RequestSendFailed(
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 845db9b78d..fa89260850 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -21,6 +21,7 @@ import logging
 import types
 import urllib
 from http import HTTPStatus
+from inspect import isawaitable
 from io import BytesIO
 from typing import (
     Any,
@@ -30,6 +31,7 @@ from typing import (
     Iterable,
     Iterator,
     List,
+    Optional,
     Pattern,
     Tuple,
     Union,
@@ -79,10 +81,12 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
     """Sends a JSON error response to clients."""
 
     if f.check(SynapseError):
-        error_code = f.value.code
-        error_dict = f.value.error_dict()
+        # mypy doesn't understand that f.check asserts the type.
+        exc = f.value  # type: SynapseError  # type: ignore
+        error_code = exc.code
+        error_dict = exc.error_dict()
 
-        logger.info("%s SynapseError: %s - %s", request, error_code, f.value.msg)
+        logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg)
     else:
         error_code = 500
         error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN}
@@ -91,7 +95,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
             "Failed handle request via %r: %r",
             request.request_metrics.name,
             request,
-            exc_info=(f.type, f.value, f.getTracebackObject()),
+            exc_info=(f.type, f.value, f.getTracebackObject()),  # type: ignore
         )
 
     # Only respond with an error response if we haven't already started writing,
@@ -128,7 +132,8 @@ def return_html_error(
             `{msg}` placeholders), or a jinja2 template
     """
     if f.check(CodeMessageException):
-        cme = f.value
+        # mypy doesn't understand that f.check asserts the type.
+        cme = f.value  # type: CodeMessageException  # type: ignore
         code = cme.code
         msg = cme.msg
 
@@ -142,7 +147,7 @@ def return_html_error(
             logger.error(
                 "Failed handle request %r",
                 request,
-                exc_info=(f.type, f.value, f.getTracebackObject()),
+                exc_info=(f.type, f.value, f.getTracebackObject()),  # type: ignore
             )
     else:
         code = HTTPStatus.INTERNAL_SERVER_ERROR
@@ -151,7 +156,7 @@ def return_html_error(
         logger.error(
             "Failed handle request %r",
             request,
-            exc_info=(f.type, f.value, f.getTracebackObject()),
+            exc_info=(f.type, f.value, f.getTracebackObject()),  # type: ignore
         )
 
     if isinstance(error_template, str):
@@ -278,7 +283,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
             raw_callback_return = method_handler(request)
 
             # Is it synchronous? We'll allow this for now.
-            if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
+            if isawaitable(raw_callback_return):
                 callback_return = await raw_callback_return
             else:
                 callback_return = raw_callback_return  # type: ignore
@@ -399,8 +404,10 @@ class JsonResource(DirectServeJsonResource):
             A tuple of the callback to use, the name of the servlet, and the
             key word arguments to pass to the callback
         """
+        # At this point the path must be bytes.
+        request_path_bytes = request.path  # type: bytes  # type: ignore
+        request_path = request_path_bytes.decode("ascii")
         # Treat HEAD requests as GET requests.
-        request_path = request.path.decode("ascii")
         request_method = request.method
         if request_method == b"HEAD":
             request_method = b"GET"
@@ -551,7 +558,7 @@ class _ByteProducer:
         request: Request,
         iterator: Iterator[bytes],
     ):
-        self._request = request
+        self._request = request  # type: Optional[Request]
         self._iterator = iterator
         self._paused = False
 
@@ -563,7 +570,7 @@ class _ByteProducer:
         """
         Send a list of bytes as a chunk of a response.
         """
-        if not data:
+        if not data or not self._request:
             return
         self._request.write(b"".join(data))
 
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 30153237e3..47754aff43 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -14,7 +14,7 @@
 import contextlib
 import logging
 import time
-from typing import Optional, Union
+from typing import Optional, Type, Union
 
 import attr
 from zope.interface import implementer
@@ -57,7 +57,7 @@ class SynapseRequest(Request):
 
     def __init__(self, channel, *args, **kw):
         Request.__init__(self, channel, *args, **kw)
-        self.site = channel.site
+        self.site = channel.site  # type: SynapseSite
         self._channel = channel  # this is used by the tests
         self.start_time = 0.0
 
@@ -96,25 +96,34 @@ class SynapseRequest(Request):
     def get_request_id(self):
         return "%s-%i" % (self.get_method(), self.request_seq)
 
-    def get_redacted_uri(self):
-        uri = self.uri
+    def get_redacted_uri(self) -> str:
+        """Gets the redacted URI associated with the request (or placeholder if the URI
+        has not yet been received).
+
+        Note: This is necessary as the placeholder value in twisted is str
+        rather than bytes, so we need to sanitise `self.uri`.
+
+        Returns:
+            The redacted URI as a string.
+        """
+        uri = self.uri  # type: Union[bytes, str]
         if isinstance(uri, bytes):
-            uri = self.uri.decode("ascii", errors="replace")
+            uri = uri.decode("ascii", errors="replace")
         return redact_uri(uri)
 
-    def get_method(self):
-        """Gets the method associated with the request (or placeholder if not
-        method has yet been received).
+    def get_method(self) -> str:
+        """Gets the method associated with the request (or placeholder if method
+        has not yet been received).
 
         Note: This is necessary as the placeholder value in twisted is str
         rather than bytes, so we need to sanitise `self.method`.
 
         Returns:
-            str
+            The request method as a string.
         """
-        method = self.method
+        method = self.method  # type: Union[bytes, str]
         if isinstance(method, bytes):
-            method = self.method.decode("ascii")
+            return self.method.decode("ascii")
         return method
 
     def render(self, resrc):
@@ -432,7 +441,9 @@ class SynapseSite(Site):
 
         assert config.http_options is not None
         proxied = config.http_options.x_forwarded
-        self.requestFactory = XForwardedForRequest if proxied else SynapseRequest
+        self.requestFactory = (
+            XForwardedForRequest if proxied else SynapseRequest
+        )  # type: Type[Request]
         self.access_logger = logging.getLogger(logger_name)
         self.server_version_string = server_version_string.encode("ascii")
 
diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py
index f8e9112b56..174ca7be5a 100644
--- a/synapse/logging/_remote.py
+++ b/synapse/logging/_remote.py
@@ -32,7 +32,7 @@ from twisted.internet.endpoints import (
     TCP4ClientEndpoint,
     TCP6ClientEndpoint,
 )
-from twisted.internet.interfaces import IPushProducer, ITransport
+from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint, ITransport
 from twisted.internet.protocol import Factory, Protocol
 from twisted.python.failure import Failure
 
@@ -121,7 +121,9 @@ class RemoteHandler(logging.Handler):
         try:
             ip = ip_address(self.host)
             if isinstance(ip, IPv4Address):
-                endpoint = TCP4ClientEndpoint(_reactor, self.host, self.port)
+                endpoint = TCP4ClientEndpoint(
+                    _reactor, self.host, self.port
+                )  # type: IStreamClientEndpoint
             elif isinstance(ip, IPv6Address):
                 endpoint = TCP6ClientEndpoint(_reactor, self.host, self.port)
             else:
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index a8cb49d5b4..3b499efc07 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -527,7 +527,7 @@ class ReactorLastSeenMetric:
 REGISTRY.register(ReactorLastSeenMetric())
 
 
-def runUntilCurrentTimer(func):
+def runUntilCurrentTimer(reactor, func):
     @functools.wraps(func)
     def f(*args, **kwargs):
         now = reactor.seconds()
@@ -590,13 +590,14 @@ def runUntilCurrentTimer(func):
 
 try:
     # Ensure the reactor has all the attributes we expect
-    reactor.runUntilCurrent
-    reactor._newTimedCalls
-    reactor.threadCallQueue
+    reactor.seconds  # type: ignore
+    reactor.runUntilCurrent  # type: ignore
+    reactor._newTimedCalls  # type: ignore
+    reactor.threadCallQueue  # type: ignore
 
     # runUntilCurrent is called when we have pending calls. It is called once
     # per iteratation after fd polling.
-    reactor.runUntilCurrent = runUntilCurrentTimer(reactor.runUntilCurrent)
+    reactor.runUntilCurrent = runUntilCurrentTimer(reactor, reactor.runUntilCurrent)  # type: ignore
 
     # We manually run the GC each reactor tick so that we can get some metrics
     # about time spent doing GC,
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 2e3b311c4a..db2d400b7e 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Iterable, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Generator, Iterable, Optional, Tuple
 
 from twisted.internet import defer
 
@@ -307,7 +307,7 @@ class ModuleApi:
     @defer.inlineCallbacks
     def get_state_events_in_room(
         self, room_id: str, types: Iterable[Tuple[str, Optional[str]]]
-    ) -> defer.Deferred:
+    ) -> Generator[defer.Deferred, Any, defer.Deferred]:
         """Gets current state events for the given room.
 
         (This is exposed for compatibility with the old SpamCheckerApi. We should
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index f4d7e199e9..eb6de8ba72 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -15,11 +15,12 @@
 # limitations under the License.
 import logging
 import urllib.parse
-from typing import TYPE_CHECKING, Any, Dict, Iterable, Union
+from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union
 
 from prometheus_client import Counter
 
 from twisted.internet.error import AlreadyCalled, AlreadyCancelled
+from twisted.internet.interfaces import IDelayedCall
 
 from synapse.api.constants import EventTypes
 from synapse.events import EventBase
@@ -71,7 +72,7 @@ class HttpPusher(Pusher):
         self.data = pusher_config.data
         self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
         self.failing_since = pusher_config.failing_since
-        self.timed_call = None
+        self.timed_call = None  # type: Optional[IDelayedCall]
         self._is_processing = False
         self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
         self._pusherpool = hs.get_pusherpool()
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 2618eb1e53..3455839d67 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -108,9 +108,7 @@ class ReplicationDataHandler:
 
         # Map from stream to list of deferreds waiting for the stream to
         # arrive at a particular position. The lists are sorted by stream position.
-        self._streams_to_waiters = (
-            {}
-        )  # type: Dict[str, List[Tuple[int, Deferred[None]]]]
+        self._streams_to_waiters = {}  # type: Dict[str, List[Tuple[int, Deferred]]]
 
     async def on_rdata(
         self, stream_name: str, instance_name: str, token: int, rows: list
diff --git a/synapse/server.py b/synapse/server.py
index 1d4370e0ba..afd7cd72e7 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -38,6 +38,7 @@ from typing import (
 
 import twisted.internet.base
 import twisted.internet.tcp
+from twisted.internet import defer
 from twisted.mail.smtp import sendmail
 from twisted.web.iweb import IPolicyForHTTPS
 
@@ -403,7 +404,7 @@ class HomeServer(metaclass=abc.ABCMeta):
         return RoomShutdownHandler(self)
 
     @cache_in_self
-    def get_sendmail(self) -> sendmail:
+    def get_sendmail(self) -> Callable[..., defer.Deferred]:
         return sendmail
 
     @cache_in_self
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 744d8d0941..20af3285bd 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -522,7 +522,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
             shorthand=False,
         )
         self.assertEqual(channel.code, 302, channel.result)
-        cas_uri = channel.headers.getRawHeaders("Location")[0]
+        location_headers = channel.headers.getRawHeaders("Location")
+        assert location_headers
+        cas_uri = location_headers[0]
         cas_uri_path, cas_uri_query = cas_uri.split("?", 1)
 
         # it should redirect us to the login page of the cas server
@@ -545,7 +547,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
             + "&idp=saml",
         )
         self.assertEqual(channel.code, 302, channel.result)
-        saml_uri = channel.headers.getRawHeaders("Location")[0]
+        location_headers = channel.headers.getRawHeaders("Location")
+        assert location_headers
+        saml_uri = location_headers[0]
         saml_uri_path, saml_uri_query = saml_uri.split("?", 1)
 
         # it should redirect us to the login page of the SAML server
@@ -567,17 +571,21 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
             + "&idp=oidc",
         )
         self.assertEqual(channel.code, 302, channel.result)
-        oidc_uri = channel.headers.getRawHeaders("Location")[0]
+        location_headers = channel.headers.getRawHeaders("Location")
+        assert location_headers
+        oidc_uri = location_headers[0]
         oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
 
         # it should redirect us to the auth page of the OIDC server
         self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
 
         # ... and should have set a cookie including the redirect url
-        cookies = dict(
-            h.split(";")[0].split("=", maxsplit=1)
-            for h in channel.headers.getRawHeaders("Set-Cookie")
-        )
+        cookie_headers = channel.headers.getRawHeaders("Set-Cookie")
+        assert cookie_headers
+        cookies = {}  # type: Dict[str, str]
+        for h in cookie_headers:
+            key, value = h.split(";")[0].split("=", maxsplit=1)
+            cookies[key] = value
 
         oidc_session_cookie = cookies["oidc_session"]
         macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie)
@@ -590,9 +598,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
 
         # that should serve a confirmation page
         self.assertEqual(channel.code, 200, channel.result)
-        self.assertTrue(
-            channel.headers.getRawHeaders("Content-Type")[-1].startswith("text/html")
-        )
+        content_type_headers = channel.headers.getRawHeaders("Content-Type")
+        assert content_type_headers
+        self.assertTrue(content_type_headers[-1].startswith("text/html"))
         p = TestHtmlParser()
         p.feed(channel.text_body)
         p.close()
@@ -806,6 +814,7 @@ class CASTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(channel.code, 302)
         location_headers = channel.headers.getRawHeaders("Location")
+        assert location_headers
         self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url)
 
     @override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}})
@@ -1248,7 +1257,9 @@ class UsernamePickerTestCase(HomeserverTestCase):
 
         # that should redirect to the username picker
         self.assertEqual(channel.code, 302, channel.result)
-        picker_url = channel.headers.getRawHeaders("Location")[0]
+        location_headers = channel.headers.getRawHeaders("Location")
+        assert location_headers
+        picker_url = location_headers[0]
         self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details")
 
         # ... with a username_mapping_session cookie
@@ -1291,6 +1302,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
         )
         self.assertEqual(chan.code, 302, chan.result)
         location_headers = chan.headers.getRawHeaders("Location")
+        assert location_headers
 
         # send a request to the completion page, which should 302 to the client redirectUrl
         chan = self.make_request(
@@ -1300,6 +1312,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
         )
         self.assertEqual(chan.code, 302, chan.result)
         location_headers = chan.headers.getRawHeaders("Location")
+        assert location_headers
 
         # ensure that the returned location matches the requested redirect URL
         path, query = location_headers[0].split("?", 1)