summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/9591.misc1
-rw-r--r--synapse/api/auth.py2
-rw-r--r--synapse/federation/federation_server.py8
-rw-r--r--synapse/handlers/oidc_handler.py9
-rw-r--r--synapse/http/client.py9
-rw-r--r--synapse/logging/_remote.py23
-rw-r--r--synapse/push/emailpusher.py4
-rw-r--r--synapse/replication/tcp/handler.py44
-rw-r--r--synapse/replication/tcp/protocol.py24
-rw-r--r--synapse/replication/tcp/redis.py8
-rw-r--r--synapse/rest/admin/_base.py15
-rw-r--r--synapse/rest/admin/media.py29
-rw-r--r--synapse/rest/client/v2_alpha/groups.py105
-rw-r--r--synapse/rest/media/v1/config_resource.py3
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py3
-rw-r--r--synapse/rest/media/v1/upload_resource.py3
-rw-r--r--synapse/server.py8
-rw-r--r--tests/replication/test_federation_ack.py8
18 files changed, 187 insertions, 119 deletions
diff --git a/changelog.d/9591.misc b/changelog.d/9591.misc
new file mode 100644
index 0000000000..14c7b78dd9
--- /dev/null
+++ b/changelog.d/9591.misc
@@ -0,0 +1 @@
+Fix incorrect type hints.
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 968cf6f174..e10e33fd23 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -164,7 +164,7 @@ class Auth:
 
     async def get_user_by_req(
         self,
-        request: Request,
+        request: SynapseRequest,
         allow_guest: bool = False,
         rights: str = "access",
         allow_expired: bool = False,
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index db6e49dbca..9839d3d016 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -880,7 +880,9 @@ class FederationHandlerRegistry:
         self.edu_handlers = (
             {}
         )  # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
-        self.query_handlers = {}  # type: Dict[str, Callable[[dict], Awaitable[None]]]
+        self.query_handlers = (
+            {}
+        )  # type: Dict[str, Callable[[dict], Awaitable[JsonDict]]]
 
         # Map from type to instance names that we should route EDU handling to.
         # We randomly choose one instance from the list to route to for each new
@@ -914,7 +916,7 @@ class FederationHandlerRegistry:
         self.edu_handlers[edu_type] = handler
 
     def register_query_handler(
-        self, query_type: str, handler: Callable[[dict], defer.Deferred]
+        self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
     ):
         """Sets the handler callable that will be used to handle an incoming
         federation query of the given type.
@@ -987,7 +989,7 @@ class FederationHandlerRegistry:
         # Oh well, let's just log and move on.
         logger.warning("No handler registered for EDU type %s", edu_type)
 
-    async def on_query(self, query_type: str, args: dict):
+    async def on_query(self, query_type: str, args: dict) -> JsonDict:
         handler = self.query_handlers.get(query_type)
         if handler:
             return await handler(args)
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 825fadb76f..f5d1821127 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -34,6 +34,7 @@ from pymacaroons.exceptions import (
 from typing_extensions import TypedDict
 
 from twisted.web.client import readBody
+from twisted.web.http_headers import Headers
 
 from synapse.config import ConfigError
 from synapse.config.oidc_config import (
@@ -538,7 +539,7 @@ class OidcProvider:
         """
         metadata = await self.load_metadata()
         token_endpoint = metadata.get("token_endpoint")
-        headers = {
+        raw_headers = {
             "Content-Type": "application/x-www-form-urlencoded",
             "User-Agent": self._http_client.user_agent,
             "Accept": "application/json",
@@ -552,10 +553,10 @@ class OidcProvider:
         body = urlencode(args, True)
 
         # Fill the body/headers with credentials
-        uri, headers, body = self._client_auth.prepare(
-            method="POST", uri=token_endpoint, headers=headers, body=body
+        uri, raw_headers, body = self._client_auth.prepare(
+            method="POST", uri=token_endpoint, headers=raw_headers, body=body
         )
-        headers = {k: [v] for (k, v) in headers.items()}
+        headers = Headers({k: [v] for (k, v) in raw_headers.items()})
 
         # Do the actual request
         # We're not using the SimpleHttpClient util methods as we don't want to
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 8f3da486b3..d4ab3a2732 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -57,7 +57,13 @@ from twisted.web.client import (
 )
 from twisted.web.http import PotentialDataLoss
 from twisted.web.http_headers import Headers
-from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse
+from twisted.web.iweb import (
+    UNKNOWN_LENGTH,
+    IAgent,
+    IBodyProducer,
+    IPolicyForHTTPS,
+    IResponse,
+)
 
 from synapse.api.errors import Codes, HttpResponseException, SynapseError
 from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
@@ -870,6 +876,7 @@ def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> by
     return query_str.encode("utf8")
 
 
+@implementer(IPolicyForHTTPS)
 class InsecureInterceptableContextFactory(ssl.ContextFactory):
     """
     Factory for PyOpenSSL SSL contexts which accepts any certificate for any domain.
diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py
index 174ca7be5a..643492ceaf 100644
--- a/synapse/logging/_remote.py
+++ b/synapse/logging/_remote.py
@@ -32,8 +32,9 @@ from twisted.internet.endpoints import (
     TCP4ClientEndpoint,
     TCP6ClientEndpoint,
 )
-from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint, ITransport
+from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint
 from twisted.internet.protocol import Factory, Protocol
+from twisted.internet.tcp import Connection
 from twisted.python.failure import Failure
 
 logger = logging.getLogger(__name__)
@@ -52,7 +53,9 @@ class LogProducer:
         format: A callable to format the log record to a string.
     """
 
-    transport = attr.ib(type=ITransport)
+    # This is essentially ITCPTransport, but that is missing certain fields
+    # (connected and registerProducer) which are part of the implementation.
+    transport = attr.ib(type=Connection)
     _format = attr.ib(type=Callable[[logging.LogRecord], str])
     _buffer = attr.ib(type=deque)
     _paused = attr.ib(default=False, type=bool, init=False)
@@ -149,8 +152,6 @@ class RemoteHandler(logging.Handler):
         if self._connection_waiter:
             return
 
-        self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
-
         def fail(failure: Failure) -> None:
             # If the Deferred was cancelled (e.g. during shutdown) do not try to
             # reconnect (this will cause an infinite loop of errors).
@@ -163,9 +164,13 @@ class RemoteHandler(logging.Handler):
             self._connect()
 
         def writer(result: Protocol) -> None:
+            # Force recognising transport as a Connection and not the more
+            # generic ITransport.
+            transport = result.transport  # type: Connection  # type: ignore
+
             # We have a connection. If we already have a producer, and its
             # transport is the same, just trigger a resumeProducing.
-            if self._producer and result.transport is self._producer.transport:
+            if self._producer and transport is self._producer.transport:
                 self._producer.resumeProducing()
                 self._connection_waiter = None
                 return
@@ -177,14 +182,16 @@ class RemoteHandler(logging.Handler):
             # Make a new producer and start it.
             self._producer = LogProducer(
                 buffer=self._buffer,
-                transport=result.transport,
+                transport=transport,
                 format=self.format,
             )
-            result.transport.registerProducer(self._producer, True)
+            transport.registerProducer(self._producer, True)
             self._producer.resumeProducing()
             self._connection_waiter = None
 
-        self._connection_waiter.addCallbacks(writer, fail)
+        deferred = self._service.whenConnected(failAfterFailures=1)  # type: Deferred
+        deferred.addCallbacks(writer, fail)
+        self._connection_waiter = deferred
 
     def _handle_pressure(self) -> None:
         """
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 5fec2aaf5d..3dc06a79e8 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -16,8 +16,8 @@
 import logging
 from typing import TYPE_CHECKING, Dict, List, Optional
 
-from twisted.internet.base import DelayedCall
 from twisted.internet.error import AlreadyCalled, AlreadyCancelled
+from twisted.internet.interfaces import IDelayedCall
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.push import Pusher, PusherConfig, ThrottleParams
@@ -66,7 +66,7 @@ class EmailPusher(Pusher):
 
         self.store = self.hs.get_datastore()
         self.email = pusher_config.pushkey
-        self.timed_call = None  # type: Optional[DelayedCall]
+        self.timed_call = None  # type: Optional[IDelayedCall]
         self.throttle_params = {}  # type: Dict[str, ThrottleParams]
         self._inited = False
 
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index a7245da152..ee909f3fc5 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -48,7 +48,7 @@ from synapse.replication.tcp.commands import (
     UserIpCommand,
     UserSyncCommand,
 )
-from synapse.replication.tcp.protocol import AbstractConnection
+from synapse.replication.tcp.protocol import IReplicationConnection
 from synapse.replication.tcp.streams import (
     STREAMS_MAP,
     AccountDataStream,
@@ -82,7 +82,7 @@ user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache"
 
 # the type of the entries in _command_queues_by_stream
 _StreamCommandQueue = Deque[
-    Tuple[Union[RdataCommand, PositionCommand], AbstractConnection]
+    Tuple[Union[RdataCommand, PositionCommand], IReplicationConnection]
 ]
 
 
@@ -174,7 +174,7 @@ class ReplicationCommandHandler:
 
         # The currently connected connections. (The list of places we need to send
         # outgoing replication commands to.)
-        self._connections = []  # type: List[AbstractConnection]
+        self._connections = []  # type: List[IReplicationConnection]
 
         LaterGauge(
             "synapse_replication_tcp_resource_total_connections",
@@ -197,7 +197,7 @@ class ReplicationCommandHandler:
 
         # For each connection, the incoming stream names that have received a POSITION
         # from that connection.
-        self._streams_by_connection = {}  # type: Dict[AbstractConnection, Set[str]]
+        self._streams_by_connection = {}  # type: Dict[IReplicationConnection, Set[str]]
 
         LaterGauge(
             "synapse_replication_tcp_command_queue",
@@ -220,7 +220,7 @@ class ReplicationCommandHandler:
             self._server_notices_sender = hs.get_server_notices_sender()
 
     def _add_command_to_stream_queue(
-        self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
+        self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand]
     ) -> None:
         """Queue the given received command for processing
 
@@ -267,7 +267,7 @@ class ReplicationCommandHandler:
     async def _process_command(
         self,
         cmd: Union[PositionCommand, RdataCommand],
-        conn: AbstractConnection,
+        conn: IReplicationConnection,
         stream_name: str,
     ) -> None:
         if isinstance(cmd, PositionCommand):
@@ -321,10 +321,10 @@ class ReplicationCommandHandler:
         """Get a list of streams that this instances replicates."""
         return self._streams_to_replicate
 
-    def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
+    def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand):
         self.send_positions_to_connection(conn)
 
-    def send_positions_to_connection(self, conn: AbstractConnection):
+    def send_positions_to_connection(self, conn: IReplicationConnection):
         """Send current position of all streams this process is source of to
         the connection.
         """
@@ -347,7 +347,7 @@ class ReplicationCommandHandler:
             )
 
     def on_USER_SYNC(
-        self, conn: AbstractConnection, cmd: UserSyncCommand
+        self, conn: IReplicationConnection, cmd: UserSyncCommand
     ) -> Optional[Awaitable[None]]:
         user_sync_counter.inc()
 
@@ -359,21 +359,23 @@ class ReplicationCommandHandler:
             return None
 
     def on_CLEAR_USER_SYNC(
-        self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
+        self, conn: IReplicationConnection, cmd: ClearUserSyncsCommand
     ) -> Optional[Awaitable[None]]:
         if self._is_master:
             return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
         else:
             return None
 
-    def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand):
+    def on_FEDERATION_ACK(
+        self, conn: IReplicationConnection, cmd: FederationAckCommand
+    ):
         federation_ack_counter.inc()
 
         if self._federation_sender:
             self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
 
     def on_USER_IP(
-        self, conn: AbstractConnection, cmd: UserIpCommand
+        self, conn: IReplicationConnection, cmd: UserIpCommand
     ) -> Optional[Awaitable[None]]:
         user_ip_cache_counter.inc()
 
@@ -395,7 +397,7 @@ class ReplicationCommandHandler:
         assert self._server_notices_sender is not None
         await self._server_notices_sender.on_user_ip(cmd.user_id)
 
-    def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
+    def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand):
         if cmd.instance_name == self._instance_name:
             # Ignore RDATA that are just our own echoes
             return
