diff --git a/mypy.ini b/mypy.ini
index c63b01fe99..8023ccc3e8 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -99,6 +99,9 @@ disallow_untyped_defs = True
[mypy-synapse.rest.*]
disallow_untyped_defs = True
+[mypy-synapse.replication.http._base]
+disallow_untyped_defs = True
+
[mypy-synapse.state.*]
disallow_untyped_defs = True
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 441af7a848..5c08961ce5 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -15,7 +15,17 @@
"""Contains functions for registering clients."""
import logging
-from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Awaitable,
+ Callable,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Tuple,
+)
from prometheus_client import Counter
from typing_extensions import TypedDict
@@ -103,6 +113,7 @@ class RegistrationHandler(BaseHandler):
self.spam_checker = hs.get_spam_checker()
+ self._register_device_client: Callable[..., Awaitable[Mapping[str, Any]]]
if hs.config.worker.worker_app:
self._register_client = ReplicationRegisterServlet.make_client(hs)
self._register_device_client = RegisterDeviceReplicationServlet.make_client(
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index f1b78d09f9..684bd009cb 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -17,14 +17,18 @@ import logging
import re
import urllib
from inspect import signature
-from typing import TYPE_CHECKING, Dict, List, Tuple
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple
from prometheus_client import Counter, Gauge
+from twisted.web.http import Request
+
from synapse.api.errors import HttpResponseException, SynapseError
from synapse.http import RequestTimedOutError
+from synapse.http.server import HttpServer
from synapse.logging import opentracing
from synapse.logging.opentracing import trace
+from synapse.types import JsonDict
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string
@@ -113,10 +117,11 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
if hs.config.worker.worker_replication_secret:
self._replication_secret = hs.config.worker.worker_replication_secret
- def _check_auth(self, request) -> None:
+ def _check_auth(self, request: Request) -> None:
# Get the authorization header.
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
-
+ if auth_headers is None:
+ raise RuntimeError("No Authorization header.")
if len(auth_headers) > 1:
raise RuntimeError("Too many Authorization headers.")
parts = auth_headers[0].split(b" ")
@@ -129,7 +134,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
raise RuntimeError("Invalid Authorization header.")
@abc.abstractmethod
- async def _serialize_payload(**kwargs):
+ async def _serialize_payload(**kwargs: str) -> Dict[str, Any]:
"""Static method that is called when creating a request.
Concrete implementations should have explicit parameters (rather than
@@ -144,7 +149,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
return {}
@abc.abstractmethod
- async def _handle_request(self, request, **kwargs):
+ async def _handle_request(
+ self, request: Request, **kwargs: str
+ ) -> Tuple[int, JsonDict]:
"""Handle incoming request.
This is called with the request object and PATH_ARGS.
@@ -156,7 +163,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
pass
@classmethod
- def make_client(cls, hs):
+ def make_client(cls, hs: HomeServer) -> Callable[..., Awaitable[JsonDict]]:
"""Create a client that makes requests.
Returns a callable that accepts the same parameters as
@@ -183,7 +190,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
@trace(opname="outgoing_replication_request")
@outgoing_gauge.track_inprogress()
- async def send_request(*, instance_name="master", **kwargs):
+ async def send_request(
+ *, instance_name: str = "master", **kwargs: str
+ ) -> JsonDict:
if instance_name == local_instance_name:
raise Exception("Trying to send HTTP request to self")
if instance_name == "master":
@@ -207,6 +216,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
txn_id = random_string(10)
url_args.append(txn_id)
+ request_func: Callable[..., Awaitable[JsonDict]]
if cls.METHOD == "POST":
request_func = client.post_json_get_json
elif cls.METHOD == "PUT":
@@ -264,7 +274,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
return send_request
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
"""Called by the server to register this as a handler to the
appropriate path.
"""
@@ -285,7 +295,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
self.__class__.__name__,
)
- async def _check_auth_and_handle(self, request, **kwargs):
+ async def _check_auth_and_handle(
+ self, request: Request, **kwargs: str
+ ) -> Tuple[int, JsonDict]:
"""Called on new incoming requests when caching is enabled. Checks
if there is a cached response for the request and returns that,
otherwise calls `_handle_request` and caches its response.
@@ -301,7 +313,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
txn_id = kwargs.pop("txn_id")
return await self.response_cache.wrap(
- txn_id, self._handle_request, request, **kwargs
+ txn_id, self._handle_request, request, cache_context=False, **kwargs
)
return await self._handle_request(request, **kwargs)
|