From a72d5f39db55dfecb48291acdd6566c6556e0b0b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 27 May 2020 19:41:06 +0100 Subject: Add test for Linearizer.is_queued(..) --- tests/util/test_linearizer.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) (limited to 'tests') diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py index 852ef23185..ca3858b184 100644 --- a/tests/util/test_linearizer.py +++ b/tests/util/test_linearizer.py @@ -45,6 +45,38 @@ class LinearizerTestCase(unittest.TestCase): with (yield d2): pass + @defer.inlineCallbacks + def test_linearizer_is_queued(self): + linearizer = Linearizer() + + key = object() + + d1 = linearizer.queue(key) + cm1 = yield d1 + + # Since d1 gets called immediately, "is_queued" should return false. + self.assertFalse(linearizer.is_queued(key)) + + d2 = linearizer.queue(key) + self.assertFalse(d2.called) + + # Now d2 is queued up behind successful completion of cm1 + self.assertTrue(linearizer.is_queued(key)) + + with cm1: + self.assertFalse(d2.called) + + # cm1 still not done, so d2 still queued. + self.assertTrue(linearizer.is_queued(key)) + + # And now d2 is called and nothing is in the queue again + self.assertFalse(linearizer.is_queued(key)) + + with (yield d2): + self.assertFalse(linearizer.is_queued(key)) + + self.assertFalse(linearizer.is_queued(key)) + def test_lots_of_queued_things(self): # we have one slow thing, and lots of fast things queued up behind it. # it should *not* explode the stack. -- cgit 1.5.1 From 901b1fa561e3cc661d78aa96d59802cf2078cb0d Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 1 Jun 2020 16:34:33 +0200 Subject: Email notifications for new users when creating via the Admin API. (#7267) --- changelog.d/7267.bugfix | 1 + synapse/rest/admin/users.py | 16 +++++++++ tests/rest/admin/test_user.py | 75 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 92 insertions(+) create mode 100644 changelog.d/7267.bugfix (limited to 'tests') diff --git a/changelog.d/7267.bugfix b/changelog.d/7267.bugfix new file mode 100644 index 0000000000..0af316c1a2 --- /dev/null +++ b/changelog.d/7267.bugfix @@ -0,0 +1 @@ +Fix email notifications not being enabled for new users when created via the Admin API. diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index e7f6928c85..82251dbe5f 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -142,6 +142,7 @@ class UserRestServletV2(RestServlet): self.set_password_handler = hs.get_set_password_handler() self.deactivate_account_handler = hs.get_deactivate_account_handler() self.registration_handler = hs.get_registration_handler() + self.pusher_pool = hs.get_pusherpool() async def on_GET(self, request, user_id): await assert_requester_is_admin(self.auth, request) @@ -281,6 +282,21 @@ class UserRestServletV2(RestServlet): await self.auth_handler.add_threepid( user_id, threepid["medium"], threepid["address"], current_time ) + if ( + self.hs.config.email_enable_notifs + and self.hs.config.email_notif_for_new_users + ): + await self.pusher_pool.add_pusher( + user_id=user_id, + access_token=None, + kind="email", + app_id="m.email", + app_display_name="Email Notifications", + device_display_name=threepid["address"], + pushkey=threepid["address"], + lang=None, # We don't know a user's language here + data={}, + ) if "avatar_url" in body and type(body["avatar_url"]) == str: await self.profile_handler.set_avatar_url( diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 6c88ab06e2..e29cc24a8a 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -516,6 +516,81 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(False, channel.json_body["is_guest"]) self.assertEqual(False, channel.json_body["deactivated"]) + def test_create_user_email_notif_for_new_users(self): + """ + Check that a new regular user is created successfully and + got an email pusher. + """ + self.hs.config.registration_shared_secret = None + self.hs.config.email_enable_notifs = True + self.hs.config.email_notif_for_new_users = True + url = "/_synapse/admin/v2/users/@bob:test" + + # Create user + body = json.dumps( + { + "password": "abc123", + "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + } + ) + + request, channel = self.make_request( + "PUT", + url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) + + pushers = self.get_success( + self.store.get_pushers_by({"user_name": "@bob:test"}) + ) + pushers = list(pushers) + self.assertEqual(len(pushers), 1) + self.assertEqual("@bob:test", pushers[0]["user_name"]) + + def test_create_user_email_no_notif_for_new_users(self): + """ + Check that a new regular user is created successfully and + got not an email pusher. + """ + self.hs.config.registration_shared_secret = None + self.hs.config.email_enable_notifs = False + self.hs.config.email_notif_for_new_users = False + url = "/_synapse/admin/v2/users/@bob:test" + + # Create user + body = json.dumps( + { + "password": "abc123", + "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + } + ) + + request, channel = self.make_request( + "PUT", + url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) + + pushers = self.get_success( + self.store.get_pushers_by({"user_name": "@bob:test"}) + ) + pushers = list(pushers) + self.assertEqual(len(pushers), 0) + def test_set_password(self): """ Test setting a new password for another user. -- cgit 1.5.1 From 33c39ab93c5cb9b82faa2717c62a5ffaf6780a86 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 1 Jun 2020 17:47:30 +0200 Subject: Process cross-signing keys when resyncing device lists (#7594) It looks like `user_device_resync` was ignoring cross-signing keys from the results received from the remote server. This patch fixes this, by processing these keys using the same process `_handle_signing_key_updates` does (and effectively factor that part out of that function). --- changelog.d/7594.bugfix | 1 + synapse/handlers/device.py | 58 +++++++++++++++++++++++++++++++++++++++++++- synapse/handlers/e2e_keys.py | 22 ++++------------- tests/test_federation.py | 56 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 119 insertions(+), 18 deletions(-) create mode 100644 changelog.d/7594.bugfix (limited to 'tests') diff --git a/changelog.d/7594.bugfix b/changelog.d/7594.bugfix new file mode 100644 index 0000000000..f0c067e184 --- /dev/null +++ b/changelog.d/7594.bugfix @@ -0,0 +1 @@ +Fix a bug causing the cross-signing keys to be ignored when resyncing a device list. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 2cbb695bb1..230d170258 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Any, Dict, Optional from six import iteritems, itervalues @@ -30,7 +31,11 @@ from synapse.api.errors import ( ) from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import RoomStreamToken, get_domain_from_id +from synapse.types import ( + RoomStreamToken, + get_domain_from_id, + get_verify_key_from_cross_signing_key, +) from synapse.util import stringutils from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache @@ -795,6 +800,13 @@ class DeviceListUpdater(object): stream_id = result["stream_id"] devices = result["devices"] + # Get the master key and the self-signing key for this user if provided in the + # response (None if not in the response). + # The response will not contain the user signing key, as this key is only used by + # its owner, thus it doesn't make sense to send it over federation. + master_key = result.get("master_key") + self_signing_key = result.get("self_signing_key") + # If the remote server has more than ~1000 devices for this user # we assume that something is going horribly wrong (e.g. a bot # that logs in and creates a new device every time it tries to @@ -824,6 +836,13 @@ class DeviceListUpdater(object): yield self.store.update_remote_device_list_cache(user_id, devices, stream_id) device_ids = [device["device_id"] for device in devices] + + # Handle cross-signing keys. + cross_signing_device_ids = yield self.process_cross_signing_key_update( + user_id, master_key, self_signing_key, + ) + device_ids = device_ids + cross_signing_device_ids + yield self.device_handler.notify_device_update(user_id, device_ids) # We clobber the seen updates since we've re-synced from a given @@ -831,3 +850,40 @@ class DeviceListUpdater(object): self._seen_updates[user_id] = {stream_id} defer.returnValue(result) + + @defer.inlineCallbacks + def process_cross_signing_key_update( + self, + user_id: str, + master_key: Optional[Dict[str, Any]], + self_signing_key: Optional[Dict[str, Any]], + ) -> list: + """Process the given new master and self-signing key for the given remote user. + + Args: + user_id: The ID of the user these keys are for. + master_key: The dict of the cross-signing master key as returned by the + remote server. + self_signing_key: The dict of the cross-signing self-signing key as returned + by the remote server. + + Return: + The device IDs for the given keys. + """ + device_ids = [] + + if master_key: + yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key) + _, verify_key = get_verify_key_from_cross_signing_key(master_key) + # verify_key is a VerifyKey from signedjson, which uses + # .version to denote the portion of the key ID after the + # algorithm and colon, which is the device ID + device_ids.append(verify_key.version) + if self_signing_key: + yield self.store.set_e2e_cross_signing_key( + user_id, "self_signing", self_signing_key + ) + _, verify_key = get_verify_key_from_cross_signing_key(self_signing_key) + device_ids.append(verify_key.version) + + return device_ids diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 8f1bc0323c..774a252619 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -1291,6 +1291,7 @@ class SigningKeyEduUpdater(object): """ device_handler = self.e2e_keys_handler.device_handler + device_list_updater = device_handler.device_list_updater with (yield self._remote_edu_linearizer.queue(user_id)): pending_updates = self._pending_updates.pop(user_id, []) @@ -1303,22 +1304,9 @@ class SigningKeyEduUpdater(object): logger.info("pending updates: %r", pending_updates) for master_key, self_signing_key in pending_updates: - if master_key: - yield self.store.set_e2e_cross_signing_key( - user_id, "master", master_key - ) - _, verify_key = get_verify_key_from_cross_signing_key(master_key) - # verify_key is a VerifyKey from signedjson, which uses - # .version to denote the portion of the key ID after the - # algorithm and colon, which is the device ID - device_ids.append(verify_key.version) - if self_signing_key: - yield self.store.set_e2e_cross_signing_key( - user_id, "self_signing", self_signing_key - ) - _, verify_key = get_verify_key_from_cross_signing_key( - self_signing_key - ) - device_ids.append(verify_key.version) + new_device_ids = yield device_list_updater.process_cross_signing_key_update( + user_id, master_key, self_signing_key, + ) + device_ids = device_ids + new_device_ids yield device_handler.notify_device_update(user_id, device_ids) diff --git a/tests/test_federation.py b/tests/test_federation.py index c5099dd039..c662195eec 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -206,3 +206,59 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # list. self.reactor.advance(30) self.assertEqual(self.resync_attempts, 2) + + def test_cross_signing_keys_retry(self): + """Tests that resyncing a device list correctly processes cross-signing keys from + the remote server. + """ + remote_user_id = "@john:test_remote" + remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" + remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" + + # Register mock device list retrieval on the federation client. + federation_client = self.homeserver.get_federation_client() + federation_client.query_user_devices = Mock( + return_value={ + "user_id": remote_user_id, + "stream_id": 1, + "devices": [], + "master_key": { + "user_id": remote_user_id, + "usage": ["master"], + "keys": {"ed25519:" + remote_master_key: remote_master_key}, + }, + "self_signing_key": { + "user_id": remote_user_id, + "usage": ["self_signing"], + "keys": { + "ed25519:" + remote_self_signing_key: remote_self_signing_key + }, + }, + } + ) + + # Resync the device list. + device_handler = self.homeserver.get_device_handler() + self.get_success( + device_handler.device_list_updater.user_device_resync(remote_user_id), + ) + + # Retrieve the cross-signing keys for this user. + keys = self.get_success( + self.store.get_e2e_cross_signing_keys_bulk(user_ids=[remote_user_id]), + ) + self.assertTrue(remote_user_id in keys) + + # Check that the master key is the one returned by the mock. + master_key = keys[remote_user_id]["master"] + self.assertEqual(len(master_key["keys"]), 1) + self.assertTrue("ed25519:" + remote_master_key in master_key["keys"].keys()) + self.assertTrue(remote_master_key in master_key["keys"].values()) + + # Check that the self-signing key is the one returned by the mock. + self_signing_key = keys[remote_user_id]["self_signing"] + self.assertEqual(len(self_signing_key["keys"]), 1) + self.assertTrue( + "ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(), + ) + self.assertTrue(remote_self_signing_key in self_signing_key["keys"].values()) -- cgit 1.5.1 From fe434cd3c94dfc98954aea908e188e5d97df60db Mon Sep 17 00:00:00 2001 From: Olof Johansson Date: Mon, 1 Jun 2020 18:55:07 +0200 Subject: Fix a bug in automatic user creation with m.login.jwt. (#7585) --- changelog.d/7585.bugfix | 1 + synapse/rest/client/v1/login.py | 15 ++-- tests/rest/client/v1/test_login.py | 153 +++++++++++++++++++++++++++++++++++++ 3 files changed, 162 insertions(+), 7 deletions(-) create mode 100644 changelog.d/7585.bugfix (limited to 'tests') diff --git a/changelog.d/7585.bugfix b/changelog.d/7585.bugfix new file mode 100644 index 0000000000..263295599d --- /dev/null +++ b/changelog.d/7585.bugfix @@ -0,0 +1 @@ +Fix a bug in automatic user creation during first time login with `m.login.jwt`. Regression in v1.6.0. Contributed by @olof. diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index d89b2e5532..36aca82346 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -299,7 +299,7 @@ class LoginRestServlet(RestServlet): return result async def _complete_login( - self, user_id, login_submission, callback=None, create_non_existant_users=False + self, user_id, login_submission, callback=None, create_non_existent_users=False ): """Called when we've successfully authed the user and now need to actually login them in (e.g. create devices). This gets called on @@ -312,7 +312,7 @@ class LoginRestServlet(RestServlet): user_id (str): ID of the user to register. login_submission (dict): Dictionary of login information. callback (func|None): Callback function to run after registration. - create_non_existant_users (bool): Whether to create the user if + create_non_existent_users (bool): Whether to create the user if they don't exist. Defaults to False. Returns: @@ -331,12 +331,13 @@ class LoginRestServlet(RestServlet): update=True, ) - if create_non_existant_users: - user_id = await self.auth_handler.check_user_exists(user_id) - if not user_id: - user_id = await self.registration_handler.register_user( + if create_non_existent_users: + canonical_uid = await self.auth_handler.check_user_exists(user_id) + if not canonical_uid: + canonical_uid = await self.registration_handler.register_user( localpart=UserID.from_string(user_id).localpart ) + user_id = canonical_uid device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") @@ -391,7 +392,7 @@ class LoginRestServlet(RestServlet): user_id = UserID(user, self.hs.hostname).to_string() result = await self._complete_login( - user_id, login_submission, create_non_existant_users=True + user_id, login_submission, create_non_existent_users=True ) return result diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index eb8f6264fd..0f0f7ca72d 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -1,8 +1,11 @@ import json +import time import urllib.parse from mock import Mock +import jwt + import synapse.rest.admin from synapse.rest.client.v1 import login, logout from synapse.rest.client.v2_alpha import devices @@ -473,3 +476,153 @@ class CASTestCase(unittest.HomeserverTestCase): # Because the user is deactivated they are served an error template. self.assertEqual(channel.code, 403) self.assertIn(b"SSO account deactivated", channel.result["body"]) + + +class JWTTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + ] + + jwt_secret = "secret" + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + self.hs.config.jwt_enabled = True + self.hs.config.jwt_secret = self.jwt_secret + self.hs.config.jwt_algorithm = "HS256" + return self.hs + + def jwt_encode(self, token, secret=jwt_secret): + return jwt.encode(token, secret, "HS256").decode("ascii") + + def jwt_login(self, *args): + params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)}) + request, channel = self.make_request(b"POST", LOGIN_URL, params) + self.render(request) + return channel + + def test_login_jwt_valid_registered(self): + self.register_user("kermit", "monkey") + channel = self.jwt_login({"sub": "kermit"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + def test_login_jwt_valid_unregistered(self): + channel = self.jwt_login({"sub": "frog"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@frog:test") + + def test_login_jwt_invalid_signature(self): + channel = self.jwt_login({"sub": "frog"}, "notsecret") + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.json_body["error"], "Invalid JWT") + + def test_login_jwt_expired(self): + channel = self.jwt_login({"sub": "frog", "exp": 864000}) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.json_body["error"], "JWT expired") + + def test_login_jwt_not_before(self): + now = int(time.time()) + channel = self.jwt_login({"sub": "frog", "nbf": now + 3600}) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.json_body["error"], "Invalid JWT") + + def test_login_no_sub(self): + channel = self.jwt_login({"username": "root"}) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.json_body["error"], "Invalid JWT") + + def test_login_no_token(self): + params = json.dumps({"type": "m.login.jwt"}) + request, channel = self.make_request(b"POST", LOGIN_URL, params) + self.render(request) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.json_body["error"], "Token field for JWT is missing") + + +# The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use +# RSS256, with a public key configured in synapse as "jwt_secret", and tokens +# signed by the private key. +class JWTPubKeyTestCase(unittest.HomeserverTestCase): + servlets = [ + login.register_servlets, + ] + + # This key's pubkey is used as the jwt_secret setting of synapse. Valid + # tokens are signed by this and validated using the pubkey. It is generated + # with `openssl genrsa 512` (not a secure way to generate real keys, but + # good enough for tests!) + jwt_privatekey = "\n".join( + [ + "-----BEGIN RSA PRIVATE KEY-----", + "MIIBPAIBAAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7TKO1vSEWdq7u9x8SMFiB", + "492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQJAUv7OOSOtiU+wzJq82rnk", + "yR4NHqt7XX8BvkZPM7/+EjBRanmZNSp5kYZzKVaZ/gTOM9+9MwlmhidrUOweKfB/", + "kQIhAPZwHazbjo7dYlJs7wPQz1vd+aHSEH+3uQKIysebkmm3AiEA1nc6mDdmgiUq", + "TpIN8A4MBKmfZMWTLq6z05y/qjKyxb0CIQDYJxCwTEenIaEa4PdoJl+qmXFasVDN", + "ZU0+XtNV7yul0wIhAMI9IhiStIjS2EppBa6RSlk+t1oxh2gUWlIh+YVQfZGRAiEA", + "tqBR7qLZGJ5CVKxWmNhJZGt1QHoUtOch8t9C4IdOZ2g=", + "-----END RSA PRIVATE KEY-----", + ] + ) + + # Generated with `openssl rsa -in foo.key -pubout`, with the the above + # private key placed in foo.key (jwt_privatekey). + jwt_pubkey = "\n".join( + [ + "-----BEGIN PUBLIC KEY-----", + "MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7", + "TKO1vSEWdq7u9x8SMFiB492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQ==", + "-----END PUBLIC KEY-----", + ] + ) + + # This key is used to sign tokens that shouldn't be accepted by synapse. + # Generated just like jwt_privatekey. + bad_privatekey = "\n".join( + [ + "-----BEGIN RSA PRIVATE KEY-----", + "MIIBOgIBAAJBAL//SQrKpKbjCCnv/FlasJCv+t3k/MPsZfniJe4DVFhsktF2lwQv", + "gLjmQD3jBUTz+/FndLSBvr3F4OHtGL9O/osCAwEAAQJAJqH0jZJW7Smzo9ShP02L", + "R6HRZcLExZuUrWI+5ZSP7TaZ1uwJzGFspDrunqaVoPobndw/8VsP8HFyKtceC7vY", + "uQIhAPdYInDDSJ8rFKGiy3Ajv5KWISBicjevWHF9dbotmNO9AiEAxrdRJVU+EI9I", + "eB4qRZpY6n4pnwyP0p8f/A3NBaQPG+cCIFlj08aW/PbxNdqYoBdeBA0xDrXKfmbb", + "iwYxBkwL0JCtAiBYmsi94sJn09u2Y4zpuCbJeDPKzWkbuwQh+W1fhIWQJQIhAKR0", + "KydN6cRLvphNQ9c/vBTdlzWxzcSxREpguC7F1J1m", + "-----END RSA PRIVATE KEY-----", + ] + ) + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + self.hs.config.jwt_enabled = True + self.hs.config.jwt_secret = self.jwt_pubkey + self.hs.config.jwt_algorithm = "RS256" + return self.hs + + def jwt_encode(self, token, secret=jwt_privatekey): + return jwt.encode(token, secret, "RS256").decode("ascii") + + def jwt_login(self, *args): + params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)}) + request, channel = self.make_request(b"POST", LOGIN_URL, params) + self.render(request) + return channel + + def test_login_jwt_valid(self): + channel = self.jwt_login({"sub": "kermit"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + def test_login_jwt_invalid_signature(self): + channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey) + self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED") + self.assertEqual(channel.json_body["error"], "Invalid JWT") -- cgit 1.5.1 From 0188daf32c5d978c33b2bba9eb2b0b63262ca5fe Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 3 Jun 2020 16:39:30 +0100 Subject: Replace instances of reactor pumping with get_success. (#7619) Calls `self.get_success` on all deferred methods instead of abusing `self.pump()`. This has the benefit of working with coroutines, as well as checking that method execution completed successfully. There are also a few small cleanups that I made in the process. --- changelog.d/7619.misc | 1 + tests/storage/test_monthly_active_users.py | 267 ++++++++++++++++------------- tests/test_mau.py | 5 +- 3 files changed, 152 insertions(+), 121 deletions(-) create mode 100644 changelog.d/7619.misc (limited to 'tests') diff --git a/changelog.d/7619.misc b/changelog.d/7619.misc new file mode 100644 index 0000000000..23a8b30b19 --- /dev/null +++ b/changelog.d/7619.misc @@ -0,0 +1 @@ +Check that all asynchronous tasks succeed and general cleanup of `MonthlyActiveUsersTestCase` and `TestMauLimit`. diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 447fcb3a1c..9c04e92577 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -61,21 +61,27 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): user2_email = threepids[1]["address"] user3 = "@user3:server" - self.store.register_user(user_id=user1) - self.store.register_user(user_id=user2) - self.store.register_user(user_id=user3, user_type=UserTypes.SUPPORT) - self.pump() + self.get_success(self.store.register_user(user_id=user1)) + self.get_success(self.store.register_user(user_id=user2)) + self.get_success( + self.store.register_user(user_id=user3, user_type=UserTypes.SUPPORT) + ) now = int(self.hs.get_clock().time_msec()) - self.store.user_add_threepid(user1, "email", user1_email, now, now) - self.store.user_add_threepid(user2, "email", user2_email, now, now) + self.get_success( + self.store.user_add_threepid(user1, "email", user1_email, now, now) + ) + self.get_success( + self.store.user_add_threepid(user2, "email", user2_email, now, now) + ) # XXX why are we doing this here? this function is only run at startup # so it is odd to re-run it here. - self.store.db.runInteraction( - "initialise", self.store._initialise_reserved_users, threepids + self.get_success( + self.store.db.runInteraction( + "initialise", self.store._initialise_reserved_users, threepids + ) ) - self.pump() # the number of users we expect will be counted against the mau limit # -1 because user3 is a support user and does not count @@ -83,13 +89,13 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): # Check the number of active users. Ensure user3 (support user) is not counted active_count = self.get_success(self.store.get_monthly_active_count()) - self.assertEquals(active_count, user_num) + self.assertEqual(active_count, user_num) # Test each of the registered users is marked as active - timestamp = self.store.user_last_seen_monthly_active(user1) - self.assertTrue(self.get_success(timestamp)) - timestamp = self.store.user_last_seen_monthly_active(user2) - self.assertTrue(self.get_success(timestamp)) + timestamp = self.get_success(self.store.user_last_seen_monthly_active(user1)) + self.assertGreater(timestamp, 0) + timestamp = self.get_success(self.store.user_last_seen_monthly_active(user2)) + self.assertGreater(timestamp, 0) # Test that users with reserved 3pids are not removed from the MAU table # XXX some of this is redundant. poking things into the config shouldn't @@ -98,77 +104,79 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.hs.config.max_mau_value = 0 self.reactor.advance(FORTY_DAYS) self.hs.config.max_mau_value = 5 - self.store.reap_monthly_active_users() - self.pump() - active_count = self.store.get_monthly_active_count() - self.assertEquals(self.get_success(active_count), user_num) + self.get_success(self.store.reap_monthly_active_users()) + + active_count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(active_count, user_num) # Add some more users and check they are counted as active ru_count = 2 - self.store.upsert_monthly_active_user("@ru1:server") - self.store.upsert_monthly_active_user("@ru2:server") - self.pump() - active_count = self.store.get_monthly_active_count() - self.assertEqual(self.get_success(active_count), user_num + ru_count) + + self.get_success(self.store.upsert_monthly_active_user("@ru1:server")) + self.get_success(self.store.upsert_monthly_active_user("@ru2:server")) + + active_count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(active_count, user_num + ru_count) # now run the reaper and check that the number of active users is reduced # to max_mau_value - self.store.reap_monthly_active_users() - self.pump() + self.get_success(self.store.reap_monthly_active_users()) - active_count = self.store.get_monthly_active_count() - self.assertEquals(self.get_success(active_count), 3) + active_count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(active_count, 3) def test_can_insert_and_count_mau(self): - count = self.store.get_monthly_active_count() - self.assertEqual(0, self.get_success(count)) + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, 0) - self.store.upsert_monthly_active_user("@user:server") - self.pump() + d = self.store.upsert_monthly_active_user("@user:server") + self.get_success(d) - count = self.store.get_monthly_active_count() - self.assertEqual(1, self.get_success(count)) + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, 1) def test_user_last_seen_monthly_active(self): user_id1 = "@user1:server" user_id2 = "@user2:server" user_id3 = "@user3:server" - result = self.store.user_last_seen_monthly_active(user_id1) - self.assertFalse(self.get_success(result) == 0) + result = self.get_success(self.store.user_last_seen_monthly_active(user_id1)) + self.assertNotEqual(result, 0) - self.store.upsert_monthly_active_user(user_id1) - self.store.upsert_monthly_active_user(user_id2) - self.pump() + self.get_success(self.store.upsert_monthly_active_user(user_id1)) + self.get_success(self.store.upsert_monthly_active_user(user_id2)) - result = self.store.user_last_seen_monthly_active(user_id1) - self.assertGreater(self.get_success(result), 0) + result = self.get_success(self.store.user_last_seen_monthly_active(user_id1)) + self.assertGreater(result, 0) - result = self.store.user_last_seen_monthly_active(user_id3) - self.assertNotEqual(self.get_success(result), 0) + result = self.get_success(self.store.user_last_seen_monthly_active(user_id3)) + self.assertNotEqual(result, 0) @override_config({"max_mau_value": 5}) def test_reap_monthly_active_users(self): initial_users = 10 for i in range(initial_users): - self.store.upsert_monthly_active_user("@user%d:server" % i) - self.pump() + self.get_success( + self.store.upsert_monthly_active_user("@user%d:server" % i) + ) - count = self.store.get_monthly_active_count() - self.assertTrue(self.get_success(count), initial_users) + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, initial_users) - self.store.reap_monthly_active_users() - self.pump() - count = self.store.get_monthly_active_count() - self.assertEquals(self.get_success(count), self.hs.config.max_mau_value) + d = self.store.reap_monthly_active_users() + self.get_success(d) + + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, self.hs.config.max_mau_value) self.reactor.advance(FORTY_DAYS) - self.store.reap_monthly_active_users() - self.pump() - count = self.store.get_monthly_active_count() - self.assertEquals(self.get_success(count), 0) + d = self.store.reap_monthly_active_users() + self.get_success(d) + + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, 0) # Note that below says mau_limit (no s), this is the name of the config # value, although it gets stored on the config object as mau_limits. @@ -182,7 +190,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): for i in range(initial_users): user = "@user%d:server" % i email = "user%d@matrix.org" % i + self.get_success(self.store.upsert_monthly_active_user(user)) + # Need to ensure that the most recent entries in the # monthly_active_users table are reserved now = int(self.hs.get_clock().time_msec()) @@ -194,26 +204,37 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.user_add_threepid(user, "email", email, now, now) ) - self.store.db.runInteraction( + d = self.store.db.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) - count = self.store.get_monthly_active_count() - self.assertTrue(self.get_success(count), initial_users) + self.get_success(d) - users = self.store.get_registered_reserved_users() - self.assertEquals(len(self.get_success(users)), reserved_user_number) + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, initial_users) - self.get_success(self.store.reap_monthly_active_users()) - count = self.store.get_monthly_active_count() - self.assertEquals(self.get_success(count), self.hs.config.max_mau_value) + users = self.get_success(self.store.get_registered_reserved_users()) + self.assertEqual(len(users), reserved_user_number) + + d = self.store.reap_monthly_active_users() + self.get_success(d) + + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, self.hs.config.max_mau_value) def test_populate_monthly_users_is_guest(self): # Test that guest users are not added to mau list user_id = "@user_id:host" - self.store.register_user(user_id=user_id, password_hash=None, make_guest=True) + + d = self.store.register_user( + user_id=user_id, password_hash=None, make_guest=True + ) + self.get_success(d) + self.store.upsert_monthly_active_user = Mock() - self.store.populate_monthly_active_users(user_id) - self.pump() + + d = self.store.populate_monthly_active_users(user_id) + self.get_success(d) + self.store.upsert_monthly_active_user.assert_not_called() def test_populate_monthly_users_should_update(self): @@ -224,8 +245,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(None) ) - self.store.populate_monthly_active_users("user_id") - self.pump() + d = self.store.populate_monthly_active_users("user_id") + self.get_success(d) + self.store.upsert_monthly_active_user.assert_called_once() def test_populate_monthly_users_should_not_update(self): @@ -235,16 +257,18 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(self.hs.get_clock().time_msec()) ) - self.store.populate_monthly_active_users("user_id") - self.pump() + + d = self.store.populate_monthly_active_users("user_id") + self.get_success(d) + self.store.upsert_monthly_active_user.assert_not_called() def test_get_reserved_real_user_account(self): # Test no reserved users, or reserved threepids users = self.get_success(self.store.get_registered_reserved_users()) - self.assertEquals(len(users), 0) - # Test reserved users but no registered users + self.assertEqual(len(users), 0) + # Test reserved users but no registered users user1 = "@user1:example.com" user2 = "@user2:example.com" @@ -254,63 +278,64 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): {"medium": "email", "address": user1_email}, {"medium": "email", "address": user2_email}, ] + self.hs.config.mau_limits_reserved_threepids = threepids - self.store.db.runInteraction( + d = self.store.db.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) + self.get_success(d) - self.pump() users = self.get_success(self.store.get_registered_reserved_users()) - self.assertEquals(len(users), 0) + self.assertEqual(len(users), 0) - # Test reserved registed users - self.store.register_user(user_id=user1, password_hash=None) - self.store.register_user(user_id=user2, password_hash=None) - self.pump() + # Test reserved registered users + self.get_success(self.store.register_user(user_id=user1, password_hash=None)) + self.get_success(self.store.register_user(user_id=user2, password_hash=None)) now = int(self.hs.get_clock().time_msec()) self.store.user_add_threepid(user1, "email", user1_email, now, now) self.store.user_add_threepid(user2, "email", user2_email, now, now) users = self.get_success(self.store.get_registered_reserved_users()) - self.assertEquals(len(users), len(threepids)) + self.assertEqual(len(users), len(threepids)) def test_support_user_not_add_to_mau_limits(self): support_user_id = "@support:test" - count = self.store.get_monthly_active_count() - self.pump() - self.assertEqual(self.get_success(count), 0) - self.store.register_user( + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, 0) + + d = self.store.register_user( user_id=support_user_id, password_hash=None, user_type=UserTypes.SUPPORT ) + self.get_success(d) - self.store.upsert_monthly_active_user(support_user_id) - count = self.store.get_monthly_active_count() - self.pump() - self.assertEqual(self.get_success(count), 0) + d = self.store.upsert_monthly_active_user(support_user_id) + self.get_success(d) + + d = self.store.get_monthly_active_count() + count = self.get_success(d) + self.assertEqual(count, 0) # Note that the max_mau_value setting should not matter. @override_config( {"limit_usage_by_mau": False, "mau_stats_only": True, "max_mau_value": 1} ) def test_track_monthly_users_without_cap(self): - count = self.store.get_monthly_active_count() - self.assertEqual(0, self.get_success(count)) + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(0, count) - self.store.upsert_monthly_active_user("@user1:server") - self.store.upsert_monthly_active_user("@user2:server") - self.pump() + self.get_success(self.store.upsert_monthly_active_user("@user1:server")) + self.get_success(self.store.upsert_monthly_active_user("@user2:server")) - count = self.store.get_monthly_active_count() - self.assertEqual(2, self.get_success(count)) + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(2, count) @override_config({"limit_usage_by_mau": False, "mau_stats_only": False}) def test_no_users_when_not_tracking(self): self.store.upsert_monthly_active_user = Mock() - self.store.populate_monthly_active_users("@user:sever") - self.pump() + self.get_success(self.store.populate_monthly_active_users("@user:sever")) self.store.upsert_monthly_active_user.assert_not_called() @@ -325,33 +350,39 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): service2 = "service2" native = "native" - self.store.register_user( - user_id=appservice1_user1, password_hash=None, appservice_id=service1 + self.get_success( + self.store.register_user( + user_id=appservice1_user1, password_hash=None, appservice_id=service1 + ) + ) + self.get_success( + self.store.register_user( + user_id=appservice1_user2, password_hash=None, appservice_id=service1 + ) ) - self.store.register_user( - user_id=appservice1_user2, password_hash=None, appservice_id=service1 + self.get_success( + self.store.register_user( + user_id=appservice2_user1, password_hash=None, appservice_id=service2 + ) ) - self.store.register_user( - user_id=appservice2_user1, password_hash=None, appservice_id=service2 + self.get_success( + self.store.register_user(user_id=native_user1, password_hash=None) ) - self.store.register_user(user_id=native_user1, password_hash=None) - self.pump() - count = self.store.get_monthly_active_count_by_service() - self.assertEqual({}, self.get_success(count)) + count = self.get_success(self.store.get_monthly_active_count_by_service()) + self.assertEqual(count, {}) - self.store.upsert_monthly_active_user(native_user1) - self.store.upsert_monthly_active_user(appservice1_user1) - self.store.upsert_monthly_active_user(appservice1_user2) - self.store.upsert_monthly_active_user(appservice2_user1) - self.pump() + self.get_success(self.store.upsert_monthly_active_user(native_user1)) + self.get_success(self.store.upsert_monthly_active_user(appservice1_user1)) + self.get_success(self.store.upsert_monthly_active_user(appservice1_user2)) + self.get_success(self.store.upsert_monthly_active_user(appservice2_user1)) - count = self.store.get_monthly_active_count() - self.assertEqual(4, self.get_success(count)) + count = self.get_success(self.store.get_monthly_active_count()) + self.assertEqual(count, 4) - count = self.store.get_monthly_active_count_by_service() - result = self.get_success(count) + d = self.store.get_monthly_active_count_by_service() + result = self.get_success(d) - self.assertEqual(2, result[service1]) - self.assertEqual(1, result[service2]) - self.assertEqual(1, result[native]) + self.assertEqual(result[service1], 2) + self.assertEqual(result[service2], 1) + self.assertEqual(result[native], 1) diff --git a/tests/test_mau.py b/tests/test_mau.py index 8a97f0998d..49667ed7f4 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -85,7 +85,7 @@ class TestMauLimit(unittest.HomeserverTestCase): # Advance time by 31 days self.reactor.advance(31 * 24 * 60 * 60) - self.store.reap_monthly_active_users() + self.get_success(self.store.reap_monthly_active_users()) self.reactor.advance(0) @@ -147,8 +147,7 @@ class TestMauLimit(unittest.HomeserverTestCase): # Advance by 2 months so everyone falls out of MAU self.reactor.advance(60 * 24 * 60 * 60) - self.store.reap_monthly_active_users() - self.reactor.advance(0) + self.get_success(self.store.reap_monthly_active_users()) # We can create as many new users as we want token4 = self.create_user("kermit4") -- cgit 1.5.1 From c389bfb6eac421b84d3f91b344e8f3a91e421e83 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Thu, 4 Jun 2020 20:03:40 +0100 Subject: Fix encryption algorithm typos in tests/comments (#7637) @uhoreg has confirmed these were both typos. They are only in comments and tests though, rather than anything critical. Introduced in: * https://github.com/matrix-org/synapse/pull/7157 * https://github.com/matrix-org/synapse/pull/5726 --- changelog.d/7637.misc | 1 + synapse/rest/client/v2_alpha/keys.py | 8 ++++---- tests/federation/test_federation_sender.py | 2 +- tests/handlers/test_e2e_keys.py | 10 +++++----- 4 files changed, 11 insertions(+), 10 deletions(-) create mode 100644 changelog.d/7637.misc (limited to 'tests') diff --git a/changelog.d/7637.misc b/changelog.d/7637.misc new file mode 100644 index 0000000000..90d41fa775 --- /dev/null +++ b/changelog.d/7637.misc @@ -0,0 +1 @@ +Fix typos of `m.olm.curve25519-aes-sha2` and `m.megolm.v1.aes-sha2` in comments, test files. \ No newline at end of file diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 8f41a3edbf..24bb090822 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -42,7 +42,7 @@ class KeyUploadServlet(RestServlet): "device_id": "", "valid_until_ts": , "algorithms": [ - "m.olm.curve25519-aes-sha256", + "m.olm.curve25519-aes-sha2", ] "keys": { ":": "", @@ -124,7 +124,7 @@ class KeyQueryServlet(RestServlet): "device_id": "", // Duplicated to be signed "valid_until_ts": , "algorithms": [ // List of supported algorithms - "m.olm.curve25519-aes-sha256", + "m.olm.curve25519-aes-sha2", ], "keys": { // Must include a ed25519 signing key ":": "", @@ -285,8 +285,8 @@ class SignaturesUploadServlet(RestServlet): "user_id": "", "device_id": "", "algorithms": [ - "m.olm.curve25519-aes-sha256", - "m.megolm.v1.aes-sha" + "m.olm.curve25519-aes-sha2", + "m.megolm.v1.aes-sha2" ], "keys": { ":": "", diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 33105576af..ff12539041 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -536,7 +536,7 @@ def build_device_dict(user_id: str, device_id: str, sk: SigningKey): return { "user_id": user_id, "device_id": device_id, - "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"], + "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"], "keys": { "curve25519:" + device_id: "curve25519+key", key_id(sk): encode_pubkey(sk), diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 854eb6c024..e1e144b2e7 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -222,7 +222,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): device_key_1 = { "user_id": local_user, "device_id": "abc", - "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"], + "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"], "keys": { "ed25519:abc": "base64+ed25519+key", "curve25519:abc": "base64+curve25519+key", @@ -232,7 +232,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): device_key_2 = { "user_id": local_user, "device_id": "def", - "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"], + "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"], "keys": { "ed25519:def": "base64+ed25519+key", "curve25519:def": "base64+curve25519+key", @@ -315,7 +315,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): device_key = { "user_id": local_user, "device_id": device_id, - "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"], + "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"], "keys": {"curve25519:xyz": "curve25519+key", "ed25519:xyz": device_pubkey}, "signatures": {local_user: {"ed25519:xyz": "something"}}, } @@ -391,8 +391,8 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "user_id": local_user, "device_id": device_id, "algorithms": [ - "m.olm.curve25519-aes-sha256", - "m.megolm.v1.aes-sha", + "m.olm.curve25519-aes-sha2", + "m.megolm.v1.aes-sha2", ], "keys": { "curve25519:xyz": "curve25519+key", -- cgit 1.5.1 From f4e6495b5d3267976f34088fa7459b388b801eb6 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 5 Jun 2020 10:47:20 +0100 Subject: Performance improvements and refactor of Ratelimiter (#7595) While working on https://github.com/matrix-org/synapse/issues/5665 I found myself digging into the `Ratelimiter` class and seeing that it was both: * Rather undocumented, and * causing a *lot* of config checks This PR attempts to refactor and comment the `Ratelimiter` class, as well as encourage config file accesses to only be done at instantiation. Best to be reviewed commit-by-commit. --- changelog.d/7595.misc | 1 + synapse/api/ratelimiting.py | 153 +++++++++++++++++++++------- synapse/config/ratelimiting.py | 8 +- synapse/handlers/_base.py | 60 ++++++----- synapse/handlers/auth.py | 24 ++--- synapse/handlers/message.py | 1 - synapse/handlers/register.py | 9 +- synapse/rest/client/v1/login.py | 65 ++++-------- synapse/rest/client/v2_alpha/register.py | 16 +-- synapse/server.py | 17 ++-- synapse/util/ratelimitutils.py | 2 +- tests/api/test_ratelimiting.py | 96 +++++++++++++---- tests/handlers/test_profile.py | 6 +- tests/replication/slave/storage/_base.py | 9 +- tests/rest/client/v1/test_events.py | 8 +- tests/rest/client/v1/test_login.py | 49 +++++++-- tests/rest/client/v1/test_rooms.py | 9 +- tests/rest/client/v1/test_typing.py | 10 +- tests/rest/client/v2_alpha/test_register.py | 9 +- 19 files changed, 322 insertions(+), 230 deletions(-) create mode 100644 changelog.d/7595.misc (limited to 'tests') diff --git a/changelog.d/7595.misc b/changelog.d/7595.misc new file mode 100644 index 0000000000..7a0646b1a3 --- /dev/null +++ b/changelog.d/7595.misc @@ -0,0 +1 @@ +Refactor `Ratelimiter` to limit the amount of expensive config value accesses. diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 7a049b3af7..ec6b3a69a2 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -1,4 +1,5 @@ # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,75 +17,157 @@ from collections import OrderedDict from typing import Any, Optional, Tuple from synapse.api.errors import LimitExceededError +from synapse.util import Clock class Ratelimiter(object): """ - Ratelimit message sending by user. + Ratelimit actions marked by arbitrary keys. + + Args: + clock: A homeserver clock, for retrieving the current time + rate_hz: The long term number of actions that can be performed in a second. + burst_count: How many actions that can be performed before being limited. """ - def __init__(self): - self.message_counts = ( - OrderedDict() - ) # type: OrderedDict[Any, Tuple[float, int, Optional[float]]] + def __init__(self, clock: Clock, rate_hz: float, burst_count: int): + self.clock = clock + self.rate_hz = rate_hz + self.burst_count = burst_count + + # A ordered dictionary keeping track of actions, when they were last + # performed and how often. Each entry is a mapping from a key of arbitrary type + # to a tuple representing: + # * How many times an action has occurred since a point in time + # * The point in time + # * The rate_hz of this particular entry. This can vary per request + self.actions = OrderedDict() # type: OrderedDict[Any, Tuple[float, int, float]] - def can_do_action(self, key, time_now_s, rate_hz, burst_count, update=True): + def can_do_action( + self, + key: Any, + rate_hz: Optional[float] = None, + burst_count: Optional[int] = None, + update: bool = True, + _time_now_s: Optional[int] = None, + ) -> Tuple[bool, float]: """Can the entity (e.g. user or IP address) perform the action? + Args: key: The key we should use when rate limiting. Can be a user ID (when sending events), an IP address, etc. - time_now_s: The time now. - rate_hz: The long term number of messages a user can send in a - second. - burst_count: How many messages the user can send before being - limited. - update (bool): Whether to update the message rates or not. This is - useful to check if a message would be allowed to be sent before - its ready to be actually sent. + rate_hz: The long term number of actions that can be performed in a second. + Overrides the value set during instantiation if set. + burst_count: How many actions that can be performed before being limited. + Overrides the value set during instantiation if set. + update: Whether to count this check as performing the action + _time_now_s: The current time. Optional, defaults to the current time according + to self.clock. Only used by tests. + Returns: - A pair of a bool indicating if they can send a message now and a - time in seconds of when they can next send a message. + A tuple containing: + * A bool indicating if they can perform the action now + * The reactor timestamp for when the action can be performed next. + -1 if rate_hz is less than or equal to zero """ - self.prune_message_counts(time_now_s) - message_count, time_start, _ignored = self.message_counts.get( - key, (0.0, time_now_s, None) - ) + # Override default values if set + time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() + rate_hz = rate_hz if rate_hz is not None else self.rate_hz + burst_count = burst_count if burst_count is not None else self.burst_count + + # Remove any expired entries + self._prune_message_counts(time_now_s) + + # Check if there is an existing count entry for this key + action_count, time_start, _ = self.actions.get(key, (0.0, time_now_s, 0.0)) + + # Check whether performing another action is allowed time_delta = time_now_s - time_start - sent_count = message_count - time_delta * rate_hz - if sent_count < 0: + performed_count = action_count - time_delta * rate_hz + if performed_count < 0: + # Allow, reset back to count 1 allowed = True time_start = time_now_s - message_count = 1.0 - elif sent_count > burst_count - 1.0: + action_count = 1.0 + elif performed_count > burst_count - 1.0: + # Deny, we have exceeded our burst count allowed = False else: + # We haven't reached our limit yet allowed = True - message_count += 1 + action_count += 1.0 if update: - self.message_counts[key] = (message_count, time_start, rate_hz) + self.actions[key] = (action_count, time_start, rate_hz) if rate_hz > 0: - time_allowed = time_start + (message_count - burst_count + 1) / rate_hz + # Find out when the count of existing actions expires + time_allowed = time_start + (action_count - burst_count + 1) / rate_hz + + # Don't give back a time in the past if time_allowed < time_now_s: time_allowed = time_now_s + else: + # XXX: Why is this -1? This seems to only be used in + # self.ratelimit. I guess so that clients get a time in the past and don't + # feel afraid to try again immediately time_allowed = -1 return allowed, time_allowed - def prune_message_counts(self, time_now_s): - for key in list(self.message_counts.keys()): - message_count, time_start, rate_hz = self.message_counts[key] + def _prune_message_counts(self, time_now_s: int): + """Remove message count entries that have not exceeded their defined + rate_hz limit + + Args: + time_now_s: The current time + """ + # We create a copy of the key list here as the dictionary is modified during + # the loop + for key in list(self.actions.keys()): + action_count, time_start, rate_hz = self.actions[key] + + # Rate limit = "seconds since we started limiting this action" * rate_hz + # If this limit has not been exceeded, wipe our record of this action time_delta = time_now_s - time_start - if message_count - time_delta * rate_hz > 0: - break + if action_count - time_delta * rate_hz > 0: + continue else: - del self.message_counts[key] + del self.actions[key] + + def ratelimit( + self, + key: Any, + rate_hz: Optional[float] = None, + burst_count: Optional[int] = None, + update: bool = True, + _time_now_s: Optional[int] = None, + ): + """Checks if an action can be performed. If not, raises a LimitExceededError + + Args: + key: An arbitrary key used to classify an action + rate_hz: The long term number of actions that can be performed in a second. + Overrides the value set during instantiation if set. + burst_count: How many actions that can be performed before being limited. + Overrides the value set during instantiation if set. + update: Whether to count this check as performing the action + _time_now_s: The current time. Optional, defaults to the current time according + to self.clock. Only used by tests. + + Raises: + LimitExceededError: If an action could not be performed, along with the time in + milliseconds until the action can be performed again + """ + time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() - def ratelimit(self, key, time_now_s, rate_hz, burst_count, update=True): allowed, time_allowed = self.can_do_action( - key, time_now_s, rate_hz, burst_count, update + key, + rate_hz=rate_hz, + burst_count=burst_count, + update=update, + _time_now_s=time_now_s, ) if not allowed: diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 4a3bfc4354..2dd94bae2b 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -12,11 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict + from ._base import Config class RateLimitConfig(object): - def __init__(self, config, defaults={"per_second": 0.17, "burst_count": 3.0}): + def __init__( + self, + config: Dict[str, float], + defaults={"per_second": 0.17, "burst_count": 3.0}, + ): self.per_second = config.get("per_second", defaults["per_second"]) self.burst_count = config.get("burst_count", defaults["burst_count"]) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 3b781d9836..61dc4beafe 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -19,7 +19,7 @@ from twisted.internet import defer import synapse.types from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import LimitExceededError +from synapse.api.ratelimiting import Ratelimiter from synapse.types import UserID logger = logging.getLogger(__name__) @@ -44,11 +44,26 @@ class BaseHandler(object): self.notifier = hs.get_notifier() self.state_handler = hs.get_state_handler() self.distributor = hs.get_distributor() - self.ratelimiter = hs.get_ratelimiter() - self.admin_redaction_ratelimiter = hs.get_admin_redaction_ratelimiter() self.clock = hs.get_clock() self.hs = hs + # The rate_hz and burst_count are overridden on a per-user basis + self.request_ratelimiter = Ratelimiter( + clock=self.clock, rate_hz=0, burst_count=0 + ) + self._rc_message = self.hs.config.rc_message + + # Check whether ratelimiting room admin message redaction is enabled + # by the presence of rate limits in the config + if self.hs.config.rc_admin_redaction: + self.admin_redaction_ratelimiter = Ratelimiter( + clock=self.clock, + rate_hz=self.hs.config.rc_admin_redaction.per_second, + burst_count=self.hs.config.rc_admin_redaction.burst_count, + ) + else: + self.admin_redaction_ratelimiter = None + self.server_name = hs.hostname self.event_builder_factory = hs.get_event_builder_factory() @@ -70,7 +85,6 @@ class BaseHandler(object): Raises: LimitExceededError if the request should be ratelimited """ - time_now = self.clock.time() user_id = requester.user.to_string() # The AS user itself is never rate limited. @@ -83,48 +97,32 @@ class BaseHandler(object): if requester.app_service and not requester.app_service.is_rate_limited(): return + messages_per_second = self._rc_message.per_second + burst_count = self._rc_message.burst_count + # Check if there is a per user override in the DB. override = yield self.store.get_ratelimit_for_user(user_id) if override: - # If overriden with a null Hz then ratelimiting has been entirely + # If overridden with a null Hz then ratelimiting has been entirely # disabled for the user if not override.messages_per_second: return messages_per_second = override.messages_per_second burst_count = override.burst_count + + if is_admin_redaction and self.admin_redaction_ratelimiter: + # If we have separate config for admin redactions, use a separate + # ratelimiter as to not have user_ids clash + self.admin_redaction_ratelimiter.ratelimit(user_id, update=update) else: - # We default to different values if this is an admin redaction and - # the config is set - if is_admin_redaction and self.hs.config.rc_admin_redaction: - messages_per_second = self.hs.config.rc_admin_redaction.per_second - burst_count = self.hs.config.rc_admin_redaction.burst_count - else: - messages_per_second = self.hs.config.rc_message.per_second - burst_count = self.hs.config.rc_message.burst_count - - if is_admin_redaction and self.hs.config.rc_admin_redaction: - # If we have separate config for admin redactions we use a separate - # ratelimiter - allowed, time_allowed = self.admin_redaction_ratelimiter.can_do_action( - user_id, - time_now, - rate_hz=messages_per_second, - burst_count=burst_count, - update=update, - ) - else: - allowed, time_allowed = self.ratelimiter.can_do_action( + # Override rate and burst count per-user + self.request_ratelimiter.ratelimit( user_id, - time_now, rate_hz=messages_per_second, burst_count=burst_count, update=update, ) - if not allowed: - raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now)) - ) async def maybe_kick_guest_users(self, event, context=None): # Technically this function invalidates current_state by changing it. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 75b39e878c..119678e67b 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -108,7 +108,11 @@ class AuthHandler(BaseHandler): # Ratelimiter for failed auth during UIA. Uses same ratelimit config # as per `rc_login.failed_attempts`. - self._failed_uia_attempts_ratelimiter = Ratelimiter() + self._failed_uia_attempts_ratelimiter = Ratelimiter( + clock=self.clock, + rate_hz=self.hs.config.rc_login_failed_attempts.per_second, + burst_count=self.hs.config.rc_login_failed_attempts.burst_count, + ) self._clock = self.hs.get_clock() @@ -196,13 +200,7 @@ class AuthHandler(BaseHandler): user_id = requester.user.to_string() # Check if we should be ratelimited due to too many previous failed attempts - self._failed_uia_attempts_ratelimiter.ratelimit( - user_id, - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, - update=False, - ) + self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False) # build a list of supported flows flows = [[login_type] for login_type in self._supported_ui_auth_types] @@ -212,14 +210,8 @@ class AuthHandler(BaseHandler): flows, request, request_body, clientip, description ) except LoginError: - # Update the ratelimite to say we failed (`can_do_action` doesn't raise). - self._failed_uia_attempts_ratelimiter.can_do_action( - user_id, - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, - update=True, - ) + # Update the ratelimiter to say we failed (`can_do_action` doesn't raise). + self._failed_uia_attempts_ratelimiter.can_do_action(user_id) raise # find the completed login type diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 681f92cafd..649ca1f08a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -362,7 +362,6 @@ class EventCreationHandler(object): self.profile_handler = hs.get_profile_handler() self.event_builder_factory = hs.get_event_builder_factory() self.server_name = hs.hostname - self.ratelimiter = hs.get_ratelimiter() self.notifier = hs.get_notifier() self.config = hs.config self.require_membership_for_aliases = hs.config.require_membership_for_aliases diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 55a03e53ea..cd746be7c8 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -425,14 +425,7 @@ class RegistrationHandler(BaseHandler): if not address: return - time_now = self.clock.time() - - self.ratelimiter.ratelimit( - address, - time_now_s=time_now, - rate_hz=self.hs.config.rc_registration.per_second, - burst_count=self.hs.config.rc_registration.burst_count, - ) + self.ratelimiter.ratelimit(address) def register_with_store( self, diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 6ac7c5142b..dceb2792fa 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -87,11 +87,22 @@ class LoginRestServlet(RestServlet): self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self.handlers = hs.get_handlers() - self._clock = hs.get_clock() self._well_known_builder = WellKnownBuilder(hs) - self._address_ratelimiter = Ratelimiter() - self._account_ratelimiter = Ratelimiter() - self._failed_attempts_ratelimiter = Ratelimiter() + self._address_ratelimiter = Ratelimiter( + clock=hs.get_clock(), + rate_hz=self.hs.config.rc_login_address.per_second, + burst_count=self.hs.config.rc_login_address.burst_count, + ) + self._account_ratelimiter = Ratelimiter( + clock=hs.get_clock(), + rate_hz=self.hs.config.rc_login_account.per_second, + burst_count=self.hs.config.rc_login_account.burst_count, + ) + self._failed_attempts_ratelimiter = Ratelimiter( + clock=hs.get_clock(), + rate_hz=self.hs.config.rc_login_failed_attempts.per_second, + burst_count=self.hs.config.rc_login_failed_attempts.burst_count, + ) def on_GET(self, request): flows = [] @@ -124,13 +135,7 @@ class LoginRestServlet(RestServlet): return 200, {} async def on_POST(self, request): - self._address_ratelimiter.ratelimit( - request.getClientIP(), - time_now_s=self.hs.clock.time(), - rate_hz=self.hs.config.rc_login_address.per_second, - burst_count=self.hs.config.rc_login_address.burst_count, - update=True, - ) + self._address_ratelimiter.ratelimit(request.getClientIP()) login_submission = parse_json_object_from_request(request) try: @@ -198,13 +203,7 @@ class LoginRestServlet(RestServlet): # We also apply account rate limiting using the 3PID as a key, as # otherwise using 3PID bypasses the ratelimiting based on user ID. - self._failed_attempts_ratelimiter.ratelimit( - (medium, address), - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, - update=False, - ) + self._failed_attempts_ratelimiter.ratelimit((medium, address), update=False) # Check for login providers that support 3pid login types ( @@ -238,13 +237,7 @@ class LoginRestServlet(RestServlet): # If it returned None but the 3PID was bound then we won't hit # this code path, which is fine as then the per-user ratelimit # will kick in below. - self._failed_attempts_ratelimiter.can_do_action( - (medium, address), - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, - update=True, - ) + self._failed_attempts_ratelimiter.can_do_action((medium, address)) raise LoginError(403, "", errcode=Codes.FORBIDDEN) identifier = {"type": "m.id.user", "user": user_id} @@ -263,11 +256,7 @@ class LoginRestServlet(RestServlet): # Check if we've hit the failed ratelimit (but don't update it) self._failed_attempts_ratelimiter.ratelimit( - qualified_user_id.lower(), - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, - update=False, + qualified_user_id.lower(), update=False ) try: @@ -279,13 +268,7 @@ class LoginRestServlet(RestServlet): # limiter. Using `can_do_action` avoids us raising a ratelimit # exception and masking the LoginError. The actual ratelimiting # should have happened above. - self._failed_attempts_ratelimiter.can_do_action( - qualified_user_id.lower(), - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, - update=True, - ) + self._failed_attempts_ratelimiter.can_do_action(qualified_user_id.lower()) raise result = await self._complete_login( @@ -318,13 +301,7 @@ class LoginRestServlet(RestServlet): # Before we actually log them in we check if they've already logged in # too often. This happens here rather than before as we don't # necessarily know the user before now. - self._account_ratelimiter.ratelimit( - user_id.lower(), - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_account.per_second, - burst_count=self.hs.config.rc_login_account.burst_count, - update=True, - ) + self._account_ratelimiter.ratelimit(user_id.lower()) if create_non_existent_users: canonical_uid = await self.auth_handler.check_user_exists(user_id) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index addd4cae19..b9ffe86b2a 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -26,7 +26,6 @@ import synapse.types from synapse.api.constants import LoginType from synapse.api.errors import ( Codes, - LimitExceededError, SynapseError, ThreepidValidationError, UnrecognizedRequestError, @@ -396,20 +395,7 @@ class RegisterRestServlet(RestServlet): client_addr = request.getClientIP() - time_now = self.clock.time() - - allowed, time_allowed = self.ratelimiter.can_do_action( - client_addr, - time_now_s=time_now, - rate_hz=self.hs.config.rc_registration.per_second, - burst_count=self.hs.config.rc_registration.burst_count, - update=False, - ) - - if not allowed: - raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now)) - ) + self.ratelimiter.ratelimit(client_addr, update=False) kind = b"user" if b"kind" in request.args: diff --git a/synapse/server.py b/synapse/server.py index ca2deb49bb..fe94836a2c 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -242,9 +242,12 @@ class HomeServer(object): self.clock = Clock(reactor) self.distributor = Distributor() - self.ratelimiter = Ratelimiter() - self.admin_redaction_ratelimiter = Ratelimiter() - self.registration_ratelimiter = Ratelimiter() + + self.registration_ratelimiter = Ratelimiter( + clock=self.clock, + rate_hz=config.rc_registration.per_second, + burst_count=config.rc_registration.burst_count, + ) self.datastores = None @@ -314,15 +317,9 @@ class HomeServer(object): def get_distributor(self): return self.distributor - def get_ratelimiter(self): - return self.ratelimiter - - def get_registration_ratelimiter(self): + def get_registration_ratelimiter(self) -> Ratelimiter: return self.registration_ratelimiter - def get_admin_redaction_ratelimiter(self): - return self.admin_redaction_ratelimiter - def build_federation_client(self): return FederationClient(self) diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 5ca4521ce3..e5efdfcd02 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -43,7 +43,7 @@ class FederationRateLimiter(object): self.ratelimiters = collections.defaultdict(new_limiter) def ratelimit(self, host): - """Used to ratelimit an incoming request from given host + """Used to ratelimit an incoming request from a given host Example usage: diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index dbdd427cac..d580e729c5 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -1,39 +1,97 @@ -from synapse.api.ratelimiting import Ratelimiter +from synapse.api.ratelimiting import LimitExceededError, Ratelimiter from tests import unittest class TestRatelimiter(unittest.TestCase): - def test_allowed(self): - limiter = Ratelimiter() - allowed, time_allowed = limiter.can_do_action( - key="test_id", time_now_s=0, rate_hz=0.1, burst_count=1 - ) + def test_allowed_via_can_do_action(self): + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=0) self.assertTrue(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_do_action( - key="test_id", time_now_s=5, rate_hz=0.1, burst_count=1 - ) + allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=5) self.assertFalse(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_do_action( - key="test_id", time_now_s=10, rate_hz=0.1, burst_count=1 - ) + allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=10) self.assertTrue(allowed) self.assertEquals(20.0, time_allowed) - def test_pruning(self): - limiter = Ratelimiter() + def test_allowed_via_ratelimit(self): + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + + # Shouldn't raise + limiter.ratelimit(key="test_id", _time_now_s=0) + + # Should raise + with self.assertRaises(LimitExceededError) as context: + limiter.ratelimit(key="test_id", _time_now_s=5) + self.assertEqual(context.exception.retry_after_ms, 5000) + + # Shouldn't raise + limiter.ratelimit(key="test_id", _time_now_s=10) + + def test_allowed_via_can_do_action_and_overriding_parameters(self): + """Test that we can override options of can_do_action that would otherwise fail + an action + """ + # Create a Ratelimiter with a very low allowed rate_hz and burst_count + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + + # First attempt should be allowed + allowed, time_allowed = limiter.can_do_action(("test_id",), _time_now_s=0,) + self.assertTrue(allowed) + self.assertEqual(10.0, time_allowed) + + # Second attempt, 1s later, will fail + allowed, time_allowed = limiter.can_do_action(("test_id",), _time_now_s=1,) + self.assertFalse(allowed) + self.assertEqual(10.0, time_allowed) + + # But, if we allow 10 actions/sec for this request, we should be allowed + # to continue. allowed, time_allowed = limiter.can_do_action( - key="test_id_1", time_now_s=0, rate_hz=0.1, burst_count=1 + ("test_id",), _time_now_s=1, rate_hz=10.0 ) + self.assertTrue(allowed) + self.assertEqual(1.1, time_allowed) - self.assertIn("test_id_1", limiter.message_counts) - + # Similarly if we allow a burst of 10 actions allowed, time_allowed = limiter.can_do_action( - key="test_id_2", time_now_s=10, rate_hz=0.1, burst_count=1 + ("test_id",), _time_now_s=1, burst_count=10 ) + self.assertTrue(allowed) + self.assertEqual(1.0, time_allowed) + + def test_allowed_via_ratelimit_and_overriding_parameters(self): + """Test that we can override options of the ratelimit method that would otherwise + fail an action + """ + # Create a Ratelimiter with a very low allowed rate_hz and burst_count + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + + # First attempt should be allowed + limiter.ratelimit(key=("test_id",), _time_now_s=0) + + # Second attempt, 1s later, will fail + with self.assertRaises(LimitExceededError) as context: + limiter.ratelimit(key=("test_id",), _time_now_s=1) + self.assertEqual(context.exception.retry_after_ms, 9000) + + # But, if we allow 10 actions/sec for this request, we should be allowed + # to continue. + limiter.ratelimit(key=("test_id",), _time_now_s=1, rate_hz=10.0) + + # Similarly if we allow a burst of 10 actions + limiter.ratelimit(key=("test_id",), _time_now_s=1, burst_count=10) + + def test_pruning(self): + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + limiter.can_do_action(key="test_id_1", _time_now_s=0) + + self.assertIn("test_id_1", limiter.actions) + + limiter.can_do_action(key="test_id_2", _time_now_s=10) - self.assertNotIn("test_id_1", limiter.message_counts) + self.assertNotIn("test_id_1", limiter.actions) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 8aa56f1496..29dd7d9c6e 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -14,7 +14,7 @@ # limitations under the License. -from mock import Mock, NonCallableMock +from mock import Mock from twisted.internet import defer @@ -55,12 +55,8 @@ class ProfileTestCase(unittest.TestCase): federation_client=self.mock_federation, federation_server=Mock(), federation_registry=self.mock_registry, - ratelimiter=NonCallableMock(spec_set=["can_do_action"]), ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.can_do_action.return_value = (True, 0) - self.store = hs.get_datastore() self.frank = UserID.from_string("@1234ABCD:test") diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 32cb04645f..56497b8476 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock, NonCallableMock +from mock import Mock from tests.replication._base import BaseStreamTestCase @@ -21,12 +21,7 @@ from tests.replication._base import BaseStreamTestCase class BaseSlavedStoreTestCase(BaseStreamTestCase): def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver( - federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["can_do_action"]), - ) - - hs.get_ratelimiter().can_do_action.return_value = (True, 0) + hs = self.setup_test_homeserver(federation_client=Mock()) return hs diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index b54b06482b..f75520877f 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -15,7 +15,7 @@ """ Tests REST events for /events paths.""" -from mock import Mock, NonCallableMock +from mock import Mock import synapse.rest.admin from synapse.rest.client.v1 import events, login, room @@ -40,11 +40,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): config["enable_registration"] = True config["auto_join_rooms"] = [] - hs = self.setup_test_homeserver( - config=config, ratelimiter=NonCallableMock(spec_set=["can_do_action"]) - ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.can_do_action.return_value = (True, 0) + hs = self.setup_test_homeserver(config=config) hs.get_handlers().federation_handler = Mock() diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 0f0f7ca72d..9033f09fd2 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -29,7 +29,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() self.hs.config.enable_registration = True self.hs.config.registrations_require_3pid = [] @@ -38,10 +37,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): return self.hs + @override_config( + { + "rc_login": { + "address": {"per_second": 0.17, "burst_count": 5}, + # Prevent the account login ratelimiter from raising first + # + # This is normally covered by the default test homeserver config + # which sets these values to 10000, but as we're overriding the entire + # rc_login dict here, we need to set this manually as well + "account": {"per_second": 10000, "burst_count": 10000}, + } + } + ) def test_POST_ratelimiting_per_address(self): - self.hs.config.rc_login_address.burst_count = 5 - self.hs.config.rc_login_address.per_second = 0.17 - # Create different users so we're sure not to be bothered by the per-user # ratelimiter. for i in range(0, 6): @@ -80,10 +89,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) + @override_config( + { + "rc_login": { + "account": {"per_second": 0.17, "burst_count": 5}, + # Prevent the address login ratelimiter from raising first + # + # This is normally covered by the default test homeserver config + # which sets these values to 10000, but as we're overriding the entire + # rc_login dict here, we need to set this manually as well + "address": {"per_second": 10000, "burst_count": 10000}, + } + } + ) def test_POST_ratelimiting_per_account(self): - self.hs.config.rc_login_account.burst_count = 5 - self.hs.config.rc_login_account.per_second = 0.17 - self.register_user("kermit", "monkey") for i in range(0, 6): @@ -119,10 +138,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) + @override_config( + { + "rc_login": { + # Prevent the address login ratelimiter from raising first + # + # This is normally covered by the default test homeserver config + # which sets these values to 10000, but as we're overriding the entire + # rc_login dict here, we need to set this manually as well + "address": {"per_second": 10000, "burst_count": 10000}, + "failed_attempts": {"per_second": 0.17, "burst_count": 5}, + } + } + ) def test_POST_ratelimiting_per_account_failed_attempts(self): - self.hs.config.rc_login_failed_attempts.burst_count = 5 - self.hs.config.rc_login_failed_attempts.per_second = 0.17 - self.register_user("kermit", "monkey") for i in range(0, 6): diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 7dd86d0c27..4886bbb401 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -20,7 +20,7 @@ import json -from mock import Mock, NonCallableMock +from mock import Mock from six.moves.urllib import parse as urlparse from twisted.internet import defer @@ -46,13 +46,8 @@ class RoomBase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver( - "red", - http_client=None, - federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["can_do_action"]), + "red", http_client=None, federation_client=Mock(), ) - self.ratelimiter = self.hs.get_ratelimiter() - self.ratelimiter.can_do_action.return_value = (True, 0) self.hs.get_federation_handler = Mock(return_value=Mock()) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 4bc3aaf02d..18260bb90e 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -16,7 +16,7 @@ """Tests REST events for /rooms paths.""" -from mock import Mock, NonCallableMock +from mock import Mock from twisted.internet import defer @@ -39,17 +39,11 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - "red", - http_client=None, - federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["can_do_action"]), + "red", http_client=None, federation_client=Mock(), ) self.event_source = hs.get_event_sources().sources["typing"] - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.can_do_action.return_value = (True, 0) - hs.get_handlers().federation_handler = Mock() def get_user_by_access_token(token=None, allow_guest=False): diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 5637ce2090..7deaf5b24a 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -29,6 +29,7 @@ from synapse.rest.client.v1 import login, logout from synapse.rest.client.v2_alpha import account, account_validity, register, sync from tests import unittest +from tests.unittest import override_config class RegisterRestServletTestCase(unittest.HomeserverTestCase): @@ -146,10 +147,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.json_body["error"], "Guest access is disabled") + @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) def test_POST_ratelimiting_guest(self): - self.hs.config.rc_registration.burst_count = 5 - self.hs.config.rc_registration.per_second = 0.17 - for i in range(0, 6): url = self.url + b"?kind=guest" request, channel = self.make_request(b"POST", url, b"{}") @@ -168,10 +167,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) + @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) def test_POST_ratelimiting(self): - self.hs.config.rc_registration.burst_count = 5 - self.hs.config.rc_registration.per_second = 0.17 - for i in range(0, 6): params = { "username": "kermit" + str(i), -- cgit 1.5.1 From e55ee7c32fb3d958fb9da4e34133f53a6c57d2ea Mon Sep 17 00:00:00 2001 From: WGH Date: Fri, 5 Jun 2020 13:54:27 +0300 Subject: Add support for webp thumbnailing (#7586) Closes #4382 Signed-off-by: Maxim Plotnikov --- changelog.d/7586.feature | 1 + synapse/config/repository.py | 1 + tests/rest/media/v1/test_media_storage.py | 135 ++++++++++++++++++++++-------- 3 files changed, 101 insertions(+), 36 deletions(-) create mode 100644 changelog.d/7586.feature (limited to 'tests') diff --git a/changelog.d/7586.feature b/changelog.d/7586.feature new file mode 100644 index 0000000000..ef0231d823 --- /dev/null +++ b/changelog.d/7586.feature @@ -0,0 +1 @@ +Add support for generating thumbnails for WebP images. Previously, users would see an empty box instead of preview image. diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 9d2ce20220..b751d02d37 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -70,6 +70,7 @@ def parse_thumbnail_requirements(thumbnail_sizes): jpeg_thumbnail = ThumbnailRequirement(width, height, method, "image/jpeg") png_thumbnail = ThumbnailRequirement(width, height, method, "image/png") requirements.setdefault("image/jpeg", []).append(jpeg_thumbnail) + requirements.setdefault("image/webp", []).append(jpeg_thumbnail) requirements.setdefault("image/gif", []).append(png_thumbnail) requirements.setdefault("image/png", []).append(png_thumbnail) return { diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 1809ceb839..1ca648ef2b 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -18,10 +18,16 @@ import os import shutil import tempfile from binascii import unhexlify +from io import BytesIO +from typing import Optional from mock import Mock from six.moves.urllib import parse +import attr +import PIL.Image as Image +from parameterized import parameterized_class + from twisted.internet.defer import Deferred from synapse.logging.context import make_deferred_yieldable @@ -94,6 +100,68 @@ class MediaStorageTests(unittest.HomeserverTestCase): self.assertEqual(test_body, body) +@attr.s +class _TestImage: + """An image for testing thumbnailing with the expected results + + Attributes: + data: The raw image to thumbnail + content_type: The type of the image as a content type, e.g. "image/png" + extension: The extension associated with the format, e.g. ".png" + expected_cropped: The expected bytes from cropped thumbnailing, or None if + test should just check for success. + expected_scaled: The expected bytes from scaled thumbnailing, or None if + test should just check for a valid image returned. + """ + + data = attr.ib(type=bytes) + content_type = attr.ib(type=bytes) + extension = attr.ib(type=bytes) + expected_cropped = attr.ib(type=Optional[bytes]) + expected_scaled = attr.ib(type=Optional[bytes]) + + +@parameterized_class( + ("test_image",), + [ + # smol png + ( + _TestImage( + unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ), + b"image/png", + b".png", + unhexlify( + b"89504e470d0a1a0a0000000d4948445200000020000000200806" + b"000000737a7af40000001a49444154789cedc101010000008220" + b"ffaf6e484001000000ef0610200001194334ee0000000049454e" + b"44ae426082" + ), + unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000d49444154789c636060606000000005" + b"0001a5f645400000000049454e44ae426082" + ), + ), + ), + # small lossless webp + ( + _TestImage( + unhexlify( + b"524946461a000000574542505650384c0d0000002f0000001007" + b"1011118888fe0700" + ), + b"image/webp", + b".webp", + None, + None, + ), + ), + ], +) class MediaRepoTests(unittest.HomeserverTestCase): hijack_auth = True @@ -151,13 +219,6 @@ class MediaRepoTests(unittest.HomeserverTestCase): self.download_resource = self.media_repo.children[b"download"] self.thumbnail_resource = self.media_repo.children[b"thumbnail"] - # smol png - self.end_content = unhexlify( - b"89504e470d0a1a0a0000000d4948445200000001000000010806" - b"0000001f15c4890000000a49444154789c63000100000500010d" - b"0a2db40000000049454e44ae426082" - ) - self.media_id = "example.com/12345" def _req(self, content_disposition): @@ -176,14 +237,14 @@ class MediaRepoTests(unittest.HomeserverTestCase): self.assertEqual(self.fetches[0][3], {"allow_remote": "false"}) headers = { - b"Content-Length": [b"%d" % (len(self.end_content))], - b"Content-Type": [b"image/png"], + b"Content-Length": [b"%d" % (len(self.test_image.data))], + b"Content-Type": [self.test_image.content_type], } if content_disposition: headers[b"Content-Disposition"] = [content_disposition] self.fetches[0][0].callback( - (self.end_content, (len(self.end_content), headers)) + (self.test_image.data, (len(self.test_image.data), headers)) ) self.pump() @@ -196,12 +257,15 @@ class MediaRepoTests(unittest.HomeserverTestCase): If the filename is filename= then Synapse will decode it as an ASCII string, and use filename= in the response. """ - channel = self._req(b"inline; filename=out.png") + channel = self._req(b"inline; filename=out" + self.test_image.extension) headers = channel.headers - self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"]) self.assertEqual( - headers.getRawHeaders(b"Content-Disposition"), [b"inline; filename=out.png"] + headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type] + ) + self.assertEqual( + headers.getRawHeaders(b"Content-Disposition"), + [b"inline; filename=out" + self.test_image.extension], ) def test_disposition_filenamestar_utf8escaped(self): @@ -211,13 +275,17 @@ class MediaRepoTests(unittest.HomeserverTestCase): response. """ filename = parse.quote("\u2603".encode("utf8")).encode("ascii") - channel = self._req(b"inline; filename*=utf-8''" + filename + b".png") + channel = self._req( + b"inline; filename*=utf-8''" + filename + self.test_image.extension + ) headers = channel.headers - self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"]) + self.assertEqual( + headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type] + ) self.assertEqual( headers.getRawHeaders(b"Content-Disposition"), - [b"inline; filename*=utf-8''" + filename + b".png"], + [b"inline; filename*=utf-8''" + filename + self.test_image.extension], ) def test_disposition_none(self): @@ -228,27 +296,16 @@ class MediaRepoTests(unittest.HomeserverTestCase): channel = self._req(None) headers = channel.headers - self.assertEqual(headers.getRawHeaders(b"Content-Type"), [b"image/png"]) + self.assertEqual( + headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type] + ) self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None) def test_thumbnail_crop(self): - expected_body = unhexlify( - b"89504e470d0a1a0a0000000d4948445200000020000000200806" - b"000000737a7af40000001a49444154789cedc101010000008220" - b"ffaf6e484001000000ef0610200001194334ee0000000049454e" - b"44ae426082" - ) - - self._test_thumbnail("crop", expected_body) + self._test_thumbnail("crop", self.test_image.expected_cropped) def test_thumbnail_scale(self): - expected_body = unhexlify( - b"89504e470d0a1a0a0000000d4948445200000001000000010806" - b"0000001f15c4890000000d49444154789c636060606000000005" - b"0001a5f645400000000049454e44ae426082" - ) - - self._test_thumbnail("scale", expected_body) + self._test_thumbnail("scale", self.test_image.expected_scaled) def _test_thumbnail(self, method, expected_body): params = "?width=32&height=32&method=" + method @@ -259,13 +316,19 @@ class MediaRepoTests(unittest.HomeserverTestCase): self.pump() headers = { - b"Content-Length": [b"%d" % (len(self.end_content))], - b"Content-Type": [b"image/png"], + b"Content-Length": [b"%d" % (len(self.test_image.data))], + b"Content-Type": [self.test_image.content_type], } self.fetches[0][0].callback( - (self.end_content, (len(self.end_content), headers)) + (self.test_image.data, (len(self.test_image.data), headers)) ) self.pump() self.assertEqual(channel.code, 200) - self.assertEqual(channel.result["body"], expected_body, channel.result["body"]) + if expected_body is not None: + self.assertEqual( + channel.result["body"], expected_body, channel.result["body"] + ) + else: + # ensure that the result is at least some valid image + Image.open(BytesIO(channel.result["body"])) -- cgit 1.5.1 From 2970ce83674a4d910ebc46b505c9dcb83a15a1b9 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 5 Jun 2020 14:07:22 +0200 Subject: Add device management to admin API (#7481) - Admin is able to - change displaynames - delete devices - list devices - get device informations Fixes #7330 --- changelog.d/7481.feature | 1 + docs/admin_api/user_admin_api.rst | 209 +++++++++++++++ synapse/rest/admin/__init__.py | 8 + synapse/rest/admin/devices.py | 161 ++++++++++++ tests/rest/admin/test_device.py | 541 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 920 insertions(+) create mode 100644 changelog.d/7481.feature create mode 100644 synapse/rest/admin/devices.py create mode 100644 tests/rest/admin/test_device.py (limited to 'tests') diff --git a/changelog.d/7481.feature b/changelog.d/7481.feature new file mode 100644 index 0000000000..f167f3632c --- /dev/null +++ b/changelog.d/7481.feature @@ -0,0 +1 @@ +Add admin APIs to allow server admins to manage users' devices. Contributed by @dklimpel. diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index 776e71ec04..a3d52b282b 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -1,3 +1,5 @@ +.. contents:: + Create or modify Account ======================== @@ -245,3 +247,210 @@ with a body of: } including an ``access_token`` of a server admin. + + +User devices +============ + +List all devices +---------------- +Gets information about all devices for a specific ``user_id``. + +**Usage** + +A standard request to query the devices of an user: + +:: + + GET /_synapse/admin/v2/users//devices + + {} + +Response: + +.. code:: json + + { + "devices": [ + { + "device_id": "QBUAZIFURK", + "display_name": "android", + "last_seen_ip": "1.2.3.4", + "last_seen_ts": 1474491775024, + "user_id": "" + }, + { + "device_id": "AUIECTSRND", + "display_name": "ios", + "last_seen_ip": "1.2.3.5", + "last_seen_ts": 1474491775025, + "user_id": "" + } + ] + } + +**Parameters** + +The following query parameters are available: + +- ``user_id`` - fully qualified: for example, ``@user:server.com``. + +The following fields are possible in the JSON response body: + +- ``devices`` - An array of objects, each containing information about a device. + Device objects contain the following fields: + + - ``device_id`` - Identifier of device. + - ``display_name`` - Display name set by the user for this device. + Absent if no name has been set. + - ``last_seen_ip`` - The IP address where this device was last seen. + (May be a few minutes out of date, for efficiency reasons). + - ``last_seen_ts`` - The timestamp (in milliseconds since the unix epoch) when this + devices was last seen. (May be a few minutes out of date, for efficiency reasons). + - ``user_id`` - Owner of device. + +Delete multiple devices +------------------ +Deletes the given devices for a specific ``user_id``, and invalidates +any access token associated with them. + +**Usage** + +A standard request to delete devices: + +:: + + POST /_synapse/admin/v2/users//delete_devices + + { + "devices": [ + "QBUAZIFURK", + "AUIECTSRND" + ], + } + + +Response: + +.. code:: json + + {} + +**Parameters** + +The following query parameters are available: + +- ``user_id`` - fully qualified: for example, ``@user:server.com``. + +The following fields are required in the JSON request body: + +- ``devices`` - The list of device IDs to delete. + +Show a device +--------------- +Gets information on a single device, by ``device_id`` for a specific ``user_id``. + +**Usage** + +A standard request to get a device: + +:: + + GET /_synapse/admin/v2/users//devices/ + + {} + + +Response: + +.. code:: json + + { + "device_id": "", + "display_name": "android", + "last_seen_ip": "1.2.3.4", + "last_seen_ts": 1474491775024, + "user_id": "" + } + +**Parameters** + +The following query parameters are available: + +- ``user_id`` - fully qualified: for example, ``@user:server.com``. +- ``device_id`` - The device to retrieve. + +The following fields are possible in the JSON response body: + +- ``device_id`` - Identifier of device. +- ``display_name`` - Display name set by the user for this device. + Absent if no name has been set. +- ``last_seen_ip`` - The IP address where this device was last seen. + (May be a few minutes out of date, for efficiency reasons). +- ``last_seen_ts`` - The timestamp (in milliseconds since the unix epoch) when this + devices was last seen. (May be a few minutes out of date, for efficiency reasons). +- ``user_id`` - Owner of device. + +Update a device +--------------- +Updates the metadata on the given ``device_id`` for a specific ``user_id``. + +**Usage** + +A standard request to update a device: + +:: + + PUT /_synapse/admin/v2/users//devices/ + + { + "display_name": "My other phone" + } + + +Response: + +.. code:: json + + {} + +**Parameters** + +The following query parameters are available: + +- ``user_id`` - fully qualified: for example, ``@user:server.com``. +- ``device_id`` - The device to update. + +The following fields are required in the JSON request body: + +- ``display_name`` - The new display name for this device. If not given, + the display name is unchanged. + +Delete a device +--------------- +Deletes the given ``device_id`` for a specific ``user_id``, +and invalidates any access token associated with it. + +**Usage** + +A standard request for delete a device: + +:: + + DELETE /_synapse/admin/v2/users//devices/ + + {} + + +Response: + +.. code:: json + + {} + +**Parameters** + +The following query parameters are available: + +- ``user_id`` - fully qualified: for example, ``@user:server.com``. +- ``device_id`` - The device to delete. diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 6b85148a32..9eda592de9 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -26,6 +26,11 @@ from synapse.rest.admin._base import ( assert_requester_is_admin, historical_admin_path_patterns, ) +from synapse.rest.admin.devices import ( + DeleteDevicesRestServlet, + DeviceRestServlet, + DevicesRestServlet, +) from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet @@ -202,6 +207,9 @@ def register_servlets(hs, http_server): UserAdminServlet(hs).register(http_server) UserRestServletV2(hs).register(http_server) UsersRestServletV2(hs).register(http_server) + DeviceRestServlet(hs).register(http_server) + DevicesRestServlet(hs).register(http_server) + DeleteDevicesRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource(hs, http_server): diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py new file mode 100644 index 0000000000..8d32677339 --- /dev/null +++ b/synapse/rest/admin/devices.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import re + +from synapse.api.errors import NotFoundError, SynapseError +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, +) +from synapse.rest.admin._base import assert_requester_is_admin +from synapse.types import UserID + +logger = logging.getLogger(__name__) + + +class DeviceRestServlet(RestServlet): + """ + Get, update or delete the given user's device + """ + + PATTERNS = ( + re.compile( + "^/_synapse/admin/v2/users/(?P[^/]*)/devices/(?P[^/]*)$" + ), + ) + + def __init__(self, hs): + super(DeviceRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + self.store = hs.get_datastore() + + async def on_GET(self, request, user_id, device_id): + await assert_requester_is_admin(self.auth, request) + + target_user = UserID.from_string(user_id) + if not self.hs.is_mine(target_user): + raise SynapseError(400, "Can only lookup local users") + + u = await self.store.get_user_by_id(target_user.to_string()) + if u is None: + raise NotFoundError("Unknown user") + + device = await self.device_handler.get_device( + target_user.to_string(), device_id + ) + return 200, device + + async def on_DELETE(self, request, user_id, device_id): + await assert_requester_is_admin(self.auth, request) + + target_user = UserID.from_string(user_id) + if not self.hs.is_mine(target_user): + raise SynapseError(400, "Can only lookup local users") + + u = await self.store.get_user_by_id(target_user.to_string()) + if u is None: + raise NotFoundError("Unknown user") + + await self.device_handler.delete_device(target_user.to_string(), device_id) + return 200, {} + + async def on_PUT(self, request, user_id, device_id): + await assert_requester_is_admin(self.auth, request) + + target_user = UserID.from_string(user_id) + if not self.hs.is_mine(target_user): + raise SynapseError(400, "Can only lookup local users") + + u = await self.store.get_user_by_id(target_user.to_string()) + if u is None: + raise NotFoundError("Unknown user") + + body = parse_json_object_from_request(request, allow_empty_body=True) + await self.device_handler.update_device( + target_user.to_string(), device_id, body + ) + return 200, {} + + +class DevicesRestServlet(RestServlet): + """ + Retrieve the given user's devices + """ + + PATTERNS = (re.compile("^/_synapse/admin/v2/users/(?P[^/]*)/devices$"),) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + self.store = hs.get_datastore() + + async def on_GET(self, request, user_id): + await assert_requester_is_admin(self.auth, request) + + target_user = UserID.from_string(user_id) + if not self.hs.is_mine(target_user): + raise SynapseError(400, "Can only lookup local users") + + u = await self.store.get_user_by_id(target_user.to_string()) + if u is None: + raise NotFoundError("Unknown user") + + devices = await self.device_handler.get_devices_by_user(target_user.to_string()) + return 200, {"devices": devices} + + +class DeleteDevicesRestServlet(RestServlet): + """ + API for bulk deletion of devices. Accepts a JSON object with a devices + key which lists the device_ids to delete. + """ + + PATTERNS = ( + re.compile("^/_synapse/admin/v2/users/(?P[^/]*)/delete_devices$"), + ) + + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + self.store = hs.get_datastore() + + async def on_POST(self, request, user_id): + await assert_requester_is_admin(self.auth, request) + + target_user = UserID.from_string(user_id) + if not self.hs.is_mine(target_user): + raise SynapseError(400, "Can only lookup local users") + + u = await self.store.get_user_by_id(target_user.to_string()) + if u is None: + raise NotFoundError("Unknown user") + + body = parse_json_object_from_request(request, allow_empty_body=False) + assert_params_in_dict(body, ["devices"]) + + await self.device_handler.delete_devices( + target_user.to_string(), body["devices"] + ) + return 200, {} diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py new file mode 100644 index 0000000000..faa7f381a9 --- /dev/null +++ b/tests/rest/admin/test_device.py @@ -0,0 +1,541 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import urllib.parse + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client.v1 import login + +from tests import unittest + + +class DeviceRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.handler = hs.get_device_handler() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_token = self.login("user", "pass") + res = self.get_success(self.handler.get_devices_by_user(self.other_user)) + self.other_user_device_id = res[0]["device_id"] + + self.url = "/_synapse/admin/v2/users/%s/devices/%s" % ( + urllib.parse.quote(self.other_user), + self.other_user_device_id, + ) + + def test_no_auth(self): + """ + Try to get a device of an user without authentication. + """ + request, channel = self.make_request("GET", self.url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + request, channel = self.make_request("PUT", self.url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + request, channel = self.make_request("DELETE", self.url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + request, channel = self.make_request( + "GET", self.url, access_token=self.other_user_token, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + request, channel = self.make_request( + "PUT", self.url, access_token=self.other_user_token, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + request, channel = self.make_request( + "DELETE", self.url, access_token=self.other_user_token, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_user_does_not_exist(self): + """ + Tests that a lookup for a user that does not exist returns a 404 + """ + url = ( + "/_synapse/admin/v2/users/@unknown_person:test/devices/%s" + % self.other_user_device_id + ) + + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + request, channel = self.make_request( + "PUT", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + request, channel = self.make_request( + "DELETE", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_user_is_not_local(self): + """ + Tests that a lookup for a user that is not a local returns a 400 + """ + url = ( + "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices/%s" + % self.other_user_device_id + ) + + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only lookup local users", channel.json_body["error"]) + + request, channel = self.make_request( + "PUT", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only lookup local users", channel.json_body["error"]) + + request, channel = self.make_request( + "DELETE", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only lookup local users", channel.json_body["error"]) + + def test_unknown_device(self): + """ + Tests that a lookup for a device that does not exist returns either 404 or 200. + """ + url = "/_synapse/admin/v2/users/%s/devices/unknown_device" % urllib.parse.quote( + self.other_user + ) + + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + request, channel = self.make_request( + "PUT", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + + request, channel = self.make_request( + "DELETE", url, access_token=self.admin_user_tok, + ) + self.render(request) + + # Delete unknown device returns status 200 + self.assertEqual(200, channel.code, msg=channel.json_body) + + def test_update_device_too_long_display_name(self): + """ + Update a device with a display name that is invalid (too long). + """ + # Set iniital display name. + update = {"display_name": "new display"} + self.get_success( + self.handler.update_device( + self.other_user, self.other_user_device_id, update + ) + ) + + # Request to update a device display name with a new value that is longer than allowed. + update = { + "display_name": "a" + * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1) + } + + body = json.dumps(update) + request, channel = self.make_request( + "PUT", + self.url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + # Ensure the display name was not updated. + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("new display", channel.json_body["display_name"]) + + def test_update_no_display_name(self): + """ + Tests that a update for a device without JSON returns a 200 + """ + # Set iniital display name. + update = {"display_name": "new display"} + self.get_success( + self.handler.update_device( + self.other_user, self.other_user_device_id, update + ) + ) + + request, channel = self.make_request( + "PUT", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + + # Ensure the display name was not updated. + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("new display", channel.json_body["display_name"]) + + def test_update_display_name(self): + """ + Tests a normal successful update of display name + """ + # Set new display_name + body = json.dumps({"display_name": "new displayname"}) + request, channel = self.make_request( + "PUT", + self.url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + + # Check new display_name + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("new displayname", channel.json_body["display_name"]) + + def test_get_device(self): + """ + Tests that a normal lookup for a device is successfully + """ + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(self.other_user, channel.json_body["user_id"]) + # Check that all fields are available + self.assertIn("user_id", channel.json_body) + self.assertIn("device_id", channel.json_body) + self.assertIn("display_name", channel.json_body) + self.assertIn("last_seen_ip", channel.json_body) + self.assertIn("last_seen_ts", channel.json_body) + + def test_delete_device(self): + """ + Tests that a remove of a device is successfully + """ + # Count number of devies of an user. + res = self.get_success(self.handler.get_devices_by_user(self.other_user)) + number_devices = len(res) + self.assertEqual(1, number_devices) + + # Delete device + request, channel = self.make_request( + "DELETE", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + + # Ensure that the number of devices is decreased + res = self.get_success(self.handler.get_devices_by_user(self.other_user)) + self.assertEqual(number_devices - 1, len(res)) + + +class DevicesRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + + self.url = "/_synapse/admin/v2/users/%s/devices" % urllib.parse.quote( + self.other_user + ) + + def test_no_auth(self): + """ + Try to list devices of an user without authentication. + """ + request, channel = self.make_request("GET", self.url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + other_user_token = self.login("user", "pass") + + request, channel = self.make_request( + "GET", self.url, access_token=other_user_token, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_user_does_not_exist(self): + """ + Tests that a lookup for a user that does not exist returns a 404 + """ + url = "/_synapse/admin/v2/users/@unknown_person:test/devices" + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_user_is_not_local(self): + """ + Tests that a lookup for a user that is not a local returns a 400 + """ + url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices" + + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only lookup local users", channel.json_body["error"]) + + def test_get_devices(self): + """ + Tests that a normal lookup for devices is successfully + """ + # Create devices + number_devices = 5 + for n in range(number_devices): + self.login("user", "pass") + + # Get devices + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(number_devices, len(channel.json_body["devices"])) + self.assertEqual(self.other_user, channel.json_body["devices"][0]["user_id"]) + # Check that all fields are available + for d in channel.json_body["devices"]: + self.assertIn("user_id", d) + self.assertIn("device_id", d) + self.assertIn("display_name", d) + self.assertIn("last_seen_ip", d) + self.assertIn("last_seen_ts", d) + + +class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.handler = hs.get_device_handler() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + + self.url = "/_synapse/admin/v2/users/%s/delete_devices" % urllib.parse.quote( + self.other_user + ) + + def test_no_auth(self): + """ + Try to delete devices of an user without authentication. + """ + request, channel = self.make_request("POST", self.url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + other_user_token = self.login("user", "pass") + + request, channel = self.make_request( + "POST", self.url, access_token=other_user_token, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_user_does_not_exist(self): + """ + Tests that a lookup for a user that does not exist returns a 404 + """ + url = "/_synapse/admin/v2/users/@unknown_person:test/delete_devices" + request, channel = self.make_request( + "POST", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_user_is_not_local(self): + """ + Tests that a lookup for a user that is not a local returns a 400 + """ + url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/delete_devices" + + request, channel = self.make_request( + "POST", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only lookup local users", channel.json_body["error"]) + + def test_unknown_devices(self): + """ + Tests that a remove of a device that does not exist returns 200. + """ + body = json.dumps({"devices": ["unknown_device1", "unknown_device2"]}) + request, channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + # Delete unknown devices returns status 200 + self.assertEqual(200, channel.code, msg=channel.json_body) + + def test_delete_devices(self): + """ + Tests that a remove of devices is successfully + """ + + # Create devices + number_devices = 5 + for n in range(number_devices): + self.login("user", "pass") + + # Get devices + res = self.get_success(self.handler.get_devices_by_user(self.other_user)) + self.assertEqual(number_devices, len(res)) + + # Create list of device IDs + device_ids = [] + for d in res: + device_ids.append(str(d["device_id"])) + + # Delete devices + body = json.dumps({"devices": device_ids}) + request, channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + + res = self.get_success(self.handler.get_devices_by_user(self.other_user)) + self.assertEqual(0, len(res)) -- cgit 1.5.1 From 908f9e2d24617a62f5e2fe52aa68941c64b0fde3 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 5 Jun 2020 14:08:49 +0200 Subject: Allow new users to be registered via the admin API even if the monthly active user limit has been reached (#7263) --- changelog.d/7263.bugfix | 1 + synapse/handlers/register.py | 7 +- synapse/rest/admin/users.py | 2 + tests/rest/admin/test_user.py | 178 ++++++++++++++++++++++++++++++++++++++---- 4 files changed, 172 insertions(+), 16 deletions(-) create mode 100644 changelog.d/7263.bugfix (limited to 'tests') diff --git a/changelog.d/7263.bugfix b/changelog.d/7263.bugfix new file mode 100644 index 0000000000..0b4739261c --- /dev/null +++ b/changelog.d/7263.bugfix @@ -0,0 +1 @@ +Allow new users to be registered via the admin API even if the monthly active user limit has been reached. Contributed by @dkimpel. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index cd746be7c8..ffda09226c 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -150,6 +150,7 @@ class RegistrationHandler(BaseHandler): default_display_name=None, address=None, bind_emails=[], + by_admin=False, ): """Registers a new client on the server. @@ -165,6 +166,8 @@ class RegistrationHandler(BaseHandler): will be set to this. Defaults to 'localpart'. address (str|None): the IP address used to perform the registration. bind_emails (List[str]): list of emails to bind to this account. + by_admin (bool): True if this registration is being made via the + admin api, otherwise False. Returns: Deferred[str]: user_id Raises: @@ -172,7 +175,9 @@ class RegistrationHandler(BaseHandler): """ yield self.check_registration_ratelimit(address) - yield self.auth.check_auth_blocking(threepid=threepid) + # do not check_auth_blocking if the call is coming through the Admin API + if not by_admin: + yield self.auth.check_auth_blocking(threepid=threepid) if localpart is not None: yield self.check_username(localpart, guest_access_token=guest_access_token) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 82251dbe5f..fefc8f71fa 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -270,6 +270,7 @@ class UserRestServletV2(RestServlet): admin=bool(admin), default_display_name=displayname, user_type=user_type, + by_admin=True, ) if "threepids" in body: @@ -432,6 +433,7 @@ class UserRegisterServlet(RestServlet): password_hash=password_hash, admin=bool(admin), user_type=user_type, + by_admin=True, ) result = await register._create_registration_details(user_id, body) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index e29cc24a8a..cca5f548e6 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -22,9 +22,12 @@ from mock import Mock import synapse.rest.admin from synapse.api.constants import UserTypes +from synapse.api.errors import HttpResponseException, ResourceLimitError from synapse.rest.client.v1 import login +from synapse.rest.client.v2_alpha import sync from tests import unittest +from tests.unittest import override_config class UserRegisterTestCase(unittest.HomeserverTestCase): @@ -320,6 +323,52 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("Invalid user type", channel.json_body["error"]) + @override_config( + {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0} + ) + def test_register_mau_limit_reached(self): + """ + Check we can register a user via the shared secret registration API + even if the MAU limit is reached. + """ + handler = self.hs.get_registration_handler() + store = self.hs.get_datastore() + + # Set monthly active users to the limit + store.get_monthly_active_count = Mock(return_value=self.hs.config.max_mau_value) + # Check that the blocking of monthly active users is working as expected + # The registration of a new user fails due to the limit + self.get_failure( + handler.register_user(localpart="local_part"), ResourceLimitError + ) + + # Register new user with admin API + request, channel = self.make_request("GET", self.url) + self.render(request) + nonce = channel.json_body["nonce"] + + want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) + want_mac.update( + nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support" + ) + want_mac = want_mac.hexdigest() + + body = json.dumps( + { + "nonce": nonce, + "username": "bob", + "password": "abc123", + "admin": True, + "user_type": UserTypes.SUPPORT, + "mac": want_mac, + } + ) + request, channel = self.make_request("POST", self.url, body.encode("utf8")) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["user_id"]) + class UsersListTestCase(unittest.HomeserverTestCase): @@ -368,6 +417,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, + sync.register_servlets, ] def prepare(self, reactor, clock, hs): @@ -386,7 +436,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ If the user is not a server admin, an error is returned. """ - self.hs.config.registration_shared_secret = None url = "/_synapse/admin/v2/users/@bob:test" request, channel = self.make_request( @@ -409,7 +458,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ Tests that a lookup for a user that does not exist returns a 404 """ - self.hs.config.registration_shared_secret = None request, channel = self.make_request( "GET", @@ -425,7 +473,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ Check that a new admin user is created successfully. """ - self.hs.config.registration_shared_secret = None url = "/_synapse/admin/v2/users/@bob:test" # Create user (server admin) @@ -473,7 +520,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ Check that a new regular user is created successfully. """ - self.hs.config.registration_shared_secret = None url = "/_synapse/admin/v2/users/@bob:test" # Create user @@ -516,14 +562,114 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(False, channel.json_body["is_guest"]) self.assertEqual(False, channel.json_body["deactivated"]) + @override_config( + {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0} + ) + def test_create_user_mau_limit_reached_active_admin(self): + """ + Check that an admin can register a new user via the admin API + even if the MAU limit is reached. + Admin user was active before creating user. + """ + + handler = self.hs.get_registration_handler() + + # Sync to set admin user to active + # before limit of monthly active users is reached + request, channel = self.make_request( + "GET", "/sync", access_token=self.admin_user_tok + ) + self.render(request) + + if channel.code != 200: + raise HttpResponseException( + channel.code, channel.result["reason"], channel.result["body"] + ) + + # Set monthly active users to the limit + self.store.get_monthly_active_count = Mock( + return_value=self.hs.config.max_mau_value + ) + # Check that the blocking of monthly active users is working as expected + # The registration of a new user fails due to the limit + self.get_failure( + handler.register_user(localpart="local_part"), ResourceLimitError + ) + + # Register new user with admin API + url = "/_synapse/admin/v2/users/@bob:test" + + # Create user + body = json.dumps({"password": "abc123", "admin": False}) + + request, channel = self.make_request( + "PUT", + url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual(False, channel.json_body["admin"]) + + @override_config( + {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0} + ) + def test_create_user_mau_limit_reached_passive_admin(self): + """ + Check that an admin can register a new user via the admin API + even if the MAU limit is reached. + Admin user was not active before creating user. + """ + + handler = self.hs.get_registration_handler() + + # Set monthly active users to the limit + self.store.get_monthly_active_count = Mock( + return_value=self.hs.config.max_mau_value + ) + # Check that the blocking of monthly active users is working as expected + # The registration of a new user fails due to the limit + self.get_failure( + handler.register_user(localpart="local_part"), ResourceLimitError + ) + + # Register new user with admin API + url = "/_synapse/admin/v2/users/@bob:test" + + # Create user + body = json.dumps({"password": "abc123", "admin": False}) + + request, channel = self.make_request( + "PUT", + url, + access_token=self.admin_user_tok, + content=body.encode(encoding="utf_8"), + ) + self.render(request) + + # Admin user is not blocked by mau anymore + self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@bob:test", channel.json_body["name"]) + self.assertEqual(False, channel.json_body["admin"]) + + @override_config( + { + "email": { + "enable_notifs": True, + "notif_for_new_users": True, + "notif_from": "test@example.com", + }, + "public_baseurl": "https://example.com", + } + ) def test_create_user_email_notif_for_new_users(self): """ Check that a new regular user is created successfully and got an email pusher. """ - self.hs.config.registration_shared_secret = None - self.hs.config.email_enable_notifs = True - self.hs.config.email_notif_for_new_users = True url = "/_synapse/admin/v2/users/@bob:test" # Create user @@ -554,14 +700,21 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(len(pushers), 1) self.assertEqual("@bob:test", pushers[0]["user_name"]) + @override_config( + { + "email": { + "enable_notifs": False, + "notif_for_new_users": False, + "notif_from": "test@example.com", + }, + "public_baseurl": "https://example.com", + } + ) def test_create_user_email_no_notif_for_new_users(self): """ Check that a new regular user is created successfully and got not an email pusher. """ - self.hs.config.registration_shared_secret = None - self.hs.config.email_enable_notifs = False - self.hs.config.email_notif_for_new_users = False url = "/_synapse/admin/v2/users/@bob:test" # Create user @@ -595,7 +748,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ Test setting a new password for another user. """ - self.hs.config.registration_shared_secret = None # Change password body = json.dumps({"password": "hahaha"}) @@ -614,7 +766,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ Test setting the displayname of another user. """ - self.hs.config.registration_shared_secret = None # Modify user body = json.dumps({"displayname": "foobar"}) @@ -645,7 +796,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ Test setting threepid for an other user. """ - self.hs.config.registration_shared_secret = None # Delete old and add new threepid to user body = json.dumps( @@ -711,7 +861,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ Test setting the admin flag on a user. """ - self.hs.config.registration_shared_secret = None # Set a user as an admin body = json.dumps({"admin": True}) @@ -743,7 +892,6 @@ class UserRestTestCase(unittest.HomeserverTestCase): Ensure an account can't accidentally be deactivated by using a str value for the deactivated body parameter """ - self.hs.config.registration_shared_secret = None url = "/_synapse/admin/v2/users/@bob:test" # Create user -- cgit 1.5.1 From 09099313e6d527938013bb46640efc3768960d21 Mon Sep 17 00:00:00 2001 From: Travis Ralston Date: Fri, 5 Jun 2020 11:18:15 -0600 Subject: Add an option to disable autojoin for guest accounts (#6637) Fixes https://github.com/matrix-org/synapse/issues/3177 --- changelog.d/6637.feature | 1 + docs/sample_config.yaml | 7 +++++++ synapse/config/registration.py | 8 ++++++++ synapse/handlers/register.py | 8 +++++++- tests/handlers/test_register.py | 10 ++++++++++ 5 files changed, 33 insertions(+), 1 deletion(-) create mode 100644 changelog.d/6637.feature (limited to 'tests') diff --git a/changelog.d/6637.feature b/changelog.d/6637.feature new file mode 100644 index 0000000000..5228ebc1e5 --- /dev/null +++ b/changelog.d/6637.feature @@ -0,0 +1 @@ +Add an option to disable autojoining rooms for guest accounts. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index b06394a2bd..94e1ec698f 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1223,6 +1223,13 @@ account_threepid_delegates: # #autocreate_auto_join_rooms: true +# When auto_join_rooms is specified, setting this flag to false prevents +# guest accounts from being automatically joined to the rooms. +# +# Defaults to true. +# +#auto_join_rooms_for_guests: false + ## Metrics ### diff --git a/synapse/config/registration.py b/synapse/config/registration.py index a9aa8c3737..fecced2d57 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -128,6 +128,7 @@ class RegistrationConfig(Config): if not RoomAlias.is_valid(room_alias): raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,)) self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True) + self.auto_join_rooms_for_guests = config.get("auto_join_rooms_for_guests", True) self.enable_set_displayname = config.get("enable_set_displayname", True) self.enable_set_avatar_url = config.get("enable_set_avatar_url", True) @@ -368,6 +369,13 @@ class RegistrationConfig(Config): # users cannot be auto-joined since they do not exist. # #autocreate_auto_join_rooms: true + + # When auto_join_rooms is specified, setting this flag to false prevents + # guest accounts from being automatically joined to the rooms. + # + # Defaults to true. + # + #auto_join_rooms_for_guests: false """ % locals() ) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index ffda09226c..5c7113a3bb 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -244,7 +244,13 @@ class RegistrationHandler(BaseHandler): fail_count += 1 if not self.hs.config.user_consent_at_registration: - yield defer.ensureDeferred(self._auto_join_rooms(user_id)) + if not self.hs.config.auto_join_rooms_for_guests and make_guest: + logger.info( + "Skipping auto-join for %s because auto-join for guests is disabled", + user_id, + ) + else: + yield defer.ensureDeferred(self._auto_join_rooms(user_id)) else: logger.info( "Skipping auto-join for %s because consent is required at registration", diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 1b7935cef2..ca32f993a3 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -135,6 +135,16 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.handler.register_user(localpart="local_part"), ResourceLimitError ) + def test_auto_join_rooms_for_guests(self): + room_alias_str = "#room:test" + self.hs.config.auto_join_rooms = [room_alias_str] + self.hs.config.auto_join_rooms_for_guests = False + user_id = self.get_success( + self.handler.register_user(localpart="jeff", make_guest=True), + ) + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) + self.assertEqual(len(rooms), 0) + def test_auto_create_auto_join_rooms(self): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] -- cgit 1.5.1 From 737b4a936e35daaa839e7296d888041941546b47 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 5 Jun 2020 14:42:55 -0400 Subject: Convert user directory handler and related classes to async/await. (#7640) --- changelog.d/7640.misc | 1 + synapse/handlers/register.py | 6 +- synapse/handlers/state_deltas.py | 9 +-- synapse/handlers/stats.py | 47 ++++++-------- synapse/handlers/user_directory.py | 118 +++++++++++++--------------------- tests/handlers/test_user_directory.py | 8 +-- 6 files changed, 78 insertions(+), 111 deletions(-) create mode 100644 changelog.d/7640.misc (limited to 'tests') diff --git a/changelog.d/7640.misc b/changelog.d/7640.misc new file mode 100644 index 0000000000..55edc1c781 --- /dev/null +++ b/changelog.d/7640.misc @@ -0,0 +1 @@ +Convert user directory, state deltas, and stats handlers to async/await. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 5c7113a3bb..af812dbda9 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -207,8 +207,10 @@ class RegistrationHandler(BaseHandler): if self.hs.config.user_directory_search_all_users: profile = yield self.store.get_profileinfo(localpart) - yield self.user_directory_handler.handle_local_profile_change( - user_id, profile + yield defer.ensureDeferred( + self.user_directory_handler.handle_local_profile_change( + user_id, profile + ) ) else: diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py index f065970c40..8590c1eff4 100644 --- a/synapse/handlers/state_deltas.py +++ b/synapse/handlers/state_deltas.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - logger = logging.getLogger(__name__) @@ -24,8 +22,7 @@ class StateDeltasHandler(object): def __init__(self, hs): self.store = hs.get_datastore() - @defer.inlineCallbacks - def _get_key_change(self, prev_event_id, event_id, key_name, public_value): + async def _get_key_change(self, prev_event_id, event_id, key_name, public_value): """Given two events check if the `key_name` field in content changed from not matching `public_value` to doing so. @@ -41,10 +38,10 @@ class StateDeltasHandler(object): prev_event = None event = None if prev_event_id: - prev_event = yield self.store.get_event(prev_event_id, allow_none=True) + prev_event = await self.store.get_event(prev_event_id, allow_none=True) if event_id: - event = yield self.store.get_event(event_id, allow_none=True) + event = await self.store.get_event(event_id, allow_none=True) if not event and not prev_event: logger.debug("Neither event exists: %r %r", prev_event_id, event_id) diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index d93a276693..149f861239 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -16,17 +16,14 @@ import logging from collections import Counter -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership -from synapse.handlers.state_deltas import StateDeltasHandler from synapse.metrics import event_processing_positions from synapse.metrics.background_process_metrics import run_as_background_process logger = logging.getLogger(__name__) -class StatsHandler(StateDeltasHandler): +class StatsHandler: """Handles keeping the *_stats tables updated with a simple time-series of information about the users, rooms and media on the server, such that admins have some idea of who is consuming their resources. @@ -35,7 +32,6 @@ class StatsHandler(StateDeltasHandler): """ def __init__(self, hs): - super(StatsHandler, self).__init__(hs) self.hs = hs self.store = hs.get_datastore() self.state = hs.get_state_handler() @@ -68,20 +64,18 @@ class StatsHandler(StateDeltasHandler): self._is_processing = True - @defer.inlineCallbacks - def process(): + async def process(): try: - yield self._unsafe_process() + await self._unsafe_process() finally: self._is_processing = False run_as_background_process("stats.notify_new_event", process) - @defer.inlineCallbacks - def _unsafe_process(self): + async def _unsafe_process(self): # If self.pos is None then means we haven't fetched it from DB if self.pos is None: - self.pos = yield self.store.get_stats_positions() + self.pos = await self.store.get_stats_positions() # Loop round handling deltas until we're up to date @@ -96,13 +90,13 @@ class StatsHandler(StateDeltasHandler): logger.debug( "Processing room stats %s->%s", self.pos, room_max_stream_ordering ) - max_pos, deltas = yield self.store.get_current_state_deltas( + max_pos, deltas = await self.store.get_current_state_deltas( self.pos, room_max_stream_ordering ) if deltas: logger.debug("Handling %d state deltas", len(deltas)) - room_deltas, user_deltas = yield self._handle_deltas(deltas) + room_deltas, user_deltas = await self._handle_deltas(deltas) else: room_deltas = {} user_deltas = {} @@ -111,7 +105,7 @@ class StatsHandler(StateDeltasHandler): ( room_count, user_count, - ) = yield self.store.get_changes_room_total_events_and_bytes( + ) = await self.store.get_changes_room_total_events_and_bytes( self.pos, max_pos ) @@ -125,7 +119,7 @@ class StatsHandler(StateDeltasHandler): logger.debug("user_deltas: %s", user_deltas) # Always call this so that we update the stats position. - yield self.store.bulk_update_stats_delta( + await self.store.bulk_update_stats_delta( self.clock.time_msec(), updates={"room": room_deltas, "user": user_deltas}, stream_id=max_pos, @@ -137,13 +131,12 @@ class StatsHandler(StateDeltasHandler): self.pos = max_pos - @defer.inlineCallbacks - def _handle_deltas(self, deltas): + async def _handle_deltas(self, deltas): """Called with the state deltas to process Returns: - Deferred[tuple[dict[str, Counter], dict[str, counter]]] - Resovles to two dicts, the room deltas and the user deltas, + tuple[dict[str, Counter], dict[str, counter]] + Two dicts: the room deltas and the user deltas, mapping from room/user ID to changes in the various fields. """ @@ -162,7 +155,7 @@ class StatsHandler(StateDeltasHandler): logger.debug("Handling: %r, %r %r, %s", room_id, typ, state_key, event_id) - token = yield self.store.get_earliest_token_for_stats("room", room_id) + token = await self.store.get_earliest_token_for_stats("room", room_id) # If the earliest token to begin from is larger than our current # stream ID, skip processing this delta. @@ -184,7 +177,7 @@ class StatsHandler(StateDeltasHandler): sender = None if event_id is not None: - event = yield self.store.get_event(event_id, allow_none=True) + event = await self.store.get_event(event_id, allow_none=True) if event: event_content = event.content or {} sender = event.sender @@ -200,16 +193,16 @@ class StatsHandler(StateDeltasHandler): room_stats_delta["current_state_events"] += 1 if typ == EventTypes.Member: - # we could use _get_key_change here but it's a bit inefficient - # given we're not testing for a specific result; might as well - # just grab the prev_membership and membership strings and - # compare them. + # we could use StateDeltasHandler._get_key_change here but it's + # a bit inefficient given we're not testing for a specific + # result; might as well just grab the prev_membership and + # membership strings and compare them. # We take None rather than leave as a previous membership # in the absence of a previous event because we do not want to # reduce the leave count when a new-to-the-room user joins. prev_membership = None if prev_event_id is not None: - prev_event = yield self.store.get_event( + prev_event = await self.store.get_event( prev_event_id, allow_none=True ) if prev_event: @@ -301,6 +294,6 @@ class StatsHandler(StateDeltasHandler): for room_id, state in room_to_state_updates.items(): logger.debug("Updating room_stats_state for %s: %s", room_id, state) - yield self.store.update_room_state(room_id, state) + await self.store.update_room_state(room_id, state) return room_to_stats_deltas, user_to_stats_deltas diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 722760c59d..12423b909a 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -17,14 +17,11 @@ import logging from six import iteritems, iterkeys -from twisted.internet import defer - import synapse.metrics from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.handlers.state_deltas import StateDeltasHandler from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.roommember import ProfileInfo -from synapse.types import get_localpart_from_id from synapse.util.metrics import Measure logger = logging.getLogger(__name__) @@ -103,43 +100,39 @@ class UserDirectoryHandler(StateDeltasHandler): if self._is_processing: return - @defer.inlineCallbacks - def process(): + async def process(): try: - yield self._unsafe_process() + await self._unsafe_process() finally: self._is_processing = False self._is_processing = True run_as_background_process("user_directory.notify_new_event", process) - @defer.inlineCallbacks - def handle_local_profile_change(self, user_id, profile): + async def handle_local_profile_change(self, user_id, profile): """Called to update index of our local user profiles when they change irrespective of any rooms the user may be in. """ # FIXME(#3714): We should probably do this in the same worker as all # the other changes. - is_support = yield self.store.is_support_user(user_id) + is_support = await self.store.is_support_user(user_id) # Support users are for diagnostics and should not appear in the user directory. if not is_support: - yield self.store.update_profile_in_user_dir( + await self.store.update_profile_in_user_dir( user_id, profile.display_name, profile.avatar_url ) - @defer.inlineCallbacks - def handle_user_deactivated(self, user_id): + async def handle_user_deactivated(self, user_id): """Called when a user ID is deactivated """ # FIXME(#3714): We should probably do this in the same worker as all # the other changes. - yield self.store.remove_from_user_dir(user_id) + await self.store.remove_from_user_dir(user_id) - @defer.inlineCallbacks - def _unsafe_process(self): + async def _unsafe_process(self): # If self.pos is None then means we haven't fetched it from DB if self.pos is None: - self.pos = yield self.store.get_user_directory_stream_pos() + self.pos = await self.store.get_user_directory_stream_pos() # If still None then the initial background update hasn't happened yet if self.pos is None: @@ -155,12 +148,12 @@ class UserDirectoryHandler(StateDeltasHandler): logger.debug( "Processing user stats %s->%s", self.pos, room_max_stream_ordering ) - max_pos, deltas = yield self.store.get_current_state_deltas( + max_pos, deltas = await self.store.get_current_state_deltas( self.pos, room_max_stream_ordering ) logger.debug("Handling %d state deltas", len(deltas)) - yield self._handle_deltas(deltas) + await self._handle_deltas(deltas) self.pos = max_pos @@ -169,10 +162,9 @@ class UserDirectoryHandler(StateDeltasHandler): max_pos ) - yield self.store.update_user_directory_stream_pos(max_pos) + await self.store.update_user_directory_stream_pos(max_pos) - @defer.inlineCallbacks - def _handle_deltas(self, deltas): + async def _handle_deltas(self, deltas): """Called with the state deltas to process """ for delta in deltas: @@ -187,11 +179,11 @@ class UserDirectoryHandler(StateDeltasHandler): # For join rule and visibility changes we need to check if the room # may have become public or not and add/remove the users in said room if typ in (EventTypes.RoomHistoryVisibility, EventTypes.JoinRules): - yield self._handle_room_publicity_change( + await self._handle_room_publicity_change( room_id, prev_event_id, event_id, typ ) elif typ == EventTypes.Member: - change = yield self._get_key_change( + change = await self._get_key_change( prev_event_id, event_id, key_name="membership", @@ -201,7 +193,7 @@ class UserDirectoryHandler(StateDeltasHandler): if change is False: # Need to check if the server left the room entirely, if so # we might need to remove all the users in that room - is_in_room = yield self.store.is_host_joined( + is_in_room = await self.store.is_host_joined( room_id, self.server_name ) if not is_in_room: @@ -209,40 +201,41 @@ class UserDirectoryHandler(StateDeltasHandler): # Fetch all the users that we marked as being in user # directory due to being in the room and then check if # need to remove those users or not - user_ids = yield self.store.get_users_in_dir_due_to_room( + user_ids = await self.store.get_users_in_dir_due_to_room( room_id ) for user_id in user_ids: - yield self._handle_remove_user(room_id, user_id) + await self._handle_remove_user(room_id, user_id) return else: logger.debug("Server is still in room: %r", room_id) - is_support = yield self.store.is_support_user(state_key) + is_support = await self.store.is_support_user(state_key) if not is_support: if change is None: # Handle any profile changes - yield self._handle_profile_change( + await self._handle_profile_change( state_key, room_id, prev_event_id, event_id ) continue if change: # The user joined - event = yield self.store.get_event(event_id, allow_none=True) + event = await self.store.get_event(event_id, allow_none=True) profile = ProfileInfo( avatar_url=event.content.get("avatar_url"), display_name=event.content.get("displayname"), ) - yield self._handle_new_user(room_id, state_key, profile) + await self._handle_new_user(room_id, state_key, profile) else: # The user left - yield self._handle_remove_user(room_id, state_key) + await self._handle_remove_user(room_id, state_key) else: logger.debug("Ignoring irrelevant type: %r", typ) - @defer.inlineCallbacks - def _handle_room_publicity_change(self, room_id, prev_event_id, event_id, typ): + async def _handle_room_publicity_change( + self, room_id, prev_event_id, event_id, typ + ): """Handle a room having potentially changed from/to world_readable/publically joinable. @@ -255,14 +248,14 @@ class UserDirectoryHandler(StateDeltasHandler): logger.debug("Handling change for %s: %s", typ, room_id) if typ == EventTypes.RoomHistoryVisibility: - change = yield self._get_key_change( + change = await self._get_key_change( prev_event_id, event_id, key_name="history_visibility", public_value="world_readable", ) elif typ == EventTypes.JoinRules: - change = yield self._get_key_change( + change = await self._get_key_change( prev_event_id, event_id, key_name="join_rule", @@ -278,7 +271,7 @@ class UserDirectoryHandler(StateDeltasHandler): # There's been a change to or from being world readable. - is_public = yield self.store.is_room_world_readable_or_publicly_joinable( + is_public = await self.store.is_room_world_readable_or_publicly_joinable( room_id ) @@ -293,11 +286,11 @@ class UserDirectoryHandler(StateDeltasHandler): # ignore the change return - users_with_profile = yield self.state.get_current_users_in_room(room_id) + users_with_profile = await self.state.get_current_users_in_room(room_id) # Remove every user from the sharing tables for that room. for user_id in iterkeys(users_with_profile): - yield self.store.remove_user_who_share_room(user_id, room_id) + await self.store.remove_user_who_share_room(user_id, room_id) # Then, re-add them to the tables. # NOTE: this is not the most efficient method, as handle_new_user sets @@ -306,26 +299,9 @@ class UserDirectoryHandler(StateDeltasHandler): # being added multiple times. The batching upserts shouldn't make this # too bad, though. for user_id, profile in iteritems(users_with_profile): - yield self._handle_new_user(room_id, user_id, profile) - - @defer.inlineCallbacks - def _handle_local_user(self, user_id): - """Adds a new local roomless user into the user_directory_search table. - Used to populate up the user index when we have an - user_directory_search_all_users specified. - """ - logger.debug("Adding new local user to dir, %r", user_id) - - profile = yield self.store.get_profileinfo(get_localpart_from_id(user_id)) - - row = yield self.store.get_user_in_directory(user_id) - if not row: - yield self.store.update_profile_in_user_dir( - user_id, profile.display_name, profile.avatar_url - ) + await self._handle_new_user(room_id, user_id, profile) - @defer.inlineCallbacks - def _handle_new_user(self, room_id, user_id, profile): + async def _handle_new_user(self, room_id, user_id, profile): """Called when we might need to add user to directory Args: @@ -334,18 +310,18 @@ class UserDirectoryHandler(StateDeltasHandler): """ logger.debug("Adding new user to dir, %r", user_id) - yield self.store.update_profile_in_user_dir( + await self.store.update_profile_in_user_dir( user_id, profile.display_name, profile.avatar_url ) - is_public = yield self.store.is_room_world_readable_or_publicly_joinable( + is_public = await self.store.is_room_world_readable_or_publicly_joinable( room_id ) # Now we update users who share rooms with users. - users_with_profile = yield self.state.get_current_users_in_room(room_id) + users_with_profile = await self.state.get_current_users_in_room(room_id) if is_public: - yield self.store.add_users_in_public_rooms(room_id, (user_id,)) + await self.store.add_users_in_public_rooms(room_id, (user_id,)) else: to_insert = set() @@ -376,10 +352,9 @@ class UserDirectoryHandler(StateDeltasHandler): to_insert.add((other_user_id, user_id)) if to_insert: - yield self.store.add_users_who_share_private_room(room_id, to_insert) + await self.store.add_users_who_share_private_room(room_id, to_insert) - @defer.inlineCallbacks - def _handle_remove_user(self, room_id, user_id): + async def _handle_remove_user(self, room_id, user_id): """Called when we might need to remove user from directory Args: @@ -389,24 +364,23 @@ class UserDirectoryHandler(StateDeltasHandler): logger.debug("Removing user %r", user_id) # Remove user from sharing tables - yield self.store.remove_user_who_share_room(user_id, room_id) + await self.store.remove_user_who_share_room(user_id, room_id) # Are they still in any rooms? If not, remove them entirely. - rooms_user_is_in = yield self.store.get_user_dir_rooms_user_is_in(user_id) + rooms_user_is_in = await self.store.get_user_dir_rooms_user_is_in(user_id) if len(rooms_user_is_in) == 0: - yield self.store.remove_from_user_dir(user_id) + await self.store.remove_from_user_dir(user_id) - @defer.inlineCallbacks - def _handle_profile_change(self, user_id, room_id, prev_event_id, event_id): + async def _handle_profile_change(self, user_id, room_id, prev_event_id, event_id): """Check member event changes for any profile changes and update the database if there are. """ if not prev_event_id or not event_id: return - prev_event = yield self.store.get_event(prev_event_id, allow_none=True) - event = yield self.store.get_event(event_id, allow_none=True) + prev_event = await self.store.get_event(prev_event_id, allow_none=True) + event = await self.store.get_event(event_id, allow_none=True) if not prev_event or not event: return @@ -421,4 +395,4 @@ class UserDirectoryHandler(StateDeltasHandler): new_avatar = event.content.get("avatar_url") if prev_name != new_name or prev_avatar != new_avatar: - yield self.store.update_profile_in_user_dir(user_id, new_name, new_avatar) + await self.store.update_profile_in_user_dir(user_id, new_name, new_avatar) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 572df8d80b..c15bce5bef 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -14,6 +14,8 @@ # limitations under the License. from mock import Mock +from twisted.internet import defer + import synapse.rest.admin from synapse.api.constants import UserTypes from synapse.rest.client.v1 import login, room @@ -75,18 +77,16 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) - self.store.remove_from_user_dir = Mock() - self.store.remove_from_user_in_public_room = Mock() + self.store.remove_from_user_dir = Mock(return_value=defer.succeed(None)) self.get_success(self.handler.handle_user_deactivated(s_user_id)) self.store.remove_from_user_dir.not_called() - self.store.remove_from_user_in_public_room.not_called() def test_handle_user_deactivated_regular_user(self): r_user_id = "@regular:test" self.get_success( self.store.register_user(user_id=r_user_id, password_hash=None) ) - self.store.remove_from_user_dir = Mock() + self.store.remove_from_user_dir = Mock(return_value=defer.succeed(None)) self.get_success(self.handler.handle_user_deactivated(r_user_id)) self.store.remove_from_user_dir.called_once_with(r_user_id) -- cgit 1.5.1