summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/replication/test_multi_media_repo.py277
-rw-r--r--tests/server.py2
-rw-r--r--tests/test_utils/__init__.py34
-rw-r--r--tests/unittest.py6
-rw-r--r--tests/util/caches/test_descriptors.py60
5 files changed, 375 insertions, 4 deletions
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
new file mode 100644

index 0000000000..77c261dbf7 --- /dev/null +++ b/tests/replication/test_multi_media_repo.py
@@ -0,0 +1,277 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from binascii import unhexlify +from typing import Tuple + +from twisted.internet.protocol import Factory +from twisted.protocols.tls import TLSMemoryBIOFactory +from twisted.web.http import HTTPChannel +from twisted.web.server import Request + +from synapse.rest import admin +from synapse.rest.client.v1 import login +from synapse.server import HomeServer + +from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file +from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.server import FakeChannel, FakeTransport + +logger = logging.getLogger(__name__) + +test_server_connection_factory = None + + +class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): + """Checks running multiple media repos work correctly. + """ + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.user_id = self.register_user("user", "pass") + self.access_token = self.login("user", "pass") + + self.reactor.lookups["example.com"] = "127.0.0.2" + + def default_config(self): + conf = super().default_config() + conf["federation_custom_ca_list"] = [get_test_ca_cert_file()] + return conf + + def _get_media_req( + self, hs: HomeServer, target: str, media_id: str + ) -> Tuple[FakeChannel, Request]: + """Request some remote media from the given HS by calling the download + API. + + This then triggers an outbound request from the HS to the target. + + Returns: + The channel for the *client* request and the *outbound* request for + the media which the caller should respond to. + """ + + request, channel = self.make_request( + "GET", + "/{}/{}".format(target, media_id), + shorthand=False, + access_token=self.access_token, + ) + request.render(hs.get_media_repository_resource().children[b"download"]) + self.pump() + + clients = self.reactor.tcpClients + self.assertGreaterEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop() + + # build the test server + server_tls_protocol = _build_test_server(get_connection_factory()) + + # now, tell the client protocol factory to build the client protocol (it will be a + # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an + # HTTP11ClientProtocol) and wire the output of said protocol up to the server via + # a FakeTransport. + # + # Normally this would be done by the TCP socket code in Twisted, but we are + # stubbing that out here. + client_protocol = client_factory.buildProtocol(None) + client_protocol.makeConnection( + FakeTransport(server_tls_protocol, self.reactor, client_protocol) + ) + + # tell the server tls protocol to send its stuff back to the client, too + server_tls_protocol.makeConnection( + FakeTransport(client_protocol, self.reactor, server_tls_protocol) + ) + + # fish the test server back out of the server-side TLS protocol. + http_server = server_tls_protocol.wrappedProtocol + + # give the reactor a pump to get the TLS juices flowing. + self.reactor.pump((0.1,)) + + self.assertEqual(len(http_server.requests), 1) + request = http_server.requests[0] + + self.assertEqual(request.method, b"GET") + self.assertEqual( + request.path, + "/_matrix/media/r0/download/{}/{}".format(target, media_id).encode("utf-8"), + ) + self.assertEqual( + request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")] + ) + + return channel, request + + def test_basic(self): + """Test basic fetching of remote media from a single worker. + """ + hs1 = self.make_worker_hs("synapse.app.generic_worker") + + channel, request = self._get_media_req(hs1, "example.com:443", "ABC123") + + request.setResponseCode(200) + request.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"]) + request.write(b"Hello!") + request.finish() + + self.pump(0.1) + + self.assertEqual(channel.code, 200) + self.assertEqual(channel.result["body"], b"Hello!") + + def test_download_simple_file_race(self): + """Test that fetching remote media from two different processes at the + same time works. + """ + hs1 = self.make_worker_hs("synapse.app.generic_worker") + hs2 = self.make_worker_hs("synapse.app.generic_worker") + + start_count = self._count_remote_media() + + # Make two requests without responding to the outbound media requests. + channel1, request1 = self._get_media_req(hs1, "example.com:443", "ABC123") + channel2, request2 = self._get_media_req(hs2, "example.com:443", "ABC123") + + # Respond to the first outbound media request and check that the client + # request is successful + request1.setResponseCode(200) + request1.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"]) + request1.write(b"Hello!") + request1.finish() + + self.pump(0.1) + + self.assertEqual(channel1.code, 200, channel1.result["body"]) + self.assertEqual(channel1.result["body"], b"Hello!") + + # Now respond to the second with the same content. + request2.setResponseCode(200) + request2.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"]) + request2.write(b"Hello!") + request2.finish() + + self.pump(0.1) + + self.assertEqual(channel2.code, 200, channel2.result["body"]) + self.assertEqual(channel2.result["body"], b"Hello!") + + # We expect only one new file to have been persisted. + self.assertEqual(start_count + 1, self._count_remote_media()) + + def test_download_image_race(self): + """Test that fetching remote *images* from two different processes at + the same time works. + + This checks that races generating thumbnails are handled correctly. + """ + hs1 = self.make_worker_hs("synapse.app.generic_worker") + hs2 = self.make_worker_hs("synapse.app.generic_worker") + + start_count = self._count_remote_thumbnails() + + channel1, request1 = self._get_media_req(hs1, "example.com:443", "PIC1") + channel2, request2 = self._get_media_req(hs2, "example.com:443", "PIC1") + + png_data = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ) + + request1.setResponseCode(200) + request1.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"]) + request1.write(png_data) + request1.finish() + + self.pump(0.1) + + self.assertEqual(channel1.code, 200, channel1.result["body"]) + self.assertEqual(channel1.result["body"], png_data) + + request2.setResponseCode(200) + request2.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"]) + request2.write(png_data) + request2.finish() + + self.pump(0.1) + + self.assertEqual(channel2.code, 200, channel2.result["body"]) + self.assertEqual(channel2.result["body"], png_data) + + # We expect only three new thumbnails to have been persisted. + self.assertEqual(start_count + 3, self._count_remote_thumbnails()) + + def _count_remote_media(self) -> int: + """Count the number of files in our remote media directory. + """ + path = os.path.join( + self.hs.get_media_repository().primary_base_path, "remote_content" + ) + return sum(len(files) for _, _, files in os.walk(path)) + + def _count_remote_thumbnails(self) -> int: + """Count the number of files in our remote thumbnails directory. + """ + path = os.path.join( + self.hs.get_media_repository().primary_base_path, "remote_thumbnail" + ) + return sum(len(files) for _, _, files in os.walk(path)) + + +def get_connection_factory(): + # this needs to happen once, but not until we are ready to run the first test + global test_server_connection_factory + if test_server_connection_factory is None: + test_server_connection_factory = TestServerTLSConnectionFactory( + sanlist=[b"DNS:example.com"] + ) + return test_server_connection_factory + + +def _build_test_server(connection_creator): + """Construct a test server + + This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol + + Args: + connection_creator (IOpenSSLServerConnectionCreator): thing to build + SSL connections + sanlist (list[bytes]): list of the SAN entries for the cert returned + by the server + + Returns: + TLSMemoryBIOProtocol + """ + server_factory = Factory.forProtocol(HTTPChannel) + # Request.finish expects the factory to have a 'log' method. + server_factory.log = _log_request + + server_tls_factory = TLSMemoryBIOFactory( + connection_creator, isClient=False, wrappedFactory=server_factory + ) + + return server_tls_factory.buildProtocol(None) + + +def _log_request(request): + """Implements Factory.log, which is expected by Request.finish""" + logger.info("Completed request %s", request) diff --git a/tests/server.py b/tests/server.py
index b97003fa5a..3dd2cfc072 100644 --- a/tests/server.py +++ b/tests/server.py
@@ -46,7 +46,7 @@ class FakeChannel: site = attr.ib(type=Site) _reactor = attr.ib() - result = attr.ib(default=attr.Factory(dict)) + result = attr.ib(type=dict, default=attr.Factory(dict)) _producer = None @property diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index a298cc0fd3..d232b72264 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py
@@ -17,8 +17,10 @@ """ Utilities for running the unit tests """ +import sys +import warnings from asyncio import Future -from typing import Any, Awaitable, TypeVar +from typing import Any, Awaitable, Callable, TypeVar TV = TypeVar("TV") @@ -48,3 +50,33 @@ def make_awaitable(result: Any) -> Awaitable[Any]: future = Future() # type: ignore future.set_result(result) return future + + +def setup_awaitable_errors() -> Callable[[], None]: + """ + Convert warnings from a non-awaited coroutines into errors. + """ + warnings.simplefilter("error", RuntimeWarning) + + # unraisablehook was added in Python 3.8. + if not hasattr(sys, "unraisablehook"): + return lambda: None + + # State shared between unraisablehook and check_for_unraisable_exceptions. + unraisable_exceptions = [] + orig_unraisablehook = sys.unraisablehook # type: ignore + + def unraisablehook(unraisable): + unraisable_exceptions.append(unraisable.exc_value) + + def cleanup(): + """ + A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions. + """ + sys.unraisablehook = orig_unraisablehook # type: ignore + if unraisable_exceptions: + raise unraisable_exceptions.pop() + + sys.unraisablehook = unraisablehook # type: ignore + + return cleanup diff --git a/tests/unittest.py b/tests/unittest.py
index 257f465897..08cf9b10c5 100644 --- a/tests/unittest.py +++ b/tests/unittest.py
@@ -54,7 +54,7 @@ from tests.server import ( render, setup_test_homeserver, ) -from tests.test_utils import event_injection +from tests.test_utils import event_injection, setup_awaitable_errors from tests.test_utils.logging_setup import setup_logging from tests.utils import default_config, setupdb @@ -119,6 +119,10 @@ class TestCase(unittest.TestCase): logging.getLogger().setLevel(level) + # Trial messes with the warnings configuration, thus this has to be + # done in the context of an individual TestCase. + self.addCleanup(setup_awaitable_errors()) + return orig() @around(self) diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 2ad08f541b..cf1e3203a4 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py
@@ -29,13 +29,46 @@ from synapse.logging.context import ( make_deferred_yieldable, ) from synapse.util.caches import descriptors -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import cached, lru_cache from tests import unittest +from tests.test_utils import get_awaitable_result logger = logging.getLogger(__name__) +class LruCacheDecoratorTestCase(unittest.TestCase): + def test_base(self): + class Cls: + def __init__(self): + self.mock = mock.Mock() + + @lru_cache() + def fn(self, arg1, arg2): + return self.mock(arg1, arg2) + + obj = Cls() + obj.mock.return_value = "fish" + r = obj.fn(1, 2) + self.assertEqual(r, "fish") + obj.mock.assert_called_once_with(1, 2) + obj.mock.reset_mock() + + # a call with different params should call the mock again + obj.mock.return_value = "chips" + r = obj.fn(1, 3) + self.assertEqual(r, "chips") + obj.mock.assert_called_once_with(1, 3) + obj.mock.reset_mock() + + # the two values should now be cached + r = obj.fn(1, 2) + self.assertEqual(r, "fish") + r = obj.fn(1, 3) + self.assertEqual(r, "chips") + obj.mock.assert_not_called() + + def run_on_reactor(): d = defer.Deferred() reactor.callLater(0, d.callback, 0) @@ -362,6 +395,31 @@ class DescriptorTestCase(unittest.TestCase): d = obj.fn(1) self.failureResultOf(d, SynapseError) + def test_invalidate_cascade(self): + """Invalidations should cascade up through cache contexts""" + + class Cls: + @cached(cache_context=True) + async def func1(self, key, cache_context): + return await self.func2(key, on_invalidate=cache_context.invalidate) + + @cached(cache_context=True) + async def func2(self, key, cache_context): + return self.func3(key, on_invalidate=cache_context.invalidate) + + @lru_cache(cache_context=True) + def func3(self, key, cache_context): + self.invalidate = cache_context.invalidate + return 42 + + obj = Cls() + + top_invalidate = mock.Mock() + r = get_awaitable_result(obj.func1("k1", on_invalidate=top_invalidate)) + self.assertEqual(r, 42) + obj.invalidate() + top_invalidate.assert_called_once() + class CacheDecoratorTestCase(unittest.HomeserverTestCase): """More tests for @cached