diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 8fb5140a05..3d44667489 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -201,6 +201,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
admin.register_servlets,
login.register_servlets,
sync.register_servlets,
+ account_validity.register_servlets,
]
def make_homeserver(self, reactor, clock):
@@ -238,6 +239,68 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result,
)
+ def test_manual_renewal(self):
+ user_id = self.register_user("kermit", "monkey")
+ tok = self.login("kermit", "monkey")
+
+ self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
+
+ # If we register the admin user at the beginning of the test, it will
+ # expire at the same time as the normal user and the renewal request
+ # will be denied.
+ self.register_user("admin", "adminpassword", admin=True)
+ admin_tok = self.login("admin", "adminpassword")
+
+ url = "/_matrix/client/unstable/admin/account_validity/validity"
+ params = {
+ "user_id": user_id,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok,
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # The specific endpoint doesn't matter, all we need is an authenticated
+ # endpoint.
+ request, channel = self.make_request(
+ b"GET", "/sync", access_token=tok,
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ def test_manual_expire(self):
+ user_id = self.register_user("kermit", "monkey")
+ tok = self.login("kermit", "monkey")
+
+ self.register_user("admin", "adminpassword", admin=True)
+ admin_tok = self.login("admin", "adminpassword")
+
+ url = "/_matrix/client/unstable/admin/account_validity/validity"
+ params = {
+ "user_id": user_id,
+ "expiration_ts": 0,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok,
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # The specific endpoint doesn't matter, all we need is an authenticated
+ # endpoint.
+ request, channel = self.make_request(
+ b"GET", "/sync", access_token=tok,
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"403", channel.result)
+ self.assertEquals(
+ channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result,
+ )
+
class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
@@ -287,6 +350,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
return self.hs
def test_renewal_email(self):
+ self.email_attempts = []
+
user_id = self.register_user("kermit", "monkey")
tok = self.login("kermit", "monkey")
# We need to manually add an email address otherwise the handler will do
@@ -297,14 +362,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
validated_at=now, added_at=now,
))
- # The specific endpoint doesn't matter, all we need is an authenticated
- # endpoint.
- request, channel = self.make_request(
- b"GET", "/sync", access_token=tok,
- )
- self.render(request)
- self.assertEquals(channel.result["code"], b"200", channel.result)
-
# Move 6 days forward. This should trigger a renewal email to be sent.
self.reactor.advance(datetime.timedelta(days=6).total_seconds())
self.assertEqual(len(self.email_attempts), 1)
@@ -326,3 +383,25 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
)
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ def test_manual_email_send(self):
+ self.email_attempts = []
+
+ user_id = self.register_user("kermit", "monkey")
+ tok = self.login("kermit", "monkey")
+ # We need to manually add an email address otherwise the handler will do
+ # nothing.
+ now = self.hs.clock.time_msec()
+ self.get_success(self.store.user_add_threepid(
+ user_id=user_id, medium="email", address="kermit@example.com",
+ validated_at=now, added_at=now,
+ ))
+
+ request, channel = self.make_request(
+ b"POST", "/_matrix/client/unstable/account_validity/send_mail",
+ access_token=tok,
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ self.assertEqual(len(self.email_attempts), 1)
|