@@ -412,7 +414,7 @@ class ReplicationCommandHandler:
         self._add_command_to_stream_queue(conn, cmd)
 
     async def _process_rdata(
-        self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
+        self, stream_name: str, conn: IReplicationConnection, cmd: RdataCommand
     ) -> None:
         """Process an RDATA command
 
@@ -486,7 +488,7 @@ class ReplicationCommandHandler:
             stream_name, instance_name, token, rows
         )
 
-    def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
+    def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand):
         if cmd.instance_name == self._instance_name:
             # Ignore POSITION that are just our own echoes
             return
@@ -496,7 +498,7 @@ class ReplicationCommandHandler:
         self._add_command_to_stream_queue(conn, cmd)
 
     async def _process_position(
-        self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
+        self, stream_name: str, conn: IReplicationConnection, cmd: PositionCommand
     ) -> None:
         """Process a POSITION command
 
@@ -553,7 +555,9 @@ class ReplicationCommandHandler:
 
         self._streams_by_connection.setdefault(conn, set()).add(stream_name)
 
-    def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand):
+    def on_REMOTE_SERVER_UP(
+        self, conn: IReplicationConnection, cmd: RemoteServerUpCommand
+    ):
         """"Called when get a new REMOTE_SERVER_UP command."""
         self._replication_data_handler.on_remote_server_up(cmd.data)
 
