diff --git a/changelog.d/15084.misc b/changelog.d/15084.misc
new file mode 100644
index 0000000000..93ceaeafc9
--- /dev/null
+++ b/changelog.d/15084.misc
@@ -0,0 +1 @@
+Improve type hints.
diff --git a/mypy.ini b/mypy.ini
index ff6e04b12f..94562d0bce 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -31,8 +31,6 @@ exclude = (?x)
|synapse/storage/databases/__init__.py
|synapse/storage/databases/main/cache.py
|synapse/storage/schema/
-
- |tests/server.py
)$
[mypy-synapse.federation.transport.client]
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index febcc1499d..e2a3bad065 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, cast
+from typing import List, Optional, Sequence, Tuple, cast
from unittest.mock import Mock
from typing_extensions import TypeAlias
from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
from synapse.appservice import (
ApplicationService,
@@ -40,9 +41,6 @@ from tests.test_utils import simple_async_mock
from ..utils import MockClock
-if TYPE_CHECKING:
- from twisted.internet.testing import MemoryReactor
-
class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
def setUp(self) -> None:
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index d27422515c..eb7f53fee5 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -30,7 +30,7 @@ from twisted.internet.interfaces import (
IOpenSSLClientConnectionCreator,
IProtocolFactory,
)
-from twisted.internet.protocol import Factory
+from twisted.internet.protocol import Factory, Protocol
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.web._newclient import ResponseNeverReceived
from twisted.web.client import Agent
@@ -466,7 +466,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
else:
assert isinstance(proxy_server_transport, FakeTransport)
client_protocol = proxy_server_transport.other
- c2s_transport = client_protocol.transport
+ assert isinstance(client_protocol, Protocol)
+ c2s_transport = checked_cast(FakeTransport, client_protocol.transport)
c2s_transport.other = server_ssl_protocol
self.reactor.advance(0)
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index 22fdc7f5f2..cc175052ac 100644
--- a/tests/http/test_proxyagent.py
+++ b/tests/http/test_proxyagent.py
@@ -28,7 +28,7 @@ from twisted.internet.endpoints import (
_WrappingProtocol,
)
from twisted.internet.interfaces import IProtocol, IProtocolFactory
-from twisted.internet.protocol import Factory
+from twisted.internet.protocol import Factory, Protocol
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.web.http import HTTPChannel
@@ -644,7 +644,8 @@ class MatrixFederationAgentTests(TestCase):
else:
assert isinstance(proxy_server_transport, FakeTransport)
client_protocol = proxy_server_transport.other
- c2s_transport = client_protocol.transport
+ assert isinstance(client_protocol, Protocol)
+ c2s_transport = checked_cast(FakeTransport, client_protocol.transport)
c2s_transport.other = server_ssl_protocol
self.reactor.advance(0)
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index f4e1e7de43..a144610078 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -34,7 +34,7 @@ from synapse.util import Clock
from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC
from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER
-from tests.server import FakeChannel, make_request
+from tests.server import FakeChannel
from tests.unittest import override_config, skip_unless
@@ -1322,16 +1322,8 @@ class OidcBackchannelLogoutTests(unittest.HomeserverTestCase):
channel = self.submit_logout_token(logout_token)
self.assertEqual(channel.code, 200)
- # Now try to exchange the login token
- channel = make_request(
- self.hs.get_reactor(),
- self.site,
- "POST",
- "/login",
- content={"type": "m.login.token", "token": login_token},
- )
- # It should have failed
- self.assertEqual(channel.code, 403)
+ # Now try to exchange the login token, it should fail.
+ self.helper.login_via_token(login_token, 403)
@override_config(
{
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 8d6f2b6ff9..9532e5ddc1 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -36,6 +36,7 @@ from urllib.parse import urlencode
import attr
from typing_extensions import Literal
+from twisted.test.proto_helpers import MemoryReactorClock
from twisted.web.resource import Resource
from twisted.web.server import Site
@@ -67,6 +68,7 @@ class RestHelper:
"""
hs: HomeServer
+ reactor: MemoryReactorClock
site: Site
auth_user_id: Optional[str]
@@ -142,7 +144,7 @@ class RestHelper:
path = path + "?access_token=%s" % tok
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"POST",
path,
@@ -216,7 +218,7 @@ class RestHelper:
data["reason"] = reason
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"POST",
path,
@@ -313,7 +315,7 @@ class RestHelper:
data.update(extra_data or {})
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"PUT",
path,
@@ -394,7 +396,7 @@ class RestHelper:
path = path + "?access_token=%s" % tok
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"PUT",
path,
@@ -433,7 +435,7 @@ class RestHelper:
path = path + f"?access_token={tok}"
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"GET",
path,
@@ -488,7 +490,7 @@ class RestHelper:
if body is not None:
content = json.dumps(body).encode("utf8")
- channel = make_request(self.hs.get_reactor(), self.site, method, path, content)
+ channel = make_request(self.reactor, self.site, method, path, content)
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
@@ -573,8 +575,8 @@ class RestHelper:
image_length = len(image_data)
path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
channel = make_request(
- self.hs.get_reactor(),
- FakeSite(resource, self.hs.get_reactor()),
+ self.reactor,
+ FakeSite(resource, self.reactor),
"POST",
path,
content=image_data,
@@ -603,7 +605,7 @@ class RestHelper:
expect_code: The return code to expect from attempting the whoami request
"""
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"GET",
"account/whoami",
@@ -642,7 +644,7 @@ class RestHelper:
) -> Tuple[JsonDict, FakeAuthorizationGrant]:
"""Log in (as a new user) via OIDC
- Returns the result of the final token login.
+ Returns the result of the final token login and the fake authorization grant.
Requires that "oidc_config" in the homeserver config be set appropriately
(TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
@@ -672,10 +674,28 @@ class RestHelper:
assert m, channel.text_body
login_token = m.group(1)
- # finally, submit the matrix login token to the login API, which gives us our
- # matrix access token and device id.
+ return self.login_via_token(login_token, expected_status), grant
+
+ def login_via_token(
+ self,
+ login_token: str,
+ expected_status: int = 200,
+ ) -> JsonDict:
+ """Submit the matrix login token to the login API, which gives us our
+ matrix access token and device id.Log in (as a new user) via OIDC
+
+ Returns the result of the token login.
+
+ Requires that "oidc_config" in the homeserver config be set appropriately
+ (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
+ "public_base_url".
+
+ Also requires the login servlet and the OIDC callback resource to be mounted at
+ the normal places.
+ """
+
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"POST",
"/login",
@@ -684,7 +704,7 @@ class RestHelper:
assert (
channel.code == expected_status
), f"unexpected status in response: {channel.code}"
- return channel.json_body, grant
+ return channel.json_body
def auth_via_oidc(
self,
@@ -805,7 +825,7 @@ class RestHelper:
with fake_serer.patch_homeserver(hs=self.hs):
# now hit the callback URI with the right params and a made-up code
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"GET",
callback_uri,
@@ -849,7 +869,7 @@ class RestHelper:
# is the easiest way of figuring out what the Host header ought to be set to
# to keep Synapse happy.
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"GET",
uri,
@@ -867,7 +887,7 @@ class RestHelper:
location = get_location(channel)
parts = urllib.parse.urlsplit(location)
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"GET",
urllib.parse.urlunsplit(("", "") + parts[2:]),
@@ -900,9 +920,7 @@ class RestHelper:
+ urllib.parse.urlencode({"session": ui_auth_session_id})
)
# hit the redirect url (which will issue a cookie and state)
- channel = make_request(
- self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint
- )
+ channel = make_request(self.reactor, self.site, "GET", sso_redirect_endpoint)
# that should serve a confirmation page
assert channel.code == HTTPStatus.OK, channel.text_body
channel.extract_cookies(cookies)
diff --git a/tests/server.py b/tests/server.py
index 237bcad8ba..5de9722766 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -22,20 +22,25 @@ import warnings
from collections import deque
from io import SEEK_END, BytesIO
from typing import (
+ Any,
+ Awaitable,
Callable,
Dict,
Iterable,
List,
MutableMapping,
Optional,
+ Sequence,
Tuple,
Type,
+ TypeVar,
Union,
+ cast,
)
from unittest.mock import Mock
import attr
-from typing_extensions import Deque
+from typing_extensions import Deque, ParamSpec
from zope.interface import implementer
from twisted.internet import address, threads, udp
@@ -44,8 +49,10 @@ from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import (
IAddress,
+ IConnector,
IConsumer,
IHostnameResolver,
+ IProducer,
IProtocol,
IPullProducer,
IPushProducer,
@@ -54,6 +61,8 @@ from twisted.internet.interfaces import (
IResolverSimple,
ITransport,
)
+from twisted.internet.protocol import ClientFactory, DatagramProtocol
+from twisted.python import threadpool
from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
from twisted.web.http_headers import Headers
@@ -61,6 +70,7 @@ from twisted.web.resource import IResource
from twisted.web.server import Request, Site
from synapse.config.database import DatabaseConnectionConfig
+from synapse.config.homeserver import HomeServerConfig
from synapse.events.presence_router import load_legacy_presence_router
from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
@@ -88,6 +98,9 @@ from tests.utils import (
logger = logging.getLogger(__name__)
+R = TypeVar("R")
+P = ParamSpec("P")
+
# the type of thing that can be passed into `make_request` in the headers list
CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]
@@ -98,12 +111,14 @@ class TimedOutException(Exception):
"""
-@implementer(IConsumer)
+@implementer(ITransport, IPushProducer, IConsumer)
@attr.s(auto_attribs=True)
class FakeChannel:
"""
A fake Twisted Web Channel (the part that interfaces with the
wire).
+
+ See twisted.web.http.HTTPChannel.
"""
site: Union[Site, "FakeSite"]
@@ -142,7 +157,7 @@ class FakeChannel:
Raises an exception if the request has not yet completed.
"""
- if not self.is_finished:
+ if not self.is_finished():
raise Exception("Request not yet completed")
return self.result["body"].decode("utf8")
@@ -165,27 +180,36 @@ class FakeChannel:
h.addRawHeader(*i)
return h
- def writeHeaders(self, version, code, reason, headers):
+ def writeHeaders(
+ self, version: bytes, code: bytes, reason: bytes, headers: Headers
+ ) -> None:
self.result["version"] = version
self.result["code"] = code
self.result["reason"] = reason
self.result["headers"] = headers
- def write(self, content: bytes) -> None:
- assert isinstance(content, bytes), "Should be bytes! " + repr(content)
+ def write(self, data: bytes) -> None:
+ assert isinstance(data, bytes), "Should be bytes! " + repr(data)
if "body" not in self.result:
self.result["body"] = b""
- self.result["body"] += content
+ self.result["body"] += data
+
+ def writeSequence(self, data: Iterable[bytes]) -> None:
+ for x in data:
+ self.write(x)
+
+ def loseConnection(self) -> None:
+ self.unregisterProducer()
+ self.transport.loseConnection()
# Type ignore: mypy doesn't like the fact that producer isn't an IProducer.
- def registerProducer( # type: ignore[override]
- self,
- producer: Union[IPullProducer, IPushProducer],
- streaming: bool,
- ) -> None:
- self._producer = producer
+ def registerProducer(self, producer: IProducer, streaming: bool) -> None:
+ # TODO This should ensure that the IProducer is an IPushProducer or
+ # IPullProducer, unfortunately twisted.protocols.basic.FileSender does
+ # implement those, but doesn't declare it.
+ self._producer = cast(Union[IPushProducer, IPullProducer], producer)
self.producerStreaming = streaming
def _produce() -> None:
@@ -202,6 +226,16 @@ class FakeChannel:
self._producer = None
+ def stopProducing(self) -> None:
+ if self._producer is not None:
+ self._producer.stopProducing()
+
+ def pauseProducing(self) -> None:
+ raise NotImplementedError()
+
+ def resumeProducing(self) -> None:
+ raise NotImplementedError()
+
def requestDone(self, _self: Request) -> None:
self.result["done"] = True
if isinstance(_self, SynapseRequest):
@@ -281,12 +315,12 @@ class FakeSite:
self.reactor = reactor
self.experimental_cors_msc3886 = experimental_cors_msc3886
- def getResourceFor(self, request):
+ def getResourceFor(self, request: Request) -> IResource:
return self._resource
def make_request(
- reactor,
+ reactor: MemoryReactorClock,
site: Union[Site, FakeSite],
method: Union[bytes, str],
path: Union[bytes, str],
@@ -409,19 +443,21 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
A MemoryReactorClock that supports callFromThread.
"""
- def __init__(self):
+ def __init__(self) -> None:
self.threadpool = ThreadPool(self)
self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {}
- self._udp = []
+ self._udp: List[udp.Port] = []
self.lookups: Dict[str, str] = {}
- self._thread_callbacks: Deque[Callable[[], None]] = deque()
+ self._thread_callbacks: Deque[Callable[..., R]] = deque()
lookups = self.lookups
@implementer(IResolverSimple)
class FakeResolver:
- def getHostByName(self, name, timeout=None):
+ def getHostByName(
+ self, name: str, timeout: Optional[Sequence[int]] = None
+ ) -> "Deferred[str]":
if name not in lookups:
return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
return succeed(lookups[name])
@@ -432,25 +468,44 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
raise NotImplementedError()
- def listenUDP(self, port, protocol, interface="", maxPacketSize=8196):
+ def listenUDP(
+ self,
+ port: int,
+ protocol: DatagramProtocol,
+ interface: str = "",
+ maxPacketSize: int = 8196,
+ ) -> udp.Port:
p = udp.Port(port, protocol, interface, maxPacketSize, self)
p.startListening()
self._udp.append(p)
return p
- def callFromThread(self, callback, *args, **kwargs):
+ def callFromThread(
+ self, callable: Callable[..., Any], *args: object, **kwargs: object
+ ) -> None:
"""
Make the callback fire in the next reactor iteration.
"""
- cb = lambda: callback(*args, **kwargs)
+ cb = lambda: callable(*args, **kwargs)
# it's not safe to call callLater() here, so we append the callback to a
# separate queue.
self._thread_callbacks.append(cb)
- def getThreadPool(self):
- return self.threadpool
+ def callInThread(
+ self, callable: Callable[..., Any], *args: object, **kwargs: object
+ ) -> None:
+ raise NotImplementedError()
+
+ def suggestThreadPoolSize(self, size: int) -> None:
+ raise NotImplementedError()
+
+ def getThreadPool(self) -> "threadpool.ThreadPool":
+ # Cast to match super-class.
+ return cast(threadpool.ThreadPool, self.threadpool)
- def add_tcp_client_callback(self, host: str, port: int, callback: Callable):
+ def add_tcp_client_callback(
+ self, host: str, port: int, callback: Callable[[], None]
+ ) -> None:
"""Add a callback that will be invoked when we receive a connection
attempt to the given IP/port using `connectTCP`.
@@ -459,7 +514,14 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
"""
self._tcp_callbacks[(host, port)] = callback
- def connectTCP(self, host: str, port: int, factory, timeout=30, bindAddress=None):
+ def connectTCP(
+ self,
+ host: str,
+ port: int,
+ factory: ClientFactory,
+ timeout: float = 30,
+ bindAddress: Optional[Tuple[str, int]] = None,
+ ) -> IConnector:
"""Fake L{IReactorTCP.connectTCP}."""
conn = super().connectTCP(
@@ -472,7 +534,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
return conn
- def advance(self, amount):
+ def advance(self, amount: float) -> None:
# first advance our reactor's time, and run any "callLater" callbacks that
# makes ready
super().advance(amount)
@@ -500,25 +562,33 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
class ThreadPool:
"""
Threadless thread pool.
+
+ See twisted.python.threadpool.ThreadPool
"""
- def __init__(self, reactor):
+ def __init__(self, reactor: IReactorTime):
self._reactor = reactor
- def start(self):
+ def start(self) -> None:
pass
- def stop(self):
+ def stop(self) -> None:
pass
- def callInThreadWithCallback(self, onResult, function, *args, **kwargs):
- def _(res):
+ def callInThreadWithCallback(
+ self,
+ onResult: Callable[[bool, Union[Failure, R]], None],
+ function: Callable[P, R],
+ *args: P.args,
+ **kwargs: P.kwargs,
+ ) -> "Deferred[None]":
+ def _(res: Any) -> None:
if isinstance(res, Failure):
onResult(False, res)
else:
onResult(True, res)
- d = Deferred()
+ d: "Deferred[None]" = Deferred()
d.addCallback(lambda x: function(*args, **kwargs))
d.addBoth(_)
self._reactor.callLater(0, d.callback, True)
@@ -535,7 +605,9 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None:
for database in server.get_datastores().databases:
pool = database._db_pool
- def runWithConnection(func, *args, **kwargs):
+ def runWithConnection(
+ func: Callable[..., R], *args: Any, **kwargs: Any
+ ) -> Awaitable[R]:
return threads.deferToThreadPool(
pool._reactor,
pool.threadpool,
@@ -545,20 +617,23 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None:
**kwargs,
)
- def runInteraction(interaction, *args, **kwargs):
+ def runInteraction(
+ desc: str, func: Callable[..., R], *args: Any, **kwargs: Any
+ ) -> Awaitable[R]:
return threads.deferToThreadPool(
pool._reactor,
pool.threadpool,
pool._runInteraction,
- interaction,
+ desc,
+ func,
*args,
**kwargs,
)
- pool.runWithConnection = runWithConnection
- pool.runInteraction = runInteraction
+ pool.runWithConnection = runWithConnection # type: ignore[assignment]
+ pool.runInteraction = runInteraction # type: ignore[assignment]
# Replace the thread pool with a threadless 'thread' pool
- pool.threadpool = ThreadPool(clock._reactor)
+ pool.threadpool = ThreadPool(clock._reactor) # type: ignore[assignment]
pool.running = True
# We've just changed the Databases to run DB transactions on the same
@@ -573,7 +648,7 @@ def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
@implementer(ITransport)
-@attr.s(cmp=False)
+@attr.s(cmp=False, auto_attribs=True)
class FakeTransport:
"""
A twisted.internet.interfaces.ITransport implementation which sends all its data
@@ -588,48 +663,50 @@ class FakeTransport:
If you want bidirectional communication, you'll need two instances.
"""
- other = attr.ib()
+ other: IProtocol
"""The Protocol object which will receive any data written to this transport.
-
- :type: twisted.internet.interfaces.IProtocol
"""
- _reactor = attr.ib()
+ _reactor: IReactorTime
"""Test reactor
-
- :type: twisted.internet.interfaces.IReactorTime
"""
- _protocol = attr.ib(default=None)
+ _protocol: Optional[IProtocol] = None
"""The Protocol which is producing data for this transport. Optional, but if set
will get called back for connectionLost() notifications etc.
"""
- _peer_address: Optional[IAddress] = attr.ib(default=None)
+ _peer_address: IAddress = attr.Factory(
+ lambda: address.IPv4Address("TCP", "127.0.0.1", 5678)
+ )
"""The value to be returned by getPeer"""
- _host_address: Optional[IAddress] = attr.ib(default=None)
+ _host_address: IAddress = attr.Factory(
+ lambda: address.IPv4Address("TCP", "127.0.0.1", 1234)
+ )
"""The value to be returned by getHost"""
disconnecting = False
disconnected = False
connected = True
- buffer = attr.ib(default=b"")
- producer = attr.ib(default=None)
- autoflush = attr.ib(default=True)
+ buffer: bytes = b""
+ producer: Optional[IPushProducer] = None
+ autoflush: bool = True
- def getPeer(self) -> Optional[IAddress]:
+ def getPeer(self) -> IAddress:
return self._peer_address
- def getHost(self) -> Optional[IAddress]:
+ def getHost(self) -> IAddress:
return self._host_address
- def loseConnection(self, reason=None):
+ def loseConnection(self) -> None:
if not self.disconnecting:
- logger.info("FakeTransport: loseConnection(%s)", reason)
+ logger.info("FakeTransport: loseConnection()")
self.disconnecting = True
if self._protocol:
- self._protocol.connectionLost(reason)
+ self._protocol.connectionLost(
+ Failure(RuntimeError("FakeTransport.loseConnection()"))
+ )
# if we still have data to write, delay until that is done
if self.buffer:
@@ -640,38 +717,38 @@ class FakeTransport:
self.connected = False
self.disconnected = True
- def abortConnection(self):
+ def abortConnection(self) -> None:
logger.info("FakeTransport: abortConnection()")
if not self.disconnecting:
self.disconnecting = True
if self._protocol:
- self._protocol.connectionLost(None)
+ self._protocol.connectionLost(None) # type: ignore[arg-type]
self.disconnected = True
- def pauseProducing(self):
+ def pauseProducing(self) -> None:
if not self.producer:
return
self.producer.pauseProducing()
- def resumeProducing(self):
+ def resumeProducing(self) -> None:
if not self.producer:
return
self.producer.resumeProducing()
- def unregisterProducer(self):
+ def unregisterProducer(self) -> None:
if not self.producer:
return
self.producer = None
- def registerProducer(self, producer, streaming):
+ def registerProducer(self, producer: IPushProducer, streaming: bool) -> None:
self.producer = producer
self.producerStreaming = streaming
- def _produce():
+ def _produce() -> None:
if not self.producer:
# we've been unregistered
return
@@ -683,7 +760,7 @@ class FakeTransport:
if not streaming:
self._reactor.callLater(0.0, _produce)
- def write(self, byt):
+ def write(self, byt: bytes) -> None:
if self.disconnecting:
raise Exception("Writing to disconnecting FakeTransport")
@@ -695,11 +772,11 @@ class FakeTransport:
if self.autoflush:
self._reactor.callLater(0.0, self.flush)
- def writeSequence(self, seq):
+ def writeSequence(self, seq: Iterable[bytes]) -> None:
for x in seq:
self.write(x)
- def flush(self, maxbytes=None):
+ def flush(self, maxbytes: Optional[int] = None) -> None:
if not self.buffer:
# nothing to do. Don't write empty buffers: it upsets the
# TLSMemoryBIOProtocol
@@ -750,17 +827,17 @@ def connect_client(
class TestHomeServer(HomeServer):
- DATASTORE_CLASS = DataStore
+ DATASTORE_CLASS = DataStore # type: ignore[assignment]
def setup_test_homeserver(
- cleanup_func,
- name="test",
- config=None,
- reactor=None,
+ cleanup_func: Callable[[Callable[[], None]], None],
+ name: str = "test",
+ config: Optional[HomeServerConfig] = None,
+ reactor: Optional[ISynapseReactor] = None,
homeserver_to_use: Type[HomeServer] = TestHomeServer,
- **kwargs,
-):
+ **kwargs: Any,
+) -> HomeServer:
"""
Setup a homeserver suitable for running tests against. Keyword arguments
are passed to the Homeserver constructor.
@@ -775,13 +852,14 @@ def setup_test_homeserver(
HomeserverTestCase.
"""
if reactor is None:
- from twisted.internet import reactor
+ from twisted.internet import reactor as _reactor
+
+ reactor = cast(ISynapseReactor, _reactor)
if config is None:
config = default_config(name, parse=True)
config.caches.resize_all_caches()
- config.ldap_enabled = False
if "clock" not in kwargs:
kwargs["clock"] = MockClock()
@@ -832,6 +910,8 @@ def setup_test_homeserver(
# Create the database before we actually try and connect to it, based off
# the template database we generate in setupdb()
if isinstance(db_engine, PostgresEngine):
+ import psycopg2.extensions
+
db_conn = db_engine.module.connect(
database=POSTGRES_BASE_DB,
user=POSTGRES_USER,
@@ -839,6 +919,7 @@ def setup_test_homeserver(
port=POSTGRES_PORT,
password=POSTGRES_PASSWORD,
)
+ assert isinstance(db_conn, psycopg2.extensions.connection)
db_conn.autocommit = True
cur = db_conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
@@ -867,14 +948,15 @@ def setup_test_homeserver(
hs.setup_background_tasks()
if isinstance(db_engine, PostgresEngine):
- database = hs.get_datastores().databases[0]
+ database_pool = hs.get_datastores().databases[0]
# We need to do cleanup on PostgreSQL
- def cleanup():
+ def cleanup() -> None:
import psycopg2
+ import psycopg2.extensions
# Close all the db pools
- database._db_pool.close()
+ database_pool._db_pool.close()
dropped = False
@@ -886,6 +968,7 @@ def setup_test_homeserver(
port=POSTGRES_PORT,
password=POSTGRES_PASSWORD,
)
+ assert isinstance(db_conn, psycopg2.extensions.connection)
db_conn.autocommit = True
cur = db_conn.cursor()
@@ -918,23 +1001,23 @@ def setup_test_homeserver(
# Need to let the HS build an auth handler and then mess with it
# because AuthHandler's constructor requires the HS, so we can't make one
# beforehand and pass it in to the HS's constructor (chicken / egg)
- async def hash(p):
+ async def hash(p: str) -> str:
return hashlib.md5(p.encode("utf8")).hexdigest()
- hs.get_auth_handler().hash = hash
+ hs.get_auth_handler().hash = hash # type: ignore[assignment]
- async def validate_hash(p, h):
+ async def validate_hash(p: str, h: str) -> bool:
return hashlib.md5(p.encode("utf8")).hexdigest() == h
- hs.get_auth_handler().validate_hash = validate_hash
+ hs.get_auth_handler().validate_hash = validate_hash # type: ignore[assignment]
# Make the threadpool and database transactions synchronous for testing.
_make_test_homeserver_synchronous(hs)
# Load any configured modules into the homeserver
module_api = hs.get_module_api()
- for module, config in hs.config.modules.loaded_modules:
- module(config=config, api=module_api)
+ for module, module_config in hs.config.modules.loaded_modules:
+ module(config=module_config, api=module_api)
load_legacy_spam_checkers(hs)
load_legacy_third_party_event_rules(hs)
diff --git a/tests/unittest.py b/tests/unittest.py
index c1cb5933fa..b21e7f1221 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -45,7 +45,7 @@ from typing_extensions import Concatenate, ParamSpec, Protocol
from twisted.internet.defer import Deferred, ensureDeferred
from twisted.python.failure import Failure
from twisted.python.threadpool import ThreadPool
-from twisted.test.proto_helpers import MemoryReactor
+from twisted.test.proto_helpers import MemoryReactor, MemoryReactorClock
from twisted.trial import unittest
from twisted.web.resource import Resource
from twisted.web.server import Request
@@ -82,7 +82,7 @@ from tests.server import (
)
from tests.test_utils import event_injection, setup_awaitable_errors
from tests.test_utils.logging_setup import setup_logging
-from tests.utils import default_config, setupdb
+from tests.utils import checked_cast, default_config, setupdb
setupdb()
setup_logging()
@@ -296,7 +296,12 @@ class HomeserverTestCase(TestCase):
from tests.rest.client.utils import RestHelper
- self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None))
+ self.helper = RestHelper(
+ self.hs,
+ checked_cast(MemoryReactorClock, self.hs.get_reactor()),
+ self.site,
+ getattr(self, "user_id", None),
+ )
if hasattr(self, "user_id"):
if self.hijack_auth:
|