diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index febd40b656..192073c520 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -201,7 +201,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
"""Ensure a piece of media is quarantined when trying to access it."""
channel = make_request(
self.reactor,
- FakeSite(self.download_resource),
+ FakeSite(self.download_resource, self.reactor),
"GET",
server_and_media_id,
shorthand=False,
@@ -271,7 +271,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Attempt to access the media
channel = make_request(
self.reactor,
- FakeSite(self.download_resource),
+ FakeSite(self.download_resource, self.reactor),
"GET",
server_name_and_media_id,
shorthand=False,
@@ -458,7 +458,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Attempt to access each piece of media
channel = make_request(
self.reactor,
- FakeSite(self.download_resource),
+ FakeSite(self.download_resource, self.reactor),
"GET",
server_and_media_id_2,
shorthand=False,
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 2f02934e72..ce30a19213 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -43,7 +43,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
- self.filepaths = MediaFilePaths(hs.config.media_store_path)
+ self.filepaths = MediaFilePaths(hs.config.media.media_store_path)
def test_no_auth(self):
"""
@@ -125,7 +125,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
# Attempt to access media
channel = make_request(
self.reactor,
- FakeSite(download_resource),
+ FakeSite(download_resource, self.reactor),
"GET",
server_and_media_id,
shorthand=False,
@@ -164,7 +164,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
# Attempt to access media
channel = make_request(
self.reactor,
- FakeSite(download_resource),
+ FakeSite(download_resource, self.reactor),
"GET",
server_and_media_id,
shorthand=False,
@@ -200,7 +200,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
- self.filepaths = MediaFilePaths(hs.config.media_store_path)
+ self.filepaths = MediaFilePaths(hs.config.media.media_store_path)
self.url = "/_synapse/admin/v1/media/%s/delete" % self.server_name
def test_no_auth(self):
@@ -525,7 +525,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
channel = make_request(
self.reactor,
- FakeSite(download_resource),
+ FakeSite(download_resource, self.reactor),
"GET",
server_and_media_id,
shorthand=False,
diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py
index 4927321e5a..9bac423ae0 100644
--- a/tests/rest/admin/test_registration_tokens.py
+++ b/tests/rest/admin/test_registration_tokens.py
@@ -95,8 +95,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
def test_create_specifying_fields(self):
"""Create a token specifying the value of all fields."""
+ # As many of the allowed characters as possible with length <= 64
+ token = "adefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789._~-"
data = {
- "token": "abcd",
+ "token": token,
"uses_allowed": 1,
"expiry_time": self.clock.time_msec() + 1000000,
}
@@ -109,7 +111,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(channel.json_body["token"], "abcd")
+ self.assertEqual(channel.json_body["token"], token)
self.assertEqual(channel.json_body["uses_allowed"], 1)
self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"])
self.assertEqual(channel.json_body["pending"], 0)
@@ -193,7 +195,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
"""Check right error is raised when server can't generate unique token."""
# Create all possible single character tokens
tokens = []
- for c in string.ascii_letters + string.digits + "-_":
+ for c in string.ascii_letters + string.digits + "._~-":
tokens.append(
{
"token": c,
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index e798513ac1..0fa55e03b4 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -47,7 +47,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.event_creation_handler = hs.get_event_creation_handler()
- hs.config.user_consent_version = "1"
+ hs.config.consent.user_consent_version = "1"
consent_uri_builder = Mock()
consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index cc3f16c62a..ee3ae9cce4 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -2473,7 +2473,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.media_repo = hs.get_media_repository_resource()
- self.filepaths = MediaFilePaths(hs.config.media_store_path)
+ self.filepaths = MediaFilePaths(hs.config.media.media_store_path)
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
@@ -2973,7 +2973,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
# Try to access a media and to create `last_access_ts`
channel = make_request(
self.reactor,
- FakeSite(download_resource),
+ FakeSite(download_resource, self.reactor),
"GET",
server_and_media_id,
shorthand=False,
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index b946fca8b3..9e9e953cf4 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -312,7 +312,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Load the password reset confirmation page
channel = make_request(
self.reactor,
- FakeSite(self.submit_token_resource),
+ FakeSite(self.submit_token_resource, self.reactor),
"GET",
path,
shorthand=False,
@@ -326,7 +326,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Confirm the password reset
channel = make_request(
self.reactor,
- FakeSite(self.submit_token_resource),
+ FakeSite(self.submit_token_resource, self.reactor),
"POST",
path,
content=b"",
diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py
index 65c58ce70a..84d092ca82 100644
--- a/tests/rest/client/test_consent.py
+++ b/tests/rest/client/test_consent.py
@@ -61,7 +61,11 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
"""You can observe the terms form without specifying a user"""
resource = consent_resource.ConsentResource(self.hs)
channel = make_request(
- self.reactor, FakeSite(resource), "GET", "/consent?v=1", shorthand=False
+ self.reactor,
+ FakeSite(resource, self.reactor),
+ "GET",
+ "/consent?v=1",
+ shorthand=False,
)
self.assertEqual(channel.code, 200)
@@ -83,7 +87,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
)
channel = make_request(
self.reactor,
- FakeSite(resource),
+ FakeSite(resource, self.reactor),
"GET",
consent_uri,
access_token=access_token,
@@ -98,7 +102,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
# POST to the consent page, saying we've agreed
channel = make_request(
self.reactor,
- FakeSite(resource),
+ FakeSite(resource, self.reactor),
"POST",
consent_uri + "&v=" + version,
access_token=access_token,
@@ -110,7 +114,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
# changed
channel = make_request(
self.reactor,
- FakeSite(resource),
+ FakeSite(resource, self.reactor),
"GET",
consent_uri,
access_token=access_token,
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index f5c195a075..371615a015 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -97,7 +97,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.hs.config.enable_registration = True
self.hs.config.registrations_require_3pid = []
self.hs.config.auto_join_rooms = []
- self.hs.config.enable_registration_captcha = False
+ self.hs.config.captcha.enable_registration_captcha = False
return self.hs
@@ -815,9 +815,9 @@ class JWTTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver()
- self.hs.config.jwt_enabled = True
- self.hs.config.jwt_secret = self.jwt_secret
- self.hs.config.jwt_algorithm = self.jwt_algorithm
+ self.hs.config.jwt.jwt_enabled = True
+ self.hs.config.jwt.jwt_secret = self.jwt_secret
+ self.hs.config.jwt.jwt_algorithm = self.jwt_algorithm
return self.hs
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
@@ -1023,9 +1023,9 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver()
- self.hs.config.jwt_enabled = True
- self.hs.config.jwt_secret = self.jwt_pubkey
- self.hs.config.jwt_algorithm = "RS256"
+ self.hs.config.jwt.jwt_enabled = True
+ self.hs.config.jwt.jwt_secret = self.jwt_pubkey
+ self.hs.config.jwt.jwt_algorithm = "RS256"
return self.hs
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 9f3ab2c985..72a5a11b46 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -146,7 +146,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.json_body["errcode"], "M_FORBIDDEN")
def test_POST_guest_registration(self):
- self.hs.config.macaroon_secret_key = "test"
+ self.hs.config.key.macaroon_secret_key = "test"
self.hs.config.allow_guest_access = True
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index ef847f0f5f..30bdaa9c27 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -18,7 +18,7 @@
"""Tests REST events for /rooms paths."""
import json
-from typing import Iterable
+from typing import Dict, Iterable, List, Optional
from unittest.mock import Mock, call
from urllib import parse as urlparse
@@ -30,7 +30,7 @@ from synapse.api.errors import Codes, HttpResponseException
from synapse.handlers.pagination import PurgeStatus
from synapse.rest import admin
from synapse.rest.client import account, directory, login, profile, room, sync
-from synapse.types import JsonDict, RoomAlias, UserID, create_requester
+from synapse.types import JsonDict, Requester, RoomAlias, UserID, create_requester
from synapse.util.stringutils import random_string
from tests import unittest
@@ -669,6 +669,121 @@ class RoomsCreateTestCase(RoomBase):
channel = self.make_request("POST", "/createRoom", content)
self.assertEqual(200, channel.code)
+ def test_spamchecker_invites(self):
+ """Tests the user_may_create_room_with_invites spam checker callback."""
+
+ # Mock do_3pid_invite, so we don't fail from failing to send a 3PID invite to an
+ # IS.
+ async def do_3pid_invite(
+ room_id: str,
+ inviter: UserID,
+ medium: str,
+ address: str,
+ id_server: str,
+ requester: Requester,
+ txn_id: Optional[str],
+ id_access_token: Optional[str] = None,
+ ) -> int:
+ return 0
+
+ do_3pid_invite_mock = Mock(side_effect=do_3pid_invite)
+ self.hs.get_room_member_handler().do_3pid_invite = do_3pid_invite_mock
+
+ # Add a mock callback for user_may_create_room_with_invites. Make it allow any
+ # room creation request for now.
+ return_value = True
+
+ async def user_may_create_room_with_invites(
+ user: str,
+ invites: List[str],
+ threepid_invites: List[Dict[str, str]],
+ ) -> bool:
+ return return_value
+
+ callback_mock = Mock(side_effect=user_may_create_room_with_invites)
+ self.hs.get_spam_checker()._user_may_create_room_with_invites_callbacks.append(
+ callback_mock,
+ )
+
+ # The MXIDs we'll try to invite.
+ invited_mxids = [
+ "@alice1:red",
+ "@alice2:red",
+ "@alice3:red",
+ "@alice4:red",
+ ]
+
+ # The 3PIDs we'll try to invite.
+ invited_3pids = [
+ {
+ "id_server": "example.com",
+ "id_access_token": "sometoken",
+ "medium": "email",
+ "address": "alice1@example.com",
+ },
+ {
+ "id_server": "example.com",
+ "id_access_token": "sometoken",
+ "medium": "email",
+ "address": "alice2@example.com",
+ },
+ {
+ "id_server": "example.com",
+ "id_access_token": "sometoken",
+ "medium": "email",
+ "address": "alice3@example.com",
+ },
+ ]
+
+ # Create a room and invite the Matrix users, and check that it succeeded.
+ channel = self.make_request(
+ "POST",
+ "/createRoom",
+ json.dumps({"invite": invited_mxids}).encode("utf8"),
+ )
+ self.assertEqual(200, channel.code)
+
+ # Check that the callback was called with the right arguments.
+ expected_call_args = ((self.user_id, invited_mxids, []),)
+ self.assertEquals(
+ callback_mock.call_args,
+ expected_call_args,
+ callback_mock.call_args,
+ )
+
+ # Create a room and invite the 3PIDs, and check that it succeeded.
+ channel = self.make_request(
+ "POST",
+ "/createRoom",
+ json.dumps({"invite_3pid": invited_3pids}).encode("utf8"),
+ )
+ self.assertEqual(200, channel.code)
+
+ # Check that do_3pid_invite was called the right amount of time
+ self.assertEquals(do_3pid_invite_mock.call_count, len(invited_3pids))
+
+ # Check that the callback was called with the right arguments.
+ expected_call_args = ((self.user_id, [], invited_3pids),)
+ self.assertEquals(
+ callback_mock.call_args,
+ expected_call_args,
+ callback_mock.call_args,
+ )
+
+ # Now deny any room creation.
+ return_value = False
+
+ # Create a room and invite the 3PIDs, and check that it failed.
+ channel = self.make_request(
+ "POST",
+ "/createRoom",
+ json.dumps({"invite_3pid": invited_3pids}).encode("utf8"),
+ )
+ self.assertEqual(403, channel.code)
+
+ # Check that do_3pid_invite wasn't called this time.
+ self.assertEquals(do_3pid_invite_mock.call_count, len(invited_3pids))
+
class RoomTopicTestCase(RoomBase):
"""Tests /rooms/$room_id/topic REST events."""
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index c56e45fc10..3075d3f288 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -383,7 +383,7 @@ class RestHelper:
path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
channel = make_request(
self.hs.get_reactor(),
- FakeSite(resource),
+ FakeSite(resource, self.hs.get_reactor()),
"POST",
path,
content=image_data,
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index a75c0ea3f0..4672a68596 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -84,7 +84,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
Checks that the response is a 200 and returns the decoded json body.
"""
channel = FakeChannel(self.site, self.reactor)
- req = SynapseRequest(channel)
+ req = SynapseRequest(channel, self.site)
req.content = BytesIO(b"")
req.requestReceived(
b"GET",
@@ -183,7 +183,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
)
channel = FakeChannel(self.site, self.reactor)
- req = SynapseRequest(channel)
+ req = SynapseRequest(channel, self.site)
req.content = BytesIO(encode_canonical_json(data))
req.requestReceived(
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 9ea1c2bf25..4ae00755c9 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -53,7 +53,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
self.primary_base_path = os.path.join(self.test_dir, "primary")
self.secondary_base_path = os.path.join(self.test_dir, "secondary")
- hs.config.media_store_path = self.primary_base_path
+ hs.config.media.media_store_path = self.primary_base_path
storage_providers = [FileStorageProviderBackend(hs, self.secondary_base_path)]
@@ -252,7 +252,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel = make_request(
self.reactor,
- FakeSite(self.download_resource),
+ FakeSite(self.download_resource, self.reactor),
"GET",
self.media_id,
shorthand=False,
@@ -384,7 +384,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
params = "?width=32&height=32&method=scale"
channel = make_request(
self.reactor,
- FakeSite(self.thumbnail_resource),
+ FakeSite(self.thumbnail_resource, self.reactor),
"GET",
self.media_id + params,
shorthand=False,
@@ -413,7 +413,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel = make_request(
self.reactor,
- FakeSite(self.thumbnail_resource),
+ FakeSite(self.thumbnail_resource, self.reactor),
"GET",
self.media_id + params,
shorthand=False,
@@ -433,7 +433,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
params = "?width=32&height=32&method=" + method
channel = make_request(
self.reactor,
- FakeSite(self.thumbnail_resource),
+ FakeSite(self.thumbnail_resource, self.reactor),
"GET",
self.media_id + params,
shorthand=False,
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 9d13899584..4d09b5d07e 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -21,6 +21,7 @@ from twisted.internet.error import DNSLookupError
from twisted.test.proto_helpers import AccumulatingProtocol
from synapse.config.oembed import OEmbedEndpointConfig
+from synapse.util.stringutils import parse_and_validate_mxc_uri
from tests import unittest
from tests.server import FakeTransport
@@ -620,11 +621,12 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertIn(b"/matrixdotorg", server.data)
self.assertEqual(channel.code, 200)
- self.assertIsNone(channel.json_body["og:title"])
- self.assertTrue(channel.json_body["og:image"].startswith("mxc://"))
- self.assertEqual(channel.json_body["og:image:height"], 1)
- self.assertEqual(channel.json_body["og:image:width"], 1)
- self.assertEqual(channel.json_body["og:image:type"], "image/png")
+ body = channel.json_body
+ self.assertEqual(body["og:url"], "http://twitter.com/matrixdotorg/status/12345")
+ self.assertTrue(body["og:image"].startswith("mxc://"))
+ self.assertEqual(body["og:image:height"], 1)
+ self.assertEqual(body["og:image:width"], 1)
+ self.assertEqual(body["og:image:type"], "image/png")
def test_oembed_rich(self):
"""Test an oEmbed endpoint which returns HTML content via the 'rich' type."""
@@ -633,6 +635,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
result = {
"version": "1.0",
"type": "rich",
+ # Note that this provides the author, not the title.
+ "author_name": "Alice",
"html": "<div>Content Preview</div>",
}
end_content = json.dumps(result).encode("utf-8")
@@ -660,9 +664,14 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.pump()
self.assertEqual(channel.code, 200)
+ body = channel.json_body
self.assertEqual(
- channel.json_body,
- {"og:title": None, "og:description": "Content Preview"},
+ body,
+ {
+ "og:url": "http://twitter.com/matrixdotorg/status/12345",
+ "og:title": "Alice",
+ "og:description": "Content Preview",
+ },
)
def test_oembed_format(self):
@@ -705,7 +714,140 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertIn(b"format=json", server.data)
self.assertEqual(channel.code, 200)
+ body = channel.json_body
self.assertEqual(
- channel.json_body,
- {"og:title": None, "og:description": "Content Preview"},
+ body,
+ {
+ "og:url": "http://www.hulu.com/watch/12345",
+ "og:description": "Content Preview",
+ },
+ )
+
+ def _download_image(self):
+ """Downloads an image into the URL cache.
+
+ Returns:
+ A (host, media_id) tuple representing the MXC URI of the image.
+ """
+ self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+
+ channel = self.make_request(
+ "GET",
+ "preview_url?url=http://cdn.twitter.com/matrixdotorg",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+
+ client = self.reactor.tcpClients[0][2].buildProtocol(None)
+ server = AccumulatingProtocol()
+ server.makeConnection(FakeTransport(client, self.reactor))
+ client.makeConnection(FakeTransport(server, self.reactor))
+ client.dataReceived(
+ b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: image/png\r\n\r\n"
+ % (len(SMALL_PNG),)
+ + SMALL_PNG
+ )
+
+ self.pump()
+ self.assertEqual(channel.code, 200)
+ body = channel.json_body
+ mxc_uri = body["og:image"]
+ host, _port, media_id = parse_and_validate_mxc_uri(mxc_uri)
+ self.assertIsNone(_port)
+ return host, media_id
+
+ def test_storage_providers_exclude_files(self):
+ """Test that files are not stored in or fetched from storage providers."""
+ host, media_id = self._download_image()
+
+ rel_file_path = self.preview_url.filepaths.url_cache_filepath_rel(media_id)
+ media_store_path = os.path.join(self.media_store_path, rel_file_path)
+ storage_provider_path = os.path.join(self.storage_path, rel_file_path)
+
+ # Check storage
+ self.assertTrue(os.path.isfile(media_store_path))
+ self.assertFalse(
+ os.path.isfile(storage_provider_path),
+ "URL cache file was unexpectedly stored in a storage provider",
+ )
+
+ # Check fetching
+ channel = self.make_request(
+ "GET",
+ f"download/{host}/{media_id}",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+ self.assertEqual(channel.code, 200)
+
+ # Move cached file into the storage provider
+ os.makedirs(os.path.dirname(storage_provider_path), exist_ok=True)
+ os.rename(media_store_path, storage_provider_path)
+
+ channel = self.make_request(
+ "GET",
+ f"download/{host}/{media_id}",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+ self.assertEqual(
+ channel.code,
+ 404,
+ "URL cache file was unexpectedly retrieved from a storage provider",
+ )
+
+ def test_storage_providers_exclude_thumbnails(self):
+ """Test that thumbnails are not stored in or fetched from storage providers."""
+ host, media_id = self._download_image()
+
+ rel_thumbnail_path = (
+ self.preview_url.filepaths.url_cache_thumbnail_directory_rel(media_id)
+ )
+ media_store_thumbnail_path = os.path.join(
+ self.media_store_path, rel_thumbnail_path
+ )
+ storage_provider_thumbnail_path = os.path.join(
+ self.storage_path, rel_thumbnail_path
+ )
+
+ # Check storage
+ self.assertTrue(os.path.isdir(media_store_thumbnail_path))
+ self.assertFalse(
+ os.path.isdir(storage_provider_thumbnail_path),
+ "URL cache thumbnails were unexpectedly stored in a storage provider",
+ )
+
+ # Check fetching
+ channel = self.make_request(
+ "GET",
+ f"thumbnail/{host}/{media_id}?width=32&height=32&method=scale",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+ self.assertEqual(channel.code, 200)
+
+ # Remove the original, otherwise thumbnails will regenerate
+ rel_file_path = self.preview_url.filepaths.url_cache_filepath_rel(media_id)
+ media_store_path = os.path.join(self.media_store_path, rel_file_path)
+ os.remove(media_store_path)
+
+ # Move cached thumbnails into the storage provider
+ os.makedirs(os.path.dirname(storage_provider_thumbnail_path), exist_ok=True)
+ os.rename(media_store_thumbnail_path, storage_provider_thumbnail_path)
+
+ channel = self.make_request(
+ "GET",
+ f"thumbnail/{host}/{media_id}?width=32&height=32&method=scale",
+ shorthand=False,
+ await_result=False,
+ )
+ self.pump()
+ self.assertEqual(
+ channel.code,
+ 404,
+ "URL cache thumbnail was unexpectedly retrieved from a storage provider",
)
|