summary refs log tree commit diff
path: root/tests/rest
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-02-17 13:19:38 -0500
committerGitHub <noreply@github.com>2023-02-17 18:19:38 +0000
commitc9b9143655d39ef691bd21ebc03434feb0dc377f (patch)
tree7337697cd6d0fcecace3a7344960e28430635e63 /tests/rest
parentAdd account data to export command (#14969) (diff)
downloadsynapse-c9b9143655d39ef691bd21ebc03434feb0dc377f.tar.xz
Fix-up type hints in tests/server.py. (#15084)
This file was being ignored by mypy, we remove that
and add the missing type hints & deal with any fallout.
Diffstat (limited to 'tests/rest')
-rw-r--r--tests/rest/client/test_auth.py14
-rw-r--r--tests/rest/client/utils.py58
2 files changed, 41 insertions, 31 deletions
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index f4e1e7de43..a144610078 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -34,7 +34,7 @@ from synapse.util import Clock
 from tests import unittest
 from tests.handlers.test_oidc import HAS_OIDC
 from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER
-from tests.server import FakeChannel, make_request
+from tests.server import FakeChannel
 from tests.unittest import override_config, skip_unless
 
 
@@ -1322,16 +1322,8 @@ class OidcBackchannelLogoutTests(unittest.HomeserverTestCase):
         channel = self.submit_logout_token(logout_token)
         self.assertEqual(channel.code, 200)
 
-        # Now try to exchange the login token
-        channel = make_request(
-            self.hs.get_reactor(),
-            self.site,
-            "POST",
-            "/login",
-            content={"type": "m.login.token", "token": login_token},
-        )
-        # It should have failed
-        self.assertEqual(channel.code, 403)
+        # Now try to exchange the login token, it should fail.
+        self.helper.login_via_token(login_token, 403)
 
     @override_config(
         {
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 8d6f2b6ff9..9532e5ddc1 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -36,6 +36,7 @@ from urllib.parse import urlencode
 import attr
 from typing_extensions import Literal
 
+from twisted.test.proto_helpers import MemoryReactorClock
 from twisted.web.resource import Resource
 from twisted.web.server import Site
 
@@ -67,6 +68,7 @@ class RestHelper:
     """
 
     hs: HomeServer
+    reactor: MemoryReactorClock
     site: Site
     auth_user_id: Optional[str]
 
@@ -142,7 +144,7 @@ class RestHelper:
             path = path + "?access_token=%s" % tok
 
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "POST",
             path,
@@ -216,7 +218,7 @@ class RestHelper:
             data["reason"] = reason
 
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "POST",
             path,
@@ -313,7 +315,7 @@ class RestHelper:
         data.update(extra_data or {})
 
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "PUT",
             path,
@@ -394,7 +396,7 @@ class RestHelper:
             path = path + "?access_token=%s" % tok
 
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "PUT",
             path,
@@ -433,7 +435,7 @@ class RestHelper:
             path = path + f"?access_token={tok}"
 
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "GET",
             path,
@@ -488,7 +490,7 @@ class RestHelper:
         if body is not None:
             content = json.dumps(body).encode("utf8")
 
-        channel = make_request(self.hs.get_reactor(), self.site, method, path, content)
+        channel = make_request(self.reactor, self.site, method, path, content)
 
         assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
             expect_code,
@@ -573,8 +575,8 @@ class RestHelper:
         image_length = len(image_data)
         path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
         channel = make_request(
-            self.hs.get_reactor(),
-            FakeSite(resource, self.hs.get_reactor()),
+            self.reactor,
+            FakeSite(resource, self.reactor),
             "POST",
             path,
             content=image_data,
@@ -603,7 +605,7 @@ class RestHelper:
             expect_code: The return code to expect from attempting the whoami request
         """
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "GET",
             "account/whoami",
@@ -642,7 +644,7 @@ class RestHelper:
     ) -> Tuple[JsonDict, FakeAuthorizationGrant]:
         """Log in (as a new user) via OIDC
 
-        Returns the result of the final token login.
+        Returns the result of the final token login and the fake authorization grant.
 
         Requires that "oidc_config" in the homeserver config be set appropriately
         (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
@@ -672,10 +674,28 @@ class RestHelper:
         assert m, channel.text_body
         login_token = m.group(1)
 
-        # finally, submit the matrix login token to the login API, which gives us our
-        # matrix access token and device id.
+        return self.login_via_token(login_token, expected_status), grant
+
+    def login_via_token(
+        self,
+        login_token: str,
+        expected_status: int = 200,
+    ) -> JsonDict:
+        """Submit the matrix login token to the login API, which gives us our
+        matrix access token and device id.Log in (as a new user) via OIDC
+
+        Returns the result of the token login.
+
+        Requires that "oidc_config" in the homeserver config be set appropriately
+        (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
+        "public_base_url".
+
+        Also requires the login servlet and the OIDC callback resource to be mounted at
+        the normal places.
+        """
+
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "POST",
             "/login",
@@ -684,7 +704,7 @@ class RestHelper:
         assert (
             channel.code == expected_status
         ), f"unexpected status in response: {channel.code}"
-        return channel.json_body, grant
+        return channel.json_body
 
     def auth_via_oidc(
         self,
@@ -805,7 +825,7 @@ class RestHelper:
         with fake_serer.patch_homeserver(hs=self.hs):
             # now hit the callback URI with the right params and a made-up code
             channel = make_request(
-                self.hs.get_reactor(),
+                self.reactor,
                 self.site,
                 "GET",
                 callback_uri,
@@ -849,7 +869,7 @@ class RestHelper:
         # is the easiest way of figuring out what the Host header ought to be set to
         # to keep Synapse happy.
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "GET",
             uri,
@@ -867,7 +887,7 @@ class RestHelper:
         location = get_location(channel)
         parts = urllib.parse.urlsplit(location)
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "GET",
             urllib.parse.urlunsplit(("", "") + parts[2:]),
@@ -900,9 +920,7 @@ class RestHelper:
             + urllib.parse.urlencode({"session": ui_auth_session_id})
         )
         # hit the redirect url (which will issue a cookie and state)
-        channel = make_request(
-            self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint
-        )
+        channel = make_request(self.reactor, self.site, "GET", sso_redirect_endpoint)
         # that should serve a confirmation page
         assert channel.code == HTTPStatus.OK, channel.text_body
         channel.extract_cookies(cookies)