diff options
-rw-r--r-- | synapse/api/auth/msc3861_delegated.py | 42 | ||||
-rw-r--r-- | tests/handlers/test_oauth_delegation.py | 35 | ||||
-rw-r--r-- | tests/test_utils/__init__.py | 4 |
3 files changed, 74 insertions, 7 deletions
diff --git a/synapse/api/auth/msc3861_delegated.py b/synapse/api/auth/msc3861_delegated.py index b84dce2563..82c66691da 100644 --- a/synapse/api/auth/msc3861_delegated.py +++ b/synapse/api/auth/msc3861_delegated.py @@ -27,9 +27,11 @@ from twisted.web.http_headers import Headers from synapse.api.auth.base import BaseAuth from synapse.api.errors import ( AuthError, + HttpResponseException, InvalidClientTokenError, OAuthInsufficientScopeError, StoreError, + SynapseError, ) from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable @@ -117,6 +119,21 @@ class MSC3861DelegatedAuth(BaseAuth): return metadata async def _introspect_token(self, token: str) -> IntrospectionToken: + """ + Send a token to the introspection endpoint and returns the introspection response + + Parameters: + token: The token to introspect + + Raises: + HttpResponseException: If the introspection endpoint returns a non-2xx response + ValueError: If the introspection endpoint returns an invalid JSON response + JSONDecodeError: If the introspection endpoint returns a non-JSON response + Exception: If the HTTP request fails + + Returns: + The introspection response + """ metadata = await self._issuer_metadata.get() introspection_endpoint = metadata.get("introspection_endpoint") raw_headers: Dict[str, str] = { @@ -136,7 +153,7 @@ class MSC3861DelegatedAuth(BaseAuth): # Do the actual request # We're not using the SimpleHttpClient util methods as we don't want to - # check the HTTP status code and we do the body encoding ourself. + # check the HTTP status code, and we do the body encoding ourselves. response = await self._http_client.request( method="POST", uri=uri, @@ -145,10 +162,21 @@ class MSC3861DelegatedAuth(BaseAuth): ) resp_body = await make_deferred_yieldable(readBody(response)) - # TODO: Let's not worry about 5xx errors & co. for now and just try - # decoding that as JSON. We should also do some validation of the - # response + + if response.code < 200 or response.code >= 300: + raise HttpResponseException( + response.code, + response.phrase.decode("ascii", errors="replace"), + resp_body, + ) + resp = json_decoder.decode(resp_body.decode("utf-8")) + + if not isinstance(resp, dict): + raise ValueError( + "The introspection endpoint returned an invalid JSON response." + ) + return IntrospectionToken(**resp) async def is_server_admin(self, requester: Requester) -> bool: @@ -196,7 +224,11 @@ class MSC3861DelegatedAuth(BaseAuth): scope=["urn:synapse:admin:*"], ) - introspection_result = await self._introspect_token(token) + try: + introspection_result = await self._introspect_token(token) + except Exception: + logger.exception("Failed to introspect token") + raise SynapseError(503, "Unable to introspect the access token") logger.info(f"Introspection result: {introspection_result!r}") diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py index b79c43a424..16ce2c069d 100644 --- a/tests/handlers/test_oauth_delegation.py +++ b/tests/handlers/test_oauth_delegation.py @@ -30,6 +30,7 @@ from synapse.api.errors import ( Codes, InvalidClientTokenError, OAuthInsufficientScopeError, + SynapseError, ) from synapse.rest import admin from synapse.rest.client import account, devices, keys, login, logout, register @@ -405,6 +406,40 @@ class MSC3861OAuthDelegation(HomeserverTestCase): ) self.assertEqual(requester.device_id, DEVICE) + def test_unavailable_introspection_endpoint(self) -> None: + """The handler should return an internal server error.""" + request = Mock(args={}) + request.args[b"access_token"] = [b"mockAccessToken"] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + + # The introspection endpoint is returning an error. + self.http_client.request = simple_async_mock( + return_value=FakeResponse(code=500, body=b"Internal Server Error") + ) + error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) + self.assertEqual(error.value.code, 503) + + # The introspection endpoint request fails. + self.http_client.request = simple_async_mock(raises=Exception()) + error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) + self.assertEqual(error.value.code, 503) + + # The introspection endpoint does not return a JSON object. + self.http_client.request = simple_async_mock( + return_value=FakeResponse.json( + code=200, payload=["this is an array", "not an object"] + ) + ) + error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) + self.assertEqual(error.value.code, 503) + + # The introspection endpoint does not return valid JSON. + self.http_client.request = simple_async_mock( + return_value=FakeResponse(code=200, body=b"this is not valid JSON") + ) + error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) + self.assertEqual(error.value.code, 503) + def make_device_keys(self, user_id: str, device_id: str) -> JsonDict: # We only generate a master key to simplify the test. master_signing_key = generate_signing_key(device_id) diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index e5dae670a7..c8cc841d95 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -33,7 +33,7 @@ from twisted.web.http import RESPONSES from twisted.web.http_headers import Headers from twisted.web.iweb import IResponse -from synapse.types import JsonDict +from synapse.types import JsonSerializable if TYPE_CHECKING: from sys import UnraisableHookArgs @@ -145,7 +145,7 @@ class FakeResponse: # type: ignore[misc] protocol.connectionLost(Failure(ResponseDone())) @classmethod - def json(cls, *, code: int = 200, payload: JsonDict) -> "FakeResponse": + def json(cls, *, code: int = 200, payload: JsonSerializable) -> "FakeResponse": headers = Headers({"Content-Type": ["application/json"]}) body = json.dumps(payload).encode("utf-8") return cls(code=code, body=body, headers=headers) |