summary refs log tree commit diff
path: root/tests/rest/client/test_login.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client/test_login.py')
-rw-r--r--tests/rest/client/test_login.py55
1 files changed, 28 insertions, 27 deletions
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index f6efa5fe37..e7f5517e34 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -14,6 +14,7 @@
 import json
 import time
 import urllib.parse
+from http import HTTPStatus
 from typing import Any, Dict, List, Optional
 from unittest.mock import Mock
 from urllib.parse import urlencode
@@ -261,20 +262,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         }
         channel = self.make_request(b"POST", LOGIN_URL, params)
 
-        self.assertEqual(channel.code, 200, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
         access_token = channel.json_body["access_token"]
         device_id = channel.json_body["device_id"]
 
         # we should now be able to make requests with the access token
         channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
-        self.assertEqual(channel.code, 200, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
 
         # time passes
         self.reactor.advance(24 * 3600)
 
         # ... and we should be soft-logouted
         channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
-        self.assertEqual(channel.code, 401, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
         self.assertEqual(channel.json_body["soft_logout"], True)
 
@@ -288,7 +289,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         # more requests with the expired token should still return a soft-logout
         self.reactor.advance(3600)
         channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
-        self.assertEqual(channel.code, 401, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
         self.assertEqual(channel.json_body["soft_logout"], True)
 
@@ -296,7 +297,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         self._delete_device(access_token_2, "kermit", "monkey", device_id)
 
         channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
-        self.assertEqual(channel.code, 401, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
         self.assertEqual(channel.json_body["soft_logout"], False)
 
@@ -307,7 +308,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             b"DELETE", "devices/" + device_id, access_token=access_token
         )
-        self.assertEqual(channel.code, 401, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
         # check it's a UI-Auth fail
         self.assertEqual(
             set(channel.json_body.keys()),
@@ -330,7 +331,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             access_token=access_token,
             content={"auth": auth},
         )
-        self.assertEqual(channel.code, 200, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
 
     @override_config({"session_lifetime": "24h"})
     def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None:
@@ -341,14 +342,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
 
         # we should now be able to make requests with the access token
         channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
-        self.assertEqual(channel.code, 200, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
 
         # time passes
         self.reactor.advance(24 * 3600)
 
         # ... and we should be soft-logouted
         channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
-        self.assertEqual(channel.code, 401, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
         self.assertEqual(channel.json_body["soft_logout"], True)
 
@@ -367,14 +368,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
 
         # we should now be able to make requests with the access token
         channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
-        self.assertEqual(channel.code, 200, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
 
         # time passes
         self.reactor.advance(24 * 3600)
 
         # ... and we should be soft-logouted
         channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
-        self.assertEqual(channel.code, 401, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
         self.assertEqual(channel.json_body["soft_logout"], True)
 
@@ -466,7 +467,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
     def test_get_login_flows(self) -> None:
         """GET /login should return password and SSO flows"""
         channel = self.make_request("GET", "/_matrix/client/r0/login")
-        self.assertEqual(channel.code, 200, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
 
         expected_flow_types = [
             "m.login.cas",
@@ -494,14 +495,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         """/login/sso/redirect should redirect to an identity picker"""
         # first hit the redirect url, which should redirect to our idp picker
         channel = self._make_sso_redirect_request(None)
-        self.assertEqual(channel.code, 302, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
         location_headers = channel.headers.getRawHeaders("Location")
         assert location_headers
         uri = location_headers[0]
 
         # hitting that picker should give us some HTML
         channel = self.make_request("GET", uri)
-        self.assertEqual(channel.code, 200, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
 
         # parse the form to check it has fields assumed elsewhere in this class
         html = channel.result["body"].decode("utf-8")
@@ -530,7 +531,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
             + "&idp=cas",
             shorthand=False,
         )
-        self.assertEqual(channel.code, 302, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
         location_headers = channel.headers.getRawHeaders("Location")
         assert location_headers
         cas_uri = location_headers[0]
@@ -555,7 +556,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
             + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
             + "&idp=saml",
         )
-        self.assertEqual(channel.code, 302, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
         location_headers = channel.headers.getRawHeaders("Location")
         assert location_headers
         saml_uri = location_headers[0]
@@ -579,7 +580,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
             + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
             + "&idp=oidc",
         )
-        self.assertEqual(channel.code, 302, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
         location_headers = channel.headers.getRawHeaders("Location")
         assert location_headers
         oidc_uri = location_headers[0]
@@ -606,7 +607,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"})
 
         # that should serve a confirmation page
-        self.assertEqual(channel.code, 200, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
         content_type_headers = channel.headers.getRawHeaders("Content-Type")
         assert content_type_headers
         self.assertTrue(content_type_headers[-1].startswith("text/html"))
@@ -634,7 +635,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
             "/login",
             content={"type": "m.login.token", "token": login_token},
         )
-        self.assertEqual(chan.code, 200, chan.result)
+        self.assertEqual(chan.code, HTTPStatus.OK, chan.result)
         self.assertEqual(chan.json_body["user_id"], "@user1:test")
 
     def test_multi_sso_redirect_to_unknown(self) -> None:
@@ -643,18 +644,18 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
             "GET",
             "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
         )
-        self.assertEqual(channel.code, 400, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
 
     def test_client_idp_redirect_to_unknown(self) -> None:
         """If the client tries to pick an unknown IdP, return a 404"""
         channel = self._make_sso_redirect_request("xxx")
-        self.assertEqual(channel.code, 404, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
 
     def test_client_idp_redirect_to_oidc(self) -> None:
         """If the client pick a known IdP, redirect to it"""
         channel = self._make_sso_redirect_request("oidc")
-        self.assertEqual(channel.code, 302, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
         location_headers = channel.headers.getRawHeaders("Location")
         assert location_headers
         oidc_uri = location_headers[0]
@@ -765,7 +766,7 @@ class CASTestCase(unittest.HomeserverTestCase):
         channel = self.make_request("GET", cas_ticket_url)
 
         # Test that the response is HTML.
-        self.assertEqual(channel.code, 200, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
         content_type_header_value = ""
         for header in channel.result.get("headers", []):
             if header[0] == b"Content-Type":
@@ -1246,7 +1247,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
         )
 
         # that should redirect to the username picker
-        self.assertEqual(channel.code, 302, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
         location_headers = channel.headers.getRawHeaders("Location")
         assert location_headers
         picker_url = location_headers[0]
@@ -1290,7 +1291,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
                 ("Content-Length", str(len(content))),
             ],
         )
-        self.assertEqual(chan.code, 302, chan.result)
+        self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result)
         location_headers = chan.headers.getRawHeaders("Location")
         assert location_headers
 
@@ -1300,7 +1301,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
             path=location_headers[0],
             custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
         )
-        self.assertEqual(chan.code, 302, chan.result)
+        self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result)
         location_headers = chan.headers.getRawHeaders("Location")
         assert location_headers
 
@@ -1325,5 +1326,5 @@ class UsernamePickerTestCase(HomeserverTestCase):
             "/login",
             content={"type": "m.login.token", "token": login_token},
         )
-        self.assertEqual(chan.code, 200, chan.result)
+        self.assertEqual(chan.code, HTTPStatus.OK, chan.result)
         self.assertEqual(chan.json_body["user_id"], "@bobby:test")