diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 7e4570f990..144e49d0fd 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -22,7 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
import synapse.storage
-from synapse.api.constants import EduTypes
+from synapse.api.constants import EduTypes, EventTypes
from synapse.appservice import (
ApplicationService,
TransactionOneTimeKeyCounts,
@@ -36,7 +36,7 @@ from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
-from tests.test_utils import make_awaitable, simple_async_mock
+from tests.test_utils import event_injection, make_awaitable, simple_async_mock
from tests.unittest import override_config
from tests.utils import MockClock
@@ -390,15 +390,16 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
receipts.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+ self.hs = hs
# Mock the ApplicationServiceScheduler's _TransactionController's send method so that
# we can track any outgoing ephemeral events
self.send_mock = simple_async_mock()
- hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock
+ hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment]
# Mock out application services, and allow defining our own in tests
self._services: List[ApplicationService] = []
- self.hs.get_datastores().main.get_app_services = Mock(
+ self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[assignment]
return_value=self._services
)
@@ -416,6 +417,157 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
"exclusive_as_user", "password", self.exclusive_as_user_device_id
)
+ def _notify_interested_services(self):
+ # This is normally set in `notify_interested_services` but we need to call the
+ # internal async version so the reactor gets pushed to completion.
+ self.hs.get_application_service_handler().current_max += 1
+ self.get_success(
+ self.hs.get_application_service_handler()._notify_interested_services(
+ RoomStreamToken(
+ None, self.hs.get_application_service_handler().current_max
+ )
+ )
+ )
+
+ @parameterized.expand(
+ [
+ ("@local_as_user:test", True),
+ # Defining remote users in an application service user namespace regex is a
+ # footgun since the appservice might assume that it'll receive all events
+ # sent by that remote user, but it will only receive events in rooms that
+ # are shared with a local user. So we just remove this footgun possibility
+ # entirely and we won't notify the application service based on remote
+ # users.
+ ("@remote_as_user:remote", False),
+ ]
+ )
+ def test_match_interesting_room_members(
+ self, interesting_user: str, should_notify: bool
+ ):
+ """
+ Test to make sure that a interesting user (local or remote) in the room is
+ notified as expected when someone else in the room sends a message.
+ """
+ # Register an application service that's interested in the `interesting_user`
+ interested_appservice = self._register_application_service(
+ namespaces={
+ ApplicationService.NS_USERS: [
+ {
+ "regex": interesting_user,
+ "exclusive": False,
+ },
+ ],
+ },
+ )
+
+ # Create a room
+ alice = self.register_user("alice", "pass")
+ alice_access_token = self.login("alice", "pass")
+ room_id = self.helper.create_room_as(room_creator=alice, tok=alice_access_token)
+
+ # Join the interesting user to the room
+ self.get_success(
+ event_injection.inject_member_event(
+ self.hs, room_id, interesting_user, "join"
+ )
+ )
+ # Kick the appservice into checking this membership event to get the event out
+ # of the way
+ self._notify_interested_services()
+ # We don't care about the interesting user join event (this test is making sure
+ # the next thing works)
+ self.send_mock.reset_mock()
+
+ # Send a message from an uninteresting user
+ self.helper.send_event(
+ room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "message from uninteresting user",
+ },
+ tok=alice_access_token,
+ )
+ # Kick the appservice into checking this new event
+ self._notify_interested_services()
+
+ if should_notify:
+ self.send_mock.assert_called_once()
+ (
+ service,
+ events,
+ _ephemeral,
+ _to_device_messages,
+ _otks,
+ _fbks,
+ _device_list_summary,
+ ) = self.send_mock.call_args[0]
+
+ # Even though the message came from an uninteresting user, it should still
+ # notify us because the interesting user is joined to the room where the
+ # message was sent.
+ self.assertEqual(service, interested_appservice)
+ self.assertEqual(events[0]["type"], "m.room.message")
+ self.assertEqual(events[0]["sender"], alice)
+ else:
+ self.send_mock.assert_not_called()
+
+ def test_application_services_receive_events_sent_by_interesting_local_user(self):
+ """
+ Test to make sure that a messages sent from a local user can be interesting and
+ picked up by the appservice.
+ """
+ # Register an application service that's interested in all local users
+ interested_appservice = self._register_application_service(
+ namespaces={
+ ApplicationService.NS_USERS: [
+ {
+ "regex": ".*",
+ "exclusive": False,
+ },
+ ],
+ },
+ )
+
+ # Create a room
+ alice = self.register_user("alice", "pass")
+ alice_access_token = self.login("alice", "pass")
+ room_id = self.helper.create_room_as(room_creator=alice, tok=alice_access_token)
+
+ # We don't care about interesting events before this (this test is making sure
+ # the next thing works)
+ self.send_mock.reset_mock()
+
+ # Send a message from the interesting local user
+ self.helper.send_event(
+ room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "message from interesting local user",
+ },
+ tok=alice_access_token,
+ )
+ # Kick the appservice into checking this new event
+ self._notify_interested_services()
+
+ self.send_mock.assert_called_once()
+ (
+ service,
+ events,
+ _ephemeral,
+ _to_device_messages,
+ _otks,
+ _fbks,
+ _device_list_summary,
+ ) = self.send_mock.call_args[0]
+
+ # Events sent from an interesting local user should also be picked up as
+ # interesting to the appservice.
+ self.assertEqual(service, interested_appservice)
+ self.assertEqual(events[0]["type"], "m.room.message")
+ self.assertEqual(events[0]["sender"], alice)
+
def test_sending_read_receipt_batches_to_application_services(self):
"""Tests that a large batch of read receipts are sent correctly to
interested application services.
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 7106799d44..036dbbc45b 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional
from unittest.mock import Mock
import pymacaroons
@@ -19,6 +20,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import AuthError, ResourceLimitError
from synapse.rest import admin
+from synapse.rest.client import login
from synapse.server import HomeServer
from synapse.util import Clock
@@ -29,6 +31,7 @@ from tests.test_utils import make_awaitable
class AuthTestCase(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets,
+ login.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -46,6 +49,23 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.user1 = self.register_user("a_user", "pass")
+ def token_login(self, token: str) -> Optional[str]:
+ body = {
+ "type": "m.login.token",
+ "token": token,
+ }
+
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/login",
+ body,
+ )
+
+ if channel.code == 200:
+ return channel.json_body["user_id"]
+
+ return None
+
def test_macaroon_caveats(self) -> None:
token = self.macaroon_generator.generate_guest_access_token("a_user")
macaroon = pymacaroons.Macaroon.deserialize(token)
@@ -73,49 +93,62 @@ class AuthTestCase(unittest.HomeserverTestCase):
v.satisfy_general(verify_guest)
v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
- def test_short_term_login_token_gives_user_id(self) -> None:
- token = self.macaroon_generator.generate_short_term_login_token(
- self.user1, "", duration_in_ms=5000
+ def test_login_token_gives_user_id(self) -> None:
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(
+ self.user1,
+ duration_ms=(5 * 1000),
+ )
)
- res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
+
+ res = self.get_success(self.auth_handler.consume_login_token(token))
self.assertEqual(self.user1, res.user_id)
- self.assertEqual("", res.auth_provider_id)
+ self.assertEqual(None, res.auth_provider_id)
- # when we advance the clock, the token should be rejected
- self.reactor.advance(6)
- self.get_failure(
- self.auth_handler.validate_short_term_login_token(token),
- AuthError,
+ def test_login_token_reuse_fails(self) -> None:
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(
+ self.user1,
+ duration_ms=(5 * 1000),
+ )
)
- def test_short_term_login_token_gives_auth_provider(self) -> None:
- token = self.macaroon_generator.generate_short_term_login_token(
- self.user1, auth_provider_id="my_idp"
- )
- res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
- self.assertEqual(self.user1, res.user_id)
- self.assertEqual("my_idp", res.auth_provider_id)
+ self.get_success(self.auth_handler.consume_login_token(token))
- def test_short_term_login_token_cannot_replace_user_id(self) -> None:
- token = self.macaroon_generator.generate_short_term_login_token(
- self.user1, "", duration_in_ms=5000
+ self.get_failure(
+ self.auth_handler.consume_login_token(token),
+ AuthError,
)
- macaroon = pymacaroons.Macaroon.deserialize(token)
- res = self.get_success(
- self.auth_handler.validate_short_term_login_token(macaroon.serialize())
+ def test_login_token_expires(self) -> None:
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(
+ self.user1,
+ duration_ms=(5 * 1000),
+ )
)
- self.assertEqual(self.user1, res.user_id)
-
- # add another "user_id" caveat, which might allow us to override the
- # user_id.
- macaroon.add_first_party_caveat("user_id = b_user")
+ # when we advance the clock, the token should be rejected
+ self.reactor.advance(6)
self.get_failure(
- self.auth_handler.validate_short_term_login_token(macaroon.serialize()),
+ self.auth_handler.consume_login_token(token),
AuthError,
)
+ def test_login_token_gives_auth_provider(self) -> None:
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(
+ self.user1,
+ auth_provider_id="my_idp",
+ auth_provider_session_id="11-22-33-44",
+ duration_ms=(5 * 1000),
+ )
+ )
+ res = self.get_success(self.auth_handler.consume_login_token(token))
+ self.assertEqual(self.user1, res.user_id)
+ self.assertEqual("my_idp", res.auth_provider_id)
+ self.assertEqual("11-22-33-44", res.auth_provider_session_id)
+
def test_mau_limits_disabled(self) -> None:
self.auth_blocking._limit_usage_by_mau = False
# Ensure does not throw exception
@@ -125,12 +158,12 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
)
- self.get_success(
- self.auth_handler.validate_short_term_login_token(
- self._get_macaroon().serialize()
- )
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(self.user1)
)
+ self.assertIsNotNone(self.token_login(token))
+
def test_mau_limits_exceeded_large(self) -> None:
self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastores().main.get_monthly_active_count = Mock(
@@ -147,12 +180,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.hs.get_datastores().main.get_monthly_active_count = Mock(
return_value=make_awaitable(self.large_number_of_users)
)
- self.get_failure(
- self.auth_handler.validate_short_term_login_token(
- self._get_macaroon().serialize()
- ),
- ResourceLimitError,
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(self.user1)
)
+ self.assertIsNone(self.token_login(token))
def test_mau_limits_parity(self) -> None:
# Ensure we're not at the unix epoch.
@@ -171,12 +202,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
),
ResourceLimitError,
)
- self.get_failure(
- self.auth_handler.validate_short_term_login_token(
- self._get_macaroon().serialize()
- ),
- ResourceLimitError,
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(self.user1)
)
+ self.assertIsNone(self.token_login(token))
# If in monthly active cohort
self.hs.get_datastores().main.user_last_seen_monthly_active = Mock(
@@ -187,11 +216,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.user1, device_id=None, valid_until_ms=None
)
)
- self.get_success(
- self.auth_handler.validate_short_term_login_token(
- self._get_macaroon().serialize()
- )
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(self.user1)
)
+ self.assertIsNotNone(self.token_login(token))
def test_mau_limits_not_exceeded(self) -> None:
self.auth_blocking._limit_usage_by_mau = True
@@ -209,14 +237,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.hs.get_datastores().main.get_monthly_active_count = Mock(
return_value=make_awaitable(self.small_number_of_users)
)
- self.get_success(
- self.auth_handler.validate_short_term_login_token(
- self._get_macaroon().serialize()
- )
- )
-
- def _get_macaroon(self) -> pymacaroons.Macaroon:
- token = self.macaroon_generator.generate_short_term_login_token(
- self.user1, "", duration_in_ms=5000
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(self.user1)
)
- return pymacaroons.Macaroon.deserialize(token)
+ self.assertIsNotNone(self.token_login(token))
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index e6cd3af7b7..5955410524 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
import os
-from typing import Any, Dict
+from typing import Any, Dict, Tuple
from unittest.mock import ANY, Mock, patch
from urllib.parse import parse_qs, urlparse
@@ -22,12 +21,15 @@ import pymacaroons
from twisted.test.proto_helpers import MemoryReactor
from synapse.handlers.sso import MappingException
+from synapse.http.site import SynapseRequest
from synapse.server import HomeServer
-from synapse.types import JsonDict, UserID
+from synapse.types import UserID
from synapse.util import Clock
-from synapse.util.macaroons import OidcSessionData, get_value_from_macaroon
+from synapse.util.macaroons import get_value_from_macaroon
+from synapse.util.stringutils import random_string
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
+from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer
from tests.unittest import HomeserverTestCase, override_config
try:
@@ -46,12 +48,6 @@ BASE_URL = "https://synapse/"
CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback"
SCOPES = ["openid"]
-AUTHORIZATION_ENDPOINT = ISSUER + "authorize"
-TOKEN_ENDPOINT = ISSUER + "token"
-USERINFO_ENDPOINT = ISSUER + "userinfo"
-WELL_KNOWN = ISSUER + ".well-known/openid-configuration"
-JWKS_URI = ISSUER + ".well-known/jwks.json"
-
# config for common cases
DEFAULT_CONFIG = {
"enabled": True,
@@ -66,9 +62,9 @@ DEFAULT_CONFIG = {
EXPLICIT_ENDPOINT_CONFIG = {
**DEFAULT_CONFIG,
"discover": False,
- "authorization_endpoint": AUTHORIZATION_ENDPOINT,
- "token_endpoint": TOKEN_ENDPOINT,
- "jwks_uri": JWKS_URI,
+ "authorization_endpoint": ISSUER + "authorize",
+ "token_endpoint": ISSUER + "token",
+ "jwks_uri": ISSUER + "jwks",
}
@@ -102,27 +98,6 @@ class TestMappingProviderFailures(TestMappingProvider):
}
-async def get_json(url: str) -> JsonDict:
- # Mock get_json calls to handle jwks & oidc discovery endpoints
- if url == WELL_KNOWN:
- # Minimal discovery document, as defined in OpenID.Discovery
- # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
- return {
- "issuer": ISSUER,
- "authorization_endpoint": AUTHORIZATION_ENDPOINT,
- "token_endpoint": TOKEN_ENDPOINT,
- "jwks_uri": JWKS_URI,
- "userinfo_endpoint": USERINFO_ENDPOINT,
- "response_types_supported": ["code"],
- "subject_types_supported": ["public"],
- "id_token_signing_alg_values_supported": ["RS256"],
- }
- elif url == JWKS_URI:
- return {"keys": []}
-
- return {}
-
-
def _key_file_path() -> str:
"""path to a file containing the private half of a test key"""
@@ -159,11 +134,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
return config
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- self.http_client = Mock(spec=["get_json"])
- self.http_client.get_json.side_effect = get_json
- self.http_client.user_agent = b"Synapse Test"
+ self.fake_server = FakeOidcServer(clock=clock, issuer=ISSUER)
- hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
+ hs = self.setup_test_homeserver()
+ self.hs_patcher = self.fake_server.patch_homeserver(hs=hs)
+ self.hs_patcher.start()
self.handler = hs.get_oidc_handler()
self.provider = self.handler._providers["oidc"]
@@ -175,18 +150,51 @@ class OidcHandlerTestCase(HomeserverTestCase):
# Reduce the number of attempts when generating MXIDs.
sso_handler._MAP_USERNAME_RETRIES = 3
+ auth_handler = hs.get_auth_handler()
+ # Mock the complete SSO login method.
+ self.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = self.complete_sso_login # type: ignore[assignment]
+
return hs
+ def tearDown(self) -> None:
+ self.hs_patcher.stop()
+ return super().tearDown()
+
+ def reset_mocks(self):
+ """Reset all the Mocks."""
+ self.fake_server.reset_mocks()
+ self.render_error.reset_mock()
+ self.complete_sso_login.reset_mock()
+
def metadata_edit(self, values):
"""Modify the result that will be returned by the well-known query"""
- async def patched_get_json(uri):
- res = await get_json(uri)
- if uri == WELL_KNOWN:
- res.update(values)
- return res
+ metadata = self.fake_server.get_metadata()
+ metadata.update(values)
+ return patch.object(self.fake_server, "get_metadata", return_value=metadata)
- return patch.object(self.http_client, "get_json", patched_get_json)
+ def start_authorization(
+ self,
+ userinfo: dict,
+ client_redirect_url: str = "http://client/redirect",
+ scope: str = "openid",
+ with_sid: bool = False,
+ ) -> Tuple[SynapseRequest, FakeAuthorizationGrant]:
+ """Start an authorization request, and get the callback request back."""
+ nonce = random_string(10)
+ state = random_string(10)
+
+ code, grant = self.fake_server.start_authorization(
+ userinfo=userinfo,
+ scope=scope,
+ client_id=self.provider._client_auth.client_id,
+ redirect_uri=self.provider._callback_url,
+ nonce=nonce,
+ with_sid=with_sid,
+ )
+ session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
+ return _build_callback_request(code, state, session), grant
def assertRenderedError(self, error, error_description=None):
self.render_error.assert_called_once()
@@ -210,52 +218,54 @@ class OidcHandlerTestCase(HomeserverTestCase):
"""The handler should discover the endpoints from OIDC discovery document."""
# This would throw if some metadata were invalid
metadata = self.get_success(self.provider.load_metadata())
- self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+ self.fake_server.get_metadata_handler.assert_called_once()
- self.assertEqual(metadata.issuer, ISSUER)
- self.assertEqual(metadata.authorization_endpoint, AUTHORIZATION_ENDPOINT)
- self.assertEqual(metadata.token_endpoint, TOKEN_ENDPOINT)
- self.assertEqual(metadata.jwks_uri, JWKS_URI)
- # FIXME: it seems like authlib does not have that defined in its metadata models
- # self.assertEqual(metadata.userinfo_endpoint, USERINFO_ENDPOINT)
+ self.assertEqual(metadata.issuer, self.fake_server.issuer)
+ self.assertEqual(
+ metadata.authorization_endpoint,
+ self.fake_server.authorization_endpoint,
+ )
+ self.assertEqual(metadata.token_endpoint, self.fake_server.token_endpoint)
+ self.assertEqual(metadata.jwks_uri, self.fake_server.jwks_uri)
+ # It seems like authlib does not have that defined in its metadata models
+ self.assertEqual(
+ metadata.get("userinfo_endpoint"),
+ self.fake_server.userinfo_endpoint,
+ )
# subsequent calls should be cached
- self.http_client.reset_mock()
+ self.reset_mocks()
self.get_success(self.provider.load_metadata())
- self.http_client.get_json.assert_not_called()
+ self.fake_server.get_metadata_handler.assert_not_called()
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
def test_no_discovery(self) -> None:
"""When discovery is disabled, it should not try to load from discovery document."""
self.get_success(self.provider.load_metadata())
- self.http_client.get_json.assert_not_called()
+ self.fake_server.get_metadata_handler.assert_not_called()
- @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
+ @override_config({"oidc_config": DEFAULT_CONFIG})
def test_load_jwks(self) -> None:
"""JWKS loading is done once (then cached) if used."""
jwks = self.get_success(self.provider.load_jwks())
- self.http_client.get_json.assert_called_once_with(JWKS_URI)
- self.assertEqual(jwks, {"keys": []})
+ self.fake_server.get_jwks_handler.assert_called_once()
+ self.assertEqual(jwks, self.fake_server.get_jwks())
# subsequent calls should be cached…
- self.http_client.reset_mock()
+ self.reset_mocks()
self.get_success(self.provider.load_jwks())
- self.http_client.get_json.assert_not_called()
+ self.fake_server.get_jwks_handler.assert_not_called()
# …unless forced
- self.http_client.reset_mock()
+ self.reset_mocks()
self.get_success(self.provider.load_jwks(force=True))
- self.http_client.get_json.assert_called_once_with(JWKS_URI)
+ self.fake_server.get_jwks_handler.assert_called_once()
- # Throw if the JWKS uri is missing
- original = self.provider.load_metadata
-
- async def patched_load_metadata():
- m = (await original()).copy()
- m.update({"jwks_uri": None})
- return m
-
- with patch.object(self.provider, "load_metadata", patched_load_metadata):
+ with self.metadata_edit({"jwks_uri": None}):
+ # If we don't do this, the load_metadata call will throw because of the
+ # missing jwks_uri
+ self.provider._user_profile_method = "userinfo_endpoint"
+ self.get_success(self.provider.load_metadata(force=True))
self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
@override_config({"oidc_config": DEFAULT_CONFIG})
@@ -359,7 +369,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.provider.handle_redirect_request(req, b"http://client/redirect")
)
)
- auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
+ auth_endpoint = urlparse(self.fake_server.authorization_endpoint)
self.assertEqual(url.scheme, auth_endpoint.scheme)
self.assertEqual(url.netloc, auth_endpoint.netloc)
@@ -424,48 +434,34 @@ class OidcHandlerTestCase(HomeserverTestCase):
with self.assertRaises(AttributeError):
_ = mapping_provider.get_extra_attributes
- token = {
- "type": "bearer",
- "id_token": "id_token",
- "access_token": "access_token",
- }
username = "bar"
userinfo = {
"sub": "foo",
"username": username,
}
expected_user_id = "@%s:%s" % (username, self.hs.hostname)
- self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
- self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
- self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
- code = "code"
- state = "state"
- nonce = "nonce"
client_redirect_url = "http://client/redirect"
- ip_address = "10.0.0.1"
- session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
- request = _build_callback_request(code, state, session, ip_address=ip_address)
-
+ request, _ = self.start_authorization(
+ userinfo, client_redirect_url=client_redirect_url
+ )
self.get_success(self.handler.handle_oidc_callback(request))
- auth_handler.complete_sso_login.assert_called_once_with(
+ self.complete_sso_login.assert_called_once_with(
expected_user_id,
- "oidc",
+ self.provider.idp_id,
request,
client_redirect_url,
None,
new_user=True,
auth_provider_session_id=None,
)
- self.provider._exchange_code.assert_called_once_with(code)
- self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
- self.provider._fetch_userinfo.assert_not_called()
+ self.fake_server.post_token_handler.assert_called_once()
+ self.fake_server.get_userinfo_handler.assert_not_called()
self.render_error.assert_not_called()
# Handle mapping errors
+ request, _ = self.start_authorization(userinfo)
with patch.object(
self.provider,
"_remote_id_from_userinfo",
@@ -475,81 +471,63 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertRenderedError("mapping_error")
# Handle ID token errors
- self.provider._parse_id_token = simple_async_mock(raises=Exception()) # type: ignore[assignment]
- self.get_success(self.handler.handle_oidc_callback(request))
+ request, _ = self.start_authorization(userinfo)
+ with self.fake_server.id_token_override({"iss": "https://bad.issuer/"}):
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_token")
- auth_handler.complete_sso_login.reset_mock()
- self.provider._exchange_code.reset_mock()
- self.provider._parse_id_token.reset_mock()
- self.provider._fetch_userinfo.reset_mock()
+ self.reset_mocks()
# With userinfo fetching
self.provider._user_profile_method = "userinfo_endpoint"
- token = {
- "type": "bearer",
- "access_token": "access_token",
- }
- self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
+ # Without the "openid" scope, the FakeProvider does not generate an id_token
+ request, _ = self.start_authorization(userinfo, scope="")
self.get_success(self.handler.handle_oidc_callback(request))
- auth_handler.complete_sso_login.assert_called_once_with(
+ self.complete_sso_login.assert_called_once_with(
expected_user_id,
- "oidc",
+ self.provider.idp_id,
request,
- client_redirect_url,
+ ANY,
None,
new_user=False,
auth_provider_session_id=None,
)
- self.provider._exchange_code.assert_called_once_with(code)
- self.provider._parse_id_token.assert_not_called()
- self.provider._fetch_userinfo.assert_called_once_with(token)
+ self.fake_server.post_token_handler.assert_called_once()
+ self.fake_server.get_userinfo_handler.assert_called_once()
self.render_error.assert_not_called()
+ self.reset_mocks()
+
# With an ID token, userinfo fetching and sid in the ID token
self.provider._user_profile_method = "userinfo_endpoint"
- token = {
- "type": "bearer",
- "access_token": "access_token",
- "id_token": "id_token",
- }
- id_token = {
- "sid": "abcdefgh",
- }
- self.provider._parse_id_token = simple_async_mock(return_value=id_token) # type: ignore[assignment]
- self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
- auth_handler.complete_sso_login.reset_mock()
- self.provider._fetch_userinfo.reset_mock()
+ request, grant = self.start_authorization(userinfo, with_sid=True)
+ self.assertIsNotNone(grant.sid)
self.get_success(self.handler.handle_oidc_callback(request))
- auth_handler.complete_sso_login.assert_called_once_with(
+ self.complete_sso_login.assert_called_once_with(
expected_user_id,
- "oidc",
+ self.provider.idp_id,
request,
- client_redirect_url,
+ ANY,
None,
new_user=False,
- auth_provider_session_id=id_token["sid"],
+ auth_provider_session_id=grant.sid,
)
- self.provider._exchange_code.assert_called_once_with(code)
- self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
- self.provider._fetch_userinfo.assert_called_once_with(token)
+ self.fake_server.post_token_handler.assert_called_once()
+ self.fake_server.get_userinfo_handler.assert_called_once()
self.render_error.assert_not_called()
# Handle userinfo fetching error
- self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) # type: ignore[assignment]
- self.get_success(self.handler.handle_oidc_callback(request))
+ request, _ = self.start_authorization(userinfo)
+ with self.fake_server.buggy_endpoint(userinfo=True):
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("fetch_error")
- # Handle code exchange failure
- from synapse.handlers.oidc import OidcError
-
- self.provider._exchange_code = simple_async_mock( # type: ignore[assignment]
- raises=OidcError("invalid_request")
- )
- self.get_success(self.handler.handle_oidc_callback(request))
- self.assertRenderedError("invalid_request")
+ request, _ = self.start_authorization(userinfo)
+ with self.fake_server.buggy_endpoint(token=True):
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("server_error")
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback_session(self) -> None:
@@ -599,18 +577,22 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
def test_exchange_code(self) -> None:
"""Code exchange behaves correctly and handles various error scenarios."""
- token = {"type": "bearer"}
- token_json = json.dumps(token).encode("utf-8")
- self.http_client.request = simple_async_mock(
- return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
+ token = {
+ "type": "Bearer",
+ "access_token": "aabbcc",
+ }
+
+ self.fake_server.post_token_handler.side_effect = None
+ self.fake_server.post_token_handler.return_value = FakeResponse.json(
+ payload=token
)
code = "code"
ret = self.get_success(self.provider._exchange_code(code))
- kwargs = self.http_client.request.call_args[1]
+ kwargs = self.fake_server.request.call_args[1]
self.assertEqual(ret, token)
self.assertEqual(kwargs["method"], "POST")
- self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+ self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)
args = parse_qs(kwargs["data"].decode("utf-8"))
self.assertEqual(args["grant_type"], ["authorization_code"])
@@ -620,12 +602,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
# Test error handling
- self.http_client.request = simple_async_mock(
- return_value=FakeResponse(
- code=400,
- phrase=b"Bad Request",
- body=b'{"error": "foo", "error_description": "bar"}',
- )
+ self.fake_server.post_token_handler.return_value = FakeResponse.json(
+ code=400, payload={"error": "foo", "error_description": "bar"}
)
from synapse.handlers.oidc import OidcError
@@ -634,46 +612,30 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(exc.value.error_description, "bar")
# Internal server error with no JSON body
- self.http_client.request = simple_async_mock(
- return_value=FakeResponse(
- code=500,
- phrase=b"Internal Server Error",
- body=b"Not JSON",
- )
+ self.fake_server.post_token_handler.return_value = FakeResponse(
+ code=500, body=b"Not JSON"
)
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "server_error")
# Internal server error with JSON body
- self.http_client.request = simple_async_mock(
- return_value=FakeResponse(
- code=500,
- phrase=b"Internal Server Error",
- body=b'{"error": "internal_server_error"}',
- )
+ self.fake_server.post_token_handler.return_value = FakeResponse.json(
+ code=500, payload={"error": "internal_server_error"}
)
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "internal_server_error")
# 4xx error without "error" field
- self.http_client.request = simple_async_mock(
- return_value=FakeResponse(
- code=400,
- phrase=b"Bad request",
- body=b"{}",
- )
+ self.fake_server.post_token_handler.return_value = FakeResponse.json(
+ code=400, payload={}
)
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "server_error")
# 2xx error with "error" field
- self.http_client.request = simple_async_mock(
- return_value=FakeResponse(
- code=200,
- phrase=b"OK",
- body=b'{"error": "some_error"}',
- )
+ self.fake_server.post_token_handler.return_value = FakeResponse.json(
+ code=200, payload={"error": "some_error"}
)
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "some_error")
@@ -697,11 +659,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
"""Test that code exchange works with a JWK client secret."""
from authlib.jose import jwt
- token = {"type": "bearer"}
- self.http_client.request = simple_async_mock(
- return_value=FakeResponse(
- code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
- )
+ token = {
+ "type": "Bearer",
+ "access_token": "aabbcc",
+ }
+
+ self.fake_server.post_token_handler.side_effect = None
+ self.fake_server.post_token_handler.return_value = FakeResponse.json(
+ payload=token
)
code = "code"
@@ -714,9 +679,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(ret, token)
# the request should have hit the token endpoint
- kwargs = self.http_client.request.call_args[1]
+ kwargs = self.fake_server.request.call_args[1]
self.assertEqual(kwargs["method"], "POST")
- self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+ self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)
# the client secret provided to the should be a jwt which can be checked with
# the public key
@@ -750,11 +715,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
def test_exchange_code_no_auth(self) -> None:
"""Test that code exchange works with no client secret."""
- token = {"type": "bearer"}
- self.http_client.request = simple_async_mock(
- return_value=FakeResponse(
- code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
- )
+ token = {
+ "type": "Bearer",
+ "access_token": "aabbcc",
+ }
+
+ self.fake_server.post_token_handler.side_effect = None
+ self.fake_server.post_token_handler.return_value = FakeResponse.json(
+ payload=token
)
code = "code"
ret = self.get_success(self.provider._exchange_code(code))
@@ -762,9 +730,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(ret, token)
# the request should have hit the token endpoint
- kwargs = self.http_client.request.call_args[1]
+ kwargs = self.fake_server.request.call_args[1]
self.assertEqual(kwargs["method"], "POST")
- self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+ self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)
# check the POSTed data
args = parse_qs(kwargs["data"].decode("utf-8"))
@@ -787,37 +755,19 @@ class OidcHandlerTestCase(HomeserverTestCase):
"""
Login while using a mapping provider that implements get_extra_attributes.
"""
- token = {
- "type": "bearer",
- "id_token": "id_token",
- "access_token": "access_token",
- }
userinfo = {
"sub": "foo",
"username": "foo",
"phone": "1234567",
}
- self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
- self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
-
- state = "state"
- client_redirect_url = "http://client/redirect"
- session = self._generate_oidc_session_token(
- state=state,
- nonce="nonce",
- client_redirect_url=client_redirect_url,
- )
- request = _build_callback_request("code", state, session)
-
+ request, _ = self.start_authorization(userinfo)
self.get_success(self.handler.handle_oidc_callback(request))
- auth_handler.complete_sso_login.assert_called_once_with(
+ self.complete_sso_login.assert_called_once_with(
"@foo:test",
- "oidc",
+ self.provider.idp_id,
request,
- client_redirect_url,
+ ANY,
{"phone": "1234567"},
new_user=True,
auth_provider_session_id=None,
@@ -826,41 +776,40 @@ class OidcHandlerTestCase(HomeserverTestCase):
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_map_userinfo_to_user(self) -> None:
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
-
userinfo: dict = {
"sub": "test_user",
"username": "test_user",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_called_once_with(
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_called_once_with(
"@test_user:test",
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=True,
auth_provider_session_id=None,
)
- auth_handler.complete_sso_login.reset_mock()
+ self.reset_mocks()
# Some providers return an integer ID.
userinfo = {
"sub": 1234,
"username": "test_user_2",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_called_once_with(
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_called_once_with(
"@test_user_2:test",
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=True,
auth_provider_session_id=None,
)
- auth_handler.complete_sso_login.reset_mock()
+ self.reset_mocks()
# Test if the mxid is already taken
store = self.hs.get_datastores().main
@@ -869,8 +818,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user3.to_string(), password_hash=None)
)
userinfo = {"sub": "test3", "username": "test_user_3"}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
self.assertRenderedError(
"mapping_error",
"Mapping provider does not support de-duplicating Matrix IDs",
@@ -885,38 +835,37 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user.to_string(), password_hash=None)
)
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
-
# Map a user via SSO.
userinfo = {
"sub": "test",
"username": "test_user",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_called_once_with(
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_called_once_with(
user.to_string(),
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=False,
auth_provider_session_id=None,
)
- auth_handler.complete_sso_login.reset_mock()
+ self.reset_mocks()
# Subsequent calls should map to the same mxid.
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_called_once_with(
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_called_once_with(
user.to_string(),
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=False,
auth_provider_session_id=None,
)
- auth_handler.complete_sso_login.reset_mock()
+ self.reset_mocks()
# Note that a second SSO user can be mapped to the same Matrix ID. (This
# requires a unique sub, but something that maps to the same matrix ID,
@@ -927,17 +876,18 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test1",
"username": "test_user",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_called_once_with(
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_called_once_with(
user.to_string(),
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=False,
auth_provider_session_id=None,
)
- auth_handler.complete_sso_login.reset_mock()
+ self.reset_mocks()
# Register some non-exact matching cases.
user2 = UserID.from_string("@TEST_user_2:test")
@@ -954,8 +904,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test2",
"username": "TEST_USER_2",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
args = self.assertRenderedError("mapping_error")
self.assertTrue(
args[2].startswith(
@@ -969,11 +920,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user2.to_string(), password_hash=None)
)
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_called_once_with(
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_called_once_with(
"@TEST_USER_2:test",
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=False,
@@ -983,9 +935,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_map_userinfo_to_invalid_localpart(self) -> None:
"""If the mapping provider generates an invalid localpart it should be rejected."""
- self.get_success(
- _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
- )
+ userinfo = {"sub": "test2", "username": "föö"}
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("mapping_error", "localpart is invalid: föö")
@override_config(
@@ -1000,9 +952,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
def test_map_userinfo_to_user_retries(self) -> None:
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
-
store = self.hs.get_datastores().main
self.get_success(
store.register_user(user_id="@test_user:test", password_hash=None)
@@ -1011,19 +960,20 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test",
"username": "test_user",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
# test_user is already taken, so test_user1 gets registered instead.
- auth_handler.complete_sso_login.assert_called_once_with(
+ self.complete_sso_login.assert_called_once_with(
"@test_user1:test",
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=True,
auth_provider_session_id=None,
)
- auth_handler.complete_sso_login.reset_mock()
+ self.reset_mocks()
# Register all of the potential mxids for a particular OIDC username.
self.get_success(
@@ -1039,8 +989,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "tester",
"username": "tester",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
self.assertRenderedError(
"mapping_error", "Unable to generate a Matrix ID from the SSO response"
)
@@ -1052,7 +1003,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "tester",
"username": "",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("mapping_error", "localpart is invalid: ")
@override_config(
@@ -1071,7 +1023,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "tester",
"username": None,
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("mapping_error", "localpart is invalid: ")
@override_config(
@@ -1084,16 +1037,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
def test_attribute_requirements(self) -> None:
"""The required attributes must be met from the OIDC userinfo response."""
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
-
# userinfo lacking "test": "foobar" attribute should fail.
userinfo = {
"sub": "tester",
"username": "tester",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
# userinfo with "test": "foobar" attribute should succeed.
userinfo = {
@@ -1101,13 +1052,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "tester",
"test": "foobar",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
# check that the auth handler got called as expected
- auth_handler.complete_sso_login.assert_called_once_with(
+ self.complete_sso_login.assert_called_once_with(
"@tester:test",
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=True,
@@ -1124,21 +1076,20 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
def test_attribute_requirements_contains(self) -> None:
"""Test that auth succeeds if userinfo attribute CONTAINS required value"""
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
# userinfo with "test": ["foobar", "foo", "bar"] attribute should succeed.
userinfo = {
"sub": "tester",
"username": "tester",
"test": ["foobar", "foo", "bar"],
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
# check that the auth handler got called as expected
- auth_handler.complete_sso_login.assert_called_once_with(
+ self.complete_sso_login.assert_called_once_with(
"@tester:test",
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=True,
@@ -1158,16 +1109,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
Test that auth fails if attributes exist but don't match,
or are non-string values.
"""
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
# userinfo with "test": "not_foobar" attribute should fail
userinfo: dict = {
"sub": "tester",
"username": "tester",
"test": "not_foobar",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
# userinfo with "test": ["foo", "bar"] attribute should fail
userinfo = {
@@ -1175,8 +1125,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "tester",
"test": ["foo", "bar"],
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
# userinfo with "test": False attribute should fail
# this is largely just to ensure we don't crash here
@@ -1185,8 +1136,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "tester",
"test": False,
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
# userinfo with "test": None attribute should fail
# a value of None breaks the OIDC spec, but it's important to not crash here
@@ -1195,8 +1147,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "tester",
"test": None,
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
# userinfo with "test": 1 attribute should fail
# this is largely just to ensure we don't crash here
@@ -1205,8 +1158,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "tester",
"test": 1,
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
# userinfo with "test": 3.14 attribute should fail
# this is largely just to ensure we don't crash here
@@ -1215,8 +1169,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "tester",
"test": 3.14,
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
def _generate_oidc_session_token(
self,
@@ -1230,7 +1185,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
return self.handler._macaroon_generator.generate_oidc_session_token(
state=state,
session_data=OidcSessionData(
- idp_id="oidc",
+ idp_id=self.provider.idp_id,
nonce=nonce,
client_redirect_url=client_redirect_url,
ui_auth_session_id=ui_auth_session_id,
@@ -1238,41 +1193,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
-async def _make_callback_with_userinfo(
- hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect"
-) -> None:
- """Mock up an OIDC callback with the given userinfo dict
-
- We'll pull out the OIDC handler from the homeserver, stub out a couple of methods,
- and poke in the userinfo dict as if it were the response to an OIDC userinfo call.
-
- Args:
- hs: the HomeServer impl to send the callback to.
- userinfo: the OIDC userinfo dict
- client_redirect_url: the URL to redirect to on success.
- """
-
- handler = hs.get_oidc_handler()
- provider = handler._providers["oidc"]
- provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) # type: ignore[assignment]
- provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
- provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
-
- state = "state"
- session = handler._macaroon_generator.generate_oidc_session_token(
- state=state,
- session_data=OidcSessionData(
- idp_id="oidc",
- nonce="nonce",
- client_redirect_url=client_redirect_url,
- ui_auth_session_id="",
- ),
- )
- request = _build_callback_request("code", state, session)
-
- await handler.handle_oidc_callback(request)
-
-
def _build_callback_request(
code: str,
state: str,
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index f88c725a42..675aa023ac 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -14,6 +14,8 @@
from typing import Any, Awaitable, Callable, Dict
from unittest.mock import Mock
+from parameterized import parameterized
+
from twisted.test.proto_helpers import MemoryReactor
import synapse.types
@@ -327,6 +329,53 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertFalse(res)
+ @unittest.override_config(
+ {"server_name": "test:8888", "allowed_avatar_mimetypes": ["image/png"]}
+ )
+ def test_avatar_constraint_on_local_server_with_port(self):
+ """Test that avatar metadata is correctly fetched when the media is on a local
+ server and the server has an explicit port.
+
+ (This was previously a bug)
+ """
+ local_server_name = self.hs.config.server.server_name
+ media_id = "local"
+ local_mxc = f"mxc://{local_server_name}/{media_id}"
+
+ # mock up the existence of the avatar file
+ self._setup_local_files({media_id: {"mimetype": "image/png"}})
+
+ # and now check that check_avatar_size_and_mime_type is happy
+ self.assertTrue(
+ self.get_success(self.handler.check_avatar_size_and_mime_type(local_mxc))
+ )
+
+ @parameterized.expand([("remote",), ("remote:1234",)])
+ @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
+ def test_check_avatar_on_remote_server(self, remote_server_name: str) -> None:
+ """Test that avatar metadata is correctly fetched from a remote server"""
+ media_id = "remote"
+ remote_mxc = f"mxc://{remote_server_name}/{media_id}"
+
+ # if the media is remote, check_avatar_size_and_mime_type just checks the
+ # media cache, so we don't need to instantiate a real remote server. It is
+ # sufficient to poke an entry into the db.
+ self.get_success(
+ self.hs.get_datastores().main.store_cached_remote_media(
+ media_id=media_id,
+ media_type="image/png",
+ media_length=50,
+ origin=remote_server_name,
+ time_now_ms=self.clock.time_msec(),
+ upload_name=None,
+ filesystem_id="xyz",
+ )
+ )
+
+ self.assertTrue(
+ self.get_success(self.handler.check_avatar_size_and_mime_type(remote_mxc))
+ )
+
def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]):
"""Stores metadata about files in the database.
|