summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--mypy.ini3
-rw-r--r--synapse/handlers/register.py13
-rw-r--r--synapse/replication/http/_base.py32
3 files changed, 37 insertions, 11 deletions
diff --git a/mypy.ini b/mypy.ini
index c63b01fe99..8023ccc3e8 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -99,6 +99,9 @@ disallow_untyped_defs = True
 [mypy-synapse.rest.*]
 disallow_untyped_defs = True
 
+[mypy-synapse.replication.http._base]
+disallow_untyped_defs = True
+
 [mypy-synapse.state.*]
 disallow_untyped_defs = True
 
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 441af7a848..5c08961ce5 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -15,7 +15,17 @@
 """Contains functions for registering clients."""
 
 import logging
-from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Awaitable,
+    Callable,
+    Iterable,
+    List,
+    Mapping,
+    Optional,
+    Tuple,
+)
 
 from prometheus_client import Counter
 from typing_extensions import TypedDict
@@ -103,6 +113,7 @@ class RegistrationHandler(BaseHandler):
 
         self.spam_checker = hs.get_spam_checker()
 
+        self._register_device_client: Callable[..., Awaitable[Mapping[str, Any]]]
         if hs.config.worker.worker_app:
             self._register_client = ReplicationRegisterServlet.make_client(hs)
             self._register_device_client = RegisterDeviceReplicationServlet.make_client(
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index f1b78d09f9..684bd009cb 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -17,14 +17,18 @@ import logging
 import re
 import urllib
 from inspect import signature
-from typing import TYPE_CHECKING, Dict, List, Tuple
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple
 
 from prometheus_client import Counter, Gauge
 
+from twisted.web.http import Request
+
 from synapse.api.errors import HttpResponseException, SynapseError
 from synapse.http import RequestTimedOutError
+from synapse.http.server import HttpServer
 from synapse.logging import opentracing
 from synapse.logging.opentracing import trace
+from synapse.types import JsonDict
 from synapse.util.caches.response_cache import ResponseCache
 from synapse.util.stringutils import random_string
 
@@ -113,10 +117,11 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         if hs.config.worker.worker_replication_secret:
             self._replication_secret = hs.config.worker.worker_replication_secret
 
-    def _check_auth(self, request) -> None:
+    def _check_auth(self, request: Request) -> None:
         # Get the authorization header.
         auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
-
+        if auth_headers is None:
+            raise RuntimeError("No Authorization header.")
         if len(auth_headers) > 1:
             raise RuntimeError("Too many Authorization headers.")
         parts = auth_headers[0].split(b" ")
@@ -129,7 +134,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         raise RuntimeError("Invalid Authorization header.")
 
     @abc.abstractmethod
-    async def _serialize_payload(**kwargs):
+    async def _serialize_payload(**kwargs: str) -> Dict[str, Any]:
         """Static method that is called when creating a request.
 
         Concrete implementations should have explicit parameters (rather than
@@ -144,7 +149,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         return {}
 
     @abc.abstractmethod
-    async def _handle_request(self, request, **kwargs):
+    async def _handle_request(
+        self, request: Request, **kwargs: str
+    ) -> Tuple[int, JsonDict]:
         """Handle incoming request.
 
         This is called with the request object and PATH_ARGS.
@@ -156,7 +163,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         pass
 
     @classmethod
-    def make_client(cls, hs):
+    def make_client(cls, hs: HomeServer) -> Callable[..., Awaitable[JsonDict]]:
         """Create a client that makes requests.
 
         Returns a callable that accepts the same parameters as
@@ -183,7 +190,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
 
         @trace(opname="outgoing_replication_request")
         @outgoing_gauge.track_inprogress()
-        async def send_request(*, instance_name="master", **kwargs):
+        async def send_request(
+            *, instance_name: str = "master", **kwargs: str
+        ) -> JsonDict:
             if instance_name == local_instance_name:
                 raise Exception("Trying to send HTTP request to self")
             if instance_name == "master":
@@ -207,6 +216,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
                 txn_id = random_string(10)
                 url_args.append(txn_id)
 
+            request_func: Callable[..., Awaitable[JsonDict]]
             if cls.METHOD == "POST":
                 request_func = client.post_json_get_json
             elif cls.METHOD == "PUT":
@@ -264,7 +274,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
 
         return send_request
 
-    def register(self, http_server):
+    def register(self, http_server: HttpServer) -> None:
         """Called by the server to register this as a handler to the
         appropriate path.
         """
@@ -285,7 +295,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
             self.__class__.__name__,
         )
 
-    async def _check_auth_and_handle(self, request, **kwargs):
+    async def _check_auth_and_handle(
+        self, request: Request, **kwargs: str
+    ) -> Tuple[int, JsonDict]:
         """Called on new incoming requests when caching is enabled. Checks
         if there is a cached response for the request and returns that,
         otherwise calls `_handle_request` and caches its response.
@@ -301,7 +313,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
             txn_id = kwargs.pop("txn_id")
 
             return await self.response_cache.wrap(
-                txn_id, self._handle_request, request, **kwargs
+                txn_id, self._handle_request, request, cache_context=False, **kwargs
             )
 
         return await self._handle_request(request, **kwargs)