diff options
author | David Robertson <davidr@element.io> | 2022-08-21 22:27:04 +0100 |
---|---|---|
committer | David Robertson <davidr@element.io> | 2022-08-21 22:27:04 +0100 |
commit | 48ae00e5bd6889b13d53d48aae731b50ac6eec3e (patch) | |
tree | ab977f2eb233d21dffbb5edb96aba42d223db831 | |
parent | annotate getResourceFor (diff) | |
download | synapse-48ae00e5bd6889b13d53d48aae731b50ac6eec3e.tar.xz |
Annotate ThreadedMemoryReactorClock
-rw-r--r-- | tests/server.py | 52 |
1 files changed, 39 insertions, 13 deletions
diff --git a/tests/server.py b/tests/server.py index 1e6b160409..0a90c850f4 100644 --- a/tests/server.py +++ b/tests/server.py @@ -22,12 +22,14 @@ import warnings from collections import deque from io import SEEK_END, BytesIO from typing import ( + Any, Callable, Dict, Iterable, List, MutableMapping, Optional, + Sequence, Tuple, Type, Union, @@ -35,8 +37,7 @@ from typing import ( from unittest.mock import Mock import attr -from twisted.web.iweb import IRequest -from typing_extensions import Deque +from typing_extensions import Deque, ParamSpec from zope.interface import implementer from twisted.internet import address, threads, udp @@ -45,19 +46,23 @@ from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed from twisted.internet.error import DNSLookupError from twisted.internet.interfaces import ( IAddress, + IConnector, IConsumer, IHostnameResolver, IProtocol, IPullProducer, IPushProducer, + IReactorFromThreads, IReactorPluggableNameResolver, IReactorTime, IResolverSimple, ITransport, ) +from twisted.internet.protocol import ClientFactory, DatagramProtocol from twisted.python.failure import Failure from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock from twisted.web.http_headers import Headers +from twisted.web.iweb import IRequest from twisted.web.resource import IResource from twisted.web.server import Request, Site @@ -91,6 +96,7 @@ logger = logging.getLogger(__name__) # the type of thing that can be passed into `make_request` in the headers list CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]] +P = ParamSpec("P") class TimedOutException(Exception): @@ -392,17 +398,17 @@ def make_request( return channel -@implementer(IReactorPluggableNameResolver) +@implementer(IReactorPluggableNameResolver, IReactorFromThreads) class ThreadedMemoryReactorClock(MemoryReactorClock): """ A MemoryReactorClock that supports callFromThread. """ - def __init__(self): + def __init__(self) -> None: self.threadpool = ThreadPool(self) - self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {} - self._udp = [] + self._tcp_callbacks: Dict[Tuple[str, int], Callable[[], None]] = {} + self._udp: List[udp.Port] = [] self.lookups: Dict[str, str] = {} self._thread_callbacks: Deque[Callable[[], None]] = deque() @@ -410,7 +416,9 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): @implementer(IResolverSimple) class FakeResolver: - def getHostByName(self, name, timeout=None): + def getHostByName( + self, name: str, timeout: Sequence[int] = () + ) -> "Deferred[str]": if name not in lookups: return fail(DNSLookupError("OH NO: unknown %s" % (name,))) return succeed(lookups[name]) @@ -421,13 +429,22 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver: raise NotImplementedError() - def listenUDP(self, port, protocol, interface="", maxPacketSize=8196): + def listenUDP( + self, + port: int, + protocol: DatagramProtocol, + interface: str = "", + maxPacketSize: int = 8196, + ) -> udp.Port: p = udp.Port(port, protocol, interface, maxPacketSize, self) p.startListening() self._udp.append(p) return p - def callFromThread(self, callback, *args, **kwargs): + # Type-ignore: IReactorFromThreads doesn't use paramspec here. + def callFromThread( # type: ignore[override] + self, callback: Callable[P, Any], *args: P.args, **kwargs: P.kwargs + ) -> None: """ Make the callback fire in the next reactor iteration. """ @@ -436,10 +453,12 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): # separate queue. self._thread_callbacks.append(cb) - def getThreadPool(self): + def getThreadPool(self) -> "ThreadPool": return self.threadpool - def add_tcp_client_callback(self, host: str, port: int, callback: Callable): + def add_tcp_client_callback( + self, host: str, port: int, callback: Callable[[], None] + ) -> None: """Add a callback that will be invoked when we receive a connection attempt to the given IP/port using `connectTCP`. @@ -448,7 +467,14 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): """ self._tcp_callbacks[(host, port)] = callback - def connectTCP(self, host: str, port: int, factory, timeout=30, bindAddress=None): + def connectTCP( + self, + host: str, + port: int, + factory: ClientFactory, + timeout: float = 30, + bindAddress: Optional[Tuple[str, int]] = None, + ) -> IConnector: """Fake L{IReactorTCP.connectTCP}.""" conn = super().connectTCP( @@ -461,7 +487,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): return conn - def advance(self, amount): + def advance(self, amount: float) -> None: # first advance our reactor's time, and run any "callLater" callbacks that # makes ready super().advance(amount) |