diff options
Diffstat (limited to 'tests')
24 files changed, 1337 insertions, 365 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index bc75ddd3e9..dfcfaf79b6 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -19,6 +19,7 @@ import pymacaroons from twisted.test.proto_helpers import MemoryReactor from synapse.api.auth import Auth +from synapse.api.auth_blocking import AuthBlocking from synapse.api.constants import UserTypes from synapse.api.errors import ( AuthError, @@ -49,7 +50,7 @@ class AuthTestCase(unittest.HomeserverTestCase): # AuthBlocking reads from the hs' config on initialization. We need to # modify its config instead of the hs' - self.auth_blocking = self.auth._auth_blocking + self.auth_blocking = AuthBlocking(hs) self.test_user = "@foo:bar" self.test_token = b"_test_token_" @@ -312,9 +313,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.assertEqual(self.store.insert_client_ip.call_count, 2) def test_get_user_from_macaroon(self): - self.store.get_user_by_access_token = simple_async_mock( - TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device") - ) + self.store.get_user_by_access_token = simple_async_mock(None) user_id = "@baldrick:matrix.org" macaroon = pymacaroons.Macaroon( @@ -322,17 +321,14 @@ class AuthTestCase(unittest.HomeserverTestCase): identifier="key", key=self.hs.config.key.macaroon_secret_key, ) + # "Legacy" macaroons should not work for regular users not in the database macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) - user_info = self.get_success( - self.auth.get_user_by_access_token(macaroon.serialize()) + serialized = macaroon.serialize() + self.get_failure( + self.auth.get_user_by_access_token(serialized), InvalidClientTokenError ) - self.assertEqual(user_id, user_info.user_id) - - # TODO: device_id should come from the macaroon, but currently comes - # from the db. - self.assertEqual(user_info.device_id, "device") def test_get_guest_user_from_macaroon(self): self.store.get_user_by_id = simple_async_mock({"is_guest": True}) @@ -362,20 +358,22 @@ class AuthTestCase(unittest.HomeserverTestCase): small_number_of_users = 1 # Ensure no error thrown - self.get_success(self.auth.check_auth_blocking()) + self.get_success(self.auth_blocking.check_auth_blocking()) self.auth_blocking._limit_usage_by_mau = True self.store.get_monthly_active_count = simple_async_mock(lots_of_users) - e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) + e = self.get_failure( + self.auth_blocking.check_auth_blocking(), ResourceLimitError + ) self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact) self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEqual(e.value.code, 403) # Ensure does not throw an error self.store.get_monthly_active_count = simple_async_mock(small_number_of_users) - self.get_success(self.auth.check_auth_blocking()) + self.get_success(self.auth_blocking.check_auth_blocking()) def test_blocking_mau__depending_on_user_type(self): self.auth_blocking._max_mau_value = 50 @@ -383,15 +381,18 @@ class AuthTestCase(unittest.HomeserverTestCase): self.store.get_monthly_active_count = simple_async_mock(100) # Support users allowed - self.get_success(self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT)) + self.get_success( + self.auth_blocking.check_auth_blocking(user_type=UserTypes.SUPPORT) + ) self.store.get_monthly_active_count = simple_async_mock(100) # Bots not allowed self.get_failure( - self.auth.check_auth_blocking(user_type=UserTypes.BOT), ResourceLimitError + self.auth_blocking.check_auth_blocking(user_type=UserTypes.BOT), + ResourceLimitError, ) self.store.get_monthly_active_count = simple_async_mock(100) # Real users not allowed - self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) + self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError) def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips(self): self.auth_blocking._max_mau_value = 50 @@ -419,7 +420,7 @@ class AuthTestCase(unittest.HomeserverTestCase): app_service=appservice, authenticated_entity="@appservice:server", ) - self.get_success(self.auth.check_auth_blocking(requester=requester)) + self.get_success(self.auth_blocking.check_auth_blocking(requester=requester)) def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(self): self.auth_blocking._max_mau_value = 50 @@ -448,7 +449,8 @@ class AuthTestCase(unittest.HomeserverTestCase): authenticated_entity="@appservice:server", ) self.get_failure( - self.auth.check_auth_blocking(requester=requester), ResourceLimitError + self.auth_blocking.check_auth_blocking(requester=requester), + ResourceLimitError, ) def test_reserved_threepid(self): @@ -459,18 +461,21 @@ class AuthTestCase(unittest.HomeserverTestCase): unknown_threepid = {"medium": "email", "address": "unreserved@server.com"} self.auth_blocking._mau_limits_reserved_threepids = [threepid] - self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) + self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError) self.get_failure( - self.auth.check_auth_blocking(threepid=unknown_threepid), ResourceLimitError + self.auth_blocking.check_auth_blocking(threepid=unknown_threepid), + ResourceLimitError, ) - self.get_success(self.auth.check_auth_blocking(threepid=threepid)) + self.get_success(self.auth_blocking.check_auth_blocking(threepid=threepid)) def test_hs_disabled(self): self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled_message = "Reason for being disabled" - e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) + e = self.get_failure( + self.auth_blocking.check_auth_blocking(), ResourceLimitError + ) self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact) self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEqual(e.value.code, 403) @@ -485,7 +490,9 @@ class AuthTestCase(unittest.HomeserverTestCase): self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled_message = "Reason for being disabled" - e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) + e = self.get_failure( + self.auth_blocking.check_auth_blocking(), ResourceLimitError + ) self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact) self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEqual(e.value.code, 403) @@ -495,4 +502,4 @@ class AuthTestCase(unittest.HomeserverTestCase): user = "@user:server" self.auth_blocking._server_notices_mxid = user self.auth_blocking._hs_disabled_message = "Reason for being disabled" - self.get_success(self.auth.check_auth_blocking(user)) + self.get_success(self.auth_blocking.check_auth_blocking(user)) diff --git a/tests/federation/transport/server/test__base.py b/tests/federation/transport/server/test__base.py index e63885c1c9..d33e86db4c 100644 --- a/tests/federation/transport/server/test__base.py +++ b/tests/federation/transport/server/test__base.py @@ -24,7 +24,7 @@ from synapse.types import JsonDict from synapse.util.ratelimitutils import FederationRateLimiter from tests import unittest -from tests.http.server._base import EndpointCancellationTestHelperMixin +from tests.http.server._base import test_disconnect class CancellableFederationServlet(BaseFederationServlet): @@ -54,9 +54,7 @@ class CancellableFederationServlet(BaseFederationServlet): return HTTPStatus.OK, {"result": True} -class BaseFederationServletCancellationTests( - unittest.FederatingHomeserverTestCase, EndpointCancellationTestHelperMixin -): +class BaseFederationServletCancellationTests(unittest.FederatingHomeserverTestCase): """Tests for `BaseFederationServlet` cancellation.""" skip = "`BaseFederationServlet` does not support cancellation yet." @@ -86,7 +84,7 @@ class BaseFederationServletCancellationTests( # request won't be processed. self.pump() - self._test_disconnect( + test_disconnect( self.reactor, channel, expect_cancellation=True, @@ -106,7 +104,7 @@ class BaseFederationServletCancellationTests( # request won't be processed. self.pump() - self._test_disconnect( + test_disconnect( self.reactor, channel, expect_cancellation=False, diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 67a7829769..7106799d44 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -38,7 +38,7 @@ class AuthTestCase(unittest.HomeserverTestCase): # MAU tests # AuthBlocking reads from the hs' config on initialization. We need to # modify its config instead of the hs' - self.auth_blocking = hs.get_auth()._auth_blocking + self.auth_blocking = hs.get_auth_blocking() self.auth_blocking._max_mau_value = 50 self.small_number_of_users = 1 diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 01ea7d2a42..b8b465d35b 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -154,7 +154,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): self._record_users() # delete the device - self.get_success(self.handler.delete_device(user1, "abc")) + self.get_success(self.handler.delete_devices(user1, ["abc"])) # check the device was deleted self.get_failure(self.handler.get_device(user1, "abc"), NotFoundError) @@ -179,7 +179,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): ) # delete the device - self.get_success(self.handler.delete_device(user1, "abc")) + self.get_success(self.handler.delete_devices(user1, ["abc"])) # check that the device_inbox was deleted res = self.get_success( diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 1231aed944..e6cd3af7b7 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -25,7 +25,7 @@ from synapse.handlers.sso import MappingException from synapse.server import HomeServer from synapse.types import JsonDict, UserID from synapse.util import Clock -from synapse.util.macaroons import get_value_from_macaroon +from synapse.util.macaroons import OidcSessionData, get_value_from_macaroon from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock from tests.unittest import HomeserverTestCase, override_config @@ -1227,7 +1227,7 @@ class OidcHandlerTestCase(HomeserverTestCase): ) -> str: from synapse.handlers.oidc import OidcSessionData - return self.handler._token_generator.generate_oidc_session_token( + return self.handler._macaroon_generator.generate_oidc_session_token( state=state, session_data=OidcSessionData( idp_id="oidc", @@ -1251,7 +1251,6 @@ async def _make_callback_with_userinfo( userinfo: the OIDC userinfo dict client_redirect_url: the URL to redirect to on success. """ - from synapse.handlers.oidc import OidcSessionData handler = hs.get_oidc_handler() provider = handler._providers["oidc"] @@ -1260,7 +1259,7 @@ async def _make_callback_with_userinfo( provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment] state = "state" - session = handler._token_generator.generate_oidc_session_token( + session = handler._macaroon_generator.generate_oidc_session_token( state=state, session_data=OidcSessionData( idp_id="oidc", diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index b6ba19c739..23f35d5bf5 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -699,7 +699,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): """ if localpart is None: raise SynapseError(400, "Request must include user id") - await self.hs.get_auth().check_auth_blocking() + await self.hs.get_auth_blocking().check_auth_blocking() need_register = True try: diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py index 0546655690..aa650756e4 100644 --- a/tests/handlers/test_room_summary.py +++ b/tests/handlers/test_room_summary.py @@ -178,7 +178,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): result_room_ids.append(result_room["room_id"]) result_children_ids.append( [ - (cs["room_id"], cs["state_key"]) + (result_room["room_id"], cs["state_key"]) for cs in result_room["children_state"] ] ) diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index ecd78fa369..05f9ec3c51 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -46,16 +46,9 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.get_success( self.store.db_pool.simple_insert( "background_updates", - {"update_name": "populate_stats_prepare", "progress_json": "{}"}, - ) - ) - self.get_success( - self.store.db_pool.simple_insert( - "background_updates", { "update_name": "populate_stats_process_rooms", "progress_json": "{}", - "depends_on": "populate_stats_prepare", }, ) ) @@ -69,16 +62,6 @@ class StatsRoomTests(unittest.HomeserverTestCase): }, ) ) - self.get_success( - self.store.db_pool.simple_insert( - "background_updates", - { - "update_name": "populate_stats_cleanup", - "progress_json": "{}", - "depends_on": "populate_stats_process_users", - }, - ) - ) async def get_all_room_state(self): return await self.store.db_pool.simple_select_list( @@ -533,7 +516,6 @@ class StatsRoomTests(unittest.HomeserverTestCase): { "update_name": "populate_stats_process_rooms", "progress_json": "{}", - "depends_on": "populate_stats_prepare", }, ) ) @@ -547,16 +529,6 @@ class StatsRoomTests(unittest.HomeserverTestCase): }, ) ) - self.get_success( - self.store.db_pool.simple_insert( - "background_updates", - { - "update_name": "populate_stats_cleanup", - "progress_json": "{}", - "depends_on": "populate_stats_process_users", - }, - ) - ) self.wait_for_background_updates() diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index db3302a4c7..ecc7cc6461 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -45,7 +45,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): # AuthBlocking reads from the hs' config on initialization. We need to # modify its config instead of the hs' - self.auth_blocking = self.hs.get_auth()._auth_blocking + self.auth_blocking = self.hs.get_auth_blocking() def test_wait_for_sync_for_user_auth_blocking(self): user_id1 = "@user1:test" diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py index b9f1a381aa..994d8880b0 100644 --- a/tests/http/server/_base.py +++ b/tests/http/server/_base.py @@ -12,89 +12,543 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect +import itertools +import logging from http import HTTPStatus -from typing import Any, Callable, Optional, Union +from typing import ( + Any, + Callable, + ContextManager, + Dict, + List, + Optional, + Set, + Tuple, + TypeVar, + Union, +) from unittest import mock +from unittest.mock import Mock +from twisted.internet.defer import Deferred from twisted.internet.error import ConnectionDone +from twisted.python.failure import Failure +from twisted.test.proto_helpers import MemoryReactorClock +from twisted.web.server import Site from synapse.http.server import ( HTTP_STATUS_REQUEST_CANCELLED, respond_with_html_bytes, respond_with_json, ) +from synapse.http.site import SynapseRequest +from synapse.logging.context import LoggingContext, make_deferred_yieldable from synapse.types import JsonDict -from tests import unittest -from tests.server import FakeChannel, ThreadedMemoryReactorClock +from tests.server import FakeChannel, make_request +from tests.unittest import logcontext_clean +logger = logging.getLogger(__name__) -class EndpointCancellationTestHelperMixin(unittest.TestCase): - """Provides helper methods for testing cancellation of endpoints.""" - def _test_disconnect( - self, - reactor: ThreadedMemoryReactorClock, - channel: FakeChannel, - expect_cancellation: bool, - expected_body: Union[bytes, JsonDict], - expected_code: Optional[int] = None, - ) -> None: - """Disconnects an in-flight request and checks the response. +T = TypeVar("T") - Args: - reactor: The twisted reactor running the request handler. - channel: The `FakeChannel` for the request. - expect_cancellation: `True` if request processing is expected to be - cancelled, `False` if the request should run to completion. - expected_body: The expected response for the request. - expected_code: The expected status code for the request. Defaults to `200` - or `499` depending on `expect_cancellation`. - """ - # Determine the expected status code. - if expected_code is None: - if expect_cancellation: - expected_code = HTTP_STATUS_REQUEST_CANCELLED - else: - expected_code = HTTPStatus.OK - - request = channel.request - self.assertFalse( - channel.is_finished(), + +def test_disconnect( + reactor: MemoryReactorClock, + channel: FakeChannel, + expect_cancellation: bool, + expected_body: Union[bytes, JsonDict], + expected_code: Optional[int] = None, +) -> None: + """Disconnects an in-flight request and checks the response. + + Args: + reactor: The twisted reactor running the request handler. + channel: The `FakeChannel` for the request. + expect_cancellation: `True` if request processing is expected to be cancelled, + `False` if the request should run to completion. + expected_body: The expected response for the request. + expected_code: The expected status code for the request. Defaults to `200` or + `499` depending on `expect_cancellation`. + """ + # Determine the expected status code. + if expected_code is None: + if expect_cancellation: + expected_code = HTTP_STATUS_REQUEST_CANCELLED + else: + expected_code = HTTPStatus.OK + + request = channel.request + if channel.is_finished(): + raise AssertionError( "Request finished before we could disconnect - " - "was `await_result=False` passed to `make_request`?", + "ensure `await_result=False` is passed to `make_request`.", ) - # We're about to disconnect the request. This also disconnects the channel, so - # we have to rely on mocks to extract the response. - respond_method: Callable[..., Any] - if isinstance(expected_body, bytes): - respond_method = respond_with_html_bytes + # We're about to disconnect the request. This also disconnects the channel, so we + # have to rely on mocks to extract the response. + respond_method: Callable[..., Any] + if isinstance(expected_body, bytes): + respond_method = respond_with_html_bytes + else: + respond_method = respond_with_json + + with mock.patch( + f"synapse.http.server.{respond_method.__name__}", wraps=respond_method + ) as respond_mock: + # Disconnect the request. + request.connectionLost(reason=ConnectionDone()) + + if expect_cancellation: + # An immediate cancellation is expected. + respond_mock.assert_called_once() else: - respond_method = respond_with_json + respond_mock.assert_not_called() - with mock.patch( - f"synapse.http.server.{respond_method.__name__}", wraps=respond_method - ) as respond_mock: - # Disconnect the request. - request.connectionLost(reason=ConnectionDone()) + # The handler is expected to run to completion. + reactor.advance(1.0) + respond_mock.assert_called_once() - if expect_cancellation: - # An immediate cancellation is expected. - respond_mock.assert_called_once() - args, _kwargs = respond_mock.call_args - code, body = args[1], args[2] - self.assertEqual(code, expected_code) - self.assertEqual(request.code, expected_code) - self.assertEqual(body, expected_body) - else: - respond_mock.assert_not_called() - - # The handler is expected to run to completion. - reactor.pump([1.0]) + args, _kwargs = respond_mock.call_args + code, body = args[1], args[2] + + if code != expected_code: + raise AssertionError( + f"{code} != {expected_code} : " + "Request did not finish with the expected status code." + ) + + if request.code != expected_code: + raise AssertionError( + f"{request.code} != {expected_code} : " + "Request did not finish with the expected status code." + ) + + if body != expected_body: + raise AssertionError( + f"{body!r} != {expected_body!r} : " + "Request did not finish with the expected status code." + ) + + +@logcontext_clean +def make_request_with_cancellation_test( + test_name: str, + reactor: MemoryReactorClock, + site: Site, + method: str, + path: str, + content: Union[bytes, str, JsonDict] = b"", +) -> FakeChannel: + """Performs a request repeatedly, disconnecting at successive `await`s, until + one completes. + + Fails if: + * A logging context is lost during cancellation. + * A logging context get restarted after it is marked as finished, eg. if + a request's logging context is used by some processing started by the + request, but the request neglects to cancel that processing or wait for it + to complete. + + Note that "Re-starting finished log context" errors get raised within the + request handling code and may or may not get caught. These errors will + likely manifest as a different logging context error at a later point. When + debugging logging context failures, setting a breakpoint in + `logcontext_error` can prove useful. + * A request gets stuck, possibly due to a previous cancellation. + * The request does not return a 499 when the client disconnects. + This implies that a `CancelledError` was swallowed somewhere. + + It is up to the caller to verify that the request returns the correct data when + it finally runs to completion. + + Note that this function can only cover a single code path and does not guarantee + that an endpoint is compatible with cancellation on every code path. + To allow inspection of the code path that is being tested, this function will + log the stack trace at every `await` that gets cancelled. To view these log + lines, `trial` can be run with the `SYNAPSE_TEST_LOG_LEVEL=INFO` environment + variable, which will include the log lines in `_trial_temp/test.log`. + Alternatively, `_log_for_request` can be modified to write to `sys.stdout`. + + Args: + test_name: The name of the test, which will be logged. + reactor: The twisted reactor running the request handler. + site: The twisted `Site` to use to render the request. + method: The HTTP request method ("verb"). + path: The HTTP path, suitably URL encoded (e.g. escaped UTF-8 & spaces and + such). + content: The body of the request. + + Returns: + The `FakeChannel` object which stores the result of the final request that + runs to completion. + """ + # To process a request, a coroutine run is created for the async method handling + # the request. That method may then start other coroutine runs, wrapped in + # `Deferred`s. + # + # We would like to trigger a cancellation at the first `await`, re-run the + # request and cancel at the second `await`, and so on. By patching + # `Deferred.__next__`, we can intercept `await`s, track which ones we have or + # have not seen, and force them to block when they wouldn't have. + + # The set of previously seen `await`s. + # Each element is a stringified stack trace. + seen_awaits: Set[Tuple[str, ...]] = set() + + _log_for_request( + 0, f"Running make_request_with_cancellation_test for {test_name}..." + ) + + for request_number in itertools.count(1): + deferred_patch = Deferred__next__Patch(seen_awaits, request_number) + + try: + with mock.patch( + "synapse.http.server.respond_with_json", wraps=respond_with_json + ) as respond_mock: + with deferred_patch.patch(): + # Start the request. + channel = make_request( + reactor, site, method, path, content, await_result=False + ) + request = channel.request + + # Run the request until we see a new `await` which we have not + # yet cancelled at, or it completes. + while not respond_mock.called and not deferred_patch.new_await_seen: + previous_awaits_seen = deferred_patch.awaits_seen + + reactor.advance(0.0) + + if deferred_patch.awaits_seen == previous_awaits_seen: + # We didn't see any progress. Try advancing the clock. + reactor.advance(1.0) + + if deferred_patch.awaits_seen == previous_awaits_seen: + # We still didn't see any progress. The request might be + # stuck. + raise AssertionError( + "Request appears to be stuck, possibly due to a " + "previous cancelled request" + ) + + if respond_mock.called: + # The request ran to completion and we are done with testing it. + + # `respond_with_json` writes the response asynchronously, so we + # might have to give the reactor a kick before the channel gets + # the response. + deferred_patch.unblock_awaits() + channel.await_result() + + return channel + + # Disconnect the client and wait for the response. + request.connectionLost(reason=ConnectionDone()) + + _log_for_request(request_number, "--- disconnected ---") + + # Advance the reactor just enough to get a response. + # We don't want to advance the reactor too far, because we can only + # detect re-starts of finished logging contexts after we set the + # finished flag below. + for _ in range(2): + # We may need to pump the reactor to allow `delay_cancellation`s to + # finish. + if not respond_mock.called: + reactor.advance(0.0) + + # Try advancing the clock if that didn't work. + if not respond_mock.called: + reactor.advance(1.0) + + # `delay_cancellation`s may be waiting for processing that we've + # forced to block. Try unblocking them, followed by another round of + # pumping the reactor. + if not respond_mock.called: + deferred_patch.unblock_awaits() + + # Mark the request's logging context as finished. If it gets + # activated again, an `AssertionError` will be raised and bubble up + # through request handling code. This `AssertionError` may or may not be + # caught. Eventually some other code will deactivate the logging + # context which will raise a different `AssertionError` because + # resource usage won't have been correctly tracked. + if isinstance(request, SynapseRequest) and request.logcontext: + request.logcontext.finished = True + + # Check that the request finished with a 499, + # ie. the `CancelledError` wasn't swallowed. respond_mock.assert_called_once() - args, _kwargs = respond_mock.call_args - code, body = args[1], args[2] - self.assertEqual(code, expected_code) - self.assertEqual(request.code, expected_code) - self.assertEqual(body, expected_body) + + if request.code != HTTP_STATUS_REQUEST_CANCELLED: + raise AssertionError( + f"{request.code} != {HTTP_STATUS_REQUEST_CANCELLED} : " + "Cancelled request did not finish with the correct status code." + ) + finally: + # Unblock any processing that might be shared between requests, if we + # haven't already done so. + deferred_patch.unblock_awaits() + + assert False, "unreachable" # noqa: B011 + + +class Deferred__next__Patch: + """A `Deferred.__next__` patch that will intercept `await`s and force them + to block once it sees a new `await`. + + When done with the patch, `unblock_awaits()` must be called to clean up after any + `await`s that were forced to block, otherwise processing shared between multiple + requests, such as database queries started by `@cached`, will become permanently + stuck. + + Usage: + seen_awaits = set() + deferred_patch = Deferred__next__Patch(seen_awaits, 1) + try: + with deferred_patch.patch(): + # do things + ... + finally: + deferred_patch.unblock_awaits() + """ + + def __init__(self, seen_awaits: Set[Tuple[str, ...]], request_number: int): + """ + Args: + seen_awaits: The set of stack traces of `await`s that have been previously + seen. When the `Deferred.__next__` patch sees a new `await`, it will add + it to the set. + request_number: The request number to log against. + """ + self._request_number = request_number + self._seen_awaits = seen_awaits + + self._original_Deferred___next__ = Deferred.__next__ + + # The number of `await`s on `Deferred`s we have seen so far. + self.awaits_seen = 0 + + # Whether we have seen a new `await` not in `seen_awaits`. + self.new_await_seen = False + + # To force `await`s on resolved `Deferred`s to block, we make up a new + # unresolved `Deferred` and return it out of `Deferred.__next__` / + # `coroutine.send()`. We have to resolve it later, in case the `await`ing + # coroutine is part of some shared processing, such as `@cached`. + self._to_unblock: Dict[Deferred, Union[object, Failure]] = {} + + # The last stack we logged. + self._previous_stack: List[inspect.FrameInfo] = [] + + def patch(self) -> ContextManager[Mock]: + """Returns a context manager which patches `Deferred.__next__`.""" + + def Deferred___next__( + deferred: "Deferred[T]", value: object = None + ) -> "Deferred[T]": + """Intercepts `await`s on `Deferred`s and rigs them to block once we have + seen enough of them. + + `Deferred.__next__` will normally: + * return `self` if the `Deferred` is unresolved, in which case + `coroutine.send()` will return the `Deferred`, and + `_defer.inlineCallbacks` will stop running the coroutine until the + `Deferred` is resolved. + * raise a `StopIteration(result)`, containing the result of the `await`. + * raise another exception, which will come out of the `await`. + """ + self.awaits_seen += 1 + + stack = _get_stack(skip_frames=1) + stack_hash = _hash_stack(stack) + + if stack_hash not in self._seen_awaits: + # Block at the current `await` onwards. + self._seen_awaits.add(stack_hash) + self.new_await_seen = True + + if not self.new_await_seen: + # This `await` isn't interesting. Let it proceed normally. + + # Don't log the stack. It's been seen before in a previous run. + self._previous_stack = stack + + return self._original_Deferred___next__(deferred, value) + + # We want to block at the current `await`. + if deferred.called and not deferred.paused: + # This `Deferred` already has a result. + # We return a new, unresolved, `Deferred` for `_inlineCallbacks` to wait + # on. This blocks the coroutine that did this `await`. + # We queue it up for unblocking later. + new_deferred: "Deferred[T]" = Deferred() + self._to_unblock[new_deferred] = deferred.result + + _log_await_stack( + stack, + self._previous_stack, + self._request_number, + "force-blocked await", + ) + self._previous_stack = stack + + return make_deferred_yieldable(new_deferred) + + # This `Deferred` does not have a result yet. + # The `await` will block normally, so we don't have to do anything. + _log_await_stack( + stack, + self._previous_stack, + self._request_number, + "blocking await", + ) + self._previous_stack = stack + + return self._original_Deferred___next__(deferred, value) + + return mock.patch.object(Deferred, "__next__", new=Deferred___next__) + + def unblock_awaits(self) -> None: + """Unblocks any shared processing that we forced to block. + + Must be called when done, otherwise processing shared between multiple requests, + such as database queries started by `@cached`, will become permanently stuck. + """ + to_unblock = self._to_unblock + self._to_unblock = {} + for deferred, result in to_unblock.items(): + deferred.callback(result) + + +def _log_for_request(request_number: int, message: str) -> None: + """Logs a message for an iteration of `make_request_with_cancellation_test`.""" + # We want consistent alignment when logging stack traces, so ensure the logging + # context has a fixed width name. + with LoggingContext(name=f"request-{request_number:<2}"): + logger.info(message) + + +def _log_await_stack( + stack: List[inspect.FrameInfo], + previous_stack: List[inspect.FrameInfo], + request_number: int, + note: str, +) -> None: + """Logs the stack for an `await` in `make_request_with_cancellation_test`. + + Only logs the part of the stack that has changed since the previous call. + + Example output looks like: + ``` + delay_cancellation:750 (synapse/util/async_helpers.py:750) + DatabasePool._runInteraction:768 (synapse/storage/database.py:768) + > *blocked on await* at DatabasePool.runWithConnection:891 (synapse/storage/database.py:891) + ``` + + Args: + stack: The stack to log, as returned by `_get_stack()`. + previous_stack: The previous stack logged, with callers appearing before + callees. + request_number: The request number to log against. + note: A note to attach to the last stack frame, eg. "blocked on await". + """ + for i, frame_info in enumerate(stack[:-1]): + # Skip any frames in common with the previous logging. + if i < len(previous_stack) and frame_info == previous_stack[i]: + continue + + frame = _format_stack_frame(frame_info) + message = f"{' ' * i}{frame}" + _log_for_request(request_number, message) + + # Always print the final frame with the `await`. + # If the frame with the `await` started another coroutine run, we may have already + # printed a deeper stack which includes our final frame. We want to log where all + # `await`s happen, so we reprint the frame in this case. + i = len(stack) - 1 + frame_info = stack[i] + frame = _format_stack_frame(frame_info) + message = f"{' ' * i}> *{note}* at {frame}" + _log_for_request(request_number, message) + + +def _format_stack_frame(frame_info: inspect.FrameInfo) -> str: + """Returns a string representation of a stack frame. + + Used for debug logging. + + Returns: + A string, formatted like + "JsonResource._async_render:559 (synapse/http/server.py:559)". + """ + method_name = _get_stack_frame_method_name(frame_info) + + return ( + f"{method_name}:{frame_info.lineno} ({frame_info.filename}:{frame_info.lineno})" + ) + + +def _get_stack(skip_frames: int) -> List[inspect.FrameInfo]: + """Captures the stack for a request. + + Skips any twisted frames and stops at `JsonResource.wrapped_async_request_handler`. + + Used for debug logging. + + Returns: + A list of `inspect.FrameInfo`s, with callers appearing before callees. + """ + stack = [] + + skip_frames += 1 # Also skip `get_stack` itself. + + for frame_info in inspect.stack()[skip_frames:]: + # Skip any twisted `inlineCallbacks` gunk. + if "/twisted/" in frame_info.filename: + continue + + # Exclude the reactor frame, upwards. + method_name = _get_stack_frame_method_name(frame_info) + if method_name == "ThreadedMemoryReactorClock.advance": + break + + stack.append(frame_info) + + # Stop at `JsonResource`'s `wrapped_async_request_handler`, which is the entry + # point for request handling. + if frame_info.function == "wrapped_async_request_handler": + break + + return stack[::-1] + + +def _get_stack_frame_method_name(frame_info: inspect.FrameInfo) -> str: + """Returns the name of a stack frame's method. + + eg. "JsonResource._async_render". + """ + method_name = frame_info.function + + # Prefix the class name for instance methods. + frame_self = frame_info.frame.f_locals.get("self") + if frame_self: + method = getattr(frame_self, method_name, None) + if method: + method_name = method.__qualname__ + else: + # We couldn't find the method on `self`. + # Make something up. It's useful to know which class "contains" a + # function anyway. + method_name = f"{type(frame_self).__name__} {method_name}" + + return method_name + + +def _hash_stack(stack: List[inspect.FrameInfo]): + """Turns a stack into a hashable value that can be put into a set.""" + return tuple(_format_stack_frame(frame) for frame in stack) diff --git a/tests/http/test_fedclient.py b/tests/http/test_matrixfederationclient.py index 006dbab093..be9eaf34e8 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_matrixfederationclient.py @@ -617,3 +617,17 @@ class FederationClientTests(HomeserverTestCase): self.assertIsInstance(f.value, RequestSendFailed) self.assertTrue(transport.disconnecting) + + def test_build_auth_headers_rejects_falsey_destinations(self) -> None: + with self.assertRaises(ValueError): + self.cl.build_auth_headers(None, b"GET", b"https://example.com") + with self.assertRaises(ValueError): + self.cl.build_auth_headers(b"", b"GET", b"https://example.com") + with self.assertRaises(ValueError): + self.cl.build_auth_headers( + None, b"GET", b"https://example.com", destination_is=b"" + ) + with self.assertRaises(ValueError): + self.cl.build_auth_headers( + b"", b"GET", b"https://example.com", destination_is=b"" + ) diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py index b3655d7b44..bb966c80c6 100644 --- a/tests/http/test_servlet.py +++ b/tests/http/test_servlet.py @@ -30,7 +30,7 @@ from synapse.server import HomeServer from synapse.types import JsonDict from tests import unittest -from tests.http.server._base import EndpointCancellationTestHelperMixin +from tests.http.server._base import test_disconnect def make_request(content): @@ -108,9 +108,7 @@ class CancellableRestServlet(RestServlet): return HTTPStatus.OK, {"result": True} -class TestRestServletCancellation( - unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin -): +class TestRestServletCancellation(unittest.HomeserverTestCase): """Tests for `RestServlet` cancellation.""" servlets = [ @@ -120,7 +118,7 @@ class TestRestServletCancellation( def test_cancellable_disconnect(self) -> None: """Test that handlers with the `@cancellable` flag can be cancelled.""" channel = self.make_request("GET", "/sleep", await_result=False) - self._test_disconnect( + test_disconnect( self.reactor, channel, expect_cancellation=True, @@ -130,7 +128,7 @@ class TestRestServletCancellation( def test_uncancellable_disconnect(self) -> None: """Test that handlers without the `@cancellable` flag cannot be cancelled.""" channel = self.make_request("POST", "/sleep", await_result=False) - self._test_disconnect( + test_disconnect( self.reactor, channel, expect_cancellation=False, diff --git a/tests/replication/http/test__base.py b/tests/replication/http/test__base.py index a5ab093a27..822a957c3a 100644 --- a/tests/replication/http/test__base.py +++ b/tests/replication/http/test__base.py @@ -25,7 +25,7 @@ from synapse.server import HomeServer from synapse.types import JsonDict from tests import unittest -from tests.http.server._base import EndpointCancellationTestHelperMixin +from tests.http.server._base import test_disconnect class CancellableReplicationEndpoint(ReplicationEndpoint): @@ -69,9 +69,7 @@ class UncancellableReplicationEndpoint(ReplicationEndpoint): return HTTPStatus.OK, {"result": True} -class ReplicationEndpointCancellationTestCase( - unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin -): +class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase): """Tests for `ReplicationEndpoint` cancellation.""" def create_test_resource(self): @@ -87,7 +85,7 @@ class ReplicationEndpointCancellationTestCase( """Test that handlers with the `@cancellable` flag can be cancelled.""" path = f"{REPLICATION_PREFIX}/{CancellableReplicationEndpoint.NAME}/" channel = self.make_request("POST", path, await_result=False) - self._test_disconnect( + test_disconnect( self.reactor, channel, expect_cancellation=True, @@ -98,7 +96,7 @@ class ReplicationEndpointCancellationTestCase( """Test that handlers without the `@cancellable` flag cannot be cancelled.""" path = f"{REPLICATION_PREFIX}/{UncancellableReplicationEndpoint.NAME}/" channel = self.make_request("POST", path, await_result=False) - self._test_disconnect( + test_disconnect( self.reactor, channel, expect_cancellation=False, diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py index 77c3ced42e..29bed0e872 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py @@ -13,6 +13,8 @@ # limitations under the License. """Tests REST events for /profile paths.""" +import urllib.parse +from http import HTTPStatus from typing import Any, Dict, Optional from twisted.test.proto_helpers import MemoryReactor @@ -49,6 +51,12 @@ class ProfileTestCase(unittest.HomeserverTestCase): res = self._get_displayname() self.assertEqual(res, "owner") + def test_get_displayname_rejects_bad_username(self) -> None: + channel = self.make_request( + "GET", f"/profile/{urllib.parse.quote('@alice:')}/displayname" + ) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) + def test_set_displayname(self) -> None: channel = self.make_request( "PUT", diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 62e4db23ef..aa84906548 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -728,6 +728,7 @@ class RelationsTestCase(BaseRelationsTestCase): class RelationPaginationTestCase(BaseRelationsTestCase): + @unittest.override_config({"experimental_features": {"msc3715_enabled": True}}) def test_basic_paginate_relations(self) -> None: """Tests that calling pagination API correctly the latest relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index f523d89b8f..35c59ee9e0 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -18,10 +18,13 @@ """Tests REST events for /rooms paths.""" import json -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Union from unittest.mock import Mock, call from urllib import parse as urlparse +# `Literal` appears with Python 3.8. +from typing_extensions import Literal + from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin @@ -42,6 +45,7 @@ from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest +from tests.http.server._base import make_request_with_cancellation_test from tests.test_utils import make_awaitable PATH_PREFIX = b"/_matrix/client/api/v1" @@ -471,6 +475,49 @@ class RoomPermissionsTestCase(RoomBase): ) +class RoomStateTestCase(RoomBase): + """Tests /rooms/$room_id/state.""" + + user_id = "@sid1:red" + + def test_get_state_cancellation(self) -> None: + """Test cancellation of a `/rooms/$room_id/state` request.""" + room_id = self.helper.create_room_as(self.user_id) + channel = make_request_with_cancellation_test( + "test_state_cancellation", + self.reactor, + self.site, + "GET", + "/rooms/%s/state" % room_id, + ) + + self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertCountEqual( + [state_event["type"] for state_event in channel.json_body], + { + "m.room.create", + "m.room.power_levels", + "m.room.join_rules", + "m.room.member", + "m.room.history_visibility", + }, + ) + + def test_get_state_event_cancellation(self) -> None: + """Test cancellation of a `/rooms/$room_id/state/$event_type` request.""" + room_id = self.helper.create_room_as(self.user_id) + channel = make_request_with_cancellation_test( + "test_state_cancellation", + self.reactor, + self.site, + "GET", + "/rooms/%s/state/m.room.member/%s" % (room_id, self.user_id), + ) + + self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(channel.json_body, {"membership": "join"}) + + class RoomsMemberListTestCase(RoomBase): """Tests /rooms/$room_id/members/list REST events.""" @@ -591,6 +638,62 @@ class RoomsMemberListTestCase(RoomBase): channel = self.make_request("GET", room_path) self.assertEqual(200, channel.code, msg=channel.result["body"]) + def test_get_member_list_cancellation(self) -> None: + """Test cancellation of a `/rooms/$room_id/members` request.""" + room_id = self.helper.create_room_as(self.user_id) + channel = make_request_with_cancellation_test( + "test_get_member_list_cancellation", + self.reactor, + self.site, + "GET", + "/rooms/%s/members" % room_id, + ) + + self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["chunk"]), 1) + self.assertLessEqual( + { + "content": {"membership": "join"}, + "room_id": room_id, + "sender": self.user_id, + "state_key": self.user_id, + "type": "m.room.member", + "user_id": self.user_id, + }.items(), + channel.json_body["chunk"][0].items(), + ) + + def test_get_member_list_with_at_token_cancellation(self) -> None: + """Test cancellation of a `/rooms/$room_id/members?at=<sync token>` request.""" + room_id = self.helper.create_room_as(self.user_id) + + # first sync to get an at token + channel = self.make_request("GET", "/sync") + self.assertEqual(200, channel.code) + sync_token = channel.json_body["next_batch"] + + channel = make_request_with_cancellation_test( + "test_get_member_list_with_at_token_cancellation", + self.reactor, + self.site, + "GET", + "/rooms/%s/members?at=%s" % (room_id, sync_token), + ) + + self.assertEqual(200, channel.code, msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["chunk"]), 1) + self.assertLessEqual( + { + "content": {"membership": "join"}, + "room_id": room_id, + "sender": self.user_id, + "state_key": self.user_id, + "type": "m.room.member", + "user_id": self.user_id, + }.items(), + channel.json_body["chunk"][0].items(), + ) + class RoomsCreateTestCase(RoomBase): """Tests /rooms and /rooms/$room_id REST events.""" @@ -677,9 +780,11 @@ class RoomsCreateTestCase(RoomBase): channel = self.make_request("POST", "/createRoom", content) self.assertEqual(200, channel.code) - def test_spam_checker_may_join_room(self) -> None: + def test_spam_checker_may_join_room_deprecated(self) -> None: """Tests that the user_may_join_room spam checker callback is correctly bypassed when creating a new room. + + In this test, we use the deprecated API in which callbacks return a bool. """ async def user_may_join_room( @@ -701,6 +806,32 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(join_mock.call_count, 0) + def test_spam_checker_may_join_room(self) -> None: + """Tests that the user_may_join_room spam checker callback is correctly bypassed + when creating a new room. + + In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`. + """ + + async def user_may_join_room( + mxid: str, + room_id: str, + is_invite: bool, + ) -> Codes: + return Codes.CONSENT_NOT_GIVEN + + join_mock = Mock(side_effect=user_may_join_room) + self.hs.get_spam_checker()._user_may_join_room_callbacks.append(join_mock) + + channel = self.make_request( + "POST", + "/createRoom", + {}, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + self.assertEqual(join_mock.call_count, 0) + class RoomTopicTestCase(RoomBase): """Tests /rooms/$room_id/topic REST events.""" @@ -911,9 +1042,11 @@ class RoomJoinTestCase(RoomBase): self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1) - def test_spam_checker_may_join_room(self) -> None: + def test_spam_checker_may_join_room_deprecated(self) -> None: """Tests that the user_may_join_room spam checker callback is correctly called and blocks room joins when needed. + + This test uses the deprecated API, in which callbacks return booleans. """ # Register a dummy callback. Make it allow all room joins for now. @@ -926,6 +1059,8 @@ class RoomJoinTestCase(RoomBase): ) -> bool: return return_value + # `spec` argument is needed for this function mock to have `__qualname__`, which + # is needed for `Measure` metrics buried in SpamChecker. callback_mock = Mock(side_effect=user_may_join_room, spec=lambda *x: None) self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock) @@ -968,6 +1103,67 @@ class RoomJoinTestCase(RoomBase): return_value = False self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2) + def test_spam_checker_may_join_room(self) -> None: + """Tests that the user_may_join_room spam checker callback is correctly called + and blocks room joins when needed. + + This test uses the latest API to this day, in which callbacks return `NOT_SPAM` or `Codes`. + """ + + # Register a dummy callback. Make it allow all room joins for now. + return_value: Union[Literal["NOT_SPAM"], Codes] = synapse.module_api.NOT_SPAM + + async def user_may_join_room( + userid: str, + room_id: str, + is_invited: bool, + ) -> Union[Literal["NOT_SPAM"], Codes]: + return return_value + + # `spec` argument is needed for this function mock to have `__qualname__`, which + # is needed for `Measure` metrics buried in SpamChecker. + callback_mock = Mock(side_effect=user_may_join_room, spec=lambda *x: None) + self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock) + + # Join a first room, without being invited to it. + self.helper.join(self.room1, self.user2, tok=self.tok2) + + # Check that the callback was called with the right arguments. + expected_call_args = ( + ( + self.user2, + self.room1, + False, + ), + ) + self.assertEqual( + callback_mock.call_args, + expected_call_args, + callback_mock.call_args, + ) + + # Join a second room, this time with an invite for it. + self.helper.invite(self.room2, self.user1, self.user2, tok=self.tok1) + self.helper.join(self.room2, self.user2, tok=self.tok2) + + # Check that the callback was called with the right arguments. + expected_call_args = ( + ( + self.user2, + self.room2, + True, + ), + ) + self.assertEqual( + callback_mock.call_args, + expected_call_args, + callback_mock.call_args, + ) + + # Now make the callback deny all room joins, and check that a join actually fails. + return_value = Codes.CONSENT_NOT_GIVEN + self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2) + class RoomJoinRatelimitTestCase(RoomBase): user_id = "@sid1:red" @@ -2845,9 +3041,14 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - def test_threepid_invite_spamcheck(self) -> None: + def test_threepid_invite_spamcheck_deprecated(self) -> None: + """ + Test allowing/blocking threepid invites with a spam-check module. + + In this test, we use the deprecated API in which callbacks return a bool. + """ # Mock a few functions to prevent the test from failing due to failing to talk to - # a remote IS. We keep the mock for _mock_make_and_store_3pid_invite around so we + # a remote IS. We keep the mock for make_and_store_3pid_invite around so we # can check its call_count later on during the test. make_invite_mock = Mock(return_value=make_awaitable(0)) self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock @@ -2901,3 +3102,67 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): # Also check that it stopped before calling _make_and_store_3pid_invite. make_invite_mock.assert_called_once() + + def test_threepid_invite_spamcheck(self) -> None: + """ + Test allowing/blocking threepid invites with a spam-check module. + + In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`.""" + # Mock a few functions to prevent the test from failing due to failing to talk to + # a remote IS. We keep the mock for make_and_store_3pid_invite around so we + # can check its call_count later on during the test. + make_invite_mock = Mock(return_value=make_awaitable(0)) + self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock + self.hs.get_identity_handler().lookup_3pid = Mock( + return_value=make_awaitable(None), + ) + + # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it + # allow everything for now. + # `spec` argument is needed for this function mock to have `__qualname__`, which + # is needed for `Measure` metrics buried in SpamChecker. + mock = Mock( + return_value=make_awaitable(synapse.module_api.NOT_SPAM), + spec=lambda *x: None, + ) + self.hs.get_spam_checker()._user_may_send_3pid_invite_callbacks.append(mock) + + # Send a 3PID invite into the room and check that it succeeded. + email_to_invite = "teresa@example.com" + channel = self.make_request( + method="POST", + path="/rooms/" + self.room_id + "/invite", + content={ + "id_server": "example.com", + "id_access_token": "sometoken", + "medium": "email", + "address": email_to_invite, + }, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200) + + # Check that the callback was called with the right params. + mock.assert_called_with(self.user_id, "email", email_to_invite, self.room_id) + + # Check that the call to send the invite was made. + make_invite_mock.assert_called_once() + + # Now change the return value of the callback to deny any invite and test that + # we can't send the invite. + mock.return_value = make_awaitable(Codes.CONSENT_NOT_GIVEN) + channel = self.make_request( + method="POST", + path="/rooms/" + self.room_id + "/invite", + content={ + "id_server": "example.com", + "id_access_token": "sometoken", + "medium": "email", + "address": email_to_invite, + }, + access_token=self.tok, + ) + self.assertEqual(channel.code, 403) + + # Also check that it stopped before calling _make_and_store_3pid_invite. + make_invite_mock.assert_called_once() diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 07e29788e5..e07ae78fc4 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -96,7 +96,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): def test_maybe_send_server_notice_to_user_remove_blocked_notice(self): """Test when user has blocked notice, but should have it removed""" - self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None)) + self._rlsn._auth_blocking.check_auth_blocking = Mock( + return_value=make_awaitable(None) + ) mock_event = Mock( type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) @@ -112,7 +114,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ Test when user has blocked notice, but notice ought to be there (NOOP) """ - self._rlsn._auth.check_auth_blocking = Mock( + self._rlsn._auth_blocking.check_auth_blocking = Mock( return_value=make_awaitable(None), side_effect=ResourceLimitError(403, "foo"), ) @@ -132,7 +134,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ Test when user does not have blocked notice, but should have one """ - self._rlsn._auth.check_auth_blocking = Mock( + self._rlsn._auth_blocking.check_auth_blocking = Mock( return_value=make_awaitable(None), side_effect=ResourceLimitError(403, "foo"), ) @@ -145,7 +147,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ Test when user does not have blocked notice, nor should they (NOOP) """ - self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None)) + self._rlsn._auth_blocking.check_auth_blocking = Mock( + return_value=make_awaitable(None) + ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -156,7 +160,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): Test when user is not part of the MAU cohort - this should not ever happen - but ... """ - self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None)) + self._rlsn._auth_blocking.check_auth_blocking = Mock( + return_value=make_awaitable(None) + ) self._rlsn._store.user_last_seen_monthly_active = Mock( return_value=make_awaitable(None) ) @@ -170,7 +176,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): Test that when server is over MAU limit and alerting is suppressed, then an alert message is not sent into the room """ - self._rlsn._auth.check_auth_blocking = Mock( + self._rlsn._auth_blocking.check_auth_blocking = Mock( return_value=make_awaitable(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER @@ -185,7 +191,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ Test that when a server is disabled, that MAU limit alerting is ignored. """ - self._rlsn._auth.check_auth_blocking = Mock( + self._rlsn._auth_blocking.check_auth_blocking = Mock( return_value=make_awaitable(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED @@ -202,7 +208,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): When the room is already in a blocked state, test that when alerting is suppressed that the room is returned to an unblocked state. """ - self._rlsn._auth.check_auth_blocking = Mock( + self._rlsn._auth_blocking.check_auth_blocking = Mock( return_value=make_awaitable(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 8370a27195..78b83d97b6 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -13,7 +13,17 @@ # limitations under the License. import itertools -from typing import List +from typing import ( + Collection, + Dict, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, + TypeVar, +) import attr @@ -22,13 +32,13 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.room_versions import RoomVersions from synapse.event_auth import auth_types_for_event -from synapse.events import make_event_from_dict +from synapse.events import EventBase, make_event_from_dict from synapse.state.v2 import ( _get_auth_chain_difference, lexicographical_topological_sort, resolve_events_with_store, ) -from synapse.types import EventID +from synapse.types import EventID, StateMap from tests import unittest @@ -48,7 +58,7 @@ ORIGIN_SERVER_TS = 0 class FakeClock: - def sleep(self, msec): + def sleep(self, msec: float) -> "defer.Deferred[None]": return defer.succeed(None) @@ -60,7 +70,14 @@ class FakeEvent: as domain. """ - def __init__(self, id, sender, type, state_key, content): + def __init__( + self, + id: str, + sender: str, + type: str, + state_key: Optional[str], + content: Mapping[str, object], + ): self.node_id = id self.event_id = EventID(id, "example.com").to_string() self.sender = sender @@ -69,12 +86,12 @@ class FakeEvent: self.content = content self.room_id = ROOM_ID - def to_event(self, auth_events, prev_events): + def to_event(self, auth_events: List[str], prev_events: List[str]) -> EventBase: """Given the auth_events and prev_events, convert to a Frozen Event Args: - auth_events (list[str]): list of event_ids - prev_events (list[str]): list of event_ids + auth_events: list of event_ids + prev_events: list of event_ids Returns: FrozenEvent @@ -164,7 +181,7 @@ INITIAL_EDGES = ["START", "IMZ", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE"] class StateTestCase(unittest.TestCase): - def test_ban_vs_pl(self): + def test_ban_vs_pl(self) -> None: events = [ FakeEvent( id="PA", @@ -202,7 +219,7 @@ class StateTestCase(unittest.TestCase): self.do_check(events, edges, expected_state_ids) - def test_join_rule_evasion(self): + def test_join_rule_evasion(self) -> None: events = [ FakeEvent( id="JR", @@ -226,7 +243,7 @@ class StateTestCase(unittest.TestCase): self.do_check(events, edges, expected_state_ids) - def test_offtopic_pl(self): + def test_offtopic_pl(self) -> None: events = [ FakeEvent( id="PA", @@ -257,7 +274,7 @@ class StateTestCase(unittest.TestCase): self.do_check(events, edges, expected_state_ids) - def test_topic_basic(self): + def test_topic_basic(self) -> None: events = [ FakeEvent( id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={} @@ -297,7 +314,7 @@ class StateTestCase(unittest.TestCase): self.do_check(events, edges, expected_state_ids) - def test_topic_reset(self): + def test_topic_reset(self) -> None: events = [ FakeEvent( id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={} @@ -327,7 +344,7 @@ class StateTestCase(unittest.TestCase): self.do_check(events, edges, expected_state_ids) - def test_topic(self): + def test_topic(self) -> None: events = [ FakeEvent( id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={} @@ -380,7 +397,7 @@ class StateTestCase(unittest.TestCase): self.do_check(events, edges, expected_state_ids) - def test_mainline_sort(self): + def test_mainline_sort(self) -> None: """Tests that the mainline ordering works correctly.""" events = [ @@ -434,22 +451,26 @@ class StateTestCase(unittest.TestCase): self.do_check(events, edges, expected_state_ids) - def do_check(self, events, edges, expected_state_ids): + def do_check( + self, + events: List[FakeEvent], + edges: List[List[str]], + expected_state_ids: List[str], + ) -> None: """Take a list of events and edges and calculate the state of the graph at END, and asserts it matches `expected_state_ids` Args: - events (list[FakeEvent]) - edges (list[list[str]]): A list of chains of event edges, e.g. + events + edges: A list of chains of event edges, e.g. `[[A, B, C]]` are edges A->B and B->C. - expected_state_ids (list[str]): The expected state at END, (excluding + expected_state_ids: The expected state at END, (excluding the keys that haven't changed since START). """ # We want to sort the events into topological order for processing. - graph = {} + graph: Dict[str, Set[str]] = {} - # node_id -> FakeEvent - fake_event_map = {} + fake_event_map: Dict[str, FakeEvent] = {} for ev in itertools.chain(INITIAL_EVENTS, events): graph[ev.node_id] = set() @@ -462,10 +483,8 @@ class StateTestCase(unittest.TestCase): for a, b in pairwise(edge_list): graph[a].add(b) - # event_id -> FrozenEvent - event_map = {} - # node_id -> state - state_at_event = {} + event_map: Dict[str, EventBase] = {} + state_at_event: Dict[str, StateMap[str]] = {} # We copy the map as the sort consumes the graph graph_copy = {k: set(v) for k, v in graph.items()} @@ -496,7 +515,16 @@ class StateTestCase(unittest.TestCase): if fake_event.state_key is not None: state_after[(fake_event.type, fake_event.state_key)] = event_id - auth_types = set(auth_types_for_event(RoomVersions.V6, fake_event)) + # This type ignore is a bit sad. Things we have tried: + # 1. Define a `GenericEvent` Protocol satisfied by FakeEvent, EventBase and + # EventBuilder. But this is Hard because the relevant attributes are + # DictProperty[T] descriptors on EventBase but normal Ts on FakeEvent. + # 2. Define a `GenericEvent` Protocol describing `FakeEvent` only, and + # change this function to accept Union[Event, EventBase, EventBuilder]. + # This seems reasonable to me, but mypy isn't happy. I think that's + # a mypy bug, see https://github.com/python/mypy/issues/5570 + # Instead, resort to a type-ignore. + auth_types = set(auth_types_for_event(RoomVersions.V6, fake_event)) # type: ignore[arg-type] auth_events = [] for key in auth_types: @@ -530,8 +558,14 @@ class StateTestCase(unittest.TestCase): class LexicographicalTestCase(unittest.TestCase): - def test_simple(self): - graph = {"l": {"o"}, "m": {"n", "o"}, "n": {"o"}, "o": set(), "p": {"o"}} + def test_simple(self) -> None: + graph: Dict[str, Set[str]] = { + "l": {"o"}, + "m": {"n", "o"}, + "n": {"o"}, + "o": set(), + "p": {"o"}, + } res = list(lexicographical_topological_sort(graph, key=lambda x: x)) @@ -539,7 +573,7 @@ class LexicographicalTestCase(unittest.TestCase): class SimpleParamStateTestCase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: # We build up a simple DAG. event_map = {} @@ -627,7 +661,7 @@ class SimpleParamStateTestCase(unittest.TestCase): ] } - def test_event_map_none(self): + def test_event_map_none(self) -> None: # Test that we correctly handle passing `None` as the event_map state_d = resolve_events_with_store( @@ -649,7 +683,7 @@ class AuthChainDifferenceTestCase(unittest.TestCase): events. """ - def test_simple(self): + def test_simple(self) -> None: # Test getting the auth difference for a simple chain with a single # unpersisted event: # @@ -695,7 +729,7 @@ class AuthChainDifferenceTestCase(unittest.TestCase): self.assertEqual(difference, {c.event_id}) - def test_multiple_unpersisted_chain(self): + def test_multiple_unpersisted_chain(self) -> None: # Test getting the auth difference for a simple chain with multiple # unpersisted events: # @@ -752,7 +786,7 @@ class AuthChainDifferenceTestCase(unittest.TestCase): self.assertEqual(difference, {d.event_id, c.event_id}) - def test_unpersisted_events_different_sets(self): + def test_unpersisted_events_different_sets(self) -> None: # Test getting the auth difference for with multiple unpersisted events # in different branches: # @@ -820,7 +854,10 @@ class AuthChainDifferenceTestCase(unittest.TestCase): self.assertEqual(difference, {d.event_id, e.event_id}) -def pairwise(iterable): +T = TypeVar("T") + + +def pairwise(iterable: Iterable[T]) -> Iterable[Tuple[T, T]]: "s -> (s0,s1), (s1,s2), (s2, s3), ..." a, b = itertools.tee(iterable) next(b, None) @@ -829,24 +866,26 @@ def pairwise(iterable): @attr.s class TestStateResolutionStore: - event_map = attr.ib() + event_map: Dict[str, EventBase] = attr.ib() - def get_events(self, event_ids, allow_rejected=False): + def get_events( + self, event_ids: Collection[str], allow_rejected: bool = False + ) -> "defer.Deferred[Dict[str, EventBase]]": """Get events from the database Args: - event_ids (list): The event_ids of the events to fetch - allow_rejected (bool): If True return rejected events. + event_ids: The event_ids of the events to fetch + allow_rejected: If True return rejected events. Returns: - Deferred[dict[str, FrozenEvent]]: Dict from event_id to event. + Dict from event_id to event. """ return defer.succeed( {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map} ) - def _get_auth_chain(self, event_ids: List[str]) -> List[str]: + def _get_auth_chain(self, event_ids: Iterable[str]) -> List[str]: """Gets the full auth chain for a set of events (including rejected events). @@ -880,7 +919,9 @@ class TestStateResolutionStore: return list(result) - def get_auth_chain_difference(self, room_id, auth_sets): + def get_auth_chain_difference( + self, room_id: str, auth_sets: List[Set[str]] + ) -> "defer.Deferred[Set[str]]": chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets] common = set(chains[0]).intersection(*chains[1:]) diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index e2c506e5a4..229ecd84a6 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -15,10 +15,12 @@ import unittest from typing import Optional +from parameterized import parameterized + from synapse import event_auth from synapse.api.constants import EventContentFields from synapse.api.errors import AuthError -from synapse.api.room_versions import RoomVersions +from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions from synapse.events import EventBase, make_event_from_dict from synapse.types import JsonDict, get_domain_from_id @@ -30,38 +32,39 @@ class EventAuthTestCase(unittest.TestCase): """ creator = "@creator:example.com" auth_events = [ - _create_event(creator), - _join_event(creator), + _create_event(RoomVersions.V9, creator), + _join_event(RoomVersions.V9, creator), ] # creator should be able to send state event_auth.check_auth_rules_for_event( - RoomVersions.V9, - _random_state_event(creator), + _random_state_event(RoomVersions.V9, creator), auth_events, ) # ... but a rejected join_rules event should cause it to be rejected - rejected_join_rules = _join_rules_event(creator, "public") + rejected_join_rules = _join_rules_event( + RoomVersions.V9, + creator, + "public", + ) rejected_join_rules.rejected_reason = "stinky" auth_events.append(rejected_join_rules) self.assertRaises( AuthError, event_auth.check_auth_rules_for_event, - RoomVersions.V9, - _random_state_event(creator), + _random_state_event(RoomVersions.V9, creator), auth_events, ) # ... even if there is *also* a good join rules - auth_events.append(_join_rules_event(creator, "public")) + auth_events.append(_join_rules_event(RoomVersions.V9, creator, "public")) self.assertRaises( AuthError, event_auth.check_auth_rules_for_event, - RoomVersions.V9, - _random_state_event(creator), + _random_state_event(RoomVersions.V9, creator), auth_events, ) @@ -73,15 +76,14 @@ class EventAuthTestCase(unittest.TestCase): creator = "@creator:example.com" joiner = "@joiner:example.com" auth_events = [ - _create_event(creator), - _join_event(creator), - _join_event(joiner), + _create_event(RoomVersions.V1, creator), + _join_event(RoomVersions.V1, creator), + _join_event(RoomVersions.V1, joiner), ] # creator should be able to send state event_auth.check_auth_rules_for_event( - RoomVersions.V1, - _random_state_event(creator), + _random_state_event(RoomVersions.V1, creator), auth_events, ) @@ -89,8 +91,7 @@ class EventAuthTestCase(unittest.TestCase): self.assertRaises( AuthError, event_auth.check_auth_rules_for_event, - RoomVersions.V1, - _random_state_event(joiner), + _random_state_event(RoomVersions.V1, joiner), auth_events, ) @@ -104,28 +105,28 @@ class EventAuthTestCase(unittest.TestCase): king = "@joiner2:example.com" auth_events = [ - _create_event(creator), - _join_event(creator), + _create_event(RoomVersions.V1, creator), + _join_event(RoomVersions.V1, creator), _power_levels_event( - creator, {"state_default": "30", "users": {pleb: "29", king: "30"}} + RoomVersions.V1, + creator, + {"state_default": "30", "users": {pleb: "29", king: "30"}}, ), - _join_event(pleb), - _join_event(king), + _join_event(RoomVersions.V1, pleb), + _join_event(RoomVersions.V1, king), ] # pleb should not be able to send state self.assertRaises( AuthError, event_auth.check_auth_rules_for_event, - RoomVersions.V1, - _random_state_event(pleb), + _random_state_event(RoomVersions.V1, pleb), auth_events, ), # king should be able to send state event_auth.check_auth_rules_for_event( - RoomVersions.V1, - _random_state_event(king), + _random_state_event(RoomVersions.V1, king), auth_events, ) @@ -134,37 +135,33 @@ class EventAuthTestCase(unittest.TestCase): creator = "@creator:example.com" other = "@other:example.com" auth_events = [ - _create_event(creator), - _join_event(creator), + _create_event(RoomVersions.V1, creator), + _join_event(RoomVersions.V1, creator), ] # creator should be able to send aliases event_auth.check_auth_rules_for_event( - RoomVersions.V1, - _alias_event(creator), + _alias_event(RoomVersions.V1, creator), auth_events, ) # Reject an event with no state key. with self.assertRaises(AuthError): event_auth.check_auth_rules_for_event( - RoomVersions.V1, - _alias_event(creator, state_key=""), + _alias_event(RoomVersions.V1, creator, state_key=""), auth_events, ) # If the domain of the sender does not match the state key, reject. with self.assertRaises(AuthError): event_auth.check_auth_rules_for_event( - RoomVersions.V1, - _alias_event(creator, state_key="test.com"), + _alias_event(RoomVersions.V1, creator, state_key="test.com"), auth_events, ) # Note that the member does *not* need to be in the room. event_auth.check_auth_rules_for_event( - RoomVersions.V1, - _alias_event(other), + _alias_event(RoomVersions.V1, other), auth_events, ) @@ -173,38 +170,35 @@ class EventAuthTestCase(unittest.TestCase): creator = "@creator:example.com" other = "@other:example.com" auth_events = [ - _create_event(creator), - _join_event(creator), + _create_event(RoomVersions.V6, creator), + _join_event(RoomVersions.V6, creator), ] # creator should be able to send aliases event_auth.check_auth_rules_for_event( - RoomVersions.V6, - _alias_event(creator), + _alias_event(RoomVersions.V6, creator), auth_events, ) # No particular checks are done on the state key. event_auth.check_auth_rules_for_event( - RoomVersions.V6, - _alias_event(creator, state_key=""), + _alias_event(RoomVersions.V6, creator, state_key=""), auth_events, ) event_auth.check_auth_rules_for_event( - RoomVersions.V6, - _alias_event(creator, state_key="test.com"), + _alias_event(RoomVersions.V6, creator, state_key="test.com"), auth_events, ) # Per standard auth rules, the member must be in the room. with self.assertRaises(AuthError): event_auth.check_auth_rules_for_event( - RoomVersions.V6, - _alias_event(other), + _alias_event(RoomVersions.V6, other), auth_events, ) - def test_msc2209(self): + @parameterized.expand([(RoomVersions.V1, True), (RoomVersions.V6, False)]) + def test_notifications(self, room_version: RoomVersion, allow_modification: bool): """ Notifications power levels get checked due to MSC2209. """ @@ -212,28 +206,26 @@ class EventAuthTestCase(unittest.TestCase): pleb = "@joiner:example.com" auth_events = [ - _create_event(creator), - _join_event(creator), + _create_event(room_version, creator), + _join_event(room_version, creator), _power_levels_event( - creator, {"state_default": "30", "users": {pleb: "30"}} + room_version, creator, {"state_default": "30", "users": {pleb: "30"}} ), - _join_event(pleb), + _join_event(room_version, pleb), ] - # pleb should be able to modify the notifications power level. - event_auth.check_auth_rules_for_event( - RoomVersions.V1, - _power_levels_event(pleb, {"notifications": {"room": 100}}), - auth_events, + pl_event = _power_levels_event( + room_version, pleb, {"notifications": {"room": 100}} ) - # But an MSC2209 room rejects this change. - with self.assertRaises(AuthError): - event_auth.check_auth_rules_for_event( - RoomVersions.V6, - _power_levels_event(pleb, {"notifications": {"room": 100}}), - auth_events, - ) + # on room V1, pleb should be able to modify the notifications power level. + if allow_modification: + event_auth.check_auth_rules_for_event(pl_event, auth_events) + + else: + # But an MSC2209 room rejects this change. + with self.assertRaises(AuthError): + event_auth.check_auth_rules_for_event(pl_event, auth_events) def test_join_rules_public(self): """ @@ -243,58 +235,60 @@ class EventAuthTestCase(unittest.TestCase): pleb = "@joiner:example.com" auth_events = { - ("m.room.create", ""): _create_event(creator), - ("m.room.member", creator): _join_event(creator), - ("m.room.join_rules", ""): _join_rules_event(creator, "public"), + ("m.room.create", ""): _create_event(RoomVersions.V6, creator), + ("m.room.member", creator): _join_event(RoomVersions.V6, creator), + ("m.room.join_rules", ""): _join_rules_event( + RoomVersions.V6, creator, "public" + ), } # Check join. event_auth.check_auth_rules_for_event( - RoomVersions.V6, - _join_event(pleb), + _join_event(RoomVersions.V6, pleb), auth_events.values(), ) # A user cannot be force-joined to a room. with self.assertRaises(AuthError): event_auth.check_auth_rules_for_event( - RoomVersions.V6, - _member_event(pleb, "join", sender=creator), + _member_event(RoomVersions.V6, pleb, "join", sender=creator), auth_events.values(), ) # Banned should be rejected. - auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban") + auth_events[("m.room.member", pleb)] = _member_event( + RoomVersions.V6, pleb, "ban" + ) with self.assertRaises(AuthError): event_auth.check_auth_rules_for_event( - RoomVersions.V6, - _join_event(pleb), + _join_event(RoomVersions.V6, pleb), auth_events.values(), ) # A user who left can re-join. - auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave") + auth_events[("m.room.member", pleb)] = _member_event( + RoomVersions.V6, pleb, "leave" + ) event_auth.check_auth_rules_for_event( - RoomVersions.V6, - _join_event(pleb), + _join_event(RoomVersions.V6, pleb), auth_events.values(), ) # A user can send a join if they're in the room. - auth_events[("m.room.member", pleb)] = _member_event(pleb, "join") + auth_events[("m.room.member", pleb)] = _member_event( + RoomVersions.V6, pleb, "join" + ) event_auth.check_auth_rules_for_event( - RoomVersions.V6, - _join_event(pleb), + _join_event(RoomVersions.V6, pleb), auth_events.values(), ) # A user can accept an invite. auth_events[("m.room.member", pleb)] = _member_event( - pleb, "invite", sender=creator + RoomVersions.V6, pleb, "invite", sender=creator ) event_auth.check_auth_rules_for_event( - RoomVersions.V6, - _join_event(pleb), + _join_event(RoomVersions.V6, pleb), auth_events.values(), ) @@ -306,64 +300,88 @@ class EventAuthTestCase(unittest.TestCase): pleb = "@joiner:example.com" auth_events = { - ("m.room.create", ""): _create_event(creator), - ("m.room.member", creator): _join_event(creator), - ("m.room.join_rules", ""): _join_rules_event(creator, "invite"), + ("m.room.create", ""): _create_event(RoomVersions.V6, creator), + ("m.room.member", creator): _join_event(RoomVersions.V6, creator), + ("m.room.join_rules", ""): _join_rules_event( + RoomVersions.V6, creator, "invite" + ), } # A join without an invite is rejected. with self.assertRaises(AuthError): event_auth.check_auth_rules_for_event( - RoomVersions.V6, - _join_event(pleb), + _join_event(RoomVersions.V6, pleb), auth_events.values(), ) # A user cannot be force-joined to a room. with self.assertRaises(AuthError): event_auth.check_auth_rules_for_event( - RoomVersions.V6, - _member_event(pleb, "join", sender=creator), + _member_event(RoomVersions.V6, pleb, "join", sender=creator), auth_events.values(), ) # Banned should be rejected. - auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban") + auth_events[("m.room.member", pleb)] = _member_event( + RoomVersions.V6, pleb, "ban" + ) with self.assertRaises(AuthError): event_auth.check_auth_rules_for_event( - RoomVersions.V6, - _join_event(pleb), + _join_event(RoomVersions.V6, pleb), auth_events.values(), ) # A user who left cannot re-join. - auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave") + auth_events[("m.room.member", pleb)] = _member_event( + RoomVersions.V6, pleb, "leave" + ) with self.assertRaises(AuthError): event_auth.check_auth_rules_for_event( - RoomVersions.V6, - _join_event(pleb), + _join_event(RoomVersions.V6, pleb), auth_events.values(), ) # A user can send a join if they're in the room. - auth_events[("m.room.member", pleb)] = _member_event(pleb, "join") + auth_events[("m.room.member", pleb)] = _member_event( + RoomVersions.V6, pleb, "join" + ) event_auth.check_auth_rules_for_event( - RoomVersions.V6, - _join_event(pleb), + _join_event(RoomVersions.V6, pleb), auth_events.values(), ) # A user can accept an invite. auth_events[("m.room.member", pleb)] = _member_event( - pleb, "invite", sender=creator + RoomVersions.V6, pleb, "invite", sender=creator ) event_auth.check_auth_rules_for_event( - RoomVersions.V6, - _join_event(pleb), + _join_event(RoomVersions.V6, pleb), auth_events.values(), ) - def test_join_rules_msc3083_restricted(self): + def test_join_rules_restricted_old_room(self) -> None: + """Old room versions should reject joins to restricted rooms""" + creator = "@creator:example.com" + pleb = "@joiner:example.com" + + auth_events = { + ("m.room.create", ""): _create_event(RoomVersions.V6, creator), + ("m.room.member", creator): _join_event(RoomVersions.V6, creator), + ("m.room.power_levels", ""): _power_levels_event( + RoomVersions.V6, creator, {"invite": 0} + ), + ("m.room.join_rules", ""): _join_rules_event( + RoomVersions.V6, creator, "restricted" + ), + } + + with self.assertRaises(AuthError): + event_auth.check_auth_rules_for_event( + _join_event(RoomVersions.V6, pleb), + auth_events.values(), + ) + + def test_join_rules_msc3083_restricted(self) -> None: """ Test joining a restricted room from MSC3083. @@ -377,29 +395,25 @@ class EventAuthTestCase(unittest.TestCase): pleb = "@joiner:example.com" auth_events = { - ("m.room.create", ""): _create_event(creator), - ("m.room.member", creator): _join_event(creator), - ("m.room.power_levels", ""): _power_levels_event(creator, {"invite": 0}), - ("m.room.join_rules", ""): _join_rules_event(creator, "restricted"), + ("m.room.create", ""): _create_event(RoomVersions.V8, creator), + ("m.room.member", creator): _join_event(RoomVersions.V8, creator), + ("m.room.power_levels", ""): _power_levels_event( + RoomVersions.V8, creator, {"invite": 0} + ), + ("m.room.join_rules", ""): _join_rules_event( + RoomVersions.V8, creator, "restricted" + ), } - # Older room versions don't understand this join rule - with self.assertRaises(AuthError): - event_auth.check_auth_rules_for_event( - RoomVersions.V6, - _join_event(pleb), - auth_events.values(), - ) - # A properly formatted join event should work. authorised_join_event = _join_event( + RoomVersions.V8, pleb, additional_content={ EventContentFields.AUTHORISING_USER: "@creator:example.com" }, ) event_auth.check_auth_rules_for_event( - RoomVersions.V8, authorised_join_event, auth_events.values(), ) @@ -408,14 +422,16 @@ class EventAuthTestCase(unittest.TestCase): # are done properly). pl_auth_events = auth_events.copy() pl_auth_events[("m.room.power_levels", "")] = _power_levels_event( - creator, {"invite": 100, "users": {"@inviter:foo.test": 150}} + RoomVersions.V8, + creator, + {"invite": 100, "users": {"@inviter:foo.test": 150}}, ) pl_auth_events[("m.room.member", "@inviter:foo.test")] = _join_event( - "@inviter:foo.test" + RoomVersions.V8, "@inviter:foo.test" ) event_auth.check_auth_rules_for_event( - RoomVersions.V8, _join_event( + RoomVersions.V8, pleb, additional_content={ EventContentFields.AUTHORISING_USER: "@inviter:foo.test" @@ -427,20 +443,21 @@ class EventAuthTestCase(unittest.TestCase): # A join which is missing an authorised server is rejected. with self.assertRaises(AuthError): event_auth.check_auth_rules_for_event( - RoomVersions.V8, - _join_event(pleb), + _join_event(RoomVersions.V8, pleb), auth_events.values(), ) # An join authorised by a user who is not in the room is rejected. pl_auth_events = auth_events.copy() pl_auth_events[("m.room.power_levels", "")] = _power_levels_event( - creator, {"invite": 100, "users": {"@other:example.com": 150}} + RoomVersions.V8, + creator, + {"invite": 100, "users": {"@other:example.com": 150}}, ) with self.assertRaises(AuthError): event_auth.check_auth_rules_for_event( - RoomVersions.V8, _join_event( + RoomVersions.V8, pleb, additional_content={ EventContentFields.AUTHORISING_USER: "@other:example.com" @@ -453,8 +470,8 @@ class EventAuthTestCase(unittest.TestCase): # *would* be valid, but is sent be a different user.) with self.assertRaises(AuthError): event_auth.check_auth_rules_for_event( - RoomVersions.V8, _member_event( + RoomVersions.V8, pleb, "join", sender=creator, @@ -466,39 +483,41 @@ class EventAuthTestCase(unittest.TestCase): ) # Banned should be rejected. - auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban") + auth_events[("m.room.member", pleb)] = _member_event( + RoomVersions.V8, pleb, "ban" + ) with self.assertRaises(AuthError): event_auth.check_auth_rules_for_event( - RoomVersions.V8, authorised_join_event, auth_events.values(), ) # A user who left can re-join. - auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave") + auth_events[("m.room.member", pleb)] = _member_event( + RoomVersions.V8, pleb, "leave" + ) event_auth.check_auth_rules_for_event( - RoomVersions.V8, authorised_join_event, auth_events.values(), ) # A user can send a join if they're in the room. (This doesn't need to # be authorised since the user is already joined.) - auth_events[("m.room.member", pleb)] = _member_event(pleb, "join") + auth_events[("m.room.member", pleb)] = _member_event( + RoomVersions.V8, pleb, "join" + ) event_auth.check_auth_rules_for_event( - RoomVersions.V8, - _join_event(pleb), + _join_event(RoomVersions.V8, pleb), auth_events.values(), ) # A user can accept an invite. (This doesn't need to be authorised since # the user was invited.) auth_events[("m.room.member", pleb)] = _member_event( - pleb, "invite", sender=creator + RoomVersions.V8, pleb, "invite", sender=creator ) event_auth.check_auth_rules_for_event( - RoomVersions.V8, - _join_event(pleb), + _join_event(RoomVersions.V8, pleb), auth_events.values(), ) @@ -508,20 +527,25 @@ class EventAuthTestCase(unittest.TestCase): TEST_ROOM_ID = "!test:room" -def _create_event(user_id: str) -> EventBase: +def _create_event( + room_version: RoomVersion, + user_id: str, +) -> EventBase: return make_event_from_dict( { "room_id": TEST_ROOM_ID, - "event_id": _get_event_id(), + **_maybe_get_event_id_dict_for_room_version(room_version), "type": "m.room.create", "state_key": "", "sender": user_id, "content": {"creator": user_id}, - } + }, + room_version=room_version, ) def _member_event( + room_version: RoomVersion, user_id: str, membership: str, sender: Optional[str] = None, @@ -530,79 +554,102 @@ def _member_event( return make_event_from_dict( { "room_id": TEST_ROOM_ID, - "event_id": _get_event_id(), + **_maybe_get_event_id_dict_for_room_version(room_version), "type": "m.room.member", "sender": sender or user_id, "state_key": user_id, "content": {"membership": membership, **(additional_content or {})}, "prev_events": [], - } + }, + room_version=room_version, ) -def _join_event(user_id: str, additional_content: Optional[dict] = None) -> EventBase: - return _member_event(user_id, "join", additional_content=additional_content) +def _join_event( + room_version: RoomVersion, + user_id: str, + additional_content: Optional[dict] = None, +) -> EventBase: + return _member_event( + room_version, + user_id, + "join", + additional_content=additional_content, + ) -def _power_levels_event(sender: str, content: JsonDict) -> EventBase: +def _power_levels_event( + room_version: RoomVersion, + sender: str, + content: JsonDict, +) -> EventBase: return make_event_from_dict( { "room_id": TEST_ROOM_ID, - "event_id": _get_event_id(), + **_maybe_get_event_id_dict_for_room_version(room_version), "type": "m.room.power_levels", "sender": sender, "state_key": "", "content": content, - } + }, + room_version=room_version, ) -def _alias_event(sender: str, **kwargs) -> EventBase: +def _alias_event(room_version: RoomVersion, sender: str, **kwargs) -> EventBase: data = { "room_id": TEST_ROOM_ID, - "event_id": _get_event_id(), + **_maybe_get_event_id_dict_for_room_version(room_version), "type": "m.room.aliases", "sender": sender, "state_key": get_domain_from_id(sender), "content": {"aliases": []}, } data.update(**kwargs) - return make_event_from_dict(data) + return make_event_from_dict(data, room_version=room_version) -def _random_state_event(sender: str) -> EventBase: +def _random_state_event(room_version: RoomVersion, sender: str) -> EventBase: return make_event_from_dict( { "room_id": TEST_ROOM_ID, - "event_id": _get_event_id(), + **_maybe_get_event_id_dict_for_room_version(room_version), "type": "test.state", "sender": sender, "state_key": "", "content": {"membership": "join"}, - } + }, + room_version=room_version, ) -def _join_rules_event(sender: str, join_rule: str) -> EventBase: +def _join_rules_event( + room_version: RoomVersion, sender: str, join_rule: str +) -> EventBase: return make_event_from_dict( { "room_id": TEST_ROOM_ID, - "event_id": _get_event_id(), + **_maybe_get_event_id_dict_for_room_version(room_version), "type": "m.room.join_rules", "sender": sender, "state_key": "", "content": { "join_rule": join_rule, }, - } + }, + room_version=room_version, ) event_count = 0 -def _get_event_id() -> str: +def _maybe_get_event_id_dict_for_room_version(room_version: RoomVersion) -> dict: + """If this room version needs it, generate an event id""" + if room_version.event_format != EventFormatVersions.V1: + return {} + global event_count c = event_count event_count += 1 - return "!%i:example.com" % (c,) + return {"event_id": "!%i:example.com" % (c,)} diff --git a/tests/test_server.py b/tests/test_server.py index 0f1eb43cbc..847432f791 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -34,7 +34,7 @@ from synapse.types import JsonDict from synapse.util import Clock from tests import unittest -from tests.http.server._base import EndpointCancellationTestHelperMixin +from tests.http.server._base import test_disconnect from tests.server import ( FakeSite, ThreadedMemoryReactorClock, @@ -407,7 +407,7 @@ class CancellableDirectServeHtmlResource(DirectServeHtmlResource): return HTTPStatus.OK, b"ok" -class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMixin): +class DirectServeJsonResourceCancellationTests(unittest.TestCase): """Tests for `DirectServeJsonResource` cancellation.""" def setUp(self): @@ -421,7 +421,7 @@ class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMix channel = make_request( self.reactor, self.site, "GET", "/sleep", await_result=False ) - self._test_disconnect( + test_disconnect( self.reactor, channel, expect_cancellation=True, @@ -433,7 +433,7 @@ class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMix channel = make_request( self.reactor, self.site, "POST", "/sleep", await_result=False ) - self._test_disconnect( + test_disconnect( self.reactor, channel, expect_cancellation=False, @@ -441,7 +441,7 @@ class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMix ) -class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMixin): +class DirectServeHtmlResourceCancellationTests(unittest.TestCase): """Tests for `DirectServeHtmlResource` cancellation.""" def setUp(self): @@ -455,7 +455,7 @@ class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMix channel = make_request( self.reactor, self.site, "GET", "/sleep", await_result=False ) - self._test_disconnect( + test_disconnect( self.reactor, channel, expect_cancellation=True, @@ -467,6 +467,6 @@ class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMix channel = make_request( self.reactor, self.site, "POST", "/sleep", await_result=False ) - self._test_disconnect( + test_disconnect( self.reactor, channel, expect_cancellation=False, expected_body=b"ok" ) diff --git a/tests/test_state.py b/tests/test_state.py index 95f81bebae..b005dd8d0f 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Collection, Dict, List, Optional +from typing import Collection, Dict, List, Optional, cast from unittest.mock import Mock from twisted.internet import defer @@ -22,6 +22,8 @@ from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict from synapse.events.snapshot import EventContext from synapse.state import StateHandler, StateResolutionHandler +from synapse.util import Clock +from synapse.util.macaroons import MacaroonGenerator from tests import unittest @@ -190,13 +192,18 @@ class StateTestCase(unittest.TestCase): "get_clock", "get_state_resolution_handler", "get_account_validity_handler", + "get_macaroon_generator", "hostname", ] ) + clock = cast(Clock, MockClock()) hs.config = default_config("tesths", True) hs.get_datastores.return_value = Mock(main=self.dummy_store) hs.get_state_handler.return_value = None - hs.get_clock.return_value = MockClock() + hs.get_clock.return_value = clock + hs.get_macaroon_generator.return_value = MacaroonGenerator( + clock, "tesths", b"verysecret" + ) hs.get_auth.return_value = Auth(hs) hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs) hs.get_storage_controllers.return_value = storage_controllers diff --git a/tests/test_types.py b/tests/test_types.py index 0b10dae848..d8d82a517e 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -26,10 +26,21 @@ class UserIDTestCase(unittest.HomeserverTestCase): self.assertEqual("test", user.domain) self.assertEqual(True, self.hs.is_mine(user)) - def test_pase_empty(self): + def test_parse_rejects_empty_id(self): with self.assertRaises(SynapseError): UserID.from_string("") + def test_parse_rejects_missing_sigil(self): + with self.assertRaises(SynapseError): + UserID.from_string("alice:example.com") + + def test_parse_rejects_missing_separator(self): + with self.assertRaises(SynapseError): + UserID.from_string("@alice.example.com") + + def test_validation_rejects_missing_domain(self): + self.assertFalse(UserID.is_valid("@alice:")) + def test_build(self): user = UserID("5678efgh", "my.domain") diff --git a/tests/unittest.py b/tests/unittest.py index e7f255b4fa..c645dd3563 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -315,7 +315,7 @@ class HomeserverTestCase(TestCase): "is_guest": False, } - async def get_user_by_req(request, allow_guest=False, rights="access"): + async def get_user_by_req(request, allow_guest=False): assert self.helper.auth_user_id is not None return create_requester( UserID.from_string(self.helper.auth_user_id), diff --git a/tests/util/test_macaroons.py b/tests/util/test_macaroons.py new file mode 100644 index 0000000000..32125f7bb7 --- /dev/null +++ b/tests/util/test_macaroons.py @@ -0,0 +1,146 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pymacaroons.exceptions import MacaroonVerificationFailedException + +from synapse.util.macaroons import MacaroonGenerator, OidcSessionData + +from tests.server import get_clock +from tests.unittest import TestCase + + +class MacaroonGeneratorTestCase(TestCase): + def setUp(self): + self.reactor, hs_clock = get_clock() + self.macaroon_generator = MacaroonGenerator(hs_clock, "tesths", b"verysecret") + self.other_macaroon_generator = MacaroonGenerator( + hs_clock, "tesths", b"anothersecretkey" + ) + + def test_guest_access_token(self): + """Test the generation and verification of guest access tokens""" + token = self.macaroon_generator.generate_guest_access_token("@user:tesths") + user_id = self.macaroon_generator.verify_guest_token(token) + self.assertEqual(user_id, "@user:tesths") + + # Raises with another secret key + with self.assertRaises(MacaroonVerificationFailedException): + self.other_macaroon_generator.verify_guest_token(token) + + # Check that an old access token without the guest caveat does not work + macaroon = self.macaroon_generator._generate_base_macaroon("access") + macaroon.add_first_party_caveat(f"user_id = {user_id}") + macaroon.add_first_party_caveat("nonce = 0123456789abcdef") + token = macaroon.serialize() + + with self.assertRaises(MacaroonVerificationFailedException): + self.macaroon_generator.verify_guest_token(token) + + def test_delete_pusher_token(self): + """Test the generation and verification of delete_pusher tokens""" + token = self.macaroon_generator.generate_delete_pusher_token( + "@user:tesths", "m.mail", "john@example.com" + ) + user_id = self.macaroon_generator.verify_delete_pusher_token( + token, "m.mail", "john@example.com" + ) + self.assertEqual(user_id, "@user:tesths") + + # Raises with another secret key + with self.assertRaises(MacaroonVerificationFailedException): + self.other_macaroon_generator.verify_delete_pusher_token( + token, "m.mail", "john@example.com" + ) + + # Raises when verifying for another pushkey + with self.assertRaises(MacaroonVerificationFailedException): + self.macaroon_generator.verify_delete_pusher_token( + token, "m.mail", "other@example.com" + ) + + # Raises when verifying for another app_id + with self.assertRaises(MacaroonVerificationFailedException): + self.macaroon_generator.verify_delete_pusher_token( + token, "somethingelse", "john@example.com" + ) + + # Check that an old token without the app_id and pushkey still works + macaroon = self.macaroon_generator._generate_base_macaroon("delete_pusher") + macaroon.add_first_party_caveat("user_id = @user:tesths") + token = macaroon.serialize() + user_id = self.macaroon_generator.verify_delete_pusher_token( + token, "m.mail", "john@example.com" + ) + self.assertEqual(user_id, "@user:tesths") + + def test_short_term_login_token(self): + """Test the generation and verification of short-term login tokens""" + token = self.macaroon_generator.generate_short_term_login_token( + user_id="@user:tesths", + auth_provider_id="oidc", + auth_provider_session_id="sid", + duration_in_ms=2 * 60 * 1000, + ) + + info = self.macaroon_generator.verify_short_term_login_token(token) + self.assertEqual(info.user_id, "@user:tesths") + self.assertEqual(info.auth_provider_id, "oidc") + self.assertEqual(info.auth_provider_session_id, "sid") + + # Raises with another secret key + with self.assertRaises(MacaroonVerificationFailedException): + self.other_macaroon_generator.verify_short_term_login_token(token) + + # Wait a minute + self.reactor.pump([60]) + # Shouldn't raise + self.macaroon_generator.verify_short_term_login_token(token) + # Wait another minute + self.reactor.pump([60]) + # Should raise since it expired + with self.assertRaises(MacaroonVerificationFailedException): + self.macaroon_generator.verify_short_term_login_token(token) + + def test_oidc_session_token(self): + """Test the generation and verification of OIDC session cookies""" + state = "arandomstate" + session_data = OidcSessionData( + idp_id="oidc", + nonce="nonce", + client_redirect_url="https://example.com/", + ui_auth_session_id="", + ) + token = self.macaroon_generator.generate_oidc_session_token( + state, session_data, duration_in_ms=2 * 60 * 1000 + ).encode("utf-8") + info = self.macaroon_generator.verify_oidc_session_token(token, state) + self.assertEqual(session_data, info) + + # Raises with another secret key + with self.assertRaises(MacaroonVerificationFailedException): + self.other_macaroon_generator.verify_oidc_session_token(token, state) + + # Should raise with another state + with self.assertRaises(MacaroonVerificationFailedException): + self.macaroon_generator.verify_oidc_session_token(token, "anotherstate") + + # Wait a minute + self.reactor.pump([60]) + # Shouldn't raise + self.macaroon_generator.verify_oidc_session_token(token, state) + # Wait another minute + self.reactor.pump([60]) + # Should raise since it expired + with self.assertRaises(MacaroonVerificationFailedException): + self.macaroon_generator.verify_oidc_session_token(token, state) |