diff options
Diffstat (limited to 'tests/rest/client/test_register.py')
-rw-r--r-- | tests/rest/client/test_register.py | 746 |
1 files changed, 746 insertions, 0 deletions
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py new file mode 100644 index 0000000000..fecda037a5 --- /dev/null +++ b/tests/rest/client/test_register.py @@ -0,0 +1,746 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import datetime +import json +import os + +import pkg_resources + +import synapse.rest.admin +from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType +from synapse.api.errors import Codes +from synapse.appservice import ApplicationService +from synapse.rest.client import account, account_validity, login, logout, register, sync + +from tests import unittest +from tests.unittest import override_config + + +class RegisterRestServletTestCase(unittest.HomeserverTestCase): + + servlets = [ + login.register_servlets, + register.register_servlets, + synapse.rest.admin.register_servlets, + ] + url = b"/_matrix/client/r0/register" + + def default_config(self): + config = super().default_config() + config["allow_guest_access"] = True + return config + + def test_POST_appservice_registration_valid(self): + user_id = "@as_user_kermit:test" + as_token = "i_am_an_app_service" + + appservice = ApplicationService( + as_token, + self.hs.config.server_name, + id="1234", + namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, + sender="@as:test", + ) + + self.hs.get_datastore().services_cache.append(appservice) + request_data = json.dumps( + {"username": "as_user_kermit", "type": APP_SERVICE_REGISTRATION_TYPE} + ) + + channel = self.make_request( + b"POST", self.url + b"?access_token=i_am_an_app_service", request_data + ) + + self.assertEquals(channel.result["code"], b"200", channel.result) + det_data = {"user_id": user_id, "home_server": self.hs.hostname} + self.assertDictContainsSubset(det_data, channel.json_body) + + def test_POST_appservice_registration_no_type(self): + as_token = "i_am_an_app_service" + + appservice = ApplicationService( + as_token, + self.hs.config.server_name, + id="1234", + namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, + sender="@as:test", + ) + + self.hs.get_datastore().services_cache.append(appservice) + request_data = json.dumps({"username": "as_user_kermit"}) + + channel = self.make_request( + b"POST", self.url + b"?access_token=i_am_an_app_service", request_data + ) + + self.assertEquals(channel.result["code"], b"400", channel.result) + + def test_POST_appservice_registration_invalid(self): + self.appservice = None # no application service exists + request_data = json.dumps( + {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE} + ) + channel = self.make_request( + b"POST", self.url + b"?access_token=i_am_an_app_service", request_data + ) + + self.assertEquals(channel.result["code"], b"401", channel.result) + + def test_POST_bad_password(self): + request_data = json.dumps({"username": "kermit", "password": 666}) + channel = self.make_request(b"POST", self.url, request_data) + + self.assertEquals(channel.result["code"], b"400", channel.result) + self.assertEquals(channel.json_body["error"], "Invalid password") + + def test_POST_bad_username(self): + request_data = json.dumps({"username": 777, "password": "monkey"}) + channel = self.make_request(b"POST", self.url, request_data) + + self.assertEquals(channel.result["code"], b"400", channel.result) + self.assertEquals(channel.json_body["error"], "Invalid username") + + def test_POST_user_valid(self): + user_id = "@kermit:test" + device_id = "frogfone" + params = { + "username": "kermit", + "password": "monkey", + "device_id": device_id, + "auth": {"type": LoginType.DUMMY}, + } + request_data = json.dumps(params) + channel = self.make_request(b"POST", self.url, request_data) + + det_data = { + "user_id": user_id, + "home_server": self.hs.hostname, + "device_id": device_id, + } + self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertDictContainsSubset(det_data, channel.json_body) + + @override_config({"enable_registration": False}) + def test_POST_disabled_registration(self): + request_data = json.dumps({"username": "kermit", "password": "monkey"}) + self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) + + channel = self.make_request(b"POST", self.url, request_data) + + self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEquals(channel.json_body["error"], "Registration has been disabled") + self.assertEquals(channel.json_body["errcode"], "M_FORBIDDEN") + + def test_POST_guest_registration(self): + self.hs.config.macaroon_secret_key = "test" + self.hs.config.allow_guest_access = True + + channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") + + det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"} + self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertDictContainsSubset(det_data, channel.json_body) + + def test_POST_disabled_guest_registration(self): + self.hs.config.allow_guest_access = False + + channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") + + self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEquals(channel.json_body["error"], "Guest access is disabled") + + @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) + def test_POST_ratelimiting_guest(self): + for i in range(0, 6): + url = self.url + b"?kind=guest" + channel = self.make_request(b"POST", url, b"{}") + + if i == 5: + self.assertEquals(channel.result["code"], b"429", channel.result) + retry_after_ms = int(channel.json_body["retry_after_ms"]) + else: + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) + + channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") + + self.assertEquals(channel.result["code"], b"200", channel.result) + + @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) + def test_POST_ratelimiting(self): + for i in range(0, 6): + params = { + "username": "kermit" + str(i), + "password": "monkey", + "device_id": "frogfone", + "auth": {"type": LoginType.DUMMY}, + } + request_data = json.dumps(params) + channel = self.make_request(b"POST", self.url, request_data) + + if i == 5: + self.assertEquals(channel.result["code"], b"429", channel.result) + retry_after_ms = int(channel.json_body["retry_after_ms"]) + else: + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) + + channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") + + self.assertEquals(channel.result["code"], b"200", channel.result) + + def test_advertised_flows(self): + channel = self.make_request(b"POST", self.url, b"{}") + self.assertEquals(channel.result["code"], b"401", channel.result) + flows = channel.json_body["flows"] + + # with the stock config, we only expect the dummy flow + self.assertCountEqual([["m.login.dummy"]], (f["stages"] for f in flows)) + + @unittest.override_config( + { + "public_baseurl": "https://test_server", + "enable_registration_captcha": True, + "user_consent": { + "version": "1", + "template_dir": "/", + "require_at_registration": True, + }, + "account_threepid_delegates": { + "email": "https://id_server", + "msisdn": "https://id_server", + }, + } + ) + def test_advertised_flows_captcha_and_terms_and_3pids(self): + channel = self.make_request(b"POST", self.url, b"{}") + self.assertEquals(channel.result["code"], b"401", channel.result) + flows = channel.json_body["flows"] + + self.assertCountEqual( + [ + ["m.login.recaptcha", "m.login.terms", "m.login.dummy"], + ["m.login.recaptcha", "m.login.terms", "m.login.email.identity"], + ["m.login.recaptcha", "m.login.terms", "m.login.msisdn"], + [ + "m.login.recaptcha", + "m.login.terms", + "m.login.msisdn", + "m.login.email.identity", + ], + ], + (f["stages"] for f in flows), + ) + + @unittest.override_config( + { + "public_baseurl": "https://test_server", + "registrations_require_3pid": ["email"], + "disable_msisdn_registration": True, + "email": { + "smtp_host": "mail_server", + "smtp_port": 2525, + "notif_from": "sender@host", + }, + } + ) + def test_advertised_flows_no_msisdn_email_required(self): + channel = self.make_request(b"POST", self.url, b"{}") + self.assertEquals(channel.result["code"], b"401", channel.result) + flows = channel.json_body["flows"] + + # with the stock config, we expect all four combinations of 3pid + self.assertCountEqual( + [["m.login.email.identity"]], (f["stages"] for f in flows) + ) + + @unittest.override_config( + { + "request_token_inhibit_3pid_errors": True, + "public_baseurl": "https://test_server", + "email": { + "smtp_host": "mail_server", + "smtp_port": 2525, + "notif_from": "sender@host", + }, + } + ) + def test_request_token_existing_email_inhibit_error(self): + """Test that requesting a token via this endpoint doesn't leak existing + associations if configured that way. + """ + user_id = self.register_user("kermit", "monkey") + self.login("kermit", "monkey") + + email = "test@example.com" + + # Add a threepid + self.get_success( + self.hs.get_datastore().user_add_threepid( + user_id=user_id, + medium="email", + address=email, + validated_at=0, + added_at=0, + ) + ) + + channel = self.make_request( + "POST", + b"register/email/requestToken", + {"client_secret": "foobar", "email": email, "send_attempt": 1}, + ) + self.assertEquals(200, channel.code, channel.result) + + self.assertIsNotNone(channel.json_body.get("sid")) + + @unittest.override_config( + { + "public_baseurl": "https://test_server", + "email": { + "smtp_host": "mail_server", + "smtp_port": 2525, + "notif_from": "sender@host", + }, + } + ) + def test_reject_invalid_email(self): + """Check that bad emails are rejected""" + + # Test for email with multiple @ + channel = self.make_request( + "POST", + b"register/email/requestToken", + {"client_secret": "foobar", "email": "email@@email", "send_attempt": 1}, + ) + self.assertEquals(400, channel.code, channel.result) + # Check error to ensure that we're not erroring due to a bug in the test. + self.assertEquals( + channel.json_body, + {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, + ) + + # Test for email with no @ + channel = self.make_request( + "POST", + b"register/email/requestToken", + {"client_secret": "foobar", "email": "email", "send_attempt": 1}, + ) + self.assertEquals(400, channel.code, channel.result) + self.assertEquals( + channel.json_body, + {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, + ) + + # Test for super long email + email = "a@" + "a" * 1000 + channel = self.make_request( + "POST", + b"register/email/requestToken", + {"client_secret": "foobar", "email": email, "send_attempt": 1}, + ) + self.assertEquals(400, channel.code, channel.result) + self.assertEquals( + channel.json_body, + {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, + ) + + +class AccountValidityTestCase(unittest.HomeserverTestCase): + + servlets = [ + register.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + sync.register_servlets, + logout.register_servlets, + account_validity.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + # Test for account expiring after a week. + config["enable_registration"] = True + config["account_validity"] = { + "enabled": True, + "period": 604800000, # Time in ms for 1 week + } + self.hs = self.setup_test_homeserver(config=config) + + return self.hs + + def test_validity_period(self): + self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + + # The specific endpoint doesn't matter, all we need is an authenticated + # endpoint. + channel = self.make_request(b"GET", "/sync", access_token=tok) + + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) + + channel = self.make_request(b"GET", "/sync", access_token=tok) + + self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEquals( + channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result + ) + + def test_manual_renewal(self): + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + + self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) + + # If we register the admin user at the beginning of the test, it will + # expire at the same time as the normal user and the renewal request + # will be denied. + self.register_user("admin", "adminpassword", admin=True) + admin_tok = self.login("admin", "adminpassword") + + url = "/_synapse/admin/v1/account_validity/validity" + params = {"user_id": user_id} + request_data = json.dumps(params) + channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # The specific endpoint doesn't matter, all we need is an authenticated + # endpoint. + channel = self.make_request(b"GET", "/sync", access_token=tok) + self.assertEquals(channel.result["code"], b"200", channel.result) + + def test_manual_expire(self): + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + + self.register_user("admin", "adminpassword", admin=True) + admin_tok = self.login("admin", "adminpassword") + + url = "/_synapse/admin/v1/account_validity/validity" + params = { + "user_id": user_id, + "expiration_ts": 0, + "enable_renewal_emails": False, + } + request_data = json.dumps(params) + channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # The specific endpoint doesn't matter, all we need is an authenticated + # endpoint. + channel = self.make_request(b"GET", "/sync", access_token=tok) + self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEquals( + channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result + ) + + def test_logging_out_expired_user(self): + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + + self.register_user("admin", "adminpassword", admin=True) + admin_tok = self.login("admin", "adminpassword") + + url = "/_synapse/admin/v1/account_validity/validity" + params = { + "user_id": user_id, + "expiration_ts": 0, + "enable_renewal_emails": False, + } + request_data = json.dumps(params) + channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Try to log the user out + channel = self.make_request(b"POST", "/logout", access_token=tok) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Log the user in again (allowed for expired accounts) + tok = self.login("kermit", "monkey") + + # Try to log out all of the user's sessions + channel = self.make_request(b"POST", "/logout/all", access_token=tok) + self.assertEquals(channel.result["code"], b"200", channel.result) + + +class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): + + servlets = [ + register.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + sync.register_servlets, + account_validity.register_servlets, + account.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + # Test for account expiring after a week and renewal emails being sent 2 + # days before expiry. + config["enable_registration"] = True + config["account_validity"] = { + "enabled": True, + "period": 604800000, # Time in ms for 1 week + "renew_at": 172800000, # Time in ms for 2 days + "renew_by_email_enabled": True, + "renew_email_subject": "Renew your account", + "account_renewed_html_path": "account_renewed.html", + "invalid_token_html_path": "invalid_token.html", + } + + # Email config. + + config["email"] = { + "enable_notifs": True, + "template_dir": os.path.abspath( + pkg_resources.resource_filename("synapse", "res/templates") + ), + "expiry_template_html": "notice_expiry.html", + "expiry_template_text": "notice_expiry.txt", + "notif_template_html": "notif_mail.html", + "notif_template_text": "notif_mail.txt", + "smtp_host": "127.0.0.1", + "smtp_port": 20, + "require_transport_security": False, + "smtp_user": None, + "smtp_pass": None, + "notif_from": "test@example.com", + } + config["public_baseurl"] = "aaa" + + self.hs = self.setup_test_homeserver(config=config) + + async def sendmail(*args, **kwargs): + self.email_attempts.append((args, kwargs)) + + self.email_attempts = [] + self.hs.get_send_email_handler()._sendmail = sendmail + + self.store = self.hs.get_datastore() + + return self.hs + + def test_renewal_email(self): + self.email_attempts = [] + + (user_id, tok) = self.create_user() + + # Move 5 days forward. This should trigger a renewal email to be sent. + self.reactor.advance(datetime.timedelta(days=5).total_seconds()) + self.assertEqual(len(self.email_attempts), 1) + + # Retrieving the URL from the email is too much pain for now, so we + # retrieve the token from the DB. + renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id)) + url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token + channel = self.make_request(b"GET", url) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Check that we're getting HTML back. + content_type = channel.headers.getRawHeaders(b"Content-Type") + self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) + + # Check that the HTML we're getting is the one we expect on a successful renewal. + expiration_ts = self.get_success(self.store.get_expiration_ts_for_user(user_id)) + expected_html = self.hs.config.account_validity.account_validity_account_renewed_template.render( + expiration_ts=expiration_ts + ) + self.assertEqual( + channel.result["body"], expected_html.encode("utf8"), channel.result + ) + + # Move 1 day forward. Try to renew with the same token again. + url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token + channel = self.make_request(b"GET", url) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Check that we're getting HTML back. + content_type = channel.headers.getRawHeaders(b"Content-Type") + self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) + + # Check that the HTML we're getting is the one we expect when reusing a + # token. The account expiration date should not have changed. + expected_html = self.hs.config.account_validity.account_validity_account_previously_renewed_template.render( + expiration_ts=expiration_ts + ) + self.assertEqual( + channel.result["body"], expected_html.encode("utf8"), channel.result + ) + + # Move 3 days forward. If the renewal failed, every authed request with + # our access token should be denied from now, otherwise they should + # succeed. + self.reactor.advance(datetime.timedelta(days=3).total_seconds()) + channel = self.make_request(b"GET", "/sync", access_token=tok) + self.assertEquals(channel.result["code"], b"200", channel.result) + + def test_renewal_invalid_token(self): + # Hit the renewal endpoint with an invalid token and check that it behaves as + # expected, i.e. that it responds with 404 Not Found and the correct HTML. + url = "/_matrix/client/unstable/account_validity/renew?token=123" + channel = self.make_request(b"GET", url) + self.assertEquals(channel.result["code"], b"404", channel.result) + + # Check that we're getting HTML back. + content_type = channel.headers.getRawHeaders(b"Content-Type") + self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) + + # Check that the HTML we're getting is the one we expect when using an + # invalid/unknown token. + expected_html = ( + self.hs.config.account_validity.account_validity_invalid_token_template.render() + ) + self.assertEqual( + channel.result["body"], expected_html.encode("utf8"), channel.result + ) + + def test_manual_email_send(self): + self.email_attempts = [] + + (user_id, tok) = self.create_user() + channel = self.make_request( + b"POST", + "/_matrix/client/unstable/account_validity/send_mail", + access_token=tok, + ) + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.assertEqual(len(self.email_attempts), 1) + + def test_deactivated_user(self): + self.email_attempts = [] + + (user_id, tok) = self.create_user() + + request_data = json.dumps( + { + "auth": { + "type": "m.login.password", + "user": user_id, + "password": "monkey", + }, + "erase": False, + } + ) + channel = self.make_request( + "POST", "account/deactivate", request_data, access_token=tok + ) + self.assertEqual(channel.code, 200) + + self.reactor.advance(datetime.timedelta(days=8).total_seconds()) + + self.assertEqual(len(self.email_attempts), 0) + + def create_user(self): + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + # We need to manually add an email address otherwise the handler will do + # nothing. + now = self.hs.get_clock().time_msec() + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address="kermit@example.com", + validated_at=now, + added_at=now, + ) + ) + return user_id, tok + + def test_manual_email_send_expired_account(self): + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + + # We need to manually add an email address otherwise the handler will do + # nothing. + now = self.hs.get_clock().time_msec() + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address="kermit@example.com", + validated_at=now, + added_at=now, + ) + ) + + # Make the account expire. + self.reactor.advance(datetime.timedelta(days=8).total_seconds()) + + # Ignore all emails sent by the automatic background task and only focus on the + # ones sent manually. + self.email_attempts = [] + + # Test that we're still able to manually trigger a mail to be sent. + channel = self.make_request( + b"POST", + "/_matrix/client/unstable/account_validity/send_mail", + access_token=tok, + ) + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.assertEqual(len(self.email_attempts), 1) + + +class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): + + servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource] + + def make_homeserver(self, reactor, clock): + self.validity_period = 10 + self.max_delta = self.validity_period * 10.0 / 100.0 + + config = self.default_config() + + config["enable_registration"] = True + config["account_validity"] = {"enabled": False} + + self.hs = self.setup_test_homeserver(config=config) + + # We need to set these directly, instead of in the homeserver config dict above. + # This is due to account validity-related config options not being read by + # Synapse when account_validity.enabled is False. + self.hs.get_datastore()._account_validity_period = self.validity_period + self.hs.get_datastore()._account_validity_startup_job_max_delta = self.max_delta + + self.store = self.hs.get_datastore() + + return self.hs + + def test_background_job(self): + """ + Tests the same thing as test_background_job, except that it sets the + startup_job_max_delta parameter and checks that the expiration date is within the + allowed range. + """ + user_id = self.register_user("kermit_delta", "user") + + self.hs.config.account_validity.startup_job_max_delta = self.max_delta + + now_ms = self.hs.get_clock().time_msec() + self.get_success(self.store._set_expiration_date_when_missing()) + + res = self.get_success(self.store.get_expiration_ts_for_user(user_id)) + + self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta) + self.assertLessEqual(res, now_ms + self.validity_period) |