diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 11681d030b..0f15f0ec57 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -939,7 +939,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("POST", self.url, b"{}")
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_not_admin(self):
@@ -950,7 +950,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
channel = self.make_request("POST", url, access_token=self.other_user_token)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual("You are not a server admin", channel.json_body["error"])
channel = self.make_request(
@@ -960,7 +960,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
content=b"{}",
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual("You are not a server admin", channel.json_body["error"])
def test_user_does_not_exist(self):
@@ -990,7 +990,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_user_is_not_local(self):
@@ -1006,7 +1006,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
def test_deactivate_user_erase_true(self):
"""
- Test deactivating an user and set `erase` to `true`
+ Test deactivating a user and set `erase` to `true`
"""
# Get user
@@ -1016,24 +1016,22 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User1", channel.json_body["displayname"])
- # Deactivate user
- body = json.dumps({"erase": True})
-
+ # Deactivate and erase user
channel = self.make_request(
"POST",
self.url,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"erase": True},
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
# Get user
channel = self.make_request(
@@ -1042,7 +1040,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
self.assertEqual(0, len(channel.json_body["threepids"]))
@@ -1053,7 +1051,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
def test_deactivate_user_erase_false(self):
"""
- Test deactivating an user and set `erase` to `false`
+ Test deactivating a user and set `erase` to `false`
"""
# Get user
@@ -1063,7 +1061,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
@@ -1071,13 +1069,11 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self.assertEqual("User1", channel.json_body["displayname"])
# Deactivate user
- body = json.dumps({"erase": False})
-
channel = self.make_request(
"POST",
self.url,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"erase": False},
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1089,7 +1085,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
self.assertEqual(0, len(channel.json_body["threepids"]))
@@ -1103,6 +1099,60 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
self._is_erased("@user:test", False)
+ def test_deactivate_user_erase_true_no_profile(self):
+ """
+ Test deactivating a user and set `erase` to `true`
+ if user has no profile information (stored in the database table `profiles`).
+ """
+
+ # Users normally have an entry in `profiles`, but occasionally they are created without one.
+ # To test deactivation for users without a profile, we delete the profile information for our user.
+ self.get_success(
+ self.store.db_pool.simple_delete_one(
+ table="profiles", keyvalues={"user_id": "user"}
+ )
+ )
+
+ # Get user
+ channel = self.make_request(
+ "GET",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(False, channel.json_body["deactivated"])
+ self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
+ self.assertIsNone(channel.json_body["avatar_url"])
+ self.assertIsNone(channel.json_body["displayname"])
+
+ # Deactivate and erase user
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={"erase": True},
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Get user
+ channel = self.make_request(
+ "GET",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(True, channel.json_body["deactivated"])
+ self.assertEqual(0, len(channel.json_body["threepids"]))
+ self.assertIsNone(channel.json_body["avatar_url"])
+ self.assertIsNone(channel.json_body["displayname"])
+
+ self._is_erased("@user:test", True)
+
def _is_erased(self, user_id: str, expect: bool) -> None:
"""Assert that the user is erased or not"""
d = self.store.is_user_erased(user_id)
@@ -1155,7 +1205,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.other_user_token,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual("You are not a server admin", channel.json_body["error"])
channel = self.make_request(
@@ -1165,7 +1215,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=b"{}",
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual("You are not a server admin", channel.json_body["error"])
def test_user_does_not_exist(self):
@@ -1182,6 +1232,58 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"])
+ def test_get_user(self):
+ """
+ Test a simple get of a user.
+ """
+ channel = self.make_request(
+ "GET",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual("User", channel.json_body["displayname"])
+ self._check_fields(channel.json_body)
+
+ def test_get_user_with_sso(self):
+ """
+ Test get a user with SSO details.
+ """
+ self.get_success(
+ self.store.record_user_external_id(
+ "auth_provider1", "external_id1", self.other_user
+ )
+ )
+ self.get_success(
+ self.store.record_user_external_id(
+ "auth_provider2", "external_id2", self.other_user
+ )
+ )
+
+ channel = self.make_request(
+ "GET",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(
+ "external_id1", channel.json_body["external_ids"][0]["external_id"]
+ )
+ self.assertEqual(
+ "auth_provider1", channel.json_body["external_ids"][0]["auth_provider"]
+ )
+ self.assertEqual(
+ "external_id2", channel.json_body["external_ids"][1]["external_id"]
+ )
+ self.assertEqual(
+ "auth_provider2", channel.json_body["external_ids"][1]["auth_provider"]
+ )
+ self._check_fields(channel.json_body)
+
def test_create_server_admin(self):
"""
Check that a new admin user is created successfully.
@@ -1189,30 +1291,29 @@ class UserRestTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v2/users/@bob:test"
# Create user (server admin)
- body = json.dumps(
- {
- "password": "abc123",
- "admin": True,
- "displayname": "Bob's name",
- "threepids": [{"medium": "email", "address": "bob@bob.bob"}],
- "avatar_url": "mxc://fibble/wibble",
- }
- )
+ body = {
+ "password": "abc123",
+ "admin": True,
+ "displayname": "Bob's name",
+ "threepids": [{"medium": "email", "address": "bob@bob.bob"}],
+ "avatar_url": "mxc://fibble/wibble",
+ }
channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content=body,
)
- self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertTrue(channel.json_body["admin"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
+ self._check_fields(channel.json_body)
# Get user
channel = self.make_request(
@@ -1221,7 +1322,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1230,6 +1331,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertFalse(channel.json_body["is_guest"])
self.assertFalse(channel.json_body["deactivated"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
+ self._check_fields(channel.json_body)
def test_create_user(self):
"""
@@ -1238,30 +1340,29 @@ class UserRestTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v2/users/@bob:test"
# Create user
- body = json.dumps(
- {
- "password": "abc123",
- "admin": False,
- "displayname": "Bob's name",
- "threepids": [{"medium": "email", "address": "bob@bob.bob"}],
- "avatar_url": "mxc://fibble/wibble",
- }
- )
+ body = {
+ "password": "abc123",
+ "admin": False,
+ "displayname": "Bob's name",
+ "threepids": [{"medium": "email", "address": "bob@bob.bob"}],
+ "avatar_url": "mxc://fibble/wibble",
+ }
channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content=body,
)
- self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertFalse(channel.json_body["admin"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
+ self._check_fields(channel.json_body)
# Get user
channel = self.make_request(
@@ -1270,7 +1371,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1280,6 +1381,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertFalse(channel.json_body["deactivated"])
self.assertFalse(channel.json_body["shadow_banned"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
+ self._check_fields(channel.json_body)
@override_config(
{"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
@@ -1316,16 +1418,14 @@ class UserRestTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v2/users/@bob:test"
# Create user
- body = json.dumps({"password": "abc123", "admin": False})
-
channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"password": "abc123", "admin": False},
)
- self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertFalse(channel.json_body["admin"])
@@ -1355,17 +1455,15 @@ class UserRestTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v2/users/@bob:test"
# Create user
- body = json.dumps({"password": "abc123", "admin": False})
-
channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"password": "abc123", "admin": False},
)
# Admin user is not blocked by mau anymore
- self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertFalse(channel.json_body["admin"])
@@ -1387,21 +1485,19 @@ class UserRestTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v2/users/@bob:test"
# Create user
- body = json.dumps(
- {
- "password": "abc123",
- "threepids": [{"medium": "email", "address": "bob@bob.bob"}],
- }
- )
+ body = {
+ "password": "abc123",
+ "threepids": [{"medium": "email", "address": "bob@bob.bob"}],
+ }
channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content=body,
)
- self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
@@ -1431,21 +1527,19 @@ class UserRestTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v2/users/@bob:test"
# Create user
- body = json.dumps(
- {
- "password": "abc123",
- "threepids": [{"medium": "email", "address": "bob@bob.bob"}],
- }
- )
+ body = {
+ "password": "abc123",
+ "threepids": [{"medium": "email", "address": "bob@bob.bob"}],
+ }
channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content=body,
)
- self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
@@ -1462,16 +1556,15 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
# Change password
- body = json.dumps({"password": "hahaha"})
-
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"password": "hahaha"},
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self._check_fields(channel.json_body)
def test_set_displayname(self):
"""
@@ -1479,16 +1572,14 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
# Modify user
- body = json.dumps({"displayname": "foobar"})
-
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"displayname": "foobar"},
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("foobar", channel.json_body["displayname"])
@@ -1499,7 +1590,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("foobar", channel.json_body["displayname"])
@@ -1509,18 +1600,14 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
# Delete old and add new threepid to user
- body = json.dumps(
- {"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]}
- )
-
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]},
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
@@ -1532,7 +1619,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
@@ -1557,7 +1644,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
@@ -1577,7 +1664,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"deactivated": True},
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
self.assertIsNone(channel.json_body["password_hash"])
@@ -1598,7 +1685,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
self.assertIsNone(channel.json_body["password_hash"])
@@ -1625,7 +1712,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"deactivated": True},
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
@@ -1641,7 +1728,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"displayname": "Foobar"},
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["deactivated"])
self.assertEqual("Foobar", channel.json_body["displayname"])
@@ -1665,7 +1752,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False},
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
# Reactivate the user.
channel = self.make_request(
@@ -1674,7 +1761,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False, "password": "foo"},
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
self.assertIsNotNone(channel.json_body["password_hash"])
@@ -1696,7 +1783,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False, "password": "foo"},
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Reactivate the user without a password.
@@ -1706,7 +1793,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False},
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
self.assertIsNone(channel.json_body["password_hash"])
@@ -1728,7 +1815,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False, "password": "foo"},
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Reactivate the user without a password.
@@ -1738,7 +1825,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": False},
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
self.assertIsNone(channel.json_body["password_hash"])
@@ -1757,7 +1844,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"admin": True},
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["admin"])
@@ -1768,7 +1855,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["admin"])
@@ -1787,7 +1874,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"password": "abc123"},
)
- self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("bob", channel.json_body["displayname"])
@@ -1798,7 +1885,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("bob", channel.json_body["displayname"])
self.assertEqual(0, channel.json_body["deactivated"])
@@ -1811,7 +1898,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"password": "abc123", "deactivated": "false"},
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(400, channel.code, msg=channel.json_body)
# Check user is not deactivated
channel = self.make_request(
@@ -1820,7 +1907,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("bob", channel.json_body["displayname"])
@@ -1845,7 +1932,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
content={"deactivated": True},
)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertTrue(channel.json_body["deactivated"])
self.assertIsNone(channel.json_body["password_hash"])
self._is_erased(user_id, False)
@@ -1853,6 +1940,25 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertIsNone(self.get_success(d))
self._is_erased(user_id, True)
+ def _check_fields(self, content: JsonDict):
+ """Checks that the expected user attributes are present in content
+
+ Args:
+ content: Content dictionary to check
+ """
+ self.assertIn("displayname", content)
+ self.assertIn("threepids", content)
+ self.assertIn("avatar_url", content)
+ self.assertIn("admin", content)
+ self.assertIn("deactivated", content)
+ self.assertIn("shadow_banned", content)
+ self.assertIn("password_hash", content)
+ self.assertIn("creation_ts", content)
+ self.assertIn("appservice_id", content)
+ self.assertIn("consent_server_notice_sent", content)
+ self.assertIn("consent_version", content)
+ self.assertIn("external_ids", content)
+
class UserMembershipRestTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index ed55a640af..69798e95c3 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -52,6 +52,7 @@ class RestHelper:
room_version: str = None,
tok: str = None,
expect_code: int = 200,
+ extra_content: Optional[Dict] = None,
) -> str:
"""
Create a room.
@@ -72,7 +73,7 @@ class RestHelper:
temp_id = self.auth_user_id
self.auth_user_id = room_creator
path = "/_matrix/client/r0/createRoom"
- content = {}
+ content = extra_content or {}
if not is_public:
content["visibility"] = "private"
if room_version:
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index 485e3650c3..6b90f838b6 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -20,7 +20,7 @@ import synapse.rest.admin
from synapse.api.constants import LoginType
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.rest.client.v1 import login
-from synapse.rest.client.v2_alpha import auth, devices, register
+from synapse.rest.client.v2_alpha import account, auth, devices, register
from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.types import JsonDict, UserID
@@ -498,3 +498,221 @@ class UIAuthTests(unittest.HomeserverTestCase):
self.delete_device(
self.user_tok, self.device_id, 403, body={"auth": {"session": session_id}}
)
+
+
+class RefreshAuthTests(unittest.HomeserverTestCase):
+ servlets = [
+ auth.register_servlets,
+ account.register_servlets,
+ login.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ register.register_servlets,
+ ]
+ hijack_auth = False
+
+ def prepare(self, reactor, clock, hs):
+ self.user_pass = "pass"
+ self.user = self.register_user("test", self.user_pass)
+
+ def test_login_issue_refresh_token(self):
+ """
+ A login response should include a refresh_token only if asked.
+ """
+ # Test login
+ body = {"type": "m.login.password", "user": "test", "password": self.user_pass}
+
+ login_without_refresh = self.make_request(
+ "POST", "/_matrix/client/r0/login", body
+ )
+ self.assertEqual(login_without_refresh.code, 200, login_without_refresh.result)
+ self.assertNotIn("refresh_token", login_without_refresh.json_body)
+
+ login_with_refresh = self.make_request(
+ "POST",
+ "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true",
+ body,
+ )
+ self.assertEqual(login_with_refresh.code, 200, login_with_refresh.result)
+ self.assertIn("refresh_token", login_with_refresh.json_body)
+ self.assertIn("expires_in_ms", login_with_refresh.json_body)
+
+ def test_register_issue_refresh_token(self):
+ """
+ A register response should include a refresh_token only if asked.
+ """
+ register_without_refresh = self.make_request(
+ "POST",
+ "/_matrix/client/r0/register",
+ {
+ "username": "test2",
+ "password": self.user_pass,
+ "auth": {"type": LoginType.DUMMY},
+ },
+ )
+ self.assertEqual(
+ register_without_refresh.code, 200, register_without_refresh.result
+ )
+ self.assertNotIn("refresh_token", register_without_refresh.json_body)
+
+ register_with_refresh = self.make_request(
+ "POST",
+ "/_matrix/client/r0/register?org.matrix.msc2918.refresh_token=true",
+ {
+ "username": "test3",
+ "password": self.user_pass,
+ "auth": {"type": LoginType.DUMMY},
+ },
+ )
+ self.assertEqual(register_with_refresh.code, 200, register_with_refresh.result)
+ self.assertIn("refresh_token", register_with_refresh.json_body)
+ self.assertIn("expires_in_ms", register_with_refresh.json_body)
+
+ def test_token_refresh(self):
+ """
+ A refresh token can be used to issue a new access token.
+ """
+ body = {"type": "m.login.password", "user": "test", "password": self.user_pass}
+ login_response = self.make_request(
+ "POST",
+ "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true",
+ body,
+ )
+ self.assertEqual(login_response.code, 200, login_response.result)
+
+ refresh_response = self.make_request(
+ "POST",
+ "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh",
+ {"refresh_token": login_response.json_body["refresh_token"]},
+ )
+ self.assertEqual(refresh_response.code, 200, refresh_response.result)
+ self.assertIn("access_token", refresh_response.json_body)
+ self.assertIn("refresh_token", refresh_response.json_body)
+ self.assertIn("expires_in_ms", refresh_response.json_body)
+
+ # The access and refresh tokens should be different from the original ones after refresh
+ self.assertNotEqual(
+ login_response.json_body["access_token"],
+ refresh_response.json_body["access_token"],
+ )
+ self.assertNotEqual(
+ login_response.json_body["refresh_token"],
+ refresh_response.json_body["refresh_token"],
+ )
+
+ @override_config({"access_token_lifetime": "1m"})
+ def test_refresh_token_expiration(self):
+ """
+ The access token should have some time as specified in the config.
+ """
+ body = {"type": "m.login.password", "user": "test", "password": self.user_pass}
+ login_response = self.make_request(
+ "POST",
+ "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true",
+ body,
+ )
+ self.assertEqual(login_response.code, 200, login_response.result)
+ self.assertApproximates(
+ login_response.json_body["expires_in_ms"], 60 * 1000, 100
+ )
+
+ refresh_response = self.make_request(
+ "POST",
+ "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh",
+ {"refresh_token": login_response.json_body["refresh_token"]},
+ )
+ self.assertEqual(refresh_response.code, 200, refresh_response.result)
+ self.assertApproximates(
+ refresh_response.json_body["expires_in_ms"], 60 * 1000, 100
+ )
+
+ def test_refresh_token_invalidation(self):
+ """Refresh tokens are invalidated after first use of the next token.
+
+ A refresh token is considered invalid if:
+ - it was already used at least once
+ - and either
+ - the next access token was used
+ - the next refresh token was used
+
+ The chain of tokens goes like this:
+
+ login -|-> first_refresh -> third_refresh (fails)
+ |-> second_refresh -> fifth_refresh
+ |-> fourth_refresh (fails)
+ """
+
+ body = {"type": "m.login.password", "user": "test", "password": self.user_pass}
+ login_response = self.make_request(
+ "POST",
+ "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true",
+ body,
+ )
+ self.assertEqual(login_response.code, 200, login_response.result)
+
+ # This first refresh should work properly
+ first_refresh_response = self.make_request(
+ "POST",
+ "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh",
+ {"refresh_token": login_response.json_body["refresh_token"]},
+ )
+ self.assertEqual(
+ first_refresh_response.code, 200, first_refresh_response.result
+ )
+
+ # This one as well, since the token in the first one was never used
+ second_refresh_response = self.make_request(
+ "POST",
+ "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh",
+ {"refresh_token": login_response.json_body["refresh_token"]},
+ )
+ self.assertEqual(
+ second_refresh_response.code, 200, second_refresh_response.result
+ )
+
+ # This one should not, since the token from the first refresh is not valid anymore
+ third_refresh_response = self.make_request(
+ "POST",
+ "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh",
+ {"refresh_token": first_refresh_response.json_body["refresh_token"]},
+ )
+ self.assertEqual(
+ third_refresh_response.code, 401, third_refresh_response.result
+ )
+
+ # The associated access token should also be invalid
+ whoami_response = self.make_request(
+ "GET",
+ "/_matrix/client/r0/account/whoami",
+ access_token=first_refresh_response.json_body["access_token"],
+ )
+ self.assertEqual(whoami_response.code, 401, whoami_response.result)
+
+ # But all other tokens should work (they will expire after some time)
+ for access_token in [
+ second_refresh_response.json_body["access_token"],
+ login_response.json_body["access_token"],
+ ]:
+ whoami_response = self.make_request(
+ "GET", "/_matrix/client/r0/account/whoami", access_token=access_token
+ )
+ self.assertEqual(whoami_response.code, 200, whoami_response.result)
+
+ # Now that the access token from the last valid refresh was used once, refreshing with the N-1 token should fail
+ fourth_refresh_response = self.make_request(
+ "POST",
+ "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh",
+ {"refresh_token": login_response.json_body["refresh_token"]},
+ )
+ self.assertEqual(
+ fourth_refresh_response.code, 403, fourth_refresh_response.result
+ )
+
+ # But refreshing from the last valid refresh token still works
+ fifth_refresh_response = self.make_request(
+ "POST",
+ "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh",
+ {"refresh_token": second_refresh_response.json_body["refresh_token"]},
+ )
+ self.assertEqual(
+ fifth_refresh_response.code, 200, fifth_refresh_response.result
+ )
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 726a22f90c..1ac23e1769 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -47,35 +47,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
channel = self.make_request("GET", "/sync")
self.assertEqual(channel.code, 200)
- self.assertTrue(
- {
- "next_batch",
- "rooms",
- "presence",
- "account_data",
- "to_device",
- "device_lists",
- }.issubset(set(channel.json_body.keys()))
- )
-
- def test_sync_presence_disabled(self):
- """
- When presence is disabled, the key does not appear in /sync.
- """
- self.hs.config.use_presence = False
-
- channel = self.make_request("GET", "/sync")
-
- self.assertEqual(channel.code, 200)
- self.assertTrue(
- {
- "next_batch",
- "rooms",
- "account_data",
- "to_device",
- "device_lists",
- }.issubset(set(channel.json_body.keys()))
- )
+ self.assertIn("next_batch", channel.json_body)
class SyncFilterTestCase(unittest.HomeserverTestCase):
|