@@ -576,7 +580,7 @@ class ReplicationCommandHandler:
         # between two instances, but that is not currently supported).
         self.send_command(cmd, ignore_conn=conn)
 
-    def new_connection(self, connection: AbstractConnection):
+    def new_connection(self, connection: IReplicationConnection):
         """Called when we have a new connection."""
         self._connections.append(connection)
 
@@ -603,7 +607,7 @@ class ReplicationCommandHandler:
                 UserSyncCommand(self._instance_id, user_id, True, now)
             )
 
-    def lost_connection(self, connection: AbstractConnection):
+    def lost_connection(self, connection: IReplicationConnection):
         """Called when a connection is closed/lost."""
         # we no longer need _streams_by_connection for this connection.
         streams = self._streams_by_connection.pop(connection, None)
@@ -624,7 +628,7 @@ class ReplicationCommandHandler:
         return bool(self._connections)
 
     def send_command(
-        self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None
+        self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None
     ):
         """Send a command to all connected connections.
 
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index e0b4ad314d..8e4734b59c 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -46,7 +46,6 @@ indicate which side is sending, these are *not* included on the wire::
     > ERROR server stopping
     * connection closed by server *
 """
-import abc
 import fcntl
 import logging
 import struct
@@ -54,6 +53,7 @@ from inspect import isawaitable
 from typing import TYPE_CHECKING, List, Optional
 
 from prometheus_client import Counter
