summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/api/test_ratelimiting.py4
-rw-r--r--tests/handlers/test_oauth_delegation.py10
-rw-r--r--tests/handlers/test_oidc.py44
-rw-r--r--tests/handlers/test_user_directory.py61
-rw-r--r--tests/push/test_http.py78
-rw-r--r--tests/rest/admin/test_scheduled_tasks.py192
-rw-r--r--tests/rest/client/test_login.py20
-rw-r--r--tests/test_utils/oidc.py19
8 files changed, 413 insertions, 15 deletions
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py

index a59e168db1..1a1cbde74e 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py
@@ -220,9 +220,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): self.assertIn("test_id_1", limiter.actions) - self.get_success_or_raise( - limiter.can_do_action(None, key="test_id_2", _time_now_s=10) - ) + self.reactor.advance(60) self.assertNotIn("test_id_1", limiter.actions) diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py
index 034a1594d9..934bfee0bc 100644 --- a/tests/handlers/test_oauth_delegation.py +++ b/tests/handlers/test_oauth_delegation.py
@@ -147,6 +147,16 @@ class MSC3861OAuthDelegation(HomeserverTestCase): return hs + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + # Provision the user and the device we use in the tests. + store = homeserver.get_datastores().main + self.get_success(store.register_user(USER_ID)) + self.get_success( + store.store_device(USER_ID, DEVICE, initial_device_display_name=None) + ) + def _assertParams(self) -> None: """Assert that the request parameters are correct.""" params = parse_qs(self.http_client.request.call_args[1]["data"].decode("utf-8")) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index a7cead83d0..e5f31d57ca 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py
@@ -1033,6 +1033,50 @@ class OidcHandlerTestCase(HomeserverTestCase): { "oidc_config": { **DEFAULT_CONFIG, + "redirect_uri": TEST_REDIRECT_URI, + } + } + ) + def test_code_exchange_ignores_access_token(self) -> None: + """ + Code exchange completes successfully and doesn't validate the `at_hash` + (access token hash) field of an ID token when the access token isn't + going to be used. + + The access token won't be used in this test because Synapse (currently) + only needs it to fetch a user's metadata if it isn't included in the ID + token itself. + + Because we have included "openid" in the requested scopes for this IdP + (see `SCOPES`), user metadata is be included in the ID token. Thus the + access token isn't needed, and it's unnecessary for Synapse to validate + the access token. + + This is a regression test for a situation where an upstream identity + provider was providing an invalid `at_hash` value, which Synapse errored + on, yet Synapse wasn't using the access token for anything. + """ + # Exchange the code against the fake IdP. + userinfo = { + "sub": "foo", + "username": "foo", + "phone": "1234567", + } + with self.fake_server.id_token_override( + { + "at_hash": "invalid-hash", + } + ): + request, _ = self.start_authorization(userinfo) + self.get_success(self.handler.handle_oidc_callback(request)) + + # If no error was rendered, then we have success. + self.render_error.assert_not_called() + + @override_config( + { + "oidc_config": { + **DEFAULT_CONFIG, "user_mapping_provider": { "module": __name__ + ".TestMappingProviderExtra" }, diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index a75095a79f..a9e9d7d7ea 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py
@@ -992,6 +992,67 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): [self.assertIn(user, local_users) for user in received_user_id_ordering[:3]] [self.assertIn(user, remote_users) for user in received_user_id_ordering[3:]] + @override_config( + { + "user_directory": { + "enabled": True, + "search_all_users": True, + "exclude_remote_users": True, + } + } + ) + def test_exclude_remote_users(self) -> None: + """Tests that only local users are returned when + user_directory.exclude_remote_users is True. + """ + + # Create a room and few users to test the directory with + searching_user = self.register_user("searcher", "password") + searching_user_tok = self.login("searcher", "password") + + room_id = self.helper.create_room_as( + searching_user, + room_version=RoomVersions.V1.identifier, + tok=searching_user_tok, + ) + + # Create a few local users and join them to the room + local_user_1 = self.register_user("user_xxxxx", "password") + local_user_2 = self.register_user("user_bbbbb", "password") + local_user_3 = self.register_user("user_zzzzz", "password") + + self._add_user_to_room(room_id, RoomVersions.V1, local_user_1) + self._add_user_to_room(room_id, RoomVersions.V1, local_user_2) + self._add_user_to_room(room_id, RoomVersions.V1, local_user_3) + + # Create a few "remote" users and join them to the room + remote_user_1 = "@user_aaaaa:remote_server" + remote_user_2 = "@user_yyyyy:remote_server" + remote_user_3 = "@user_ccccc:remote_server" + self._add_user_to_room(room_id, RoomVersions.V1, remote_user_1) + self._add_user_to_room(room_id, RoomVersions.V1, remote_user_2) + self._add_user_to_room(room_id, RoomVersions.V1, remote_user_3) + + local_users = [local_user_1, local_user_2, local_user_3] + remote_users = [remote_user_1, remote_user_2, remote_user_3] + + # The local searching user searches for the term "user", which other users have + # in their user id + results = self.get_success( + self.handler.search_users(searching_user, "user", 20) + )["results"] + received_user_ids = [result["user_id"] for result in results] + + for user in local_users: + self.assertIn( + user, received_user_ids, f"Local user {user} not found in results" + ) + + for user in remote_users: + self.assertNotIn( + user, received_user_ids, f"Remote user {user} should not be in results" + ) + def _add_user_to_room( self, room_id: str, diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index 5c235bbe53..b42fd284b6 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py
@@ -1167,3 +1167,81 @@ class HTTPPusherTests(HomeserverTestCase): self.assertEqual( self.push_attempts[0][2]["notification"]["counts"]["unread"], 1 ) + + def test_push_backoff(self) -> None: + """ + The HTTP pusher will backoff correctly if it fails to contact the pusher. + """ + + # Register the user who gets notified + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Register the user who sends the message + other_user_id = self.register_user("otheruser", "pass") + other_access_token = self.login("otheruser", "pass") + + # Register the pusher + user_tuple = self.get_success( + self.hs.get_datastores().main.get_user_by_access_token(access_token) + ) + assert user_tuple is not None + device_id = user_tuple.device_id + + self.get_success( + self.hs.get_pusherpool().add_or_update_pusher( + user_id=user_id, + device_id=device_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data={"url": "http://example.com/_matrix/push/v1/notify"}, + ) + ) + + # Create a room with the other user + room = self.helper.create_room_as(user_id, tok=access_token) + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + + # The other user sends some messages + self.helper.send(room, body="Message 1", tok=other_access_token) + + # One push was attempted to be sent + self.assertEqual(len(self.push_attempts), 1) + self.assertEqual( + self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify" + ) + self.assertEqual( + self.push_attempts[0][2]["notification"]["content"]["body"], "Message 1" + ) + self.push_attempts[0][0].callback({}) + self.pump() + + # Send another message, this time it fails + self.helper.send(room, body="Message 2", tok=other_access_token) + self.assertEqual(len(self.push_attempts), 2) + self.push_attempts[1][0].errback(Exception("couldn't connect")) + self.pump() + + # Sending yet another message doesn't trigger a push immediately + self.helper.send(room, body="Message 3", tok=other_access_token) + self.pump() + self.assertEqual(len(self.push_attempts), 2) + + # .. but waiting for a bit will cause more pushes + self.reactor.advance(10) + self.assertEqual(len(self.push_attempts), 3) + self.assertEqual( + self.push_attempts[2][2]["notification"]["content"]["body"], "Message 2" + ) + self.push_attempts[2][0].callback({}) + self.pump() + + self.assertEqual(len(self.push_attempts), 4) + self.assertEqual( + self.push_attempts[3][2]["notification"]["content"]["body"], "Message 3" + ) + self.push_attempts[3][0].callback({}) diff --git a/tests/rest/admin/test_scheduled_tasks.py b/tests/rest/admin/test_scheduled_tasks.py new file mode 100644
index 0000000000..9654e9322b --- /dev/null +++ b/tests/rest/admin/test_scheduled_tasks.py
@@ -0,0 +1,192 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# <https://www.gnu.org/licenses/agpl-3.0.html>. +# +# +# +from typing import Mapping, Optional, Tuple + +from twisted.test.proto_helpers import MemoryReactor + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client import login +from synapse.server import HomeServer +from synapse.types import JsonMapping, ScheduledTask, TaskStatus +from synapse.util import Clock + +from tests import unittest + + +class ScheduledTasksAdminApiTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + self._task_scheduler = hs.get_task_scheduler() + + # create and schedule a few tasks + async def _test_task( + task: ScheduledTask, + ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + return TaskStatus.ACTIVE, None, None + + async def _finished_test_task( + task: ScheduledTask, + ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + return TaskStatus.COMPLETE, None, None + + async def _failed_test_task( + task: ScheduledTask, + ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + return TaskStatus.FAILED, None, "Everything failed" + + self._task_scheduler.register_action(_test_task, "test_task") + self.get_success( + self._task_scheduler.schedule_task("test_task", resource_id="test") + ) + + self._task_scheduler.register_action(_finished_test_task, "finished_test_task") + self.get_success( + self._task_scheduler.schedule_task( + "finished_test_task", resource_id="finished_task" + ) + ) + + self._task_scheduler.register_action(_failed_test_task, "failed_test_task") + self.get_success( + self._task_scheduler.schedule_task( + "failed_test_task", resource_id="failed_task" + ) + ) + + def check_scheduled_tasks_response(self, scheduled_tasks: Mapping) -> list: + result = [] + for task in scheduled_tasks: + if task["resource_id"] == "test": + self.assertEqual(task["status"], TaskStatus.ACTIVE) + self.assertEqual(task["action"], "test_task") + result.append(task) + if task["resource_id"] == "finished_task": + self.assertEqual(task["status"], TaskStatus.COMPLETE) + self.assertEqual(task["action"], "finished_test_task") + result.append(task) + if task["resource_id"] == "failed_task": + self.assertEqual(task["status"], TaskStatus.FAILED) + self.assertEqual(task["action"], "failed_test_task") + result.append(task) + + return result + + def test_requester_is_not_admin(self) -> None: + """ + If the user is not a server admin, an error 403 is returned. + """ + + self.register_user("user", "pass", admin=False) + other_user_tok = self.login("user", "pass") + + channel = self.make_request( + "GET", + "/_synapse/admin/v1/scheduled_tasks", + content={}, + access_token=other_user_tok, + ) + + self.assertEqual(403, channel.code, msg=channel.json_body) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_scheduled_tasks(self) -> None: + """ + Test that endpoint returns scheduled tasks. + """ + + channel = self.make_request( + "GET", + "/_synapse/admin/v1/scheduled_tasks", + content={}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + scheduled_tasks = channel.json_body["scheduled_tasks"] + + # make sure we got back all the scheduled tasks + found_tasks = self.check_scheduled_tasks_response(scheduled_tasks) + self.assertEqual(len(found_tasks), 3) + + def test_filtering_scheduled_tasks(self) -> None: + """ + Test that filtering the scheduled tasks response via query params works as expected. + """ + # filter via job_status + channel = self.make_request( + "GET", + "/_synapse/admin/v1/scheduled_tasks?job_status=active", + content={}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + scheduled_tasks = channel.json_body["scheduled_tasks"] + found_tasks = self.check_scheduled_tasks_response(scheduled_tasks) + + # only the active task should have been returned + self.assertEqual(len(found_tasks), 1) + self.assertEqual(found_tasks[0]["status"], "active") + + # filter via action_name + channel = self.make_request( + "GET", + "/_synapse/admin/v1/scheduled_tasks?action_name=test_task", + content={}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + scheduled_tasks = channel.json_body["scheduled_tasks"] + + # only test_task should have been returned + found_tasks = self.check_scheduled_tasks_response(scheduled_tasks) + self.assertEqual(len(found_tasks), 1) + self.assertEqual(found_tasks[0]["action"], "test_task") + + # filter via max_timestamp + channel = self.make_request( + "GET", + "/_synapse/admin/v1/scheduled_tasks?max_timestamp=0", + content={}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + scheduled_tasks = channel.json_body["scheduled_tasks"] + found_tasks = self.check_scheduled_tasks_response(scheduled_tasks) + + # none should have been returned + self.assertEqual(len(found_tasks), 0) + + # filter via resource id + channel = self.make_request( + "GET", + "/_synapse/admin/v1/scheduled_tasks?resource_id=failed_task", + content={}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + scheduled_tasks = channel.json_body["scheduled_tasks"] + found_tasks = self.check_scheduled_tasks_response(scheduled_tasks) + + # only the task with the matching resource id should have been returned + self.assertEqual(len(found_tasks), 1) + self.assertEqual(found_tasks[0]["resource_id"], "failed_task") diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index d7148917d0..c5c6604667 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py
@@ -1262,18 +1262,18 @@ class JWTTestCase(unittest.HomeserverTestCase): channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual( + self.assertRegex( channel.json_body["error"], - 'JWT validation failed: invalid_claim: Invalid claim "iss"', + r"^JWT validation failed: invalid_claim: Invalid claim [\"']iss[\"']$", ) # Not providing an issuer. channel = self.jwt_login({"sub": "kermit"}) self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual( + self.assertRegex( channel.json_body["error"], - 'JWT validation failed: missing_claim: Missing "iss" claim', + r"^JWT validation failed: missing_claim: Missing [\"']iss[\"'] claim$", ) def test_login_iss_no_config(self) -> None: @@ -1294,18 +1294,18 @@ class JWTTestCase(unittest.HomeserverTestCase): channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual( + self.assertRegex( channel.json_body["error"], - 'JWT validation failed: invalid_claim: Invalid claim "aud"', + r"^JWT validation failed: invalid_claim: Invalid claim [\"']aud[\"']$", ) # Not providing an audience. channel = self.jwt_login({"sub": "kermit"}) self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual( + self.assertRegex( channel.json_body["error"], - 'JWT validation failed: missing_claim: Missing "aud" claim', + r"^JWT validation failed: missing_claim: Missing [\"']aud[\"'] claim$", ) def test_login_aud_no_config(self) -> None: @@ -1313,9 +1313,9 @@ class JWTTestCase(unittest.HomeserverTestCase): channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) self.assertEqual(channel.code, 403, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual( + self.assertRegex( channel.json_body["error"], - 'JWT validation failed: invalid_claim: Invalid claim "aud"', + r"^JWT validation failed: invalid_claim: Invalid claim [\"']aud[\"']$", ) def test_login_default_sub(self) -> None: diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py
index 6c4be1c1f8..5bf5e5cb0c 100644 --- a/tests/test_utils/oidc.py +++ b/tests/test_utils/oidc.py
@@ -20,7 +20,9 @@ # +import base64 import json +from hashlib import sha256 from typing import Any, ContextManager, Dict, List, Optional, Tuple from unittest.mock import Mock, patch from urllib.parse import parse_qs @@ -154,10 +156,23 @@ class FakeOidcServer: json_payload = json.dumps(payload) return jws.serialize_compact(protected, json_payload, self._key).decode("utf-8") - def generate_id_token(self, grant: FakeAuthorizationGrant) -> str: + def generate_id_token( + self, grant: FakeAuthorizationGrant, access_token: str + ) -> str: + # Generate a hash of the access token for the optional + # `at_hash` field in an ID Token. + # + # 3.1.3.6. ID Token, https://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken + at_hash = ( + base64.urlsafe_b64encode(sha256(access_token.encode("ascii")).digest()[:16]) + .rstrip(b"=") + .decode("ascii") + ) + now = int(self._clock.time()) id_token = { **grant.userinfo, + "at_hash": at_hash, "iss": self.issuer, "aud": grant.client_id, "iat": now, @@ -243,7 +258,7 @@ class FakeOidcServer: } if "openid" in grant.scope: - token["id_token"] = self.generate_id_token(grant) + token["id_token"] = self.generate_id_token(grant, access_token) return dict(token)