diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 6e36e73f0d..cdb0048122 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -18,7 +18,7 @@ import pymacaroons
from twisted.test.proto_helpers import MemoryReactor
-from synapse.api.auth import Auth
+from synapse.api.auth.internal import InternalAuth
from synapse.api.auth_blocking import AuthBlocking
from synapse.api.constants import UserTypes
from synapse.api.errors import (
@@ -48,7 +48,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
# have been called by the HomeserverTestCase machinery.
hs.datastores.main = self.store # type: ignore[union-attr]
hs.get_auth_handler().store = self.store
- self.auth = Auth(hs)
+ self.auth = InternalAuth(hs)
# AuthBlocking reads from the hs' config on initialization. We need to
# modify its config instead of the hs'
@@ -426,6 +426,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
access_token_id=None,
device_id="FOOBAR",
is_guest=False,
+ scope=set(),
shadow_banned=False,
app_service=appservice,
authenticated_entity="@appservice:server",
@@ -456,6 +457,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
access_token_id=None,
device_id="FOOBAR",
is_guest=False,
+ scope=set(),
shadow_banned=False,
app_service=appservice,
authenticated_entity="@appservice:server",
diff --git a/tests/config/test_oauth_delegation.py b/tests/config/test_oauth_delegation.py
new file mode 100644
index 0000000000..f57c813a58
--- /dev/null
+++ b/tests/config/test_oauth_delegation.py
@@ -0,0 +1,257 @@
+# Copyright 2023 Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from unittest.mock import Mock
+
+from synapse.config import ConfigError
+from synapse.config.homeserver import HomeServerConfig
+from synapse.module_api import ModuleApi
+from synapse.types import JsonDict
+
+from tests.server import get_clock, setup_test_homeserver
+from tests.unittest import TestCase, skip_unless
+from tests.utils import default_config
+
+try:
+ import authlib # noqa: F401
+
+ HAS_AUTHLIB = True
+except ImportError:
+ HAS_AUTHLIB = False
+
+
+# These are a few constants that are used as config parameters in the tests.
+SERVER_NAME = "test"
+ISSUER = "https://issuer/"
+CLIENT_ID = "test-client-id"
+CLIENT_SECRET = "test-client-secret"
+BASE_URL = "https://synapse/"
+
+
+class CustomAuthModule:
+ """A module which registers a password auth provider."""
+
+ @staticmethod
+ def parse_config(config: JsonDict) -> None:
+ pass
+
+ def __init__(self, config: None, api: ModuleApi):
+ api.register_password_auth_provider_callbacks(
+ auth_checkers={("m.login.password", ("password",)): Mock()},
+ )
+
+
+@skip_unless(HAS_AUTHLIB, "requires authlib")
+class MSC3861OAuthDelegation(TestCase):
+ """Test that the Homeserver fails to initialize if the config is invalid."""
+
+ def setUp(self) -> None:
+ self.config_dict: JsonDict = {
+ **default_config("test"),
+ "public_baseurl": BASE_URL,
+ "enable_registration": False,
+ "experimental_features": {
+ "msc3861": {
+ "enabled": True,
+ "issuer": ISSUER,
+ "client_id": CLIENT_ID,
+ "client_auth_method": "client_secret_post",
+ "client_secret": CLIENT_SECRET,
+ }
+ },
+ }
+
+ def parse_config(self) -> HomeServerConfig:
+ config = HomeServerConfig()
+ config.parse_config_dict(self.config_dict, "", "")
+ return config
+
+ def test_client_secret_post_works(self) -> None:
+ self.config_dict["experimental_features"]["msc3861"].update(
+ client_auth_method="client_secret_post",
+ client_secret=CLIENT_SECRET,
+ )
+
+ self.parse_config()
+
+ def test_client_secret_post_requires_client_secret(self) -> None:
+ self.config_dict["experimental_features"]["msc3861"].update(
+ client_auth_method="client_secret_post",
+ client_secret=None,
+ )
+
+ with self.assertRaises(ConfigError):
+ self.parse_config()
+
+ def test_client_secret_basic_works(self) -> None:
+ self.config_dict["experimental_features"]["msc3861"].update(
+ client_auth_method="client_secret_basic",
+ client_secret=CLIENT_SECRET,
+ )
+
+ self.parse_config()
+
+ def test_client_secret_basic_requires_client_secret(self) -> None:
+ self.config_dict["experimental_features"]["msc3861"].update(
+ client_auth_method="client_secret_basic",
+ client_secret=None,
+ )
+
+ with self.assertRaises(ConfigError):
+ self.parse_config()
+
+ def test_client_secret_jwt_works(self) -> None:
+ self.config_dict["experimental_features"]["msc3861"].update(
+ client_auth_method="client_secret_jwt",
+ client_secret=CLIENT_SECRET,
+ )
+
+ self.parse_config()
+
+ def test_client_secret_jwt_requires_client_secret(self) -> None:
+ self.config_dict["experimental_features"]["msc3861"].update(
+ client_auth_method="client_secret_jwt",
+ client_secret=None,
+ )
+
+ with self.assertRaises(ConfigError):
+ self.parse_config()
+
+ def test_invalid_client_auth_method(self) -> None:
+ self.config_dict["experimental_features"]["msc3861"].update(
+ client_auth_method="invalid",
+ )
+
+ with self.assertRaises(ConfigError):
+ self.parse_config()
+
+ def test_private_key_jwt_requires_jwk(self) -> None:
+ self.config_dict["experimental_features"]["msc3861"].update(
+ client_auth_method="private_key_jwt",
+ )
+
+ with self.assertRaises(ConfigError):
+ self.parse_config()
+
+ def test_private_key_jwt_works(self) -> None:
+ self.config_dict["experimental_features"]["msc3861"].update(
+ client_auth_method="private_key_jwt",
+ jwk={
+ "p": "-frVdP_tZ-J_nIR6HNMDq1N7aunwm51nAqNnhqIyuA8ikx7LlQED1tt2LD3YEvYyW8nxE2V95HlCRZXQPMiRJBFOsbmYkzl2t-MpavTaObB_fct_JqcRtdXddg4-_ihdjRDwUOreq_dpWh6MIKsC3UyekfkHmeEJg5YpOTL15j8",
+ "kty": "RSA",
+ "q": "oFw-Enr_YozQB1ab-kawn4jY3yHi8B1nSmYT0s8oTCflrmps5BFJfCkHL5ij3iY15z0o2m0N-jjB1oSJ98O4RayEEYNQlHnTNTl0kRIWzpoqblHUIxVcahIpP_xTovBJzwi8XXoLGqHOOMA-r40LSyVgP2Ut8D9qBwV6_UfT0LU",
+ "d": "WFkDPYo4b4LIS64D_QtQfGGuAObPvc3HFfp9VZXyq3SJR58XZRHE0jqtlEMNHhOTgbMYS3w8nxPQ_qVzY-5hs4fIanwvB64mAoOGl0qMHO65DTD_WsGFwzYClJPBVniavkLE2Hmpu8IGe6lGliN8vREC6_4t69liY-XcN_ECboVtC2behKkLOEASOIMuS7YcKAhTJFJwkl1dqDlliEn5A4u4xy7nuWQz3juB1OFdKlwGA5dfhDNglhoLIwNnkLsUPPFO-WB5ZNEW35xxHOToxj4bShvDuanVA6mJPtTKjz0XibjB36bj_nF_j7EtbE2PdGJ2KevAVgElR4lqS4ISgQ",
+ "e": "AQAB",
+ "kid": "test",
+ "qi": "cPfNk8l8W5exVNNea4d7QZZ8Qr8LgHghypYAxz8PQh1fNa8Ya1SNUDVzC2iHHhszxxA0vB9C7jGze8dBrvnzWYF1XvQcqNIVVgHhD57R1Nm3dj2NoHIKe0Cu4bCUtP8xnZQUN4KX7y4IIcgRcBWG1hT6DEYZ4BxqicnBXXNXAUI",
+ "dp": "dKlMHvslV1sMBQaKWpNb3gPq0B13TZhqr3-E2_8sPlvJ3fD8P4CmwwnOn50JDuhY3h9jY5L06sBwXjspYISVv8hX-ndMLkEeF3lrJeA5S70D8rgakfZcPIkffm3tlf1Ok3v5OzoxSv3-67Df4osMniyYwDUBCB5Oq1tTx77xpU8",
+ "dq": "S4ooU1xNYYcjl9FcuJEEMqKsRrAXzzSKq6laPTwIp5dDwt2vXeAm1a4eDHXC-6rUSZGt5PbqVqzV4s-cjnJMI8YYkIdjNg4NSE1Ac_YpeDl3M3Colb5CQlU7yUB7xY2bt0NOOFp9UJZYJrOo09mFMGjy5eorsbitoZEbVqS3SuE",
+ "n": "nJbYKqFwnURKimaviyDFrNLD3gaKR1JW343Qem25VeZxoMq1665RHVoO8n1oBm4ClZdjIiZiVdpyqzD5-Ow12YQgQEf1ZHP3CCcOQQhU57Rh5XvScTe5IxYVkEW32IW2mp_CJ6WfjYpfeL4azarVk8H3Vr59d1rSrKTVVinVdZer9YLQyC_rWAQNtHafPBMrf6RYiNGV9EiYn72wFIXlLlBYQ9Fx7bfe1PaL6qrQSsZP3_rSpuvVdLh1lqGeCLR0pyclA9uo5m2tMyCXuuGQLbA_QJm5xEc7zd-WFdux2eXF045oxnSZ_kgQt-pdN7AxGWOVvwoTf9am6mSkEdv6iw",
+ },
+ )
+ self.parse_config()
+
+ def test_registration_cannot_be_enabled(self) -> None:
+ self.config_dict["enable_registration"] = True
+ with self.assertRaises(ConfigError):
+ self.parse_config()
+
+ def test_password_config_cannot_be_enabled(self) -> None:
+ self.config_dict["password_config"] = {"enabled": True}
+ with self.assertRaises(ConfigError):
+ self.parse_config()
+
+ def test_oidc_sso_cannot_be_enabled(self) -> None:
+ self.config_dict["oidc_providers"] = [
+ {
+ "idp_id": "microsoft",
+ "idp_name": "Microsoft",
+ "issuer": "https://login.microsoftonline.com/<tenant id>/v2.0",
+ "client_id": "<client id>",
+ "client_secret": "<client secret>",
+ "scopes": ["openid", "profile"],
+ "authorization_endpoint": "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/authorize",
+ "token_endpoint": "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/token",
+ "userinfo_endpoint": "https://graph.microsoft.com/oidc/userinfo",
+ }
+ ]
+
+ with self.assertRaises(ConfigError):
+ self.parse_config()
+
+ def test_cas_sso_cannot_be_enabled(self) -> None:
+ self.config_dict["cas_config"] = {
+ "enabled": True,
+ "server_url": "https://cas-server.com",
+ "displayname_attribute": "name",
+ "required_attributes": {"userGroup": "staff", "department": "None"},
+ }
+
+ with self.assertRaises(ConfigError):
+ self.parse_config()
+
+ def test_auth_providers_cannot_be_enabled(self) -> None:
+ self.config_dict["modules"] = [
+ {
+ "module": f"{__name__}.{CustomAuthModule.__qualname__}",
+ "config": {},
+ }
+ ]
+
+ # This requires actually setting up an HS, as the module will be run on setup,
+ # which should raise as the module tries to register an auth provider
+ config = self.parse_config()
+ reactor, clock = get_clock()
+ with self.assertRaises(ConfigError):
+ setup_test_homeserver(
+ self.addCleanup, reactor=reactor, clock=clock, config=config
+ )
+
+ def test_jwt_auth_cannot_be_enabled(self) -> None:
+ self.config_dict["jwt_config"] = {
+ "enabled": True,
+ "secret": "my-secret-token",
+ "algorithm": "HS256",
+ }
+
+ with self.assertRaises(ConfigError):
+ self.parse_config()
+
+ def test_login_via_existing_session_cannot_be_enabled(self) -> None:
+ self.config_dict["login_via_existing_session"] = {"enabled": True}
+ with self.assertRaises(ConfigError):
+ self.parse_config()
+
+ def test_captcha_cannot_be_enabled(self) -> None:
+ self.config_dict.update(
+ enable_registration_captcha=True,
+ recaptcha_public_key="test",
+ recaptcha_private_key="test",
+ )
+ with self.assertRaises(ConfigError):
+ self.parse_config()
+
+ def test_refreshable_tokens_cannot_be_enabled(self) -> None:
+ self.config_dict.update(
+ refresh_token_lifetime="24h",
+ refreshable_access_token_lifetime="10m",
+ nonrefreshable_access_token_lifetime="24h",
+ )
+ with self.assertRaises(ConfigError):
+ self.parse_config()
+
+ def test_session_lifetime_cannot_be_set(self) -> None:
+ self.config_dict["session_lifetime"] = "24h"
+ with self.assertRaises(ConfigError):
+ self.parse_config()
diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py
new file mode 100644
index 0000000000..6309d7b36e
--- /dev/null
+++ b/tests/handlers/test_oauth_delegation.py
@@ -0,0 +1,664 @@
+# Copyright 2022 Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from http import HTTPStatus
+from typing import Any, Dict, Union
+from unittest.mock import ANY, Mock
+from urllib.parse import parse_qs
+
+from signedjson.key import (
+ encode_verify_key_base64,
+ generate_signing_key,
+ get_verify_key,
+)
+from signedjson.sign import sign_json
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.errors import (
+ AuthError,
+ Codes,
+ InvalidClientTokenError,
+ OAuthInsufficientScopeError,
+ SynapseError,
+)
+from synapse.rest import admin
+from synapse.rest.client import account, devices, keys, login, logout, register
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
+
+from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
+from tests.unittest import HomeserverTestCase, skip_unless
+from tests.utils import mock_getRawHeaders
+
+try:
+ import authlib # noqa: F401
+
+ HAS_AUTHLIB = True
+except ImportError:
+ HAS_AUTHLIB = False
+
+
+# These are a few constants that are used as config parameters in the tests.
+SERVER_NAME = "test"
+ISSUER = "https://issuer/"
+CLIENT_ID = "test-client-id"
+CLIENT_SECRET = "test-client-secret"
+BASE_URL = "https://synapse/"
+SCOPES = ["openid"]
+
+AUTHORIZATION_ENDPOINT = ISSUER + "authorize"
+TOKEN_ENDPOINT = ISSUER + "token"
+USERINFO_ENDPOINT = ISSUER + "userinfo"
+WELL_KNOWN = ISSUER + ".well-known/openid-configuration"
+JWKS_URI = ISSUER + ".well-known/jwks.json"
+INTROSPECTION_ENDPOINT = ISSUER + "introspect"
+
+SYNAPSE_ADMIN_SCOPE = "urn:synapse:admin:*"
+MATRIX_USER_SCOPE = "urn:matrix:org.matrix.msc2967.client:api:*"
+MATRIX_GUEST_SCOPE = "urn:matrix:org.matrix.msc2967.client:api:guest"
+MATRIX_DEVICE_SCOPE_PREFIX = "urn:matrix:org.matrix.msc2967.client:device:"
+DEVICE = "AABBCCDD"
+MATRIX_DEVICE_SCOPE = MATRIX_DEVICE_SCOPE_PREFIX + DEVICE
+SUBJECT = "abc-def-ghi"
+USERNAME = "test-user"
+USER_ID = "@" + USERNAME + ":" + SERVER_NAME
+
+
+async def get_json(url: str) -> JsonDict:
+ # Mock get_json calls to handle jwks & oidc discovery endpoints
+ if url == WELL_KNOWN:
+ # Minimal discovery document, as defined in OpenID.Discovery
+ # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
+ return {
+ "issuer": ISSUER,
+ "authorization_endpoint": AUTHORIZATION_ENDPOINT,
+ "token_endpoint": TOKEN_ENDPOINT,
+ "jwks_uri": JWKS_URI,
+ "userinfo_endpoint": USERINFO_ENDPOINT,
+ "introspection_endpoint": INTROSPECTION_ENDPOINT,
+ "response_types_supported": ["code"],
+ "subject_types_supported": ["public"],
+ "id_token_signing_alg_values_supported": ["RS256"],
+ }
+ elif url == JWKS_URI:
+ return {"keys": []}
+
+ return {}
+
+
+@skip_unless(HAS_AUTHLIB, "requires authlib")
+class MSC3861OAuthDelegation(HomeserverTestCase):
+ servlets = [
+ account.register_servlets,
+ devices.register_servlets,
+ keys.register_servlets,
+ register.register_servlets,
+ login.register_servlets,
+ logout.register_servlets,
+ admin.register_servlets,
+ ]
+
+ def default_config(self) -> Dict[str, Any]:
+ config = super().default_config()
+ config["public_baseurl"] = BASE_URL
+ config["disable_registration"] = True
+ config["experimental_features"] = {
+ "msc3861": {
+ "enabled": True,
+ "issuer": ISSUER,
+ "client_id": CLIENT_ID,
+ "client_auth_method": "client_secret_post",
+ "client_secret": CLIENT_SECRET,
+ }
+ }
+ return config
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.http_client = Mock(spec=["get_json"])
+ self.http_client.get_json.side_effect = get_json
+ self.http_client.user_agent = b"Synapse Test"
+
+ hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
+
+ self.auth = hs.get_auth()
+
+ return hs
+
+ def _assertParams(self) -> None:
+ """Assert that the request parameters are correct."""
+ params = parse_qs(self.http_client.request.call_args[1]["data"].decode("utf-8"))
+ self.assertEqual(params["token"], ["mockAccessToken"])
+ self.assertEqual(params["client_id"], [CLIENT_ID])
+ self.assertEqual(params["client_secret"], [CLIENT_SECRET])
+
+ def test_inactive_token(self) -> None:
+ """The handler should return a 403 where the token is inactive."""
+
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={"active": False},
+ )
+ )
+ request = Mock(args={})
+ request.args[b"access_token"] = [b"mockAccessToken"]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
+ self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+ self.http_client.request.assert_called_once_with(
+ method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+ )
+ self._assertParams()
+
+ def test_active_no_scope(self) -> None:
+ """The handler should return a 403 where no scope is given."""
+
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={"active": True},
+ )
+ )
+ request = Mock(args={})
+ request.args[b"access_token"] = [b"mockAccessToken"]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
+ self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+ self.http_client.request.assert_called_once_with(
+ method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+ )
+ self._assertParams()
+
+ def test_active_user_no_subject(self) -> None:
+ """The handler should return a 500 when no subject is present."""
+
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={"active": True, "scope": " ".join([MATRIX_USER_SCOPE])},
+ )
+ )
+ request = Mock(args={})
+ request.args[b"access_token"] = [b"mockAccessToken"]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
+ self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+ self.http_client.request.assert_called_once_with(
+ method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+ )
+ self._assertParams()
+
+ def test_active_no_user_scope(self) -> None:
+ """The handler should return a 500 when no subject is present."""
+
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={
+ "active": True,
+ "sub": SUBJECT,
+ "scope": " ".join([MATRIX_DEVICE_SCOPE]),
+ },
+ )
+ )
+ request = Mock(args={})
+ request.args[b"access_token"] = [b"mockAccessToken"]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
+ self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+ self.http_client.request.assert_called_once_with(
+ method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+ )
+ self._assertParams()
+
+ def test_active_admin_not_user(self) -> None:
+ """The handler should raise when the scope has admin right but not user."""
+
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={
+ "active": True,
+ "sub": SUBJECT,
+ "scope": " ".join([SYNAPSE_ADMIN_SCOPE]),
+ "username": USERNAME,
+ },
+ )
+ )
+ request = Mock(args={})
+ request.args[b"access_token"] = [b"mockAccessToken"]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
+ self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+ self.http_client.request.assert_called_once_with(
+ method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+ )
+ self._assertParams()
+
+ def test_active_admin(self) -> None:
+ """The handler should return a requester with admin rights."""
+
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={
+ "active": True,
+ "sub": SUBJECT,
+ "scope": " ".join([SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE]),
+ "username": USERNAME,
+ },
+ )
+ )
+ request = Mock(args={})
+ request.args[b"access_token"] = [b"mockAccessToken"]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ requester = self.get_success(self.auth.get_user_by_req(request))
+ self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+ self.http_client.request.assert_called_once_with(
+ method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+ )
+ self._assertParams()
+ self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
+ self.assertEqual(requester.is_guest, False)
+ self.assertEqual(requester.device_id, None)
+ self.assertEqual(
+ get_awaitable_result(self.auth.is_server_admin(requester)), True
+ )
+
+ def test_active_admin_highest_privilege(self) -> None:
+ """The handler should resolve to the most permissive scope."""
+
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={
+ "active": True,
+ "sub": SUBJECT,
+ "scope": " ".join(
+ [SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE, MATRIX_GUEST_SCOPE]
+ ),
+ "username": USERNAME,
+ },
+ )
+ )
+ request = Mock(args={})
+ request.args[b"access_token"] = [b"mockAccessToken"]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ requester = self.get_success(self.auth.get_user_by_req(request))
+ self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+ self.http_client.request.assert_called_once_with(
+ method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+ )
+ self._assertParams()
+ self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
+ self.assertEqual(requester.is_guest, False)
+ self.assertEqual(requester.device_id, None)
+ self.assertEqual(
+ get_awaitable_result(self.auth.is_server_admin(requester)), True
+ )
+
+ def test_active_user(self) -> None:
+ """The handler should return a requester with normal user rights."""
+
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={
+ "active": True,
+ "sub": SUBJECT,
+ "scope": " ".join([MATRIX_USER_SCOPE]),
+ "username": USERNAME,
+ },
+ )
+ )
+ request = Mock(args={})
+ request.args[b"access_token"] = [b"mockAccessToken"]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ requester = self.get_success(self.auth.get_user_by_req(request))
+ self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+ self.http_client.request.assert_called_once_with(
+ method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+ )
+ self._assertParams()
+ self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
+ self.assertEqual(requester.is_guest, False)
+ self.assertEqual(requester.device_id, None)
+ self.assertEqual(
+ get_awaitable_result(self.auth.is_server_admin(requester)), False
+ )
+
+ def test_active_user_with_device(self) -> None:
+ """The handler should return a requester with normal user rights and a device ID."""
+
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={
+ "active": True,
+ "sub": SUBJECT,
+ "scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]),
+ "username": USERNAME,
+ },
+ )
+ )
+ request = Mock(args={})
+ request.args[b"access_token"] = [b"mockAccessToken"]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ requester = self.get_success(self.auth.get_user_by_req(request))
+ self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+ self.http_client.request.assert_called_once_with(
+ method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+ )
+ self._assertParams()
+ self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
+ self.assertEqual(requester.is_guest, False)
+ self.assertEqual(
+ get_awaitable_result(self.auth.is_server_admin(requester)), False
+ )
+ self.assertEqual(requester.device_id, DEVICE)
+
+ def test_multiple_devices(self) -> None:
+ """The handler should raise an error if multiple devices are found in the scope."""
+
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={
+ "active": True,
+ "sub": SUBJECT,
+ "scope": " ".join(
+ [
+ MATRIX_USER_SCOPE,
+ f"{MATRIX_DEVICE_SCOPE_PREFIX}AABBCC",
+ f"{MATRIX_DEVICE_SCOPE_PREFIX}DDEEFF",
+ ]
+ ),
+ "username": USERNAME,
+ },
+ )
+ )
+ request = Mock(args={})
+ request.args[b"access_token"] = [b"mockAccessToken"]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ self.get_failure(self.auth.get_user_by_req(request), AuthError)
+
+ def test_active_guest_not_allowed(self) -> None:
+ """The handler should return an insufficient scope error."""
+
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={
+ "active": True,
+ "sub": SUBJECT,
+ "scope": " ".join([MATRIX_GUEST_SCOPE, MATRIX_DEVICE_SCOPE]),
+ "username": USERNAME,
+ },
+ )
+ )
+ request = Mock(args={})
+ request.args[b"access_token"] = [b"mockAccessToken"]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ error = self.get_failure(
+ self.auth.get_user_by_req(request), OAuthInsufficientScopeError
+ )
+ self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+ self.http_client.request.assert_called_once_with(
+ method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+ )
+ self._assertParams()
+ self.assertEqual(
+ getattr(error.value, "headers", {})["WWW-Authenticate"],
+ 'Bearer error="insufficient_scope", scope="urn:matrix:org.matrix.msc2967.client:api:*"',
+ )
+
+ def test_active_guest_allowed(self) -> None:
+ """The handler should return a requester with guest user rights and a device ID."""
+
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={
+ "active": True,
+ "sub": SUBJECT,
+ "scope": " ".join([MATRIX_GUEST_SCOPE, MATRIX_DEVICE_SCOPE]),
+ "username": USERNAME,
+ },
+ )
+ )
+ request = Mock(args={})
+ request.args[b"access_token"] = [b"mockAccessToken"]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ requester = self.get_success(
+ self.auth.get_user_by_req(request, allow_guest=True)
+ )
+ self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+ self.http_client.request.assert_called_once_with(
+ method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+ )
+ self._assertParams()
+ self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
+ self.assertEqual(requester.is_guest, True)
+ self.assertEqual(
+ get_awaitable_result(self.auth.is_server_admin(requester)), False
+ )
+ self.assertEqual(requester.device_id, DEVICE)
+
+ def test_unavailable_introspection_endpoint(self) -> None:
+ """The handler should return an internal server error."""
+ request = Mock(args={})
+ request.args[b"access_token"] = [b"mockAccessToken"]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+
+ # The introspection endpoint is returning an error.
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(code=500, body=b"Internal Server Error")
+ )
+ error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
+ self.assertEqual(error.value.code, 503)
+
+ # The introspection endpoint request fails.
+ self.http_client.request = simple_async_mock(raises=Exception())
+ error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
+ self.assertEqual(error.value.code, 503)
+
+ # The introspection endpoint does not return a JSON object.
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200, payload=["this is an array", "not an object"]
+ )
+ )
+ error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
+ self.assertEqual(error.value.code, 503)
+
+ # The introspection endpoint does not return valid JSON.
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(code=200, body=b"this is not valid JSON")
+ )
+ error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
+ self.assertEqual(error.value.code, 503)
+
+ def make_device_keys(self, user_id: str, device_id: str) -> JsonDict:
+ # We only generate a master key to simplify the test.
+ master_signing_key = generate_signing_key(device_id)
+ master_verify_key = encode_verify_key_base64(get_verify_key(master_signing_key))
+
+ return {
+ "master_key": sign_json(
+ {
+ "user_id": user_id,
+ "usage": ["master"],
+ "keys": {"ed25519:" + master_verify_key: master_verify_key},
+ },
+ user_id,
+ master_signing_key,
+ ),
+ }
+
+ def test_cross_signing(self) -> None:
+ """Try uploading device keys with OAuth delegation enabled."""
+
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={
+ "active": True,
+ "sub": SUBJECT,
+ "scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]),
+ "username": USERNAME,
+ },
+ )
+ )
+ keys_upload_body = self.make_device_keys(USER_ID, DEVICE)
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/keys/device_signing/upload",
+ keys_upload_body,
+ access_token="mockAccessToken",
+ )
+
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/keys/device_signing/upload",
+ keys_upload_body,
+ access_token="mockAccessToken",
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.NOT_IMPLEMENTED, channel.json_body)
+
+ def expect_unauthorized(
+ self, method: str, path: str, content: Union[bytes, str, JsonDict] = ""
+ ) -> None:
+ channel = self.make_request(method, path, content, shorthand=False)
+
+ self.assertEqual(channel.code, 401, channel.json_body)
+
+ def expect_unrecognized(
+ self, method: str, path: str, content: Union[bytes, str, JsonDict] = ""
+ ) -> None:
+ channel = self.make_request(method, path, content)
+
+ self.assertEqual(channel.code, 404, channel.json_body)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.UNRECOGNIZED, channel.json_body
+ )
+
+ def test_uia_endpoints(self) -> None:
+ """Test that endpoints that were removed in MSC2964 are no longer available."""
+
+ # This is just an endpoint that should remain visible (but requires auth):
+ self.expect_unauthorized("GET", "/_matrix/client/v3/devices")
+
+ # This remains usable, but will require a uia scope:
+ self.expect_unauthorized(
+ "POST", "/_matrix/client/v3/keys/device_signing/upload"
+ )
+
+ def test_3pid_endpoints(self) -> None:
+ """Test that 3pid account management endpoints that were removed in MSC2964 are no longer available."""
+
+ # Remains and requires auth:
+ self.expect_unauthorized("GET", "/_matrix/client/v3/account/3pid")
+ self.expect_unauthorized(
+ "POST",
+ "/_matrix/client/v3/account/3pid/bind",
+ {
+ "client_secret": "foo",
+ "id_access_token": "bar",
+ "id_server": "foo",
+ "sid": "bar",
+ },
+ )
+ self.expect_unauthorized("POST", "/_matrix/client/v3/account/3pid/unbind", {})
+
+ # These are gone:
+ self.expect_unrecognized(
+ "POST", "/_matrix/client/v3/account/3pid"
+ ) # deprecated
+ self.expect_unrecognized("POST", "/_matrix/client/v3/account/3pid/add")
+ self.expect_unrecognized("POST", "/_matrix/client/v3/account/3pid/delete")
+ self.expect_unrecognized(
+ "POST", "/_matrix/client/v3/account/3pid/email/requestToken"
+ )
+ self.expect_unrecognized(
+ "POST", "/_matrix/client/v3/account/3pid/msisdn/requestToken"
+ )
+
+ def test_account_management_endpoints_removed(self) -> None:
+ """Test that account management endpoints that were removed in MSC2964 are no longer available."""
+ self.expect_unrecognized("POST", "/_matrix/client/v3/account/deactivate")
+ self.expect_unrecognized("POST", "/_matrix/client/v3/account/password")
+ self.expect_unrecognized(
+ "POST", "/_matrix/client/v3/account/password/email/requestToken"
+ )
+ self.expect_unrecognized(
+ "POST", "/_matrix/client/v3/account/password/msisdn/requestToken"
+ )
+
+ def test_registration_endpoints_removed(self) -> None:
+ """Test that registration endpoints that were removed in MSC2964 are no longer available."""
+ self.expect_unrecognized(
+ "GET", "/_matrix/client/v1/register/m.login.registration_token/validity"
+ )
+ # This is still available for AS registrations
+ # self.expect_unrecognized("POST", "/_matrix/client/v3/register")
+ self.expect_unrecognized("GET", "/_matrix/client/v3/register/available")
+ self.expect_unrecognized(
+ "POST", "/_matrix/client/v3/register/email/requestToken"
+ )
+ self.expect_unrecognized(
+ "POST", "/_matrix/client/v3/register/msisdn/requestToken"
+ )
+
+ def test_session_management_endpoints_removed(self) -> None:
+ """Test that session management endpoints that were removed in MSC2964 are no longer available."""
+ self.expect_unrecognized("GET", "/_matrix/client/v3/login")
+ self.expect_unrecognized("POST", "/_matrix/client/v3/login")
+ self.expect_unrecognized("GET", "/_matrix/client/v3/login/sso/redirect")
+ self.expect_unrecognized("POST", "/_matrix/client/v3/logout")
+ self.expect_unrecognized("POST", "/_matrix/client/v3/logout/all")
+ self.expect_unrecognized("POST", "/_matrix/client/v3/refresh")
+ self.expect_unrecognized("GET", "/_matrix/static/client/login")
+
+ def test_device_management_endpoints_removed(self) -> None:
+ """Test that device management endpoints that were removed in MSC2964 are no longer available."""
+ self.expect_unrecognized("POST", "/_matrix/client/v3/delete_devices")
+ self.expect_unrecognized("DELETE", "/_matrix/client/v3/devices/{DEVICE}")
+
+ def test_openid_endpoints_removed(self) -> None:
+ """Test that OpenID id_token endpoints that were removed in MSC2964 are no longer available."""
+ self.expect_unrecognized(
+ "POST", "/_matrix/client/v3/user/{USERNAME}/openid/request_token"
+ )
+
+ def test_admin_api_endpoints_removed(self) -> None:
+ """Test that admin API endpoints that were removed in MSC2964 are no longer available."""
+ self.expect_unrecognized("GET", "/_synapse/admin/v1/registration_tokens")
+ self.expect_unrecognized("POST", "/_synapse/admin/v1/registration_tokens/new")
+ self.expect_unrecognized("GET", "/_synapse/admin/v1/registration_tokens/abcd")
+ self.expect_unrecognized("PUT", "/_synapse/admin/v1/registration_tokens/abcd")
+ self.expect_unrecognized(
+ "DELETE", "/_synapse/admin/v1/registration_tokens/abcd"
+ )
+ self.expect_unrecognized("POST", "/_synapse/admin/v1/reset_password/foo")
+ self.expect_unrecognized("POST", "/_synapse/admin/v1/users/foo/login")
+ self.expect_unrecognized("GET", "/_synapse/admin/v1/register")
+ self.expect_unrecognized("POST", "/_synapse/admin/v1/register")
+ self.expect_unrecognized("GET", "/_synapse/admin/v1/users/foo/admin")
+ self.expect_unrecognized("PUT", "/_synapse/admin/v1/users/foo/admin")
+ self.expect_unrecognized("POST", "/_synapse/admin/v1/account_validity/validity")
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 73822b07a5..8d8584609b 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -17,7 +17,7 @@ from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
-from synapse.api.auth import Auth
+from synapse.api.auth.internal import InternalAuth
from synapse.api.constants import UserTypes
from synapse.api.errors import (
CodeMessageException,
@@ -683,7 +683,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
request = Mock(args={})
request.args[b"access_token"] = [token.encode("ascii")]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- auth = Auth(self.hs)
+ auth = InternalAuth(self.hs)
requester = self.get_success(auth.get_user_by_req(request))
self.assertTrue(requester.shadow_banned)
diff --git a/tests/media/test_html_preview.py b/tests/media/test_html_preview.py
index e7da75db3e..ea84bb3d3d 100644
--- a/tests/media/test_html_preview.py
+++ b/tests/media/test_html_preview.py
@@ -24,7 +24,7 @@ from tests import unittest
try:
import lxml
except ImportError:
- lxml = None
+ lxml = None # type: ignore[assignment]
class SummarizeTestCase(unittest.TestCase):
@@ -160,6 +160,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
+ assert tree is not None
og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -176,6 +177,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
+ assert tree is not None
og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -195,6 +197,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
+ assert tree is not None
og = parse_html_to_open_graph(tree)
self.assertEqual(
@@ -217,6 +220,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
+ assert tree is not None
og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -231,6 +235,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
+ assert tree is not None
og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
@@ -246,6 +251,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
+ assert tree is not None
og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Title", "og:description": "Title"})
@@ -261,6 +267,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
+ assert tree is not None
og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
@@ -281,6 +288,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
+ assert tree is not None
og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Title", "og:description": "Finally!"})
@@ -296,6 +304,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
+ assert tree is not None
og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
@@ -324,6 +333,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
<head><title>Foo</title></head><body>Some text.</body></html>
""".strip()
tree = decode_body(html, "http://example.com/test.html")
+ assert tree is not None
og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -338,6 +348,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
</html>
"""
tree = decode_body(html, "http://example.com/test.html", "invalid-encoding")
+ assert tree is not None
og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -353,6 +364,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
</html>
"""
tree = decode_body(html, "http://example.com/test.html")
+ assert tree is not None
og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
@@ -367,6 +379,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
</html>
"""
tree = decode_body(html, "http://example.com/test.html")
+ assert tree is not None
og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."})
@@ -380,6 +393,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
</html>
"""
tree = decode_body(html, "http://example.com/test.html")
+ assert tree is not None
og = parse_html_to_open_graph(tree)
self.assertEqual(
og,
@@ -401,6 +415,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
</html>
"""
tree = decode_body(html, "http://example.com/test.html")
+ assert tree is not None
og = parse_html_to_open_graph(tree)
self.assertEqual(
og,
@@ -419,6 +434,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
with a cheeky SVG</svg></u> and <strong>some</strong> tail text</b></a>
"""
tree = decode_body(html, "http://example.com/test.html")
+ assert tree is not None
og = parse_html_to_open_graph(tree)
self.assertEqual(
og,
diff --git a/tests/media/test_oembed.py b/tests/media/test_oembed.py
index c8bf8421da..3bc19cb1cc 100644
--- a/tests/media/test_oembed.py
+++ b/tests/media/test_oembed.py
@@ -28,7 +28,7 @@ from tests.unittest import HomeserverTestCase
try:
import lxml
except ImportError:
- lxml = None
+ lxml = None # type: ignore[assignment]
class OEmbedTests(HomeserverTestCase):
diff --git a/tests/media/test_url_previewer.py b/tests/media/test_url_previewer.py
index 3c4c7d6765..46ecde5344 100644
--- a/tests/media/test_url_previewer.py
+++ b/tests/media/test_url_previewer.py
@@ -24,7 +24,7 @@ from tests.unittest import override_config
try:
import lxml
except ImportError:
- lxml = None
+ lxml = None # type: ignore[assignment]
class URLPreviewTests(unittest.HomeserverTestCase):
diff --git a/tests/rest/admin/test_jwks.py b/tests/rest/admin/test_jwks.py
new file mode 100644
index 0000000000..a9a6191c73
--- /dev/null
+++ b/tests/rest/admin/test_jwks.py
@@ -0,0 +1,106 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict
+
+from twisted.web.resource import Resource
+
+from synapse.rest.synapse.client import build_synapse_client_resource_tree
+
+from tests.unittest import HomeserverTestCase, override_config, skip_unless
+
+try:
+ import authlib # noqa: F401
+
+ HAS_AUTHLIB = True
+except ImportError:
+ HAS_AUTHLIB = False
+
+
+@skip_unless(HAS_AUTHLIB, "requires authlib")
+class JWKSTestCase(HomeserverTestCase):
+ """Test /_synapse/jwks JWKS data."""
+
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d.update(build_synapse_client_resource_tree(self.hs))
+ return d
+
+ def test_empty_jwks(self) -> None:
+ """Test that the JWKS endpoint is not present by default."""
+ channel = self.make_request("GET", "/_synapse/jwks")
+ self.assertEqual(404, channel.code, channel.result)
+
+ @override_config(
+ {
+ "disable_registration": True,
+ "experimental_features": {
+ "msc3861": {
+ "enabled": True,
+ "issuer": "https://issuer/",
+ "client_id": "test-client-id",
+ "client_auth_method": "client_secret_post",
+ "client_secret": "secret",
+ },
+ },
+ }
+ )
+ def test_empty_jwks_for_msc3861_client_secret_post(self) -> None:
+ """Test that the JWKS endpoint is empty when plain auth is used."""
+ channel = self.make_request("GET", "/_synapse/jwks")
+ self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual({"keys": []}, channel.json_body)
+
+ @override_config(
+ {
+ "disable_registration": True,
+ "experimental_features": {
+ "msc3861": {
+ "enabled": True,
+ "issuer": "https://issuer/",
+ "client_id": "test-client-id",
+ "client_auth_method": "private_key_jwt",
+ "jwk": {
+ "p": "-frVdP_tZ-J_nIR6HNMDq1N7aunwm51nAqNnhqIyuA8ikx7LlQED1tt2LD3YEvYyW8nxE2V95HlCRZXQPMiRJBFOsbmYkzl2t-MpavTaObB_fct_JqcRtdXddg4-_ihdjRDwUOreq_dpWh6MIKsC3UyekfkHmeEJg5YpOTL15j8",
+ "kty": "RSA",
+ "q": "oFw-Enr_YozQB1ab-kawn4jY3yHi8B1nSmYT0s8oTCflrmps5BFJfCkHL5ij3iY15z0o2m0N-jjB1oSJ98O4RayEEYNQlHnTNTl0kRIWzpoqblHUIxVcahIpP_xTovBJzwi8XXoLGqHOOMA-r40LSyVgP2Ut8D9qBwV6_UfT0LU",
+ "d": "WFkDPYo4b4LIS64D_QtQfGGuAObPvc3HFfp9VZXyq3SJR58XZRHE0jqtlEMNHhOTgbMYS3w8nxPQ_qVzY-5hs4fIanwvB64mAoOGl0qMHO65DTD_WsGFwzYClJPBVniavkLE2Hmpu8IGe6lGliN8vREC6_4t69liY-XcN_ECboVtC2behKkLOEASOIMuS7YcKAhTJFJwkl1dqDlliEn5A4u4xy7nuWQz3juB1OFdKlwGA5dfhDNglhoLIwNnkLsUPPFO-WB5ZNEW35xxHOToxj4bShvDuanVA6mJPtTKjz0XibjB36bj_nF_j7EtbE2PdGJ2KevAVgElR4lqS4ISgQ",
+ "e": "AQAB",
+ "kid": "test",
+ "qi": "cPfNk8l8W5exVNNea4d7QZZ8Qr8LgHghypYAxz8PQh1fNa8Ya1SNUDVzC2iHHhszxxA0vB9C7jGze8dBrvnzWYF1XvQcqNIVVgHhD57R1Nm3dj2NoHIKe0Cu4bCUtP8xnZQUN4KX7y4IIcgRcBWG1hT6DEYZ4BxqicnBXXNXAUI",
+ "dp": "dKlMHvslV1sMBQaKWpNb3gPq0B13TZhqr3-E2_8sPlvJ3fD8P4CmwwnOn50JDuhY3h9jY5L06sBwXjspYISVv8hX-ndMLkEeF3lrJeA5S70D8rgakfZcPIkffm3tlf1Ok3v5OzoxSv3-67Df4osMniyYwDUBCB5Oq1tTx77xpU8",
+ "dq": "S4ooU1xNYYcjl9FcuJEEMqKsRrAXzzSKq6laPTwIp5dDwt2vXeAm1a4eDHXC-6rUSZGt5PbqVqzV4s-cjnJMI8YYkIdjNg4NSE1Ac_YpeDl3M3Colb5CQlU7yUB7xY2bt0NOOFp9UJZYJrOo09mFMGjy5eorsbitoZEbVqS3SuE",
+ "n": "nJbYKqFwnURKimaviyDFrNLD3gaKR1JW343Qem25VeZxoMq1665RHVoO8n1oBm4ClZdjIiZiVdpyqzD5-Ow12YQgQEf1ZHP3CCcOQQhU57Rh5XvScTe5IxYVkEW32IW2mp_CJ6WfjYpfeL4azarVk8H3Vr59d1rSrKTVVinVdZer9YLQyC_rWAQNtHafPBMrf6RYiNGV9EiYn72wFIXlLlBYQ9Fx7bfe1PaL6qrQSsZP3_rSpuvVdLh1lqGeCLR0pyclA9uo5m2tMyCXuuGQLbA_QJm5xEc7zd-WFdux2eXF045oxnSZ_kgQt-pdN7AxGWOVvwoTf9am6mSkEdv6iw",
+ },
+ },
+ },
+ }
+ )
+ def test_key_returned_for_msc3861_client_secret_post(self) -> None:
+ """Test that the JWKS includes public part of JWK for private_key_jwt auth is used."""
+ channel = self.make_request("GET", "/_synapse/jwks")
+ self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual(
+ {
+ "keys": [
+ {
+ "kty": "RSA",
+ "e": "AQAB",
+ "kid": "test",
+ "n": "nJbYKqFwnURKimaviyDFrNLD3gaKR1JW343Qem25VeZxoMq1665RHVoO8n1oBm4ClZdjIiZiVdpyqzD5-Ow12YQgQEf1ZHP3CCcOQQhU57Rh5XvScTe5IxYVkEW32IW2mp_CJ6WfjYpfeL4azarVk8H3Vr59d1rSrKTVVinVdZer9YLQyC_rWAQNtHafPBMrf6RYiNGV9EiYn72wFIXlLlBYQ9Fx7bfe1PaL6qrQSsZP3_rSpuvVdLh1lqGeCLR0pyclA9uo5m2tMyCXuuGQLbA_QJm5xEc7zd-WFdux2eXF045oxnSZ_kgQt-pdN7AxGWOVvwoTf9am6mSkEdv6iw",
+ }
+ ]
+ },
+ channel.json_body,
+ )
diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py
index c16e8d43f4..cf23430f6a 100644
--- a/tests/rest/client/test_capabilities.py
+++ b/tests/rest/client/test_capabilities.py
@@ -186,3 +186,31 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.assertGreater(len(details["support"]), 0)
for room_version in details["support"]:
self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, str(room_version))
+
+ def test_get_get_token_login_fields_when_disabled(self) -> None:
+ """By default login via an existing session is disabled."""
+ access_token = self.get_success(
+ self.auth_handler.create_access_token_for_user_id(
+ self.user, device_id=None, valid_until_ms=None
+ )
+ )
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertFalse(capabilities["m.get_login_token"]["enabled"])
+
+ @override_config({"login_via_existing_session": {"enabled": True}})
+ def test_get_get_token_login_fields_when_enabled(self) -> None:
+ access_token = self.get_success(
+ self.auth_handler.create_access_token_for_user_id(
+ self.user, device_id=None, valid_until_ms=None
+ )
+ )
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertTrue(capabilities["m.get_login_token"]["enabled"])
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index dc32982e22..f3c3bc69a9 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -446,6 +446,29 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
)
+ def test_get_login_flows_with_login_via_existing_disabled(self) -> None:
+ """GET /login should return m.login.token without get_login_token"""
+ channel = self.make_request("GET", "/_matrix/client/r0/login")
+ self.assertEqual(channel.code, 200, channel.result)
+
+ flows = {flow["type"]: flow for flow in channel.json_body["flows"]}
+ self.assertNotIn("m.login.token", flows)
+
+ @override_config({"login_via_existing_session": {"enabled": True}})
+ def test_get_login_flows_with_login_via_existing_enabled(self) -> None:
+ """GET /login should return m.login.token with get_login_token true"""
+ channel = self.make_request("GET", "/_matrix/client/r0/login")
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self.assertCountEqual(
+ channel.json_body["flows"],
+ [
+ {"type": "m.login.token", "get_login_token": True},
+ {"type": "m.login.password"},
+ {"type": "m.login.application_service"},
+ ],
+ )
+
@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
class MultiSSOTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py
index b8187db982..f05e619aa8 100644
--- a/tests/rest/client/test_login_token_request.py
+++ b/tests/rest/client/test_login_token_request.py
@@ -15,14 +15,14 @@
from twisted.test.proto_helpers import MemoryReactor
from synapse.rest import admin
-from synapse.rest.client import login, login_token_request
+from synapse.rest.client import login, login_token_request, versions
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
from tests.unittest import override_config
-endpoint = "/_matrix/client/unstable/org.matrix.msc3882/login/token"
+GET_TOKEN_ENDPOINT = "/_matrix/client/v1/login/get_token"
class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
@@ -30,6 +30,7 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
login.register_servlets,
admin.register_servlets,
login_token_request.register_servlets,
+ versions.register_servlets, # TODO: remove once unstable revision 0 support is removed
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
@@ -46,26 +47,26 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
self.password = "password"
def test_disabled(self) -> None:
- channel = self.make_request("POST", endpoint, {}, access_token=None)
+ channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=None)
self.assertEqual(channel.code, 404)
self.register_user(self.user, self.password)
token = self.login(self.user, self.password)
- channel = self.make_request("POST", endpoint, {}, access_token=token)
+ channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=token)
self.assertEqual(channel.code, 404)
- @override_config({"experimental_features": {"msc3882_enabled": True}})
+ @override_config({"login_via_existing_session": {"enabled": True}})
def test_require_auth(self) -> None:
- channel = self.make_request("POST", endpoint, {}, access_token=None)
+ channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=None)
self.assertEqual(channel.code, 401)
- @override_config({"experimental_features": {"msc3882_enabled": True}})
+ @override_config({"login_via_existing_session": {"enabled": True}})
def test_uia_on(self) -> None:
user_id = self.register_user(self.user, self.password)
token = self.login(self.user, self.password)
- channel = self.make_request("POST", endpoint, {}, access_token=token)
+ channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=token)
self.assertEqual(channel.code, 401)
self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
@@ -80,9 +81,9 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
},
}
- channel = self.make_request("POST", endpoint, uia, access_token=token)
+ channel = self.make_request("POST", GET_TOKEN_ENDPOINT, uia, access_token=token)
self.assertEqual(channel.code, 200)
- self.assertEqual(channel.json_body["expires_in"], 300)
+ self.assertEqual(channel.json_body["expires_in_ms"], 300000)
login_token = channel.json_body["login_token"]
@@ -95,15 +96,15 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["user_id"], user_id)
@override_config(
- {"experimental_features": {"msc3882_enabled": True, "msc3882_ui_auth": False}}
+ {"login_via_existing_session": {"enabled": True, "require_ui_auth": False}}
)
def test_uia_off(self) -> None:
user_id = self.register_user(self.user, self.password)
token = self.login(self.user, self.password)
- channel = self.make_request("POST", endpoint, {}, access_token=token)
+ channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=token)
self.assertEqual(channel.code, 200)
- self.assertEqual(channel.json_body["expires_in"], 300)
+ self.assertEqual(channel.json_body["expires_in_ms"], 300000)
login_token = channel.json_body["login_token"]
@@ -117,10 +118,10 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
@override_config(
{
- "experimental_features": {
- "msc3882_enabled": True,
- "msc3882_ui_auth": False,
- "msc3882_token_timeout": "15s",
+ "login_via_existing_session": {
+ "enabled": True,
+ "require_ui_auth": False,
+ "token_timeout": "15s",
}
}
)
@@ -128,6 +129,40 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
self.register_user(self.user, self.password)
token = self.login(self.user, self.password)
- channel = self.make_request("POST", endpoint, {}, access_token=token)
+ channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=token)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["expires_in_ms"], 15000)
+
+ @override_config(
+ {
+ "login_via_existing_session": {
+ "enabled": True,
+ "require_ui_auth": False,
+ "token_timeout": "15s",
+ }
+ }
+ )
+ def test_unstable_support(self) -> None:
+ # TODO: remove support for unstable MSC3882 is no longer needed
+
+ # check feature is advertised in versions response:
+ channel = self.make_request(
+ "GET", "/_matrix/client/versions", {}, access_token=None
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body["unstable_features"]["org.matrix.msc3882"], True
+ )
+
+ self.register_user(self.user, self.password)
+ token = self.login(self.user, self.password)
+
+ # check feature is available via the unstable endpoint and returns an expires_in value in seconds
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/unstable/org.matrix.msc3882/login/token",
+ {},
+ access_token=token,
+ )
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["expires_in"], 15)
diff --git a/tests/rest/media/test_url_preview.py b/tests/rest/media/test_url_preview.py
index 170fb0534a..05d5e39cab 100644
--- a/tests/rest/media/test_url_preview.py
+++ b/tests/rest/media/test_url_preview.py
@@ -40,7 +40,7 @@ from tests.test_utils import SMALL_PNG
try:
import lxml
except ImportError:
- lxml = None
+ lxml = None # type: ignore[assignment]
class URLPreviewTests(unittest.HomeserverTestCase):
diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
index 2091b08d89..377243a170 100644
--- a/tests/rest/test_well_known.py
+++ b/tests/rest/test_well_known.py
@@ -17,6 +17,13 @@ from synapse.rest.well_known import well_known_resource
from tests import unittest
+try:
+ import authlib # noqa: F401
+
+ HAS_AUTHLIB = True
+except ImportError:
+ HAS_AUTHLIB = False
+
class WellKnownTests(unittest.HomeserverTestCase):
def create_test_resource(self) -> Resource:
@@ -96,3 +103,37 @@ class WellKnownTests(unittest.HomeserverTestCase):
"GET", "/.well-known/matrix/server", shorthand=False
)
self.assertEqual(channel.code, 404)
+
+ @unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
+ @unittest.override_config(
+ {
+ "public_baseurl": "https://homeserver", # this is only required so that client well known is served
+ "experimental_features": {
+ "msc3861": {
+ "enabled": True,
+ "issuer": "https://issuer",
+ "account_management_url": "https://my-account.issuer",
+ "client_id": "id",
+ "client_auth_method": "client_secret_post",
+ "client_secret": "secret",
+ },
+ },
+ "disable_registration": True,
+ }
+ )
+ def test_client_well_known_msc3861_oauth_delegation(self) -> None:
+ channel = self.make_request(
+ "GET", "/.well-known/matrix/client", shorthand=False
+ )
+
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body,
+ {
+ "m.homeserver": {"base_url": "https://homeserver/"},
+ "org.matrix.msc2965.authentication": {
+ "issuer": "https://issuer",
+ "account": "https://my-account.issuer",
+ },
+ },
+ )
diff --git a/tests/server.py b/tests/server.py
index 7296f0a552..a12c3e3b9a 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -642,7 +642,7 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None:
pool.runWithConnection = runWithConnection # type: ignore[assignment]
pool.runInteraction = runInteraction # type: ignore[assignment]
# Replace the thread pool with a threadless 'thread' pool
- pool.threadpool = ThreadPool(clock._reactor) # type: ignore[assignment]
+ pool.threadpool = ThreadPool(clock._reactor)
pool.running = True
# We've just changed the Databases to run DB transactions on the same
diff --git a/tests/test_state.py b/tests/test_state.py
index ddf59916b1..7a49b87953 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -28,7 +28,7 @@ from unittest.mock import Mock
from twisted.internet import defer
-from synapse.api.auth import Auth
+from synapse.api.auth.internal import InternalAuth
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, make_event_from_dict
@@ -240,7 +240,7 @@ class StateTestCase(unittest.TestCase):
hs.get_macaroon_generator.return_value = MacaroonGenerator(
clock, "tesths", b"verysecret"
)
- hs.get_auth.return_value = Auth(hs)
+ hs.get_auth.return_value = InternalAuth(hs)
hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
hs.get_storage_controllers.return_value = storage_controllers
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index e5dae670a7..c8cc841d95 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -33,7 +33,7 @@ from twisted.web.http import RESPONSES
from twisted.web.http_headers import Headers
from twisted.web.iweb import IResponse
-from synapse.types import JsonDict
+from synapse.types import JsonSerializable
if TYPE_CHECKING:
from sys import UnraisableHookArgs
@@ -145,7 +145,7 @@ class FakeResponse: # type: ignore[misc]
protocol.connectionLost(Failure(ResponseDone()))
@classmethod
- def json(cls, *, code: int = 200, payload: JsonDict) -> "FakeResponse":
+ def json(cls, *, code: int = 200, payload: JsonSerializable) -> "FakeResponse":
headers = Headers({"Content-Type": ["application/json"]})
body = json.dumps(payload).encode("utf-8")
return cls(code=code, body=body, headers=headers)
|