+from zope.interface import Interface, implementer
 
 from twisted.internet import task
 from twisted.protocols.basic import LineOnlyReceiver
@@ -121,6 +121,14 @@ class ConnectionStates:
     CLOSED = "closed"
 
 
+class IReplicationConnection(Interface):
+    """An interface for replication connections."""
+
+    def send_command(cmd: Command):
+        """Send the command down the connection"""
+
+
+@implementer(IReplicationConnection)
 class BaseReplicationStreamProtocol(LineOnlyReceiver):
     """Base replication protocol shared between client and server.
 
@@ -495,20 +503,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
         self.send_command(ReplicateCommand())
 
 
-class AbstractConnection(abc.ABC):
-    """An interface for replication connections."""
-
-    @abc.abstractmethod
-    def send_command(self, cmd: Command):
-        """Send the command down the connection"""
-        pass
-
-
-# This tells python that `BaseReplicationStreamProtocol` implements the
-# interface.
-AbstractConnection.register(BaseReplicationStreamProtocol)
-
-
 # The following simply registers metrics for the replication connections
 
 pending_commands = LaterGauge(
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 574eaea1eb..7cccde097d 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, cast
 
 import attr
 import txredisapi
+from zope.interface import implementer
 
 from twisted.internet.address import IPv4Address, IPv6Address
 from twisted.internet.interfaces import IAddress, IConnector
@@ -36,7 +37,7 @@ from synapse.replication.tcp.commands import (
     parse_command_from_line,
 )
 from synapse.replication.tcp.protocol import (
-    AbstractConnection,
+    IReplicationConnection,
     tcp_inbound_commands_counter,
     tcp_outbound_commands_counter,
 )
@@ -66,7 +67,8 @@ class ConstantProperty(Generic[T, V]):
         pass
 
 
-class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
+@implementer(IReplicationConnection)
+class RedisSubscriber(txredisapi.SubscriberProtocol):
     """Connection to redis subscribed to replication stream.
 
     This class fulfils two functions:
