diff --git a/tests/server.py b/tests/server.py
index ce017ca0f6..df3f1564c9 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -43,6 +43,7 @@ from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import (
IAddress,
+ IConsumer,
IHostnameResolver,
IProtocol,
IPullProducer,
@@ -53,11 +54,7 @@ from twisted.internet.interfaces import (
ITransport,
)
from twisted.python.failure import Failure
-from twisted.test.proto_helpers import (
- AccumulatingProtocol,
- MemoryReactor,
- MemoryReactorClock,
-)
+from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
from twisted.web.http_headers import Headers
from twisted.web.resource import IResource
from twisted.web.server import Request, Site
@@ -96,6 +93,7 @@ class TimedOutException(Exception):
"""
+@implementer(IConsumer)
@attr.s(auto_attribs=True)
class FakeChannel:
"""
@@ -104,7 +102,7 @@ class FakeChannel:
"""
site: Union[Site, "FakeSite"]
- _reactor: MemoryReactor
+ _reactor: MemoryReactorClock
result: dict = attr.Factory(dict)
_ip: str = "127.0.0.1"
_producer: Optional[Union[IPullProducer, IPushProducer]] = None
@@ -122,7 +120,7 @@ class FakeChannel:
self._request = request
@property
- def json_body(self):
+ def json_body(self) -> JsonDict:
return json.loads(self.text_body)
@property
@@ -140,7 +138,7 @@ class FakeChannel:
return self.result.get("done", False)
@property
- def code(self):
+ def code(self) -> int:
if not self.result:
raise Exception("No result yet.")
return int(self.result["code"])
@@ -160,7 +158,7 @@ class FakeChannel:
self.result["reason"] = reason
self.result["headers"] = headers
- def write(self, content):
+ def write(self, content: bytes) -> None:
assert isinstance(content, bytes), "Should be bytes! " + repr(content)
if "body" not in self.result:
@@ -168,11 +166,16 @@ class FakeChannel:
self.result["body"] += content
- def registerProducer(self, producer, streaming):
+ # Type ignore: mypy doesn't like the fact that producer isn't an IProducer.
+ def registerProducer( # type: ignore[override]
+ self,
+ producer: Union[IPullProducer, IPushProducer],
+ streaming: bool,
+ ) -> None:
self._producer = producer
self.producerStreaming = streaming
- def _produce():
+ def _produce() -> None:
if self._producer:
self._producer.resumeProducing()
self._reactor.callLater(0.1, _produce)
@@ -180,31 +183,32 @@ class FakeChannel:
if not streaming:
self._reactor.callLater(0.0, _produce)
- def unregisterProducer(self):
+ def unregisterProducer(self) -> None:
if self._producer is None:
return
self._producer = None
- def requestDone(self, _self):
+ def requestDone(self, _self: Request) -> None:
self.result["done"] = True
if isinstance(_self, SynapseRequest):
+ assert _self.logcontext is not None
self.resource_usage = _self.logcontext.get_resource_usage()
- def getPeer(self):
+ def getPeer(self) -> IAddress:
# We give an address so that getClientAddress/getClientIP returns a non null entry,
# causing us to record the MAU
return address.IPv4Address("TCP", self._ip, 3423)
- def getHost(self):
+ def getHost(self) -> IAddress:
# this is called by Request.__init__ to configure Request.host.
return address.IPv4Address("TCP", "127.0.0.1", 8888)
- def isSecure(self):
+ def isSecure(self) -> bool:
return False
@property
- def transport(self):
+ def transport(self) -> "FakeChannel":
return self
def await_result(self, timeout_ms: int = 1000) -> None:
|