summary refs log tree commit diff
path: root/tests/server.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/server.py71
1 files changed, 61 insertions, 10 deletions
diff --git a/tests/server.py b/tests/server.py

index 3e377585ce..f01708b77f 100644 --- a/tests/server.py +++ b/tests/server.py
@@ -58,6 +58,7 @@ import twisted from twisted.enterprise import adbapi from twisted.internet import address, tcp, threads, udp from twisted.internet._resolver import SimpleResolverComplexifier +from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed from twisted.internet.error import DNSLookupError from twisted.internet.interfaces import ( @@ -73,6 +74,7 @@ from twisted.internet.interfaces import ( IReactorPluggableNameResolver, IReactorTime, IResolverSimple, + ITCPTransport, ITransport, ) from twisted.internet.protocol import ClientFactory, DatagramProtocol, Factory @@ -223,9 +225,9 @@ class FakeChannel: new_headers.addRawHeader(k, v) headers = new_headers - assert isinstance( - headers, Headers - ), f"headers are of the wrong type: {headers!r}" + assert isinstance(headers, Headers), ( + f"headers are of the wrong type: {headers!r}" + ) self.result["headers"] = headers @@ -341,7 +343,6 @@ class FakeSite: self, resource: IResource, reactor: IReactorTime, - experimental_cors_msc3886: bool = False, ): """ @@ -350,7 +351,6 @@ class FakeSite: """ self._resource = resource self.reactor = reactor - self.experimental_cors_msc3886 = experimental_cors_msc3886 def getResourceFor(self, request: Request) -> IResource: return self._resource @@ -780,7 +780,7 @@ def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]: return clock, hs_clock -@implementer(ITransport) +@implementer(ITCPTransport) @attr.s(cmp=False, auto_attribs=True) class FakeTransport: """ @@ -809,12 +809,12 @@ class FakeTransport: will get called back for connectionLost() notifications etc. """ - _peer_address: IAddress = attr.Factory( + _peer_address: Union[IPv4Address, IPv6Address] = attr.Factory( lambda: address.IPv4Address("TCP", "127.0.0.1", 5678) ) """The value to be returned by getPeer""" - _host_address: IAddress = attr.Factory( + _host_address: Union[IPv4Address, IPv6Address] = attr.Factory( lambda: address.IPv4Address("TCP", "127.0.0.1", 1234) ) """The value to be returned by getHost""" @@ -826,10 +826,10 @@ class FakeTransport: producer: Optional[IPushProducer] = None autoflush: bool = True - def getPeer(self) -> IAddress: + def getPeer(self) -> Union[IPv4Address, IPv6Address]: return self._peer_address - def getHost(self) -> IAddress: + def getHost(self) -> Union[IPv4Address, IPv6Address]: return self._host_address def loseConnection(self) -> None: @@ -939,6 +939,51 @@ class FakeTransport: logger.info("FakeTransport: Buffer now empty, completing disconnect") self.disconnected = True + ## ITCPTransport methods. ## + + def loseWriteConnection(self) -> None: + """ + Half-close the write side of a TCP connection. + + If the protocol instance this is attached to provides + IHalfCloseableProtocol, it will get notified when the operation is + done. When closing write connection, as with loseConnection this will + only happen when buffer has emptied and there is no registered + producer. + """ + raise NotImplementedError() + + def getTcpNoDelay(self) -> bool: + """ + Return if C{TCP_NODELAY} is enabled. + """ + return False + + def setTcpNoDelay(self, enabled: bool) -> None: + """ + Enable/disable C{TCP_NODELAY}. + + Enabling C{TCP_NODELAY} turns off Nagle's algorithm. Small packets are + sent sooner, possibly at the expense of overall throughput. + """ + # Ignore setting this. + + def getTcpKeepAlive(self) -> bool: + """ + Return if C{SO_KEEPALIVE} is enabled. + """ + return False + + def setTcpKeepAlive(self, enabled: bool) -> None: + """ + Enable/disable C{SO_KEEPALIVE}. + + Enabling C{SO_KEEPALIVE} sends packets periodically when the connection + is otherwise idle, usually once every two hours. They are intended + to allow detection of lost peers in a non-infinite amount of time. + """ + # Ignore setting this. + def connect_client( reactor: ThreadedMemoryReactorClock, client_id: int @@ -1166,6 +1211,12 @@ def setup_test_homeserver( hs.get_auth_handler().validate_hash = validate_hash # type: ignore[assignment] + # We need to replace the media threadpool with the fake test threadpool. + def thread_pool() -> threadpool.ThreadPool: + return reactor.getThreadPool() + + hs.get_media_sender_thread_pool = thread_pool # type: ignore[method-assign] + # Load any configured modules into the homeserver module_api = hs.get_module_api() for module, module_config in hs.config.modules.loaded_modules: