summary refs log tree commit diff
path: root/tests/rest/client/test_register.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client/test_register.py')
-rw-r--r--tests/rest/client/test_register.py746
1 files changed, 746 insertions, 0 deletions
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
new file mode 100644
index 0000000000..fecda037a5
--- /dev/null
+++ b/tests/rest/client/test_register.py
@@ -0,0 +1,746 @@
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import datetime
+import json
+import os
+
+import pkg_resources
+
+import synapse.rest.admin
+from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
+from synapse.api.errors import Codes
+from synapse.appservice import ApplicationService
+from synapse.rest.client import account, account_validity, login, logout, register, sync
+
+from tests import unittest
+from tests.unittest import override_config
+
+
+class RegisterRestServletTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        login.register_servlets,
+        register.register_servlets,
+        synapse.rest.admin.register_servlets,
+    ]
+    url = b"/_matrix/client/r0/register"
+
+    def default_config(self):
+        config = super().default_config()
+        config["allow_guest_access"] = True
+        return config
+
+    def test_POST_appservice_registration_valid(self):
+        user_id = "@as_user_kermit:test"
+        as_token = "i_am_an_app_service"
+
+        appservice = ApplicationService(
+            as_token,
+            self.hs.config.server_name,
+            id="1234",
+            namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
+            sender="@as:test",
+        )
+
+        self.hs.get_datastore().services_cache.append(appservice)
+        request_data = json.dumps(
+            {"username": "as_user_kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
+        )
+
+        channel = self.make_request(
+            b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
+        )
+
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+        det_data = {"user_id": user_id, "home_server": self.hs.hostname}
+        self.assertDictContainsSubset(det_data, channel.json_body)
+
+    def test_POST_appservice_registration_no_type(self):
+        as_token = "i_am_an_app_service"
+
+        appservice = ApplicationService(
+            as_token,
+            self.hs.config.server_name,
+            id="1234",
+            namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
+            sender="@as:test",
+        )
+
+        self.hs.get_datastore().services_cache.append(appservice)
+        request_data = json.dumps({"username": "as_user_kermit"})
+
+        channel = self.make_request(
+            b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
+        )
+
+        self.assertEquals(channel.result["code"], b"400", channel.result)
+
+    def test_POST_appservice_registration_invalid(self):
+        self.appservice = None  # no application service exists
+        request_data = json.dumps(
+            {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
+        )
+        channel = self.make_request(
+            b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
+        )
+
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+
+    def test_POST_bad_password(self):
+        request_data = json.dumps({"username": "kermit", "password": 666})
+        channel = self.make_request(b"POST", self.url, request_data)
+
+        self.assertEquals(channel.result["code"], b"400", channel.result)
+        self.assertEquals(channel.json_body["error"], "Invalid password")
+
+    def test_POST_bad_username(self):
+        request_data = json.dumps({"username": 777, "password": "monkey"})
+        channel = self.make_request(b"POST", self.url, request_data)
+
+        self.assertEquals(channel.result["code"], b"400", channel.result)
+        self.assertEquals(channel.json_body["error"], "Invalid username")
+
+    def test_POST_user_valid(self):
+        user_id = "@kermit:test"
+        device_id = "frogfone"
+        params = {
+            "username": "kermit",
+            "password": "monkey",
+            "device_id": device_id,
+            "auth": {"type": LoginType.DUMMY},
+        }
+        request_data = json.dumps(params)
+        channel = self.make_request(b"POST", self.url, request_data)
+
+        det_data = {
+            "user_id": user_id,
+            "home_server": self.hs.hostname,
+            "device_id": device_id,
+        }
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertDictContainsSubset(det_data, channel.json_body)
+
+    @override_config({"enable_registration": False})
+    def test_POST_disabled_registration(self):
+        request_data = json.dumps({"username": "kermit", "password": "monkey"})
+        self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
+
+        channel = self.make_request(b"POST", self.url, request_data)
+
+        self.assertEquals(channel.result["code"], b"403", channel.result)
+        self.assertEquals(channel.json_body["error"], "Registration has been disabled")
+        self.assertEquals(channel.json_body["errcode"], "M_FORBIDDEN")
+
+    def test_POST_guest_registration(self):
+        self.hs.config.macaroon_secret_key = "test"
+        self.hs.config.allow_guest_access = True
+
+        channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+
+        det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"}
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertDictContainsSubset(det_data, channel.json_body)
+
+    def test_POST_disabled_guest_registration(self):
+        self.hs.config.allow_guest_access = False
+
+        channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+
+        self.assertEquals(channel.result["code"], b"403", channel.result)
+        self.assertEquals(channel.json_body["error"], "Guest access is disabled")
+
+    @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
+    def test_POST_ratelimiting_guest(self):
+        for i in range(0, 6):
+            url = self.url + b"?kind=guest"
+            channel = self.make_request(b"POST", url, b"{}")
+
+            if i == 5:
+                self.assertEquals(channel.result["code"], b"429", channel.result)
+                retry_after_ms = int(channel.json_body["retry_after_ms"])
+            else:
+                self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
+
+        channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+
+    @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
+    def test_POST_ratelimiting(self):
+        for i in range(0, 6):
+            params = {
+                "username": "kermit" + str(i),
+                "password": "monkey",
+                "device_id": "frogfone",
+                "auth": {"type": LoginType.DUMMY},
+            }
+            request_data = json.dumps(params)
+            channel = self.make_request(b"POST", self.url, request_data)
+
+            if i == 5:
+                self.assertEquals(channel.result["code"], b"429", channel.result)
+                retry_after_ms = int(channel.json_body["retry_after_ms"])
+            else:
+                self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
+
+        channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+
+    def test_advertised_flows(self):
+        channel = self.make_request(b"POST", self.url, b"{}")
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        flows = channel.json_body["flows"]
+
+        # with the stock config, we only expect the dummy flow
+        self.assertCountEqual([["m.login.dummy"]], (f["stages"] for f in flows))
+
+    @unittest.override_config(
+        {
+            "public_baseurl": "https://test_server",
+            "enable_registration_captcha": True,
+            "user_consent": {
+                "version": "1",
+                "template_dir": "/",
+                "require_at_registration": True,
+            },
+            "account_threepid_delegates": {
+                "email": "https://id_server",
+                "msisdn": "https://id_server",
+            },
+        }
+    )
+    def test_advertised_flows_captcha_and_terms_and_3pids(self):
+        channel = self.make_request(b"POST", self.url, b"{}")
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        flows = channel.json_body["flows"]
+
+        self.assertCountEqual(
+            [
+                ["m.login.recaptcha", "m.login.terms", "m.login.dummy"],
+                ["m.login.recaptcha", "m.login.terms", "m.login.email.identity"],
+                ["m.login.recaptcha", "m.login.terms", "m.login.msisdn"],
+                [
+                    "m.login.recaptcha",
+                    "m.login.terms",
+                    "m.login.msisdn",
+                    "m.login.email.identity",
+                ],
+            ],
+            (f["stages"] for f in flows),
+        )
+
+    @unittest.override_config(
+        {
+            "public_baseurl": "https://test_server",
+            "registrations_require_3pid": ["email"],
+            "disable_msisdn_registration": True,
+            "email": {
+                "smtp_host": "mail_server",
+                "smtp_port": 2525,
+                "notif_from": "sender@host",
+            },
+        }
+    )
+    def test_advertised_flows_no_msisdn_email_required(self):
+        channel = self.make_request(b"POST", self.url, b"{}")
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        flows = channel.json_body["flows"]
+
+        # with the stock config, we expect all four combinations of 3pid
+        self.assertCountEqual(
+            [["m.login.email.identity"]], (f["stages"] for f in flows)
+        )
+
+    @unittest.override_config(
+        {
+            "request_token_inhibit_3pid_errors": True,
+            "public_baseurl": "https://test_server",
+            "email": {
+                "smtp_host": "mail_server",
+                "smtp_port": 2525,
+                "notif_from": "sender@host",
+            },
+        }
+    )
+    def test_request_token_existing_email_inhibit_error(self):
+        """Test that requesting a token via this endpoint doesn't leak existing
+        associations if configured that way.
+        """
+        user_id = self.register_user("kermit", "monkey")
+        self.login("kermit", "monkey")
+
+        email = "test@example.com"
+
+        # Add a threepid
+        self.get_success(
+            self.hs.get_datastore().user_add_threepid(
+                user_id=user_id,
+                medium="email",
+                address=email,
+                validated_at=0,
+                added_at=0,
+            )
+        )
+
+        channel = self.make_request(
+            "POST",
+            b"register/email/requestToken",
+            {"client_secret": "foobar", "email": email, "send_attempt": 1},
+        )
+        self.assertEquals(200, channel.code, channel.result)
+
+        self.assertIsNotNone(channel.json_body.get("sid"))
+
+    @unittest.override_config(
+        {
+            "public_baseurl": "https://test_server",
+            "email": {
+                "smtp_host": "mail_server",
+                "smtp_port": 2525,
+                "notif_from": "sender@host",
+            },
+        }
+    )
+    def test_reject_invalid_email(self):
+        """Check that bad emails are rejected"""
+
+        # Test for email with multiple @
+        channel = self.make_request(
+            "POST",
+            b"register/email/requestToken",
+            {"client_secret": "foobar", "email": "email@@email", "send_attempt": 1},
+        )
+        self.assertEquals(400, channel.code, channel.result)
+        # Check error to ensure that we're not erroring due to a bug in the test.
+        self.assertEquals(
+            channel.json_body,
+            {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"},
+        )
+
+        # Test for email with no @
+        channel = self.make_request(
+            "POST",
+            b"register/email/requestToken",
+            {"client_secret": "foobar", "email": "email", "send_attempt": 1},
+        )
+        self.assertEquals(400, channel.code, channel.result)
+        self.assertEquals(
+            channel.json_body,
+            {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"},
+        )
+
+        # Test for super long email
+        email = "a@" + "a" * 1000
+        channel = self.make_request(
+            "POST",
+            b"register/email/requestToken",
+            {"client_secret": "foobar", "email": email, "send_attempt": 1},
+        )
+        self.assertEquals(400, channel.code, channel.result)
+        self.assertEquals(
+            channel.json_body,
+            {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"},
+        )
+
+
+class AccountValidityTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        register.register_servlets,
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        login.register_servlets,
+        sync.register_servlets,
+        logout.register_servlets,
+        account_validity.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor, clock):
+        config = self.default_config()
+        # Test for account expiring after a week.
+        config["enable_registration"] = True
+        config["account_validity"] = {
+            "enabled": True,
+            "period": 604800000,  # Time in ms for 1 week
+        }
+        self.hs = self.setup_test_homeserver(config=config)
+
+        return self.hs
+
+    def test_validity_period(self):
+        self.register_user("kermit", "monkey")
+        tok = self.login("kermit", "monkey")
+
+        # The specific endpoint doesn't matter, all we need is an authenticated
+        # endpoint.
+        channel = self.make_request(b"GET", "/sync", access_token=tok)
+
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
+
+        channel = self.make_request(b"GET", "/sync", access_token=tok)
+
+        self.assertEquals(channel.result["code"], b"403", channel.result)
+        self.assertEquals(
+            channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
+        )
+
+    def test_manual_renewal(self):
+        user_id = self.register_user("kermit", "monkey")
+        tok = self.login("kermit", "monkey")
+
+        self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
+
+        # If we register the admin user at the beginning of the test, it will
+        # expire at the same time as the normal user and the renewal request
+        # will be denied.
+        self.register_user("admin", "adminpassword", admin=True)
+        admin_tok = self.login("admin", "adminpassword")
+
+        url = "/_synapse/admin/v1/account_validity/validity"
+        params = {"user_id": user_id}
+        request_data = json.dumps(params)
+        channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        # The specific endpoint doesn't matter, all we need is an authenticated
+        # endpoint.
+        channel = self.make_request(b"GET", "/sync", access_token=tok)
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+
+    def test_manual_expire(self):
+        user_id = self.register_user("kermit", "monkey")
+        tok = self.login("kermit", "monkey")
+
+        self.register_user("admin", "adminpassword", admin=True)
+        admin_tok = self.login("admin", "adminpassword")
+
+        url = "/_synapse/admin/v1/account_validity/validity"
+        params = {
+            "user_id": user_id,
+            "expiration_ts": 0,
+            "enable_renewal_emails": False,
+        }
+        request_data = json.dumps(params)
+        channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        # The specific endpoint doesn't matter, all we need is an authenticated
+        # endpoint.
+        channel = self.make_request(b"GET", "/sync", access_token=tok)
+        self.assertEquals(channel.result["code"], b"403", channel.result)
+        self.assertEquals(
+            channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
+        )
+
+    def test_logging_out_expired_user(self):
+        user_id = self.register_user("kermit", "monkey")
+        tok = self.login("kermit", "monkey")
+
+        self.register_user("admin", "adminpassword", admin=True)
+        admin_tok = self.login("admin", "adminpassword")
+
+        url = "/_synapse/admin/v1/account_validity/validity"
+        params = {
+            "user_id": user_id,
+            "expiration_ts": 0,
+            "enable_renewal_emails": False,
+        }
+        request_data = json.dumps(params)
+        channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        # Try to log the user out
+        channel = self.make_request(b"POST", "/logout", access_token=tok)
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        # Log the user in again (allowed for expired accounts)
+        tok = self.login("kermit", "monkey")
+
+        # Try to log out all of the user's sessions
+        channel = self.make_request(b"POST", "/logout/all", access_token=tok)
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+
+
+class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        register.register_servlets,
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        login.register_servlets,
+        sync.register_servlets,
+        account_validity.register_servlets,
+        account.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor, clock):
+        config = self.default_config()
+
+        # Test for account expiring after a week and renewal emails being sent 2
+        # days before expiry.
+        config["enable_registration"] = True
+        config["account_validity"] = {
+            "enabled": True,
+            "period": 604800000,  # Time in ms for 1 week
+            "renew_at": 172800000,  # Time in ms for 2 days
+            "renew_by_email_enabled": True,
+            "renew_email_subject": "Renew your account",
+            "account_renewed_html_path": "account_renewed.html",
+            "invalid_token_html_path": "invalid_token.html",
+        }
+
+        # Email config.
+
+        config["email"] = {
+            "enable_notifs": True,
+            "template_dir": os.path.abspath(
+                pkg_resources.resource_filename("synapse", "res/templates")
+            ),
+            "expiry_template_html": "notice_expiry.html",
+            "expiry_template_text": "notice_expiry.txt",
+            "notif_template_html": "notif_mail.html",
+            "notif_template_text": "notif_mail.txt",
+            "smtp_host": "127.0.0.1",
+            "smtp_port": 20,
+            "require_transport_security": False,
+            "smtp_user": None,
+            "smtp_pass": None,
+            "notif_from": "test@example.com",
+        }
+        config["public_baseurl"] = "aaa"
+
+        self.hs = self.setup_test_homeserver(config=config)
+
+        async def sendmail(*args, **kwargs):
+            self.email_attempts.append((args, kwargs))
+
+        self.email_attempts = []
+        self.hs.get_send_email_handler()._sendmail = sendmail
+
+        self.store = self.hs.get_datastore()
+
+        return self.hs
+
+    def test_renewal_email(self):
+        self.email_attempts = []
+
+        (user_id, tok) = self.create_user()
+
+        # Move 5 days forward. This should trigger a renewal email to be sent.
+        self.reactor.advance(datetime.timedelta(days=5).total_seconds())
+        self.assertEqual(len(self.email_attempts), 1)
+
+        # Retrieving the URL from the email is too much pain for now, so we
+        # retrieve the token from the DB.
+        renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id))
+        url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
+        channel = self.make_request(b"GET", url)
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        # Check that we're getting HTML back.
+        content_type = channel.headers.getRawHeaders(b"Content-Type")
+        self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result)
+
+        # Check that the HTML we're getting is the one we expect on a successful renewal.
+        expiration_ts = self.get_success(self.store.get_expiration_ts_for_user(user_id))
+        expected_html = self.hs.config.account_validity.account_validity_account_renewed_template.render(
+            expiration_ts=expiration_ts
+        )
+        self.assertEqual(
+            channel.result["body"], expected_html.encode("utf8"), channel.result
+        )
+
+        # Move 1 day forward. Try to renew with the same token again.
+        url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
+        channel = self.make_request(b"GET", url)
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        # Check that we're getting HTML back.
+        content_type = channel.headers.getRawHeaders(b"Content-Type")
+        self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result)
+
+        # Check that the HTML we're getting is the one we expect when reusing a
+        # token. The account expiration date should not have changed.
+        expected_html = self.hs.config.account_validity.account_validity_account_previously_renewed_template.render(
+            expiration_ts=expiration_ts
+        )
+        self.assertEqual(
+            channel.result["body"], expected_html.encode("utf8"), channel.result
+        )
+
+        # Move 3 days forward. If the renewal failed, every authed request with
+        # our access token should be denied from now, otherwise they should
+        # succeed.
+        self.reactor.advance(datetime.timedelta(days=3).total_seconds())
+        channel = self.make_request(b"GET", "/sync", access_token=tok)
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+
+    def test_renewal_invalid_token(self):
+        # Hit the renewal endpoint with an invalid token and check that it behaves as
+        # expected, i.e. that it responds with 404 Not Found and the correct HTML.
+        url = "/_matrix/client/unstable/account_validity/renew?token=123"
+        channel = self.make_request(b"GET", url)
+        self.assertEquals(channel.result["code"], b"404", channel.result)
+
+        # Check that we're getting HTML back.
+        content_type = channel.headers.getRawHeaders(b"Content-Type")
+        self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result)
+
+        # Check that the HTML we're getting is the one we expect when using an
+        # invalid/unknown token.
+        expected_html = (
+            self.hs.config.account_validity.account_validity_invalid_token_template.render()
+        )
+        self.assertEqual(
+            channel.result["body"], expected_html.encode("utf8"), channel.result
+        )
+
+    def test_manual_email_send(self):
+        self.email_attempts = []
+
+        (user_id, tok) = self.create_user()
+        channel = self.make_request(
+            b"POST",
+            "/_matrix/client/unstable/account_validity/send_mail",
+            access_token=tok,
+        )
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        self.assertEqual(len(self.email_attempts), 1)
+
+    def test_deactivated_user(self):
+        self.email_attempts = []
+
+        (user_id, tok) = self.create_user()
+
+        request_data = json.dumps(
+            {
+                "auth": {
+                    "type": "m.login.password",
+                    "user": user_id,
+                    "password": "monkey",
+                },
+                "erase": False,
+            }
+        )
+        channel = self.make_request(
+            "POST", "account/deactivate", request_data, access_token=tok
+        )
+        self.assertEqual(channel.code, 200)
+
+        self.reactor.advance(datetime.timedelta(days=8).total_seconds())
+
+        self.assertEqual(len(self.email_attempts), 0)
+
+    def create_user(self):
+        user_id = self.register_user("kermit", "monkey")
+        tok = self.login("kermit", "monkey")
+        # We need to manually add an email address otherwise the handler will do
+        # nothing.
+        now = self.hs.get_clock().time_msec()
+        self.get_success(
+            self.store.user_add_threepid(
+                user_id=user_id,
+                medium="email",
+                address="kermit@example.com",
+                validated_at=now,
+                added_at=now,
+            )
+        )
+        return user_id, tok
+
+    def test_manual_email_send_expired_account(self):
+        user_id = self.register_user("kermit", "monkey")
+        tok = self.login("kermit", "monkey")
+
+        # We need to manually add an email address otherwise the handler will do
+        # nothing.
+        now = self.hs.get_clock().time_msec()
+        self.get_success(
+            self.store.user_add_threepid(
+                user_id=user_id,
+                medium="email",
+                address="kermit@example.com",
+                validated_at=now,
+                added_at=now,
+            )
+        )
+
+        # Make the account expire.
+        self.reactor.advance(datetime.timedelta(days=8).total_seconds())
+
+        # Ignore all emails sent by the automatic background task and only focus on the
+        # ones sent manually.
+        self.email_attempts = []
+
+        # Test that we're still able to manually trigger a mail to be sent.
+        channel = self.make_request(
+            b"POST",
+            "/_matrix/client/unstable/account_validity/send_mail",
+            access_token=tok,
+        )
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        self.assertEqual(len(self.email_attempts), 1)
+
+
+class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
+
+    servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
+
+    def make_homeserver(self, reactor, clock):
+        self.validity_period = 10
+        self.max_delta = self.validity_period * 10.0 / 100.0
+
+        config = self.default_config()
+
+        config["enable_registration"] = True
+        config["account_validity"] = {"enabled": False}
+
+        self.hs = self.setup_test_homeserver(config=config)
+
+        # We need to set these directly, instead of in the homeserver config dict above.
+        # This is due to account validity-related config options not being read by
+        # Synapse when account_validity.enabled is False.
+        self.hs.get_datastore()._account_validity_period = self.validity_period
+        self.hs.get_datastore()._account_validity_startup_job_max_delta = self.max_delta
+
+        self.store = self.hs.get_datastore()
+
+        return self.hs
+
+    def test_background_job(self):
+        """
+        Tests the same thing as test_background_job, except that it sets the
+        startup_job_max_delta parameter and checks that the expiration date is within the
+        allowed range.
+        """
+        user_id = self.register_user("kermit_delta", "user")
+
+        self.hs.config.account_validity.startup_job_max_delta = self.max_delta
+
+        now_ms = self.hs.get_clock().time_msec()
+        self.get_success(self.store._set_expiration_date_when_missing())
+
+        res = self.get_success(self.store.get_expiration_ts_for_user(user_id))
+
+        self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
+        self.assertLessEqual(res, now_ms + self.validity_period)