summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/appservice/test_scheduler.py6
-rw-r--r--tests/http/federation/test_matrix_federation_agent.py5
-rw-r--r--tests/http/test_proxyagent.py5
-rw-r--r--tests/rest/client/test_auth.py14
-rw-r--r--tests/rest/client/utils.py58
-rw-r--r--tests/server.py253
-rw-r--r--tests/unittest.py11
7 files changed, 225 insertions, 127 deletions
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index febcc1499d..e2a3bad065 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -11,12 +11,13 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, cast
+from typing import List, Optional, Sequence, Tuple, cast
 from unittest.mock import Mock
 
 from typing_extensions import TypeAlias
 
 from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.appservice import (
     ApplicationService,
@@ -40,9 +41,6 @@ from tests.test_utils import simple_async_mock
 
 from ..utils import MockClock
 
-if TYPE_CHECKING:
-    from twisted.internet.testing import MemoryReactor
-
 
 class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
     def setUp(self) -> None:
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index d27422515c..eb7f53fee5 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -30,7 +30,7 @@ from twisted.internet.interfaces import (
     IOpenSSLClientConnectionCreator,
     IProtocolFactory,
 )
-from twisted.internet.protocol import Factory
+from twisted.internet.protocol import Factory, Protocol
 from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
 from twisted.web._newclient import ResponseNeverReceived
 from twisted.web.client import Agent
@@ -466,7 +466,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
         else:
             assert isinstance(proxy_server_transport, FakeTransport)
             client_protocol = proxy_server_transport.other
-            c2s_transport = client_protocol.transport
+            assert isinstance(client_protocol, Protocol)
+            c2s_transport = checked_cast(FakeTransport, client_protocol.transport)
             c2s_transport.other = server_ssl_protocol
 
         self.reactor.advance(0)
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index 22fdc7f5f2..cc175052ac 100644
--- a/tests/http/test_proxyagent.py
+++ b/tests/http/test_proxyagent.py
@@ -28,7 +28,7 @@ from twisted.internet.endpoints import (
     _WrappingProtocol,
 )
 from twisted.internet.interfaces import IProtocol, IProtocolFactory
-from twisted.internet.protocol import Factory
+from twisted.internet.protocol import Factory, Protocol
 from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
 from twisted.web.http import HTTPChannel
 
@@ -644,7 +644,8 @@ class MatrixFederationAgentTests(TestCase):
         else:
             assert isinstance(proxy_server_transport, FakeTransport)
             client_protocol = proxy_server_transport.other
-            c2s_transport = client_protocol.transport
+            assert isinstance(client_protocol, Protocol)
+            c2s_transport = checked_cast(FakeTransport, client_protocol.transport)
             c2s_transport.other = server_ssl_protocol
 
         self.reactor.advance(0)
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index f4e1e7de43..a144610078 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -34,7 +34,7 @@ from synapse.util import Clock
 from tests import unittest
 from tests.handlers.test_oidc import HAS_OIDC
 from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER
-from tests.server import FakeChannel, make_request
+from tests.server import FakeChannel
 from tests.unittest import override_config, skip_unless
 
 
@@ -1322,16 +1322,8 @@ class OidcBackchannelLogoutTests(unittest.HomeserverTestCase):
         channel = self.submit_logout_token(logout_token)
         self.assertEqual(channel.code, 200)
 
-        # Now try to exchange the login token
-        channel = make_request(
-            self.hs.get_reactor(),
-            self.site,
-            "POST",
-            "/login",
-            content={"type": "m.login.token", "token": login_token},
-        )
-        # It should have failed
-        self.assertEqual(channel.code, 403)
+        # Now try to exchange the login token, it should fail.
+        self.helper.login_via_token(login_token, 403)
 
     @override_config(
         {
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 8d6f2b6ff9..9532e5ddc1 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -36,6 +36,7 @@ from urllib.parse import urlencode
 import attr
 from typing_extensions import Literal
 
+from twisted.test.proto_helpers import MemoryReactorClock
 from twisted.web.resource import Resource
 from twisted.web.server import Site
 
@@ -67,6 +68,7 @@ class RestHelper:
     """
 
     hs: HomeServer
+    reactor: MemoryReactorClock
     site: Site
     auth_user_id: Optional[str]
 
@@ -142,7 +144,7 @@ class RestHelper:
             path = path + "?access_token=%s" % tok
 
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "POST",
             path,
@@ -216,7 +218,7 @@ class RestHelper:
             data["reason"] = reason
 
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "POST",
             path,
@@ -313,7 +315,7 @@ class RestHelper:
         data.update(extra_data or {})
 
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "PUT",
             path,
@@ -394,7 +396,7 @@ class RestHelper:
             path = path + "?access_token=%s" % tok
 
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "PUT",
             path,
@@ -433,7 +435,7 @@ class RestHelper:
             path = path + f"?access_token={tok}"
 
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "GET",
             path,
@@ -488,7 +490,7 @@ class RestHelper:
         if body is not None:
             content = json.dumps(body).encode("utf8")
 
-        channel = make_request(self.hs.get_reactor(), self.site, method, path, content)
+        channel = make_request(self.reactor, self.site, method, path, content)
 
         assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
             expect_code,
@@ -573,8 +575,8 @@ class RestHelper:
         image_length = len(image_data)
         path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
         channel = make_request(
-            self.hs.get_reactor(),
-            FakeSite(resource, self.hs.get_reactor()),
+            self.reactor,
+            FakeSite(resource, self.reactor),
             "POST",
             path,
             content=image_data,
@@ -603,7 +605,7 @@ class RestHelper:
             expect_code: The return code to expect from attempting the whoami request
         """
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "GET",
             "account/whoami",
@@ -642,7 +644,7 @@ class RestHelper:
     ) -> Tuple[JsonDict, FakeAuthorizationGrant]:
         """Log in (as a new user) via OIDC
 
-        Returns the result of the final token login.
+        Returns the result of the final token login and the fake authorization grant.
 
         Requires that "oidc_config" in the homeserver config be set appropriately
         (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
@@ -672,10 +674,28 @@ class RestHelper:
         assert m, channel.text_body
         login_token = m.group(1)
 
-        # finally, submit the matrix login token to the login API, which gives us our
-        # matrix access token and device id.
+        return self.login_via_token(login_token, expected_status), grant
+
+    def login_via_token(
+        self,
+        login_token: str,
+        expected_status: int = 200,
+    ) -> JsonDict:
+        """Submit the matrix login token to the login API, which gives us our
+        matrix access token and device id.Log in (as a new user) via OIDC
+
+        Returns the result of the token login.
+
+        Requires that "oidc_config" in the homeserver config be set appropriately
+        (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
+        "public_base_url".
+
+        Also requires the login servlet and the OIDC callback resource to be mounted at
+        the normal places.
+        """
+
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "POST",
             "/login",
@@ -684,7 +704,7 @@ class RestHelper:
         assert (
             channel.code == expected_status
         ), f"unexpected status in response: {channel.code}"
-        return channel.json_body, grant
+        return channel.json_body
 
     def auth_via_oidc(
         self,
@@ -805,7 +825,7 @@ class RestHelper:
         with fake_serer.patch_homeserver(hs=self.hs):
             # now hit the callback URI with the right params and a made-up code
             channel = make_request(
-                self.hs.get_reactor(),
+                self.reactor,
                 self.site,
                 "GET",
                 callback_uri,
@@ -849,7 +869,7 @@ class RestHelper:
         # is the easiest way of figuring out what the Host header ought to be set to
         # to keep Synapse happy.
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "GET",
             uri,
@@ -867,7 +887,7 @@ class RestHelper:
         location = get_location(channel)
         parts = urllib.parse.urlsplit(location)
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "GET",
             urllib.parse.urlunsplit(("", "") + parts[2:]),
@@ -900,9 +920,7 @@ class RestHelper:
             + urllib.parse.urlencode({"session": ui_auth_session_id})
         )
         # hit the redirect url (which will issue a cookie and state)
-        channel = make_request(
-            self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint
-        )
+        channel = make_request(self.reactor, self.site, "GET", sso_redirect_endpoint)
         # that should serve a confirmation page
         assert channel.code == HTTPStatus.OK, channel.text_body
         channel.extract_cookies(cookies)
diff --git a/tests/server.py b/tests/server.py
index 237bcad8ba..5de9722766 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -22,20 +22,25 @@ import warnings
 from collections import deque
 from io import SEEK_END, BytesIO
 from typing import (
+    Any,
+    Awaitable,
     Callable,
     Dict,
     Iterable,
     List,
     MutableMapping,
     Optional,
+    Sequence,
     Tuple,
     Type,
+    TypeVar,
     Union,
+    cast,
 )
 from unittest.mock import Mock
 
 import attr
-from typing_extensions import Deque
+from typing_extensions import Deque, ParamSpec
 from zope.interface import implementer
 
 from twisted.internet import address, threads, udp
@@ -44,8 +49,10 @@ from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
 from twisted.internet.error import DNSLookupError
 from twisted.internet.interfaces import (
     IAddress,
+    IConnector,
     IConsumer,
     IHostnameResolver,
+    IProducer,
     IProtocol,
     IPullProducer,
     IPushProducer,
@@ -54,6 +61,8 @@ from twisted.internet.interfaces import (
     IResolverSimple,
     ITransport,
 )
+from twisted.internet.protocol import ClientFactory, DatagramProtocol
+from twisted.python import threadpool
 from twisted.python.failure import Failure
 from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
 from twisted.web.http_headers import Headers
@@ -61,6 +70,7 @@ from twisted.web.resource import IResource
 from twisted.web.server import Request, Site
 
 from synapse.config.database import DatabaseConnectionConfig
+from synapse.config.homeserver import HomeServerConfig
 from synapse.events.presence_router import load_legacy_presence_router
 from synapse.events.spamcheck import load_legacy_spam_checkers
 from synapse.events.third_party_rules import load_legacy_third_party_event_rules
@@ -88,6 +98,9 @@ from tests.utils import (
 
 logger = logging.getLogger(__name__)
 
+R = TypeVar("R")
+P = ParamSpec("P")
+
 # the type of thing that can be passed into `make_request` in the headers list
 CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]
 
@@ -98,12 +111,14 @@ class TimedOutException(Exception):
     """
 
 
-@implementer(IConsumer)
+@implementer(ITransport, IPushProducer, IConsumer)
 @attr.s(auto_attribs=True)
 class FakeChannel:
     """
     A fake Twisted Web Channel (the part that interfaces with the
     wire).
+
+    See twisted.web.http.HTTPChannel.
     """
 
     site: Union[Site, "FakeSite"]
@@ -142,7 +157,7 @@ class FakeChannel:
 
         Raises an exception if the request has not yet completed.
         """
-        if not self.is_finished:
+        if not self.is_finished():
             raise Exception("Request not yet completed")
         return self.result["body"].decode("utf8")
 
@@ -165,27 +180,36 @@ class FakeChannel:
             h.addRawHeader(*i)
         return h
 
-    def writeHeaders(self, version, code, reason, headers):
+    def writeHeaders(
+        self, version: bytes, code: bytes, reason: bytes, headers: Headers
+    ) -> None:
         self.result["version"] = version
         self.result["code"] = code
         self.result["reason"] = reason
         self.result["headers"] = headers
 
-    def write(self, content: bytes) -> None:
-        assert isinstance(content, bytes), "Should be bytes! " + repr(content)
+    def write(self, data: bytes) -> None:
+        assert isinstance(data, bytes), "Should be bytes! " + repr(data)
 
         if "body" not in self.result:
             self.result["body"] = b""
 
-        self.result["body"] += content
+        self.result["body"] += data
+
+    def writeSequence(self, data: Iterable[bytes]) -> None:
+        for x in data:
+            self.write(x)
+
+    def loseConnection(self) -> None:
+        self.unregisterProducer()
+        self.transport.loseConnection()
 
     # Type ignore: mypy doesn't like the fact that producer isn't an IProducer.
-    def registerProducer(  # type: ignore[override]
-        self,
-        producer: Union[IPullProducer, IPushProducer],
-        streaming: bool,
-    ) -> None:
-        self._producer = producer
+    def registerProducer(self, producer: IProducer, streaming: bool) -> None:
+        # TODO This should ensure that the IProducer is an IPushProducer or
+        # IPullProducer, unfortunately twisted.protocols.basic.FileSender does
+        # implement those, but doesn't declare it.
+        self._producer = cast(Union[IPushProducer, IPullProducer], producer)
         self.producerStreaming = streaming
 
         def _produce() -> None:
@@ -202,6 +226,16 @@ class FakeChannel:
 
         self._producer = None
 
+    def stopProducing(self) -> None:
+        if self._producer is not None:
+            self._producer.stopProducing()
+
+    def pauseProducing(self) -> None:
+        raise NotImplementedError()
+
+    def resumeProducing(self) -> None:
+        raise NotImplementedError()
+
     def requestDone(self, _self: Request) -> None:
         self.result["done"] = True
         if isinstance(_self, SynapseRequest):
@@ -281,12 +315,12 @@ class FakeSite:
         self.reactor = reactor
         self.experimental_cors_msc3886 = experimental_cors_msc3886
 
-    def getResourceFor(self, request):
+    def getResourceFor(self, request: Request) -> IResource:
         return self._resource
 
 
 def make_request(
-    reactor,
+    reactor: MemoryReactorClock,
     site: Union[Site, FakeSite],
     method: Union[bytes, str],
     path: Union[bytes, str],
@@ -409,19 +443,21 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
     A MemoryReactorClock that supports callFromThread.
     """
 
-    def __init__(self):
+    def __init__(self) -> None:
         self.threadpool = ThreadPool(self)
 
         self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {}
-        self._udp = []
+        self._udp: List[udp.Port] = []
         self.lookups: Dict[str, str] = {}
-        self._thread_callbacks: Deque[Callable[[], None]] = deque()
+        self._thread_callbacks: Deque[Callable[..., R]] = deque()
 
         lookups = self.lookups
 
         @implementer(IResolverSimple)
         class FakeResolver:
-            def getHostByName(self, name, timeout=None):
+            def getHostByName(
+                self, name: str, timeout: Optional[Sequence[int]] = None
+            ) -> "Deferred[str]":
                 if name not in lookups:
                     return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
                 return succeed(lookups[name])
@@ -432,25 +468,44 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
     def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
         raise NotImplementedError()
 
-    def listenUDP(self, port, protocol, interface="", maxPacketSize=8196):
+    def listenUDP(
+        self,
+        port: int,
+        protocol: DatagramProtocol,
+        interface: str = "",
+        maxPacketSize: int = 8196,
+    ) -> udp.Port:
         p = udp.Port(port, protocol, interface, maxPacketSize, self)
         p.startListening()
         self._udp.append(p)
         return p
 
-    def callFromThread(self, callback, *args, **kwargs):
+    def callFromThread(
+        self, callable: Callable[..., Any], *args: object, **kwargs: object
+    ) -> None:
         """
         Make the callback fire in the next reactor iteration.
         """
-        cb = lambda: callback(*args, **kwargs)
+        cb = lambda: callable(*args, **kwargs)
         # it's not safe to call callLater() here, so we append the callback to a
         # separate queue.
         self._thread_callbacks.append(cb)
 
-    def getThreadPool(self):
-        return self.threadpool
+    def callInThread(
+        self, callable: Callable[..., Any], *args: object, **kwargs: object
+    ) -> None:
+        raise NotImplementedError()
+
+    def suggestThreadPoolSize(self, size: int) -> None:
+        raise NotImplementedError()
+
+    def getThreadPool(self) -> "threadpool.ThreadPool":
+        # Cast to match super-class.
+        return cast(threadpool.ThreadPool, self.threadpool)
 
-    def add_tcp_client_callback(self, host: str, port: int, callback: Callable):
+    def add_tcp_client_callback(
+        self, host: str, port: int, callback: Callable[[], None]
+    ) -> None:
         """Add a callback that will be invoked when we receive a connection
         attempt to the given IP/port using `connectTCP`.
 
@@ -459,7 +514,14 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
         """
         self._tcp_callbacks[(host, port)] = callback
 
-    def connectTCP(self, host: str, port: int, factory, timeout=30, bindAddress=None):
+    def connectTCP(
+        self,
+        host: str,
+        port: int,
+        factory: ClientFactory,
+        timeout: float = 30,
+        bindAddress: Optional[Tuple[str, int]] = None,
+    ) -> IConnector:
         """Fake L{IReactorTCP.connectTCP}."""
 
         conn = super().connectTCP(
@@ -472,7 +534,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
 
         return conn
 
-    def advance(self, amount):
+    def advance(self, amount: float) -> None:
         # first advance our reactor's time, and run any "callLater" callbacks that
         # makes ready
         super().advance(amount)
@@ -500,25 +562,33 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
 class ThreadPool:
     """
     Threadless thread pool.
+
+    See twisted.python.threadpool.ThreadPool
     """
 
-    def __init__(self, reactor):
+    def __init__(self, reactor: IReactorTime):
         self._reactor = reactor
 
-    def start(self):
+    def start(self) -> None:
         pass
 
-    def stop(self):
+    def stop(self) -> None:
         pass
 
-    def callInThreadWithCallback(self, onResult, function, *args, **kwargs):
-        def _(res):
+    def callInThreadWithCallback(
+        self,
+        onResult: Callable[[bool, Union[Failure, R]], None],
+        function: Callable[P, R],
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> "Deferred[None]":
+        def _(res: Any) -> None:
             if isinstance(res, Failure):
                 onResult(False, res)
             else:
                 onResult(True, res)
 
-        d = Deferred()
+        d: "Deferred[None]" = Deferred()
         d.addCallback(lambda x: function(*args, **kwargs))
         d.addBoth(_)
         self._reactor.callLater(0, d.callback, True)
@@ -535,7 +605,9 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None:
     for database in server.get_datastores().databases:
         pool = database._db_pool
 
-        def runWithConnection(func, *args, **kwargs):
+        def runWithConnection(
+            func: Callable[..., R], *args: Any, **kwargs: Any
+        ) -> Awaitable[R]:
             return threads.deferToThreadPool(
                 pool._reactor,
                 pool.threadpool,
@@ -545,20 +617,23 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None:
                 **kwargs,
             )
 
-        def runInteraction(interaction, *args, **kwargs):
+        def runInteraction(
+            desc: str, func: Callable[..., R], *args: Any, **kwargs: Any
+        ) -> Awaitable[R]:
             return threads.deferToThreadPool(
                 pool._reactor,
                 pool.threadpool,
                 pool._runInteraction,
-                interaction,
+                desc,
+                func,
                 *args,
                 **kwargs,
             )
 
-        pool.runWithConnection = runWithConnection
-        pool.runInteraction = runInteraction
+        pool.runWithConnection = runWithConnection  # type: ignore[assignment]
+        pool.runInteraction = runInteraction  # type: ignore[assignment]
         # Replace the thread pool with a threadless 'thread' pool
-        pool.threadpool = ThreadPool(clock._reactor)
+        pool.threadpool = ThreadPool(clock._reactor)  # type: ignore[assignment]
         pool.running = True
 
     # We've just changed the Databases to run DB transactions on the same
@@ -573,7 +648,7 @@ def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
 
 
 @implementer(ITransport)
-@attr.s(cmp=False)
+@attr.s(cmp=False, auto_attribs=True)
 class FakeTransport:
     """
     A twisted.internet.interfaces.ITransport implementation which sends all its data
@@ -588,48 +663,50 @@ class FakeTransport:
     If you want bidirectional communication, you'll need two instances.
     """
 
-    other = attr.ib()
+    other: IProtocol
     """The Protocol object which will receive any data written to this transport.
-
-    :type: twisted.internet.interfaces.IProtocol
     """
 
-    _reactor = attr.ib()
+    _reactor: IReactorTime
     """Test reactor
-
-    :type: twisted.internet.interfaces.IReactorTime
     """
 
-    _protocol = attr.ib(default=None)
+    _protocol: Optional[IProtocol] = None
     """The Protocol which is producing data for this transport. Optional, but if set
     will get called back for connectionLost() notifications etc.
     """
 
-    _peer_address: Optional[IAddress] = attr.ib(default=None)
+    _peer_address: IAddress = attr.Factory(
+        lambda: address.IPv4Address("TCP", "127.0.0.1", 5678)
+    )
     """The value to be returned by getPeer"""
 
-    _host_address: Optional[IAddress] = attr.ib(default=None)
+    _host_address: IAddress = attr.Factory(
+        lambda: address.IPv4Address("TCP", "127.0.0.1", 1234)
+    )
     """The value to be returned by getHost"""
 
     disconnecting = False
     disconnected = False
     connected = True
-    buffer = attr.ib(default=b"")
-    producer = attr.ib(default=None)
-    autoflush = attr.ib(default=True)
+    buffer: bytes = b""
+    producer: Optional[IPushProducer] = None
+    autoflush: bool = True
 
-    def getPeer(self) -> Optional[IAddress]:
+    def getPeer(self) -> IAddress:
         return self._peer_address
 
-    def getHost(self) -> Optional[IAddress]:
+    def getHost(self) -> IAddress:
         return self._host_address
 
-    def loseConnection(self, reason=None):
+    def loseConnection(self) -> None:
         if not self.disconnecting:
-            logger.info("FakeTransport: loseConnection(%s)", reason)
+            logger.info("FakeTransport: loseConnection()")
             self.disconnecting = True
             if self._protocol:
-                self._protocol.connectionLost(reason)
+                self._protocol.connectionLost(
+                    Failure(RuntimeError("FakeTransport.loseConnection()"))
+                )
 
             # if we still have data to write, delay until that is done
             if self.buffer:
@@ -640,38 +717,38 @@ class FakeTransport:
                 self.connected = False
                 self.disconnected = True
 
-    def abortConnection(self):
+    def abortConnection(self) -> None:
         logger.info("FakeTransport: abortConnection()")
 
         if not self.disconnecting:
             self.disconnecting = True
             if self._protocol:
-                self._protocol.connectionLost(None)
+                self._protocol.connectionLost(None)  # type: ignore[arg-type]
 
         self.disconnected = True
 
-    def pauseProducing(self):
+    def pauseProducing(self) -> None:
         if not self.producer:
             return
 
         self.producer.pauseProducing()
 
-    def resumeProducing(self):
+    def resumeProducing(self) -> None:
         if not self.producer:
             return
         self.producer.resumeProducing()
 
-    def unregisterProducer(self):
+    def unregisterProducer(self) -> None:
         if not self.producer:
             return
 
         self.producer = None
 
-    def registerProducer(self, producer, streaming):
+    def registerProducer(self, producer: IPushProducer, streaming: bool) -> None:
         self.producer = producer
         self.producerStreaming = streaming
 
-        def _produce():
+        def _produce() -> None:
             if not self.producer:
                 # we've been unregistered
                 return
@@ -683,7 +760,7 @@ class FakeTransport:
         if not streaming:
             self._reactor.callLater(0.0, _produce)
 
-    def write(self, byt):
+    def write(self, byt: bytes) -> None:
         if self.disconnecting:
             raise Exception("Writing to disconnecting FakeTransport")
 
@@ -695,11 +772,11 @@ class FakeTransport:
         if self.autoflush:
             self._reactor.callLater(0.0, self.flush)
 
-    def writeSequence(self, seq):
+    def writeSequence(self, seq: Iterable[bytes]) -> None:
         for x in seq:
             self.write(x)
 
-    def flush(self, maxbytes=None):
+    def flush(self, maxbytes: Optional[int] = None) -> None:
         if not self.buffer:
             # nothing to do. Don't write empty buffers: it upsets the
             # TLSMemoryBIOProtocol
@@ -750,17 +827,17 @@ def connect_client(
 
 
 class TestHomeServer(HomeServer):
-    DATASTORE_CLASS = DataStore
+    DATASTORE_CLASS = DataStore  # type: ignore[assignment]
 
 
 def setup_test_homeserver(
-    cleanup_func,
-    name="test",
-    config=None,
-    reactor=None,
+    cleanup_func: Callable[[Callable[[], None]], None],
+    name: str = "test",
+    config: Optional[HomeServerConfig] = None,
+    reactor: Optional[ISynapseReactor] = None,
     homeserver_to_use: Type[HomeServer] = TestHomeServer,
-    **kwargs,
-):
+    **kwargs: Any,
+) -> HomeServer:
     """
     Setup a homeserver suitable for running tests against.  Keyword arguments
     are passed to the Homeserver constructor.
@@ -775,13 +852,14 @@ def setup_test_homeserver(
     HomeserverTestCase.
     """
     if reactor is None:
-        from twisted.internet import reactor
+        from twisted.internet import reactor as _reactor
+
+        reactor = cast(ISynapseReactor, _reactor)
 
     if config is None:
         config = default_config(name, parse=True)
 
     config.caches.resize_all_caches()
-    config.ldap_enabled = False
 
     if "clock" not in kwargs:
         kwargs["clock"] = MockClock()
@@ -832,6 +910,8 @@ def setup_test_homeserver(
     # Create the database before we actually try and connect to it, based off
     # the template database we generate in setupdb()
     if isinstance(db_engine, PostgresEngine):
+        import psycopg2.extensions
+
         db_conn = db_engine.module.connect(
             database=POSTGRES_BASE_DB,
             user=POSTGRES_USER,
@@ -839,6 +919,7 @@ def setup_test_homeserver(
             port=POSTGRES_PORT,
             password=POSTGRES_PASSWORD,
         )
+        assert isinstance(db_conn, psycopg2.extensions.connection)
         db_conn.autocommit = True
         cur = db_conn.cursor()
         cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
@@ -867,14 +948,15 @@ def setup_test_homeserver(
         hs.setup_background_tasks()
 
     if isinstance(db_engine, PostgresEngine):
-        database = hs.get_datastores().databases[0]
+        database_pool = hs.get_datastores().databases[0]
 
         # We need to do cleanup on PostgreSQL
-        def cleanup():
+        def cleanup() -> None:
             import psycopg2
+            import psycopg2.extensions
 
             # Close all the db pools
-            database._db_pool.close()
+            database_pool._db_pool.close()
 
             dropped = False
 
@@ -886,6 +968,7 @@ def setup_test_homeserver(
                 port=POSTGRES_PORT,
                 password=POSTGRES_PASSWORD,
             )
+            assert isinstance(db_conn, psycopg2.extensions.connection)
             db_conn.autocommit = True
             cur = db_conn.cursor()
 
@@ -918,23 +1001,23 @@ def setup_test_homeserver(
     # Need to let the HS build an auth handler and then mess with it
     # because AuthHandler's constructor requires the HS, so we can't make one
     # beforehand and pass it in to the HS's constructor (chicken / egg)
-    async def hash(p):
+    async def hash(p: str) -> str:
         return hashlib.md5(p.encode("utf8")).hexdigest()
 
-    hs.get_auth_handler().hash = hash
+    hs.get_auth_handler().hash = hash  # type: ignore[assignment]
 
-    async def validate_hash(p, h):
+    async def validate_hash(p: str, h: str) -> bool:
         return hashlib.md5(p.encode("utf8")).hexdigest() == h
 
-    hs.get_auth_handler().validate_hash = validate_hash
+    hs.get_auth_handler().validate_hash = validate_hash  # type: ignore[assignment]
 
     # Make the threadpool and database transactions synchronous for testing.
     _make_test_homeserver_synchronous(hs)
 
     # Load any configured modules into the homeserver
     module_api = hs.get_module_api()
-    for module, config in hs.config.modules.loaded_modules:
-        module(config=config, api=module_api)
+    for module, module_config in hs.config.modules.loaded_modules:
+        module(config=module_config, api=module_api)
 
     load_legacy_spam_checkers(hs)
     load_legacy_third_party_event_rules(hs)
diff --git a/tests/unittest.py b/tests/unittest.py
index c1cb5933fa..b21e7f1221 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -45,7 +45,7 @@ from typing_extensions import Concatenate, ParamSpec, Protocol
 from twisted.internet.defer import Deferred, ensureDeferred
 from twisted.python.failure import Failure
 from twisted.python.threadpool import ThreadPool
-from twisted.test.proto_helpers import MemoryReactor
+from twisted.test.proto_helpers import MemoryReactor, MemoryReactorClock
 from twisted.trial import unittest
 from twisted.web.resource import Resource
 from twisted.web.server import Request
@@ -82,7 +82,7 @@ from tests.server import (
 )
 from tests.test_utils import event_injection, setup_awaitable_errors
 from tests.test_utils.logging_setup import setup_logging
-from tests.utils import default_config, setupdb
+from tests.utils import checked_cast, default_config, setupdb
 
 setupdb()
 setup_logging()
@@ -296,7 +296,12 @@ class HomeserverTestCase(TestCase):
 
         from tests.rest.client.utils import RestHelper
 
-        self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None))
+        self.helper = RestHelper(
+            self.hs,
+            checked_cast(MemoryReactorClock, self.hs.get_reactor()),
+            self.site,
+            getattr(self, "user_id", None),
+        )
 
         if hasattr(self, "user_id"):
             if self.hijack_auth: