diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index f6a6aed35e..20940c8107 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -22,6 +22,7 @@ from twisted.internet.protocol import Protocol
from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel
from twisted.web.resource import Resource
+from twisted.web.server import Request, Site
from synapse.app.generic_worker import (
GenericWorkerReplicationHandler,
@@ -32,7 +33,10 @@ from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
-from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.replication.tcp.resource import (
+ ReplicationStreamProtocolFactory,
+ ServerReplicationStreamProtocol,
+)
from synapse.server import HomeServer
from synapse.util import Clock
@@ -59,7 +63,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer()
- self.server = server_factory.buildProtocol(None)
+ self.server = server_factory.buildProtocol(
+ None
+ ) # type: ServerReplicationStreamProtocol
# Make a new HomeServer object for the worker
self.reactor.lookups["testserv"] = "1.2.3.4"
@@ -155,9 +161,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
request_factory = OneShotRequestFactory()
# Set up the server side protocol
- channel = _PushHTTPChannel(self.reactor)
- channel.requestFactory = request_factory
- channel.site = self.site
+ channel = _PushHTTPChannel(self.reactor, request_factory, self.site)
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@@ -188,8 +192,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
fetching updates for given stream.
"""
+ path = request.path # type: bytes # type: ignore
self.assertRegex(
- request.path,
+ path,
br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
% (stream_name.encode("ascii"),),
)
@@ -390,9 +395,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
request_factory = OneShotRequestFactory()
# Set up the server side protocol
- channel = _PushHTTPChannel(self.reactor)
- channel.requestFactory = request_factory
- channel.site = self._hs_to_site[hs]
+ channel = _PushHTTPChannel(self.reactor, request_factory, self._hs_to_site[hs])
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@@ -475,9 +478,13 @@ class _PushHTTPChannel(HTTPChannel):
makes it very hard to test.
"""
- def __init__(self, reactor: IReactorTime):
+ def __init__(
+ self, reactor: IReactorTime, request_factory: Callable[..., Request], site: Site
+ ):
super().__init__()
self.reactor = reactor
+ self.requestFactory = request_factory
+ self.site = site
self._pull_to_push_producer = None # type: Optional[_PullToPushProducer]
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 36d1e6bc4a..9f77125fd4 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -105,7 +105,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
self.assertEqual(test_body, body)
-@attr.s
+@attr.s(slots=True, frozen=True)
class _TestImage:
"""An image for testing thumbnailing with the expected results
@@ -117,13 +117,15 @@ class _TestImage:
test should just check for success.
expected_scaled: The expected bytes from scaled thumbnailing, or None if
test should just check for a valid image returned.
+ expected_found: True if the file should exist on the server, or False if
+ a 404 is expected.
"""
data = attr.ib(type=bytes)
content_type = attr.ib(type=bytes)
extension = attr.ib(type=bytes)
- expected_cropped = attr.ib(type=Optional[bytes])
- expected_scaled = attr.ib(type=Optional[bytes])
+ expected_cropped = attr.ib(type=Optional[bytes], default=None)
+ expected_scaled = attr.ib(type=Optional[bytes], default=None)
expected_found = attr.ib(default=True, type=bool)
@@ -153,6 +155,21 @@ class _TestImage:
),
),
),
+ # small png with transparency.
+ (
+ _TestImage(
+ unhexlify(
+ b"89504e470d0a1a0a0000000d49484452000000010000000101000"
+ b"00000376ef9240000000274524e5300010194fdae0000000a4944"
+ b"4154789c636800000082008177cd72b60000000049454e44ae426"
+ b"082"
+ ),
+ b"image/png",
+ b".png",
+ # Note that we don't check the output since it varies across
+ # different versions of Pillow.
+ ),
+ ),
# small lossless webp
(
_TestImage(
@@ -162,8 +179,6 @@ class _TestImage:
),
b"image/webp",
b".webp",
- None,
- None,
),
),
# an empty file
@@ -172,9 +187,7 @@ class _TestImage:
b"",
b"image/gif",
b".gif",
- None,
- None,
- False,
+ expected_found=False,
),
),
],
diff --git a/tests/server.py b/tests/server.py
index 939a0008ca..863f6da738 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -188,7 +188,7 @@ class FakeSite:
def make_request(
reactor,
- site: Site,
+ site: Union[Site, FakeSite],
method,
path,
content=b"",
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index 52ae5c5713..74568b34f8 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -28,7 +28,7 @@ class ToTwistedHandler(logging.Handler):
def emit(self, record):
log_entry = self.format(record)
log_level = record.levelname.lower().replace("warning", "warn")
- self.tx_log.emit(
+ self.tx_log.emit( # type: ignore
twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry
)
|