diff options
Diffstat (limited to 'tests/handlers/test_oauth_delegation.py')
-rw-r--r-- | tests/handlers/test_oauth_delegation.py | 129 |
1 files changed, 115 insertions, 14 deletions
diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py index a72ecfdc97..a2b8e3d562 100644 --- a/tests/handlers/test_oauth_delegation.py +++ b/tests/handlers/test_oauth_delegation.py @@ -13,7 +13,8 @@ # limitations under the License. from http import HTTPStatus -from typing import Any, Dict, Union +from io import BytesIO +from typing import Any, Dict, Optional, Union from unittest.mock import ANY, AsyncMock, Mock from urllib.parse import parse_qs @@ -25,6 +26,8 @@ from signedjson.key import ( from signedjson.sign import sign_json from twisted.test.proto_helpers import MemoryReactor +from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse from synapse.api.errors import ( AuthError, @@ -33,23 +36,17 @@ from synapse.api.errors import ( OAuthInsufficientScopeError, SynapseError, ) +from synapse.http.site import SynapseRequest from synapse.rest import admin from synapse.rest.client import account, devices, keys, login, logout, register from synapse.server import HomeServer -from synapse.types import JsonDict +from synapse.types import JsonDict, UserID from synapse.util import Clock +from tests.server import FakeChannel from tests.test_utils import FakeResponse, get_awaitable_result -from tests.unittest import HomeserverTestCase, skip_unless -from tests.utils import mock_getRawHeaders - -try: - import authlib # noqa: F401 - - HAS_AUTHLIB = True -except ImportError: - HAS_AUTHLIB = False - +from tests.unittest import HomeserverTestCase, override_config, skip_unless +from tests.utils import HAS_AUTHLIB, checked_cast, mock_getRawHeaders # These are a few constants that are used as config parameters in the tests. SERVER_NAME = "test" @@ -75,6 +72,7 @@ MATRIX_DEVICE_SCOPE = MATRIX_DEVICE_SCOPE_PREFIX + DEVICE SUBJECT = "abc-def-ghi" USERNAME = "test-user" USER_ID = "@" + USERNAME + ":" + SERVER_NAME +OIDC_ADMIN_USERID = f"@__oidc_admin:{SERVER_NAME}" async def get_json(url: str) -> JsonDict: @@ -134,7 +132,10 @@ class MSC3861OAuthDelegation(HomeserverTestCase): hs = self.setup_test_homeserver(proxied_http_client=self.http_client) - self.auth = hs.get_auth() + # Import this here so that we've checked that authlib is available. + from synapse.api.auth.msc3861_delegated import MSC3861DelegatedAuth + + self.auth = checked_cast(MSC3861DelegatedAuth, hs.get_auth()) return hs @@ -675,7 +676,8 @@ class MSC3861OAuthDelegation(HomeserverTestCase): request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = self.get_success(self.auth.get_user_by_req(request)) self.assertEqual( - requester.user.to_string(), "@%s:%s" % ("__oidc_admin", SERVER_NAME) + requester.user.to_string(), + OIDC_ADMIN_USERID, ) self.assertEqual(requester.is_guest, False) self.assertEqual(requester.device_id, None) @@ -685,3 +687,102 @@ class MSC3861OAuthDelegation(HomeserverTestCase): # There should be no call to the introspection endpoint self.http_client.request.assert_not_called() + + @override_config({"mau_stats_only": True}) + def test_request_tracking(self) -> None: + """Using an access token should update the client_ips and MAU tables.""" + # To start, there are no MAU users. + store = self.hs.get_datastores().main + mau = self.get_success(store.get_monthly_active_count()) + self.assertEqual(mau, 0) + + known_token = "token-token-GOOD-:)" + + async def mock_http_client_request( + method: str, + uri: str, + data: Optional[bytes] = None, + headers: Optional[Headers] = None, + ) -> IResponse: + """Mocked auth provider response.""" + assert method == "POST" + token = parse_qs(data)[b"token"][0].decode("utf-8") + if token == known_token: + return FakeResponse.json( + code=200, + payload={ + "active": True, + "scope": MATRIX_USER_SCOPE, + "sub": SUBJECT, + "username": USERNAME, + }, + ) + + return FakeResponse.json(code=200, payload={"active": False}) + + self.http_client.request = mock_http_client_request + + EXAMPLE_IPV4_ADDR = "123.123.123.123" + EXAMPLE_USER_AGENT = "httprettygood" + + # First test a known access token + channel = FakeChannel(self.site, self.reactor) + # type-ignore: FakeChannel is a mock of an HTTPChannel, not a proper HTTPChannel + req = SynapseRequest(channel, self.site) # type: ignore[arg-type] + req.client.host = EXAMPLE_IPV4_ADDR + req.requestHeaders.addRawHeader("Authorization", f"Bearer {known_token}") + req.requestHeaders.addRawHeader("User-Agent", EXAMPLE_USER_AGENT) + req.content = BytesIO(b"") + req.requestReceived( + b"GET", + b"/_matrix/client/v3/account/whoami", + b"1.1", + ) + channel.await_result() + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + self.assertEqual(channel.json_body["user_id"], USER_ID, channel.json_body) + + # Expect to see one MAU entry, from the first request + mau = self.get_success(store.get_monthly_active_count()) + self.assertEqual(mau, 1) + + conn_infos = self.get_success( + store.get_user_ip_and_agents(UserID.from_string(USER_ID)) + ) + self.assertEqual(len(conn_infos), 1, conn_infos) + conn_info = conn_infos[0] + self.assertEqual(conn_info["access_token"], known_token) + self.assertEqual(conn_info["ip"], EXAMPLE_IPV4_ADDR) + self.assertEqual(conn_info["user_agent"], EXAMPLE_USER_AGENT) + + # Now test MAS making a request using the special __oidc_admin token + MAS_IPV4_ADDR = "127.0.0.1" + MAS_USER_AGENT = "masmasmas" + + channel = FakeChannel(self.site, self.reactor) + req = SynapseRequest(channel, self.site) # type: ignore[arg-type] + req.client.host = MAS_IPV4_ADDR + req.requestHeaders.addRawHeader( + "Authorization", f"Bearer {self.auth._admin_token}" + ) + req.requestHeaders.addRawHeader("User-Agent", MAS_USER_AGENT) + req.content = BytesIO(b"") + req.requestReceived( + b"GET", + b"/_matrix/client/v3/account/whoami", + b"1.1", + ) + channel.await_result() + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + self.assertEqual( + channel.json_body["user_id"], OIDC_ADMIN_USERID, channel.json_body + ) + + # Still expect to see one MAU entry, from the first request + mau = self.get_success(store.get_monthly_active_count()) + self.assertEqual(mau, 1) + + conn_infos = self.get_success( + store.get_user_ip_and_agents(UserID.from_string(OIDC_ADMIN_USERID)) + ) + self.assertEqual(conn_infos, []) |