@@ -75,7 +77,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
     connection, parsing *incoming* messages into replication commands, and passing them
     to `ReplicationCommandHandler`
 
-    (b) it implements the AbstractConnection API, where it sends *outgoing* commands
+    (b) it implements the IReplicationConnection API, where it sends *outgoing* commands
     onto outbound_redis_connection.
 
     Due to the vagaries of `txredisapi` we don't want to have a custom
diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py
index e09234c644..7681e55b58 100644
--- a/synapse/rest/admin/_base.py
+++ b/synapse/rest/admin/_base.py
@@ -15,10 +15,9 @@
 
 import re
 
-import twisted.web.server
-
-import synapse.api.auth
+from synapse.api.auth import Auth
 from synapse.api.errors import AuthError
+from synapse.http.site import SynapseRequest
 from synapse.types import UserID
 
 
@@ -37,13 +36,11 @@ def admin_patterns(path_regex: str, version: str = "v1"):
     return patterns
 
 
-async def assert_requester_is_admin(
-    auth: synapse.api.auth.Auth, request: twisted.web.server.Request
-) -> None:
+async def assert_requester_is_admin(auth: Auth, request: SynapseRequest) -> None:
     """Verify that the requester is an admin user
 
     Args:
-        auth: api.auth.Auth singleton
+        auth: Auth singleton
         request: incoming request
 
     Raises:
@@ -53,11 +50,11 @@ async def assert_requester_is_admin(
     await assert_user_is_admin(auth, requester.user)
 
 
-async def assert_user_is_admin(auth: synapse.api.auth.Auth, user_id: UserID) -> None:
+async def assert_user_is_admin(auth: Auth, user_id: UserID) -> None:
     """Verify that the given user is an admin user
 
     Args:
-        auth: api.auth.Auth singleton
+        auth: Auth singleton
         user_id: user to check
 
     Raises:
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 511c859f64..7fcc48a9d7 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -17,10 +17,9 @@
 import logging
 from typing import TYPE_CHECKING, Tuple
 
-from twisted.web.server import Request
-
 from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
 from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
