summary refs log tree commit diff
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2022-08-21 22:27:04 +0100
committerDavid Robertson <davidr@element.io>2022-08-21 22:27:04 +0100
commit48ae00e5bd6889b13d53d48aae731b50ac6eec3e (patch)
treeab977f2eb233d21dffbb5edb96aba42d223db831
parentannotate getResourceFor (diff)
downloadsynapse-48ae00e5bd6889b13d53d48aae731b50ac6eec3e.tar.xz
Annotate ThreadedMemoryReactorClock
-rw-r--r--tests/server.py52
1 files changed, 39 insertions, 13 deletions
diff --git a/tests/server.py b/tests/server.py
index 1e6b160409..0a90c850f4 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -22,12 +22,14 @@ import warnings
 from collections import deque
 from io import SEEK_END, BytesIO
 from typing import (
+    Any,
     Callable,
     Dict,
     Iterable,
     List,
     MutableMapping,
     Optional,
+    Sequence,
     Tuple,
     Type,
     Union,
@@ -35,8 +37,7 @@ from typing import (
 from unittest.mock import Mock
 
 import attr
-from twisted.web.iweb import IRequest
-from typing_extensions import Deque
+from typing_extensions import Deque, ParamSpec
 from zope.interface import implementer
 
 from twisted.internet import address, threads, udp
@@ -45,19 +46,23 @@ from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
 from twisted.internet.error import DNSLookupError
 from twisted.internet.interfaces import (
     IAddress,
+    IConnector,
     IConsumer,
     IHostnameResolver,
     IProtocol,
     IPullProducer,
     IPushProducer,
+    IReactorFromThreads,
     IReactorPluggableNameResolver,
     IReactorTime,
     IResolverSimple,
     ITransport,
 )
+from twisted.internet.protocol import ClientFactory, DatagramProtocol
 from twisted.python.failure import Failure
 from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
 from twisted.web.http_headers import Headers
+from twisted.web.iweb import IRequest
 from twisted.web.resource import IResource
 from twisted.web.server import Request, Site
 
@@ -91,6 +96,7 @@ logger = logging.getLogger(__name__)
 
 # the type of thing that can be passed into `make_request` in the headers list
 CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]
+P = ParamSpec("P")
 
 
 class TimedOutException(Exception):
@@ -392,17 +398,17 @@ def make_request(
     return channel
 
 
-@implementer(IReactorPluggableNameResolver)
+@implementer(IReactorPluggableNameResolver, IReactorFromThreads)
 class ThreadedMemoryReactorClock(MemoryReactorClock):
     """
     A MemoryReactorClock that supports callFromThread.
     """
 
-    def __init__(self):
+    def __init__(self) -> None:
         self.threadpool = ThreadPool(self)
 
-        self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {}
-        self._udp = []
+        self._tcp_callbacks: Dict[Tuple[str, int], Callable[[], None]] = {}
+        self._udp: List[udp.Port] = []
         self.lookups: Dict[str, str] = {}
         self._thread_callbacks: Deque[Callable[[], None]] = deque()
 
@@ -410,7 +416,9 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
 
         @implementer(IResolverSimple)
         class FakeResolver:
-            def getHostByName(self, name, timeout=None):
+            def getHostByName(
+                self, name: str, timeout: Sequence[int] = ()
+            ) -> "Deferred[str]":
                 if name not in lookups:
                     return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
                 return succeed(lookups[name])
@@ -421,13 +429,22 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
     def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
         raise NotImplementedError()
 
-    def listenUDP(self, port, protocol, interface="", maxPacketSize=8196):
+    def listenUDP(
+        self,
+        port: int,
+        protocol: DatagramProtocol,
+        interface: str = "",
+        maxPacketSize: int = 8196,
+    ) -> udp.Port:
         p = udp.Port(port, protocol, interface, maxPacketSize, self)
         p.startListening()
         self._udp.append(p)
         return p
 
-    def callFromThread(self, callback, *args, **kwargs):
+    # Type-ignore: IReactorFromThreads doesn't use paramspec here.
+    def callFromThread(  # type: ignore[override]
+        self, callback: Callable[P, Any], *args: P.args, **kwargs: P.kwargs
+    ) -> None:
         """
         Make the callback fire in the next reactor iteration.
         """
@@ -436,10 +453,12 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
         # separate queue.
         self._thread_callbacks.append(cb)
 
-    def getThreadPool(self):
+    def getThreadPool(self) -> "ThreadPool":
         return self.threadpool
 
-    def add_tcp_client_callback(self, host: str, port: int, callback: Callable):
+    def add_tcp_client_callback(
+        self, host: str, port: int, callback: Callable[[], None]
+    ) -> None:
         """Add a callback that will be invoked when we receive a connection
         attempt to the given IP/port using `connectTCP`.
 
@@ -448,7 +467,14 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
         """
         self._tcp_callbacks[(host, port)] = callback
 
-    def connectTCP(self, host: str, port: int, factory, timeout=30, bindAddress=None):
+    def connectTCP(
+        self,
+        host: str,
+        port: int,
+        factory: ClientFactory,
+        timeout: float = 30,
+        bindAddress: Optional[Tuple[str, int]] = None,
+    ) -> IConnector:
         """Fake L{IReactorTCP.connectTCP}."""
 
         conn = super().connectTCP(
@@ -461,7 +487,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
 
         return conn
 
-    def advance(self, amount):
+    def advance(self, amount: float) -> None:
         # first advance our reactor's time, and run any "callLater" callbacks that
         # makes ready
         super().advance(amount)