diff --git a/tests/server.py b/tests/server.py
index 3e377585ce..f01708b77f 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -58,6 +58,7 @@ import twisted
from twisted.enterprise import adbapi
from twisted.internet import address, tcp, threads, udp
from twisted.internet._resolver import SimpleResolverComplexifier
+from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import (
@@ -73,6 +74,7 @@ from twisted.internet.interfaces import (
IReactorPluggableNameResolver,
IReactorTime,
IResolverSimple,
+ ITCPTransport,
ITransport,
)
from twisted.internet.protocol import ClientFactory, DatagramProtocol, Factory
@@ -223,9 +225,9 @@ class FakeChannel:
new_headers.addRawHeader(k, v)
headers = new_headers
- assert isinstance(
- headers, Headers
- ), f"headers are of the wrong type: {headers!r}"
+ assert isinstance(headers, Headers), (
+ f"headers are of the wrong type: {headers!r}"
+ )
self.result["headers"] = headers
@@ -341,7 +343,6 @@ class FakeSite:
self,
resource: IResource,
reactor: IReactorTime,
- experimental_cors_msc3886: bool = False,
):
"""
@@ -350,7 +351,6 @@ class FakeSite:
"""
self._resource = resource
self.reactor = reactor
- self.experimental_cors_msc3886 = experimental_cors_msc3886
def getResourceFor(self, request: Request) -> IResource:
return self._resource
@@ -780,7 +780,7 @@ def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
return clock, hs_clock
-@implementer(ITransport)
+@implementer(ITCPTransport)
@attr.s(cmp=False, auto_attribs=True)
class FakeTransport:
"""
@@ -809,12 +809,12 @@ class FakeTransport:
will get called back for connectionLost() notifications etc.
"""
- _peer_address: IAddress = attr.Factory(
+ _peer_address: Union[IPv4Address, IPv6Address] = attr.Factory(
lambda: address.IPv4Address("TCP", "127.0.0.1", 5678)
)
"""The value to be returned by getPeer"""
- _host_address: IAddress = attr.Factory(
+ _host_address: Union[IPv4Address, IPv6Address] = attr.Factory(
lambda: address.IPv4Address("TCP", "127.0.0.1", 1234)
)
"""The value to be returned by getHost"""
@@ -826,10 +826,10 @@ class FakeTransport:
producer: Optional[IPushProducer] = None
autoflush: bool = True
- def getPeer(self) -> IAddress:
+ def getPeer(self) -> Union[IPv4Address, IPv6Address]:
return self._peer_address
- def getHost(self) -> IAddress:
+ def getHost(self) -> Union[IPv4Address, IPv6Address]:
return self._host_address
def loseConnection(self) -> None:
@@ -939,6 +939,51 @@ class FakeTransport:
logger.info("FakeTransport: Buffer now empty, completing disconnect")
self.disconnected = True
+ ## ITCPTransport methods. ##
+
+ def loseWriteConnection(self) -> None:
+ """
+ Half-close the write side of a TCP connection.
+
+ If the protocol instance this is attached to provides
+ IHalfCloseableProtocol, it will get notified when the operation is
+ done. When closing write connection, as with loseConnection this will
+ only happen when buffer has emptied and there is no registered
+ producer.
+ """
+ raise NotImplementedError()
+
+ def getTcpNoDelay(self) -> bool:
+ """
+ Return if C{TCP_NODELAY} is enabled.
+ """
+ return False
+
+ def setTcpNoDelay(self, enabled: bool) -> None:
+ """
+ Enable/disable C{TCP_NODELAY}.
+
+ Enabling C{TCP_NODELAY} turns off Nagle's algorithm. Small packets are
+ sent sooner, possibly at the expense of overall throughput.
+ """
+ # Ignore setting this.
+
+ def getTcpKeepAlive(self) -> bool:
+ """
+ Return if C{SO_KEEPALIVE} is enabled.
+ """
+ return False
+
+ def setTcpKeepAlive(self, enabled: bool) -> None:
+ """
+ Enable/disable C{SO_KEEPALIVE}.
+
+ Enabling C{SO_KEEPALIVE} sends packets periodically when the connection
+ is otherwise idle, usually once every two hours. They are intended
+ to allow detection of lost peers in a non-infinite amount of time.
+ """
+ # Ignore setting this.
+
def connect_client(
reactor: ThreadedMemoryReactorClock, client_id: int
@@ -1166,6 +1211,12 @@ def setup_test_homeserver(
hs.get_auth_handler().validate_hash = validate_hash # type: ignore[assignment]
+ # We need to replace the media threadpool with the fake test threadpool.
+ def thread_pool() -> threadpool.ThreadPool:
+ return reactor.getThreadPool()
+
+ hs.get_media_sender_thread_pool = thread_pool # type: ignore[method-assign]
+
# Load any configured modules into the homeserver
module_api = hs.get_module_api()
for module, module_config in hs.config.modules.loaded_modules:
|