+from synapse.http.site import SynapseRequest
 from synapse.rest.admin._base import (
     admin_patterns,
     assert_requester_is_admin,
@@ -50,7 +49,9 @@ class QuarantineMediaInRoom(RestServlet):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
 
-    async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
+    async def on_POST(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -75,7 +76,9 @@ class QuarantineMediaByUser(RestServlet):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
 
-    async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
+    async def on_POST(
+        self, request: SynapseRequest, user_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -103,7 +106,7 @@ class QuarantineMediaByID(RestServlet):
         self.auth = hs.get_auth()
 
     async def on_POST(
-        self, request: Request, server_name: str, media_id: str
+        self, request: SynapseRequest, server_name: str, media_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
@@ -127,7 +130,9 @@ class ProtectMediaByID(RestServlet):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
 
-    async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]:
+    async def on_POST(
+        self, request: SynapseRequest, media_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -148,7 +153,9 @@ class ListMediaInRoom(RestServlet):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
 
-    async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
+    async def on_GET(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         is_admin = await self.auth.is_server_admin(requester.user)
         if not is_admin:
@@ -166,7 +173,7 @@ class PurgeMediaCacheRestServlet(RestServlet):
         self.media_repository = hs.get_media_repository()
         self.auth = hs.get_auth()
 
-    async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
         before_ts = parse_integer(request, "before_ts", required=True)
@@ -189,7 +196,7 @@ class DeleteMediaByID(RestServlet):
         self.media_repository = hs.get_media_repository()
 
     async def on_DELETE(
-        self, request: Request, server_name: str, media_id: str
+        self, request: SynapseRequest, server_name: str, media_id: str
     ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
@@ -218,7 +225,9 @@ class DeleteMediaByDateSize(RestServlet):
         self.server_name = hs.hostname
         self.media_repository = hs.get_media_repository()
 
-    async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]:
+    async def on_POST(
+        self, request: SynapseRequest, server_name: str
+    ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
         before_ts = parse_integer(request, "before_ts", required=True)
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index 7aea4cebf5..5901432fad 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -32,6 +32,7 @@ from synapse.http.servlet import (
     assert_params_in_dict,
     parse_json_object_from_request,
 )
+from synapse.http.site import SynapseRequest
 from synapse.types import GroupID, JsonDict
 
 from ._base import client_patterns
@@ -70,7 +71,9 @@ class GroupServlet(RestServlet):
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+    async def on_GET(
+        self, request: SynapseRequest, group_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -81,7 +84,9 @@ class GroupServlet(RestServlet):
         return 200, group_description
 
     @_validate_group_id
-    async def on_POST(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+    async def on_POST(
+        self, request: SynapseRequest, group_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
@@ -111,7 +116,9 @@ class GroupSummaryServlet(RestServlet):
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+    async def on_GET(
+        self, request: SynapseRequest, group_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -144,7 +151,11 @@ class GroupSummaryRoomsCatServlet(RestServlet):
 
     @_validate_group_id
     async def on_PUT(
-        self, request: Request, group_id: str, category_id: Optional[str], room_id: str
+        self,
+        request: SynapseRequest,
+        group_id: str,
+        category_id: Optional[str],
+        room_id: str,
     ):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -176,7 +187,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
 
     @_validate_group_id
     async def on_DELETE(
-        self, request: Request, group_id: str, category_id: str, room_id: str
+        self, request: SynapseRequest, group_id: str, category_id: str, room_id: str
     ):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -206,7 +217,7 @@ class GroupCategoryServlet(RestServlet):
 
     @_validate_group_id
     async def on_GET(
-        self, request: Request, group_id: str, category_id: str
+        self, request: SynapseRequest, group_id: str, category_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
@@ -219,7 +230,7 @@ class GroupCategoryServlet(RestServlet):
 
     @_validate_group_id
     async def on_PUT(
-        self, request: Request, group_id: str, category_id: str
+        self, request: SynapseRequest, group_id: str, category_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -247,7 +258,7 @@ class GroupCategoryServlet(RestServlet):
 
     @_validate_group_id
     async def on_DELETE(
-        self, request: Request, group_id: str, category_id: str
+        self, request: SynapseRequest, group_id: str, category_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -274,7 +285,9 @@ class GroupCategoriesServlet(RestServlet):
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+    async def on_GET(
+        self, request: SynapseRequest, group_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -298,7 +311,7 @@ class GroupRoleServlet(RestServlet):
 
     @_validate_group_id
     async def on_GET(
-        self, request: Request, group_id: str, role_id: str
+        self, request: SynapseRequest, group_id: str, role_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
@@ -311,7 +324,7 @@ class GroupRoleServlet(RestServlet):
 
     @_validate_group_id
     async def on_PUT(
-        self, request: Request, group_id: str, role_id: str
+        self, request: SynapseRequest, group_id: str, role_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -339,7 +352,7 @@ class GroupRoleServlet(RestServlet):
 
     @_validate_group_id
     async def on_DELETE(
-        self, request: Request, group_id: str, role_id: str
+        self, request: SynapseRequest, group_id: str, role_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -366,7 +379,9 @@ class GroupRolesServlet(RestServlet):
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+    async def on_GET(
+        self, request: SynapseRequest, group_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -399,7 +414,11 @@ class GroupSummaryUsersRoleServlet(RestServlet):
 
     @_validate_group_id
     async def on_PUT(
-        self, request: Request, group_id: str, role_id: Optional[str], user_id: str
+        self,
+        request: SynapseRequest,
+        group_id: str,
+        role_id: Optional[str],
+        user_id: str,
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -431,7 +450,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
 
     @_validate_group_id
     async def on_DELETE(
-        self, request: Request, group_id: str, role_id: str, user_id: str
+        self, request: SynapseRequest, group_id: str, role_id: str, user_id: str
     ):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -458,7 +477,9 @@ class GroupRoomServlet(RestServlet):
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+    async def on_GET(
+        self, request: SynapseRequest, group_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -481,7 +502,9 @@ class GroupUsersServlet(RestServlet):
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+    async def on_GET(
+        self, request: SynapseRequest, group_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
@@ -504,7 +527,9 @@ class GroupInvitedUsersServlet(RestServlet):
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+    async def on_GET(
+        self, request: SynapseRequest, group_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
@@ -526,7 +551,9 @@ class GroupSettingJoinPolicyServlet(RestServlet):
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+    async def on_PUT(
+        self, request: SynapseRequest, group_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
@@ -554,7 +581,7 @@ class GroupCreateServlet(RestServlet):
         self.groups_handler = hs.get_groups_local_handler()
         self.server_name = hs.hostname
 
-    async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
@@ -598,7 +625,7 @@ class GroupAdminRoomsServlet(RestServlet):
 
     @_validate_group_id
     async def on_PUT(
-        self, request: Request, group_id: str, room_id: str
+        self, request: SynapseRequest, group_id: str, room_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -615,7 +642,7 @@ class GroupAdminRoomsServlet(RestServlet):
 
     @_validate_group_id
     async def on_DELETE(
-        self, request: Request, group_id: str, room_id: str
+        self, request: SynapseRequest, group_id: str, room_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -646,7 +673,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
 
     @_validate_group_id
     async def on_PUT(
-        self, request: Request, group_id: str, room_id: str, config_key: str
+        self, request: SynapseRequest, group_id: str, room_id: str, config_key: str
     ):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -678,7 +705,9 @@ class GroupAdminUsersInviteServlet(RestServlet):
         self.is_mine_id = hs.is_mine_id
 
     @_validate_group_id
-    async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
+    async def on_PUT(
+        self, request: SynapseRequest, group_id, user_id
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
@@ -708,7 +737,9 @@ class GroupAdminUsersKickServlet(RestServlet):
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
+    async def on_PUT(
+        self, request: SynapseRequest, group_id, user_id
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
@@ -735,7 +766,9 @@ class GroupSelfLeaveServlet(RestServlet):
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+    async def on_PUT(
+        self, request: SynapseRequest, group_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
@@ -762,7 +795,9 @@ class GroupSelfJoinServlet(RestServlet):
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+    async def on_PUT(
+        self, request: SynapseRequest, group_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
@@ -789,7 +824,9 @@ class GroupSelfAcceptInviteServlet(RestServlet):
         self.groups_handler = hs.get_groups_local_handler()
 
     @_validate_group_id
-    async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+    async def on_PUT(
+        self, request: SynapseRequest, group_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
@@ -816,7 +853,9 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
         self.store = hs.get_datastore()
 
     @_validate_group_id
-    async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
+    async def on_PUT(
+        self, request: SynapseRequest, group_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
@@ -839,7 +878,9 @@ class PublicisedGroupsForUserServlet(RestServlet):
         self.store = hs.get_datastore()
         self.groups_handler = hs.get_groups_local_handler()
 
-    async def on_GET(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
+    async def on_GET(
+        self, request: SynapseRequest, user_id: str
+    ) -> Tuple[int, JsonDict]:
         await self.auth.get_user_by_req(request, allow_guest=True)
 
         result = await self.groups_handler.get_publicised_groups_for_user(user_id)
@@ -859,7 +900,7 @@ class PublicisedGroupsForUsersServlet(RestServlet):
         self.store = hs.get_datastore()
         self.groups_handler = hs.get_groups_local_handler()
 
-    async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         await self.auth.get_user_by_req(request, allow_guest=True)
 
         content = parse_json_object_from_request(request)
@@ -881,7 +922,7 @@ class GroupsForUserServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
-    async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py
index 9039662f7e..1eff98ef14 100644
--- a/synapse/rest/media/v1/config_resource.py
+++ b/synapse/rest/media/v1/config_resource.py
@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING
 from twisted.web.server import Request
 
 from synapse.http.server import DirectServeJsonResource, respond_with_json
+from synapse.http.site import SynapseRequest
 
 if TYPE_CHECKING:
     from synapse.app.homeserver import HomeServer
@@ -35,7 +36,7 @@ class MediaConfigResource(DirectServeJsonResource):
         self.auth = hs.get_auth()
         self.limits_dict = {"m.upload.size": config.max_upload_size}
 
-    async def _async_render_GET(self, request: Request) -> None:
+    async def _async_render_GET(self, request: SynapseRequest) -> None:
         await self.auth.get_user_by_req(request)
         respond_with_json(request, 200, self.limits_dict, send_cors=True)
 
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index a074e807dc..b8895aeaa9 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -39,6 +39,7 @@ from synapse.http.server import (
     respond_with_json_bytes,
 )
 from synapse.http.servlet import parse_integer, parse_string
+from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.rest.media.v1._base import get_filename_from_headers
@@ -185,7 +186,7 @@ class PreviewUrlResource(DirectServeJsonResource):
         request.setHeader(b"Allow", b"OPTIONS, GET")
         respond_with_json(request, 200, {}, send_cors=True)
 
-    async def _async_render_GET(self, request: Request) -> None:
+    async def _async_render_GET(self, request: SynapseRequest) -> None:
 
         # XXX: if get_user_by_req fails, what should we do in an async render?
         requester = await self.auth.get_user_by_req(request)
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 5e104fac40..ae5aef2f7f 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -22,6 +22,7 @@ from twisted.web.server import Request
 from synapse.api.errors import Codes, SynapseError
 from synapse.http.server import DirectServeJsonResource, respond_with_json
 from synapse.http.servlet import parse_string
+from synapse.http.site import SynapseRequest
 from synapse.rest.media.v1.media_storage import SpamMediaException
 
 if TYPE_CHECKING:
@@ -49,7 +50,7 @@ class UploadResource(DirectServeJsonResource):
     async def _async_render_OPTIONS(self, request: Request) -> None:
         respond_with_json(request, 200, {}, send_cors=True)
 
-    async def _async_render_POST(self, request: Request) -> None:
+    async def _async_render_POST(self, request: SynapseRequest) -> None:
         requester = await self.auth.get_user_by_req(request)
         # TODO: The checks here are a bit late. The content will have
         # already been uploaded to a tmp file at this point
diff --git a/synapse/server.py b/synapse/server.py
index 369cc88026..48ac87a124 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -351,11 +351,9 @@ class HomeServer(metaclass=abc.ABCMeta):
 
     @cache_in_self
     def get_http_client_context_factory(self) -> IPolicyForHTTPS:
-        return (
-            InsecureInterceptableContextFactory()
-            if self.config.use_insecure_ssl_client_just_for_testing_do_not_use
-            else RegularPolicyForHTTPS()
-        )
+        if self.config.use_insecure_ssl_client_just_for_testing_do_not_use:
+            return InsecureInterceptableContextFactory()
+        return RegularPolicyForHTTPS()
 
     @cache_in_self
     def get_simple_http_client(self) -> SimpleHttpClient:
diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index f235f1bd83..0d9e3bb11d 100644
--- a/tests/replication/test_federation_ack.py
+++ b/tests/replication/test_federation_ack.py
@@ -17,7 +17,7 @@ import mock
 
 from synapse.app.generic_worker import GenericWorkerServer
 from synapse.replication.tcp.commands import FederationAckCommand
-from synapse.replication.tcp.protocol import AbstractConnection
+from synapse.replication.tcp.protocol import IReplicationConnection
 from synapse.replication.tcp.streams.federation import FederationStream
 
 from tests.unittest import HomeserverTestCase
@@ -51,8 +51,10 @@ class FederationAckTestCase(HomeserverTestCase):
         """
         rch = self.hs.get_tcp_replication()
 
-        # wire up the ReplicationCommandHandler to a mock connection
-        mock_connection = mock.Mock(spec=AbstractConnection)
+        # wire up the ReplicationCommandHandler to a mock connection, which needs
+        # to implement IReplicationConnection. (Note that Mock doesn't understand
+        # interfaces, but casing an interface to a list gives the attributes.)
+        mock_connection = mock.Mock(spec=list(IReplicationConnection))
         rch.new_connection(mock_connection)
 
         # tell it it received an RDATA row