diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 38e6d9f536..7dd4a5a367 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -20,6 +20,8 @@ from unittest.mock import Mock
from twisted.internet import defer
import synapse
+from synapse.handlers.auth import load_legacy_password_auth_providers
+from synapse.module_api import ModuleApi
from synapse.rest.client import devices, login
from synapse.types import JsonDict
@@ -36,8 +38,8 @@ ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_servi
mock_password_provider = Mock()
-class PasswordOnlyAuthProvider:
- """A password_provider which only implements `check_password`."""
+class LegacyPasswordOnlyAuthProvider:
+ """A legacy password_provider which only implements `check_password`."""
@staticmethod
def parse_config(self):
@@ -50,8 +52,8 @@ class PasswordOnlyAuthProvider:
return mock_password_provider.check_password(*args)
-class CustomAuthProvider:
- """A password_provider which implements a custom login type."""
+class LegacyCustomAuthProvider:
+ """A legacy password_provider which implements a custom login type."""
@staticmethod
def parse_config(self):
@@ -67,7 +69,23 @@ class CustomAuthProvider:
return mock_password_provider.check_auth(*args)
-class PasswordCustomAuthProvider:
+class CustomAuthProvider:
+ """A module which registers password_auth_provider callbacks for a custom login type."""
+
+ @staticmethod
+ def parse_config(self):
+ pass
+
+ def __init__(self, config, api: ModuleApi):
+ api.register_password_auth_provider_callbacks(
+ auth_checkers={("test.login_type", ("test_field",)): self.check_auth},
+ )
+
+ def check_auth(self, *args):
+ return mock_password_provider.check_auth(*args)
+
+
+class LegacyPasswordCustomAuthProvider:
"""A password_provider which implements password login via `check_auth`, as well
as a custom type."""
@@ -85,8 +103,32 @@ class PasswordCustomAuthProvider:
return mock_password_provider.check_auth(*args)
-def providers_config(*providers: Type[Any]) -> dict:
- """Returns a config dict that will enable the given password auth providers"""
+class PasswordCustomAuthProvider:
+ """A module which registers password_auth_provider callbacks for a custom login type.
+ as well as a password login"""
+
+ @staticmethod
+ def parse_config(self):
+ pass
+
+ def __init__(self, config, api: ModuleApi):
+ api.register_password_auth_provider_callbacks(
+ auth_checkers={
+ ("test.login_type", ("test_field",)): self.check_auth,
+ ("m.login.password", ("password",)): self.check_auth,
+ },
+ )
+ pass
+
+ def check_auth(self, *args):
+ return mock_password_provider.check_auth(*args)
+
+ def check_pass(self, *args):
+ return mock_password_provider.check_password(*args)
+
+
+def legacy_providers_config(*providers: Type[Any]) -> dict:
+ """Returns a config dict that will enable the given legacy password auth providers"""
return {
"password_providers": [
{"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
@@ -95,6 +137,16 @@ def providers_config(*providers: Type[Any]) -> dict:
}
+def providers_config(*providers: Type[Any]) -> dict:
+ """Returns a config dict that will enable the given modules"""
+ return {
+ "modules": [
+ {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
+ for provider in providers
+ ]
+ }
+
+
class PasswordAuthProviderTests(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
@@ -107,8 +159,21 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock()
super().setUp()
- @override_config(providers_config(PasswordOnlyAuthProvider))
- def test_password_only_auth_provider_login(self):
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver()
+ # Load the modules into the homeserver
+ module_api = hs.get_module_api()
+ for module, config in hs.config.modules.loaded_modules:
+ module(config=config, api=module_api)
+ load_legacy_password_auth_providers(hs)
+
+ return hs
+
+ @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
+ def test_password_only_auth_progiver_login_legacy(self):
+ self.password_only_auth_provider_login_test_body()
+
+ def password_only_auth_provider_login_test_body(self):
# login flows should only have m.login.password
flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
@@ -138,8 +203,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"@ USER🙂NAME :test", " pASS😢word "
)
- @override_config(providers_config(PasswordOnlyAuthProvider))
- def test_password_only_auth_provider_ui_auth(self):
+ @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
+ def test_password_only_auth_provider_ui_auth_legacy(self):
+ self.password_only_auth_provider_ui_auth_test_body()
+
+ def password_only_auth_provider_ui_auth_test_body(self):
"""UI Auth should delegate correctly to the password provider"""
# create the user, otherwise access doesn't work
@@ -172,8 +240,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
- @override_config(providers_config(PasswordOnlyAuthProvider))
- def test_local_user_fallback_login(self):
+ @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
+ def test_local_user_fallback_login_legacy(self):
+ self.local_user_fallback_login_test_body()
+
+ def local_user_fallback_login_test_body(self):
"""rejected login should fall back to local db"""
self.register_user("localuser", "localpass")
@@ -186,8 +257,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@localuser:test", channel.json_body["user_id"])
- @override_config(providers_config(PasswordOnlyAuthProvider))
- def test_local_user_fallback_ui_auth(self):
+ @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
+ def test_local_user_fallback_ui_auth_legacy(self):
+ self.local_user_fallback_ui_auth_test_body()
+
+ def local_user_fallback_ui_auth_test_body(self):
"""rejected login should fall back to local db"""
self.register_user("localuser", "localpass")
@@ -223,11 +297,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
@override_config(
{
- **providers_config(PasswordOnlyAuthProvider),
+ **legacy_providers_config(LegacyPasswordOnlyAuthProvider),
"password_config": {"localdb_enabled": False},
}
)
- def test_no_local_user_fallback_login(self):
+ def test_no_local_user_fallback_login_legacy(self):
+ self.no_local_user_fallback_login_test_body()
+
+ def no_local_user_fallback_login_test_body(self):
"""localdb_enabled can block login with the local password"""
self.register_user("localuser", "localpass")
@@ -242,11 +319,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
@override_config(
{
- **providers_config(PasswordOnlyAuthProvider),
+ **legacy_providers_config(LegacyPasswordOnlyAuthProvider),
"password_config": {"localdb_enabled": False},
}
)
- def test_no_local_user_fallback_ui_auth(self):
+ def test_no_local_user_fallback_ui_auth_legacy(self):
+ self.no_local_user_fallback_ui_auth_test_body()
+
+ def no_local_user_fallback_ui_auth_test_body(self):
"""localdb_enabled can block ui auth with the local password"""
self.register_user("localuser", "localpass")
@@ -280,11 +360,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
@override_config(
{
- **providers_config(PasswordOnlyAuthProvider),
+ **legacy_providers_config(LegacyPasswordOnlyAuthProvider),
"password_config": {"enabled": False},
}
)
- def test_password_auth_disabled(self):
+ def test_password_auth_disabled_legacy(self):
+ self.password_auth_disabled_test_body()
+
+ def password_auth_disabled_test_body(self):
"""password auth doesn't work if it's disabled across the board"""
# login flows should be empty
flows = self._get_login_flows()
@@ -295,8 +378,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_password.assert_not_called()
+ @override_config(legacy_providers_config(LegacyCustomAuthProvider))
+ def test_custom_auth_provider_login_legacy(self):
+ self.custom_auth_provider_login_test_body()
+
@override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_login(self):
+ self.custom_auth_provider_login_test_body()
+
+ def custom_auth_provider_login_test_body(self):
# login flows should have the custom flow and m.login.password, since we
# haven't disabled local password lookup.
# (password must come first, because reasons)
@@ -312,7 +402,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called()
- mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
+ mock_password_provider.check_auth.return_value = defer.succeed(
+ ("@user:bz", None)
+ )
channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@user:bz", channel.json_body["user_id"])
@@ -325,7 +417,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# in these cases, but at least we can guard against the API changing
# unexpectedly
mock_password_provider.check_auth.return_value = defer.succeed(
- "@ MALFORMED! :bz"
+ ("@ MALFORMED! :bz", None)
)
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
self.assertEqual(channel.code, 200, channel.result)
@@ -334,8 +426,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
" USER🙂NAME ", "test.login_type", {"test_field": " abc "}
)
+ @override_config(legacy_providers_config(LegacyCustomAuthProvider))
+ def test_custom_auth_provider_ui_auth_legacy(self):
+ self.custom_auth_provider_ui_auth_test_body()
+
@override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_ui_auth(self):
+ self.custom_auth_provider_ui_auth_test_body()
+
+ def custom_auth_provider_ui_auth_test_body(self):
# register the user and log in twice, to get two devices
self.register_user("localuser", "localpass")
tok1 = self.login("localuser", "localpass")
@@ -367,7 +466,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock()
# right params, but authing as the wrong user
- mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
+ mock_password_provider.check_auth.return_value = defer.succeed(
+ ("@user:bz", None)
+ )
body["auth"]["test_field"] = "foo"
channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 403)
@@ -379,7 +480,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# and finally, succeed
mock_password_provider.check_auth.return_value = defer.succeed(
- "@localuser:test"
+ ("@localuser:test", None)
)
channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 200)
@@ -387,8 +488,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"localuser", "test.login_type", {"test_field": "foo"}
)
+ @override_config(legacy_providers_config(LegacyCustomAuthProvider))
+ def test_custom_auth_provider_callback_legacy(self):
+ self.custom_auth_provider_callback_test_body()
+
@override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_callback(self):
+ self.custom_auth_provider_callback_test_body()
+
+ def custom_auth_provider_callback_test_body(self):
callback = Mock(return_value=defer.succeed(None))
mock_password_provider.check_auth.return_value = defer.succeed(
@@ -411,9 +519,21 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertIn(p, call_args[0])
@override_config(
+ {
+ **legacy_providers_config(LegacyCustomAuthProvider),
+ "password_config": {"enabled": False},
+ }
+ )
+ def test_custom_auth_password_disabled_legacy(self):
+ self.custom_auth_password_disabled_test_body()
+
+ @override_config(
{**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
)
def test_custom_auth_password_disabled(self):
+ self.custom_auth_password_disabled_test_body()
+
+ def custom_auth_password_disabled_test_body(self):
"""Test login with a custom auth provider where password login is disabled"""
self.register_user("localuser", "localpass")
@@ -427,11 +547,23 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
@override_config(
{
+ **legacy_providers_config(LegacyCustomAuthProvider),
+ "password_config": {"enabled": False, "localdb_enabled": False},
+ }
+ )
+ def test_custom_auth_password_disabled_localdb_enabled_legacy(self):
+ self.custom_auth_password_disabled_localdb_enabled_test_body()
+
+ @override_config(
+ {
**providers_config(CustomAuthProvider),
"password_config": {"enabled": False, "localdb_enabled": False},
}
)
def test_custom_auth_password_disabled_localdb_enabled(self):
+ self.custom_auth_password_disabled_localdb_enabled_test_body()
+
+ def custom_auth_password_disabled_localdb_enabled_test_body(self):
"""Check the localdb_enabled == enabled == False
Regression test for https://github.com/matrix-org/synapse/issues/8914: check
@@ -450,11 +582,23 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
@override_config(
{
+ **legacy_providers_config(LegacyPasswordCustomAuthProvider),
+ "password_config": {"enabled": False},
+ }
+ )
+ def test_password_custom_auth_password_disabled_login_legacy(self):
+ self.password_custom_auth_password_disabled_login_test_body()
+
+ @override_config(
+ {
**providers_config(PasswordCustomAuthProvider),
"password_config": {"enabled": False},
}
)
def test_password_custom_auth_password_disabled_login(self):
+ self.password_custom_auth_password_disabled_login_test_body()
+
+ def password_custom_auth_password_disabled_login_test_body(self):
"""log in with a custom auth provider which implements password, but password
login is disabled"""
self.register_user("localuser", "localpass")
@@ -466,6 +610,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called()
+ mock_password_provider.check_password.assert_not_called()
+
+ @override_config(
+ {
+ **legacy_providers_config(LegacyPasswordCustomAuthProvider),
+ "password_config": {"enabled": False},
+ }
+ )
+ def test_password_custom_auth_password_disabled_ui_auth_legacy(self):
+ self.password_custom_auth_password_disabled_ui_auth_test_body()
@override_config(
{
@@ -474,12 +628,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
}
)
def test_password_custom_auth_password_disabled_ui_auth(self):
+ self.password_custom_auth_password_disabled_ui_auth_test_body()
+
+ def password_custom_auth_password_disabled_ui_auth_test_body(self):
"""UI Auth with a custom auth provider which implements password, but password
login is disabled"""
# register the user and log in twice via the test login type to get two devices,
self.register_user("localuser", "localpass")
mock_password_provider.check_auth.return_value = defer.succeed(
- "@localuser:test"
+ ("@localuser:test", None)
)
channel = self._send_login("test.login_type", "localuser", test_field="")
self.assertEqual(channel.code, 200, channel.result)
@@ -516,6 +673,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"Password login has been disabled.", channel.json_body["error"]
)
mock_password_provider.check_auth.assert_not_called()
+ mock_password_provider.check_password.assert_not_called()
mock_password_provider.reset_mock()
# successful auth
@@ -526,6 +684,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.check_auth.assert_called_once_with(
"localuser", "test.login_type", {"test_field": "x"}
)
+ mock_password_provider.check_password.assert_not_called()
+
+ @override_config(
+ {
+ **legacy_providers_config(LegacyCustomAuthProvider),
+ "password_config": {"localdb_enabled": False},
+ }
+ )
+ def test_custom_auth_no_local_user_fallback_legacy(self):
+ self.custom_auth_no_local_user_fallback_test_body()
@override_config(
{
@@ -534,6 +702,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
}
)
def test_custom_auth_no_local_user_fallback(self):
+ self.custom_auth_no_local_user_fallback_test_body()
+
+ def custom_auth_no_local_user_fallback_test_body(self):
"""Test login with a custom auth provider where the local db is disabled"""
self.register_user("localuser", "localpass")
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 6ed9e42173..c9e2754b09 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -14,14 +14,13 @@
import hashlib
import hmac
-import json
import os
import urllib.parse
from binascii import unhexlify
from typing import List, Optional
from unittest.mock import Mock, patch
-from parameterized import parameterized
+from parameterized import parameterized, parameterized_class
import synapse.rest.admin
from synapse.api.constants import UserTypes
@@ -104,8 +103,8 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# 59 seconds
self.reactor.advance(59)
- body = json.dumps({"nonce": nonce})
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {"nonce": nonce}
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("username must be specified", channel.json_body["error"])
@@ -113,7 +112,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# 61 seconds
self.reactor.advance(2)
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("unrecognised nonce", channel.json_body["error"])
@@ -129,18 +128,16 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin")
want_mac = want_mac.hexdigest()
- body = json.dumps(
- {
- "nonce": nonce,
- "username": "bob",
- "password": "abc123",
- "admin": True,
- "mac": want_mac,
- }
- )
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {
+ "nonce": nonce,
+ "username": "bob",
+ "password": "abc123",
+ "admin": True,
+ "mac": want_mac,
+ }
+ channel = self.make_request("POST", self.url, body)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual("HMAC incorrect", channel.json_body["error"])
def test_register_correct_nonce(self):
@@ -157,17 +154,15 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
)
want_mac = want_mac.hexdigest()
- body = json.dumps(
- {
- "nonce": nonce,
- "username": "bob",
- "password": "abc123",
- "admin": True,
- "user_type": UserTypes.SUPPORT,
- "mac": want_mac,
- }
- )
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {
+ "nonce": nonce,
+ "username": "bob",
+ "password": "abc123",
+ "admin": True,
+ "user_type": UserTypes.SUPPORT,
+ "mac": want_mac,
+ }
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["user_id"])
@@ -183,22 +178,20 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
want_mac.update(nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin")
want_mac = want_mac.hexdigest()
- body = json.dumps(
- {
- "nonce": nonce,
- "username": "bob",
- "password": "abc123",
- "admin": True,
- "mac": want_mac,
- }
- )
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {
+ "nonce": nonce,
+ "username": "bob",
+ "password": "abc123",
+ "admin": True,
+ "mac": want_mac,
+ }
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["user_id"])
# Now, try and reuse it
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("unrecognised nonce", channel.json_body["error"])
@@ -218,9 +211,8 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Nonce check
#
- # Must be present
- body = json.dumps({})
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ # Must be an empty body present
+ channel = self.make_request("POST", self.url, {})
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("nonce must be specified", channel.json_body["error"])
@@ -230,29 +222,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
#
# Must be present
- body = json.dumps({"nonce": nonce()})
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, {"nonce": nonce()})
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("username must be specified", channel.json_body["error"])
# Must be a string
- body = json.dumps({"nonce": nonce(), "username": 1234})
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {"nonce": nonce(), "username": 1234}
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid username", channel.json_body["error"])
# Must not have null bytes
- body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"})
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {"nonce": nonce(), "username": "abcd\u0000"}
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid username", channel.json_body["error"])
# Must not have null bytes
- body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {"nonce": nonce(), "username": "a" * 1000}
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid username", channel.json_body["error"])
@@ -262,29 +253,29 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
#
# Must be present
- body = json.dumps({"nonce": nonce(), "username": "a"})
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {"nonce": nonce(), "username": "a"}
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("password must be specified", channel.json_body["error"])
# Must be a string
- body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {"nonce": nonce(), "username": "a", "password": 1234}
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid password", channel.json_body["error"])
# Must not have null bytes
- body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"})
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {"nonce": nonce(), "username": "a", "password": "abcd\u0000"}
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid password", channel.json_body["error"])
# Super long
- body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {"nonce": nonce(), "username": "a", "password": "A" * 1000}
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid password", channel.json_body["error"])
@@ -294,15 +285,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
#
# Invalid user_type
- body = json.dumps(
- {
- "nonce": nonce(),
- "username": "a",
- "password": "1234",
- "user_type": "invalid",
- }
- )
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {
+ "nonce": nonce(),
+ "username": "a",
+ "password": "1234",
+ "user_type": "invalid",
+ }
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid user type", channel.json_body["error"])
@@ -320,10 +309,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
want_mac.update(nonce.encode("ascii") + b"\x00bob1\x00abc123\x00notadmin")
want_mac = want_mac.hexdigest()
- body = json.dumps(
- {"nonce": nonce, "username": "bob1", "password": "abc123", "mac": want_mac}
- )
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {
+ "nonce": nonce,
+ "username": "bob1",
+ "password": "abc123",
+ "mac": want_mac,
+ }
+
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob1:test", channel.json_body["user_id"])
@@ -340,16 +333,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
want_mac.update(nonce.encode("ascii") + b"\x00bob2\x00abc123\x00notadmin")
want_mac = want_mac.hexdigest()
- body = json.dumps(
- {
- "nonce": nonce,
- "username": "bob2",
- "displayname": None,
- "password": "abc123",
- "mac": want_mac,
- }
- )
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {
+ "nonce": nonce,
+ "username": "bob2",
+ "displayname": None,
+ "password": "abc123",
+ "mac": want_mac,
+ }
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob2:test", channel.json_body["user_id"])
@@ -366,22 +357,20 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
want_mac.update(nonce.encode("ascii") + b"\x00bob3\x00abc123\x00notadmin")
want_mac = want_mac.hexdigest()
- body = json.dumps(
- {
- "nonce": nonce,
- "username": "bob3",
- "displayname": "",
- "password": "abc123",
- "mac": want_mac,
- }
- )
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {
+ "nonce": nonce,
+ "username": "bob3",
+ "displayname": "",
+ "password": "abc123",
+ "mac": want_mac,
+ }
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob3:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob3:test/displayname")
- self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(404, channel.code, msg=channel.json_body)
# set displayname
channel = self.make_request("GET", self.url)
@@ -391,16 +380,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
want_mac.update(nonce.encode("ascii") + b"\x00bob4\x00abc123\x00notadmin")
want_mac = want_mac.hexdigest()
- body = json.dumps(
- {
- "nonce": nonce,
- "username": "bob4",
- "displayname": "Bob's Name",
- "password": "abc123",
- "mac": want_mac,
- }
- )
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {
+ "nonce": nonce,
+ "username": "bob4",
+ "displayname": "Bob's Name",
+ "password": "abc123",
+ "mac": want_mac,
+ }
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob4:test", channel.json_body["user_id"])
@@ -440,17 +427,15 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
)
want_mac = want_mac.hexdigest()
- body = json.dumps(
- {
- "nonce": nonce,
- "username": "bob",
- "password": "abc123",
- "admin": True,
- "user_type": UserTypes.SUPPORT,
- "mac": want_mac,
- }
- )
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {
+ "nonce": nonce,
+ "username": "bob",
+ "password": "abc123",
+ "admin": True,
+ "user_type": UserTypes.SUPPORT,
+ "mac": want_mac,
+ }
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["user_id"])
@@ -993,12 +978,11 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
"""
If parameter `erase` is not boolean, return an error
"""
- body = json.dumps({"erase": "False"})
channel = self.make_request(
"POST",
self.url,
- content=body.encode(encoding="utf_8"),
+ content={"erase": "False"},
access_token=self.admin_user_tok,
)
@@ -2201,7 +2185,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self):
@@ -2216,7 +2200,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_does_not_exist(self):
@@ -2359,7 +2343,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self):
@@ -2374,7 +2358,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_does_not_exist(self):
@@ -3073,7 +3057,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
"""Try to login as a user without authentication."""
channel = self.make_request("POST", self.url, b"{}")
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_not_admin(self):
@@ -3082,7 +3066,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
"POST", self.url, b"{}", access_token=self.other_user_tok
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, channel.code, msg=channel.json_body)
def test_send_event(self):
"""Test that sending event as a user works."""
@@ -3127,7 +3111,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# The puppet token should no longer work
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(401, channel.code, msg=channel.json_body)
# .. but the real user's tokens should still work
channel = self.make_request(
@@ -3160,7 +3144,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(401, channel.code, msg=channel.json_body)
def test_admin_logout_all(self):
"""Tests that the admin user calling `/logout/all` does expire the
@@ -3181,7 +3165,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# The puppet token should no longer work
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(401, channel.code, msg=channel.json_body)
# .. but the real user's tokens should still work
channel = self.make_request(
@@ -3242,6 +3226,13 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
self.helper.join(room_id, user=self.other_user, tok=puppet_token)
+@parameterized_class(
+ ("url_prefix",),
+ [
+ ("/_synapse/admin/v1/whois/%s",),
+ ("/_matrix/client/r0/admin/whois/%s",),
+ ],
+)
class WhoisRestTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -3254,21 +3245,14 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
self.admin_user_tok = self.login("admin", "pass")
self.other_user = self.register_user("user", "pass")
- self.url1 = "/_synapse/admin/v1/whois/%s" % urllib.parse.quote(self.other_user)
- self.url2 = "/_matrix/client/r0/admin/whois/%s" % urllib.parse.quote(
- self.other_user
- )
+ self.url = self.url_prefix % self.other_user
def test_no_auth(self):
"""
Try to get information of an user without authentication.
"""
- channel = self.make_request("GET", self.url1, b"{}")
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
-
- channel = self.make_request("GET", self.url2, b"{}")
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ channel = self.make_request("GET", self.url, b"{}")
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_not_admin(self):
@@ -3280,38 +3264,21 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- self.url1,
- access_token=other_user2_token,
- )
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
- channel = self.make_request(
- "GET",
- self.url2,
+ self.url,
access_token=other_user2_token,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_is_not_local(self):
"""
Tests that a lookup for a user that is not a local returns a 400
"""
- url1 = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
- url2 = "/_matrix/client/r0/admin/whois/@unknown_person:unknown_domain"
-
- channel = self.make_request(
- "GET",
- url1,
- access_token=self.admin_user_tok,
- )
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual("Can only whois a local user", channel.json_body["error"])
+ url = self.url_prefix % "@unknown_person:unknown_domain"
channel = self.make_request(
"GET",
- url2,
+ url,
access_token=self.admin_user_tok,
)
self.assertEqual(400, channel.code, msg=channel.json_body)
@@ -3323,16 +3290,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request(
"GET",
- self.url1,
- access_token=self.admin_user_tok,
- )
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual(self.other_user, channel.json_body["user_id"])
- self.assertIn("devices", channel.json_body)
-
- channel = self.make_request(
- "GET",
- self.url2,
+ self.url,
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -3347,16 +3305,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- self.url1,
- access_token=other_user_token,
- )
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual(self.other_user, channel.json_body["user_id"])
- self.assertIn("devices", channel.json_body)
-
- channel = self.make_request(
- "GET",
- self.url2,
+ self.url,
access_token=other_user_token,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -3388,7 +3337,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
Try to get information of an user without authentication.
"""
channel = self.make_request("POST", self.url)
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_not_admin(self):
@@ -3398,7 +3347,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
other_user_token = self.login("user", "pass")
channel = self.make_request("POST", self.url, access_token=other_user_token)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_is_not_local(self):
@@ -3447,84 +3396,41 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
% urllib.parse.quote(self.other_user)
)
- def test_no_auth(self):
+ @parameterized.expand(["GET", "POST", "DELETE"])
+ def test_no_auth(self, method: str):
"""
Try to get information of a user without authentication.
"""
- channel = self.make_request("GET", self.url, b"{}")
-
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
-
- channel = self.make_request("POST", self.url, b"{}")
+ channel = self.make_request(method, self.url, b"{}")
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
-
- channel = self.make_request("DELETE", self.url, b"{}")
-
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_requester_is_no_admin(self):
+ @parameterized.expand(["GET", "POST", "DELETE"])
+ def test_requester_is_no_admin(self, method: str):
"""
If the user is not a server admin, an error is returned.
"""
other_user_token = self.login("user", "pass")
channel = self.make_request(
- "GET",
- self.url,
- access_token=other_user_token,
- )
-
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
- channel = self.make_request(
- "POST",
- self.url,
- access_token=other_user_token,
- )
-
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
- channel = self.make_request(
- "DELETE",
+ method,
self.url,
access_token=other_user_token,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_user_does_not_exist(self):
+ @parameterized.expand(["GET", "POST", "DELETE"])
+ def test_user_does_not_exist(self, method: str):
"""
Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v1/users/@unknown_person:test/override_ratelimit"
channel = self.make_request(
- "GET",
- url,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(404, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
-
- channel = self.make_request(
- "POST",
- url,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(404, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
-
- channel = self.make_request(
- "DELETE",
+ method,
url,
access_token=self.admin_user_tok,
)
@@ -3532,7 +3438,14 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
- def test_user_is_not_local(self):
+ @parameterized.expand(
+ [
+ ("GET", "Can only look up local users"),
+ ("POST", "Only local users can be ratelimited"),
+ ("DELETE", "Only local users can be ratelimited"),
+ ]
+ )
+ def test_user_is_not_local(self, method: str, error_msg: str):
"""
Tests that a lookup for a user that is not a local returns a 400
"""
@@ -3541,35 +3454,13 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
)
channel = self.make_request(
- "GET",
- url,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual("Can only look up local users", channel.json_body["error"])
-
- channel = self.make_request(
- "POST",
- url,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual(
- "Only local users can be ratelimited", channel.json_body["error"]
- )
-
- channel = self.make_request(
- "DELETE",
+ method,
url,
access_token=self.admin_user_tok,
)
self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual(
- "Only local users can be ratelimited", channel.json_body["error"]
- )
+ self.assertEqual(error_msg, channel.json_body["error"])
def test_invalid_parameter(self):
"""
diff --git a/tests/rest/media/v1/test_filepath.py b/tests/rest/media/v1/test_filepath.py
new file mode 100644
index 0000000000..09504a485f
--- /dev/null
+++ b/tests/rest/media/v1/test_filepath.py
@@ -0,0 +1,238 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.rest.media.v1.filepath import MediaFilePaths
+
+from tests import unittest
+
+
+class MediaFilePathsTestCase(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+
+ self.filepaths = MediaFilePaths("/media_store")
+
+ def test_local_media_filepath(self):
+ """Test local media paths"""
+ self.assertEqual(
+ self.filepaths.local_media_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"),
+ "local_content/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+ )
+ self.assertEqual(
+ self.filepaths.local_media_filepath("GerZNDnDZVjsOtardLuwfIBg"),
+ "/media_store/local_content/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+ )
+
+ def test_local_media_thumbnail(self):
+ """Test local media thumbnail paths"""
+ self.assertEqual(
+ self.filepaths.local_media_thumbnail_rel(
+ "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg", "scale"
+ ),
+ "local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
+ )
+ self.assertEqual(
+ self.filepaths.local_media_thumbnail(
+ "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg", "scale"
+ ),
+ "/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
+ )
+
+ def test_local_media_thumbnail_dir(self):
+ """Test local media thumbnail directory paths"""
+ self.assertEqual(
+ self.filepaths.local_media_thumbnail_dir("GerZNDnDZVjsOtardLuwfIBg"),
+ "/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+ )
+
+ def test_remote_media_filepath(self):
+ """Test remote media paths"""
+ self.assertEqual(
+ self.filepaths.remote_media_filepath_rel(
+ "example.com", "GerZNDnDZVjsOtardLuwfIBg"
+ ),
+ "remote_content/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+ )
+ self.assertEqual(
+ self.filepaths.remote_media_filepath(
+ "example.com", "GerZNDnDZVjsOtardLuwfIBg"
+ ),
+ "/media_store/remote_content/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+ )
+
+ def test_remote_media_thumbnail(self):
+ """Test remote media thumbnail paths"""
+ self.assertEqual(
+ self.filepaths.remote_media_thumbnail_rel(
+ "example.com",
+ "GerZNDnDZVjsOtardLuwfIBg",
+ 800,
+ 600,
+ "image/jpeg",
+ "scale",
+ ),
+ "remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
+ )
+ self.assertEqual(
+ self.filepaths.remote_media_thumbnail(
+ "example.com",
+ "GerZNDnDZVjsOtardLuwfIBg",
+ 800,
+ 600,
+ "image/jpeg",
+ "scale",
+ ),
+ "/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
+ )
+
+ def test_remote_media_thumbnail_legacy(self):
+ """Test old-style remote media thumbnail paths"""
+ self.assertEqual(
+ self.filepaths.remote_media_thumbnail_rel_legacy(
+ "example.com", "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg"
+ ),
+ "remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg",
+ )
+
+ def test_remote_media_thumbnail_dir(self):
+ """Test remote media thumbnail directory paths"""
+ self.assertEqual(
+ self.filepaths.remote_media_thumbnail_dir(
+ "example.com", "GerZNDnDZVjsOtardLuwfIBg"
+ ),
+ "/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+ )
+
+ def test_url_cache_filepath(self):
+ """Test URL cache paths"""
+ self.assertEqual(
+ self.filepaths.url_cache_filepath_rel("2020-01-02_GerZNDnDZVjsOtar"),
+ "url_cache/2020-01-02/GerZNDnDZVjsOtar",
+ )
+ self.assertEqual(
+ self.filepaths.url_cache_filepath("2020-01-02_GerZNDnDZVjsOtar"),
+ "/media_store/url_cache/2020-01-02/GerZNDnDZVjsOtar",
+ )
+
+ def test_url_cache_filepath_legacy(self):
+ """Test old-style URL cache paths"""
+ self.assertEqual(
+ self.filepaths.url_cache_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"),
+ "url_cache/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+ )
+ self.assertEqual(
+ self.filepaths.url_cache_filepath("GerZNDnDZVjsOtardLuwfIBg"),
+ "/media_store/url_cache/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+ )
+
+ def test_url_cache_filepath_dirs_to_delete(self):
+ """Test URL cache cleanup paths"""
+ self.assertEqual(
+ self.filepaths.url_cache_filepath_dirs_to_delete(
+ "2020-01-02_GerZNDnDZVjsOtar"
+ ),
+ ["/media_store/url_cache/2020-01-02"],
+ )
+
+ def test_url_cache_filepath_dirs_to_delete_legacy(self):
+ """Test old-style URL cache cleanup paths"""
+ self.assertEqual(
+ self.filepaths.url_cache_filepath_dirs_to_delete(
+ "GerZNDnDZVjsOtardLuwfIBg"
+ ),
+ [
+ "/media_store/url_cache/Ge/rZ",
+ "/media_store/url_cache/Ge",
+ ],
+ )
+
+ def test_url_cache_thumbnail(self):
+ """Test URL cache thumbnail paths"""
+ self.assertEqual(
+ self.filepaths.url_cache_thumbnail_rel(
+ "2020-01-02_GerZNDnDZVjsOtar", 800, 600, "image/jpeg", "scale"
+ ),
+ "url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar/800-600-image-jpeg-scale",
+ )
+ self.assertEqual(
+ self.filepaths.url_cache_thumbnail(
+ "2020-01-02_GerZNDnDZVjsOtar", 800, 600, "image/jpeg", "scale"
+ ),
+ "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar/800-600-image-jpeg-scale",
+ )
+
+ def test_url_cache_thumbnail_legacy(self):
+ """Test old-style URL cache thumbnail paths"""
+ self.assertEqual(
+ self.filepaths.url_cache_thumbnail_rel(
+ "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg", "scale"
+ ),
+ "url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
+ )
+ self.assertEqual(
+ self.filepaths.url_cache_thumbnail(
+ "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg", "scale"
+ ),
+ "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
+ )
+
+ def test_url_cache_thumbnail_directory(self):
+ """Test URL cache thumbnail directory paths"""
+ self.assertEqual(
+ self.filepaths.url_cache_thumbnail_directory_rel(
+ "2020-01-02_GerZNDnDZVjsOtar"
+ ),
+ "url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar",
+ )
+ self.assertEqual(
+ self.filepaths.url_cache_thumbnail_directory("2020-01-02_GerZNDnDZVjsOtar"),
+ "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar",
+ )
+
+ def test_url_cache_thumbnail_directory_legacy(self):
+ """Test old-style URL cache thumbnail directory paths"""
+ self.assertEqual(
+ self.filepaths.url_cache_thumbnail_directory_rel(
+ "GerZNDnDZVjsOtardLuwfIBg"
+ ),
+ "url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+ )
+ self.assertEqual(
+ self.filepaths.url_cache_thumbnail_directory("GerZNDnDZVjsOtardLuwfIBg"),
+ "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+ )
+
+ def test_url_cache_thumbnail_dirs_to_delete(self):
+ """Test URL cache thumbnail cleanup paths"""
+ self.assertEqual(
+ self.filepaths.url_cache_thumbnail_dirs_to_delete(
+ "2020-01-02_GerZNDnDZVjsOtar"
+ ),
+ [
+ "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar",
+ "/media_store/url_cache_thumbnails/2020-01-02",
+ ],
+ )
+
+ def test_url_cache_thumbnail_dirs_to_delete_legacy(self):
+ """Test old-style URL cache thumbnail cleanup paths"""
+ self.assertEqual(
+ self.filepaths.url_cache_thumbnail_dirs_to_delete(
+ "GerZNDnDZVjsOtardLuwfIBg"
+ ),
+ [
+ "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+ "/media_store/url_cache_thumbnails/Ge/rZ",
+ "/media_store/url_cache_thumbnails/Ge",
+ ],
+ )
diff --git a/tests/rest/media/v1/test_oembed.py b/tests/rest/media/v1/test_oembed.py
new file mode 100644
index 0000000000..048d0ca44a
--- /dev/null
+++ b/tests/rest/media/v1/test_oembed.py
@@ -0,0 +1,51 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.rest.media.v1.oembed import OEmbedProvider
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
+
+from tests.unittest import HomeserverTestCase
+
+
+class OEmbedTests(HomeserverTestCase):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
+ self.oembed = OEmbedProvider(homeserver)
+
+ def parse_response(self, response: JsonDict):
+ return self.oembed.parse_oembed_response(
+ "https://test", json.dumps(response).encode("utf-8")
+ )
+
+ def test_version(self):
+ """Accept versions that are similar to 1.0 as a string or int (or missing)."""
+ for version in ("1.0", 1.0, 1):
+ result = self.parse_response({"version": version, "type": "link"})
+ # An empty Open Graph response is an error, ensure the URL is included.
+ self.assertIn("og:url", result.open_graph_result)
+
+ # A missing version should be treated as 1.0.
+ result = self.parse_response({"type": "link"})
+ self.assertIn("og:url", result.open_graph_result)
+
+ # Invalid versions should be rejected.
+ for version in ("2.0", "1", 1.1, 0, None, {}, []):
+ result = self.parse_response({"version": version, "type": "link"})
+ # An empty Open Graph response is an error, ensure the URL is included.
+ self.assertEqual({}, result.open_graph_result)
|