summary refs log tree commit diff
path: root/tests/rest/client/v1
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client/v1')
-rw-r--r--tests/rest/client/v1/test_login.py124
-rw-r--r--tests/rest/client/v1/test_rooms.py35
2 files changed, 144 insertions, 15 deletions
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 2d25490374..66dfdaffbc 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -29,8 +29,7 @@ from synapse.appservice import ApplicationService
 from synapse.rest.client.v1 import login, logout
 from synapse.rest.client.v2_alpha import devices, register
 from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
-from synapse.rest.synapse.client.pick_idp import PickIdpResource
-from synapse.rest.synapse.client.pick_username import pick_username_resource
+from synapse.rest.synapse.client import build_synapse_client_resource_tree
 from synapse.types import create_requester
 
 from tests import unittest
@@ -75,6 +74,10 @@ TEST_CLIENT_REDIRECT_URL = 'https://x?<ab c>&q"+%3D%2B"="fö%26=o"'
 # the query params in TEST_CLIENT_REDIRECT_URL
 EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fö&=o"')]
 
+# (possibly experimental) login flows we expect to appear in the list after the normal
+# ones
+ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}]
+
 
 class LoginRestServletTestCase(unittest.HomeserverTestCase):
 
@@ -419,13 +422,61 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         return config
 
     def create_resource_dict(self) -> Dict[str, Resource]:
-        from synapse.rest.oidc import OIDCResource
-
         d = super().create_resource_dict()
-        d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs)
-        d["/_synapse/oidc"] = OIDCResource(self.hs)
+        d.update(build_synapse_client_resource_tree(self.hs))
         return d
 
+    def test_get_login_flows(self):
+        """GET /login should return password and SSO flows"""
+        channel = self.make_request("GET", "/_matrix/client/r0/login")
+        self.assertEqual(channel.code, 200, channel.result)
+
+        expected_flows = [
+            {"type": "m.login.cas"},
+            {"type": "m.login.sso"},
+            {"type": "m.login.token"},
+            {"type": "m.login.password"},
+        ] + ADDITIONAL_LOGIN_FLOWS
+
+        self.assertCountEqual(channel.json_body["flows"], expected_flows)
+
+    @override_config({"experimental_features": {"msc2858_enabled": True}})
+    def test_get_msc2858_login_flows(self):
+        """The SSO flow should include IdP info if MSC2858 is enabled"""
+        channel = self.make_request("GET", "/_matrix/client/r0/login")
+        self.assertEqual(channel.code, 200, channel.result)
+
+        # stick the flows results in a dict by type
+        flow_results = {}  # type: Dict[str, Any]
+        for f in channel.json_body["flows"]:
+            flow_type = f["type"]
+            self.assertNotIn(
+                flow_type, flow_results, "duplicate flow type %s" % (flow_type,)
+            )
+            flow_results[flow_type] = f
+
+        self.assertIn("m.login.sso", flow_results, "m.login.sso was not returned")
+        sso_flow = flow_results.pop("m.login.sso")
+        # we should have a set of IdPs
+        self.assertCountEqual(
+            sso_flow["org.matrix.msc2858.identity_providers"],
+            [
+                {"id": "cas", "name": "CAS"},
+                {"id": "saml", "name": "SAML"},
+                {"id": "oidc-idp1", "name": "IDP1"},
+                {"id": "oidc", "name": "OIDC"},
+            ],
+        )
+
+        # the rest of the flows are simple
+        expected_flows = [
+            {"type": "m.login.cas"},
+            {"type": "m.login.token"},
+            {"type": "m.login.password"},
+        ] + ADDITIONAL_LOGIN_FLOWS
+
+        self.assertCountEqual(flow_results.values(), expected_flows)
+
     def test_multi_sso_redirect(self):
         """/login/sso/redirect should redirect to an identity picker"""
         # first hit the redirect url, which should redirect to our idp picker
@@ -446,7 +497,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         p.feed(channel.result["body"].decode("utf-8"))
         p.close()
 
-        self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "idp1", "saml"])
+        self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "oidc-idp1", "saml"])
 
         self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL)
 
@@ -564,6 +615,43 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(channel.code, 400, channel.result)
 
