diff options
Diffstat (limited to 'tests/handlers')
-rw-r--r-- | tests/handlers/test_admin.py | 9 | ||||
-rw-r--r-- | tests/handlers/test_appservice.py | 33 | ||||
-rw-r--r-- | tests/handlers/test_auth.py | 133 | ||||
-rw-r--r-- | tests/handlers/test_cas.py | 60 | ||||
-rw-r--r-- | tests/handlers/test_device.py | 4 | ||||
-rw-r--r-- | tests/handlers/test_directory.py | 20 | ||||
-rw-r--r-- | tests/handlers/test_e2e_keys.py | 238 | ||||
-rw-r--r-- | tests/handlers/test_e2e_room_keys.py | 347 | ||||
-rw-r--r-- | tests/handlers/test_federation.py | 54 | ||||
-rw-r--r-- | tests/handlers/test_message.py | 13 | ||||
-rw-r--r-- | tests/handlers/test_oidc.py | 158 | ||||
-rw-r--r-- | tests/handlers/test_password_providers.py | 11 | ||||
-rw-r--r-- | tests/handlers/test_presence.py | 3 | ||||
-rw-r--r-- | tests/handlers/test_profile.py | 125 | ||||
-rw-r--r-- | tests/handlers/test_saml.py | 64 | ||||
-rw-r--r-- | tests/handlers/test_typing.py | 12 | ||||
-rw-r--r-- | tests/handlers/test_user_directory.py | 12 |
17 files changed, 649 insertions, 647 deletions
diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py index 5c2b4de1a6..a01fdd0839 100644 --- a/tests/handlers/test_admin.py +++ b/tests/handlers/test_admin.py @@ -44,8 +44,7 @@ class ExfiltrateData(unittest.HomeserverTestCase): self.token2 = self.login("user2", "password") def test_single_public_joined_room(self): - """Test that we write *all* events for a public room - """ + """Test that we write *all* events for a public room""" room_id = self.helper.create_room_as( self.user1, tok=self.token1, is_public=True ) @@ -116,8 +115,7 @@ class ExfiltrateData(unittest.HomeserverTestCase): self.assertEqual(counter[(EventTypes.Member, self.user2)], 1) def test_single_left_room(self): - """Tests that we don't see events in the room after we leave. - """ + """Tests that we don't see events in the room after we leave.""" room_id = self.helper.create_room_as(self.user1, tok=self.token1) self.helper.send(room_id, body="Hello!", tok=self.token1) self.helper.join(room_id, self.user2, tok=self.token2) @@ -190,8 +188,7 @@ class ExfiltrateData(unittest.HomeserverTestCase): self.assertEqual(counter[(EventTypes.Member, self.user2)], 3) def test_invite(self): - """Tests that pending invites get handled correctly. - """ + """Tests that pending invites get handled correctly.""" room_id = self.helper.create_room_as(self.user1, tok=self.token1) self.helper.send(room_id, body="Hello!", tok=self.token1) self.helper.invite(room_id, self.user1, self.user2, tok=self.token1) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 53763cd0f9..d5d3fdd99a 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -35,8 +35,8 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_scheduler = Mock() hs = Mock() hs.get_datastore.return_value = self.mock_store - self.mock_store.get_received_ts.return_value = defer.succeed(0) - self.mock_store.set_appservice_last_pos.return_value = defer.succeed(None) + self.mock_store.get_received_ts.return_value = make_awaitable(0) + self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None) hs.get_application_service_api.return_value = self.mock_as_api hs.get_application_service_scheduler.return_value = self.mock_scheduler hs.get_clock.return_value = MockClock() @@ -50,16 +50,16 @@ class AppServiceHandlerTestCase(unittest.TestCase): self._mkservice(is_interested=False), ] - self.mock_as_api.query_user.return_value = defer.succeed(True) + self.mock_as_api.query_user.return_value = make_awaitable(True) self.mock_store.get_app_services.return_value = services - self.mock_store.get_user_by_id.return_value = defer.succeed([]) + self.mock_store.get_user_by_id.return_value = make_awaitable([]) event = Mock( sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar" ) self.mock_store.get_new_events_for_appservice.side_effect = [ - defer.succeed((0, [event])), - defer.succeed((0, [])), + make_awaitable((0, [event])), + make_awaitable((0, [])), ] self.handler.notify_interested_services(RoomStreamToken(None, 0)) @@ -72,13 +72,13 @@ class AppServiceHandlerTestCase(unittest.TestCase): services = [self._mkservice(is_interested=True)] services[0].is_interested_in_user.return_value = True self.mock_store.get_app_services.return_value = services - self.mock_store.get_user_by_id.return_value = defer.succeed(None) + self.mock_store.get_user_by_id.return_value = make_awaitable(None) event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") - self.mock_as_api.query_user.return_value = defer.succeed(True) + self.mock_as_api.query_user.return_value = make_awaitable(True) self.mock_store.get_new_events_for_appservice.side_effect = [ - defer.succeed((0, [event])), - defer.succeed((0, [])), + make_awaitable((0, [event])), + make_awaitable((0, [])), ] self.handler.notify_interested_services(RoomStreamToken(None, 0)) @@ -90,13 +90,13 @@ class AppServiceHandlerTestCase(unittest.TestCase): services = [self._mkservice(is_interested=True)] services[0].is_interested_in_user.return_value = True self.mock_store.get_app_services.return_value = services - self.mock_store.get_user_by_id.return_value = defer.succeed({"name": user_id}) + self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id}) event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") - self.mock_as_api.query_user.return_value = defer.succeed(True) + self.mock_as_api.query_user.return_value = make_awaitable(True) self.mock_store.get_new_events_for_appservice.side_effect = [ - defer.succeed((0, [event])), - defer.succeed((0, [])), + make_awaitable((0, [event])), + make_awaitable((0, [])), ] self.handler.notify_interested_services(RoomStreamToken(None, 0)) @@ -106,7 +106,6 @@ class AppServiceHandlerTestCase(unittest.TestCase): "query_user called when it shouldn't have been.", ) - @defer.inlineCallbacks def test_query_room_alias_exists(self): room_alias_str = "#foo:bar" room_alias = Mock() @@ -127,8 +126,8 @@ class AppServiceHandlerTestCase(unittest.TestCase): Mock(room_id=room_id, servers=servers) ) - result = yield defer.ensureDeferred( - self.handler.query_room_alias_exists(room_alias) + result = self.successResultOf( + defer.ensureDeferred(self.handler.query_room_alias_exists(room_alias)) ) self.mock_as_api.query_alias.assert_called_once_with( diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index e24ce81284..0e42013bb9 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -16,28 +16,21 @@ from mock import Mock import pymacaroons -from twisted.internet import defer - -import synapse -import synapse.api.errors -from synapse.api.errors import ResourceLimitError +from synapse.api.errors import AuthError, ResourceLimitError from tests import unittest from tests.test_utils import make_awaitable -from tests.utils import setup_test_homeserver -class AuthTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - self.hs = yield setup_test_homeserver(self.addCleanup) - self.auth_handler = self.hs.get_auth_handler() - self.macaroon_generator = self.hs.get_macaroon_generator() +class AuthTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs): + self.auth_handler = hs.get_auth_handler() + self.macaroon_generator = hs.get_macaroon_generator() # MAU tests # AuthBlocking reads from the hs' config on initialization. We need to # modify its config instead of the hs' - self.auth_blocking = self.hs.get_auth()._auth_blocking + self.auth_blocking = hs.get_auth()._auth_blocking self.auth_blocking._max_mau_value = 50 self.small_number_of_users = 1 @@ -52,8 +45,6 @@ class AuthTestCase(unittest.TestCase): self.fail("some_user was not in %s" % macaroon.inspect()) def test_macaroon_caveats(self): - self.hs.get_clock().now = 5000 - token = self.macaroon_generator.generate_access_token("a_user") macaroon = pymacaroons.Macaroon.deserialize(token) @@ -76,29 +67,25 @@ class AuthTestCase(unittest.TestCase): v.satisfy_general(verify_nonce) v.verify(macaroon, self.hs.config.macaroon_secret_key) - @defer.inlineCallbacks def test_short_term_login_token_gives_user_id(self): - self.hs.get_clock().now = 1000 - token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000) - user_id = yield defer.ensureDeferred( + user_id = self.get_success( self.auth_handler.validate_short_term_login_token_and_get_user_id(token) ) self.assertEqual("a_user", user_id) # when we advance the clock, the token should be rejected - self.hs.get_clock().now = 6000 - with self.assertRaises(synapse.api.errors.AuthError): - yield defer.ensureDeferred( - self.auth_handler.validate_short_term_login_token_and_get_user_id(token) - ) + self.reactor.advance(6) + self.get_failure( + self.auth_handler.validate_short_term_login_token_and_get_user_id(token), + AuthError, + ) - @defer.inlineCallbacks def test_short_term_login_token_cannot_replace_user_id(self): token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000) macaroon = pymacaroons.Macaroon.deserialize(token) - user_id = yield defer.ensureDeferred( + user_id = self.get_success( self.auth_handler.validate_short_term_login_token_and_get_user_id( macaroon.serialize() ) @@ -109,102 +96,90 @@ class AuthTestCase(unittest.TestCase): # user_id. macaroon.add_first_party_caveat("user_id = b_user") - with self.assertRaises(synapse.api.errors.AuthError): - yield defer.ensureDeferred( - self.auth_handler.validate_short_term_login_token_and_get_user_id( - macaroon.serialize() - ) - ) + self.get_failure( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + macaroon.serialize() + ), + AuthError, + ) - @defer.inlineCallbacks def test_mau_limits_disabled(self): self.auth_blocking._limit_usage_by_mau = False # Ensure does not throw exception - yield defer.ensureDeferred( + self.get_success( self.auth_handler.get_access_token_for_user_id( "user_a", device_id=None, valid_until_ms=None ) ) - yield defer.ensureDeferred( + self.get_success( self.auth_handler.validate_short_term_login_token_and_get_user_id( self._get_macaroon().serialize() ) ) - @defer.inlineCallbacks def test_mau_limits_exceeded_large(self): self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastore().get_monthly_active_count = Mock( return_value=make_awaitable(self.large_number_of_users) ) - with self.assertRaises(ResourceLimitError): - yield defer.ensureDeferred( - self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None - ) - ) + self.get_failure( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ), + ResourceLimitError, + ) self.hs.get_datastore().get_monthly_active_count = Mock( return_value=make_awaitable(self.large_number_of_users) ) - with self.assertRaises(ResourceLimitError): - yield defer.ensureDeferred( - self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() - ) - ) + self.get_failure( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ), + ResourceLimitError, + ) - @defer.inlineCallbacks def test_mau_limits_parity(self): + # Ensure we're not at the unix epoch. + self.reactor.advance(1) self.auth_blocking._limit_usage_by_mau = True - # If not in monthly active cohort + # Set the server to be at the edge of too many users. self.hs.get_datastore().get_monthly_active_count = Mock( return_value=make_awaitable(self.auth_blocking._max_mau_value) ) - with self.assertRaises(ResourceLimitError): - yield defer.ensureDeferred( - self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None - ) - ) - self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=make_awaitable(self.auth_blocking._max_mau_value) + # If not in monthly active cohort + self.get_failure( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ), + ResourceLimitError, ) - with self.assertRaises(ResourceLimitError): - yield defer.ensureDeferred( - self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() - ) - ) + self.get_failure( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ), + ResourceLimitError, + ) + # If in monthly active cohort self.hs.get_datastore().user_last_seen_monthly_active = Mock( - return_value=make_awaitable(self.hs.get_clock().time_msec()) + return_value=make_awaitable(self.clock.time_msec()) ) - self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=make_awaitable(self.auth_blocking._max_mau_value) - ) - yield defer.ensureDeferred( + self.get_success( self.auth_handler.get_access_token_for_user_id( "user_a", device_id=None, valid_until_ms=None ) ) - self.hs.get_datastore().user_last_seen_monthly_active = Mock( - return_value=make_awaitable(self.hs.get_clock().time_msec()) - ) - self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=make_awaitable(self.auth_blocking._max_mau_value) - ) - yield defer.ensureDeferred( + self.get_success( self.auth_handler.validate_short_term_login_token_and_get_user_id( self._get_macaroon().serialize() ) ) - @defer.inlineCallbacks def test_mau_limits_not_exceeded(self): self.auth_blocking._limit_usage_by_mau = True @@ -212,7 +187,7 @@ class AuthTestCase(unittest.TestCase): return_value=make_awaitable(self.small_number_of_users) ) # Ensure does not raise exception - yield defer.ensureDeferred( + self.get_success( self.auth_handler.get_access_token_for_user_id( "user_a", device_id=None, valid_until_ms=None ) @@ -221,7 +196,7 @@ class AuthTestCase(unittest.TestCase): self.hs.get_datastore().get_monthly_active_count = Mock( return_value=make_awaitable(self.small_number_of_users) ) - yield defer.ensureDeferred( + self.get_success( self.auth_handler.validate_short_term_login_token_and_get_user_id( self._get_macaroon().serialize() ) diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index c37bb6440e..6f992291b8 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -16,7 +16,7 @@ from mock import Mock from synapse.handlers.cas_handler import CasResponse from tests.test_utils import simple_async_mock -from tests.unittest import HomeserverTestCase +from tests.unittest import HomeserverTestCase, override_config # These are a few constants that are used as config parameters in the tests. BASE_URL = "https://synapse/" @@ -32,6 +32,10 @@ class CasHandlerTestCase(HomeserverTestCase): "server_url": SERVER_URL, "service_url": BASE_URL, } + + # Update this config with what's in the default config so that + # override_config works as expected. + cas_config.update(config.get("cas_config", {})) config["cas_config"] = cas_config return config @@ -62,7 +66,7 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", request, "redirect_uri", None + "@test_user:test", request, "redirect_uri", None, new_user=True ) def test_map_cas_user_to_existing_user(self): @@ -85,7 +89,7 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", request, "redirect_uri", None + "@test_user:test", request, "redirect_uri", None, new_user=False ) # Subsequent calls should map to the same mxid. @@ -94,7 +98,7 @@ class CasHandlerTestCase(HomeserverTestCase): self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") ) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", request, "redirect_uri", None + "@test_user:test", request, "redirect_uri", None, new_user=False ) def test_map_cas_user_to_invalid_localpart(self): @@ -112,10 +116,54 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@f=c3=b6=c3=b6:test", request, "redirect_uri", None + "@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True + ) + + @override_config( + { + "cas_config": { + "required_attributes": {"userGroup": "staff", "department": None} + } + } + ) + def test_required_attributes(self): + """The required attributes must be met from the CAS response.""" + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + # The response doesn't have the proper userGroup or department. + cas_response = CasResponse("test_user", {}) + request = _mock_request() + self.get_success( + self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") + ) + auth_handler.complete_sso_login.assert_not_called() + + # The response doesn't have any department. + cas_response = CasResponse("test_user", {"userGroup": "staff"}) + request.reset_mock() + self.get_success( + self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") + ) + auth_handler.complete_sso_login.assert_not_called() + + # Add the proper attributes and it should succeed. + cas_response = CasResponse( + "test_user", {"userGroup": ["staff", "admin"], "department": ["sales"]} + ) + request.reset_mock() + self.get_success( + self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") + ) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "redirect_uri", None, new_user=True ) def _mock_request(): """Returns a mock which will stand in as a SynapseRequest""" - return Mock(spec=["getClientIP", "getHeader"]) + return Mock(spec=["getClientIP", "getHeader", "_disconnected"]) diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 5dfeccfeb6..821629bc38 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -260,7 +260,9 @@ class DehydrationTestCase(unittest.HomeserverTestCase): # Create a new login for the user and dehydrated the device device_id, access_token = self.get_success( self.registration.register_device( - user_id=user_id, device_id=None, initial_display_name="new device", + user_id=user_id, + device_id=None, + initial_display_name="new device", ) ) diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index a39f898608..863d8737b2 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -131,7 +131,9 @@ class TestCreateAlias(unittest.HomeserverTestCase): """A user can create an alias for a room they're in.""" self.get_success( self.handler.create_association( - create_requester(self.test_user), self.room_alias, self.room_id, + create_requester(self.test_user), + self.room_alias, + self.room_id, ) ) @@ -143,7 +145,9 @@ class TestCreateAlias(unittest.HomeserverTestCase): self.get_failure( self.handler.create_association( - create_requester(self.test_user), self.room_alias, other_room_id, + create_requester(self.test_user), + self.room_alias, + other_room_id, ), synapse.api.errors.SynapseError, ) @@ -156,7 +160,9 @@ class TestCreateAlias(unittest.HomeserverTestCase): self.get_success( self.handler.create_association( - create_requester(self.admin_user), self.room_alias, other_room_id, + create_requester(self.admin_user), + self.room_alias, + other_room_id, ) ) @@ -275,8 +281,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): class CanonicalAliasTestCase(unittest.HomeserverTestCase): - """Test modifications of the canonical alias when delete aliases. - """ + """Test modifications of the canonical alias when delete aliases.""" servlets = [ synapse.rest.admin.register_servlets, @@ -317,7 +322,10 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): def _set_canonical_alias(self, content): """Configure the canonical alias state on the room.""" self.helper.send_state( - self.room_id, "m.room.canonical_alias", content, tok=self.admin_user_tok, + self.room_id, + "m.room.canonical_alias", + content, + tok=self.admin_user_tok, ) def _get_canonical_alias(self): diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 924f29f051..5e86c5e56b 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -18,42 +18,26 @@ import mock from signedjson import key as key, sign as sign -from twisted.internet import defer - -import synapse.handlers.e2e_keys -import synapse.storage -from synapse.api import errors from synapse.api.constants import RoomEncryptionAlgorithms +from synapse.api.errors import Codes, SynapseError -from tests import unittest, utils +from tests import unittest -class E2eKeysHandlerTestCase(unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.hs = None # type: synapse.server.HomeServer - self.handler = None # type: synapse.handlers.e2e_keys.E2eKeysHandler - self.store = None # type: synapse.storage.Storage +class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): + return self.setup_test_homeserver(federation_client=mock.Mock()) - @defer.inlineCallbacks - def setUp(self): - self.hs = yield utils.setup_test_homeserver( - self.addCleanup, federation_client=mock.Mock() - ) - self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs) + def prepare(self, reactor, clock, hs): + self.handler = hs.get_e2e_keys_handler() self.store = self.hs.get_datastore() - @defer.inlineCallbacks def test_query_local_devices_no_devices(self): - """If the user has no devices, we expect an empty list. - """ + """If the user has no devices, we expect an empty list.""" local_user = "@boris:" + self.hs.hostname - res = yield defer.ensureDeferred( - self.handler.query_local_devices({local_user: None}) - ) + res = self.get_success(self.handler.query_local_devices({local_user: None})) self.assertDictEqual(res, {local_user: {}}) - @defer.inlineCallbacks def test_reupload_one_time_keys(self): """we should be able to re-upload the same keys""" local_user = "@boris:" + self.hs.hostname @@ -64,7 +48,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "alg2:k3": {"key": "key3"}, } - res = yield defer.ensureDeferred( + res = self.get_success( self.handler.upload_keys_for_user( local_user, device_id, {"one_time_keys": keys} ) @@ -73,14 +57,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase): # we should be able to change the signature without a problem keys["alg2:k2"]["signatures"]["k1"] = "sig2" - res = yield defer.ensureDeferred( + res = self.get_success( self.handler.upload_keys_for_user( local_user, device_id, {"one_time_keys": keys} ) ) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) - @defer.inlineCallbacks def test_change_one_time_keys(self): """attempts to change one-time-keys should be rejected""" @@ -92,75 +75,66 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "alg2:k3": {"key": "key3"}, } - res = yield defer.ensureDeferred( + res = self.get_success( self.handler.upload_keys_for_user( local_user, device_id, {"one_time_keys": keys} ) ) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) - try: - yield defer.ensureDeferred( - self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}} - ) - ) - self.fail("No error when changing string key") - except errors.SynapseError: - pass - - try: - yield defer.ensureDeferred( - self.handler.upload_keys_for_user( - local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}} - ) - ) - self.fail("No error when replacing dict key with string") - except errors.SynapseError: - pass - - try: - yield defer.ensureDeferred( - self.handler.upload_keys_for_user( - local_user, - device_id, - {"one_time_keys": {"alg1:k1": {"key": "key"}}}, - ) - ) - self.fail("No error when replacing string key with dict") - except errors.SynapseError: - pass - - try: - yield defer.ensureDeferred( - self.handler.upload_keys_for_user( - local_user, - device_id, - { - "one_time_keys": { - "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}} - } - }, - ) - ) - self.fail("No error when replacing dict key") - except errors.SynapseError: - pass + # Error when changing string key + self.get_failure( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}} + ), + SynapseError, + ) + + # Error when replacing dict key with strin + self.get_failure( + self.handler.upload_keys_for_user( + local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}} + ), + SynapseError, + ) + + # Error when replacing string key with dict + self.get_failure( + self.handler.upload_keys_for_user( + local_user, + device_id, + {"one_time_keys": {"alg1:k1": {"key": "key"}}}, + ), + SynapseError, + ) + + # Error when replacing dict key + self.get_failure( + self.handler.upload_keys_for_user( + local_user, + device_id, + { + "one_time_keys": { + "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}} + } + }, + ), + SynapseError, + ) - @defer.inlineCallbacks def test_claim_one_time_key(self): local_user = "@boris:" + self.hs.hostname device_id = "xyz" keys = {"alg1:k1": "key1"} - res = yield defer.ensureDeferred( + res = self.get_success( self.handler.upload_keys_for_user( local_user, device_id, {"one_time_keys": keys} ) ) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}}) - res2 = yield defer.ensureDeferred( + res2 = self.get_success( self.handler.claim_one_time_keys( {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None ) @@ -173,7 +147,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase): }, ) - @defer.inlineCallbacks def test_fallback_key(self): local_user = "@boris:" + self.hs.hostname device_id = "xyz" @@ -181,12 +154,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase): otk = {"alg1:k2": "key2"} # we shouldn't have any unused fallback keys yet - res = yield defer.ensureDeferred( + res = self.get_success( self.store.get_e2e_unused_fallback_key_types(local_user, device_id) ) self.assertEqual(res, []) - yield defer.ensureDeferred( + self.get_success( self.handler.upload_keys_for_user( local_user, device_id, @@ -195,14 +168,14 @@ class E2eKeysHandlerTestCase(unittest.TestCase): ) # we should now have an unused alg1 key - res = yield defer.ensureDeferred( + res = self.get_success( self.store.get_e2e_unused_fallback_key_types(local_user, device_id) ) self.assertEqual(res, ["alg1"]) # claiming an OTK when no OTKs are available should return the fallback # key - res = yield defer.ensureDeferred( + res = self.get_success( self.handler.claim_one_time_keys( {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None ) @@ -213,13 +186,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase): ) # we shouldn't have any unused fallback keys again - res = yield defer.ensureDeferred( + res = self.get_success( self.store.get_e2e_unused_fallback_key_types(local_user, device_id) ) self.assertEqual(res, []) # claiming an OTK again should return the same fallback key - res = yield defer.ensureDeferred( + res = self.get_success( self.handler.claim_one_time_keys( {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None ) @@ -231,22 +204,23 @@ class E2eKeysHandlerTestCase(unittest.TestCase): # if the user uploads a one-time key, the next claim should fetch the # one-time key, and then go back to the fallback - yield defer.ensureDeferred( + self.get_success( self.handler.upload_keys_for_user( local_user, device_id, {"one_time_keys": otk} ) ) - res = yield defer.ensureDeferred( + res = self.get_success( self.handler.claim_one_time_keys( {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None ) ) self.assertEqual( - res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}}, + res, + {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}}, ) - res = yield defer.ensureDeferred( + res = self.get_success( self.handler.claim_one_time_keys( {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None ) @@ -256,7 +230,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase): {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}}, ) - @defer.inlineCallbacks def test_replace_master_key(self): """uploading a new signing key should make the old signing key unavailable""" local_user = "@boris:" + self.hs.hostname @@ -270,9 +243,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): }, } } - yield defer.ensureDeferred( - self.handler.upload_signing_keys_for_user(local_user, keys1) - ) + self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1)) keys2 = { "master_key": { @@ -284,16 +255,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase): }, } } - yield defer.ensureDeferred( - self.handler.upload_signing_keys_for_user(local_user, keys2) - ) + self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys2)) - devices = yield defer.ensureDeferred( + devices = self.get_success( self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user) ) self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) - @defer.inlineCallbacks def test_reupload_signatures(self): """re-uploading a signature should not fail""" local_user = "@boris:" + self.hs.hostname @@ -326,9 +294,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0", ) - yield defer.ensureDeferred( - self.handler.upload_signing_keys_for_user(local_user, keys1) - ) + self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1)) # upload two device keys, which will be signed later by the self-signing key device_key_1 = { @@ -358,12 +324,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "signatures": {local_user: {"ed25519:def": "base64+signature"}}, } - yield defer.ensureDeferred( + self.get_success( self.handler.upload_keys_for_user( local_user, "abc", {"device_keys": device_key_1} ) ) - yield defer.ensureDeferred( + self.get_success( self.handler.upload_keys_for_user( local_user, "def", {"device_keys": device_key_2} ) @@ -372,7 +338,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): # sign the first device key and upload it del device_key_1["signatures"] sign.sign_json(device_key_1, local_user, signing_key) - yield defer.ensureDeferred( + self.get_success( self.handler.upload_signatures_for_device_keys( local_user, {local_user: {"abc": device_key_1}} ) @@ -383,7 +349,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): # signature for it del device_key_2["signatures"] sign.sign_json(device_key_2, local_user, signing_key) - yield defer.ensureDeferred( + self.get_success( self.handler.upload_signatures_for_device_keys( local_user, {local_user: {"abc": device_key_1, "def": device_key_2}} ) @@ -391,7 +357,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature" device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature" - devices = yield defer.ensureDeferred( + devices = self.get_success( self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user) ) del devices["device_keys"][local_user]["abc"]["unsigned"] @@ -399,7 +365,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase): self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1) self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2) - @defer.inlineCallbacks def test_self_signing_key_doesnt_show_up_as_device(self): """signing keys should be hidden when fetching a user's devices""" local_user = "@boris:" + self.hs.hostname @@ -413,29 +378,22 @@ class E2eKeysHandlerTestCase(unittest.TestCase): }, } } - yield defer.ensureDeferred( - self.handler.upload_signing_keys_for_user(local_user, keys1) - ) - - res = None - try: - yield defer.ensureDeferred( - self.hs.get_device_handler().check_device_registered( - user_id=local_user, - device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", - initial_device_display_name="new display name", - ) - ) - except errors.SynapseError as e: - res = e.code - self.assertEqual(res, 400) + self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1)) - res = yield defer.ensureDeferred( - self.handler.query_local_devices({local_user: None}) + e = self.get_failure( + self.hs.get_device_handler().check_device_registered( + user_id=local_user, + device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", + initial_device_display_name="new display name", + ), + SynapseError, ) + res = e.value.code + self.assertEqual(res, 400) + + res = self.get_success(self.handler.query_local_devices({local_user: None})) self.assertDictEqual(res, {local_user: {}}) - @defer.inlineCallbacks def test_upload_signatures(self): """should check signatures that are uploaded""" # set up a user with cross-signing keys and a device. This user will @@ -458,7 +416,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA" ) - yield defer.ensureDeferred( + self.get_success( self.handler.upload_keys_for_user( local_user, device_id, {"device_keys": device_key} ) @@ -501,7 +459,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "user_signing_key": usersigning_key, "self_signing_key": selfsigning_key, } - yield defer.ensureDeferred( + self.get_success( self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys) ) @@ -515,14 +473,14 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "usage": ["master"], "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey}, } - yield defer.ensureDeferred( + self.get_success( self.handler.upload_signing_keys_for_user( other_user, {"master_key": other_master_key} ) ) # test various signature failures (see below) - ret = yield defer.ensureDeferred( + ret = self.get_success( self.handler.upload_signatures_for_device_keys( local_user, { @@ -602,20 +560,16 @@ class E2eKeysHandlerTestCase(unittest.TestCase): ) user_failures = ret["failures"][local_user] + self.assertEqual(user_failures[device_id]["errcode"], Codes.INVALID_SIGNATURE) self.assertEqual( - user_failures[device_id]["errcode"], errors.Codes.INVALID_SIGNATURE + user_failures[master_pubkey]["errcode"], Codes.INVALID_SIGNATURE ) - self.assertEqual( - user_failures[master_pubkey]["errcode"], errors.Codes.INVALID_SIGNATURE - ) - self.assertEqual(user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND) + self.assertEqual(user_failures["unknown"]["errcode"], Codes.NOT_FOUND) other_user_failures = ret["failures"][other_user] + self.assertEqual(other_user_failures["unknown"]["errcode"], Codes.NOT_FOUND) self.assertEqual( - other_user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND - ) - self.assertEqual( - other_user_failures[other_master_pubkey]["errcode"], errors.Codes.UNKNOWN + other_user_failures[other_master_pubkey]["errcode"], Codes.UNKNOWN ) # test successful signatures @@ -623,7 +577,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): sign.sign_json(device_key, local_user, selfsigning_signing_key) sign.sign_json(master_key, local_user, device_signing_key) sign.sign_json(other_master_key, local_user, usersigning_signing_key) - ret = yield defer.ensureDeferred( + ret = self.get_success( self.handler.upload_signatures_for_device_keys( local_user, { @@ -636,7 +590,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): self.assertEqual(ret["failures"], {}) # fetch the signed keys/devices and make sure that the signatures are there - ret = yield defer.ensureDeferred( + ret = self.get_success( self.handler.query_devices( {"device_keys": {local_user: [], other_user: []}}, 0, local_user ) diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 45f201a399..d7498aa51a 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -19,14 +19,9 @@ import copy import mock -from twisted.internet import defer +from synapse.api.errors import SynapseError -import synapse.api.errors -import synapse.handlers.e2e_room_keys -import synapse.storage -from synapse.api import errors - -from tests import unittest, utils +from tests import unittest # sample room_key data for use in the tests room_keys = { @@ -45,51 +40,38 @@ room_keys = { } -class E2eRoomKeysHandlerTestCase(unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.hs = None # type: synapse.server.HomeServer - self.handler = None # type: synapse.handlers.e2e_keys.E2eRoomKeysHandler +class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): + return self.setup_test_homeserver(replication_layer=mock.Mock()) - @defer.inlineCallbacks - def setUp(self): - self.hs = yield utils.setup_test_homeserver( - self.addCleanup, replication_layer=mock.Mock() - ) - self.handler = synapse.handlers.e2e_room_keys.E2eRoomKeysHandler(self.hs) - self.local_user = "@boris:" + self.hs.hostname + def prepare(self, reactor, clock, hs): + self.handler = hs.get_e2e_room_keys_handler() + self.local_user = "@boris:" + hs.hostname - @defer.inlineCallbacks def test_get_missing_current_version_info(self): """Check that we get a 404 if we ask for info about the current version if there is no version. """ - res = None - try: - yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) - except errors.SynapseError as e: - res = e.code + e = self.get_failure( + self.handler.get_version_info(self.local_user), SynapseError + ) + res = e.value.code self.assertEqual(res, 404) - @defer.inlineCallbacks def test_get_missing_version_info(self): """Check that we get a 404 if we ask for info about a specific version if it doesn't exist. """ - res = None - try: - yield defer.ensureDeferred( - self.handler.get_version_info(self.local_user, "bogus_version") - ) - except errors.SynapseError as e: - res = e.code + e = self.get_failure( + self.handler.get_version_info(self.local_user, "bogus_version"), + SynapseError, + ) + res = e.value.code self.assertEqual(res, 404) - @defer.inlineCallbacks def test_create_version(self): - """Check that we can create and then retrieve versions. - """ - res = yield defer.ensureDeferred( + """Check that we can create and then retrieve versions.""" + res = self.get_success( self.handler.create_version( self.local_user, { @@ -101,7 +83,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): self.assertEqual(res, "1") # check we can retrieve it as the current version - res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) + res = self.get_success(self.handler.get_version_info(self.local_user)) version_etag = res["etag"] self.assertIsInstance(version_etag, str) del res["etag"] @@ -116,9 +98,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) # check we can retrieve it as a specific version - res = yield defer.ensureDeferred( - self.handler.get_version_info(self.local_user, "1") - ) + res = self.get_success(self.handler.get_version_info(self.local_user, "1")) self.assertEqual(res["etag"], version_etag) del res["etag"] self.assertDictEqual( @@ -132,7 +112,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) # upload a new one... - res = yield defer.ensureDeferred( + res = self.get_success( self.handler.create_version( self.local_user, { @@ -144,7 +124,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): self.assertEqual(res, "2") # check we can retrieve it as the current version - res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) + res = self.get_success(self.handler.get_version_info(self.local_user)) del res["etag"] self.assertDictEqual( res, @@ -156,11 +136,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): }, ) - @defer.inlineCallbacks def test_update_version(self): - """Check that we can update versions. - """ - version = yield defer.ensureDeferred( + """Check that we can update versions.""" + version = self.get_success( self.handler.create_version( self.local_user, { @@ -171,7 +149,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) self.assertEqual(version, "1") - res = yield defer.ensureDeferred( + res = self.get_success( self.handler.update_version( self.local_user, version, @@ -185,7 +163,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): self.assertDictEqual(res, {}) # check we can retrieve it as the current version - res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) + res = self.get_success(self.handler.get_version_info(self.local_user)) del res["etag"] self.assertDictEqual( res, @@ -197,32 +175,26 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): }, ) - @defer.inlineCallbacks def test_update_missing_version(self): - """Check that we get a 404 on updating nonexistent versions - """ - res = None - try: - yield defer.ensureDeferred( - self.handler.update_version( - self.local_user, - "1", - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": "1", - }, - ) - ) - except errors.SynapseError as e: - res = e.code + """Check that we get a 404 on updating nonexistent versions""" + e = self.get_failure( + self.handler.update_version( + self.local_user, + "1", + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": "1", + }, + ), + SynapseError, + ) + res = e.value.code self.assertEqual(res, 404) - @defer.inlineCallbacks def test_update_omitted_version(self): - """Check that the update succeeds if the version is missing from the body - """ - version = yield defer.ensureDeferred( + """Check that the update succeeds if the version is missing from the body""" + version = self.get_success( self.handler.create_version( self.local_user, { @@ -233,7 +205,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) self.assertEqual(version, "1") - yield defer.ensureDeferred( + self.get_success( self.handler.update_version( self.local_user, version, @@ -245,7 +217,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) # check we can retrieve it as the current version - res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) + res = self.get_success(self.handler.get_version_info(self.local_user)) del res["etag"] # etag is opaque, so don't test its contents self.assertDictEqual( res, @@ -257,11 +229,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): }, ) - @defer.inlineCallbacks def test_update_bad_version(self): - """Check that we get a 400 if the version in the body doesn't match - """ - version = yield defer.ensureDeferred( + """Check that we get a 400 if the version in the body doesn't match""" + version = self.get_success( self.handler.create_version( self.local_user, { @@ -272,52 +242,38 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) self.assertEqual(version, "1") - res = None - try: - yield defer.ensureDeferred( - self.handler.update_version( - self.local_user, - version, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": "incorrect", - }, - ) - ) - except errors.SynapseError as e: - res = e.code + e = self.get_failure( + self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": "incorrect", + }, + ), + SynapseError, + ) + res = e.value.code self.assertEqual(res, 400) - @defer.inlineCallbacks def test_delete_missing_version(self): - """Check that we get a 404 on deleting nonexistent versions - """ - res = None - try: - yield defer.ensureDeferred( - self.handler.delete_version(self.local_user, "1") - ) - except errors.SynapseError as e: - res = e.code + """Check that we get a 404 on deleting nonexistent versions""" + e = self.get_failure( + self.handler.delete_version(self.local_user, "1"), SynapseError + ) + res = e.value.code self.assertEqual(res, 404) - @defer.inlineCallbacks def test_delete_missing_current_version(self): - """Check that we get a 404 on deleting nonexistent current version - """ - res = None - try: - yield defer.ensureDeferred(self.handler.delete_version(self.local_user)) - except errors.SynapseError as e: - res = e.code + """Check that we get a 404 on deleting nonexistent current version""" + e = self.get_failure(self.handler.delete_version(self.local_user), SynapseError) + res = e.value.code self.assertEqual(res, 404) - @defer.inlineCallbacks def test_delete_version(self): - """Check that we can create and then delete versions. - """ - res = yield defer.ensureDeferred( + """Check that we can create and then delete versions.""" + res = self.get_success( self.handler.create_version( self.local_user, { @@ -329,36 +285,26 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): self.assertEqual(res, "1") # check we can delete it - yield defer.ensureDeferred(self.handler.delete_version(self.local_user, "1")) + self.get_success(self.handler.delete_version(self.local_user, "1")) # check that it's gone - res = None - try: - yield defer.ensureDeferred( - self.handler.get_version_info(self.local_user, "1") - ) - except errors.SynapseError as e: - res = e.code + e = self.get_failure( + self.handler.get_version_info(self.local_user, "1"), SynapseError + ) + res = e.value.code self.assertEqual(res, 404) - @defer.inlineCallbacks def test_get_missing_backup(self): - """Check that we get a 404 on querying missing backup - """ - res = None - try: - yield defer.ensureDeferred( - self.handler.get_room_keys(self.local_user, "bogus_version") - ) - except errors.SynapseError as e: - res = e.code + """Check that we get a 404 on querying missing backup""" + e = self.get_failure( + self.handler.get_room_keys(self.local_user, "bogus_version"), SynapseError + ) + res = e.value.code self.assertEqual(res, 404) - @defer.inlineCallbacks def test_get_missing_room_keys(self): - """Check we get an empty response from an empty backup - """ - version = yield defer.ensureDeferred( + """Check we get an empty response from an empty backup""" + version = self.get_success( self.handler.create_version( self.local_user, { @@ -369,33 +315,26 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) self.assertEqual(version, "1") - res = yield defer.ensureDeferred( - self.handler.get_room_keys(self.local_user, version) - ) + res = self.get_success(self.handler.get_room_keys(self.local_user, version)) self.assertDictEqual(res, {"rooms": {}}) # TODO: test the locking semantics when uploading room_keys, # although this is probably best done in sytest - @defer.inlineCallbacks def test_upload_room_keys_no_versions(self): - """Check that we get a 404 on uploading keys when no versions are defined - """ - res = None - try: - yield defer.ensureDeferred( - self.handler.upload_room_keys(self.local_user, "no_version", room_keys) - ) - except errors.SynapseError as e: - res = e.code + """Check that we get a 404 on uploading keys when no versions are defined""" + e = self.get_failure( + self.handler.upload_room_keys(self.local_user, "no_version", room_keys), + SynapseError, + ) + res = e.value.code self.assertEqual(res, 404) - @defer.inlineCallbacks def test_upload_room_keys_bogus_version(self): """Check that we get a 404 on uploading keys when an nonexistent version is specified """ - version = yield defer.ensureDeferred( + version = self.get_success( self.handler.create_version( self.local_user, { @@ -406,22 +345,16 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) self.assertEqual(version, "1") - res = None - try: - yield defer.ensureDeferred( - self.handler.upload_room_keys( - self.local_user, "bogus_version", room_keys - ) - ) - except errors.SynapseError as e: - res = e.code + e = self.get_failure( + self.handler.upload_room_keys(self.local_user, "bogus_version", room_keys), + SynapseError, + ) + res = e.value.code self.assertEqual(res, 404) - @defer.inlineCallbacks def test_upload_room_keys_wrong_version(self): - """Check that we get a 403 on uploading keys for an old version - """ - version = yield defer.ensureDeferred( + """Check that we get a 403 on uploading keys for an old version""" + version = self.get_success( self.handler.create_version( self.local_user, { @@ -432,7 +365,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) self.assertEqual(version, "1") - version = yield defer.ensureDeferred( + version = self.get_success( self.handler.create_version( self.local_user, { @@ -443,20 +376,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) self.assertEqual(version, "2") - res = None - try: - yield defer.ensureDeferred( - self.handler.upload_room_keys(self.local_user, "1", room_keys) - ) - except errors.SynapseError as e: - res = e.code + e = self.get_failure( + self.handler.upload_room_keys(self.local_user, "1", room_keys), SynapseError + ) + res = e.value.code self.assertEqual(res, 403) - @defer.inlineCallbacks def test_upload_room_keys_insert(self): - """Check that we can insert and retrieve keys for a session - """ - version = yield defer.ensureDeferred( + """Check that we can insert and retrieve keys for a session""" + version = self.get_success( self.handler.create_version( self.local_user, { @@ -467,17 +395,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) self.assertEqual(version, "1") - yield defer.ensureDeferred( + self.get_success( self.handler.upload_room_keys(self.local_user, version, room_keys) ) - res = yield defer.ensureDeferred( - self.handler.get_room_keys(self.local_user, version) - ) + res = self.get_success(self.handler.get_room_keys(self.local_user, version)) self.assertDictEqual(res, room_keys) # check getting room_keys for a given room - res = yield defer.ensureDeferred( + res = self.get_success( self.handler.get_room_keys( self.local_user, version, room_id="!abc:matrix.org" ) @@ -485,18 +411,17 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): self.assertDictEqual(res, room_keys) # check getting room_keys for a given session_id - res = yield defer.ensureDeferred( + res = self.get_success( self.handler.get_room_keys( self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" ) ) self.assertDictEqual(res, room_keys) - @defer.inlineCallbacks def test_upload_room_keys_merge(self): """Check that we can upload a new room_key for an existing session and have it correctly merged""" - version = yield defer.ensureDeferred( + version = self.get_success( self.handler.create_version( self.local_user, { @@ -507,12 +432,12 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) self.assertEqual(version, "1") - yield defer.ensureDeferred( + self.get_success( self.handler.upload_room_keys(self.local_user, version, room_keys) ) # get the etag to compare to future versions - res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) + res = self.get_success(self.handler.get_version_info(self.local_user)) backup_etag = res["etag"] self.assertEqual(res["count"], 1) @@ -522,37 +447,33 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # test that increasing the message_index doesn't replace the existing session new_room_key["first_message_index"] = 2 new_room_key["session_data"] = "new" - yield defer.ensureDeferred( + self.get_success( self.handler.upload_room_keys(self.local_user, version, new_room_keys) ) - res = yield defer.ensureDeferred( - self.handler.get_room_keys(self.local_user, version) - ) + res = self.get_success(self.handler.get_room_keys(self.local_user, version)) self.assertEqual( res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "SSBBTSBBIEZJU0gK", ) # the etag should be the same since the session did not change - res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) + res = self.get_success(self.handler.get_version_info(self.local_user)) self.assertEqual(res["etag"], backup_etag) # test that marking the session as verified however /does/ replace it new_room_key["is_verified"] = True - yield defer.ensureDeferred( + self.get_success( self.handler.upload_room_keys(self.local_user, version, new_room_keys) ) - res = yield defer.ensureDeferred( - self.handler.get_room_keys(self.local_user, version) - ) + res = self.get_success(self.handler.get_room_keys(self.local_user, version)) self.assertEqual( res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) # the etag should NOT be equal now, since the key changed - res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) + res = self.get_success(self.handler.get_version_info(self.local_user)) self.assertNotEqual(res["etag"], backup_etag) backup_etag = res["etag"] @@ -560,28 +481,24 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # with a lower forwarding count new_room_key["forwarded_count"] = 2 new_room_key["session_data"] = "other" - yield defer.ensureDeferred( + self.get_success( self.handler.upload_room_keys(self.local_user, version, new_room_keys) ) - res = yield defer.ensureDeferred( - self.handler.get_room_keys(self.local_user, version) - ) + res = self.get_success(self.handler.get_room_keys(self.local_user, version)) self.assertEqual( res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) # the etag should be the same since the session did not change - res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) + res = self.get_success(self.handler.get_version_info(self.local_user)) self.assertEqual(res["etag"], backup_etag) # TODO: check edge cases as well as the common variations here - @defer.inlineCallbacks def test_delete_room_keys(self): - """Check that we can insert and delete keys for a session - """ - version = yield defer.ensureDeferred( + """Check that we can insert and delete keys for a session""" + version = self.get_success( self.handler.create_version( self.local_user, { @@ -593,13 +510,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): self.assertEqual(version, "1") # check for bulk-delete - yield defer.ensureDeferred( + self.get_success( self.handler.upload_room_keys(self.local_user, version, room_keys) ) - yield defer.ensureDeferred( - self.handler.delete_room_keys(self.local_user, version) - ) - res = yield defer.ensureDeferred( + self.get_success(self.handler.delete_room_keys(self.local_user, version)) + res = self.get_success( self.handler.get_room_keys( self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" ) @@ -607,15 +522,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): self.assertDictEqual(res, {"rooms": {}}) # check for bulk-delete per room - yield defer.ensureDeferred( + self.get_success( self.handler.upload_room_keys(self.local_user, version, room_keys) ) - yield defer.ensureDeferred( + self.get_success( self.handler.delete_room_keys( self.local_user, version, room_id="!abc:matrix.org" ) ) - res = yield defer.ensureDeferred( + res = self.get_success( self.handler.get_room_keys( self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" ) @@ -623,15 +538,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): self.assertDictEqual(res, {"rooms": {}}) # check for bulk-delete per session - yield defer.ensureDeferred( + self.get_success( self.handler.upload_room_keys(self.local_user, version, room_keys) ) - yield defer.ensureDeferred( + self.get_success( self.handler.delete_room_keys( self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" ) ) - res = yield defer.ensureDeferred( + res = self.get_success( self.handler.get_room_keys( self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" ) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 0b24b89a2e..3af361195b 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -16,7 +16,7 @@ import logging from unittest import TestCase from synapse.api.constants import EventTypes -from synapse.api.errors import AuthError, Codes, SynapseError +from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError from synapse.api.room_versions import RoomVersions from synapse.events import EventBase from synapse.federation.federation_base import event_from_pdu_json @@ -191,6 +191,58 @@ class FederationTestCase(unittest.HomeserverTestCase): self.assertEqual(sg, sg2) + @unittest.override_config( + {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_invite_by_user_ratelimit(self): + """Tests that invites from federation to a particular user are + actually rate-limited. + """ + other_server = "otherserver" + other_user = "@otheruser:" + other_server + + # create the room + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + + def create_invite(): + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + room_version = self.get_success(self.store.get_room_version(room_id)) + return event_from_pdu_json( + { + "type": EventTypes.Member, + "content": {"membership": "invite"}, + "room_id": room_id, + "sender": other_user, + "state_key": "@user:test", + "depth": 32, + "prev_events": [], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + room_version, + ) + + for i in range(3): + event = create_invite() + self.get_success( + self.handler.on_invite_request( + other_server, + event, + event.room_version, + ) + ) + + event = create_invite() + self.get_failure( + self.handler.on_invite_request( + other_server, + event, + event.room_version, + ), + exc=LimitExceededError, + ) + def _build_and_send_join_event(self, other_server, other_user, room_id): join_event = self.get_success( self.handler.on_make_join_request(other_server, room_id, other_user) diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index f955dfa490..a0d1ebdbe3 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -44,7 +44,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token) self.info = self.get_success( - self.hs.get_datastore().get_user_by_access_token(self.access_token,) + self.hs.get_datastore().get_user_by_access_token( + self.access_token, + ) ) self.token_id = self.info.token_id @@ -169,8 +171,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase): self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token) def test_allow_server_acl(self): - """Test that sending an ACL that blocks everyone but ourselves works. - """ + """Test that sending an ACL that blocks everyone but ourselves works.""" self.helper.send_state( self.room_id, @@ -181,8 +182,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase): ) def test_deny_server_acl_block_outselves(self): - """Test that sending an ACL that blocks ourselves does not work. - """ + """Test that sending an ACL that blocks ourselves does not work.""" self.helper.send_state( self.room_id, EventTypes.ServerACL, @@ -192,8 +192,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase): ) def test_deny_redact_server_acl(self): - """Test that attempting to redact an ACL is blocked. - """ + """Test that attempting to redact an ACL is blocked.""" body = self.helper.send_state( self.room_id, diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index b3dfa40d25..cf1de28fa9 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -24,7 +24,7 @@ from synapse.handlers.sso import MappingException from synapse.server import HomeServer from synapse.types import UserID -from tests.test_utils import FakeResponse, simple_async_mock +from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock from tests.unittest import HomeserverTestCase, override_config try: @@ -40,7 +40,7 @@ ISSUER = "https://issuer/" CLIENT_ID = "test-client-id" CLIENT_SECRET = "test-client-secret" BASE_URL = "https://synapse/" -CALLBACK_URL = BASE_URL + "_synapse/oidc/callback" +CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback" SCOPES = ["openid"] AUTHORIZATION_ENDPOINT = ISSUER + "authorize" @@ -58,12 +58,6 @@ COMMON_CONFIG = { } -# The cookie name and path don't really matter, just that it has to be coherent -# between the callback & redirect handlers. -COOKIE_NAME = b"oidc_session" -COOKIE_PATH = "/_synapse/oidc" - - class TestMappingProvider: @staticmethod def parse_config(config): @@ -137,7 +131,6 @@ class OidcHandlerTestCase(HomeserverTestCase): return config def make_homeserver(self, reactor, clock): - self.http_client = Mock(spec=["get_json"]) self.http_client.get_json.side_effect = get_json self.http_client.user_agent = "Synapse Test" @@ -157,7 +150,15 @@ class OidcHandlerTestCase(HomeserverTestCase): return hs def metadata_edit(self, values): - return patch.dict(self.provider._provider_metadata, values) + """Modify the result that will be returned by the well-known query""" + + async def patched_get_json(uri): + res = await get_json(uri) + if uri == WELL_KNOWN: + res.update(values) + return res + + return patch.object(self.http_client, "get_json", patched_get_json) def assertRenderedError(self, error, error_description=None): self.render_error.assert_called_once() @@ -218,7 +219,14 @@ class OidcHandlerTestCase(HomeserverTestCase): self.http_client.get_json.assert_called_once_with(JWKS_URI) # Throw if the JWKS uri is missing - with self.metadata_edit({"jwks_uri": None}): + original = self.provider.load_metadata + + async def patched_load_metadata(): + m = (await original()).copy() + m.update({"jwks_uri": None}) + return m + + with patch.object(self.provider, "load_metadata", patched_load_metadata): self.get_failure(self.provider.load_jwks(force=True), RuntimeError) # Return empty key set if JWKS are not used @@ -228,55 +236,60 @@ class OidcHandlerTestCase(HomeserverTestCase): self.http_client.get_json.assert_not_called() self.assertEqual(jwks, {"keys": []}) - @override_config({"oidc_config": COMMON_CONFIG}) def test_validate_config(self): """Provider metadatas are extensively validated.""" h = self.provider + def force_load_metadata(): + async def force_load(): + return await h.load_metadata(force=True) + + return get_awaitable_result(force_load()) + # Default test config does not throw - h._validate_metadata() + force_load_metadata() with self.metadata_edit({"issuer": None}): - self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata) + self.assertRaisesRegex(ValueError, "issuer", force_load_metadata) with self.metadata_edit({"issuer": "http://insecure/"}): - self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata) + self.assertRaisesRegex(ValueError, "issuer", force_load_metadata) with self.metadata_edit({"issuer": "https://invalid/?because=query"}): - self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata) + self.assertRaisesRegex(ValueError, "issuer", force_load_metadata) with self.metadata_edit({"authorization_endpoint": None}): self.assertRaisesRegex( - ValueError, "authorization_endpoint", h._validate_metadata + ValueError, "authorization_endpoint", force_load_metadata ) with self.metadata_edit({"authorization_endpoint": "http://insecure/auth"}): self.assertRaisesRegex( - ValueError, "authorization_endpoint", h._validate_metadata + ValueError, "authorization_endpoint", force_load_metadata ) with self.metadata_edit({"token_endpoint": None}): - self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata) + self.assertRaisesRegex(ValueError, "token_endpoint", force_load_metadata) with self.metadata_edit({"token_endpoint": "http://insecure/token"}): - self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata) + self.assertRaisesRegex(ValueError, "token_endpoint", force_load_metadata) with self.metadata_edit({"jwks_uri": None}): - self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata) + self.assertRaisesRegex(ValueError, "jwks_uri", force_load_metadata) with self.metadata_edit({"jwks_uri": "http://insecure/jwks.json"}): - self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata) + self.assertRaisesRegex(ValueError, "jwks_uri", force_load_metadata) with self.metadata_edit({"response_types_supported": ["id_token"]}): self.assertRaisesRegex( - ValueError, "response_types_supported", h._validate_metadata + ValueError, "response_types_supported", force_load_metadata ) with self.metadata_edit( {"token_endpoint_auth_methods_supported": ["client_secret_basic"]} ): # should not throw, as client_secret_basic is the default auth method - h._validate_metadata() + force_load_metadata() with self.metadata_edit( {"token_endpoint_auth_methods_supported": ["client_secret_post"]} @@ -284,7 +297,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertRaisesRegex( ValueError, "token_endpoint_auth_methods_supported", - h._validate_metadata, + force_load_metadata, ) # Tests for configs that require the userinfo endpoint @@ -293,28 +306,30 @@ class OidcHandlerTestCase(HomeserverTestCase): h._user_profile_method = "userinfo_endpoint" self.assertTrue(h._uses_userinfo) - # Revert the profile method and do not request the "openid" scope. + # Revert the profile method and do not request the "openid" scope: this should + # mean that we check for a userinfo endpoint h._user_profile_method = "auto" h._scopes = [] self.assertTrue(h._uses_userinfo) - self.assertRaisesRegex(ValueError, "userinfo_endpoint", h._validate_metadata) + with self.metadata_edit({"userinfo_endpoint": None}): + self.assertRaisesRegex(ValueError, "userinfo_endpoint", force_load_metadata) - with self.metadata_edit( - {"userinfo_endpoint": USERINFO_ENDPOINT, "jwks_uri": None} - ): - # Shouldn't raise with a valid userinfo, even without - h._validate_metadata() + with self.metadata_edit({"jwks_uri": None}): + # Shouldn't raise with a valid userinfo, even without jwks + force_load_metadata() @override_config({"oidc_config": {"skip_verification": True}}) def test_skip_verification(self): """Provider metadata validation can be disabled by config.""" with self.metadata_edit({"issuer": "http://insecure"}): # This should not throw - self.provider._validate_metadata() + get_awaitable_result(self.provider.load_metadata()) def test_redirect_request(self): """The redirect request has the right arguments & generates a valid session cookie.""" - req = Mock(spec=["addCookie"]) + req = Mock(spec=["cookies"]) + req.cookies = [] + url = self.get_success( self.provider.handle_redirect_request(req, b"http://client/redirect") ) @@ -333,16 +348,16 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(len(params["state"]), 1) self.assertEqual(len(params["nonce"]), 1) - # Check what is in the cookie - # note: python3.5 mock does not have the .called_once() method - calls = req.addCookie.call_args_list - self.assertEqual(len(calls), 1) # called once - # For some reason, call.args does not work with python3.5 - args = calls[0][0] - kwargs = calls[0][1] - self.assertEqual(args[0], COOKIE_NAME) - self.assertEqual(kwargs["path"], COOKIE_PATH) - cookie = args[1] + # Check what is in the cookies + self.assertEqual(len(req.cookies), 2) # two cookies + cookie_header = req.cookies[0] + + # The cookie name and path don't really matter, just that it has to be coherent + # between the callback & redirect handlers. + parts = [p.strip() for p in cookie_header.split(b";")] + self.assertIn(b"Path=/_synapse/client/oidc", parts) + name, cookie = parts[0].split(b"=") + self.assertEqual(name, b"oidc_session") macaroon = pymacaroons.Macaroon.deserialize(cookie) state = self.handler._token_generator._get_value_from_macaroon( @@ -419,7 +434,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( - expected_user_id, request, client_redirect_url, None, + expected_user_id, request, client_redirect_url, None, new_user=True ) self.provider._exchange_code.assert_called_once_with(code) self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) @@ -450,7 +465,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( - expected_user_id, request, client_redirect_url, None, + expected_user_id, request, client_redirect_url, None, new_user=False ) self.provider._exchange_code.assert_called_once_with(code) self.provider._parse_id_token.assert_not_called() @@ -473,7 +488,7 @@ class OidcHandlerTestCase(HomeserverTestCase): def test_callback_session(self): """The callback verifies the session presence and validity""" - request = Mock(spec=["args", "getCookie", "addCookie"]) + request = Mock(spec=["args", "getCookie", "cookies"]) # Missing cookie request.args = {} @@ -496,7 +511,9 @@ class OidcHandlerTestCase(HomeserverTestCase): # Mismatching session session = self._generate_oidc_session_token( - state="state", nonce="nonce", client_redirect_url="http://client/redirect", + state="state", + nonce="nonce", + client_redirect_url="http://client/redirect", ) request.args = {} request.args[b"state"] = [b"mismatching state"] @@ -551,7 +568,9 @@ class OidcHandlerTestCase(HomeserverTestCase): # Internal server error with no JSON body self.http_client.request = simple_async_mock( return_value=FakeResponse( - code=500, phrase=b"Internal Server Error", body=b"Not JSON", + code=500, + phrase=b"Internal Server Error", + body=b"Not JSON", ) ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) @@ -571,7 +590,11 @@ class OidcHandlerTestCase(HomeserverTestCase): # 4xx error without "error" field self.http_client.request = simple_async_mock( - return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",) + return_value=FakeResponse( + code=400, + phrase=b"Bad request", + body=b"{}", + ) ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "server_error") @@ -579,7 +602,9 @@ class OidcHandlerTestCase(HomeserverTestCase): # 2xx error with "error" field self.http_client.request = simple_async_mock( return_value=FakeResponse( - code=200, phrase=b"OK", body=b'{"error": "some_error"}', + code=200, + phrase=b"OK", + body=b'{"error": "some_error"}', ) ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) @@ -616,14 +641,20 @@ class OidcHandlerTestCase(HomeserverTestCase): state = "state" client_redirect_url = "http://client/redirect" session = self._generate_oidc_session_token( - state=state, nonce="nonce", client_redirect_url=client_redirect_url, + state=state, + nonce="nonce", + client_redirect_url=client_redirect_url, ) request = _build_callback_request("code", state, session) self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( - "@foo:test", request, client_redirect_url, {"phone": "1234567"}, + "@foo:test", + request, + client_redirect_url, + {"phone": "1234567"}, + new_user=True, ) def test_map_userinfo_to_user(self): @@ -637,7 +668,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", ANY, ANY, None, + "@test_user:test", ANY, ANY, None, new_user=True ) auth_handler.complete_sso_login.reset_mock() @@ -648,7 +679,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user_2:test", ANY, ANY, None, + "@test_user_2:test", ANY, ANY, None, new_user=True ) auth_handler.complete_sso_login.reset_mock() @@ -685,14 +716,14 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), ANY, ANY, None, + user.to_string(), ANY, ANY, None, new_user=False ) auth_handler.complete_sso_login.reset_mock() # Subsequent calls should map to the same mxid. self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), ANY, ANY, None, + user.to_string(), ANY, ANY, None, new_user=False ) auth_handler.complete_sso_login.reset_mock() @@ -707,7 +738,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), ANY, ANY, None, + user.to_string(), ANY, ANY, None, new_user=False ) auth_handler.complete_sso_login.reset_mock() @@ -743,7 +774,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@TEST_USER_2:test", ANY, ANY, None, + "@TEST_USER_2:test", ANY, ANY, None, new_user=False ) def test_map_userinfo_to_invalid_localpart(self): @@ -779,7 +810,7 @@ class OidcHandlerTestCase(HomeserverTestCase): # test_user is already taken, so test_user1 gets registered instead. auth_handler.complete_sso_login.assert_called_once_with( - "@test_user1:test", ANY, ANY, None, + "@test_user1:test", ANY, ANY, None, new_user=True ) auth_handler.complete_sso_login.reset_mock() @@ -875,7 +906,9 @@ async def _make_callback_with_userinfo( session = handler._token_generator.generate_oidc_session_token( state=state, session_data=OidcSessionData( - idp_id="oidc", nonce="nonce", client_redirect_url=client_redirect_url, + idp_id="oidc", + nonce="nonce", + client_redirect_url=client_redirect_url, ), ) request = _build_callback_request("code", state, session) @@ -909,13 +942,14 @@ def _build_callback_request( spec=[ "args", "getCookie", - "addCookie", + "cookies", "requestHeaders", "getClientIP", "getHeader", ] ) + request.cookies = [] request.getCookie.return_value = session request.args = {} request.args[b"code"] = [code.encode("utf-8")] diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index f816594ee4..a98a65ae67 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -231,8 +231,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): } ) def test_no_local_user_fallback_login(self): - """localdb_enabled can block login with the local password - """ + """localdb_enabled can block login with the local password""" self.register_user("localuser", "localpass") # check_password must return an awaitable @@ -251,8 +250,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): } ) def test_no_local_user_fallback_ui_auth(self): - """localdb_enabled can block ui auth with the local password - """ + """localdb_enabled can block ui auth with the local password""" self.register_user("localuser", "localpass") # allow login via the auth provider @@ -594,7 +592,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): ) def _delete_device( - self, access_token: str, device: str, body: Union[JsonDict, bytes] = b"", + self, + access_token: str, + device: str, + body: Union[JsonDict, bytes] = b"", ) -> FakeChannel: """Delete an individual device.""" channel = self.make_request( diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 0794b32c9c..be2ee26f07 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -589,8 +589,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): ) def _add_new_user(self, room_id, user_id): - """Add new user to the room by creating an event and poking the federation API. - """ + """Add new user to the room by creating an event and poking the federation API.""" hostname = get_domain_from_id(user_id) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 022943a10a..18ca8b84f5 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -13,25 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. - from mock import Mock -from twisted.internet import defer - import synapse.types from synapse.api.errors import AuthError, SynapseError from synapse.types import UserID from tests import unittest from tests.test_utils import make_awaitable -from tests.utils import setup_test_homeserver -class ProfileTestCase(unittest.TestCase): +class ProfileTestCase(unittest.HomeserverTestCase): """ Tests profile management. """ - @defer.inlineCallbacks - def setUp(self): + def make_homeserver(self, reactor, clock): self.mock_federation = Mock() self.mock_registry = Mock() @@ -42,39 +37,35 @@ class ProfileTestCase(unittest.TestCase): self.mock_registry.register_query_handler = register_query_handler - hs = yield setup_test_homeserver( - self.addCleanup, + hs = self.setup_test_homeserver( federation_client=self.mock_federation, federation_server=Mock(), federation_registry=self.mock_registry, ) + return hs + def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() self.frank = UserID.from_string("@1234ABCD:test") self.bob = UserID.from_string("@4567:test") self.alice = UserID.from_string("@alice:remote") - yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart)) + self.get_success(self.store.create_profile(self.frank.localpart)) self.handler = hs.get_profile_handler() - self.hs = hs - @defer.inlineCallbacks def test_get_my_name(self): - yield defer.ensureDeferred( + self.get_success( self.store.set_profile_displayname(self.frank.localpart, "Frank") ) - displayname = yield defer.ensureDeferred( - self.handler.get_displayname(self.frank) - ) + displayname = self.get_success(self.handler.get_displayname(self.frank)) self.assertEquals("Frank", displayname) - @defer.inlineCallbacks def test_set_my_name(self): - yield defer.ensureDeferred( + self.get_success( self.handler.set_displayname( self.frank, synapse.types.create_requester(self.frank), "Frank Jr." ) @@ -82,7 +73,7 @@ class ProfileTestCase(unittest.TestCase): self.assertEquals( ( - yield defer.ensureDeferred( + self.get_success( self.store.get_profile_displayname(self.frank.localpart) ) ), @@ -90,7 +81,7 @@ class ProfileTestCase(unittest.TestCase): ) # Set displayname again - yield defer.ensureDeferred( + self.get_success( self.handler.set_displayname( self.frank, synapse.types.create_requester(self.frank), "Frank" ) @@ -98,7 +89,7 @@ class ProfileTestCase(unittest.TestCase): self.assertEquals( ( - yield defer.ensureDeferred( + self.get_success( self.store.get_profile_displayname(self.frank.localpart) ) ), @@ -106,32 +97,27 @@ class ProfileTestCase(unittest.TestCase): ) # Set displayname to an empty string - yield defer.ensureDeferred( + self.get_success( self.handler.set_displayname( self.frank, synapse.types.create_requester(self.frank), "" ) ) self.assertIsNone( - ( - yield defer.ensureDeferred( - self.store.get_profile_displayname(self.frank.localpart) - ) - ) + (self.get_success(self.store.get_profile_displayname(self.frank.localpart))) ) - @defer.inlineCallbacks def test_set_my_name_if_disabled(self): self.hs.config.enable_set_displayname = False # Setting displayname for the first time is allowed - yield defer.ensureDeferred( + self.get_success( self.store.set_profile_displayname(self.frank.localpart, "Frank") ) self.assertEquals( ( - yield defer.ensureDeferred( + self.get_success( self.store.get_profile_displayname(self.frank.localpart) ) ), @@ -139,33 +125,27 @@ class ProfileTestCase(unittest.TestCase): ) # Setting displayname a second time is forbidden - d = defer.ensureDeferred( + self.get_failure( self.handler.set_displayname( self.frank, synapse.types.create_requester(self.frank), "Frank Jr." - ) + ), + SynapseError, ) - yield self.assertFailure(d, SynapseError) - - @defer.inlineCallbacks def test_set_my_name_noauth(self): - d = defer.ensureDeferred( + self.get_failure( self.handler.set_displayname( self.frank, synapse.types.create_requester(self.bob), "Frank Jr." - ) + ), + AuthError, ) - yield self.assertFailure(d, AuthError) - - @defer.inlineCallbacks def test_get_other_name(self): self.mock_federation.make_query.return_value = make_awaitable( {"displayname": "Alice"} ) - displayname = yield defer.ensureDeferred( - self.handler.get_displayname(self.alice) - ) + displayname = self.get_success(self.handler.get_displayname(self.alice)) self.assertEquals(displayname, "Alice") self.mock_federation.make_query.assert_called_with( @@ -175,14 +155,11 @@ class ProfileTestCase(unittest.TestCase): ignore_backoff=True, ) - @defer.inlineCallbacks def test_incoming_fed_query(self): - yield defer.ensureDeferred(self.store.create_profile("caroline")) - yield defer.ensureDeferred( - self.store.set_profile_displayname("caroline", "Caroline") - ) + self.get_success(self.store.create_profile("caroline")) + self.get_success(self.store.set_profile_displayname("caroline", "Caroline")) - response = yield defer.ensureDeferred( + response = self.get_success( self.query_handlers["profile"]( {"user_id": "@caroline:test", "field": "displayname"} ) @@ -190,20 +167,18 @@ class ProfileTestCase(unittest.TestCase): self.assertEquals({"displayname": "Caroline"}, response) - @defer.inlineCallbacks def test_get_my_avatar(self): - yield defer.ensureDeferred( + self.get_success( self.store.set_profile_avatar_url( self.frank.localpart, "http://my.server/me.png" ) ) - avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank)) + avatar_url = self.get_success(self.handler.get_avatar_url(self.frank)) self.assertEquals("http://my.server/me.png", avatar_url) - @defer.inlineCallbacks def test_set_my_avatar(self): - yield defer.ensureDeferred( + self.get_success( self.handler.set_avatar_url( self.frank, synapse.types.create_requester(self.frank), @@ -212,16 +187,12 @@ class ProfileTestCase(unittest.TestCase): ) self.assertEquals( - ( - yield defer.ensureDeferred( - self.store.get_profile_avatar_url(self.frank.localpart) - ) - ), + (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), "http://my.server/pic.gif", ) # Set avatar again - yield defer.ensureDeferred( + self.get_success( self.handler.set_avatar_url( self.frank, synapse.types.create_requester(self.frank), @@ -230,56 +201,44 @@ class ProfileTestCase(unittest.TestCase): ) self.assertEquals( - ( - yield defer.ensureDeferred( - self.store.get_profile_avatar_url(self.frank.localpart) - ) - ), + (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), "http://my.server/me.png", ) # Set avatar to an empty string - yield defer.ensureDeferred( + self.get_success( self.handler.set_avatar_url( - self.frank, synapse.types.create_requester(self.frank), "", + self.frank, + synapse.types.create_requester(self.frank), + "", ) ) self.assertIsNone( - ( - yield defer.ensureDeferred( - self.store.get_profile_avatar_url(self.frank.localpart) - ) - ), + (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), ) - @defer.inlineCallbacks def test_set_my_avatar_if_disabled(self): self.hs.config.enable_set_avatar_url = False # Setting displayname for the first time is allowed - yield defer.ensureDeferred( + self.get_success( self.store.set_profile_avatar_url( self.frank.localpart, "http://my.server/me.png" ) ) self.assertEquals( - ( - yield defer.ensureDeferred( - self.store.get_profile_avatar_url(self.frank.localpart) - ) - ), + (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), "http://my.server/me.png", ) # Set avatar a second time is forbidden - d = defer.ensureDeferred( + self.get_failure( self.handler.set_avatar_url( self.frank, synapse.types.create_requester(self.frank), "http://my.server/pic.gif", - ) + ), + SynapseError, ) - - yield self.assertFailure(d, SynapseError) diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 261c7083d1..029af2853e 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -131,7 +131,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", request, "redirect_uri", None + "@test_user:test", request, "redirect_uri", None, new_user=True ) @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) @@ -157,7 +157,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", request, "", None + "@test_user:test", request, "", None, new_user=False ) # Subsequent calls should map to the same mxid. @@ -166,7 +166,7 @@ class SamlHandlerTestCase(HomeserverTestCase): self.handler._handle_authn_response(request, saml_response, "") ) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", request, "", None + "@test_user:test", request, "", None, new_user=False ) def test_map_saml_response_to_invalid_localpart(self): @@ -214,7 +214,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # test_user is already taken, so test_user1 gets registered instead. auth_handler.complete_sso_login.assert_called_once_with( - "@test_user1:test", request, "", None + "@test_user1:test", request, "", None, new_user=True ) auth_handler.complete_sso_login.reset_mock() @@ -259,7 +259,61 @@ class SamlHandlerTestCase(HomeserverTestCase): ) self.assertEqual(e.value.location, b"https://custom-saml-redirect/") + @override_config( + { + "saml2_config": { + "attribute_requirements": [ + {"attribute": "userGroup", "value": "staff"}, + {"attribute": "department", "value": "sales"}, + ], + }, + } + ) + def test_attribute_requirements(self): + """The required attributes must be met from the SAML response.""" + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + # The response doesn't have the proper userGroup or department. + saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, "redirect_uri") + ) + auth_handler.complete_sso_login.assert_not_called() + + # The response doesn't have the proper department. + saml_response = FakeAuthnResponse( + {"uid": "test_user", "username": "test_user", "userGroup": ["staff"]} + ) + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, "redirect_uri") + ) + auth_handler.complete_sso_login.assert_not_called() + + # Add the proper attributes and it should succeed. + saml_response = FakeAuthnResponse( + { + "uid": "test_user", + "username": "test_user", + "userGroup": ["staff", "admin"], + "department": ["sales"], + } + ) + request.reset_mock() + self.get_success( + self.handler._handle_authn_response(request, saml_response, "redirect_uri") + ) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "redirect_uri", None, new_user=True + ) + def _mock_request(): """Returns a mock which will stand in as a SynapseRequest""" - return Mock(spec=["getClientIP", "getHeader"]) + return Mock(spec=["getClientIP", "getHeader", "_disconnected"]) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 96e5bdac4a..24e7138196 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -143,14 +143,14 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.datastore.get_current_state_deltas = Mock(return_value=(0, None)) self.datastore.get_to_device_stream_token = lambda: 0 - self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable( - ([], 0) + self.datastore.get_new_device_msgs_for_remote = ( + lambda *args, **kargs: make_awaitable(([], 0)) ) - self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: make_awaitable( - None + self.datastore.delete_device_msgs_for_remote = ( + lambda *args, **kargs: make_awaitable(None) ) - self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable( - None + self.datastore.set_received_txn_response = ( + lambda *args, **kwargs: make_awaitable(None) ) def test_started_typing_local(self): diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 9c886d671a..3572e54c5d 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -200,7 +200,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # Check that the room has an encryption state event event_content = self.helper.get_state( - room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token, + room_id=room_id, + event_type=EventTypes.RoomEncryption, + tok=user_token, ) self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT}) @@ -209,7 +211,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # Check that the room has an encryption state event event_content = self.helper.get_state( - room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token, + room_id=room_id, + event_type=EventTypes.RoomEncryption, + tok=user_token, ) self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT}) @@ -227,7 +231,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # Check that the room has an encryption state event event_content = self.helper.get_state( - room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token, + room_id=room_id, + event_type=EventTypes.RoomEncryption, + tok=user_token, ) self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT}) |