summary refs log tree commit diff
path: root/tests/rest/client/test_register.py
diff options
context:
space:
mode:
authorCallum Brown <callum@calcuode.com>2021-08-21 22:14:43 +0100
committerGitHub <noreply@github.com>2021-08-21 22:14:43 +0100
commit947dbbdfd1e0029da66f956d277b7c089928e1e7 (patch)
tree57cf53bcbb1f02e75e114cde5d0aa77662163038 /tests/rest/client/test_register.py
parentFlatten tests/rest/client/{v1,v2_alpha} too (#10667) (diff)
downloadsynapse-947dbbdfd1e0029da66f956d277b7c089928e1e7.tar.xz
Implement MSC3231: Token authenticated registration (#10142)
Signed-off-by: Callum Brown <callum@calcuode.com>

This is part of my GSoC project implementing [MSC3231](https://github.com/matrix-org/matrix-doc/pull/3231).
Diffstat (limited to 'tests/rest/client/test_register.py')
-rw-r--r--tests/rest/client/test_register.py434
1 files changed, 434 insertions, 0 deletions
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index fecda037a5..9f3ab2c985 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -24,6 +24,7 @@ from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
 from synapse.api.errors import Codes
 from synapse.appservice import ApplicationService
 from synapse.rest.client import account, account_validity, login, logout, register, sync
+from synapse.storage._base import db_to_json
 
 from tests import unittest
 from tests.unittest import override_config
@@ -204,6 +205,371 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
+    @override_config({"registration_requires_token": True})
+    def test_POST_registration_requires_token(self):
+        username = "kermit"
+        device_id = "frogfone"
+        token = "abcd"
+        store = self.hs.get_datastore()
+        self.get_success(
+            store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": None,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": None,
+                },
+            )
+        )
+        params = {
+            "username": username,
+            "password": "monkey",
+            "device_id": device_id,
+        }
+
+        # Request without auth to get flows and session
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        flows = channel.json_body["flows"]
+        # Synapse adds a dummy stage to differentiate flows where otherwise one
+        # flow would be a subset of another flow.
+        self.assertCountEqual(
+            [[LoginType.REGISTRATION_TOKEN, LoginType.DUMMY]],
+            (f["stages"] for f in flows),
+        )
+        session = channel.json_body["session"]
+
+        # Do the registration token stage and check it has completed
+        params["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session,
+        }
+        request_data = json.dumps(params)
+        channel = self.make_request(b"POST", self.url, request_data)
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        completed = channel.json_body["completed"]
+        self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
+
+        # Do the m.login.dummy stage and check registration was successful
+        params["auth"] = {
+            "type": LoginType.DUMMY,
+            "session": session,
+        }
+        request_data = json.dumps(params)
+        channel = self.make_request(b"POST", self.url, request_data)
+        det_data = {
+            "user_id": f"@{username}:{self.hs.hostname}",
+            "home_server": self.hs.hostname,
+            "device_id": device_id,
+        }
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertDictContainsSubset(det_data, channel.json_body)
+
+        # Check the `completed` counter has been incremented and pending is 0
+        res = self.get_success(
+            store.db_pool.simple_select_one(
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcols=["pending", "completed"],
+            )
+        )
+        self.assertEquals(res["completed"], 1)
+        self.assertEquals(res["pending"], 0)
+
+    @override_config({"registration_requires_token": True})
+    def test_POST_registration_token_invalid(self):
+        params = {
+            "username": "kermit",
+            "password": "monkey",
+        }
+        # Request without auth to get session
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        session = channel.json_body["session"]
+
+        # Test with token param missing (invalid)
+        params["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "session": session,
+        }
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEquals(channel.json_body["errcode"], Codes.MISSING_PARAM)
+        self.assertEquals(channel.json_body["completed"], [])
+
+        # Test with non-string (invalid)
+        params["auth"]["token"] = 1234
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEquals(channel.json_body["errcode"], Codes.INVALID_PARAM)
+        self.assertEquals(channel.json_body["completed"], [])
+
+        # Test with unknown token (invalid)
+        params["auth"]["token"] = "1234"
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+        self.assertEquals(channel.json_body["completed"], [])
+
+    @override_config({"registration_requires_token": True})
+    def test_POST_registration_token_limit_uses(self):
+        token = "abcd"
+        store = self.hs.get_datastore()
+        # Create token that can be used once
+        self.get_success(
+            store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": 1,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": None,
+                },
+            )
+        )
+        params1 = {"username": "bert", "password": "monkey"}
+        params2 = {"username": "ernie", "password": "monkey"}
+        # Do 2 requests without auth to get two session IDs
+        channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
+        session1 = channel1.json_body["session"]
+        channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
+        session2 = channel2.json_body["session"]
+
+        # Use token with session1 and check `pending` is 1
+        params1["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session1,
+        }
+        self.make_request(b"POST", self.url, json.dumps(params1))
+        # Repeat request to make sure pending isn't increased again
+        self.make_request(b"POST", self.url, json.dumps(params1))
+        pending = self.get_success(
+            store.db_pool.simple_select_one_onecol(
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcol="pending",
+            )
+        )
+        self.assertEquals(pending, 1)
+
+        # Check auth fails when using token with session2
+        params2["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session2,
+        }
+        channel = self.make_request(b"POST", self.url, json.dumps(params2))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+        self.assertEquals(channel.json_body["completed"], [])
+
+        # Complete registration with session1
+        params1["auth"]["type"] = LoginType.DUMMY
+        self.make_request(b"POST", self.url, json.dumps(params1))
+        # Check pending=0 and completed=1
+        res = self.get_success(
+            store.db_pool.simple_select_one(
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcols=["pending", "completed"],
+            )
+        )
+        self.assertEquals(res["pending"], 0)
+        self.assertEquals(res["completed"], 1)
+
+        # Check auth still fails when using token with session2
+        channel = self.make_request(b"POST", self.url, json.dumps(params2))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+        self.assertEquals(channel.json_body["completed"], [])
+
+    @override_config({"registration_requires_token": True})
+    def test_POST_registration_token_expiry(self):
+        token = "abcd"
+        now = self.hs.get_clock().time_msec()
+        store = self.hs.get_datastore()
+        # Create token that expired yesterday
+        self.get_success(
+            store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": None,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": now - 24 * 60 * 60 * 1000,
+                },
+            )
+        )
+        params = {"username": "kermit", "password": "monkey"}
+        # Request without auth to get session
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        session = channel.json_body["session"]
+
+        # Check authentication fails with expired token
+        params["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session,
+        }
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+        self.assertEquals(channel.json_body["completed"], [])
+
+        # Update token so it expires tomorrow
+        self.get_success(
+            store.db_pool.simple_update_one(
+                "registration_tokens",
+                keyvalues={"token": token},
+                updatevalues={"expiry_time": now + 24 * 60 * 60 * 1000},
+            )
+        )
+
+        # Check authentication succeeds
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        completed = channel.json_body["completed"]
+        self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
+
+    @override_config({"registration_requires_token": True})
+    def test_POST_registration_token_session_expiry(self):
+        """Test `pending` is decremented when an uncompleted session expires."""
+        token = "abcd"
+        store = self.hs.get_datastore()
+        self.get_success(
+            store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": None,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": None,
+                },
+            )
+        )
+
+        # Do 2 requests without auth to get two session IDs
+        params1 = {"username": "bert", "password": "monkey"}
+        params2 = {"username": "ernie", "password": "monkey"}
+        channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
+        session1 = channel1.json_body["session"]
+        channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
+        session2 = channel2.json_body["session"]
+
+        # Use token with both sessions
+        params1["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session1,
+        }
+        self.make_request(b"POST", self.url, json.dumps(params1))
+
+        params2["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session2,
+        }
+        self.make_request(b"POST", self.url, json.dumps(params2))
+
+        # Complete registration with session1
+        params1["auth"]["type"] = LoginType.DUMMY
+        self.make_request(b"POST", self.url, json.dumps(params1))
+
+        # Check `result` of registration token stage for session1 is `True`
+        result1 = self.get_success(
+            store.db_pool.simple_select_one_onecol(
+                "ui_auth_sessions_credentials",
+                keyvalues={
+                    "session_id": session1,
+                    "stage_type": LoginType.REGISTRATION_TOKEN,
+                },
+                retcol="result",
+            )
+        )
+        self.assertTrue(db_to_json(result1))
+
+        # Check `result` for session2 is the token used
+        result2 = self.get_success(
+            store.db_pool.simple_select_one_onecol(
+                "ui_auth_sessions_credentials",
+                keyvalues={
+                    "session_id": session2,
+                    "stage_type": LoginType.REGISTRATION_TOKEN,
+                },
+                retcol="result",
+            )
+        )
+        self.assertEquals(db_to_json(result2), token)
+
+        # Delete both sessions (mimics expiry)
+        self.get_success(
+            store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
+        )
+
+        # Check pending is now 0
+        pending = self.get_success(
+            store.db_pool.simple_select_one_onecol(
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcol="pending",
+            )
+        )
+        self.assertEquals(pending, 0)
+
+    @override_config({"registration_requires_token": True})
+    def test_POST_registration_token_session_expiry_deleted_token(self):
+        """Test session expiry doesn't break when the token is deleted.
+
+        1. Start but don't complete UIA with a registration token
+        2. Delete the token from the database
+        3. Expire the session
+        """
+        token = "abcd"
+        store = self.hs.get_datastore()
+        self.get_success(
+            store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": None,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": None,
+                },
+            )
+        )
+
+        # Do request without auth to get a session ID
+        params = {"username": "kermit", "password": "monkey"}
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        session = channel.json_body["session"]
+
+        # Use token
+        params["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session,
+        }
+        self.make_request(b"POST", self.url, json.dumps(params))
+
+        # Delete token
+        self.get_success(
+            store.db_pool.simple_delete_one(
+                "registration_tokens",
+                keyvalues={"token": token},
+            )
+        )
+
+        # Delete session (mimics expiry)
+        self.get_success(
+            store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
+        )
+
     def test_advertised_flows(self):
         channel = self.make_request(b"POST", self.url, b"{}")
         self.assertEquals(channel.result["code"], b"401", channel.result)
