summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/api/auth/msc3861_delegated.py42
-rw-r--r--tests/handlers/test_oauth_delegation.py35
-rw-r--r--tests/test_utils/__init__.py4
3 files changed, 74 insertions, 7 deletions
diff --git a/synapse/api/auth/msc3861_delegated.py b/synapse/api/auth/msc3861_delegated.py
index b84dce2563..82c66691da 100644
--- a/synapse/api/auth/msc3861_delegated.py
+++ b/synapse/api/auth/msc3861_delegated.py
@@ -27,9 +27,11 @@ from twisted.web.http_headers import Headers
 from synapse.api.auth.base import BaseAuth
 from synapse.api.errors import (
     AuthError,
+    HttpResponseException,
     InvalidClientTokenError,
     OAuthInsufficientScopeError,
     StoreError,
+    SynapseError,
 )
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
@@ -117,6 +119,21 @@ class MSC3861DelegatedAuth(BaseAuth):
         return metadata
 
     async def _introspect_token(self, token: str) -> IntrospectionToken:
+        """
+        Send a token to the introspection endpoint and returns the introspection response
+
+        Parameters:
+            token: The token to introspect
+
+        Raises:
+            HttpResponseException: If the introspection endpoint returns a non-2xx response
+            ValueError: If the introspection endpoint returns an invalid JSON response
+            JSONDecodeError: If the introspection endpoint returns a non-JSON response
+            Exception: If the HTTP request fails
+
+        Returns:
+            The introspection response
+        """
         metadata = await self._issuer_metadata.get()
         introspection_endpoint = metadata.get("introspection_endpoint")
         raw_headers: Dict[str, str] = {
@@ -136,7 +153,7 @@ class MSC3861DelegatedAuth(BaseAuth):
 
         # Do the actual request
         # We're not using the SimpleHttpClient util methods as we don't want to
-        # check the HTTP status code and we do the body encoding ourself.
+        # check the HTTP status code, and we do the body encoding ourselves.
         response = await self._http_client.request(
             method="POST",
             uri=uri,
@@ -145,10 +162,21 @@ class MSC3861DelegatedAuth(BaseAuth):
         )
 
         resp_body = await make_deferred_yieldable(readBody(response))
-        # TODO: Let's not worry about 5xx errors & co. for now and just try
-        # decoding that as JSON. We should also do some validation of the
-        # response
+
+        if response.code < 200 or response.code >= 300:
+            raise HttpResponseException(
+                response.code,
+                response.phrase.decode("ascii", errors="replace"),
+                resp_body,
+            )
+
         resp = json_decoder.decode(resp_body.decode("utf-8"))
+
+        if not isinstance(resp, dict):
+            raise ValueError(
+                "The introspection endpoint returned an invalid JSON response."
+            )
+
         return IntrospectionToken(**resp)
 
     async def is_server_admin(self, requester: Requester) -> bool:
@@ -196,7 +224,11 @@ class MSC3861DelegatedAuth(BaseAuth):
                 scope=["urn:synapse:admin:*"],
             )
 
-        introspection_result = await self._introspect_token(token)
+        try:
+            introspection_result = await self._introspect_token(token)
+        except Exception:
+            logger.exception("Failed to introspect token")
+            raise SynapseError(503, "Unable to introspect the access token")
 
         logger.info(f"Introspection result: {introspection_result!r}")
 
diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py
index b79c43a424..16ce2c069d 100644
--- a/tests/handlers/test_oauth_delegation.py
+++ b/tests/handlers/test_oauth_delegation.py
@@ -30,6 +30,7 @@ from synapse.api.errors import (
     Codes,
     InvalidClientTokenError,
     OAuthInsufficientScopeError,
+    SynapseError,
 )
 from synapse.rest import admin
 from synapse.rest.client import account, devices, keys, login, logout, register
@@ -405,6 +406,40 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
         )
         self.assertEqual(requester.device_id, DEVICE)
 
+    def test_unavailable_introspection_endpoint(self) -> None:
+        """The handler should return an internal server error."""
+        request = Mock(args={})
+        request.args[b"access_token"] = [b"mockAccessToken"]
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+
+        # The introspection endpoint is returning an error.
+        self.http_client.request = simple_async_mock(
+            return_value=FakeResponse(code=500, body=b"Internal Server Error")
+        )
+        error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
+        self.assertEqual(error.value.code, 503)
+
+        # The introspection endpoint request fails.
+        self.http_client.request = simple_async_mock(raises=Exception())
+        error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
+        self.assertEqual(error.value.code, 503)
+
+        # The introspection endpoint does not return a JSON object.
+        self.http_client.request = simple_async_mock(
+            return_value=FakeResponse.json(
+                code=200, payload=["this is an array", "not an object"]
+            )
+        )
+        error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
+        self.assertEqual(error.value.code, 503)
+
+        # The introspection endpoint does not return valid JSON.
+        self.http_client.request = simple_async_mock(
+            return_value=FakeResponse(code=200, body=b"this is not valid JSON")
+        )
+        error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
+        self.assertEqual(error.value.code, 503)
+
     def make_device_keys(self, user_id: str, device_id: str) -> JsonDict:
         # We only generate a master key to simplify the test.
         master_signing_key = generate_signing_key(device_id)
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index e5dae670a7..c8cc841d95 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -33,7 +33,7 @@ from twisted.web.http import RESPONSES
 from twisted.web.http_headers import Headers
 from twisted.web.iweb import IResponse
 
-from synapse.types import JsonDict
+from synapse.types import JsonSerializable
 
 if TYPE_CHECKING:
     from sys import UnraisableHookArgs
@@ -145,7 +145,7 @@ class FakeResponse:  # type: ignore[misc]
         protocol.connectionLost(Failure(ResponseDone()))
 
     @classmethod
-    def json(cls, *, code: int = 200, payload: JsonDict) -> "FakeResponse":
+    def json(cls, *, code: int = 200, payload: JsonSerializable) -> "FakeResponse":
         headers = Headers({"Content-Type": ["application/json"]})
         body = json.dumps(payload).encode("utf-8")
         return cls(code=code, body=body, headers=headers)