diff options
-rw-r--r-- | mypy.ini | 3 | ||||
-rw-r--r-- | synapse/handlers/register.py | 13 | ||||
-rw-r--r-- | synapse/replication/http/_base.py | 32 |
3 files changed, 37 insertions, 11 deletions
diff --git a/mypy.ini b/mypy.ini index c63b01fe99..8023ccc3e8 100644 --- a/mypy.ini +++ b/mypy.ini @@ -99,6 +99,9 @@ disallow_untyped_defs = True [mypy-synapse.rest.*] disallow_untyped_defs = True +[mypy-synapse.replication.http._base] +disallow_untyped_defs = True + [mypy-synapse.state.*] disallow_untyped_defs = True diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 441af7a848..5c08961ce5 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -15,7 +15,17 @@ """Contains functions for registering clients.""" import logging -from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Iterable, + List, + Mapping, + Optional, + Tuple, +) from prometheus_client import Counter from typing_extensions import TypedDict @@ -103,6 +113,7 @@ class RegistrationHandler(BaseHandler): self.spam_checker = hs.get_spam_checker() + self._register_device_client: Callable[..., Awaitable[Mapping[str, Any]]] if hs.config.worker.worker_app: self._register_client = ReplicationRegisterServlet.make_client(hs) self._register_device_client = RegisterDeviceReplicationServlet.make_client( diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index f1b78d09f9..684bd009cb 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -17,14 +17,18 @@ import logging import re import urllib from inspect import signature -from typing import TYPE_CHECKING, Dict, List, Tuple +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple from prometheus_client import Counter, Gauge +from twisted.web.http import Request + from synapse.api.errors import HttpResponseException, SynapseError from synapse.http import RequestTimedOutError +from synapse.http.server import HttpServer from synapse.logging import opentracing from synapse.logging.opentracing import trace +from synapse.types import JsonDict from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import random_string @@ -113,10 +117,11 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): if hs.config.worker.worker_replication_secret: self._replication_secret = hs.config.worker.worker_replication_secret - def _check_auth(self, request) -> None: + def _check_auth(self, request: Request) -> None: # Get the authorization header. auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") - + if auth_headers is None: + raise RuntimeError("No Authorization header.") if len(auth_headers) > 1: raise RuntimeError("Too many Authorization headers.") parts = auth_headers[0].split(b" ") @@ -129,7 +134,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): raise RuntimeError("Invalid Authorization header.") @abc.abstractmethod - async def _serialize_payload(**kwargs): + async def _serialize_payload(**kwargs: str) -> Dict[str, Any]: """Static method that is called when creating a request. Concrete implementations should have explicit parameters (rather than @@ -144,7 +149,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): return {} @abc.abstractmethod - async def _handle_request(self, request, **kwargs): + async def _handle_request( + self, request: Request, **kwargs: str + ) -> Tuple[int, JsonDict]: """Handle incoming request. This is called with the request object and PATH_ARGS. @@ -156,7 +163,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): pass @classmethod - def make_client(cls, hs): + def make_client(cls, hs: HomeServer) -> Callable[..., Awaitable[JsonDict]]: """Create a client that makes requests. Returns a callable that accepts the same parameters as @@ -183,7 +190,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): @trace(opname="outgoing_replication_request") @outgoing_gauge.track_inprogress() - async def send_request(*, instance_name="master", **kwargs): + async def send_request( + *, instance_name: str = "master", **kwargs: str + ) -> JsonDict: if instance_name == local_instance_name: raise Exception("Trying to send HTTP request to self") if instance_name == "master": @@ -207,6 +216,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): txn_id = random_string(10) url_args.append(txn_id) + request_func: Callable[..., Awaitable[JsonDict]] if cls.METHOD == "POST": request_func = client.post_json_get_json elif cls.METHOD == "PUT": @@ -264,7 +274,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): return send_request - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: """Called by the server to register this as a handler to the appropriate path. """ @@ -285,7 +295,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): self.__class__.__name__, ) - async def _check_auth_and_handle(self, request, **kwargs): + async def _check_auth_and_handle( + self, request: Request, **kwargs: str + ) -> Tuple[int, JsonDict]: """Called on new incoming requests when caching is enabled. Checks if there is a cached response for the request and returns that, otherwise calls `_handle_request` and caches its response. @@ -301,7 +313,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): txn_id = kwargs.pop("txn_id") return await self.response_cache.wrap( - txn_id, self._handle_request, request, **kwargs + txn_id, self._handle_request, request, cache_context=False, **kwargs ) return await self._handle_request(request, **kwargs) |