diff options
Diffstat (limited to 'tests/server.py')
-rw-r--r-- | tests/server.py | 38 |
1 files changed, 21 insertions, 17 deletions
diff --git a/tests/server.py b/tests/server.py index ce017ca0f6..df3f1564c9 100644 --- a/tests/server.py +++ b/tests/server.py @@ -43,6 +43,7 @@ from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed from twisted.internet.error import DNSLookupError from twisted.internet.interfaces import ( IAddress, + IConsumer, IHostnameResolver, IProtocol, IPullProducer, @@ -53,11 +54,7 @@ from twisted.internet.interfaces import ( ITransport, ) from twisted.python.failure import Failure -from twisted.test.proto_helpers import ( - AccumulatingProtocol, - MemoryReactor, - MemoryReactorClock, -) +from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock from twisted.web.http_headers import Headers from twisted.web.resource import IResource from twisted.web.server import Request, Site @@ -96,6 +93,7 @@ class TimedOutException(Exception): """ +@implementer(IConsumer) @attr.s(auto_attribs=True) class FakeChannel: """ @@ -104,7 +102,7 @@ class FakeChannel: """ site: Union[Site, "FakeSite"] - _reactor: MemoryReactor + _reactor: MemoryReactorClock result: dict = attr.Factory(dict) _ip: str = "127.0.0.1" _producer: Optional[Union[IPullProducer, IPushProducer]] = None @@ -122,7 +120,7 @@ class FakeChannel: self._request = request @property - def json_body(self): + def json_body(self) -> JsonDict: return json.loads(self.text_body) @property @@ -140,7 +138,7 @@ class FakeChannel: return self.result.get("done", False) @property - def code(self): + def code(self) -> int: if not self.result: raise Exception("No result yet.") return int(self.result["code"]) @@ -160,7 +158,7 @@ class FakeChannel: self.result["reason"] = reason self.result["headers"] = headers - def write(self, content): + def write(self, content: bytes) -> None: assert isinstance(content, bytes), "Should be bytes! " + repr(content) if "body" not in self.result: @@ -168,11 +166,16 @@ class FakeChannel: self.result["body"] += content - def registerProducer(self, producer, streaming): + # Type ignore: mypy doesn't like the fact that producer isn't an IProducer. + def registerProducer( # type: ignore[override] + self, + producer: Union[IPullProducer, IPushProducer], + streaming: bool, + ) -> None: self._producer = producer self.producerStreaming = streaming - def _produce(): + def _produce() -> None: if self._producer: self._producer.resumeProducing() self._reactor.callLater(0.1, _produce) @@ -180,31 +183,32 @@ class FakeChannel: if not streaming: self._reactor.callLater(0.0, _produce) - def unregisterProducer(self): + def unregisterProducer(self) -> None: if self._producer is None: return self._producer = None - def requestDone(self, _self): + def requestDone(self, _self: Request) -> None: self.result["done"] = True if isinstance(_self, SynapseRequest): + assert _self.logcontext is not None self.resource_usage = _self.logcontext.get_resource_usage() - def getPeer(self): + def getPeer(self) -> IAddress: # We give an address so that getClientAddress/getClientIP returns a non null entry, # causing us to record the MAU return address.IPv4Address("TCP", self._ip, 3423) - def getHost(self): + def getHost(self) -> IAddress: # this is called by Request.__init__ to configure Request.host. return address.IPv4Address("TCP", "127.0.0.1", 8888) - def isSecure(self): + def isSecure(self) -> bool: return False @property - def transport(self): + def transport(self) -> "FakeChannel": return self def await_result(self, timeout_ms: int = 1000) -> None: |