@@ -744,3 +1110,71 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
 
         self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
         self.assertLessEqual(res, now_ms + self.validity_period)
+
+
+class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
+    servlets = [register.register_servlets]
+    url = "/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity"
+
+    def default_config(self):
+        config = super().default_config()
+        config["registration_requires_token"] = True
+        return config
+
+    def test_GET_token_valid(self):
+        token = "abcd"
+        store = self.hs.get_datastore()
+        self.get_success(
+            store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": None,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": None,
+                },
+            )
+        )
+
+        channel = self.make_request(
+            b"GET",
+            f"{self.url}?token={token}",
+        )
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEquals(channel.json_body["valid"], True)
+
+    def test_GET_token_invalid(self):
+        token = "1234"
+        channel = self.make_request(
+            b"GET",
+            f"{self.url}?token={token}",
+        )
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEquals(channel.json_body["valid"], False)
+
+    @override_config(
+        {"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}}
+    )
+    def test_GET_ratelimiting(self):
+        token = "1234"
+
+        for i in range(0, 6):
+            channel = self.make_request(
+                b"GET",
+                f"{self.url}?token={token}",
+            )
+
+            if i == 5:
+                self.assertEquals(channel.result["code"], b"429", channel.result)
+                retry_after_ms = int(channel.json_body["retry_after_ms"])
+            else:
+                self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
+
+        channel = self.make_request(
+            b"GET",
+            f"{self.url}?token={token}",
+        )
+        self.assertEquals(channel.result["code"], b"200", channel.result)