diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 6e36e73f0d..cdb0048122 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -18,7 +18,7 @@ import pymacaroons
from twisted.test.proto_helpers import MemoryReactor
-from synapse.api.auth import Auth
+from synapse.api.auth.internal import InternalAuth
from synapse.api.auth_blocking import AuthBlocking
from synapse.api.constants import UserTypes
from synapse.api.errors import (
@@ -48,7 +48,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
# have been called by the HomeserverTestCase machinery.
hs.datastores.main = self.store # type: ignore[union-attr]
hs.get_auth_handler().store = self.store
- self.auth = Auth(hs)
+ self.auth = InternalAuth(hs)
# AuthBlocking reads from the hs' config on initialization. We need to
# modify its config instead of the hs'
@@ -426,6 +426,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
access_token_id=None,
device_id="FOOBAR",
is_guest=False,
+ scope=set(),
shadow_banned=False,
app_service=appservice,
authenticated_entity="@appservice:server",
@@ -456,6 +457,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
access_token_id=None,
device_id="FOOBAR",
is_guest=False,
+ scope=set(),
shadow_banned=False,
app_service=appservice,
authenticated_entity="@appservice:server",
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index aa6af5ad7b..868f0c6995 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -35,7 +35,6 @@ from tests.events.test_utils import MockEvent
user_id = UserID.from_string("@test_user:test")
user2_id = UserID.from_string("@test_user2:test")
-user_localpart = "test_user"
class FilteringTestCase(unittest.HomeserverTestCase):
@@ -449,9 +448,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
]
user_filter = self.get_success(
- self.filtering.get_user_filter(
- user_localpart=user_localpart, filter_id=filter_id
- )
+ self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
)
results = self.get_success(user_filter.filter_presence(presence_states))
@@ -479,9 +476,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
]
user_filter = self.get_success(
- self.filtering.get_user_filter(
- user_localpart=user_localpart + "2", filter_id=filter_id
- )
+ self.filtering.get_user_filter(user_id=user2_id, filter_id=filter_id)
)
results = self.get_success(user_filter.filter_presence(presence_states))
@@ -498,9 +493,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
events = [event]
user_filter = self.get_success(
- self.filtering.get_user_filter(
- user_localpart=user_localpart, filter_id=filter_id
- )
+ self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
)
results = self.get_success(user_filter.filter_room_state(events=events))
@@ -519,9 +512,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
events = [event]
user_filter = self.get_success(
- self.filtering.get_user_filter(
- user_localpart=user_localpart, filter_id=filter_id
- )
+ self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
)
results = self.get_success(user_filter.filter_room_state(events))
@@ -603,9 +594,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
user_filter_json,
(
self.get_success(
- self.datastore.get_user_filter(
- user_localpart=user_localpart, filter_id=0
- )
+ self.datastore.get_user_filter(user_id=user_id, filter_id=0)
)
),
)
@@ -620,9 +609,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
)
filter = self.get_success(
- self.filtering.get_user_filter(
- user_localpart=user_localpart, filter_id=filter_id
- )
+ self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
)
self.assertEqual(filter.get_filter_json(), user_filter_json)
diff --git a/tests/config/test_oauth_delegation.py b/tests/config/test_oauth_delegation.py
new file mode 100644
index 0000000000..f57c813a58
--- /dev/null
+++ b/tests/config/test_oauth_delegation.py
@@ -0,0 +1,257 @@
+# Copyright 2023 Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from unittest.mock import Mock
+
+from synapse.config import ConfigError
+from synapse.config.homeserver import HomeServerConfig
+from synapse.module_api import ModuleApi
+from synapse.types import JsonDict
+
+from tests.server import get_clock, setup_test_homeserver
+from tests.unittest import TestCase, skip_unless
+from tests.utils import default_config
+
+try:
+ import authlib # noqa: F401
+
+ HAS_AUTHLIB = True
+except ImportError:
+ HAS_AUTHLIB = False
+
+
+# These are a few constants that are used as config parameters in the tests.
+SERVER_NAME = "test"
+ISSUER = "https://issuer/"
+CLIENT_ID = "test-client-id"
+CLIENT_SECRET = "test-client-secret"
+BASE_URL = "https://synapse/"
+
+
+class CustomAuthModule:
+ """A module which registers a password auth provider."""
+
+ @staticmethod
+ def parse_config(config: JsonDict) -> None:
+ pass
+
+ def __init__(self, config: None, api: ModuleApi):
+ api.register_password_auth_provider_callbacks(
+ auth_checkers={("m.login.password", ("password",)): Mock()},
+ )
+
+
+@skip_unless(HAS_AUTHLIB, "requires authlib")
+class MSC3861OAuthDelegation(TestCase):
+ """Test that the Homeserver fails to initialize if the config is invalid."""
+
+ def setUp(self) -> None:
+ self.config_dict: JsonDict = {
+ **default_config("test"),
+ "public_baseurl": BASE_URL,
+ "enable_registration": False,
+ "experimental_features": {
+ "msc3861": {
+ "enabled": True,
+ "issuer": ISSUER,
+ "client_id": CLIENT_ID,
+ "client_auth_method": "client_secret_post",
+ "client_secret": CLIENT_SECRET,
+ }
+ },
+ }
+
+ def parse_config(self) -> HomeServerConfig:
+ config = HomeServerConfig()
+ config.parse_config_dict(self.config_dict, "", "")
+ return config
+
+ def test_client_secret_post_works(self) -> None:
+ self.config_dict["experimental_features"]["msc3861"].update(
+ client_auth_method="client_secret_post",
+ client_secret=CLIENT_SECRET,
+ )
+
+ self.parse_config()
+
+ def test_client_secret_post_requires_client_secret(self) -> None:
+ self.config_dict["experimental_features"]["msc3861"].update(
+ client_auth_method="client_secret_post",
+ client_secret=None,
+ )
+
+ with self.assertRaises(ConfigError):
+ self.parse_config()
+
+ def test_client_secret_basic_works(self) -> None:
+ self.config_dict["experimental_features"]["msc3861"].update(
+ client_auth_method="client_secret_basic",
+ client_secret=CLIENT_SECRET,
+ )
+
+ self.parse_config()
+
+ def test_client_secret_basic_requires_client_secret(self) -> None:
+ self.config_dict["experimental_features"]["msc3861"].update(
+ client_auth_method="client_secret_basic",
+ client_secret=None,
+ )
+
+ with self.assertRaises(ConfigError):
+ self.parse_config()
+
+ def test_client_secret_jwt_works(self) -> None:
+ self.config_dict["experimental_features"]["msc3861"].update(
+ client_auth_method="client_secret_jwt",
+ client_secret=CLIENT_SECRET,
+ )
+
+ self.parse_config()
+
+ def test_client_secret_jwt_requires_client_secret(self) -> None:
+ self.config_dict["experimental_features"]["msc3861"].update(
+ client_auth_method="client_secret_jwt",
+ client_secret=None,
+ )
+
+ with self.assertRaises(ConfigError):
+ self.parse_config()
+
+ def test_invalid_client_auth_method(self) -> None:
+ self.config_dict["experimental_features"]["msc3861"].update(
+ client_auth_method="invalid",
+ )
+
+ with self.assertRaises(ConfigError):
+ self.parse_config()
+
+ def test_private_key_jwt_requires_jwk(self) -> None:
+ self.config_dict["experimental_features"]["msc3861"].update(
+ client_auth_method="private_key_jwt",
+ )
+
+ with self.assertRaises(ConfigError):
+ self.parse_config()
+
+ def test_private_key_jwt_works(self) -> None:
+ self.config_dict["experimental_features"]["msc3861"].update(
+ client_auth_method="private_key_jwt",
+ jwk={
+ "p": "-frVdP_tZ-J_nIR6HNMDq1N7aunwm51nAqNnhqIyuA8ikx7LlQED1tt2LD3YEvYyW8nxE2V95HlCRZXQPMiRJBFOsbmYkzl2t-MpavTaObB_fct_JqcRtdXddg4-_ihdjRDwUOreq_dpWh6MIKsC3UyekfkHmeEJg5YpOTL15j8",
+ "kty": "RSA",
+ "q": "oFw-Enr_YozQB1ab-kawn4jY3yHi8B1nSmYT0s8oTCflrmps5BFJfCkHL5ij3iY15z0o2m0N-jjB1oSJ98O4RayEEYNQlHnTNTl0kRIWzpoqblHUIxVcahIpP_xTovBJzwi8XXoLGqHOOMA-r40LSyVgP2Ut8D9qBwV6_UfT0LU",
+ "d": "WFkDPYo4b4LIS64D_QtQfGGuAObPvc3HFfp9VZXyq3SJR58XZRHE0jqtlEMNHhOTgbMYS3w8nxPQ_qVzY-5hs4fIanwvB64mAoOGl0qMHO65DTD_WsGFwzYClJPBVniavkLE2Hmpu8IGe6lGliN8vREC6_4t69liY-XcN_ECboVtC2behKkLOEASOIMuS7YcKAhTJFJwkl1dqDlliEn5A4u4xy7nuWQz3juB1OFdKlwGA5dfhDNglhoLIwNnkLsUPPFO-WB5ZNEW35xxHOToxj4bShvDuanVA6mJPtTKjz0XibjB36bj_nF_j7EtbE2PdGJ2KevAVgElR4lqS4ISgQ",
+ "e": "AQAB",
+ "kid": "test",
+ "qi": "cPfNk8l8W5exVNNea4d7QZZ8Qr8LgHghypYAxz8PQh1fNa8Ya1SNUDVzC2iHHhszxxA0vB9C7jGze8dBrvnzWYF1XvQcqNIVVgHhD57R1Nm3dj2NoHIKe0Cu4bCUtP8xnZQUN4KX7y4IIcgRcBWG1hT6DEYZ4BxqicnBXXNXAUI",
+ "dp": "dKlMHvslV1sMBQaKWpNb3gPq0B13TZhqr3-E2_8sPlvJ3fD8P4CmwwnOn50JDuhY3h9jY5L06sBwXjspYISVv8hX-ndMLkEeF3lrJeA5S70D8rgakfZcPIkffm3tlf1Ok3v5OzoxSv3-67Df4osMniyYwDUBCB5Oq1tTx77xpU8",
+ "dq": "S4ooU1xNYYcjl9FcuJEEMqKsRrAXzzSKq6laPTwIp5dDwt2vXeAm1a4eDHXC-6rUSZGt5PbqVqzV4s-cjnJMI8YYkIdjNg4NSE1Ac_YpeDl3M3Colb5CQlU7yUB7xY2bt0NOOFp9UJZYJrOo09mFMGjy5eorsbitoZEbVqS3SuE",
+ "n": "nJbYKqFwnURKimaviyDFrNLD3gaKR1JW343Qem25VeZxoMq1665RHVoO8n1oBm4ClZdjIiZiVdpyqzD5-Ow12YQgQEf1ZHP3CCcOQQhU57Rh5XvScTe5IxYVkEW32IW2mp_CJ6WfjYpfeL4azarVk8H3Vr59d1rSrKTVVinVdZer9YLQyC_rWAQNtHafPBMrf6RYiNGV9EiYn72wFIXlLlBYQ9Fx7bfe1PaL6qrQSsZP3_rSpuvVdLh1lqGeCLR0pyclA9uo5m2tMyCXuuGQLbA_QJm5xEc7zd-WFdux2eXF045oxnSZ_kgQt-pdN7AxGWOVvwoTf9am6mSkEdv6iw",
+ },
+ )
+ self.parse_config()
+
+ def test_registration_cannot_be_enabled(self) -> None:
+ self.config_dict["enable_registration"] = True
+ with self.assertRaises(ConfigError):
+ self.parse_config()
+
+ def test_password_config_cannot_be_enabled(self) -> None:
+ self.config_dict["password_config"] = {"enabled": True}
+ with self.assertRaises(ConfigError):
+ self.parse_config()
+
+ def test_oidc_sso_cannot_be_enabled(self) -> None:
+ self.config_dict["oidc_providers"] = [
+ {
+ "idp_id": "microsoft",
+ "idp_name": "Microsoft",
+ "issuer": "https://login.microsoftonline.com/<tenant id>/v2.0",
+ "client_id": "<client id>",
+ "client_secret": "<client secret>",
+ "scopes": ["openid", "profile"],
+ "authorization_endpoint": "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/authorize",
+ "token_endpoint": "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/token",
+ "userinfo_endpoint": "https://graph.microsoft.com/oidc/userinfo",
+ }
+ ]
+
+ with self.assertRaises(ConfigError):
+ self.parse_config()
+
+ def test_cas_sso_cannot_be_enabled(self) -> None:
+ self.config_dict["cas_config"] = {
+ "enabled": True,
+ "server_url": "https://cas-server.com",
+ "displayname_attribute": "name",
+ "required_attributes": {"userGroup": "staff", "department": "None"},
+ }
+
+ with self.assertRaises(ConfigError):
+ self.parse_config()
+
+ def test_auth_providers_cannot_be_enabled(self) -> None:
+ self.config_dict["modules"] = [
+ {
+ "module": f"{__name__}.{CustomAuthModule.__qualname__}",
+ "config": {},
+ }
+ ]
+
+ # This requires actually setting up an HS, as the module will be run on setup,
+ # which should raise as the module tries to register an auth provider
+ config = self.parse_config()
+ reactor, clock = get_clock()
+ with self.assertRaises(ConfigError):
+ setup_test_homeserver(
+ self.addCleanup, reactor=reactor, clock=clock, config=config
+ )
+
+ def test_jwt_auth_cannot_be_enabled(self) -> None:
+ self.config_dict["jwt_config"] = {
+ "enabled": True,
+ "secret": "my-secret-token",
+ "algorithm": "HS256",
+ }
+
+ with self.assertRaises(ConfigError):
+ self.parse_config()
+
+ def test_login_via_existing_session_cannot_be_enabled(self) -> None:
+ self.config_dict["login_via_existing_session"] = {"enabled": True}
+ with self.assertRaises(ConfigError):
+ self.parse_config()
+
+ def test_captcha_cannot_be_enabled(self) -> None:
+ self.config_dict.update(
+ enable_registration_captcha=True,
+ recaptcha_public_key="test",
+ recaptcha_private_key="test",
+ )
+ with self.assertRaises(ConfigError):
+ self.parse_config()
+
+ def test_refreshable_tokens_cannot_be_enabled(self) -> None:
+ self.config_dict.update(
+ refresh_token_lifetime="24h",
+ refreshable_access_token_lifetime="10m",
+ nonrefreshable_access_token_lifetime="24h",
+ )
+ with self.assertRaises(ConfigError):
+ self.parse_config()
+
+ def test_session_lifetime_cannot_be_set(self) -> None:
+ self.config_dict["session_lifetime"] = "24h"
+ with self.assertRaises(ConfigError):
+ self.parse_config()
diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py
new file mode 100644
index 0000000000..6309d7b36e
--- /dev/null
+++ b/tests/handlers/test_oauth_delegation.py
@@ -0,0 +1,664 @@
+# Copyright 2022 Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from http import HTTPStatus
+from typing import Any, Dict, Union
+from unittest.mock import ANY, Mock
+from urllib.parse import parse_qs
+
+from signedjson.key import (
+ encode_verify_key_base64,
+ generate_signing_key,
+ get_verify_key,
+)
+from signedjson.sign import sign_json
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.errors import (
+ AuthError,
+ Codes,
+ InvalidClientTokenError,
+ OAuthInsufficientScopeError,
+ SynapseError,
+)
+from synapse.rest import admin
+from synapse.rest.client import account, devices, keys, login, logout, register
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
+
+from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
+from tests.unittest import HomeserverTestCase, skip_unless
+from tests.utils import mock_getRawHeaders
+
+try:
+ import authlib # noqa: F401
+
+ HAS_AUTHLIB = True
+except ImportError:
+ HAS_AUTHLIB = False
+
+
+# These are a few constants that are used as config parameters in the tests.
+SERVER_NAME = "test"
+ISSUER = "https://issuer/"
+CLIENT_ID = "test-client-id"
+CLIENT_SECRET = "test-client-secret"
+BASE_URL = "https://synapse/"
+SCOPES = ["openid"]
+
+AUTHORIZATION_ENDPOINT = ISSUER + "authorize"
+TOKEN_ENDPOINT = ISSUER + "token"
+USERINFO_ENDPOINT = ISSUER + "userinfo"
+WELL_KNOWN = ISSUER + ".well-known/openid-configuration"
+JWKS_URI = ISSUER + ".well-known/jwks.json"
+INTROSPECTION_ENDPOINT = ISSUER + "introspect"
+
+SYNAPSE_ADMIN_SCOPE = "urn:synapse:admin:*"
+MATRIX_USER_SCOPE = "urn:matrix:org.matrix.msc2967.client:api:*"
+MATRIX_GUEST_SCOPE = "urn:matrix:org.matrix.msc2967.client:api:guest"
+MATRIX_DEVICE_SCOPE_PREFIX = "urn:matrix:org.matrix.msc2967.client:device:"
+DEVICE = "AABBCCDD"
+MATRIX_DEVICE_SCOPE = MATRIX_DEVICE_SCOPE_PREFIX + DEVICE
+SUBJECT = "abc-def-ghi"
+USERNAME = "test-user"
+USER_ID = "@" + USERNAME + ":" + SERVER_NAME
+
+
+async def get_json(url: str) -> JsonDict:
+ # Mock get_json calls to handle jwks & oidc discovery endpoints
+ if url == WELL_KNOWN:
+ # Minimal discovery document, as defined in OpenID.Discovery
+ # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
+ return {
+ "issuer": ISSUER,
+ "authorization_endpoint": AUTHORIZATION_ENDPOINT,
+ "token_endpoint": TOKEN_ENDPOINT,
+ "jwks_uri": JWKS_URI,
+ "userinfo_endpoint": USERINFO_ENDPOINT,
+ "introspection_endpoint": INTROSPECTION_ENDPOINT,
+ "response_types_supported": ["code"],
+ "subject_types_supported": ["public"],
+ "id_token_signing_alg_values_supported": ["RS256"],
+ }
+ elif url == JWKS_URI:
+ return {"keys": []}
+
+ return {}
+
+
+@skip_unless(HAS_AUTHLIB, "requires authlib")
+class MSC3861OAuthDelegation(HomeserverTestCase):
+ servlets = [
+ account.register_servlets,
+ devices.register_servlets,
+ keys.register_servlets,
+ register.register_servlets,
+ login.register_servlets,
+ logout.register_servlets,
+ admin.register_servlets,
+ ]
+
+ def default_config(self) -> Dict[str, Any]:
+ config = super().default_config()
+ config["public_baseurl"] = BASE_URL
+ config["disable_registration"] = True
+ config["experimental_features"] = {
+ "msc3861": {
+ "enabled": True,
+ "issuer": ISSUER,
+ "client_id": CLIENT_ID,
+ "client_auth_method": "client_secret_post",
+ "client_secret": CLIENT_SECRET,
+ }
+ }
+ return config
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.http_client = Mock(spec=["get_json"])
+ self.http_client.get_json.side_effect = get_json
+ self.http_client.user_agent = b"Synapse Test"
+
+ hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
+
+ self.auth = hs.get_auth()
+
+ return hs
+
+ def _assertParams(self) -> None:
+ """Assert that the request parameters are correct."""
+ params = parse_qs(self.http_client.request.call_args[1]["data"].decode("utf-8"))
+ self.assertEqual(params["token"], ["mockAccessToken"])
+ self.assertEqual(params["client_id"], [CLIENT_ID])
+ self.assertEqual(params["client_secret"], [CLIENT_SECRET])
+
+ def test_inactive_token(self) -> None:
+ """The handler should return a 403 where the token is inactive."""
+
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={"active": False},
+ )
+ )
+ request = Mock(args={})
+ request.args[b"access_token"] = [b"mockAccessToken"]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
+ self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+ self.http_client.request.assert_called_once_with(
+ method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+ )
+ self._assertParams()
+
+ def test_active_no_scope(self) -> None:
+ """The handler should return a 403 where no scope is given."""
+
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={"active": True},
+ )
+ )
+ request = Mock(args={})
+ request.args[b"access_token"] = [b"mockAccessToken"]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
+ self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+ self.http_client.request.assert_called_once_with(
+ method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+ )
+ self._assertParams()
+
+ def test_active_user_no_subject(self) -> None:
+ """The handler should return a 500 when no subject is present."""
+
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={"active": True, "scope": " ".join([MATRIX_USER_SCOPE])},
+ )
+ )
+ request = Mock(args={})
+ request.args[b"access_token"] = [b"mockAccessToken"]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
+ self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+ self.http_client.request.assert_called_once_with(
+ method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+ )
+ self._assertParams()
+
+ def test_active_no_user_scope(self) -> None:
+ """The handler should return a 500 when no subject is present."""
+
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={
+ "active": True,
+ "sub": SUBJECT,
+ "scope": " ".join([MATRIX_DEVICE_SCOPE]),
+ },
+ )
+ )
+ request = Mock(args={})
+ request.args[b"access_token"] = [b"mockAccessToken"]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
+ self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+ self.http_client.request.assert_called_once_with(
+ method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+ )
+ self._assertParams()
+
+ def test_active_admin_not_user(self) -> None:
+ """The handler should raise when the scope has admin right but not user."""
+
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={
+ "active": True,
+ "sub": SUBJECT,
+ "scope": " ".join([SYNAPSE_ADMIN_SCOPE]),
+ "username": USERNAME,
+ },
+ )
+ )
+ request = Mock(args={})
+ request.args[b"access_token"] = [b"mockAccessToken"]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
+ self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+ self.http_client.request.assert_called_once_with(
+ method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+ )
+ self._assertParams()
+
+ def test_active_admin(self) -> None:
+ """The handler should return a requester with admin rights."""
+
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={
+ "active": True,
+ "sub": SUBJECT,
+ "scope": " ".join([SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE]),
+ "username": USERNAME,
+ },
+ )
+ )
+ request = Mock(args={})
+ request.args[b"access_token"] = [b"mockAccessToken"]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ requester = self.get_success(self.auth.get_user_by_req(request))
+ self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+ self.http_client.request.assert_called_once_with(
+ method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+ )
+ self._assertParams()
+ self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
+ self.assertEqual(requester.is_guest, False)
+ self.assertEqual(requester.device_id, None)
+ self.assertEqual(
+ get_awaitable_result(self.auth.is_server_admin(requester)), True
+ )
+
+ def test_active_admin_highest_privilege(self) -> None:
+ """The handler should resolve to the most permissive scope."""
+
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={
+ "active": True,
+ "sub": SUBJECT,
+ "scope": " ".join(
+ [SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE, MATRIX_GUEST_SCOPE]
+ ),
+ "username": USERNAME,
+ },
+ )
+ )
+ request = Mock(args={})
+ request.args[b"access_token"] = [b"mockAccessToken"]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ requester = self.get_success(self.auth.get_user_by_req(request))
+ self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+ self.http_client.request.assert_called_once_with(
+ method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+ )
+ self._assertParams()
+ self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
+ self.assertEqual(requester.is_guest, False)
+ self.assertEqual(requester.device_id, None)
+ self.assertEqual(
+ get_awaitable_result(self.auth.is_server_admin(requester)), True
+ )
+
+ def test_active_user(self) -> None:
+ """The handler should return a requester with normal user rights."""
+
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={
+ "active": True,
+ "sub": SUBJECT,
+ "scope": " ".join([MATRIX_USER_SCOPE]),
+ "username": USERNAME,
+ },
+ )
+ )
+ request = Mock(args={})
+ request.args[b"access_token"] = [b"mockAccessToken"]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ requester = self.get_success(self.auth.get_user_by_req(request))
+ self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+ self.http_client.request.assert_called_once_with(
+ method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+ )
+ self._assertParams()
+ self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
+ self.assertEqual(requester.is_guest, False)
+ self.assertEqual(requester.device_id, None)
+ self.assertEqual(
+ get_awaitable_result(self.auth.is_server_admin(requester)), False
+ )
+
+ def test_active_user_with_device(self) -> None:
+ """The handler should return a requester with normal user rights and a device ID."""
+
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={
+ "active": True,
+ "sub": SUBJECT,
+ "scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]),
+ "username": USERNAME,
+ },
+ )
+ )
+ request = Mock(args={})
+ request.args[b"access_token"] = [b"mockAccessToken"]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ requester = self.get_success(self.auth.get_user_by_req(request))
+ self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+ self.http_client.request.assert_called_once_with(
+ method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+ )
+ self._assertParams()
+ self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
+ self.assertEqual(requester.is_guest, False)
+ self.assertEqual(
+ get_awaitable_result(self.auth.is_server_admin(requester)), False
+ )
+ self.assertEqual(requester.device_id, DEVICE)
+
+ def test_multiple_devices(self) -> None:
+ """The handler should raise an error if multiple devices are found in the scope."""
+
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={
+ "active": True,
+ "sub": SUBJECT,
+ "scope": " ".join(
+ [
+ MATRIX_USER_SCOPE,
+ f"{MATRIX_DEVICE_SCOPE_PREFIX}AABBCC",
+ f"{MATRIX_DEVICE_SCOPE_PREFIX}DDEEFF",
+ ]
+ ),
+ "username": USERNAME,
+ },
+ )
+ )
+ request = Mock(args={})
+ request.args[b"access_token"] = [b"mockAccessToken"]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ self.get_failure(self.auth.get_user_by_req(request), AuthError)
+
+ def test_active_guest_not_allowed(self) -> None:
+ """The handler should return an insufficient scope error."""
+
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={
+ "active": True,
+ "sub": SUBJECT,
+ "scope": " ".join([MATRIX_GUEST_SCOPE, MATRIX_DEVICE_SCOPE]),
+ "username": USERNAME,
+ },
+ )
+ )
+ request = Mock(args={})
+ request.args[b"access_token"] = [b"mockAccessToken"]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ error = self.get_failure(
+ self.auth.get_user_by_req(request), OAuthInsufficientScopeError
+ )
+ self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+ self.http_client.request.assert_called_once_with(
+ method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+ )
+ self._assertParams()
+ self.assertEqual(
+ getattr(error.value, "headers", {})["WWW-Authenticate"],
+ 'Bearer error="insufficient_scope", scope="urn:matrix:org.matrix.msc2967.client:api:*"',
+ )
+
+ def test_active_guest_allowed(self) -> None:
+ """The handler should return a requester with guest user rights and a device ID."""
+
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={
+ "active": True,
+ "sub": SUBJECT,
+ "scope": " ".join([MATRIX_GUEST_SCOPE, MATRIX_DEVICE_SCOPE]),
+ "username": USERNAME,
+ },
+ )
+ )
+ request = Mock(args={})
+ request.args[b"access_token"] = [b"mockAccessToken"]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+ requester = self.get_success(
+ self.auth.get_user_by_req(request, allow_guest=True)
+ )
+ self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+ self.http_client.request.assert_called_once_with(
+ method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
+ )
+ self._assertParams()
+ self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
+ self.assertEqual(requester.is_guest, True)
+ self.assertEqual(
+ get_awaitable_result(self.auth.is_server_admin(requester)), False
+ )
+ self.assertEqual(requester.device_id, DEVICE)
+
+ def test_unavailable_introspection_endpoint(self) -> None:
+ """The handler should return an internal server error."""
+ request = Mock(args={})
+ request.args[b"access_token"] = [b"mockAccessToken"]
+ request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+
+ # The introspection endpoint is returning an error.
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(code=500, body=b"Internal Server Error")
+ )
+ error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
+ self.assertEqual(error.value.code, 503)
+
+ # The introspection endpoint request fails.
+ self.http_client.request = simple_async_mock(raises=Exception())
+ error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
+ self.assertEqual(error.value.code, 503)
+
+ # The introspection endpoint does not return a JSON object.
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200, payload=["this is an array", "not an object"]
+ )
+ )
+ error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
+ self.assertEqual(error.value.code, 503)
+
+ # The introspection endpoint does not return valid JSON.
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(code=200, body=b"this is not valid JSON")
+ )
+ error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
+ self.assertEqual(error.value.code, 503)
+
+ def make_device_keys(self, user_id: str, device_id: str) -> JsonDict:
+ # We only generate a master key to simplify the test.
+ master_signing_key = generate_signing_key(device_id)
+ master_verify_key = encode_verify_key_base64(get_verify_key(master_signing_key))
+
+ return {
+ "master_key": sign_json(
+ {
+ "user_id": user_id,
+ "usage": ["master"],
+ "keys": {"ed25519:" + master_verify_key: master_verify_key},
+ },
+ user_id,
+ master_signing_key,
+ ),
+ }
+
+ def test_cross_signing(self) -> None:
+ """Try uploading device keys with OAuth delegation enabled."""
+
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse.json(
+ code=200,
+ payload={
+ "active": True,
+ "sub": SUBJECT,
+ "scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]),
+ "username": USERNAME,
+ },
+ )
+ )
+ keys_upload_body = self.make_device_keys(USER_ID, DEVICE)
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/keys/device_signing/upload",
+ keys_upload_body,
+ access_token="mockAccessToken",
+ )
+
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/keys/device_signing/upload",
+ keys_upload_body,
+ access_token="mockAccessToken",
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.NOT_IMPLEMENTED, channel.json_body)
+
+ def expect_unauthorized(
+ self, method: str, path: str, content: Union[bytes, str, JsonDict] = ""
+ ) -> None:
+ channel = self.make_request(method, path, content, shorthand=False)
+
+ self.assertEqual(channel.code, 401, channel.json_body)
+
+ def expect_unrecognized(
+ self, method: str, path: str, content: Union[bytes, str, JsonDict] = ""
+ ) -> None:
+ channel = self.make_request(method, path, content)
+
+ self.assertEqual(channel.code, 404, channel.json_body)
+ self.assertEqual(
+ channel.json_body["errcode"], Codes.UNRECOGNIZED, channel.json_body
+ )
+
+ def test_uia_endpoints(self) -> None:
+ """Test that endpoints that were removed in MSC2964 are no longer available."""
+
+ # This is just an endpoint that should remain visible (but requires auth):
+ self.expect_unauthorized("GET", "/_matrix/client/v3/devices")
+
+ # This remains usable, but will require a uia scope:
+ self.expect_unauthorized(
+ "POST", "/_matrix/client/v3/keys/device_signing/upload"
+ )
+
+ def test_3pid_endpoints(self) -> None:
+ """Test that 3pid account management endpoints that were removed in MSC2964 are no longer available."""
+
+ # Remains and requires auth:
+ self.expect_unauthorized("GET", "/_matrix/client/v3/account/3pid")
+ self.expect_unauthorized(
+ "POST",
+ "/_matrix/client/v3/account/3pid/bind",
+ {
+ "client_secret": "foo",
+ "id_access_token": "bar",
+ "id_server": "foo",
+ "sid": "bar",
+ },
+ )
+ self.expect_unauthorized("POST", "/_matrix/client/v3/account/3pid/unbind", {})
+
+ # These are gone:
+ self.expect_unrecognized(
+ "POST", "/_matrix/client/v3/account/3pid"
+ ) # deprecated
+ self.expect_unrecognized("POST", "/_matrix/client/v3/account/3pid/add")
+ self.expect_unrecognized("POST", "/_matrix/client/v3/account/3pid/delete")
+ self.expect_unrecognized(
+ "POST", "/_matrix/client/v3/account/3pid/email/requestToken"
+ )
+ self.expect_unrecognized(
+ "POST", "/_matrix/client/v3/account/3pid/msisdn/requestToken"
+ )
+
+ def test_account_management_endpoints_removed(self) -> None:
+ """Test that account management endpoints that were removed in MSC2964 are no longer available."""
+ self.expect_unrecognized("POST", "/_matrix/client/v3/account/deactivate")
+ self.expect_unrecognized("POST", "/_matrix/client/v3/account/password")
+ self.expect_unrecognized(
+ "POST", "/_matrix/client/v3/account/password/email/requestToken"
+ )
+ self.expect_unrecognized(
+ "POST", "/_matrix/client/v3/account/password/msisdn/requestToken"
+ )
+
+ def test_registration_endpoints_removed(self) -> None:
+ """Test that registration endpoints that were removed in MSC2964 are no longer available."""
+ self.expect_unrecognized(
+ "GET", "/_matrix/client/v1/register/m.login.registration_token/validity"
+ )
+ # This is still available for AS registrations
+ # self.expect_unrecognized("POST", "/_matrix/client/v3/register")
+ self.expect_unrecognized("GET", "/_matrix/client/v3/register/available")
+ self.expect_unrecognized(
+ "POST", "/_matrix/client/v3/register/email/requestToken"
+ )
+ self.expect_unrecognized(
+ "POST", "/_matrix/client/v3/register/msisdn/requestToken"
+ )
+
+ def test_session_management_endpoints_removed(self) -> None:
+ """Test that session management endpoints that were removed in MSC2964 are no longer available."""
+ self.expect_unrecognized("GET", "/_matrix/client/v3/login")
+ self.expect_unrecognized("POST", "/_matrix/client/v3/login")
+ self.expect_unrecognized("GET", "/_matrix/client/v3/login/sso/redirect")
+ self.expect_unrecognized("POST", "/_matrix/client/v3/logout")
+ self.expect_unrecognized("POST", "/_matrix/client/v3/logout/all")
+ self.expect_unrecognized("POST", "/_matrix/client/v3/refresh")
+ self.expect_unrecognized("GET", "/_matrix/static/client/login")
+
+ def test_device_management_endpoints_removed(self) -> None:
+ """Test that device management endpoints that were removed in MSC2964 are no longer available."""
+ self.expect_unrecognized("POST", "/_matrix/client/v3/delete_devices")
+ self.expect_unrecognized("DELETE", "/_matrix/client/v3/devices/{DEVICE}")
+
+ def test_openid_endpoints_removed(self) -> None:
+ """Test that OpenID id_token endpoints that were removed in MSC2964 are no longer available."""
+ self.expect_unrecognized(
+ "POST", "/_matrix/client/v3/user/{USERNAME}/openid/request_token"
+ )
+
+ def test_admin_api_endpoints_removed(self) -> None:
+ """Test that admin API endpoints that were removed in MSC2964 are no longer available."""
+ self.expect_unrecognized("GET", "/_synapse/admin/v1/registration_tokens")
+ self.expect_unrecognized("POST", "/_synapse/admin/v1/registration_tokens/new")
+ self.expect_unrecognized("GET", "/_synapse/admin/v1/registration_tokens/abcd")
+ self.expect_unrecognized("PUT", "/_synapse/admin/v1/registration_tokens/abcd")
+ self.expect_unrecognized(
+ "DELETE", "/_synapse/admin/v1/registration_tokens/abcd"
+ )
+ self.expect_unrecognized("POST", "/_synapse/admin/v1/reset_password/foo")
+ self.expect_unrecognized("POST", "/_synapse/admin/v1/users/foo/login")
+ self.expect_unrecognized("GET", "/_synapse/admin/v1/register")
+ self.expect_unrecognized("POST", "/_synapse/admin/v1/register")
+ self.expect_unrecognized("GET", "/_synapse/admin/v1/users/foo/admin")
+ self.expect_unrecognized("PUT", "/_synapse/admin/v1/users/foo/admin")
+ self.expect_unrecognized("POST", "/_synapse/admin/v1/account_validity/validity")
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 64a9a22afe..196ceb0b82 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -80,11 +80,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
- (
- self.get_success(
- self.store.get_profile_displayname(self.frank.localpart)
- )
- ),
+ (self.get_success(self.store.get_profile_displayname(self.frank))),
"Frank Jr.",
)
@@ -96,11 +92,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
- (
- self.get_success(
- self.store.get_profile_displayname(self.frank.localpart)
- )
- ),
+ (self.get_success(self.store.get_profile_displayname(self.frank))),
"Frank",
)
@@ -112,7 +104,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertIsNone(
- self.get_success(self.store.get_profile_displayname(self.frank.localpart))
+ self.get_success(self.store.get_profile_displayname(self.frank))
)
def test_set_my_name_if_disabled(self) -> None:
@@ -122,11 +114,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.get_success(self.store.set_profile_displayname(self.frank, "Frank"))
self.assertEqual(
- (
- self.get_success(
- self.store.get_profile_displayname(self.frank.localpart)
- )
- ),
+ (self.get_success(self.store.get_profile_displayname(self.frank))),
"Frank",
)
@@ -201,7 +189,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
- (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
+ (self.get_success(self.store.get_profile_avatar_url(self.frank))),
"http://my.server/pic.gif",
)
@@ -215,7 +203,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
- (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
+ (self.get_success(self.store.get_profile_avatar_url(self.frank))),
"http://my.server/me.png",
)
@@ -229,7 +217,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertIsNone(
- (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
+ (self.get_success(self.store.get_profile_avatar_url(self.frank))),
)
def test_set_my_avatar_if_disabled(self) -> None:
@@ -241,7 +229,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(
- (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
+ (self.get_success(self.store.get_profile_avatar_url(self.frank))),
"http://my.server/me.png",
)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 73822b07a5..8d8584609b 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -17,7 +17,7 @@ from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
-from synapse.api.auth import Auth
+from synapse.api.auth.internal import InternalAuth
from synapse.api.constants import UserTypes
from synapse.api.errors import (
CodeMessageException,
@@ -683,7 +683,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
request = Mock(args={})
request.args[b"access_token"] = [token.encode("ascii")]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- auth = Auth(self.hs)
+ auth = InternalAuth(self.hs)
requester = self.get_success(auth.get_user_by_req(request))
self.assertTrue(requester.shadow_banned)
diff --git a/tests/media/test_html_preview.py b/tests/media/test_html_preview.py
index e7da75db3e..ea84bb3d3d 100644
--- a/tests/media/test_html_preview.py
+++ b/tests/media/test_html_preview.py
@@ -24,7 +24,7 @@ from tests import unittest
try:
import lxml
except ImportError:
- lxml = None
+ lxml = None # type: ignore[assignment]
class SummarizeTestCase(unittest.TestCase):
@@ -160,6 +160,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
+ assert tree is not None
og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -176,6 +177,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
+ assert tree is not None
og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -195,6 +197,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
+ assert tree is not None
og = parse_html_to_open_graph(tree)
self.assertEqual(
@@ -217,6 +220,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
+ assert tree is not None
og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -231,6 +235,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
+ assert tree is not None
og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
@@ -246,6 +251,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
+ assert tree is not None
og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Title", "og:description": "Title"})
@@ -261,6 +267,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
+ assert tree is not None
og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
@@ -281,6 +288,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
+ assert tree is not None
og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Title", "og:description": "Finally!"})
@@ -296,6 +304,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
"""
tree = decode_body(html, "http://example.com/test.html")
+ assert tree is not None
og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
@@ -324,6 +333,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
<head><title>Foo</title></head><body>Some text.</body></html>
""".strip()
tree = decode_body(html, "http://example.com/test.html")
+ assert tree is not None
og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -338,6 +348,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
</html>
"""
tree = decode_body(html, "http://example.com/test.html", "invalid-encoding")
+ assert tree is not None
og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -353,6 +364,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
</html>
"""
tree = decode_body(html, "http://example.com/test.html")
+ assert tree is not None
og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
@@ -367,6 +379,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
</html>
"""
tree = decode_body(html, "http://example.com/test.html")
+ assert tree is not None
og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."})
@@ -380,6 +393,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
</html>
"""
tree = decode_body(html, "http://example.com/test.html")
+ assert tree is not None
og = parse_html_to_open_graph(tree)
self.assertEqual(
og,
@@ -401,6 +415,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
</html>
"""
tree = decode_body(html, "http://example.com/test.html")
+ assert tree is not None
og = parse_html_to_open_graph(tree)
self.assertEqual(
og,
@@ -419,6 +434,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
with a cheeky SVG</svg></u> and <strong>some</strong> tail text</b></a>
"""
tree = decode_body(html, "http://example.com/test.html")
+ assert tree is not None
og = parse_html_to_open_graph(tree)
self.assertEqual(
og,
diff --git a/tests/media/test_oembed.py b/tests/media/test_oembed.py
index c8bf8421da..3bc19cb1cc 100644
--- a/tests/media/test_oembed.py
+++ b/tests/media/test_oembed.py
@@ -28,7 +28,7 @@ from tests.unittest import HomeserverTestCase
try:
import lxml
except ImportError:
- lxml = None
+ lxml = None # type: ignore[assignment]
class OEmbedTests(HomeserverTestCase):
diff --git a/tests/media/test_url_previewer.py b/tests/media/test_url_previewer.py
index 3c4c7d6765..46ecde5344 100644
--- a/tests/media/test_url_previewer.py
+++ b/tests/media/test_url_previewer.py
@@ -24,7 +24,7 @@ from tests.unittest import override_config
try:
import lxml
except ImportError:
- lxml = None
+ lxml = None # type: ignore[assignment]
class URLPreviewTests(unittest.HomeserverTestCase):
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index bff7114cd8..b3310abe1b 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -28,7 +28,7 @@ from synapse.module_api import ModuleApi
from synapse.rest import admin
from synapse.rest.client import login, notifications, presence, profile, room
from synapse.server import HomeServer
-from synapse.types import JsonDict, create_requester
+from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
from tests.events.test_presence_router import send_presence_update, sync_presence
@@ -103,7 +103,9 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
self.assertEqual(email["added_at"], 0)
# Check that the displayname was assigned
- displayname = self.get_success(self.store.get_profile_displayname("bob"))
+ displayname = self.get_success(
+ self.store.get_profile_displayname(UserID.from_string("@bob:test"))
+ )
self.assertEqual(displayname, "Bobberino")
def test_can_register_admin_user(self) -> None:
diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py
index 9501096a77..1e06f86071 100644
--- a/tests/push/test_bulk_push_rule_evaluator.py
+++ b/tests/push/test_bulk_push_rule_evaluator.py
@@ -228,7 +228,6 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
)
return len(result) > 0
- @override_config({"experimental_features": {"msc3952_intentional_mentions": True}})
def test_user_mentions(self) -> None:
"""Test the behavior of an event which includes invalid user mentions."""
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
@@ -237,9 +236,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
self.assertFalse(self._create_and_process(bulk_evaluator))
# An empty mentions field should not notify.
self.assertFalse(
- self._create_and_process(
- bulk_evaluator, {EventContentFields.MSC3952_MENTIONS: {}}
- )
+ self._create_and_process(bulk_evaluator, {EventContentFields.MENTIONS: {}})
)
# Non-dict mentions should be ignored.
@@ -253,7 +250,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
for mentions in (None, True, False, 1, "foo", []):
self.assertFalse(
self._create_and_process(
- bulk_evaluator, {EventContentFields.MSC3952_MENTIONS: mentions}
+ bulk_evaluator, {EventContentFields.MENTIONS: mentions}
)
)
@@ -262,7 +259,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
self.assertFalse(
self._create_and_process(
bulk_evaluator,
- {EventContentFields.MSC3952_MENTIONS: {"user_ids": mentions}},
+ {EventContentFields.MENTIONS: {"user_ids": mentions}},
)
)
@@ -270,14 +267,14 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
self.assertTrue(
self._create_and_process(
bulk_evaluator,
- {EventContentFields.MSC3952_MENTIONS: {"user_ids": [self.alice]}},
+ {EventContentFields.MENTIONS: {"user_ids": [self.alice]}},
)
)
self.assertTrue(
self._create_and_process(
bulk_evaluator,
{
- EventContentFields.MSC3952_MENTIONS: {
+ EventContentFields.MENTIONS: {
"user_ids": ["@another:test", self.alice]
}
},
@@ -288,11 +285,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
self.assertTrue(
self._create_and_process(
bulk_evaluator,
- {
- EventContentFields.MSC3952_MENTIONS: {
- "user_ids": [self.alice, self.alice]
- }
- },
+ {EventContentFields.MENTIONS: {"user_ids": [self.alice, self.alice]}},
)
)
@@ -307,7 +300,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
self._create_and_process(
bulk_evaluator,
{
- EventContentFields.MSC3952_MENTIONS: {
+ EventContentFields.MENTIONS: {
"user_ids": [None, True, False, {}, []]
}
},
@@ -317,7 +310,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
self._create_and_process(
bulk_evaluator,
{
- EventContentFields.MSC3952_MENTIONS: {
+ EventContentFields.MENTIONS: {
"user_ids": [None, True, False, {}, [], self.alice]
}
},
@@ -331,12 +324,11 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
{
"body": self.alice,
"msgtype": "m.text",
- EventContentFields.MSC3952_MENTIONS: {},
+ EventContentFields.MENTIONS: {},
},
)
)
- @override_config({"experimental_features": {"msc3952_intentional_mentions": True}})
def test_room_mentions(self) -> None:
"""Test the behavior of an event which includes invalid room mentions."""
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
@@ -344,7 +336,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
# Room mentions from those without power should not notify.
self.assertFalse(
self._create_and_process(
- bulk_evaluator, {EventContentFields.MSC3952_MENTIONS: {"room": True}}
+ bulk_evaluator, {EventContentFields.MENTIONS: {"room": True}}
)
)
@@ -358,7 +350,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
)
self.assertTrue(
self._create_and_process(
- bulk_evaluator, {EventContentFields.MSC3952_MENTIONS: {"room": True}}
+ bulk_evaluator, {EventContentFields.MENTIONS: {"room": True}}
)
)
@@ -374,7 +366,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
self.assertFalse(
self._create_and_process(
bulk_evaluator,
- {EventContentFields.MSC3952_MENTIONS: {"room": mentions}},
+ {EventContentFields.MENTIONS: {"room": mentions}},
)
)
@@ -385,7 +377,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
{
"body": "@room",
"msgtype": "m.text",
- EventContentFields.MSC3952_MENTIONS: {},
+ EventContentFields.MENTIONS: {},
},
)
)
diff --git a/tests/rest/admin/test_jwks.py b/tests/rest/admin/test_jwks.py
new file mode 100644
index 0000000000..a9a6191c73
--- /dev/null
+++ b/tests/rest/admin/test_jwks.py
@@ -0,0 +1,106 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict
+
+from twisted.web.resource import Resource
+
+from synapse.rest.synapse.client import build_synapse_client_resource_tree
+
+from tests.unittest import HomeserverTestCase, override_config, skip_unless
+
+try:
+ import authlib # noqa: F401
+
+ HAS_AUTHLIB = True
+except ImportError:
+ HAS_AUTHLIB = False
+
+
+@skip_unless(HAS_AUTHLIB, "requires authlib")
+class JWKSTestCase(HomeserverTestCase):
+ """Test /_synapse/jwks JWKS data."""
+
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d.update(build_synapse_client_resource_tree(self.hs))
+ return d
+
+ def test_empty_jwks(self) -> None:
+ """Test that the JWKS endpoint is not present by default."""
+ channel = self.make_request("GET", "/_synapse/jwks")
+ self.assertEqual(404, channel.code, channel.result)
+
+ @override_config(
+ {
+ "disable_registration": True,
+ "experimental_features": {
+ "msc3861": {
+ "enabled": True,
+ "issuer": "https://issuer/",
+ "client_id": "test-client-id",
+ "client_auth_method": "client_secret_post",
+ "client_secret": "secret",
+ },
+ },
+ }
+ )
+ def test_empty_jwks_for_msc3861_client_secret_post(self) -> None:
+ """Test that the JWKS endpoint is empty when plain auth is used."""
+ channel = self.make_request("GET", "/_synapse/jwks")
+ self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual({"keys": []}, channel.json_body)
+
+ @override_config(
+ {
+ "disable_registration": True,
+ "experimental_features": {
+ "msc3861": {
+ "enabled": True,
+ "issuer": "https://issuer/",
+ "client_id": "test-client-id",
+ "client_auth_method": "private_key_jwt",
+ "jwk": {
+ "p": "-frVdP_tZ-J_nIR6HNMDq1N7aunwm51nAqNnhqIyuA8ikx7LlQED1tt2LD3YEvYyW8nxE2V95HlCRZXQPMiRJBFOsbmYkzl2t-MpavTaObB_fct_JqcRtdXddg4-_ihdjRDwUOreq_dpWh6MIKsC3UyekfkHmeEJg5YpOTL15j8",
+ "kty": "RSA",
+ "q": "oFw-Enr_YozQB1ab-kawn4jY3yHi8B1nSmYT0s8oTCflrmps5BFJfCkHL5ij3iY15z0o2m0N-jjB1oSJ98O4RayEEYNQlHnTNTl0kRIWzpoqblHUIxVcahIpP_xTovBJzwi8XXoLGqHOOMA-r40LSyVgP2Ut8D9qBwV6_UfT0LU",
+ "d": "WFkDPYo4b4LIS64D_QtQfGGuAObPvc3HFfp9VZXyq3SJR58XZRHE0jqtlEMNHhOTgbMYS3w8nxPQ_qVzY-5hs4fIanwvB64mAoOGl0qMHO65DTD_WsGFwzYClJPBVniavkLE2Hmpu8IGe6lGliN8vREC6_4t69liY-XcN_ECboVtC2behKkLOEASOIMuS7YcKAhTJFJwkl1dqDlliEn5A4u4xy7nuWQz3juB1OFdKlwGA5dfhDNglhoLIwNnkLsUPPFO-WB5ZNEW35xxHOToxj4bShvDuanVA6mJPtTKjz0XibjB36bj_nF_j7EtbE2PdGJ2KevAVgElR4lqS4ISgQ",
+ "e": "AQAB",
+ "kid": "test",
+ "qi": "cPfNk8l8W5exVNNea4d7QZZ8Qr8LgHghypYAxz8PQh1fNa8Ya1SNUDVzC2iHHhszxxA0vB9C7jGze8dBrvnzWYF1XvQcqNIVVgHhD57R1Nm3dj2NoHIKe0Cu4bCUtP8xnZQUN4KX7y4IIcgRcBWG1hT6DEYZ4BxqicnBXXNXAUI",
+ "dp": "dKlMHvslV1sMBQaKWpNb3gPq0B13TZhqr3-E2_8sPlvJ3fD8P4CmwwnOn50JDuhY3h9jY5L06sBwXjspYISVv8hX-ndMLkEeF3lrJeA5S70D8rgakfZcPIkffm3tlf1Ok3v5OzoxSv3-67Df4osMniyYwDUBCB5Oq1tTx77xpU8",
+ "dq": "S4ooU1xNYYcjl9FcuJEEMqKsRrAXzzSKq6laPTwIp5dDwt2vXeAm1a4eDHXC-6rUSZGt5PbqVqzV4s-cjnJMI8YYkIdjNg4NSE1Ac_YpeDl3M3Colb5CQlU7yUB7xY2bt0NOOFp9UJZYJrOo09mFMGjy5eorsbitoZEbVqS3SuE",
+ "n": "nJbYKqFwnURKimaviyDFrNLD3gaKR1JW343Qem25VeZxoMq1665RHVoO8n1oBm4ClZdjIiZiVdpyqzD5-Ow12YQgQEf1ZHP3CCcOQQhU57Rh5XvScTe5IxYVkEW32IW2mp_CJ6WfjYpfeL4azarVk8H3Vr59d1rSrKTVVinVdZer9YLQyC_rWAQNtHafPBMrf6RYiNGV9EiYn72wFIXlLlBYQ9Fx7bfe1PaL6qrQSsZP3_rSpuvVdLh1lqGeCLR0pyclA9uo5m2tMyCXuuGQLbA_QJm5xEc7zd-WFdux2eXF045oxnSZ_kgQt-pdN7AxGWOVvwoTf9am6mSkEdv6iw",
+ },
+ },
+ },
+ }
+ )
+ def test_key_returned_for_msc3861_client_secret_post(self) -> None:
+ """Test that the JWKS includes public part of JWK for private_key_jwt auth is used."""
+ channel = self.make_request("GET", "/_synapse/jwks")
+ self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual(
+ {
+ "keys": [
+ {
+ "kty": "RSA",
+ "e": "AQAB",
+ "kid": "test",
+ "n": "nJbYKqFwnURKimaviyDFrNLD3gaKR1JW343Qem25VeZxoMq1665RHVoO8n1oBm4ClZdjIiZiVdpyqzD5-Ow12YQgQEf1ZHP3CCcOQQhU57Rh5XvScTe5IxYVkEW32IW2mp_CJ6WfjYpfeL4azarVk8H3Vr59d1rSrKTVVinVdZer9YLQyC_rWAQNtHafPBMrf6RYiNGV9EiYn72wFIXlLlBYQ9Fx7bfe1PaL6qrQSsZP3_rSpuvVdLh1lqGeCLR0pyclA9uo5m2tMyCXuuGQLbA_QJm5xEc7zd-WFdux2eXF045oxnSZ_kgQt-pdN7AxGWOVvwoTf9am6mSkEdv6iw",
+ }
+ ]
+ },
+ channel.json_body,
+ )
diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py
index c16e8d43f4..cf23430f6a 100644
--- a/tests/rest/client/test_capabilities.py
+++ b/tests/rest/client/test_capabilities.py
@@ -186,3 +186,31 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.assertGreater(len(details["support"]), 0)
for room_version in details["support"]:
self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, str(room_version))
+
+ def test_get_get_token_login_fields_when_disabled(self) -> None:
+ """By default login via an existing session is disabled."""
+ access_token = self.get_success(
+ self.auth_handler.create_access_token_for_user_id(
+ self.user, device_id=None, valid_until_ms=None
+ )
+ )
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertFalse(capabilities["m.get_login_token"]["enabled"])
+
+ @override_config({"login_via_existing_session": {"enabled": True}})
+ def test_get_get_token_login_fields_when_enabled(self) -> None:
+ access_token = self.get_success(
+ self.auth_handler.create_access_token_for_user_id(
+ self.user, device_id=None, valid_until_ms=None
+ )
+ )
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertTrue(capabilities["m.get_login_token"]["enabled"])
diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py
index 9faa9de050..a2d5d340be 100644
--- a/tests/rest/client/test_filter.py
+++ b/tests/rest/client/test_filter.py
@@ -46,7 +46,9 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {"filter_id": "0"})
filter = self.get_success(
- self.store.get_user_filter(user_localpart="apple", filter_id=0)
+ self.store.get_user_filter(
+ user_id=UserID.from_string(FilterTestCase.user_id), filter_id=0
+ )
)
self.pump()
self.assertEqual(filter, self.EXAMPLE_FILTER)
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index dc32982e22..f3c3bc69a9 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -446,6 +446,29 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
)
+ def test_get_login_flows_with_login_via_existing_disabled(self) -> None:
+ """GET /login should return m.login.token without get_login_token"""
+ channel = self.make_request("GET", "/_matrix/client/r0/login")
+ self.assertEqual(channel.code, 200, channel.result)
+
+ flows = {flow["type"]: flow for flow in channel.json_body["flows"]}
+ self.assertNotIn("m.login.token", flows)
+
+ @override_config({"login_via_existing_session": {"enabled": True}})
+ def test_get_login_flows_with_login_via_existing_enabled(self) -> None:
+ """GET /login should return m.login.token with get_login_token true"""
+ channel = self.make_request("GET", "/_matrix/client/r0/login")
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self.assertCountEqual(
+ channel.json_body["flows"],
+ [
+ {"type": "m.login.token", "get_login_token": True},
+ {"type": "m.login.password"},
+ {"type": "m.login.application_service"},
+ ],
+ )
+
@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
class MultiSSOTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py
index b8187db982..f05e619aa8 100644
--- a/tests/rest/client/test_login_token_request.py
+++ b/tests/rest/client/test_login_token_request.py
@@ -15,14 +15,14 @@
from twisted.test.proto_helpers import MemoryReactor
from synapse.rest import admin
-from synapse.rest.client import login, login_token_request
+from synapse.rest.client import login, login_token_request, versions
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
from tests.unittest import override_config
-endpoint = "/_matrix/client/unstable/org.matrix.msc3882/login/token"
+GET_TOKEN_ENDPOINT = "/_matrix/client/v1/login/get_token"
class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
@@ -30,6 +30,7 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
login.register_servlets,
admin.register_servlets,
login_token_request.register_servlets,
+ versions.register_servlets, # TODO: remove once unstable revision 0 support is removed
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
@@ -46,26 +47,26 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
self.password = "password"
def test_disabled(self) -> None:
- channel = self.make_request("POST", endpoint, {}, access_token=None)
+ channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=None)
self.assertEqual(channel.code, 404)
self.register_user(self.user, self.password)
token = self.login(self.user, self.password)
- channel = self.make_request("POST", endpoint, {}, access_token=token)
+ channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=token)
self.assertEqual(channel.code, 404)
- @override_config({"experimental_features": {"msc3882_enabled": True}})
+ @override_config({"login_via_existing_session": {"enabled": True}})
def test_require_auth(self) -> None:
- channel = self.make_request("POST", endpoint, {}, access_token=None)
+ channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=None)
self.assertEqual(channel.code, 401)
- @override_config({"experimental_features": {"msc3882_enabled": True}})
+ @override_config({"login_via_existing_session": {"enabled": True}})
def test_uia_on(self) -> None:
user_id = self.register_user(self.user, self.password)
token = self.login(self.user, self.password)
- channel = self.make_request("POST", endpoint, {}, access_token=token)
+ channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=token)
self.assertEqual(channel.code, 401)
self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
@@ -80,9 +81,9 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
},
}
- channel = self.make_request("POST", endpoint, uia, access_token=token)
+ channel = self.make_request("POST", GET_TOKEN_ENDPOINT, uia, access_token=token)
self.assertEqual(channel.code, 200)
- self.assertEqual(channel.json_body["expires_in"], 300)
+ self.assertEqual(channel.json_body["expires_in_ms"], 300000)
login_token = channel.json_body["login_token"]
@@ -95,15 +96,15 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["user_id"], user_id)
@override_config(
- {"experimental_features": {"msc3882_enabled": True, "msc3882_ui_auth": False}}
+ {"login_via_existing_session": {"enabled": True, "require_ui_auth": False}}
)
def test_uia_off(self) -> None:
user_id = self.register_user(self.user, self.password)
token = self.login(self.user, self.password)
- channel = self.make_request("POST", endpoint, {}, access_token=token)
+ channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=token)
self.assertEqual(channel.code, 200)
- self.assertEqual(channel.json_body["expires_in"], 300)
+ self.assertEqual(channel.json_body["expires_in_ms"], 300000)
login_token = channel.json_body["login_token"]
@@ -117,10 +118,10 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
@override_config(
{
- "experimental_features": {
- "msc3882_enabled": True,
- "msc3882_ui_auth": False,
- "msc3882_token_timeout": "15s",
+ "login_via_existing_session": {
+ "enabled": True,
+ "require_ui_auth": False,
+ "token_timeout": "15s",
}
}
)
@@ -128,6 +129,40 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
self.register_user(self.user, self.password)
token = self.login(self.user, self.password)
- channel = self.make_request("POST", endpoint, {}, access_token=token)
+ channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=token)
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["expires_in_ms"], 15000)
+
+ @override_config(
+ {
+ "login_via_existing_session": {
+ "enabled": True,
+ "require_ui_auth": False,
+ "token_timeout": "15s",
+ }
+ }
+ )
+ def test_unstable_support(self) -> None:
+ # TODO: remove support for unstable MSC3882 is no longer needed
+
+ # check feature is advertised in versions response:
+ channel = self.make_request(
+ "GET", "/_matrix/client/versions", {}, access_token=None
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body["unstable_features"]["org.matrix.msc3882"], True
+ )
+
+ self.register_user(self.user, self.password)
+ token = self.login(self.user, self.password)
+
+ # check feature is available via the unstable endpoint and returns an expires_in value in seconds
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/unstable/org.matrix.msc3882/login/token",
+ {},
+ access_token=token,
+ )
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["expires_in"], 15)
diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py
index 84a60c0b07..b43e95292c 100644
--- a/tests/rest/client/test_redactions.py
+++ b/tests/rest/client/test_redactions.py
@@ -217,9 +217,9 @@ class RedactionsTestCase(HomeserverTestCase):
self._redact_event(self.mod_access_token, self.room_id, msg_id)
@override_config({"experimental_features": {"msc3912_enabled": True}})
- def test_redact_relations(self) -> None:
- """Tests that we can redact the relations of an event at the same time as the
- event itself.
+ def test_redact_relations_with_types(self) -> None:
+ """Tests that we can redact the relations of an event of specific types
+ at the same time as the event itself.
"""
# Send a root event.
res = self.helper.send_event(
@@ -318,6 +318,104 @@ class RedactionsTestCase(HomeserverTestCase):
self.assertNotIn("redacted_because", event_dict, event_dict)
@override_config({"experimental_features": {"msc3912_enabled": True}})
+ def test_redact_all_relations(self) -> None:
+ """Tests that we can redact all the relations of an event at the same time as the
+ event itself.
+ """
+ # Send a root event.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "hello"},
+ tok=self.mod_access_token,
+ )
+ root_event_id = res["event_id"]
+
+ # Send an edit to this root event.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "body": " * hello world",
+ "m.new_content": {
+ "body": "hello world",
+ "msgtype": "m.text",
+ },
+ "m.relates_to": {
+ "event_id": root_event_id,
+ "rel_type": RelationTypes.REPLACE,
+ },
+ "msgtype": "m.text",
+ },
+ tok=self.mod_access_token,
+ )
+ edit_event_id = res["event_id"]
+
+ # Also send a threaded message whose root is the same as the edit's.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "message 1",
+ "m.relates_to": {
+ "event_id": root_event_id,
+ "rel_type": RelationTypes.THREAD,
+ },
+ },
+ tok=self.mod_access_token,
+ )
+ threaded_event_id = res["event_id"]
+
+ # Also send a reaction, again with the same root.
+ res = self.helper.send_event(
+ room_id=self.room_id,
+ type=EventTypes.Reaction,
+ content={
+ "m.relates_to": {
+ "rel_type": RelationTypes.ANNOTATION,
+ "event_id": root_event_id,
+ "key": "👍",
+ }
+ },
+ tok=self.mod_access_token,
+ )
+ reaction_event_id = res["event_id"]
+
+ # Redact the root event, specifying that we also want to delete all events that
+ # relate to it.
+ self._redact_event(
+ self.mod_access_token,
+ self.room_id,
+ root_event_id,
+ with_relations=["*"],
+ )
+
+ # Check that the root event got redacted.
+ event_dict = self.helper.get_event(
+ self.room_id, root_event_id, self.mod_access_token
+ )
+ self.assertIn("redacted_because", event_dict, event_dict)
+
+ # Check that the edit got redacted.
+ event_dict = self.helper.get_event(
+ self.room_id, edit_event_id, self.mod_access_token
+ )
+ self.assertIn("redacted_because", event_dict, event_dict)
+
+ # Check that the threaded message got redacted.
+ event_dict = self.helper.get_event(
+ self.room_id, threaded_event_id, self.mod_access_token
+ )
+ self.assertIn("redacted_because", event_dict, event_dict)
+
+ # Check that the reaction got redacted.
+ event_dict = self.helper.get_event(
+ self.room_id, reaction_event_id, self.mod_access_token
+ )
+ self.assertIn("redacted_because", event_dict, event_dict)
+
+ @override_config({"experimental_features": {"msc3912_enabled": True}})
def test_redact_relations_no_perms(self) -> None:
"""Tests that, when redacting a message along with its relations, if not all
the related messages can be redacted because of insufficient permissions, the
diff --git a/tests/rest/media/test_url_preview.py b/tests/rest/media/test_url_preview.py
index 170fb0534a..05d5e39cab 100644
--- a/tests/rest/media/test_url_preview.py
+++ b/tests/rest/media/test_url_preview.py
@@ -40,7 +40,7 @@ from tests.test_utils import SMALL_PNG
try:
import lxml
except ImportError:
- lxml = None
+ lxml = None # type: ignore[assignment]
class URLPreviewTests(unittest.HomeserverTestCase):
diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
index 2091b08d89..377243a170 100644
--- a/tests/rest/test_well_known.py
+++ b/tests/rest/test_well_known.py
@@ -17,6 +17,13 @@ from synapse.rest.well_known import well_known_resource
from tests import unittest
+try:
+ import authlib # noqa: F401
+
+ HAS_AUTHLIB = True
+except ImportError:
+ HAS_AUTHLIB = False
+
class WellKnownTests(unittest.HomeserverTestCase):
def create_test_resource(self) -> Resource:
@@ -96,3 +103,37 @@ class WellKnownTests(unittest.HomeserverTestCase):
"GET", "/.well-known/matrix/server", shorthand=False
)
self.assertEqual(channel.code, 404)
+
+ @unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
+ @unittest.override_config(
+ {
+ "public_baseurl": "https://homeserver", # this is only required so that client well known is served
+ "experimental_features": {
+ "msc3861": {
+ "enabled": True,
+ "issuer": "https://issuer",
+ "account_management_url": "https://my-account.issuer",
+ "client_id": "id",
+ "client_auth_method": "client_secret_post",
+ "client_secret": "secret",
+ },
+ },
+ "disable_registration": True,
+ }
+ )
+ def test_client_well_known_msc3861_oauth_delegation(self) -> None:
+ channel = self.make_request(
+ "GET", "/.well-known/matrix/client", shorthand=False
+ )
+
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(
+ channel.json_body,
+ {
+ "m.homeserver": {"base_url": "https://homeserver/"},
+ "org.matrix.msc2965.authentication": {
+ "issuer": "https://issuer",
+ "account": "https://my-account.issuer",
+ },
+ },
+ )
diff --git a/tests/server.py b/tests/server.py
index 7296f0a552..a12c3e3b9a 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -642,7 +642,7 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None:
pool.runWithConnection = runWithConnection # type: ignore[assignment]
pool.runInteraction = runInteraction # type: ignore[assignment]
# Replace the thread pool with a threadless 'thread' pool
- pool.threadpool = ThreadPool(clock._reactor) # type: ignore[assignment]
+ pool.threadpool = ThreadPool(clock._reactor)
pool.running = True
# We've just changed the Databases to run DB transactions on the same
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index f9cf0fcb82..fe5bb77913 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer
@@ -35,18 +36,14 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
self.assertEqual(
"Frank",
- (
- self.get_success(
- self.store.get_profile_displayname(self.u_frank.localpart)
- )
- ),
+ (self.get_success(self.store.get_profile_displayname(self.u_frank))),
)
# test set to None
self.get_success(self.store.set_profile_displayname(self.u_frank, None))
self.assertIsNone(
- self.get_success(self.store.get_profile_displayname(self.u_frank.localpart))
+ self.get_success(self.store.get_profile_displayname(self.u_frank))
)
def test_avatar_url(self) -> None:
@@ -58,18 +55,14 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
self.assertEqual(
"http://my.site/here",
- (
- self.get_success(
- self.store.get_profile_avatar_url(self.u_frank.localpart)
- )
- ),
+ (self.get_success(self.store.get_profile_avatar_url(self.u_frank))),
)
# test set to None
self.get_success(self.store.set_profile_avatar_url(self.u_frank, None))
self.assertIsNone(
- self.get_success(self.store.get_profile_avatar_url(self.u_frank.localpart))
+ self.get_success(self.store.get_profile_avatar_url(self.u_frank))
)
def test_profiles_bg_migration(self) -> None:
diff --git a/tests/test_state.py b/tests/test_state.py
index ddf59916b1..7a49b87953 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -28,7 +28,7 @@ from unittest.mock import Mock
from twisted.internet import defer
-from synapse.api.auth import Auth
+from synapse.api.auth.internal import InternalAuth
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, make_event_from_dict
@@ -240,7 +240,7 @@ class StateTestCase(unittest.TestCase):
hs.get_macaroon_generator.return_value = MacaroonGenerator(
clock, "tesths", b"verysecret"
)
- hs.get_auth.return_value = Auth(hs)
+ hs.get_auth.return_value = InternalAuth(hs)
hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
hs.get_storage_controllers.return_value = storage_controllers
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index e5dae670a7..c8cc841d95 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -33,7 +33,7 @@ from twisted.web.http import RESPONSES
from twisted.web.http_headers import Headers
from twisted.web.iweb import IResponse
-from synapse.types import JsonDict
+from synapse.types import JsonSerializable
if TYPE_CHECKING:
from sys import UnraisableHookArgs
@@ -145,7 +145,7 @@ class FakeResponse: # type: ignore[misc]
protocol.connectionLost(Failure(ResponseDone()))
@classmethod
- def json(cls, *, code: int = 200, payload: JsonDict) -> "FakeResponse":
+ def json(cls, *, code: int = 200, payload: JsonSerializable) -> "FakeResponse":
headers = Headers({"Content-Type": ["application/json"]})
body = json.dumps(payload).encode("utf-8")
return cls(code=code, body=body, headers=headers)
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index c37f205ed0..199bb06a81 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -53,4 +53,16 @@ def setup_logging() -> None:
log_level = os.environ.get("SYNAPSE_TEST_LOG_LEVEL", "ERROR")
root_logger.setLevel(log_level)
+ # In order to not add noise by default (since we only log ERROR messages for trial
+ # tests as configured above), we only enable this for developers for looking for
+ # more INFO or DEBUG.
+ if root_logger.isEnabledFor(logging.INFO):
+ # Log when events are (maybe unexpectedly) filtered out of responses in tests. It's
+ # just nice to be able to look at the CI log and figure out why an event isn't being
+ # returned.
+ logging.getLogger("synapse.visibility.filtered_event_debug").setLevel(
+ logging.DEBUG
+ )
+
+ # Blow away the pyo3-log cache so that it reloads the configuration.
reset_logging_config()
|