diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 278808abb5..dac79bd745 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -642,7 +642,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
- self.assertFalse(media_info["quarantined_by"])
+ self.assertFalse(media_info.quarantined_by)
# quarantining
channel = self.make_request(
@@ -656,7 +656,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
- self.assertTrue(media_info["quarantined_by"])
+ self.assertTrue(media_info.quarantined_by)
# remove from quarantine
channel = self.make_request(
@@ -670,7 +670,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
- self.assertFalse(media_info["quarantined_by"])
+ self.assertFalse(media_info.quarantined_by)
def test_quarantine_protected_media(self) -> None:
"""
@@ -683,7 +683,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
# verify protection
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
- self.assertTrue(media_info["safe_from_quarantine"])
+ self.assertTrue(media_info.safe_from_quarantine)
# quarantining
channel = self.make_request(
@@ -698,7 +698,7 @@ class QuarantineMediaByIDTestCase(_AdminMediaTests):
# verify that is not in quarantine
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
- self.assertFalse(media_info["quarantined_by"])
+ self.assertFalse(media_info.quarantined_by)
class ProtectMediaByIDTestCase(_AdminMediaTests):
@@ -756,7 +756,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests):
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
- self.assertFalse(media_info["safe_from_quarantine"])
+ self.assertFalse(media_info.safe_from_quarantine)
# protect
channel = self.make_request(
@@ -770,7 +770,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests):
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
- self.assertTrue(media_info["safe_from_quarantine"])
+ self.assertTrue(media_info.safe_from_quarantine)
# unprotect
channel = self.make_request(
@@ -784,7 +784,7 @@ class ProtectMediaByIDTestCase(_AdminMediaTests):
media_info = self.get_success(self.store.get_local_media(self.media_id))
assert media_info is not None
- self.assertFalse(media_info["safe_from_quarantine"])
+ self.assertFalse(media_info.safe_from_quarantine)
class PurgeMediaCacheTestCase(_AdminMediaTests):
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 37f37a09d8..42b065d883 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -2706,7 +2706,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# is in user directory
profile = self.get_success(self.store._get_user_in_directory(self.other_user))
assert profile is not None
- self.assertTrue(profile["display_name"] == "User")
+ self.assertEqual(profile[0], "User")
# Deactivate user
channel = self.make_request(
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index cffbda9a7d..bd59bb50cf 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -139,12 +139,12 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
#
# Note that we don't have the UI Auth session ID, so just pull out the single
# row.
- ui_auth_data = self.get_success(
- self.store.db_pool.simple_select_one(
- "ui_auth_sessions", keyvalues={}, retcols=("clientdict",)
+ result = self.get_success(
+ self.store.db_pool.simple_select_one_onecol(
+ "ui_auth_sessions", keyvalues={}, retcol="clientdict"
)
)
- client_dict = db_to_json(ui_auth_data["clientdict"])
+ client_dict = db_to_json(result)
self.assertNotIn("new_password", client_dict)
@override_config({"rc_3pid_validation": {"burst_count": 3}})
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index ba4e017a0e..b04094b7b3 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -270,15 +270,15 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertLessEqual(det_data.items(), channel.json_body.items())
# Check the `completed` counter has been incremented and pending is 0
- res = self.get_success(
+ pending, completed = self.get_success(
store.db_pool.simple_select_one(
"registration_tokens",
keyvalues={"token": token},
retcols=["pending", "completed"],
)
)
- self.assertEqual(res["completed"], 1)
- self.assertEqual(res["pending"], 0)
+ self.assertEqual(completed, 1)
+ self.assertEqual(pending, 0)
@override_config({"registration_requires_token": True})
def test_POST_registration_token_invalid(self) -> None:
@@ -372,15 +372,15 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
params1["auth"]["type"] = LoginType.DUMMY
self.make_request(b"POST", self.url, params1)
# Check pending=0 and completed=1
- res = self.get_success(
+ pending, completed = self.get_success(
store.db_pool.simple_select_one(
"registration_tokens",
keyvalues={"token": token},
retcols=["pending", "completed"],
)
)
- self.assertEqual(res["pending"], 0)
- self.assertEqual(res["completed"], 1)
+ self.assertEqual(pending, 0)
+ self.assertEqual(completed, 1)
# Check auth still fails when using token with session2
channel = self.make_request(b"POST", self.url, params2)
diff --git a/tests/rest/media/test_media_retention.py b/tests/rest/media/test_media_retention.py
index b59d9dfd4d..27a663a23b 100644
--- a/tests/rest/media/test_media_retention.py
+++ b/tests/rest/media/test_media_retention.py
@@ -267,23 +267,23 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
def _assert_mxc_uri_purge_state(mxc_uri: MXCUri, expect_purged: bool) -> None:
"""Given an MXC URI, assert whether it has been purged or not."""
if mxc_uri.server_name == self.hs.config.server.server_name:
- found_media_dict = self.get_success(
- self.store.get_local_media(mxc_uri.media_id)
+ found_media = bool(
+ self.get_success(self.store.get_local_media(mxc_uri.media_id))
)
else:
- found_media_dict = self.get_success(
- self.store.get_cached_remote_media(
- mxc_uri.server_name, mxc_uri.media_id
+ found_media = bool(
+ self.get_success(
+ self.store.get_cached_remote_media(
+ mxc_uri.server_name, mxc_uri.media_id
+ )
)
)
if expect_purged:
- self.assertIsNone(
- found_media_dict, msg=f"{mxc_uri} unexpectedly not purged"
- )
+ self.assertFalse(found_media, msg=f"{mxc_uri} unexpectedly not purged")
else:
- self.assertIsNotNone(
- found_media_dict,
+ self.assertTrue(
+ found_media,
msg=f"{mxc_uri} unexpectedly purged",
)
|