summary refs log tree commit diff
path: root/synapse/replication/http/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/http/_base.py')
-rw-r--r--synapse/replication/http/_base.py31
1 files changed, 20 insertions, 11 deletions
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 585332b244..bc1d28dd19 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -15,16 +15,20 @@
 import abc
 import logging
 import re
-import urllib
+import urllib.parse
 from inspect import signature
 from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple
 
 from prometheus_client import Counter, Gauge
 
+from twisted.web.server import Request
+
 from synapse.api.errors import HttpResponseException, SynapseError
 from synapse.http import RequestTimedOutError
+from synapse.http.server import HttpServer
 from synapse.logging import opentracing
 from synapse.logging.opentracing import trace
+from synapse.types import JsonDict
 from synapse.util.caches.response_cache import ResponseCache
 from synapse.util.stringutils import random_string
 
@@ -113,10 +117,12 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         if hs.config.worker.worker_replication_secret:
             self._replication_secret = hs.config.worker.worker_replication_secret
 
-    def _check_auth(self, request) -> None:
+    def _check_auth(self, request: Request) -> None:
         # Get the authorization header.
         auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
 
+        if not auth_headers:
+            raise RuntimeError("Missing Authorization header.")
         if len(auth_headers) > 1:
             raise RuntimeError("Too many Authorization headers.")
         parts = auth_headers[0].split(b" ")
@@ -129,7 +135,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         raise RuntimeError("Invalid Authorization header.")
 
     @abc.abstractmethod
-    async def _serialize_payload(**kwargs):
+    async def _serialize_payload(**kwargs) -> JsonDict:
         """Static method that is called when creating a request.
 
         Concrete implementations should have explicit parameters (rather than
@@ -144,19 +150,20 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         return {}
 
     @abc.abstractmethod
-    async def _handle_request(self, request, **kwargs):
+    async def _handle_request(
+        self, request: Request, **kwargs: Any
+    ) -> Tuple[int, JsonDict]:
         """Handle incoming request.
 
         This is called with the request object and PATH_ARGS.
 
         Returns:
-            tuple[int, dict]: HTTP status code and a JSON serialisable dict
-            to be used as response body of request.
+            HTTP status code and a JSON serialisable dict to be used as response
+            body of request.
         """
-        pass
 
     @classmethod
-    def make_client(cls, hs: "HomeServer"):
+    def make_client(cls, hs: "HomeServer") -> Callable:
         """Create a client that makes requests.
 
         Returns a callable that accepts the same parameters as
@@ -182,7 +189,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
             )
 
         @trace(opname="outgoing_replication_request")
-        async def send_request(*, instance_name="master", **kwargs):
+        async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
             with outgoing_gauge.track_inprogress():
                 if instance_name == local_instance_name:
                     raise Exception("Trying to send HTTP request to self")
@@ -268,7 +275,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
 
         return send_request
 
-    def register(self, http_server):
+    def register(self, http_server: HttpServer) -> None:
         """Called by the server to register this as a handler to the
         appropriate path.
         """
@@ -289,7 +296,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
             self.__class__.__name__,
         )
 
-    async def _check_auth_and_handle(self, request, **kwargs):
+    async def _check_auth_and_handle(
+        self, request: Request, **kwargs: Any
+    ) -> Tuple[int, JsonDict]:
         """Called on new incoming requests when caching is enabled. Checks
         if there is a cached response for the request and returns that,
         otherwise calls `_handle_request` and caches its response.