+    def test_client_idp_redirect_msc2858_disabled(self):
+        """If the client tries to pick an IdP but MSC2858 is disabled, return a 400"""
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
+            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
+        )
+        self.assertEqual(channel.code, 400, channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
+
+    @override_config({"experimental_features": {"msc2858_enabled": True}})
+    def test_client_idp_redirect_to_unknown(self):
+        """If the client tries to pick an unknown IdP, return a 404"""
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/xxx?redirectUrl="
+            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
+        )
+        self.assertEqual(channel.code, 404, channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
+
+    @override_config({"experimental_features": {"msc2858_enabled": True}})
+    def test_client_idp_redirect_to_oidc(self):
+        """If the client pick a known IdP, redirect to it"""
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
+            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
+        )
+
+        self.assertEqual(channel.code, 302, channel.result)
+        oidc_uri = channel.headers.getRawHeaders("Location")[0]
+        oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
+
+        # it should redirect us to the auth page of the OIDC server
+        self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
+
     @staticmethod
     def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
         prefix = key + " = "
@@ -1119,11 +1207,8 @@ class UsernamePickerTestCase(HomeserverTestCase):
         return config
 
     def create_resource_dict(self) -> Dict[str, Resource]:
-        from synapse.rest.oidc import OIDCResource
-
         d = super().create_resource_dict()
-        d["/_synapse/client/pick_username"] = pick_username_resource(self.hs)
-        d["/_synapse/oidc"] = OIDCResource(self.hs)
+        d.update(build_synapse_client_resource_tree(self.hs))
         return d
 
     def test_username_picker(self):
@@ -1137,7 +1222,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
         # that should redirect to the username picker
         self.assertEqual(channel.code, 302, channel.result)
         picker_url = channel.headers.getRawHeaders("Location")[0]
-        self.assertEqual(picker_url, "/_synapse/client/pick_username")
+        self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details")
 
         # ... with a username_mapping_session cookie
         cookies = {}  # type: Dict[str,str]
@@ -1161,12 +1246,11 @@ class UsernamePickerTestCase(HomeserverTestCase):
         self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000)
 
         # Now, submit a username to the username picker, which should serve a redirect
-        # back to the client
-        submit_path = picker_url + "/submit"
+        # to the completion page
         content = urlencode({b"username": b"bobby"}).encode("utf8")
         chan = self.make_request(
             "POST",
-            path=submit_path,
+            path=picker_url,
             content=content,
             content_is_form=True,
             custom_headers=[
@@ -1178,6 +1262,16 @@ class UsernamePickerTestCase(HomeserverTestCase):
         )
         self.assertEqual(chan.code, 302, chan.result)
         location_headers = chan.headers.getRawHeaders("Location")
+
+        # send a request to the completion page, which should 302 to the client redirectUrl
+        chan = self.make_request(
+            "GET",
+            path=location_headers[0],
+            custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
+        )
+        self.assertEqual(chan.code, 302, chan.result)
+        location_headers = chan.headers.getRawHeaders("Location")
+
         # ensure that the returned location matches the requested redirect URL
         path, query = location_headers[0].split("?", 1)
         self.assertEqual(path, "https://x")
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index d4e3165436..2548b3a80c 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -616,6 +616,41 @@ class RoomMemberStateTestCase(RoomBase):
         self.assertEquals(json.loads(content), channel.json_body)
 
 
+class RoomInviteRatelimitTestCase(RoomBase):
+    user_id = "@sid1:red"
+
+    servlets = [
+        admin.register_servlets,
+        profile.register_servlets,
+        room.register_servlets,
+    ]
+
+    @unittest.override_config(
+        {"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}}
+    )
+    def test_invites_by_rooms_ratelimit(self):
+        """Tests that invites in a room are actually rate-limited."""
+        room_id = self.helper.create_room_as(self.user_id)
+
+        for i in range(3):
+            self.helper.invite(room_id, self.user_id, "@user-%s:red" % (i,))
+
+        self.helper.invite(room_id, self.user_id, "@user-4:red", expect_code=429)
+
+    @unittest.override_config(
+        {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
+    )
+    def test_invites_by_users_ratelimit(self):
+        """Tests that invites to a specific user are actually rate-limited."""
+
+        for i in range(3):
+            room_id = self.helper.create_room_as(self.user_id)
+            self.helper.invite(room_id, self.user_id, "@other-users:red")
+
+        room_id = self.helper.create_room_as(self.user_id)
+        self.helper.invite(room_id, self.user_id, "@other-users:red", expect_code=429)
+
+
 class RoomJoinRatelimitTestCase(RoomBase):
     user_id = "@sid1:red"