From f0b03186d96305fd44d74a89bf4230beec0c5c31 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 1 Apr 2022 17:04:16 +0100 Subject: Add type hints for `tests/unittest.py`. (#12347) In particular, add type hints for get_success and friends, which are then helpful in a bunch of places. --- tests/rest/admin/test_media.py | 8 ++++++++ tests/rest/admin/test_user.py | 15 +++++++++------ 2 files changed, 17 insertions(+), 6 deletions(-) (limited to 'tests/rest') diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 0d47dd0aff..e909e444ac 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -702,6 +702,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): """ media_info = self.get_success(self.store.get_local_media(self.media_id)) + assert media_info is not None self.assertFalse(media_info["quarantined_by"]) # quarantining @@ -715,6 +716,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) + assert media_info is not None self.assertTrue(media_info["quarantined_by"]) # remove from quarantine @@ -728,6 +730,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) + assert media_info is not None self.assertFalse(media_info["quarantined_by"]) def test_quarantine_protected_media(self) -> None: @@ -740,6 +743,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): # verify protection media_info = self.get_success(self.store.get_local_media(self.media_id)) + assert media_info is not None self.assertTrue(media_info["safe_from_quarantine"]) # quarantining @@ -754,6 +758,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): # verify that is not in quarantine media_info = self.get_success(self.store.get_local_media(self.media_id)) + assert media_info is not None self.assertFalse(media_info["quarantined_by"]) @@ -830,6 +835,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): """ media_info = self.get_success(self.store.get_local_media(self.media_id)) + assert media_info is not None self.assertFalse(media_info["safe_from_quarantine"]) # protect @@ -843,6 +849,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) + assert media_info is not None self.assertTrue(media_info["safe_from_quarantine"]) # unprotect @@ -856,6 +863,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) + assert media_info is not None self.assertFalse(media_info["safe_from_quarantine"]) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index bef911d5df..0cdf1dec40 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1590,10 +1590,9 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) - pushers = self.get_success( - self.store.get_pushers_by({"user_name": "@bob:test"}) + pushers = list( + self.get_success(self.store.get_pushers_by({"user_name": "@bob:test"})) ) - pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertEqual("@bob:test", pushers[0].user_name) @@ -1632,10 +1631,9 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) - pushers = self.get_success( - self.store.get_pushers_by({"user_name": "@bob:test"}) + pushers = list( + self.get_success(self.store.get_pushers_by({"user_name": "@bob:test"})) ) - pushers = list(pushers) self.assertEqual(len(pushers), 0) def test_set_password(self) -> None: @@ -2144,6 +2142,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # is in user directory profile = self.get_success(self.store.get_user_in_directory(self.other_user)) + assert profile is not None self.assertTrue(profile["display_name"] == "User") # Deactivate user @@ -2711,6 +2710,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): user_tuple = self.get_success( self.store.get_user_by_access_token(other_user_token) ) + assert user_tuple is not None token_id = user_tuple.token_id self.get_success( @@ -3676,6 +3676,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): # The user starts off as not shadow-banned. other_user_token = self.login("user", "pass") result = self.get_success(self.store.get_user_by_access_token(other_user_token)) + assert result is not None self.assertFalse(result.shadow_banned) channel = self.make_request("POST", self.url, access_token=self.admin_user_tok) @@ -3684,6 +3685,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): # Ensure the user is shadow-banned (and the cache was cleared). result = self.get_success(self.store.get_user_by_access_token(other_user_token)) + assert result is not None self.assertTrue(result.shadow_banned) # Un-shadow-ban the user. @@ -3695,6 +3697,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): # Ensure the user is no longer shadow-banned (and the cache was cleared). result = self.get_success(self.store.get_user_by_access_token(other_user_token)) + assert result is not None self.assertFalse(result.shadow_banned) -- cgit 1.4.1