diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/handlers/test_password_providers.py | 223 | ||||
-rw-r--r-- | tests/rest/admin/test_user.py | 401 | ||||
-rw-r--r-- | tests/rest/media/v1/test_filepath.py | 238 | ||||
-rw-r--r-- | tests/rest/media/v1/test_oembed.py | 51 |
4 files changed, 632 insertions, 281 deletions
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 38e6d9f536..7dd4a5a367 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -20,6 +20,8 @@ from unittest.mock import Mock from twisted.internet import defer import synapse +from synapse.handlers.auth import load_legacy_password_auth_providers +from synapse.module_api import ModuleApi from synapse.rest.client import devices, login from synapse.types import JsonDict @@ -36,8 +38,8 @@ ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_servi mock_password_provider = Mock() -class PasswordOnlyAuthProvider: - """A password_provider which only implements `check_password`.""" +class LegacyPasswordOnlyAuthProvider: + """A legacy password_provider which only implements `check_password`.""" @staticmethod def parse_config(self): @@ -50,8 +52,8 @@ class PasswordOnlyAuthProvider: return mock_password_provider.check_password(*args) -class CustomAuthProvider: - """A password_provider which implements a custom login type.""" +class LegacyCustomAuthProvider: + """A legacy password_provider which implements a custom login type.""" @staticmethod def parse_config(self): @@ -67,7 +69,23 @@ class CustomAuthProvider: return mock_password_provider.check_auth(*args) -class PasswordCustomAuthProvider: +class CustomAuthProvider: + """A module which registers password_auth_provider callbacks for a custom login type.""" + + @staticmethod + def parse_config(self): + pass + + def __init__(self, config, api: ModuleApi): + api.register_password_auth_provider_callbacks( + auth_checkers={("test.login_type", ("test_field",)): self.check_auth}, + ) + + def check_auth(self, *args): + return mock_password_provider.check_auth(*args) + + +class LegacyPasswordCustomAuthProvider: """A password_provider which implements password login via `check_auth`, as well as a custom type.""" @@ -85,8 +103,32 @@ class PasswordCustomAuthProvider: return mock_password_provider.check_auth(*args) -def providers_config(*providers: Type[Any]) -> dict: - """Returns a config dict that will enable the given password auth providers""" +class PasswordCustomAuthProvider: + """A module which registers password_auth_provider callbacks for a custom login type. + as well as a password login""" + + @staticmethod + def parse_config(self): + pass + + def __init__(self, config, api: ModuleApi): + api.register_password_auth_provider_callbacks( + auth_checkers={ + ("test.login_type", ("test_field",)): self.check_auth, + ("m.login.password", ("password",)): self.check_auth, + }, + ) + pass + + def check_auth(self, *args): + return mock_password_provider.check_auth(*args) + + def check_pass(self, *args): + return mock_password_provider.check_password(*args) + + +def legacy_providers_config(*providers: Type[Any]) -> dict: + """Returns a config dict that will enable the given legacy password auth providers""" return { "password_providers": [ {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}} @@ -95,6 +137,16 @@ def providers_config(*providers: Type[Any]) -> dict: } +def providers_config(*providers: Type[Any]) -> dict: + """Returns a config dict that will enable the given modules""" + return { + "modules": [ + {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}} + for provider in providers + ] + } + + class PasswordAuthProviderTests(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets, @@ -107,8 +159,21 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.reset_mock() super().setUp() - @override_config(providers_config(PasswordOnlyAuthProvider)) - def test_password_only_auth_provider_login(self): + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + # Load the modules into the homeserver + module_api = hs.get_module_api() + for module, config in hs.config.modules.loaded_modules: + module(config=config, api=module_api) + load_legacy_password_auth_providers(hs) + + return hs + + @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) + def test_password_only_auth_progiver_login_legacy(self): + self.password_only_auth_provider_login_test_body() + + def password_only_auth_provider_login_test_body(self): # login flows should only have m.login.password flows = self._get_login_flows() self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS) @@ -138,8 +203,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "@ USER🙂NAME :test", " pASS😢word " ) - @override_config(providers_config(PasswordOnlyAuthProvider)) - def test_password_only_auth_provider_ui_auth(self): + @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) + def test_password_only_auth_provider_ui_auth_legacy(self): + self.password_only_auth_provider_ui_auth_test_body() + + def password_only_auth_provider_ui_auth_test_body(self): """UI Auth should delegate correctly to the password provider""" # create the user, otherwise access doesn't work @@ -172,8 +240,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) mock_password_provider.check_password.assert_called_once_with("@u:test", "p") - @override_config(providers_config(PasswordOnlyAuthProvider)) - def test_local_user_fallback_login(self): + @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) + def test_local_user_fallback_login_legacy(self): + self.local_user_fallback_login_test_body() + + def local_user_fallback_login_test_body(self): """rejected login should fall back to local db""" self.register_user("localuser", "localpass") @@ -186,8 +257,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) self.assertEqual("@localuser:test", channel.json_body["user_id"]) - @override_config(providers_config(PasswordOnlyAuthProvider)) - def test_local_user_fallback_ui_auth(self): + @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) + def test_local_user_fallback_ui_auth_legacy(self): + self.local_user_fallback_ui_auth_test_body() + + def local_user_fallback_ui_auth_test_body(self): """rejected login should fall back to local db""" self.register_user("localuser", "localpass") @@ -223,11 +297,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): @override_config( { - **providers_config(PasswordOnlyAuthProvider), + **legacy_providers_config(LegacyPasswordOnlyAuthProvider), "password_config": {"localdb_enabled": False}, } ) - def test_no_local_user_fallback_login(self): + def test_no_local_user_fallback_login_legacy(self): + self.no_local_user_fallback_login_test_body() + + def no_local_user_fallback_login_test_body(self): """localdb_enabled can block login with the local password""" self.register_user("localuser", "localpass") @@ -242,11 +319,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): @override_config( { - **providers_config(PasswordOnlyAuthProvider), + **legacy_providers_config(LegacyPasswordOnlyAuthProvider), "password_config": {"localdb_enabled": False}, } ) - def test_no_local_user_fallback_ui_auth(self): + def test_no_local_user_fallback_ui_auth_legacy(self): + self.no_local_user_fallback_ui_auth_test_body() + + def no_local_user_fallback_ui_auth_test_body(self): """localdb_enabled can block ui auth with the local password""" self.register_user("localuser", "localpass") @@ -280,11 +360,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): @override_config( { - **providers_config(PasswordOnlyAuthProvider), + **legacy_providers_config(LegacyPasswordOnlyAuthProvider), "password_config": {"enabled": False}, } ) - def test_password_auth_disabled(self): + def test_password_auth_disabled_legacy(self): + self.password_auth_disabled_test_body() + + def password_auth_disabled_test_body(self): """password auth doesn't work if it's disabled across the board""" # login flows should be empty flows = self._get_login_flows() @@ -295,8 +378,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_password.assert_not_called() + @override_config(legacy_providers_config(LegacyCustomAuthProvider)) + def test_custom_auth_provider_login_legacy(self): + self.custom_auth_provider_login_test_body() + @override_config(providers_config(CustomAuthProvider)) def test_custom_auth_provider_login(self): + self.custom_auth_provider_login_test_body() + + def custom_auth_provider_login_test_body(self): # login flows should have the custom flow and m.login.password, since we # haven't disabled local password lookup. # (password must come first, because reasons) @@ -312,7 +402,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_auth.assert_not_called() - mock_password_provider.check_auth.return_value = defer.succeed("@user:bz") + mock_password_provider.check_auth.return_value = defer.succeed( + ("@user:bz", None) + ) channel = self._send_login("test.login_type", "u", test_field="y") self.assertEqual(channel.code, 200, channel.result) self.assertEqual("@user:bz", channel.json_body["user_id"]) @@ -325,7 +417,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # in these cases, but at least we can guard against the API changing # unexpectedly mock_password_provider.check_auth.return_value = defer.succeed( - "@ MALFORMED! :bz" + ("@ MALFORMED! :bz", None) ) channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ") self.assertEqual(channel.code, 200, channel.result) @@ -334,8 +426,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): " USER🙂NAME ", "test.login_type", {"test_field": " abc "} ) + @override_config(legacy_providers_config(LegacyCustomAuthProvider)) + def test_custom_auth_provider_ui_auth_legacy(self): + self.custom_auth_provider_ui_auth_test_body() + @override_config(providers_config(CustomAuthProvider)) def test_custom_auth_provider_ui_auth(self): + self.custom_auth_provider_ui_auth_test_body() + + def custom_auth_provider_ui_auth_test_body(self): # register the user and log in twice, to get two devices self.register_user("localuser", "localpass") tok1 = self.login("localuser", "localpass") @@ -367,7 +466,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.reset_mock() # right params, but authing as the wrong user - mock_password_provider.check_auth.return_value = defer.succeed("@user:bz") + mock_password_provider.check_auth.return_value = defer.succeed( + ("@user:bz", None) + ) body["auth"]["test_field"] = "foo" channel = self._delete_device(tok1, "dev2", body) self.assertEqual(channel.code, 403) @@ -379,7 +480,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # and finally, succeed mock_password_provider.check_auth.return_value = defer.succeed( - "@localuser:test" + ("@localuser:test", None) ) channel = self._delete_device(tok1, "dev2", body) self.assertEqual(channel.code, 200) @@ -387,8 +488,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "localuser", "test.login_type", {"test_field": "foo"} ) + @override_config(legacy_providers_config(LegacyCustomAuthProvider)) + def test_custom_auth_provider_callback_legacy(self): + self.custom_auth_provider_callback_test_body() + @override_config(providers_config(CustomAuthProvider)) def test_custom_auth_provider_callback(self): + self.custom_auth_provider_callback_test_body() + + def custom_auth_provider_callback_test_body(self): callback = Mock(return_value=defer.succeed(None)) mock_password_provider.check_auth.return_value = defer.succeed( @@ -411,9 +519,21 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertIn(p, call_args[0]) @override_config( + { + **legacy_providers_config(LegacyCustomAuthProvider), + "password_config": {"enabled": False}, + } + ) + def test_custom_auth_password_disabled_legacy(self): + self.custom_auth_password_disabled_test_body() + + @override_config( {**providers_config(CustomAuthProvider), "password_config": {"enabled": False}} ) def test_custom_auth_password_disabled(self): + self.custom_auth_password_disabled_test_body() + + def custom_auth_password_disabled_test_body(self): """Test login with a custom auth provider where password login is disabled""" self.register_user("localuser", "localpass") @@ -427,11 +547,23 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): @override_config( { + **legacy_providers_config(LegacyCustomAuthProvider), + "password_config": {"enabled": False, "localdb_enabled": False}, + } + ) + def test_custom_auth_password_disabled_localdb_enabled_legacy(self): + self.custom_auth_password_disabled_localdb_enabled_test_body() + + @override_config( + { **providers_config(CustomAuthProvider), "password_config": {"enabled": False, "localdb_enabled": False}, } ) def test_custom_auth_password_disabled_localdb_enabled(self): + self.custom_auth_password_disabled_localdb_enabled_test_body() + + def custom_auth_password_disabled_localdb_enabled_test_body(self): """Check the localdb_enabled == enabled == False Regression test for https://github.com/matrix-org/synapse/issues/8914: check @@ -450,11 +582,23 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): @override_config( { + **legacy_providers_config(LegacyPasswordCustomAuthProvider), + "password_config": {"enabled": False}, + } + ) + def test_password_custom_auth_password_disabled_login_legacy(self): + self.password_custom_auth_password_disabled_login_test_body() + + @override_config( + { **providers_config(PasswordCustomAuthProvider), "password_config": {"enabled": False}, } ) def test_password_custom_auth_password_disabled_login(self): + self.password_custom_auth_password_disabled_login_test_body() + + def password_custom_auth_password_disabled_login_test_body(self): """log in with a custom auth provider which implements password, but password login is disabled""" self.register_user("localuser", "localpass") @@ -466,6 +610,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): channel = self._send_password_login("localuser", "localpass") self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_auth.assert_not_called() + mock_password_provider.check_password.assert_not_called() + + @override_config( + { + **legacy_providers_config(LegacyPasswordCustomAuthProvider), + "password_config": {"enabled": False}, + } + ) + def test_password_custom_auth_password_disabled_ui_auth_legacy(self): + self.password_custom_auth_password_disabled_ui_auth_test_body() @override_config( { @@ -474,12 +628,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): } ) def test_password_custom_auth_password_disabled_ui_auth(self): + self.password_custom_auth_password_disabled_ui_auth_test_body() + + def password_custom_auth_password_disabled_ui_auth_test_body(self): """UI Auth with a custom auth provider which implements password, but password login is disabled""" # register the user and log in twice via the test login type to get two devices, self.register_user("localuser", "localpass") mock_password_provider.check_auth.return_value = defer.succeed( - "@localuser:test" + ("@localuser:test", None) ) channel = self._send_login("test.login_type", "localuser", test_field="") self.assertEqual(channel.code, 200, channel.result) @@ -516,6 +673,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "Password login has been disabled.", channel.json_body["error"] ) mock_password_provider.check_auth.assert_not_called() + mock_password_provider.check_password.assert_not_called() mock_password_provider.reset_mock() # successful auth @@ -526,6 +684,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.check_auth.assert_called_once_with( "localuser", "test.login_type", {"test_field": "x"} ) + mock_password_provider.check_password.assert_not_called() + + @override_config( + { + **legacy_providers_config(LegacyCustomAuthProvider), + "password_config": {"localdb_enabled": False}, + } + ) + def test_custom_auth_no_local_user_fallback_legacy(self): + self.custom_auth_no_local_user_fallback_test_body() @override_config( { @@ -534,6 +702,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): } ) def test_custom_auth_no_local_user_fallback(self): + self.custom_auth_no_local_user_fallback_test_body() + + def custom_auth_no_local_user_fallback_test_body(self): """Test login with a custom auth provider where the local db is disabled""" self.register_user("localuser", "localpass") diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 6ed9e42173..c9e2754b09 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -14,14 +14,13 @@ import hashlib import hmac -import json import os import urllib.parse from binascii import unhexlify from typing import List, Optional from unittest.mock import Mock, patch -from parameterized import parameterized +from parameterized import parameterized, parameterized_class import synapse.rest.admin from synapse.api.constants import UserTypes @@ -104,8 +103,8 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # 59 seconds self.reactor.advance(59) - body = json.dumps({"nonce": nonce}) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = {"nonce": nonce} + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("username must be specified", channel.json_body["error"]) @@ -113,7 +112,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # 61 seconds self.reactor.advance(2) - channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("unrecognised nonce", channel.json_body["error"]) @@ -129,18 +128,16 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin") want_mac = want_mac.hexdigest() - body = json.dumps( - { - "nonce": nonce, - "username": "bob", - "password": "abc123", - "admin": True, - "mac": want_mac, - } - ) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = { + "nonce": nonce, + "username": "bob", + "password": "abc123", + "admin": True, + "mac": want_mac, + } + channel = self.make_request("POST", self.url, body) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("HMAC incorrect", channel.json_body["error"]) def test_register_correct_nonce(self): @@ -157,17 +154,15 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) want_mac = want_mac.hexdigest() - body = json.dumps( - { - "nonce": nonce, - "username": "bob", - "password": "abc123", - "admin": True, - "user_type": UserTypes.SUPPORT, - "mac": want_mac, - } - ) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = { + "nonce": nonce, + "username": "bob", + "password": "abc123", + "admin": True, + "user_type": UserTypes.SUPPORT, + "mac": want_mac, + } + channel = self.make_request("POST", self.url, body) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) @@ -183,22 +178,20 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac.update(nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin") want_mac = want_mac.hexdigest() - body = json.dumps( - { - "nonce": nonce, - "username": "bob", - "password": "abc123", - "admin": True, - "mac": want_mac, - } - ) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = { + "nonce": nonce, + "username": "bob", + "password": "abc123", + "admin": True, + "mac": want_mac, + } + channel = self.make_request("POST", self.url, body) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) # Now, try and reuse it - channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("unrecognised nonce", channel.json_body["error"]) @@ -218,9 +211,8 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Nonce check # - # Must be present - body = json.dumps({}) - channel = self.make_request("POST", self.url, body.encode("utf8")) + # Must be an empty body present + channel = self.make_request("POST", self.url, {}) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("nonce must be specified", channel.json_body["error"]) @@ -230,29 +222,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # # Must be present - body = json.dumps({"nonce": nonce()}) - channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, {"nonce": nonce()}) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("username must be specified", channel.json_body["error"]) # Must be a string - body = json.dumps({"nonce": nonce(), "username": 1234}) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = {"nonce": nonce(), "username": 1234} + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes - body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"}) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = {"nonce": nonce(), "username": "abcd\u0000"} + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes - body = json.dumps({"nonce": nonce(), "username": "a" * 1000}) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = {"nonce": nonce(), "username": "a" * 1000} + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) @@ -262,29 +253,29 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # # Must be present - body = json.dumps({"nonce": nonce(), "username": "a"}) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = {"nonce": nonce(), "username": "a"} + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("password must be specified", channel.json_body["error"]) # Must be a string - body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234}) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = {"nonce": nonce(), "username": "a", "password": 1234} + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # Must not have null bytes - body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"}) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = {"nonce": nonce(), "username": "a", "password": "abcd\u0000"} + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # Super long - body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000}) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = {"nonce": nonce(), "username": "a", "password": "A" * 1000} + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) @@ -294,15 +285,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # # Invalid user_type - body = json.dumps( - { - "nonce": nonce(), - "username": "a", - "password": "1234", - "user_type": "invalid", - } - ) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = { + "nonce": nonce(), + "username": "a", + "password": "1234", + "user_type": "invalid", + } + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid user type", channel.json_body["error"]) @@ -320,10 +309,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac.update(nonce.encode("ascii") + b"\x00bob1\x00abc123\x00notadmin") want_mac = want_mac.hexdigest() - body = json.dumps( - {"nonce": nonce, "username": "bob1", "password": "abc123", "mac": want_mac} - ) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = { + "nonce": nonce, + "username": "bob1", + "password": "abc123", + "mac": want_mac, + } + + channel = self.make_request("POST", self.url, body) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob1:test", channel.json_body["user_id"]) @@ -340,16 +333,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac.update(nonce.encode("ascii") + b"\x00bob2\x00abc123\x00notadmin") want_mac = want_mac.hexdigest() - body = json.dumps( - { - "nonce": nonce, - "username": "bob2", - "displayname": None, - "password": "abc123", - "mac": want_mac, - } - ) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = { + "nonce": nonce, + "username": "bob2", + "displayname": None, + "password": "abc123", + "mac": want_mac, + } + channel = self.make_request("POST", self.url, body) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob2:test", channel.json_body["user_id"]) @@ -366,22 +357,20 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac.update(nonce.encode("ascii") + b"\x00bob3\x00abc123\x00notadmin") want_mac = want_mac.hexdigest() - body = json.dumps( - { - "nonce": nonce, - "username": "bob3", - "displayname": "", - "password": "abc123", - "mac": want_mac, - } - ) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = { + "nonce": nonce, + "username": "bob3", + "displayname": "", + "password": "abc123", + "mac": want_mac, + } + channel = self.make_request("POST", self.url, body) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob3:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob3:test/displayname") - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(404, channel.code, msg=channel.json_body) # set displayname channel = self.make_request("GET", self.url) @@ -391,16 +380,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac.update(nonce.encode("ascii") + b"\x00bob4\x00abc123\x00notadmin") want_mac = want_mac.hexdigest() - body = json.dumps( - { - "nonce": nonce, - "username": "bob4", - "displayname": "Bob's Name", - "password": "abc123", - "mac": want_mac, - } - ) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = { + "nonce": nonce, + "username": "bob4", + "displayname": "Bob's Name", + "password": "abc123", + "mac": want_mac, + } + channel = self.make_request("POST", self.url, body) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob4:test", channel.json_body["user_id"]) @@ -440,17 +427,15 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) want_mac = want_mac.hexdigest() - body = json.dumps( - { - "nonce": nonce, - "username": "bob", - "password": "abc123", - "admin": True, - "user_type": UserTypes.SUPPORT, - "mac": want_mac, - } - ) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = { + "nonce": nonce, + "username": "bob", + "password": "abc123", + "admin": True, + "user_type": UserTypes.SUPPORT, + "mac": want_mac, + } + channel = self.make_request("POST", self.url, body) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) @@ -993,12 +978,11 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): """ If parameter `erase` is not boolean, return an error """ - body = json.dumps({"erase": "False"}) channel = self.make_request( "POST", self.url, - content=body.encode(encoding="utf_8"), + content={"erase": "False"}, access_token=self.admin_user_tok, ) @@ -2201,7 +2185,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self): @@ -2216,7 +2200,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_does_not_exist(self): @@ -2359,7 +2343,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self): @@ -2374,7 +2358,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_does_not_exist(self): @@ -3073,7 +3057,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): """Try to login as a user without authentication.""" channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_not_admin(self): @@ -3082,7 +3066,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): "POST", self.url, b"{}", access_token=self.other_user_tok ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) def test_send_event(self): """Test that sending event as a user works.""" @@ -3127,7 +3111,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # The puppet token should no longer work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) # .. but the real user's tokens should still work channel = self.make_request( @@ -3160,7 +3144,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) def test_admin_logout_all(self): """Tests that the admin user calling `/logout/all` does expire the @@ -3181,7 +3165,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # The puppet token should no longer work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) # .. but the real user's tokens should still work channel = self.make_request( @@ -3242,6 +3226,13 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): self.helper.join(room_id, user=self.other_user, tok=puppet_token) +@parameterized_class( + ("url_prefix",), + [ + ("/_synapse/admin/v1/whois/%s",), + ("/_matrix/client/r0/admin/whois/%s",), + ], +) class WhoisRestTestCase(unittest.HomeserverTestCase): servlets = [ @@ -3254,21 +3245,14 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): self.admin_user_tok = self.login("admin", "pass") self.other_user = self.register_user("user", "pass") - self.url1 = "/_synapse/admin/v1/whois/%s" % urllib.parse.quote(self.other_user) - self.url2 = "/_matrix/client/r0/admin/whois/%s" % urllib.parse.quote( - self.other_user - ) + self.url = self.url_prefix % self.other_user def test_no_auth(self): """ Try to get information of an user without authentication. """ - channel = self.make_request("GET", self.url1, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - - channel = self.make_request("GET", self.url2, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + channel = self.make_request("GET", self.url, b"{}") + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_not_admin(self): @@ -3280,38 +3264,21 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", - self.url1, - access_token=other_user2_token, - ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - channel = self.make_request( - "GET", - self.url2, + self.url, access_token=other_user2_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_is_not_local(self): """ Tests that a lookup for a user that is not a local returns a 400 """ - url1 = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain" - url2 = "/_matrix/client/r0/admin/whois/@unknown_person:unknown_domain" - - channel = self.make_request( - "GET", - url1, - access_token=self.admin_user_tok, - ) - self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual("Can only whois a local user", channel.json_body["error"]) + url = self.url_prefix % "@unknown_person:unknown_domain" channel = self.make_request( "GET", - url2, + url, access_token=self.admin_user_tok, ) self.assertEqual(400, channel.code, msg=channel.json_body) @@ -3323,16 +3290,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): """ channel = self.make_request( "GET", - self.url1, - access_token=self.admin_user_tok, - ) - self.assertEqual(200, channel.code, msg=channel.json_body) - self.assertEqual(self.other_user, channel.json_body["user_id"]) - self.assertIn("devices", channel.json_body) - - channel = self.make_request( - "GET", - self.url2, + self.url, access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -3347,16 +3305,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", - self.url1, - access_token=other_user_token, - ) - self.assertEqual(200, channel.code, msg=channel.json_body) - self.assertEqual(self.other_user, channel.json_body["user_id"]) - self.assertIn("devices", channel.json_body) - - channel = self.make_request( - "GET", - self.url2, + self.url, access_token=other_user_token, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -3388,7 +3337,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): Try to get information of an user without authentication. """ channel = self.make_request("POST", self.url) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_not_admin(self): @@ -3398,7 +3347,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): other_user_token = self.login("user", "pass") channel = self.make_request("POST", self.url, access_token=other_user_token) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_is_not_local(self): @@ -3447,84 +3396,41 @@ class RateLimitTestCase(unittest.HomeserverTestCase): % urllib.parse.quote(self.other_user) ) - def test_no_auth(self): + @parameterized.expand(["GET", "POST", "DELETE"]) + def test_no_auth(self, method: str): """ Try to get information of a user without authentication. """ - channel = self.make_request("GET", self.url, b"{}") - - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - - channel = self.make_request("POST", self.url, b"{}") + channel = self.make_request(method, self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - - channel = self.make_request("DELETE", self.url, b"{}") - - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + @parameterized.expand(["GET", "POST", "DELETE"]) + def test_requester_is_no_admin(self, method: str): """ If the user is not a server admin, an error is returned. """ other_user_token = self.login("user", "pass") channel = self.make_request( - "GET", - self.url, - access_token=other_user_token, - ) - - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - channel = self.make_request( - "POST", - self.url, - access_token=other_user_token, - ) - - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - channel = self.make_request( - "DELETE", + method, self.url, access_token=other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_user_does_not_exist(self): + @parameterized.expand(["GET", "POST", "DELETE"]) + def test_user_does_not_exist(self, method: str): """ Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v1/users/@unknown_person:test/override_ratelimit" channel = self.make_request( - "GET", - url, - access_token=self.admin_user_tok, - ) - - self.assertEqual(404, channel.code, msg=channel.json_body) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - - channel = self.make_request( - "POST", - url, - access_token=self.admin_user_tok, - ) - - self.assertEqual(404, channel.code, msg=channel.json_body) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - - channel = self.make_request( - "DELETE", + method, url, access_token=self.admin_user_tok, ) @@ -3532,7 +3438,14 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - def test_user_is_not_local(self): + @parameterized.expand( + [ + ("GET", "Can only look up local users"), + ("POST", "Only local users can be ratelimited"), + ("DELETE", "Only local users can be ratelimited"), + ] + ) + def test_user_is_not_local(self, method: str, error_msg: str): """ Tests that a lookup for a user that is not a local returns a 400 """ @@ -3541,35 +3454,13 @@ class RateLimitTestCase(unittest.HomeserverTestCase): ) channel = self.make_request( - "GET", - url, - access_token=self.admin_user_tok, - ) - - self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual("Can only look up local users", channel.json_body["error"]) - - channel = self.make_request( - "POST", - url, - access_token=self.admin_user_tok, - ) - - self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual( - "Only local users can be ratelimited", channel.json_body["error"] - ) - - channel = self.make_request( - "DELETE", + method, url, access_token=self.admin_user_tok, ) self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual( - "Only local users can be ratelimited", channel.json_body["error"] - ) + self.assertEqual(error_msg, channel.json_body["error"]) def test_invalid_parameter(self): """ diff --git a/tests/rest/media/v1/test_filepath.py b/tests/rest/media/v1/test_filepath.py new file mode 100644 index 0000000000..09504a485f --- /dev/null +++ b/tests/rest/media/v1/test_filepath.py @@ -0,0 +1,238 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.rest.media.v1.filepath import MediaFilePaths + +from tests import unittest + + +class MediaFilePathsTestCase(unittest.TestCase): + def setUp(self): + super().setUp() + + self.filepaths = MediaFilePaths("/media_store") + + def test_local_media_filepath(self): + """Test local media paths""" + self.assertEqual( + self.filepaths.local_media_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"), + "local_content/Ge/rZ/NDnDZVjsOtardLuwfIBg", + ) + self.assertEqual( + self.filepaths.local_media_filepath("GerZNDnDZVjsOtardLuwfIBg"), + "/media_store/local_content/Ge/rZ/NDnDZVjsOtardLuwfIBg", + ) + + def test_local_media_thumbnail(self): + """Test local media thumbnail paths""" + self.assertEqual( + self.filepaths.local_media_thumbnail_rel( + "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg", "scale" + ), + "local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", + ) + self.assertEqual( + self.filepaths.local_media_thumbnail( + "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg", "scale" + ), + "/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", + ) + + def test_local_media_thumbnail_dir(self): + """Test local media thumbnail directory paths""" + self.assertEqual( + self.filepaths.local_media_thumbnail_dir("GerZNDnDZVjsOtardLuwfIBg"), + "/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg", + ) + + def test_remote_media_filepath(self): + """Test remote media paths""" + self.assertEqual( + self.filepaths.remote_media_filepath_rel( + "example.com", "GerZNDnDZVjsOtardLuwfIBg" + ), + "remote_content/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg", + ) + self.assertEqual( + self.filepaths.remote_media_filepath( + "example.com", "GerZNDnDZVjsOtardLuwfIBg" + ), + "/media_store/remote_content/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg", + ) + + def test_remote_media_thumbnail(self): + """Test remote media thumbnail paths""" + self.assertEqual( + self.filepaths.remote_media_thumbnail_rel( + "example.com", + "GerZNDnDZVjsOtardLuwfIBg", + 800, + 600, + "image/jpeg", + "scale", + ), + "remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", + ) + self.assertEqual( + self.filepaths.remote_media_thumbnail( + "example.com", + "GerZNDnDZVjsOtardLuwfIBg", + 800, + 600, + "image/jpeg", + "scale", + ), + "/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", + ) + + def test_remote_media_thumbnail_legacy(self): + """Test old-style remote media thumbnail paths""" + self.assertEqual( + self.filepaths.remote_media_thumbnail_rel_legacy( + "example.com", "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg" + ), + "remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg", + ) + + def test_remote_media_thumbnail_dir(self): + """Test remote media thumbnail directory paths""" + self.assertEqual( + self.filepaths.remote_media_thumbnail_dir( + "example.com", "GerZNDnDZVjsOtardLuwfIBg" + ), + "/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg", + ) + + def test_url_cache_filepath(self): + """Test URL cache paths""" + self.assertEqual( + self.filepaths.url_cache_filepath_rel("2020-01-02_GerZNDnDZVjsOtar"), + "url_cache/2020-01-02/GerZNDnDZVjsOtar", + ) + self.assertEqual( + self.filepaths.url_cache_filepath("2020-01-02_GerZNDnDZVjsOtar"), + "/media_store/url_cache/2020-01-02/GerZNDnDZVjsOtar", + ) + + def test_url_cache_filepath_legacy(self): + """Test old-style URL cache paths""" + self.assertEqual( + self.filepaths.url_cache_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"), + "url_cache/Ge/rZ/NDnDZVjsOtardLuwfIBg", + ) + self.assertEqual( + self.filepaths.url_cache_filepath("GerZNDnDZVjsOtardLuwfIBg"), + "/media_store/url_cache/Ge/rZ/NDnDZVjsOtardLuwfIBg", + ) + + def test_url_cache_filepath_dirs_to_delete(self): + """Test URL cache cleanup paths""" + self.assertEqual( + self.filepaths.url_cache_filepath_dirs_to_delete( + "2020-01-02_GerZNDnDZVjsOtar" + ), + ["/media_store/url_cache/2020-01-02"], + ) + + def test_url_cache_filepath_dirs_to_delete_legacy(self): + """Test old-style URL cache cleanup paths""" + self.assertEqual( + self.filepaths.url_cache_filepath_dirs_to_delete( + "GerZNDnDZVjsOtardLuwfIBg" + ), + [ + "/media_store/url_cache/Ge/rZ", + "/media_store/url_cache/Ge", + ], + ) + + def test_url_cache_thumbnail(self): + """Test URL cache thumbnail paths""" + self.assertEqual( + self.filepaths.url_cache_thumbnail_rel( + "2020-01-02_GerZNDnDZVjsOtar", 800, 600, "image/jpeg", "scale" + ), + "url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar/800-600-image-jpeg-scale", + ) + self.assertEqual( + self.filepaths.url_cache_thumbnail( + "2020-01-02_GerZNDnDZVjsOtar", 800, 600, "image/jpeg", "scale" + ), + "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar/800-600-image-jpeg-scale", + ) + + def test_url_cache_thumbnail_legacy(self): + """Test old-style URL cache thumbnail paths""" + self.assertEqual( + self.filepaths.url_cache_thumbnail_rel( + "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg", "scale" + ), + "url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", + ) + self.assertEqual( + self.filepaths.url_cache_thumbnail( + "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg", "scale" + ), + "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", + ) + + def test_url_cache_thumbnail_directory(self): + """Test URL cache thumbnail directory paths""" + self.assertEqual( + self.filepaths.url_cache_thumbnail_directory_rel( + "2020-01-02_GerZNDnDZVjsOtar" + ), + "url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar", + ) + self.assertEqual( + self.filepaths.url_cache_thumbnail_directory("2020-01-02_GerZNDnDZVjsOtar"), + "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar", + ) + + def test_url_cache_thumbnail_directory_legacy(self): + """Test old-style URL cache thumbnail directory paths""" + self.assertEqual( + self.filepaths.url_cache_thumbnail_directory_rel( + "GerZNDnDZVjsOtardLuwfIBg" + ), + "url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg", + ) + self.assertEqual( + self.filepaths.url_cache_thumbnail_directory("GerZNDnDZVjsOtardLuwfIBg"), + "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg", + ) + + def test_url_cache_thumbnail_dirs_to_delete(self): + """Test URL cache thumbnail cleanup paths""" + self.assertEqual( + self.filepaths.url_cache_thumbnail_dirs_to_delete( + "2020-01-02_GerZNDnDZVjsOtar" + ), + [ + "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar", + "/media_store/url_cache_thumbnails/2020-01-02", + ], + ) + + def test_url_cache_thumbnail_dirs_to_delete_legacy(self): + """Test old-style URL cache thumbnail cleanup paths""" + self.assertEqual( + self.filepaths.url_cache_thumbnail_dirs_to_delete( + "GerZNDnDZVjsOtardLuwfIBg" + ), + [ + "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg", + "/media_store/url_cache_thumbnails/Ge/rZ", + "/media_store/url_cache_thumbnails/Ge", + ], + ) diff --git a/tests/rest/media/v1/test_oembed.py b/tests/rest/media/v1/test_oembed.py new file mode 100644 index 0000000000..048d0ca44a --- /dev/null +++ b/tests/rest/media/v1/test_oembed.py @@ -0,0 +1,51 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.rest.media.v1.oembed import OEmbedProvider +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock + +from tests.unittest import HomeserverTestCase + + +class OEmbedTests(HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): + self.oembed = OEmbedProvider(homeserver) + + def parse_response(self, response: JsonDict): + return self.oembed.parse_oembed_response( + "https://test", json.dumps(response).encode("utf-8") + ) + + def test_version(self): + """Accept versions that are similar to 1.0 as a string or int (or missing).""" + for version in ("1.0", 1.0, 1): + result = self.parse_response({"version": version, "type": "link"}) + # An empty Open Graph response is an error, ensure the URL is included. + self.assertIn("og:url", result.open_graph_result) + + # A missing version should be treated as 1.0. + result = self.parse_response({"type": "link"}) + self.assertIn("og:url", result.open_graph_result) + + # Invalid versions should be rejected. + for version in ("2.0", "1", 1.1, 0, None, {}, []): + result = self.parse_response({"version": version, "type": "link"}) + # An empty Open Graph response is an error, ensure the URL is included. + self.assertEqual({}, result.open_graph_result) |