diff options
author | Brendan Abolivier <babolivier@matrix.org> | 2020-06-10 11:42:30 +0100 |
---|---|---|
committer | Brendan Abolivier <babolivier@matrix.org> | 2020-06-10 11:42:30 +0100 |
commit | ec0a7b9034806d6b2ba086bae58f5c6b0fd14672 (patch) | |
tree | f2af547b1342795e10548f8fb7a9cfc93e03df37 /tests/handlers | |
parent | changelog (diff) | |
parent | 1.15.0rc1 (diff) | |
download | synapse-ec0a7b9034806d6b2ba086bae58f5c6b0fd14672.tar.xz |
Merge branch 'develop' into babolivier/mark_unread
Diffstat (limited to 'tests/handlers')
-rw-r--r-- | tests/handlers/test_auth.py | 103 | ||||
-rw-r--r-- | tests/handlers/test_device.py | 18 | ||||
-rw-r--r-- | tests/handlers/test_directory.py | 331 | ||||
-rw-r--r-- | tests/handlers/test_e2e_keys.py | 360 | ||||
-rw-r--r-- | tests/handlers/test_e2e_room_keys.py | 78 | ||||
-rw-r--r-- | tests/handlers/test_federation.py | 274 | ||||
-rw-r--r-- | tests/handlers/test_oidc.py | 570 | ||||
-rw-r--r-- | tests/handlers/test_presence.py | 67 | ||||
-rw-r--r-- | tests/handlers/test_profile.py | 107 | ||||
-rw-r--r-- | tests/handlers/test_register.py | 41 | ||||
-rw-r--r-- | tests/handlers/test_roomlist.py | 39 | ||||
-rw-r--r-- | tests/handlers/test_stats.py | 90 | ||||
-rw-r--r-- | tests/handlers/test_sync.py | 50 | ||||
-rw-r--r-- | tests/handlers/test_typing.py | 93 | ||||
-rw-r--r-- | tests/handlers/test_user_directory.py | 138 |
15 files changed, 2088 insertions, 271 deletions
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index b03103d96f..c01b04e1dc 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -39,8 +39,13 @@ class AuthTestCase(unittest.TestCase): self.hs.handlers = AuthHandlers(self.hs) self.auth_handler = self.hs.handlers.auth_handler self.macaroon_generator = self.hs.get_macaroon_generator() + # MAU tests - self.hs.config.max_mau_value = 50 + # AuthBlocking reads from the hs' config on initialization. We need to + # modify its config instead of the hs' + self.auth_blocking = self.hs.get_auth()._auth_blocking + self.auth_blocking._max_mau_value = 50 + self.small_number_of_users = 1 self.large_number_of_users = 100 @@ -82,16 +87,16 @@ class AuthTestCase(unittest.TestCase): self.hs.clock.now = 1000 token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000) - user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - token + user_id = yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id(token) ) self.assertEqual("a_user", user_id) # when we advance the clock, the token should be rejected self.hs.clock.now = 6000 with self.assertRaises(synapse.api.errors.AuthError): - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - token + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id(token) ) @defer.inlineCallbacks @@ -99,8 +104,10 @@ class AuthTestCase(unittest.TestCase): token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000) macaroon = pymacaroons.Macaroon.deserialize(token) - user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - macaroon.serialize() + user_id = yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + macaroon.serialize() + ) ) self.assertEqual("a_user", user_id) @@ -109,99 +116,121 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("user_id = b_user") with self.assertRaises(synapse.api.errors.AuthError): - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - macaroon.serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + macaroon.serialize() + ) ) @defer.inlineCallbacks def test_mau_limits_disabled(self): - self.hs.config.limit_usage_by_mau = False + self.auth_blocking._limit_usage_by_mau = False # Ensure does not throw exception - yield self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) ) - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) ) @defer.inlineCallbacks def test_mau_limits_exceeded_large(self): - self.hs.config.limit_usage_by_mau = True + self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.large_number_of_users) ) with self.assertRaises(ResourceLimitError): - yield self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) ) self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.large_number_of_users) ) with self.assertRaises(ResourceLimitError): - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) ) @defer.inlineCallbacks def test_mau_limits_parity(self): - self.hs.config.limit_usage_by_mau = True + self.auth_blocking._limit_usage_by_mau = True # If not in monthly active cohort self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value) + return_value=defer.succeed(self.auth_blocking._max_mau_value) ) with self.assertRaises(ResourceLimitError): - yield self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) ) self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value) + return_value=defer.succeed(self.auth_blocking._max_mau_value) ) with self.assertRaises(ResourceLimitError): - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) ) # If in monthly active cohort self.hs.get_datastore().user_last_seen_monthly_active = Mock( return_value=defer.succeed(self.hs.get_clock().time_msec()) ) self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value) + return_value=defer.succeed(self.auth_blocking._max_mau_value) ) - yield self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) ) self.hs.get_datastore().user_last_seen_monthly_active = Mock( return_value=defer.succeed(self.hs.get_clock().time_msec()) ) self.hs.get_datastore().get_monthly_active_count = Mock( - return_value=defer.succeed(self.hs.config.max_mau_value) + return_value=defer.succeed(self.auth_blocking._max_mau_value) ) - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) ) @defer.inlineCallbacks def test_mau_limits_not_exceeded(self): - self.hs.config.limit_usage_by_mau = True + self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.small_number_of_users) ) # Ensure does not raise exception - yield self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) ) self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.small_number_of_users) ) - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) ) def _get_macaroon(self): diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index a3aa0a1cf2..62b47f6574 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -160,6 +160,24 @@ class DeviceTestCase(unittest.HomeserverTestCase): res = self.get_success(self.handler.get_device(user1, "abc")) self.assertEqual(res["display_name"], "new display") + def test_update_device_too_long_display_name(self): + """Update a device with a display name that is invalid (too long).""" + self._record_users() + + # Request to update a device display name with a new value that is longer than allowed. + update = { + "display_name": "a" + * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1) + } + self.get_failure( + self.handler.update_device(user1, "abc", update), + synapse.api.errors.SynapseError, + ) + + # Ensure the display name was not updated. + res = self.get_success(self.handler.get_device(user1, "abc")) + self.assertEqual(res["display_name"], "display 2") + def test_update_unknown_device(self): update = {"display_name": "new_display"} res = self.handler.update_device("user_id", "unknown_device_id", update) diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 91c7a17070..00bb776271 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -18,25 +18,20 @@ from mock import Mock from twisted.internet import defer +import synapse +import synapse.api.errors +from synapse.api.constants import EventTypes from synapse.config.room_directory import RoomDirectoryConfig -from synapse.handlers.directory import DirectoryHandler -from synapse.rest.client.v1 import directory, room -from synapse.types import RoomAlias +from synapse.rest.client.v1 import directory, login, room +from synapse.types import RoomAlias, create_requester from tests import unittest -from tests.utils import setup_test_homeserver -class DirectoryHandlers(object): - def __init__(self, hs): - self.directory_handler = DirectoryHandler(hs) - - -class DirectoryTestCase(unittest.TestCase): +class DirectoryTestCase(unittest.HomeserverTestCase): """ Tests the directory service. """ - @defer.inlineCallbacks - def setUp(self): + def make_homeserver(self, reactor, clock): self.mock_federation = Mock() self.mock_registry = Mock() @@ -47,14 +42,12 @@ class DirectoryTestCase(unittest.TestCase): self.mock_registry.register_query_handler = register_query_handler - hs = yield setup_test_homeserver( - self.addCleanup, + hs = self.setup_test_homeserver( http_client=None, resource_for_federation=Mock(), federation_client=self.mock_federation, federation_registry=self.mock_registry, ) - hs.handlers = DirectoryHandlers(hs) self.handler = hs.get_handlers().directory_handler @@ -64,23 +57,25 @@ class DirectoryTestCase(unittest.TestCase): self.your_room = RoomAlias.from_string("#your-room:test") self.remote_room = RoomAlias.from_string("#another:remote") - @defer.inlineCallbacks + return hs + def test_get_local_association(self): - yield self.store.create_room_alias_association( - self.my_room, "!8765qwer:test", ["test"] + self.get_success( + self.store.create_room_alias_association( + self.my_room, "!8765qwer:test", ["test"] + ) ) - result = yield self.handler.get_association(self.my_room) + result = self.get_success(self.handler.get_association(self.my_room)) self.assertEquals({"room_id": "!8765qwer:test", "servers": ["test"]}, result) - @defer.inlineCallbacks def test_get_remote_association(self): self.mock_federation.make_query.return_value = defer.succeed( {"room_id": "!8765qwer:test", "servers": ["test", "remote"]} ) - result = yield self.handler.get_association(self.remote_room) + result = self.get_success(self.handler.get_association(self.remote_room)) self.assertEquals( {"room_id": "!8765qwer:test", "servers": ["test", "remote"]}, result @@ -93,19 +88,303 @@ class DirectoryTestCase(unittest.TestCase): ignore_backoff=True, ) - @defer.inlineCallbacks def test_incoming_fed_query(self): - yield self.store.create_room_alias_association( - self.your_room, "!8765asdf:test", ["test"] + self.get_success( + self.store.create_room_alias_association( + self.your_room, "!8765asdf:test", ["test"] + ) ) - response = yield self.query_handlers["directory"]( - {"room_alias": "#your-room:test"} + response = self.get_success( + self.handler.on_directory_query({"room_alias": "#your-room:test"}) ) self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response) +class TestCreateAlias(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + directory.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.handler = hs.get_handlers().directory_handler + + # Create user + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + # Create a test room + self.room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + + self.test_alias = "#test:test" + self.room_alias = RoomAlias.from_string(self.test_alias) + + # Create a test user. + self.test_user = self.register_user("user", "pass", admin=False) + self.test_user_tok = self.login("user", "pass") + self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok) + + def test_create_alias_joined_room(self): + """A user can create an alias for a room they're in.""" + self.get_success( + self.handler.create_association( + create_requester(self.test_user), self.room_alias, self.room_id, + ) + ) + + def test_create_alias_other_room(self): + """A user cannot create an alias for a room they're NOT in.""" + other_room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + + self.get_failure( + self.handler.create_association( + create_requester(self.test_user), self.room_alias, other_room_id, + ), + synapse.api.errors.SynapseError, + ) + + def test_create_alias_admin(self): + """An admin can create an alias for a room they're NOT in.""" + other_room_id = self.helper.create_room_as( + self.test_user, tok=self.test_user_tok + ) + + self.get_success( + self.handler.create_association( + create_requester(self.admin_user), self.room_alias, other_room_id, + ) + ) + + +class TestDeleteAlias(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + directory.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.handler = hs.get_handlers().directory_handler + self.state_handler = hs.get_state_handler() + + # Create user + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + # Create a test room + self.room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + + self.test_alias = "#test:test" + self.room_alias = RoomAlias.from_string(self.test_alias) + + # Create a test user. + self.test_user = self.register_user("user", "pass", admin=False) + self.test_user_tok = self.login("user", "pass") + self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok) + + def _create_alias(self, user): + # Create a new alias to this room. + self.get_success( + self.store.create_room_alias_association( + self.room_alias, self.room_id, ["test"], user + ) + ) + + def test_delete_alias_not_allowed(self): + """A user that doesn't meet the expected guidelines cannot delete an alias.""" + self._create_alias(self.admin_user) + self.get_failure( + self.handler.delete_association( + create_requester(self.test_user), self.room_alias + ), + synapse.api.errors.AuthError, + ) + + def test_delete_alias_creator(self): + """An alias creator can delete their own alias.""" + # Create an alias from a different user. + self._create_alias(self.test_user) + + # Delete the user's alias. + result = self.get_success( + self.handler.delete_association( + create_requester(self.test_user), self.room_alias + ) + ) + self.assertEquals(self.room_id, result) + + # Confirm the alias is gone. + self.get_failure( + self.handler.get_association(self.room_alias), + synapse.api.errors.SynapseError, + ) + + def test_delete_alias_admin(self): + """A server admin can delete an alias created by another user.""" + # Create an alias from a different user. + self._create_alias(self.test_user) + + # Delete the user's alias as the admin. + result = self.get_success( + self.handler.delete_association( + create_requester(self.admin_user), self.room_alias + ) + ) + self.assertEquals(self.room_id, result) + + # Confirm the alias is gone. + self.get_failure( + self.handler.get_association(self.room_alias), + synapse.api.errors.SynapseError, + ) + + def test_delete_alias_sufficient_power(self): + """A user with a sufficient power level should be able to delete an alias.""" + self._create_alias(self.admin_user) + + # Increase the user's power level. + self.helper.send_state( + self.room_id, + "m.room.power_levels", + {"users": {self.test_user: 100}}, + tok=self.admin_user_tok, + ) + + # They can now delete the alias. + result = self.get_success( + self.handler.delete_association( + create_requester(self.test_user), self.room_alias + ) + ) + self.assertEquals(self.room_id, result) + + # Confirm the alias is gone. + self.get_failure( + self.handler.get_association(self.room_alias), + synapse.api.errors.SynapseError, + ) + + +class CanonicalAliasTestCase(unittest.HomeserverTestCase): + """Test modifications of the canonical alias when delete aliases. + """ + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + directory.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.handler = hs.get_handlers().directory_handler + self.state_handler = hs.get_state_handler() + + # Create user + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + # Create a test room + self.room_id = self.helper.create_room_as( + self.admin_user, tok=self.admin_user_tok + ) + + self.test_alias = "#test:test" + self.room_alias = self._add_alias(self.test_alias) + + def _add_alias(self, alias: str) -> RoomAlias: + """Add an alias to the test room.""" + room_alias = RoomAlias.from_string(alias) + + # Create a new alias to this room. + self.get_success( + self.store.create_room_alias_association( + room_alias, self.room_id, ["test"], self.admin_user + ) + ) + return room_alias + + def _set_canonical_alias(self, content): + """Configure the canonical alias state on the room.""" + self.helper.send_state( + self.room_id, "m.room.canonical_alias", content, tok=self.admin_user_tok, + ) + + def _get_canonical_alias(self): + """Get the canonical alias state of the room.""" + return self.get_success( + self.state_handler.get_current_state( + self.room_id, EventTypes.CanonicalAlias, "" + ) + ) + + def test_remove_alias(self): + """Removing an alias that is the canonical alias should remove it there too.""" + # Set this new alias as the canonical alias for this room + self._set_canonical_alias( + {"alias": self.test_alias, "alt_aliases": [self.test_alias]} + ) + + data = self._get_canonical_alias() + self.assertEqual(data["content"]["alias"], self.test_alias) + self.assertEqual(data["content"]["alt_aliases"], [self.test_alias]) + + # Finally, delete the alias. + self.get_success( + self.handler.delete_association( + create_requester(self.admin_user), self.room_alias + ) + ) + + data = self._get_canonical_alias() + self.assertNotIn("alias", data["content"]) + self.assertNotIn("alt_aliases", data["content"]) + + def test_remove_other_alias(self): + """Removing an alias listed as in alt_aliases should remove it there too.""" + # Create a second alias. + other_test_alias = "#test2:test" + other_room_alias = self._add_alias(other_test_alias) + + # Set the alias as the canonical alias for this room. + self._set_canonical_alias( + { + "alias": self.test_alias, + "alt_aliases": [self.test_alias, other_test_alias], + } + ) + + data = self._get_canonical_alias() + self.assertEqual(data["content"]["alias"], self.test_alias) + self.assertEqual( + data["content"]["alt_aliases"], [self.test_alias, other_test_alias] + ) + + # Delete the second alias. + self.get_success( + self.handler.delete_association( + create_requester(self.admin_user), other_room_alias + ) + ) + + data = self._get_canonical_alias() + self.assertEqual(data["content"]["alias"], self.test_alias) + self.assertEqual(data["content"]["alt_aliases"], [self.test_alias]) + + class TestCreateAliasACL(unittest.HomeserverTestCase): user_id = "@test:test" diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 8dccc6826e..e1e144b2e7 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,9 +17,11 @@ import mock +import signedjson.key as key +import signedjson.sign as sign + from twisted.internet import defer -import synapse.api.errors import synapse.handlers.e2e_keys import synapse.storage from synapse.api import errors @@ -145,3 +149,357 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}}, }, ) + + @defer.inlineCallbacks + def test_replace_master_key(self): + """uploading a new signing key should make the old signing key unavailable""" + local_user = "@boris:" + self.hs.hostname + keys1 = { + "master_key": { + # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0 + "user_id": local_user, + "usage": ["master"], + "keys": { + "ed25519:nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk": "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk" + }, + } + } + yield self.handler.upload_signing_keys_for_user(local_user, keys1) + + keys2 = { + "master_key": { + # private key: 4TL4AjRYwDVwD3pqQzcor+ez/euOB1/q78aTJ+czDNs + "user_id": local_user, + "usage": ["master"], + "keys": { + "ed25519:Hq6gL+utB4ET+UvD5ci0kgAwsX6qP/zvf8v6OInU5iw": "Hq6gL+utB4ET+UvD5ci0kgAwsX6qP/zvf8v6OInU5iw" + }, + } + } + yield self.handler.upload_signing_keys_for_user(local_user, keys2) + + devices = yield self.handler.query_devices( + {"device_keys": {local_user: []}}, 0, local_user + ) + self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) + + @defer.inlineCallbacks + def test_reupload_signatures(self): + """re-uploading a signature should not fail""" + local_user = "@boris:" + self.hs.hostname + keys1 = { + "master_key": { + # private key: HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8 + "user_id": local_user, + "usage": ["master"], + "keys": { + "ed25519:EmkqvokUn8p+vQAGZitOk4PWjp7Ukp3txV2TbMPEiBQ": "EmkqvokUn8p+vQAGZitOk4PWjp7Ukp3txV2TbMPEiBQ" + }, + }, + "self_signing_key": { + # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0 + "user_id": local_user, + "usage": ["self_signing"], + "keys": { + "ed25519:nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk": "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk" + }, + }, + } + master_signing_key = key.decode_signing_key_base64( + "ed25519", + "EmkqvokUn8p+vQAGZitOk4PWjp7Ukp3txV2TbMPEiBQ", + "HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8", + ) + sign.sign_json(keys1["self_signing_key"], local_user, master_signing_key) + signing_key = key.decode_signing_key_base64( + "ed25519", + "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", + "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0", + ) + yield self.handler.upload_signing_keys_for_user(local_user, keys1) + + # upload two device keys, which will be signed later by the self-signing key + device_key_1 = { + "user_id": local_user, + "device_id": "abc", + "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"], + "keys": { + "ed25519:abc": "base64+ed25519+key", + "curve25519:abc": "base64+curve25519+key", + }, + "signatures": {local_user: {"ed25519:abc": "base64+signature"}}, + } + device_key_2 = { + "user_id": local_user, + "device_id": "def", + "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"], + "keys": { + "ed25519:def": "base64+ed25519+key", + "curve25519:def": "base64+curve25519+key", + }, + "signatures": {local_user: {"ed25519:def": "base64+signature"}}, + } + + yield self.handler.upload_keys_for_user( + local_user, "abc", {"device_keys": device_key_1} + ) + yield self.handler.upload_keys_for_user( + local_user, "def", {"device_keys": device_key_2} + ) + + # sign the first device key and upload it + del device_key_1["signatures"] + sign.sign_json(device_key_1, local_user, signing_key) + yield self.handler.upload_signatures_for_device_keys( + local_user, {local_user: {"abc": device_key_1}} + ) + + # sign the second device key and upload both device keys. The server + # should ignore the first device key since it already has a valid + # signature for it + del device_key_2["signatures"] + sign.sign_json(device_key_2, local_user, signing_key) + yield self.handler.upload_signatures_for_device_keys( + local_user, {local_user: {"abc": device_key_1, "def": device_key_2}} + ) + + device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature" + device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature" + devices = yield self.handler.query_devices( + {"device_keys": {local_user: []}}, 0, local_user + ) + del devices["device_keys"][local_user]["abc"]["unsigned"] + del devices["device_keys"][local_user]["def"]["unsigned"] + self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1) + self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2) + + @defer.inlineCallbacks + def test_self_signing_key_doesnt_show_up_as_device(self): + """signing keys should be hidden when fetching a user's devices""" + local_user = "@boris:" + self.hs.hostname + keys1 = { + "master_key": { + # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0 + "user_id": local_user, + "usage": ["master"], + "keys": { + "ed25519:nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk": "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk" + }, + } + } + yield self.handler.upload_signing_keys_for_user(local_user, keys1) + + res = None + try: + yield self.hs.get_device_handler().check_device_registered( + user_id=local_user, + device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", + initial_device_display_name="new display name", + ) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 400) + + res = yield self.handler.query_local_devices({local_user: None}) + self.assertDictEqual(res, {local_user: {}}) + + @defer.inlineCallbacks + def test_upload_signatures(self): + """should check signatures that are uploaded""" + # set up a user with cross-signing keys and a device. This user will + # try uploading signatures + local_user = "@boris:" + self.hs.hostname + device_id = "xyz" + # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA + device_pubkey = "NnHhnqiMFQkq969szYkooLaBAXW244ZOxgukCvm2ZeY" + device_key = { + "user_id": local_user, + "device_id": device_id, + "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"], + "keys": {"curve25519:xyz": "curve25519+key", "ed25519:xyz": device_pubkey}, + "signatures": {local_user: {"ed25519:xyz": "something"}}, + } + device_signing_key = key.decode_signing_key_base64( + "ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA" + ) + + yield self.handler.upload_keys_for_user( + local_user, device_id, {"device_keys": device_key} + ) + + # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0 + master_pubkey = "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk" + master_key = { + "user_id": local_user, + "usage": ["master"], + "keys": {"ed25519:" + master_pubkey: master_pubkey}, + } + master_signing_key = key.decode_signing_key_base64( + "ed25519", master_pubkey, "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0" + ) + usersigning_pubkey = "Hq6gL+utB4ET+UvD5ci0kgAwsX6qP/zvf8v6OInU5iw" + usersigning_key = { + # private key: 4TL4AjRYwDVwD3pqQzcor+ez/euOB1/q78aTJ+czDNs + "user_id": local_user, + "usage": ["user_signing"], + "keys": {"ed25519:" + usersigning_pubkey: usersigning_pubkey}, + } + usersigning_signing_key = key.decode_signing_key_base64( + "ed25519", usersigning_pubkey, "4TL4AjRYwDVwD3pqQzcor+ez/euOB1/q78aTJ+czDNs" + ) + sign.sign_json(usersigning_key, local_user, master_signing_key) + # private key: HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8 + selfsigning_pubkey = "EmkqvokUn8p+vQAGZitOk4PWjp7Ukp3txV2TbMPEiBQ" + selfsigning_key = { + "user_id": local_user, + "usage": ["self_signing"], + "keys": {"ed25519:" + selfsigning_pubkey: selfsigning_pubkey}, + } + selfsigning_signing_key = key.decode_signing_key_base64( + "ed25519", selfsigning_pubkey, "HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8" + ) + sign.sign_json(selfsigning_key, local_user, master_signing_key) + cross_signing_keys = { + "master_key": master_key, + "user_signing_key": usersigning_key, + "self_signing_key": selfsigning_key, + } + yield self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys) + + # set up another user with a master key. This user will be signed by + # the first user + other_user = "@otherboris:" + self.hs.hostname + other_master_pubkey = "fHZ3NPiKxoLQm5OoZbKa99SYxprOjNs4TwJUKP+twCM" + other_master_key = { + # private key: oyw2ZUx0O4GifbfFYM0nQvj9CL0b8B7cyN4FprtK8OI + "user_id": other_user, + "usage": ["master"], + "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey}, + } + yield self.handler.upload_signing_keys_for_user( + other_user, {"master_key": other_master_key} + ) + + # test various signature failures (see below) + ret = yield self.handler.upload_signatures_for_device_keys( + local_user, + { + local_user: { + # fails because the signature is invalid + # should fail with INVALID_SIGNATURE + device_id: { + "user_id": local_user, + "device_id": device_id, + "algorithms": [ + "m.olm.curve25519-aes-sha2", + "m.megolm.v1.aes-sha2", + ], + "keys": { + "curve25519:xyz": "curve25519+key", + # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA + "ed25519:xyz": device_pubkey, + }, + "signatures": { + local_user: {"ed25519:" + selfsigning_pubkey: "something"} + }, + }, + # fails because device is unknown + # should fail with NOT_FOUND + "unknown": { + "user_id": local_user, + "device_id": "unknown", + "signatures": { + local_user: {"ed25519:" + selfsigning_pubkey: "something"} + }, + }, + # fails because the signature is invalid + # should fail with INVALID_SIGNATURE + master_pubkey: { + "user_id": local_user, + "usage": ["master"], + "keys": {"ed25519:" + master_pubkey: master_pubkey}, + "signatures": { + local_user: {"ed25519:" + device_pubkey: "something"} + }, + }, + }, + other_user: { + # fails because the device is not the user's master-signing key + # should fail with NOT_FOUND + "unknown": { + "user_id": other_user, + "device_id": "unknown", + "signatures": { + local_user: {"ed25519:" + usersigning_pubkey: "something"} + }, + }, + other_master_pubkey: { + # fails because the key doesn't match what the server has + # should fail with UNKNOWN + "user_id": other_user, + "usage": ["master"], + "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey}, + "something": "random", + "signatures": { + local_user: {"ed25519:" + usersigning_pubkey: "something"} + }, + }, + }, + }, + ) + + user_failures = ret["failures"][local_user] + self.assertEqual( + user_failures[device_id]["errcode"], errors.Codes.INVALID_SIGNATURE + ) + self.assertEqual( + user_failures[master_pubkey]["errcode"], errors.Codes.INVALID_SIGNATURE + ) + self.assertEqual(user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND) + + other_user_failures = ret["failures"][other_user] + self.assertEqual( + other_user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND + ) + self.assertEqual( + other_user_failures[other_master_pubkey]["errcode"], errors.Codes.UNKNOWN + ) + + # test successful signatures + del device_key["signatures"] + sign.sign_json(device_key, local_user, selfsigning_signing_key) + sign.sign_json(master_key, local_user, device_signing_key) + sign.sign_json(other_master_key, local_user, usersigning_signing_key) + ret = yield self.handler.upload_signatures_for_device_keys( + local_user, + { + local_user: {device_id: device_key, master_pubkey: master_key}, + other_user: {other_master_pubkey: other_master_key}, + }, + ) + + self.assertEqual(ret["failures"], {}) + + # fetch the signed keys/devices and make sure that the signatures are there + ret = yield self.handler.query_devices( + {"device_keys": {local_user: [], other_user: []}}, 0, local_user + ) + + self.assertEqual( + ret["device_keys"][local_user]["xyz"]["signatures"][local_user][ + "ed25519:" + selfsigning_pubkey + ], + device_key["signatures"][local_user]["ed25519:" + selfsigning_pubkey], + ) + self.assertEqual( + ret["master_keys"][local_user]["signatures"][local_user][ + "ed25519:" + device_id + ], + master_key["signatures"][local_user]["ed25519:" + device_id], + ) + self.assertEqual( + ret["master_keys"][other_user]["signatures"][local_user][ + "ed25519:" + usersigning_pubkey + ], + other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey], + ) diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index c4503c1611..70f172eb02 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd # Copyright 2017 New Vector Ltd +# Copyright 2019 Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -94,23 +95,29 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) + version_etag = res["etag"] + del res["etag"] self.assertDictEqual( res, { "version": "1", "algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data", + "count": 0, }, ) # check we can retrieve it as a specific version res = yield self.handler.get_version_info(self.local_user, "1") + self.assertEqual(res["etag"], version_etag) + del res["etag"] self.assertDictEqual( res, { "version": "1", "algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data", + "count": 0, }, ) @@ -126,12 +133,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) + del res["etag"] self.assertDictEqual( res, { "version": "2", "algorithm": "m.megolm_backup.v1", "auth_data": "second_version_auth_data", + "count": 0, }, ) @@ -158,12 +167,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) + del res["etag"] self.assertDictEqual( res, { "algorithm": "m.megolm_backup.v1", "auth_data": "revised_first_version_auth_data", "version": version, + "count": 0, }, ) @@ -187,9 +198,8 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): self.assertEqual(res, 404) @defer.inlineCallbacks - def test_update_bad_version(self): - """Check that we get a 400 if the version in the body is missing or - doesn't match + def test_update_omitted_version(self): + """Check that the update succeeds if the version is missing from the body """ version = yield self.handler.create_version( self.local_user, @@ -197,19 +207,37 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): ) self.assertEqual(version, "1") - res = None - try: - yield self.handler.update_version( - self.local_user, - version, - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - }, - ) - except errors.SynapseError as e: - res = e.code - self.assertEqual(res, 400) + yield self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + }, + ) + + # check we can retrieve it as the current version + res = yield self.handler.get_version_info(self.local_user) + del res["etag"] # etag is opaque, so don't test its contents + self.assertDictEqual( + res, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": version, + "count": 0, + }, + ) + + @defer.inlineCallbacks + def test_update_bad_version(self): + """Check that we get a 400 if the version in the body doesn't match + """ + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) + self.assertEqual(version, "1") res = None try: @@ -394,6 +422,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): yield self.handler.upload_room_keys(self.local_user, version, room_keys) + # get the etag to compare to future versions + res = yield self.handler.get_version_info(self.local_user) + backup_etag = res["etag"] + self.assertEqual(res["count"], 1) + new_room_keys = copy.deepcopy(room_keys) new_room_key = new_room_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"] @@ -408,6 +441,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): "SSBBTSBBIEZJU0gK", ) + # the etag should be the same since the session did not change + res = yield self.handler.get_version_info(self.local_user) + self.assertEqual(res["etag"], backup_etag) + # test that marking the session as verified however /does/ replace it new_room_key["is_verified"] = True yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) @@ -417,6 +454,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) + # the etag should NOT be equal now, since the key changed + res = yield self.handler.get_version_info(self.local_user) + self.assertNotEqual(res["etag"], backup_etag) + backup_etag = res["etag"] + # test that a session with a higher forwarded_count doesn't replace one # with a lower forwarding count new_room_key["forwarded_count"] = 2 @@ -428,6 +470,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) + # the etag should be the same since the session did not change + res = yield self.handler.get_version_info(self.local_user) + self.assertEqual(res["etag"], backup_etag) + # TODO: check edge cases as well as the common variations here @defer.inlineCallbacks diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py new file mode 100644 index 0000000000..96fea58673 --- /dev/null +++ b/tests/handlers/test_federation.py @@ -0,0 +1,274 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from unittest import TestCase + +from synapse.api.constants import EventTypes +from synapse.api.errors import AuthError, Codes, SynapseError +from synapse.api.room_versions import RoomVersions +from synapse.events import EventBase +from synapse.federation.federation_base import event_from_pdu_json +from synapse.logging.context import LoggingContext, run_in_background +from synapse.rest import admin +from synapse.rest.client.v1 import login, room + +from tests import unittest + +logger = logging.getLogger(__name__) + + +class FederationTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver(http_client=None) + self.handler = hs.get_handlers().federation_handler + self.store = hs.get_datastore() + return hs + + def test_exchange_revoked_invite(self): + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + + # Send a 3PID invite event with an empty body so it's considered as a revoked one. + invite_token = "sometoken" + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.ThirdPartyInvite, + state_key=invite_token, + body={}, + tok=tok, + ) + + d = self.handler.on_exchange_third_party_invite_request( + room_id=room_id, + event_dict={ + "type": EventTypes.Member, + "room_id": room_id, + "sender": user_id, + "state_key": "@someone:example.org", + "content": { + "membership": "invite", + "third_party_invite": { + "display_name": "alice", + "signed": { + "mxid": "@alice:localhost", + "token": invite_token, + "signatures": { + "magic.forest": { + "ed25519:3": "fQpGIW1Snz+pwLZu6sTy2aHy/DYWWTspTJRPyNp0PKkymfIsNffysMl6ObMMFdIJhk6g6pwlIqZ54rxo8SLmAg" + } + }, + }, + }, + }, + }, + ) + + failure = self.get_failure(d, AuthError).value + + self.assertEqual(failure.code, 403, failure) + self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure) + self.assertEqual(failure.msg, "You are not invited to this room.") + + def test_rejected_message_event_state(self): + """ + Check that we store the state group correctly for rejected non-state events. + + Regression test for #6289. + """ + OTHER_SERVER = "otherserver" + OTHER_USER = "@otheruser:" + OTHER_SERVER + + # create the room + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + room_version = self.get_success(self.store.get_room_version(room_id)) + + # pretend that another server has joined + join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id) + + # check the state group + sg = self.successResultOf( + self.store._get_state_group_for_event(join_event.event_id) + ) + + # build and send an event which will be rejected + ev = event_from_pdu_json( + { + "type": EventTypes.Message, + "content": {}, + "room_id": room_id, + "sender": "@yetanotheruser:" + OTHER_SERVER, + "depth": join_event["depth"] + 1, + "prev_events": [join_event.event_id], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + room_version, + ) + + with LoggingContext(request="send_rejected"): + d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev) + self.get_success(d) + + # that should have been rejected + e = self.get_success(self.store.get_event(ev.event_id, allow_rejected=True)) + self.assertIsNotNone(e.rejected_reason) + + # ... and the state group should be the same as before + sg2 = self.successResultOf(self.store._get_state_group_for_event(ev.event_id)) + + self.assertEqual(sg, sg2) + + def test_rejected_state_event_state(self): + """ + Check that we store the state group correctly for rejected state events. + + Regression test for #6289. + """ + OTHER_SERVER = "otherserver" + OTHER_USER = "@otheruser:" + OTHER_SERVER + + # create the room + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + room_version = self.get_success(self.store.get_room_version(room_id)) + + # pretend that another server has joined + join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id) + + # check the state group + sg = self.successResultOf( + self.store._get_state_group_for_event(join_event.event_id) + ) + + # build and send an event which will be rejected + ev = event_from_pdu_json( + { + "type": "org.matrix.test", + "state_key": "test_key", + "content": {}, + "room_id": room_id, + "sender": "@yetanotheruser:" + OTHER_SERVER, + "depth": join_event["depth"] + 1, + "prev_events": [join_event.event_id], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + room_version, + ) + + with LoggingContext(request="send_rejected"): + d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev) + self.get_success(d) + + # that should have been rejected + e = self.get_success(self.store.get_event(ev.event_id, allow_rejected=True)) + self.assertIsNotNone(e.rejected_reason) + + # ... and the state group should be the same as before + sg2 = self.successResultOf(self.store._get_state_group_for_event(ev.event_id)) + + self.assertEqual(sg, sg2) + + def _build_and_send_join_event(self, other_server, other_user, room_id): + join_event = self.get_success( + self.handler.on_make_join_request(other_server, room_id, other_user) + ) + # the auth code requires that a signature exists, but doesn't check that + # signature... go figure. + join_event.signatures[other_server] = {"x": "y"} + with LoggingContext(request="send_join"): + d = run_in_background( + self.handler.on_send_join_request, other_server, join_event + ) + self.get_success(d) + + # sanity-check: the room should show that the new user is a member + r = self.get_success(self.store.get_current_state_ids(room_id)) + self.assertEqual(r[(EventTypes.Member, other_user)], join_event.event_id) + + return join_event + + +class EventFromPduTestCase(TestCase): + def test_valid_json(self): + """Valid JSON should be turned into an event.""" + ev = event_from_pdu_json( + { + "type": EventTypes.Message, + "content": {"bool": True, "null": None, "int": 1, "str": "foobar"}, + "room_id": "!room:test", + "sender": "@user:test", + "depth": 1, + "prev_events": [], + "auth_events": [], + "origin_server_ts": 1234, + }, + RoomVersions.V6, + ) + + self.assertIsInstance(ev, EventBase) + + def test_invalid_numbers(self): + """Invalid values for an integer should be rejected, all floats should be rejected.""" + for value in [ + -(2 ** 53), + 2 ** 53, + 1.0, + float("inf"), + float("-inf"), + float("nan"), + ]: + with self.assertRaises(SynapseError): + event_from_pdu_json( + { + "type": EventTypes.Message, + "content": {"foo": value}, + "room_id": "!room:test", + "sender": "@user:test", + "depth": 1, + "prev_events": [], + "auth_events": [], + "origin_server_ts": 1234, + }, + RoomVersions.V6, + ) + + def test_invalid_nested(self): + """List and dictionaries are recursively searched.""" + with self.assertRaises(SynapseError): + event_from_pdu_json( + { + "type": EventTypes.Message, + "content": {"foo": [{"bar": 2 ** 56}]}, + "room_id": "!room:test", + "sender": "@user:test", + "depth": 1, + "prev_events": [], + "auth_events": [], + "origin_server_ts": 1234, + }, + RoomVersions.V6, + ) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py new file mode 100644 index 0000000000..1bb25ab684 --- /dev/null +++ b/tests/handlers/test_oidc.py @@ -0,0 +1,570 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Quentin Gliech +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from urllib.parse import parse_qs, urlparse + +from mock import Mock, patch + +import attr +import pymacaroons + +from twisted.internet import defer +from twisted.python.failure import Failure +from twisted.web._newclient import ResponseDone + +from synapse.handlers.oidc_handler import ( + MappingException, + OidcError, + OidcHandler, + OidcMappingProvider, +) +from synapse.types import UserID + +from tests.unittest import HomeserverTestCase, override_config + + +@attr.s +class FakeResponse: + code = attr.ib() + body = attr.ib() + phrase = attr.ib() + + def deliverBody(self, protocol): + protocol.dataReceived(self.body) + protocol.connectionLost(Failure(ResponseDone())) + + +# These are a few constants that are used as config parameters in the tests. +ISSUER = "https://issuer/" +CLIENT_ID = "test-client-id" +CLIENT_SECRET = "test-client-secret" +BASE_URL = "https://synapse/" +CALLBACK_URL = BASE_URL + "_synapse/oidc/callback" +SCOPES = ["openid"] + +AUTHORIZATION_ENDPOINT = ISSUER + "authorize" +TOKEN_ENDPOINT = ISSUER + "token" +USERINFO_ENDPOINT = ISSUER + "userinfo" +WELL_KNOWN = ISSUER + ".well-known/openid-configuration" +JWKS_URI = ISSUER + ".well-known/jwks.json" + +# config for common cases +COMMON_CONFIG = { + "discover": False, + "authorization_endpoint": AUTHORIZATION_ENDPOINT, + "token_endpoint": TOKEN_ENDPOINT, + "jwks_uri": JWKS_URI, +} + + +# The cookie name and path don't really matter, just that it has to be coherent +# between the callback & redirect handlers. +COOKIE_NAME = b"oidc_session" +COOKIE_PATH = "/_synapse/oidc" + +MockedMappingProvider = Mock(OidcMappingProvider) + + +def simple_async_mock(return_value=None, raises=None): + # AsyncMock is not available in python3.5, this mimics part of its behaviour + async def cb(*args, **kwargs): + if raises: + raise raises + return return_value + + return Mock(side_effect=cb) + + +async def get_json(url): + # Mock get_json calls to handle jwks & oidc discovery endpoints + if url == WELL_KNOWN: + # Minimal discovery document, as defined in OpenID.Discovery + # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata + return { + "issuer": ISSUER, + "authorization_endpoint": AUTHORIZATION_ENDPOINT, + "token_endpoint": TOKEN_ENDPOINT, + "jwks_uri": JWKS_URI, + "userinfo_endpoint": USERINFO_ENDPOINT, + "response_types_supported": ["code"], + "subject_types_supported": ["public"], + "id_token_signing_alg_values_supported": ["RS256"], + } + elif url == JWKS_URI: + return {"keys": []} + + +class OidcHandlerTestCase(HomeserverTestCase): + def make_homeserver(self, reactor, clock): + + self.http_client = Mock(spec=["get_json"]) + self.http_client.get_json.side_effect = get_json + self.http_client.user_agent = "Synapse Test" + + config = self.default_config() + config["public_baseurl"] = BASE_URL + oidc_config = config.get("oidc_config", {}) + oidc_config["enabled"] = True + oidc_config["client_id"] = CLIENT_ID + oidc_config["client_secret"] = CLIENT_SECRET + oidc_config["issuer"] = ISSUER + oidc_config["scopes"] = SCOPES + oidc_config["user_mapping_provider"] = { + "module": __name__ + ".MockedMappingProvider" + } + config["oidc_config"] = oidc_config + + hs = self.setup_test_homeserver( + http_client=self.http_client, + proxied_http_client=self.http_client, + config=config, + ) + + self.handler = OidcHandler(hs) + + return hs + + def metadata_edit(self, values): + return patch.dict(self.handler._provider_metadata, values) + + def assertRenderedError(self, error, error_description=None): + args = self.handler._render_error.call_args[0] + self.assertEqual(args[1], error) + if error_description is not None: + self.assertEqual(args[2], error_description) + # Reset the render_error mock + self.handler._render_error.reset_mock() + + def test_config(self): + """Basic config correctly sets up the callback URL and client auth correctly.""" + self.assertEqual(self.handler._callback_url, CALLBACK_URL) + self.assertEqual(self.handler._client_auth.client_id, CLIENT_ID) + self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET) + + @override_config({"oidc_config": {"discover": True}}) + @defer.inlineCallbacks + def test_discovery(self): + """The handler should discover the endpoints from OIDC discovery document.""" + # This would throw if some metadata were invalid + metadata = yield defer.ensureDeferred(self.handler.load_metadata()) + self.http_client.get_json.assert_called_once_with(WELL_KNOWN) + + self.assertEqual(metadata.issuer, ISSUER) + self.assertEqual(metadata.authorization_endpoint, AUTHORIZATION_ENDPOINT) + self.assertEqual(metadata.token_endpoint, TOKEN_ENDPOINT) + self.assertEqual(metadata.jwks_uri, JWKS_URI) + # FIXME: it seems like authlib does not have that defined in its metadata models + # self.assertEqual(metadata.userinfo_endpoint, USERINFO_ENDPOINT) + + # subsequent calls should be cached + self.http_client.reset_mock() + yield defer.ensureDeferred(self.handler.load_metadata()) + self.http_client.get_json.assert_not_called() + + @override_config({"oidc_config": COMMON_CONFIG}) + @defer.inlineCallbacks + def test_no_discovery(self): + """When discovery is disabled, it should not try to load from discovery document.""" + yield defer.ensureDeferred(self.handler.load_metadata()) + self.http_client.get_json.assert_not_called() + + @override_config({"oidc_config": COMMON_CONFIG}) + @defer.inlineCallbacks + def test_load_jwks(self): + """JWKS loading is done once (then cached) if used.""" + jwks = yield defer.ensureDeferred(self.handler.load_jwks()) + self.http_client.get_json.assert_called_once_with(JWKS_URI) + self.assertEqual(jwks, {"keys": []}) + + # subsequent calls should be cached… + self.http_client.reset_mock() + yield defer.ensureDeferred(self.handler.load_jwks()) + self.http_client.get_json.assert_not_called() + + # …unless forced + self.http_client.reset_mock() + yield defer.ensureDeferred(self.handler.load_jwks(force=True)) + self.http_client.get_json.assert_called_once_with(JWKS_URI) + + # Throw if the JWKS uri is missing + with self.metadata_edit({"jwks_uri": None}): + with self.assertRaises(RuntimeError): + yield defer.ensureDeferred(self.handler.load_jwks(force=True)) + + # Return empty key set if JWKS are not used + self.handler._scopes = [] # not asking the openid scope + self.http_client.get_json.reset_mock() + jwks = yield defer.ensureDeferred(self.handler.load_jwks(force=True)) + self.http_client.get_json.assert_not_called() + self.assertEqual(jwks, {"keys": []}) + + @override_config({"oidc_config": COMMON_CONFIG}) + def test_validate_config(self): + """Provider metadatas are extensively validated.""" + h = self.handler + + # Default test config does not throw + h._validate_metadata() + + with self.metadata_edit({"issuer": None}): + self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata) + + with self.metadata_edit({"issuer": "http://insecure/"}): + self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata) + + with self.metadata_edit({"issuer": "https://invalid/?because=query"}): + self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata) + + with self.metadata_edit({"authorization_endpoint": None}): + self.assertRaisesRegex( + ValueError, "authorization_endpoint", h._validate_metadata + ) + + with self.metadata_edit({"authorization_endpoint": "http://insecure/auth"}): + self.assertRaisesRegex( + ValueError, "authorization_endpoint", h._validate_metadata + ) + + with self.metadata_edit({"token_endpoint": None}): + self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata) + + with self.metadata_edit({"token_endpoint": "http://insecure/token"}): + self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata) + + with self.metadata_edit({"jwks_uri": None}): + self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata) + + with self.metadata_edit({"jwks_uri": "http://insecure/jwks.json"}): + self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata) + + with self.metadata_edit({"response_types_supported": ["id_token"]}): + self.assertRaisesRegex( + ValueError, "response_types_supported", h._validate_metadata + ) + + with self.metadata_edit( + {"token_endpoint_auth_methods_supported": ["client_secret_basic"]} + ): + # should not throw, as client_secret_basic is the default auth method + h._validate_metadata() + + with self.metadata_edit( + {"token_endpoint_auth_methods_supported": ["client_secret_post"]} + ): + self.assertRaisesRegex( + ValueError, + "token_endpoint_auth_methods_supported", + h._validate_metadata, + ) + + # Tests for configs that the userinfo endpoint + self.assertFalse(h._uses_userinfo) + h._scopes = [] # do not request the openid scope + self.assertTrue(h._uses_userinfo) + self.assertRaisesRegex(ValueError, "userinfo_endpoint", h._validate_metadata) + + with self.metadata_edit( + {"userinfo_endpoint": USERINFO_ENDPOINT, "jwks_uri": None} + ): + # Shouldn't raise with a valid userinfo, even without + h._validate_metadata() + + @override_config({"oidc_config": {"skip_verification": True}}) + def test_skip_verification(self): + """Provider metadata validation can be disabled by config.""" + with self.metadata_edit({"issuer": "http://insecure"}): + # This should not throw + self.handler._validate_metadata() + + @defer.inlineCallbacks + def test_redirect_request(self): + """The redirect request has the right arguments & generates a valid session cookie.""" + req = Mock(spec=["addCookie"]) + url = yield defer.ensureDeferred( + self.handler.handle_redirect_request(req, b"http://client/redirect") + ) + url = urlparse(url) + auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT) + + self.assertEqual(url.scheme, auth_endpoint.scheme) + self.assertEqual(url.netloc, auth_endpoint.netloc) + self.assertEqual(url.path, auth_endpoint.path) + + params = parse_qs(url.query) + self.assertEqual(params["redirect_uri"], [CALLBACK_URL]) + self.assertEqual(params["response_type"], ["code"]) + self.assertEqual(params["scope"], [" ".join(SCOPES)]) + self.assertEqual(params["client_id"], [CLIENT_ID]) + self.assertEqual(len(params["state"]), 1) + self.assertEqual(len(params["nonce"]), 1) + + # Check what is in the cookie + # note: python3.5 mock does not have the .called_once() method + calls = req.addCookie.call_args_list + self.assertEqual(len(calls), 1) # called once + # For some reason, call.args does not work with python3.5 + args = calls[0][0] + kwargs = calls[0][1] + self.assertEqual(args[0], COOKIE_NAME) + self.assertEqual(kwargs["path"], COOKIE_PATH) + cookie = args[1] + + macaroon = pymacaroons.Macaroon.deserialize(cookie) + state = self.handler._get_value_from_macaroon(macaroon, "state") + nonce = self.handler._get_value_from_macaroon(macaroon, "nonce") + redirect = self.handler._get_value_from_macaroon( + macaroon, "client_redirect_url" + ) + + self.assertEqual(params["state"], [state]) + self.assertEqual(params["nonce"], [nonce]) + self.assertEqual(redirect, "http://client/redirect") + + @defer.inlineCallbacks + def test_callback_error(self): + """Errors from the provider returned in the callback are displayed.""" + self.handler._render_error = Mock() + request = Mock(args={}) + request.args[b"error"] = [b"invalid_client"] + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_client", "") + + request.args[b"error_description"] = [b"some description"] + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_client", "some description") + + @defer.inlineCallbacks + def test_callback(self): + """Code callback works and display errors if something went wrong. + + A lot of scenarios are tested here: + - when the callback works, with userinfo from ID token + - when the user mapping fails + - when ID token verification fails + - when the callback works, with userinfo fetched from the userinfo endpoint + - when the userinfo fetching fails + - when the code exchange fails + """ + token = { + "type": "bearer", + "id_token": "id_token", + "access_token": "access_token", + } + userinfo = { + "sub": "foo", + "preferred_username": "bar", + } + user_id = UserID("foo", "domain.org") + self.handler._render_error = Mock(return_value=None) + self.handler._exchange_code = simple_async_mock(return_value=token) + self.handler._parse_id_token = simple_async_mock(return_value=userinfo) + self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) + self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) + self.handler._auth_handler.complete_sso_login = simple_async_mock() + request = Mock(spec=["args", "getCookie", "addCookie"]) + + code = "code" + state = "state" + nonce = "nonce" + client_redirect_url = "http://client/redirect" + session = self.handler._generate_oidc_session_token( + state=state, + nonce=nonce, + client_redirect_url=client_redirect_url, + ui_auth_session_id=None, + ) + request.getCookie.return_value = session + + request.args = {} + request.args[b"code"] = [code.encode("utf-8")] + request.args[b"state"] = [state.encode("utf-8")] + + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + + self.handler._auth_handler.complete_sso_login.assert_called_once_with( + user_id, request, client_redirect_url, + ) + self.handler._exchange_code.assert_called_once_with(code) + self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce) + self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token) + self.handler._fetch_userinfo.assert_not_called() + self.handler._render_error.assert_not_called() + + # Handle mapping errors + self.handler._map_userinfo_to_user = simple_async_mock( + raises=MappingException() + ) + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("mapping_error") + self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) + + # Handle ID token errors + self.handler._parse_id_token = simple_async_mock(raises=Exception()) + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_token") + + self.handler._auth_handler.complete_sso_login.reset_mock() + self.handler._exchange_code.reset_mock() + self.handler._parse_id_token.reset_mock() + self.handler._map_userinfo_to_user.reset_mock() + self.handler._fetch_userinfo.reset_mock() + + # With userinfo fetching + self.handler._scopes = [] # do not ask the "openid" scope + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + + self.handler._auth_handler.complete_sso_login.assert_called_once_with( + user_id, request, client_redirect_url, + ) + self.handler._exchange_code.assert_called_once_with(code) + self.handler._parse_id_token.assert_not_called() + self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token) + self.handler._fetch_userinfo.assert_called_once_with(token) + self.handler._render_error.assert_not_called() + + # Handle userinfo fetching error + self.handler._fetch_userinfo = simple_async_mock(raises=Exception()) + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("fetch_error") + + # Handle code exchange failure + self.handler._exchange_code = simple_async_mock( + raises=OidcError("invalid_request") + ) + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_request") + + @defer.inlineCallbacks + def test_callback_session(self): + """The callback verifies the session presence and validity""" + self.handler._render_error = Mock(return_value=None) + request = Mock(spec=["args", "getCookie", "addCookie"]) + + # Missing cookie + request.args = {} + request.getCookie.return_value = None + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("missing_session", "No session cookie found") + + # Missing session parameter + request.args = {} + request.getCookie.return_value = "session" + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_request", "State parameter is missing") + + # Invalid cookie + request.args = {} + request.args[b"state"] = [b"state"] + request.getCookie.return_value = "session" + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_session") + + # Mismatching session + session = self.handler._generate_oidc_session_token( + state="state", + nonce="nonce", + client_redirect_url="http://client/redirect", + ui_auth_session_id=None, + ) + request.args = {} + request.args[b"state"] = [b"mismatching state"] + request.getCookie.return_value = session + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("mismatching_session") + + # Valid session + request.args = {} + request.args[b"state"] = [b"state"] + request.getCookie.return_value = session + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("invalid_request") + + @override_config({"oidc_config": {"client_auth_method": "client_secret_post"}}) + @defer.inlineCallbacks + def test_exchange_code(self): + """Code exchange behaves correctly and handles various error scenarios.""" + token = {"type": "bearer"} + token_json = json.dumps(token).encode("utf-8") + self.http_client.request = simple_async_mock( + return_value=FakeResponse(code=200, phrase=b"OK", body=token_json) + ) + code = "code" + ret = yield defer.ensureDeferred(self.handler._exchange_code(code)) + kwargs = self.http_client.request.call_args[1] + + self.assertEqual(ret, token) + self.assertEqual(kwargs["method"], "POST") + self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + + args = parse_qs(kwargs["data"].decode("utf-8")) + self.assertEqual(args["grant_type"], ["authorization_code"]) + self.assertEqual(args["code"], [code]) + self.assertEqual(args["client_id"], [CLIENT_ID]) + self.assertEqual(args["client_secret"], [CLIENT_SECRET]) + self.assertEqual(args["redirect_uri"], [CALLBACK_URL]) + + # Test error handling + self.http_client.request = simple_async_mock( + return_value=FakeResponse( + code=400, + phrase=b"Bad Request", + body=b'{"error": "foo", "error_description": "bar"}', + ) + ) + with self.assertRaises(OidcError) as exc: + yield defer.ensureDeferred(self.handler._exchange_code(code)) + self.assertEqual(exc.exception.error, "foo") + self.assertEqual(exc.exception.error_description, "bar") + + # Internal server error with no JSON body + self.http_client.request = simple_async_mock( + return_value=FakeResponse( + code=500, phrase=b"Internal Server Error", body=b"Not JSON", + ) + ) + with self.assertRaises(OidcError) as exc: + yield defer.ensureDeferred(self.handler._exchange_code(code)) + self.assertEqual(exc.exception.error, "server_error") + + # Internal server error with JSON body + self.http_client.request = simple_async_mock( + return_value=FakeResponse( + code=500, + phrase=b"Internal Server Error", + body=b'{"error": "internal_server_error"}', + ) + ) + with self.assertRaises(OidcError) as exc: + yield defer.ensureDeferred(self.handler._exchange_code(code)) + self.assertEqual(exc.exception.error, "internal_server_error") + + # 4xx error without "error" field + self.http_client.request = simple_async_mock( + return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",) + ) + with self.assertRaises(OidcError) as exc: + yield defer.ensureDeferred(self.handler._exchange_code(code)) + self.assertEqual(exc.exception.error, "server_error") + + # 2xx error with "error" field + self.http_client.request = simple_async_mock( + return_value=FakeResponse( + code=200, phrase=b"OK", body=b'{"error": "some_error"}', + ) + ) + with self.assertRaises(OidcError) as exc: + yield defer.ensureDeferred(self.handler._exchange_code(code)) + self.assertEqual(exc.exception.error, "some_error") diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index f70c6e7d65..05ea40a7de 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -19,9 +19,10 @@ from mock import Mock, call from signedjson.key import generate_signing_key from synapse.api.constants import EventTypes, Membership, PresenceState -from synapse.events import room_version_to_event_format +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events.builder import EventBuilder from synapse.handlers.presence import ( + EXTERNAL_PROCESS_EXPIRY, FEDERATION_PING_INTERVAL, FEDERATION_TIMEOUT, IDLE_TIMER, @@ -337,7 +338,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): ) new_state = handle_timeout( - state, is_mine=True, syncing_user_ids=set([user_id]), now=now + state, is_mine=True, syncing_user_ids={user_id}, now=now ) self.assertIsNotNone(new_state) @@ -413,6 +414,44 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertEquals(state, new_state) +class PresenceHandlerTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs): + self.presence_handler = hs.get_presence_handler() + self.clock = hs.get_clock() + + def test_external_process_timeout(self): + """Test that if an external process doesn't update the records for a while + we time out their syncing users presence. + """ + process_id = 1 + user_id = "@test:server" + + # Notify handler that a user is now syncing. + self.get_success( + self.presence_handler.update_external_syncs_row( + process_id, user_id, True, self.clock.time_msec() + ) + ) + + # Check that if we wait a while without telling the handler the user has + # stopped syncing that their presence state doesn't get timed out. + self.reactor.advance(EXTERNAL_PROCESS_EXPIRY / 2) + + state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(state.state, PresenceState.ONLINE) + + # Check that if the external process timeout fires, then the syncing + # user gets timed out + self.reactor.advance(EXTERNAL_PROCESS_EXPIRY) + + state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(state.state, PresenceState.OFFLINE) + + class PresenceJoinTestCase(unittest.HomeserverTestCase): """Tests remote servers get told about presence of users in the room when they join and when new local users join. @@ -455,8 +494,10 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): self.helper.join(room_id, "@test2:server") # Mark test2 as online, test will be offline with a last_active of 0 - self.presence_handler.set_state( - UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} + self.get_success( + self.presence_handler.set_state( + UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} + ) ) self.reactor.pump([0]) # Wait for presence updates to be handled @@ -504,14 +545,18 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): room_id = self.helper.create_room_as(self.user_id) # Mark test as online - self.presence_handler.set_state( - UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE} + self.get_success( + self.presence_handler.set_state( + UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE} + ) ) # Mark test2 as online, test will be offline with a last_active of 0. # Note we don't join them to the room yet - self.presence_handler.set_state( - UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} + self.get_success( + self.presence_handler.set_state( + UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} + ) ) # Add servers to the room @@ -540,7 +585,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): ) self.assertEqual(expected_state.state, PresenceState.ONLINE) self.federation_sender.send_presence_to_destinations.assert_called_once_with( - destinations=set(("server2", "server3")), states=[expected_state] + destinations={"server2", "server3"}, states=[expected_state] ) def _add_new_user(self, room_id, user_id): @@ -549,7 +594,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): hostname = get_domain_from_id(user_id) - room_version = self.get_success(self.store.get_room_version(room_id)) + room_version = self.get_success(self.store.get_room_version_id(room_id)) builder = EventBuilder( state=self.state, @@ -558,7 +603,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): clock=self.clock, hostname=hostname, signing_key=self.random_signing_key, - format_version=room_version_to_event_format(room_version), + room_version=KNOWN_ROOM_VERSIONS[room_version], room_id=room_id, type=EventTypes.Member, sender=user_id, diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index d60c124eec..29dd7d9c6e 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -14,12 +14,12 @@ # limitations under the License. -from mock import Mock, NonCallableMock +from mock import Mock from twisted.internet import defer import synapse.types -from synapse.api.errors import AuthError +from synapse.api.errors import AuthError, SynapseError from synapse.handlers.profile import MasterProfileHandler from synapse.types import UserID @@ -55,12 +55,8 @@ class ProfileTestCase(unittest.TestCase): federation_client=self.mock_federation, federation_server=Mock(), federation_registry=self.mock_registry, - ratelimiter=NonCallableMock(spec_set=["can_do_action"]), ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.can_do_action.return_value = (True, 0) - self.store = hs.get_datastore() self.frank = UserID.from_string("@1234ABCD:test") @@ -70,6 +66,7 @@ class ProfileTestCase(unittest.TestCase): yield self.store.create_profile(self.frank.localpart) self.handler = hs.get_profile_handler() + self.hs = hs @defer.inlineCallbacks def test_get_my_name(self): @@ -81,19 +78,58 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_name(self): - yield self.handler.set_displayname( - self.frank, synapse.types.create_requester(self.frank), "Frank Jr." + yield defer.ensureDeferred( + self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.frank), "Frank Jr." + ) ) self.assertEquals( - (yield self.store.get_profile_displayname(self.frank.localpart)), + ( + yield defer.ensureDeferred( + self.store.get_profile_displayname(self.frank.localpart) + ) + ), "Frank Jr.", ) + # Set displayname again + yield defer.ensureDeferred( + self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.frank), "Frank" + ) + ) + + self.assertEquals( + (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank", + ) + + @defer.inlineCallbacks + def test_set_my_name_if_disabled(self): + self.hs.config.enable_set_displayname = False + + # Setting displayname for the first time is allowed + yield self.store.set_profile_displayname(self.frank.localpart, "Frank") + + self.assertEquals( + (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank", + ) + + # Setting displayname a second time is forbidden + d = defer.ensureDeferred( + self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.frank), "Frank Jr." + ) + ) + + yield self.assertFailure(d, SynapseError) + @defer.inlineCallbacks def test_set_my_name_noauth(self): - d = self.handler.set_displayname( - self.frank, synapse.types.create_requester(self.bob), "Frank Jr." + d = defer.ensureDeferred( + self.handler.set_displayname( + self.frank, synapse.types.create_requester(self.bob), "Frank Jr." + ) ) yield self.assertFailure(d, AuthError) @@ -137,13 +173,54 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_set_my_avatar(self): - yield self.handler.set_avatar_url( - self.frank, - synapse.types.create_requester(self.frank), - "http://my.server/pic.gif", + yield defer.ensureDeferred( + self.handler.set_avatar_url( + self.frank, + synapse.types.create_requester(self.frank), + "http://my.server/pic.gif", + ) ) self.assertEquals( (yield self.store.get_profile_avatar_url(self.frank.localpart)), "http://my.server/pic.gif", ) + + # Set avatar again + yield defer.ensureDeferred( + self.handler.set_avatar_url( + self.frank, + synapse.types.create_requester(self.frank), + "http://my.server/me.png", + ) + ) + + self.assertEquals( + (yield self.store.get_profile_avatar_url(self.frank.localpart)), + "http://my.server/me.png", + ) + + @defer.inlineCallbacks + def test_set_my_avatar_if_disabled(self): + self.hs.config.enable_set_avatar_url = False + + # Setting displayname for the first time is allowed + yield self.store.set_profile_avatar_url( + self.frank.localpart, "http://my.server/me.png" + ) + + self.assertEquals( + (yield self.store.get_profile_avatar_url(self.frank.localpart)), + "http://my.server/me.png", + ) + + # Set avatar a second time is forbidden + d = defer.ensureDeferred( + self.handler.set_avatar_url( + self.frank, + synapse.types.create_requester(self.frank), + "http://my.server/pic.gif", + ) + ) + + yield self.assertFailure(d, SynapseError) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 1e9ba3a201..ca32f993a3 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -34,7 +34,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): """ Tests the RegistrationHandler. """ def make_homeserver(self, reactor, clock): - hs_config = self.default_config("test") + hs_config = self.default_config() # some of the tests rely on us having a user consent version hs_config["user_consent"] = { @@ -135,6 +135,16 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.handler.register_user(localpart="local_part"), ResourceLimitError ) + def test_auto_join_rooms_for_guests(self): + room_alias_str = "#room:test" + self.hs.config.auto_join_rooms = [room_alias_str] + self.hs.config.auto_join_rooms_for_guests = False + user_id = self.get_success( + self.handler.register_user(localpart="jeff", make_guest=True), + ) + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) + self.assertEqual(len(rooms), 0) + def test_auto_create_auto_join_rooms(self): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] @@ -175,7 +185,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - self.store.is_real_user = Mock(return_value=False) + self.store.is_real_user = Mock(return_value=defer.succeed(False)) user_id = self.get_success(self.handler.register_user(localpart="support")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) @@ -187,8 +197,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - self.store.count_real_users = Mock(return_value=1) - self.store.is_real_user = Mock(return_value=True) + self.store.count_real_users = Mock(return_value=defer.succeed(1)) + self.store.is_real_user = Mock(return_value=defer.succeed(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) directory_handler = self.hs.get_handlers().directory_handler @@ -202,8 +212,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] - self.store.count_real_users = Mock(return_value=2) - self.store.is_real_user = Mock(return_value=True) + self.store.count_real_users = Mock(return_value=defer.succeed(2)) + self.store.is_real_user = Mock(return_value=defer.succeed(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) @@ -256,8 +266,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.handler.register_user(localpart=invalid_user_id), SynapseError ) - @defer.inlineCallbacks - def get_or_create_user(self, requester, localpart, displayname, password_hash=None): + async def get_or_create_user( + self, requester, localpart, displayname, password_hash=None + ): """Creates a new user if the user does not exist, else revokes all previous access tokens and generates a new one. @@ -269,16 +280,14 @@ class RegistrationTestCase(unittest.HomeserverTestCase): one will be randomly generated. Returns: A tuple of (user_id, access_token). - Raises: - RegistrationError if there was a problem registering. """ if localpart is None: raise SynapseError(400, "Request must include user id") - yield self.hs.get_auth().check_auth_blocking() + await self.hs.get_auth().check_auth_blocking() need_register = True try: - yield self.handler.check_username(localpart) + await self.handler.check_username(localpart) except SynapseError as e: if e.errcode == Codes.USER_IN_USE: need_register = False @@ -290,21 +299,21 @@ class RegistrationTestCase(unittest.HomeserverTestCase): token = self.macaroon_generator.generate_access_token(user_id) if need_register: - yield self.handler.register_with_store( + await self.handler.register_with_store( user_id=user_id, password_hash=password_hash, create_profile_with_displayname=user.localpart, ) else: - yield self.hs.get_auth_handler().delete_access_tokens_for_user(user_id) + await self.hs.get_auth_handler().delete_access_tokens_for_user(user_id) - yield self.store.add_access_token_to_user( + await self.store.add_access_token_to_user( user_id=user_id, token=token, device_id=None, valid_until_ms=None ) if displayname is not None: # logger.info("setting user display name: %s -> %s", user_id, displayname) - yield self.hs.get_profile_handler().set_displayname( + await self.hs.get_profile_handler().set_displayname( user, requester, displayname, by_admin=True ) diff --git a/tests/handlers/test_roomlist.py b/tests/handlers/test_roomlist.py deleted file mode 100644 index 61eebb6985..0000000000 --- a/tests/handlers/test_roomlist.py +++ /dev/null @@ -1,39 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.handlers.room_list import RoomListNextBatch - -import tests.unittest -import tests.utils - - -class RoomListTestCase(tests.unittest.TestCase): - """ Tests RoomList's RoomListNextBatch. """ - - def setUp(self): - pass - - def test_check_read_batch_tokens(self): - batch_token = RoomListNextBatch( - stream_ordering="abcdef", - public_room_stream_id="123", - current_limit=20, - direction_is_forward=True, - ).to_token() - next_batch = RoomListNextBatch.from_token(batch_token) - self.assertEquals(next_batch.stream_ordering, "abcdef") - self.assertEquals(next_batch.public_room_stream_id, "123") - self.assertEquals(next_batch.current_limit, 20) - self.assertEquals(next_batch.direction_is_forward, True) diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 7569b6fab5..d9d312f0fb 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -13,9 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse import storage from synapse.rest import admin from synapse.rest.client.v1 import login, room +from synapse.storage.data_stores.main import stats from tests import unittest @@ -42,16 +42,16 @@ class StatsRoomTests(unittest.HomeserverTestCase): Add the background updates we need to run. """ # Ugh, have to reset this flag - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", {"update_name": "populate_stats_prepare", "progress_json": "{}"}, ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_process_rooms", @@ -61,7 +61,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_process_users", @@ -71,7 +71,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", @@ -82,21 +82,21 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) def get_all_room_state(self): - return self.store._simple_select_list( + return self.store.db.simple_select_list( "room_stats_state", None, retcols=("name", "topic", "canonical_alias") ) def _get_current_stats(self, stats_type, stat_id): - table, id_col = storage.stats.TYPE_TO_TABLE[stats_type] + table, id_col = stats.TYPE_TO_TABLE[stats_type] - cols = list(storage.stats.ABSOLUTE_STATS_FIELDS[stats_type]) + list( - storage.stats.PER_SLICE_FIELDS[stats_type] + cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type]) + list( + stats.PER_SLICE_FIELDS[stats_type] ) end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000) return self.get_success( - self.store._simple_select_one( + self.store.db.simple_select_one( table + "_historical", {id_col: stat_id, end_ts: end_ts}, cols, @@ -108,8 +108,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): # Do the initial population of the stats via the background update self._add_background_updates() - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) def test_initial_room(self): """ @@ -141,8 +145,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): # Do the initial population of the user directory via the background update self._add_background_updates() - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) r = self.get_success(self.get_all_room_state()) @@ -178,9 +186,9 @@ class StatsRoomTests(unittest.HomeserverTestCase): # the position that the deltas should begin at, once they take over. self.hs.config.stats_enabled = True self.handler.stats_enabled = True - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( - self.store._simple_update_one( + self.store.db.simple_update_one( table="stats_incremental_position", keyvalues={}, updatevalues={"stream_id": 0}, @@ -188,14 +196,18 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", {"update_name": "populate_stats_prepare", "progress_json": "{}"}, ) ) - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) # Now, before the table is actually ingested, add some more events. self.helper.invite(room=room_1, src=u1, targ=u2, tok=u1_token) @@ -205,13 +217,13 @@ class StatsRoomTests(unittest.HomeserverTestCase): # Now do the initial ingestion. self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", {"update_name": "populate_stats_process_rooms", "progress_json": "{}"}, ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", @@ -221,9 +233,13 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) - self.store._all_done = False - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + self.store.db.updates._all_done = False + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) self.reactor.advance(86401) @@ -607,6 +623,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): """ self.hs.config.stats_enabled = False + self.handler.stats_enabled = False u1 = self.register_user("u1", "pass") u1token = self.login("u1", "pass") @@ -618,6 +635,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.assertIsNone(self._get_current_stats("user", u1)) self.hs.config.stats_enabled = True + self.handler.stats_enabled = True self._perform_background_initial_update() @@ -651,15 +669,15 @@ class StatsRoomTests(unittest.HomeserverTestCase): # preparation stage of the initial background update # Ugh, have to reset this flag - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( - self.store._simple_delete( + self.store.db.simple_delete( "room_stats_current", {"1": 1}, "test_delete_stats" ) ) self.get_success( - self.store._simple_delete( + self.store.db.simple_delete( "user_stats_current", {"1": 1}, "test_delete_stats" ) ) @@ -671,9 +689,9 @@ class StatsRoomTests(unittest.HomeserverTestCase): # now do the background updates - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_process_rooms", @@ -683,7 +701,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_process_users", @@ -693,7 +711,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", @@ -703,8 +721,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) r1stats_complete = self._get_current_stats("room", r1) u1stats_complete = self._get_current_stats("user", u1) diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 31f54bbd7d..e178d7765b 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -12,54 +12,56 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer from synapse.api.errors import Codes, ResourceLimitError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION -from synapse.handlers.sync import SyncConfig, SyncHandler +from synapse.handlers.sync import SyncConfig from synapse.types import UserID import tests.unittest import tests.utils -from tests.utils import setup_test_homeserver -class SyncTestCase(tests.unittest.TestCase): +class SyncTestCase(tests.unittest.HomeserverTestCase): """ Tests Sync Handler. """ - @defer.inlineCallbacks - def setUp(self): - self.hs = yield setup_test_homeserver(self.addCleanup) - self.sync_handler = SyncHandler(self.hs) + def prepare(self, reactor, clock, hs): + self.hs = hs + self.sync_handler = self.hs.get_sync_handler() self.store = self.hs.get_datastore() - @defer.inlineCallbacks - def test_wait_for_sync_for_user_auth_blocking(self): + # AuthBlocking reads from the hs' config on initialization. We need to + # modify its config instead of the hs' + self.auth_blocking = self.hs.get_auth()._auth_blocking - user_id1 = "@user1:server" - user_id2 = "@user2:server" + def test_wait_for_sync_for_user_auth_blocking(self): + user_id1 = "@user1:test" + user_id2 = "@user2:test" sync_config = self._generate_sync_config(user_id1) - self.hs.config.limit_usage_by_mau = True - self.hs.config.max_mau_value = 1 + self.reactor.advance(100) # So we get not 0 time + self.auth_blocking._limit_usage_by_mau = True + self.auth_blocking._max_mau_value = 1 # Check that the happy case does not throw errors - yield self.store.upsert_monthly_active_user(user_id1) - yield self.sync_handler.wait_for_sync_for_user(sync_config) + self.get_success(self.store.upsert_monthly_active_user(user_id1)) + self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config)) # Test that global lock works - self.hs.config.hs_disabled = True - with self.assertRaises(ResourceLimitError) as e: - yield self.sync_handler.wait_for_sync_for_user(sync_config) - self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + self.auth_blocking._hs_disabled = True + e = self.get_failure( + self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError + ) + self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - self.hs.config.hs_disabled = False + self.auth_blocking._hs_disabled = False sync_config = self._generate_sync_config(user_id2) - with self.assertRaises(ResourceLimitError) as e: - yield self.sync_handler.wait_for_sync_for_user(sync_config) - self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + e = self.get_failure( + self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError + ) + self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) def _generate_sync_config(self, user_id): return SyncConfig( diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 1f2ef5d01f..2fa8d4739b 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -24,6 +24,7 @@ from synapse.api.errors import AuthError from synapse.types import UserID from tests import unittest +from tests.unittest import override_config from tests.utils import register_federation_servlets # Some local users to test with @@ -63,34 +64,39 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): mock_federation_client = Mock(spec=["put_json"]) mock_federation_client.put_json.return_value = defer.succeed((200, "OK")) + datastores = Mock() + datastores.main = Mock( + spec=[ + # Bits that Federation needs + "prep_send_transaction", + "delivered_txn", + "get_received_txn_response", + "set_received_txn_response", + "get_destination_retry_timings", + "get_devices_by_remote", + "maybe_store_room_on_invite", + # Bits that user_directory needs + "get_user_directory_stream_pos", + "get_current_state_deltas", + "get_device_updates_by_remote", + ] + ) + + # the tests assume that we are starting at unix time 1000 + reactor.pump((1000,)) + hs = self.setup_test_homeserver( - datastore=( - Mock( - spec=[ - # Bits that Federation needs - "prep_send_transaction", - "delivered_txn", - "get_received_txn_response", - "set_received_txn_response", - "get_destination_retry_timings", - "get_devices_by_remote", - # Bits that user_directory needs - "get_user_directory_stream_pos", - "get_current_state_deltas", - ] - ) - ), notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring, + replication_streams={}, ) + hs.datastores = datastores + return hs def prepare(self, reactor, clock, hs): - # the tests assume that we are starting at unix time 1000 - reactor.pump((1000,)) - mock_notifier = hs.get_notifier() self.on_new_event = mock_notifier.on_new_event @@ -109,7 +115,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): retry_timings_res ) - self.datastore.get_devices_by_remote.return_value = (0, []) + self.datastore.get_device_updates_by_remote.return_value = defer.succeed( + (0, []) + ) def get_received_txn_response(*args): return defer.succeed(None) @@ -118,19 +126,19 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.room_members = [] - def check_joined_room(room_id, user_id): + def check_user_in_room(room_id, user_id): if user_id not in [u.to_string() for u in self.room_members]: raise AuthError(401, "User is not in the room") - hs.get_auth().check_joined_room = check_joined_room + hs.get_auth().check_user_in_room = check_user_in_room def get_joined_hosts_for_room(room_id): - return set(member.domain for member in self.room_members) + return {member.domain for member in self.room_members} self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room def get_current_users_in_room(room_id): - return set(str(u) for u in self.room_members) + return {str(u) for u in self.room_members} hs.get_state_handler().get_current_users_in_room = get_current_users_in_room @@ -139,11 +147,16 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): defer.succeed(1) ) - self.datastore.get_current_state_deltas.return_value = None + self.datastore.get_current_state_deltas.return_value = (0, None) self.datastore.get_to_device_stream_token = lambda: 0 - self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0) + self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: defer.succeed( + ([], 0) + ) self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None + self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed( + None + ) def test_started_typing_local(self): self.room_members = [U_APPLE, U_BANANA] @@ -159,7 +172,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + ) self.assertEquals( events[0], [ @@ -171,6 +186,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ], ) + @override_config({"send_federation": True}) def test_started_typing_remote_send(self): self.room_members = [U_APPLE, U_ONION] @@ -222,7 +238,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + ) self.assertEquals( events[0], [ @@ -234,6 +252,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ], ) + @override_config({"send_federation": True}) def test_stopped_typing(self): self.room_members = [U_APPLE, U_BANANA, U_ONION] @@ -242,7 +261,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): member = RoomMember(ROOM_ID, U_APPLE.to_string()) self.handler._member_typing_until[member] = 1002000 - self.handler._room_typing[ROOM_ID] = set([U_APPLE.to_string()]) + self.handler._room_typing[ROOM_ID] = {U_APPLE.to_string()} self.assertEquals(self.event_source.get_current_key(), 0) @@ -273,7 +292,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + ) self.assertEquals( events[0], [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], @@ -294,7 +315,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + ) self.assertEquals( events[0], [ @@ -311,7 +334,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 2, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 2) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1) + ) self.assertEquals( events[0], [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], @@ -329,7 +354,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 3) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + ) self.assertEquals( events[0], [ diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index c5e91a8c41..c15bce5bef 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -14,6 +14,8 @@ # limitations under the License. from mock import Mock +from twisted.internet import defer + import synapse.rest.admin from synapse.api.constants import UserTypes from synapse.rest.client.v1 import login, room @@ -75,18 +77,16 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) - self.store.remove_from_user_dir = Mock() - self.store.remove_from_user_in_public_room = Mock() + self.store.remove_from_user_dir = Mock(return_value=defer.succeed(None)) self.get_success(self.handler.handle_user_deactivated(s_user_id)) self.store.remove_from_user_dir.not_called() - self.store.remove_from_user_in_public_room.not_called() def test_handle_user_deactivated_regular_user(self): r_user_id = "@regular:test" self.get_success( self.store.register_user(user_id=r_user_id, password_hash=None) ) - self.store.remove_from_user_dir = Mock() + self.store.remove_from_user_dir = Mock(return_value=defer.succeed(None)) self.get_success(self.handler.handle_user_deactivated(r_user_id)) self.store.remove_from_user_dir.called_once_with(r_user_id) @@ -114,7 +114,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): public_users = self.get_users_in_public_rooms() self.assertEqual( - self._compress_shared(shares_private), set([(u1, u2, room), (u2, u1, room)]) + self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)} ) self.assertEqual(public_users, []) @@ -147,6 +147,98 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): s = self.get_success(self.handler.search_users(u1, "user3", 10)) self.assertEqual(len(s["results"]), 0) + def test_spam_checker(self): + """ + A user which fails to the spam checks will not appear in search results. + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + + # We do not add users to the directory until they join a room. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 0) + + room = self.helper.create_room_as(u1, is_public=False, tok=u1_token) + self.helper.invite(room, src=u1, targ=u2, tok=u1_token) + self.helper.join(room, user=u2, tok=u2_token) + + # Check we have populated the database correctly. + shares_private = self.get_users_who_share_private_rooms() + public_users = self.get_users_in_public_rooms() + + self.assertEqual( + self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)} + ) + self.assertEqual(public_users, []) + + # We get one search result when searching for user2 by user1. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 1) + + # Configure a spam checker that does not filter any users. + spam_checker = self.hs.get_spam_checker() + + class AllowAll(object): + def check_username_for_spam(self, user_profile): + # Allow all users. + return False + + spam_checker.spam_checkers = [AllowAll()] + + # The results do not change: + # We get one search result when searching for user2 by user1. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 1) + + # Configure a spam checker that filters all users. + class BlockAll(object): + def check_username_for_spam(self, user_profile): + # All users are spammy. + return True + + spam_checker.spam_checkers = [BlockAll()] + + # User1 now gets no search results for any of the other users. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 0) + + def test_legacy_spam_checker(self): + """ + A spam checker without the expected method should be ignored. + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + + # We do not add users to the directory until they join a room. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 0) + + room = self.helper.create_room_as(u1, is_public=False, tok=u1_token) + self.helper.invite(room, src=u1, targ=u2, tok=u1_token) + self.helper.join(room, user=u2, tok=u2_token) + + # Check we have populated the database correctly. + shares_private = self.get_users_who_share_private_rooms() + public_users = self.get_users_in_public_rooms() + + self.assertEqual( + self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)} + ) + self.assertEqual(public_users, []) + + # Configure a spam checker. + spam_checker = self.hs.get_spam_checker() + # The spam checker doesn't need any methods, so create a bare object. + spam_checker.spam_checker = object() + + # We get one search result when searching for user2 by user1. + s = self.get_success(self.handler.search_users(u1, "user2", 10)) + self.assertEqual(len(s["results"]), 1) + def _compress_shared(self, shared): """ Compress a list of users who share rooms dicts to a list of tuples. @@ -158,7 +250,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def get_users_in_public_rooms(self): r = self.get_success( - self.store._simple_select_list( + self.store.db.simple_select_list( "users_in_public_rooms", None, ("user_id", "room_id") ) ) @@ -169,7 +261,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def get_users_who_share_private_rooms(self): return self.get_success( - self.store._simple_select_list( + self.store.db.simple_select_list( "users_who_share_private_rooms", None, ["user_id", "other_user_id", "room_id"], @@ -181,10 +273,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): Add the background updates we need to run. """ # Ugh, have to reset this flag - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_user_directory_createtables", @@ -193,7 +285,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_user_directory_process_rooms", @@ -203,7 +295,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_user_directory_process_users", @@ -213,7 +305,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_user_directory_cleanup", @@ -255,19 +347,23 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # Do the initial population of the user directory via the background update self._add_background_updates() - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) shares_private = self.get_users_who_share_private_rooms() public_users = self.get_users_in_public_rooms() # User 1 and User 2 are in the same public room - self.assertEqual(set(public_users), set([(u1, room), (u2, room)])) + self.assertEqual(set(public_users), {(u1, room), (u2, room)}) # User 1 and User 3 share private rooms self.assertEqual( self._compress_shared(shares_private), - set([(u1, u3, private_room), (u3, u1, private_room)]), + {(u1, u3, private_room), (u3, u1, private_room)}, ) def test_initial_share_all_users(self): @@ -290,15 +386,19 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # Do the initial population of the user directory via the background update self._add_background_updates() - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) shares_private = self.get_users_who_share_private_rooms() public_users = self.get_users_in_public_rooms() # No users share rooms self.assertEqual(public_users, []) - self.assertEqual(self._compress_shared(shares_private), set([])) + self.assertEqual(self._compress_shared(shares_private), set()) # Despite not sharing a room, search_all_users means we get a search # result. |