From b5efcb577e2c9b8b38cb86f87cf65fa93eb2566b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 26 Mar 2021 16:49:46 +0000 Subject: Make it possible to use dmypy (#9692) Running `dmypy run` will do a `mypy` check while spinning up a daemon that makes rerunning `dmypy run` a lot faster. `dmypy` doesn't support `follow_imports = silent` and has `local_partial_types` enabled, so this PR enables those options and fixes the issues that were newly raised. Note that `local_partial_types` will be enabled by default in upcoming mypy releases. --- changelog.d/9692.misc | 1 + mypy.ini | 3 ++- synapse/api/auth.py | 5 +++++ synapse/config/cache.py | 6 ++++-- synapse/handlers/oidc_handler.py | 3 +++ synapse/logging/opentracing.py | 2 +- synapse/replication/tcp/protocol.py | 2 +- synapse/rest/admin/rooms.py | 3 +++ synapse/rest/admin/users.py | 3 +++ synapse/rest/client/v2_alpha/sync.py | 3 +++ synapse/rest/media/v1/preview_url_resource.py | 2 ++ synapse/rest/synapse/client/pick_username.py | 3 +++ synapse/util/caches/__init__.py | 4 ++-- tests/replication/tcp/streams/test_typing.py | 1 + tests/replication/test_multi_media_repo.py | 4 ++-- tests/server.py | 28 +++++++++++++++++++-------- 16 files changed, 56 insertions(+), 17 deletions(-) create mode 100644 changelog.d/9692.misc diff --git a/changelog.d/9692.misc b/changelog.d/9692.misc new file mode 100644 index 0000000000..d02002586e --- /dev/null +++ b/changelog.d/9692.misc @@ -0,0 +1 @@ +Make it possible to use `dmypy`. diff --git a/mypy.ini b/mypy.ini index 709a8d07a5..3ae5d45787 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,12 +1,13 @@ [mypy] namespace_packages = True plugins = mypy_zope:plugin, scripts-dev/mypy_synapse_plugin.py -follow_imports = silent +follow_imports = normal check_untyped_defs = True show_error_codes = True show_traceback = True mypy_path = stubs warn_unreachable = True +local_partial_types = True # To find all folders that pass mypy you run: # diff --git a/synapse/api/auth.py b/synapse/api/auth.py index e10e33fd23..7d9930ae7b 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -558,6 +558,9 @@ class Auth: Returns: bool: False if no access_token was given, True otherwise. """ + # This will always be set by the time Twisted calls us. + assert request.args is not None + query_params = request.args.get(b"access_token") auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") return bool(query_params) or bool(auth_headers) @@ -574,6 +577,8 @@ class Auth: MissingClientTokenError: If there isn't a single access_token in the request """ + # This will always be set by the time Twisted calls us. + assert request.args is not None auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") query_params = request.args.get(b"access_token") diff --git a/synapse/config/cache.py b/synapse/config/cache.py index 8e03f14005..4e8abbf88a 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py @@ -24,7 +24,7 @@ from ._base import Config, ConfigError _CACHE_PREFIX = "SYNAPSE_CACHE_FACTOR" # Map from canonicalised cache name to cache. -_CACHES = {} +_CACHES = {} # type: Dict[str, Callable[[float], None]] # a lock on the contents of _CACHES _CACHES_LOCK = threading.Lock() @@ -59,7 +59,9 @@ def _canonicalise_cache_name(cache_name: str) -> str: return cache_name.lower() -def add_resizable_cache(cache_name: str, cache_resize_callback: Callable): +def add_resizable_cache( + cache_name: str, cache_resize_callback: Callable[[float], None] +): """Register a cache that's size can dynamically change Args: diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index bc3630e9e9..6624212d6f 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -149,6 +149,9 @@ class OidcHandler: Args: request: the incoming request from the browser. """ + # This will always be set by the time Twisted calls us. + assert request.args is not None + # The provider might redirect with an error. # In that case, just display it as-is. if b"error" in request.args: diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 10bd4a1461..c6e6335740 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -262,7 +262,7 @@ logger = logging.getLogger(__name__) # Block everything by default # A regex which matches the server_names to expose traces for. # None means 'block everything'. -_homeserver_whitelist = None +_homeserver_whitelist = None # type: Optional[re.Pattern[str]] # Util methods diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 825900f64c..e829add257 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -104,7 +104,7 @@ tcp_outbound_commands_counter = Counter( # A list of all connected protocols. This allows us to send metrics about the # connections. -connected_connections = [] +connected_connections = [] # type: List[BaseReplicationStreamProtocol] logger = logging.getLogger(__name__) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 263d8ec076..cfe1bebb91 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -390,6 +390,9 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): async def on_POST( self, request: SynapseRequest, room_identifier: str ) -> Tuple[int, JsonDict]: + # This will always be set by the time Twisted calls us. + assert request.args is not None + requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index aaa56a7024..309bd2771b 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -833,6 +833,9 @@ class UserMediaRestServlet(RestServlet): async def on_GET( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: + # This will always be set by the time Twisted calls us. + assert request.args is not None + await assert_requester_is_admin(self.auth, request) if not self.is_mine(UserID.from_string(user_id)): diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index a0db0a054b..3481770c83 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -91,6 +91,9 @@ class SyncRestServlet(RestServlet): self._event_serializer = hs.get_event_client_serializer() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + # This will always be set by the time Twisted calls us. + assert request.args is not None + if b"from" in request.args: # /events used to use 'from', but /sync uses 'since'. # Lets be helpful and whine if we see a 'from'. diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index e590a0deab..c4ed9dfdb4 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -187,6 +187,8 @@ class PreviewUrlResource(DirectServeJsonResource): respond_with_json(request, 200, {}, send_cors=True) async def _async_render_GET(self, request: SynapseRequest) -> None: + # This will always be set by the time Twisted calls us. + assert request.args is not None # XXX: if get_user_by_req fails, what should we do in an async render? requester = await self.auth.get_user_by_req(request) diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py index 51acaa9a92..d9ffe84489 100644 --- a/synapse/rest/synapse/client/pick_username.py +++ b/synapse/rest/synapse/client/pick_username.py @@ -104,6 +104,9 @@ class AccountDetailsResource(DirectServeHtmlResource): respond_with_html(request, 200, html) async def _async_render_POST(self, request: SynapseRequest): + # This will always be set by the time Twisted calls us. + assert request.args is not None + try: session_id = get_username_mapping_session_cookie_from_request(request) except SynapseError as e: diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index f968706334..48f64eeb38 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -25,8 +25,8 @@ from synapse.config.cache import add_resizable_cache logger = logging.getLogger(__name__) -caches_by_name = {} -collectors_by_name = {} # type: Dict +caches_by_name = {} # type: Dict[str, Sized] +collectors_by_name = {} # type: Dict[str, CacheMetric] cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"]) cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"]) diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py index 5acfb3e53e..ca49d4dd3a 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py @@ -69,6 +69,7 @@ class TypingStreamTestCase(BaseStreamTestCase): self.assert_request_is_get_repl_stream_updates(request, "typing") # The from token should be the token from the last RDATA we got. + assert request.args is not None self.assertEqual(int(request.args[b"from_token"][0]), token) self.test_handler.on_rdata.assert_called_once() diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index 7ff11cde10..b0800f9840 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -15,7 +15,7 @@ import logging import os from binascii import unhexlify -from typing import Tuple +from typing import Optional, Tuple from twisted.internet.protocol import Factory from twisted.protocols.tls import TLSMemoryBIOFactory @@ -32,7 +32,7 @@ from tests.server import FakeChannel, FakeSite, FakeTransport, make_request logger = logging.getLogger(__name__) -test_server_connection_factory = None +test_server_connection_factory = None # type: Optional[TestServerTLSConnectionFactory] class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): diff --git a/tests/server.py b/tests/server.py index 57cc4ac605..b535a5d886 100644 --- a/tests/server.py +++ b/tests/server.py @@ -2,7 +2,7 @@ import json import logging from collections import deque from io import SEEK_END, BytesIO -from typing import Callable, Iterable, MutableMapping, Optional, Tuple, Union +from typing import Callable, Dict, Iterable, MutableMapping, Optional, Tuple, Union import attr from typing_extensions import Deque @@ -13,8 +13,11 @@ from twisted.internet._resolver import SimpleResolverComplexifier from twisted.internet.defer import Deferred, fail, succeed from twisted.internet.error import DNSLookupError from twisted.internet.interfaces import ( + IHostnameResolver, + IProtocol, + IPullProducer, + IPushProducer, IReactorPluggableNameResolver, - IReactorTCP, IResolverSimple, ITransport, ) @@ -45,11 +48,11 @@ class FakeChannel: wire). """ - site = attr.ib(type=Site) + site = attr.ib(type=Union[Site, "FakeSite"]) _reactor = attr.ib() result = attr.ib(type=dict, default=attr.Factory(dict)) _ip = attr.ib(type=str, default="127.0.0.1") - _producer = None + _producer = None # type: Optional[Union[IPullProducer, IPushProducer]] @property def json_body(self): @@ -159,7 +162,11 @@ class FakeChannel: Any cookines found are added to the given dict """ - for h in self.headers.getRawHeaders("Set-Cookie"): + headers = self.headers.getRawHeaders("Set-Cookie") + if not headers: + return + + for h in headers: parts = h.split(";") k, v = parts[0].split("=", maxsplit=1) cookies[k] = v @@ -311,8 +318,8 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): self._tcp_callbacks = {} self._udp = [] - lookups = self.lookups = {} - self._thread_callbacks = deque() # type: Deque[Callable[[], None]]() + lookups = self.lookups = {} # type: Dict[str, str] + self._thread_callbacks = deque() # type: Deque[Callable[[], None]] @implementer(IResolverSimple) class FakeResolver: @@ -324,6 +331,9 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): self.nameResolver = SimpleResolverComplexifier(FakeResolver()) super().__init__() + def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver: + raise NotImplementedError() + def listenUDP(self, port, protocol, interface="", maxPacketSize=8196): p = udp.Port(port, protocol, interface, maxPacketSize, self) p.startListening() @@ -621,7 +631,9 @@ class FakeTransport: self.disconnected = True -def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol: +def connect_client( + reactor: ThreadedMemoryReactorClock, client_id: int +) -> Tuple[IProtocol, AccumulatingProtocol]: """ Connect a client to a fake TCP transport. -- cgit 1.4.1