diff options
Diffstat (limited to 'tests/rest/client/test_account.py')
-rw-r--r-- | tests/rest/client/test_account.py | 290 |
1 files changed, 157 insertions, 133 deletions
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index 6c4462e74a..def836054d 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -15,11 +15,12 @@ import json import os import re from email.parser import Parser -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from unittest.mock import Mock import pkg_resources +from twisted.internet.interfaces import IReactorTCP from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin @@ -30,6 +31,7 @@ from synapse.rest import admin from synapse.rest.client import account, login, register, room from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource from synapse.server import HomeServer +from synapse.types import JsonDict from synapse.util import Clock from tests import unittest @@ -46,7 +48,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): login.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() # Email config. @@ -67,20 +69,27 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): hs = self.setup_test_homeserver(config=config) async def sendmail( - reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs - ): - self.email_attempts.append(msg) - - self.email_attempts = [] + reactor: IReactorTCP, + smtphost: str, + smtpport: int, + from_addr: str, + to_addr: str, + msg_bytes: bytes, + *args: Any, + **kwargs: Any, + ) -> None: + self.email_attempts.append(msg_bytes) + + self.email_attempts: List[bytes] = [] hs.get_send_email_handler()._sendmail = sendmail return hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.submit_token_resource = PasswordResetSubmitTokenResource(hs) - def test_basic_password_reset(self): + def test_basic_password_reset(self) -> None: """Test basic password reset flow""" old_password = "monkey" new_password = "kangeroo" @@ -118,7 +127,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): self.attempt_wrong_password_login("kermit", old_password) @override_config({"rc_3pid_validation": {"burst_count": 3}}) - def test_ratelimit_by_email(self): + def test_ratelimit_by_email(self) -> None: """Test that we ratelimit /requestToken for the same email.""" old_password = "monkey" new_password = "kangeroo" @@ -139,7 +148,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): ) ) - def reset(ip): + def reset(ip: str) -> None: client_secret = "foobar" session_id = self._request_token(email, client_secret, ip) @@ -166,7 +175,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): self.assertEqual(cm.exception.code, 429) - def test_basic_password_reset_canonicalise_email(self): + def test_basic_password_reset_canonicalise_email(self) -> None: """Test basic password reset flow Request password reset with different spelling """ @@ -206,7 +215,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # Assert we can't log in with the old password self.attempt_wrong_password_login("kermit", old_password) - def test_cant_reset_password_without_clicking_link(self): + def test_cant_reset_password_without_clicking_link(self) -> None: """Test that we do actually need to click the link in the email""" old_password = "monkey" new_password = "kangeroo" @@ -241,7 +250,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # Assert we can't log in with the new password self.attempt_wrong_password_login("kermit", new_password) - def test_no_valid_token(self): + def test_no_valid_token(self) -> None: """Test that we do actually need to request a token and can't just make a session up. """ @@ -277,7 +286,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): self.attempt_wrong_password_login("kermit", new_password) @unittest.override_config({"request_token_inhibit_3pid_errors": True}) - def test_password_reset_bad_email_inhibit_error(self): + def test_password_reset_bad_email_inhibit_error(self) -> None: """Test that triggering a password reset with an email address that isn't bound to an account doesn't leak the lack of binding for that address if configured that way. @@ -292,7 +301,12 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): self.assertIsNotNone(session_id) - def _request_token(self, email, client_secret, ip="127.0.0.1"): + def _request_token( + self, + email: str, + client_secret: str, + ip: str = "127.0.0.1", + ) -> str: channel = self.make_request( "POST", b"account/password/email/requestToken", @@ -309,7 +323,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): return channel.json_body["sid"] - def _validate_token(self, link): + def _validate_token(self, link: str) -> None: # Remove the host path = link.replace("https://example.com", "") @@ -339,7 +353,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): ) self.assertEqual(200, channel.code, channel.result) - def _get_link_from_email(self): + def _get_link_from_email(self) -> str: assert self.email_attempts, "No emails have been sent" raw_msg = self.email_attempts[-1].decode("UTF-8") @@ -354,14 +368,19 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): if not text: self.fail("Could not find text portion of email to parse") + assert text is not None match = re.search(r"https://example.com\S+", text) assert match, "Could not find link in email" return match.group(0) def _reset_password( - self, new_password, session_id, client_secret, expected_code=200 - ): + self, + new_password: str, + session_id: str, + client_secret: str, + expected_code: int = 200, + ) -> None: channel = self.make_request( "POST", b"account/password", @@ -388,11 +407,11 @@ class DeactivateTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.hs = self.setup_test_homeserver() return self.hs - def test_deactivate_account(self): + def test_deactivate_account(self) -> None: user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test") @@ -407,7 +426,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase): channel = self.make_request("GET", "account/whoami", access_token=tok) self.assertEqual(channel.code, 401) - def test_pending_invites(self): + def test_pending_invites(self) -> None: """Tests that deactivating a user rejects every pending invite for them.""" store = self.hs.get_datastores().main @@ -448,7 +467,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase): self.assertEqual(len(memberships), 1, memberships) self.assertEqual(memberships[0].room_id, room_id, memberships) - def deactivate(self, user_id, tok): + def deactivate(self, user_id: str, tok: str) -> None: request_data = json.dumps( { "auth": { @@ -474,12 +493,12 @@ class WhoamiTestCase(unittest.HomeserverTestCase): register.register_servlets, ] - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["allow_guest_access"] = True return config - def test_GET_whoami(self): + def test_GET_whoami(self) -> None: device_id = "wouldgohere" user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test", device_id=device_id) @@ -496,7 +515,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase): }, ) - def test_GET_whoami_guests(self): + def test_GET_whoami_guests(self) -> None: channel = self.make_request( b"POST", b"/_matrix/client/r0/register?kind=guest", b"{}" ) @@ -516,7 +535,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase): }, ) - def test_GET_whoami_appservices(self): + def test_GET_whoami_appservices(self) -> None: user_id = "@as:test" as_token = "i_am_an_app_service" @@ -541,7 +560,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase): ) self.assertFalse(hasattr(whoami, "device_id")) - def _whoami(self, tok): + def _whoami(self, tok: str) -> JsonDict: channel = self.make_request("GET", "account/whoami", {}, access_token=tok) self.assertEqual(channel.code, 200) return channel.json_body @@ -555,7 +574,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): synapse.rest.admin.register_servlets_for_client_rest_resource, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() # Email config. @@ -576,16 +595,23 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.hs = self.setup_test_homeserver(config=config) async def sendmail( - reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs - ): - self.email_attempts.append(msg) - - self.email_attempts = [] + reactor: IReactorTCP, + smtphost: str, + smtpport: int, + from_addr: str, + to_addr: str, + msg_bytes: bytes, + *args: Any, + **kwargs: Any, + ) -> None: + self.email_attempts.append(msg_bytes) + + self.email_attempts: List[bytes] = [] self.hs.get_send_email_handler()._sendmail = sendmail return self.hs - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.user_id = self.register_user("kermit", "test") @@ -593,83 +619,73 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.email = "test@example.com" self.url_3pid = b"account/3pid" - def test_add_valid_email(self): - self.get_success(self._add_email(self.email, self.email)) + def test_add_valid_email(self) -> None: + self._add_email(self.email, self.email) - def test_add_valid_email_second_time(self): - self.get_success(self._add_email(self.email, self.email)) - self.get_success( - self._request_token_invalid_email( - self.email, - expected_errcode=Codes.THREEPID_IN_USE, - expected_error="Email is already in use", - ) + def test_add_valid_email_second_time(self) -> None: + self._add_email(self.email, self.email) + self._request_token_invalid_email( + self.email, + expected_errcode=Codes.THREEPID_IN_USE, + expected_error="Email is already in use", ) - def test_add_valid_email_second_time_canonicalise(self): - self.get_success(self._add_email(self.email, self.email)) - self.get_success( - self._request_token_invalid_email( - "TEST@EXAMPLE.COM", - expected_errcode=Codes.THREEPID_IN_USE, - expected_error="Email is already in use", - ) + def test_add_valid_email_second_time_canonicalise(self) -> None: + self._add_email(self.email, self.email) + self._request_token_invalid_email( + "TEST@EXAMPLE.COM", + expected_errcode=Codes.THREEPID_IN_USE, + expected_error="Email is already in use", ) - def test_add_email_no_at(self): - self.get_success( - self._request_token_invalid_email( - "address-without-at.bar", - expected_errcode=Codes.UNKNOWN, - expected_error="Unable to parse email address", - ) + def test_add_email_no_at(self) -> None: + self._request_token_invalid_email( + "address-without-at.bar", + expected_errcode=Codes.UNKNOWN, + expected_error="Unable to parse email address", ) - def test_add_email_two_at(self): - self.get_success( - self._request_token_invalid_email( - "foo@foo@test.bar", - expected_errcode=Codes.UNKNOWN, - expected_error="Unable to parse email address", - ) + def test_add_email_two_at(self) -> None: + self._request_token_invalid_email( + "foo@foo@test.bar", + expected_errcode=Codes.UNKNOWN, + expected_error="Unable to parse email address", ) - def test_add_email_bad_format(self): - self.get_success( - self._request_token_invalid_email( - "user@bad.example.net@good.example.com", - expected_errcode=Codes.UNKNOWN, - expected_error="Unable to parse email address", - ) + def test_add_email_bad_format(self) -> None: + self._request_token_invalid_email( + "user@bad.example.net@good.example.com", + expected_errcode=Codes.UNKNOWN, + expected_error="Unable to parse email address", ) - def test_add_email_domain_to_lower(self): - self.get_success(self._add_email("foo@TEST.BAR", "foo@test.bar")) + def test_add_email_domain_to_lower(self) -> None: + self._add_email("foo@TEST.BAR", "foo@test.bar") - def test_add_email_domain_with_umlaut(self): - self.get_success(self._add_email("foo@Öumlaut.com", "foo@öumlaut.com")) + def test_add_email_domain_with_umlaut(self) -> None: + self._add_email("foo@Öumlaut.com", "foo@öumlaut.com") - def test_add_email_address_casefold(self): - self.get_success(self._add_email("Strauß@Example.com", "strauss@example.com")) + def test_add_email_address_casefold(self) -> None: + self._add_email("Strauß@Example.com", "strauss@example.com") - def test_address_trim(self): - self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar")) + def test_address_trim(self) -> None: + self._add_email(" foo@test.bar ", "foo@test.bar") @override_config({"rc_3pid_validation": {"burst_count": 3}}) - def test_ratelimit_by_ip(self): + def test_ratelimit_by_ip(self) -> None: """Tests that adding emails is ratelimited by IP""" # We expect to be able to set three emails before getting ratelimited. - self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar")) - self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar")) - self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar")) + self._add_email("foo1@test.bar", "foo1@test.bar") + self._add_email("foo2@test.bar", "foo2@test.bar") + self._add_email("foo3@test.bar", "foo3@test.bar") with self.assertRaises(HttpResponseException) as cm: - self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar")) + self._add_email("foo4@test.bar", "foo4@test.bar") self.assertEqual(cm.exception.code, 429) - def test_add_email_if_disabled(self): + def test_add_email_if_disabled(self) -> None: """Test adding email to profile when doing so is disallowed""" self.hs.config.registration.enable_3pid_changes = False @@ -695,7 +711,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): }, access_token=self.user_id_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Get user @@ -705,10 +721,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) - def test_delete_email(self): + def test_delete_email(self) -> None: """Test deleting an email from profile""" # Add a threepid self.get_success( @@ -727,7 +743,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): {"medium": "email", "address": self.email}, access_token=self.user_id_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # Get user channel = self.make_request( @@ -736,10 +752,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) - def test_delete_email_if_disabled(self): + def test_delete_email_if_disabled(self) -> None: """Test deleting an email from profile when disallowed""" self.hs.config.registration.enable_3pid_changes = False @@ -761,7 +777,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) # Get user @@ -771,11 +787,11 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) - def test_cant_add_email_without_clicking_link(self): + def test_cant_add_email_without_clicking_link(self) -> None: """Test that we do actually need to click the link in the email""" client_secret = "foobar" session_id = self._request_token(self.email, client_secret) @@ -797,7 +813,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): }, access_token=self.user_id_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) # Get user @@ -807,10 +823,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) - def test_no_valid_token(self): + def test_no_valid_token(self) -> None: """Test that we do actually need to request a token and can't just make a session up. """ @@ -832,7 +848,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): }, access_token=self.user_id_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) # Get user @@ -842,11 +858,11 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) @override_config({"next_link_domain_whitelist": None}) - def test_next_link(self): + def test_next_link(self) -> None: """Tests a valid next_link parameter value with no whitelist (good case)""" self._request_token( "something@example.com", @@ -856,7 +872,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): ) @override_config({"next_link_domain_whitelist": None}) - def test_next_link_exotic_protocol(self): + def test_next_link_exotic_protocol(self) -> None: """Tests using a esoteric protocol as a next_link parameter value. Someone may be hosting a client on IPFS etc. """ @@ -868,7 +884,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): ) @override_config({"next_link_domain_whitelist": None}) - def test_next_link_file_uri(self): + def test_next_link_file_uri(self) -> None: """Tests next_link parameters cannot be file URI""" # Attempt to use a next_link value that points to the local disk self._request_token( @@ -879,7 +895,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): ) @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]}) - def test_next_link_domain_whitelist(self): + def test_next_link_domain_whitelist(self) -> None: """Tests next_link parameters must fit the whitelist if provided""" # Ensure not providing a next_link parameter still works @@ -912,7 +928,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): ) @override_config({"next_link_domain_whitelist": []}) - def test_empty_next_link_domain_whitelist(self): + def test_empty_next_link_domain_whitelist(self) -> None: """Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially disallowed """ @@ -962,28 +978,28 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def _request_token_invalid_email( self, - email, - expected_errcode, - expected_error, - client_secret="foobar", - ): + email: str, + expected_errcode: str, + expected_error: str, + client_secret: str = "foobar", + ) -> None: channel = self.make_request( "POST", b"account/3pid/email/requestToken", {"client_secret": client_secret, "email": email, "send_attempt": 1}, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.result["body"]) self.assertEqual(expected_errcode, channel.json_body["errcode"]) self.assertEqual(expected_error, channel.json_body["error"]) - def _validate_token(self, link): + def _validate_token(self, link: str) -> None: # Remove the host path = link.replace("https://example.com", "") channel = self.make_request("GET", path, shorthand=False) self.assertEqual(200, channel.code, channel.result) - def _get_link_from_email(self): + def _get_link_from_email(self) -> str: assert self.email_attempts, "No emails have been sent" raw_msg = self.email_attempts[-1].decode("UTF-8") @@ -998,12 +1014,13 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): if not text: self.fail("Could not find text portion of email to parse") + assert text is not None match = re.search(r"https://example.com\S+", text) assert match, "Could not find link in email" return match.group(0) - def _add_email(self, request_email, expected_email): + def _add_email(self, request_email: str, expected_email: str) -> None: """Test adding an email to profile""" previous_email_attempts = len(self.email_attempts) @@ -1030,7 +1047,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) # Get user channel = self.make_request( @@ -1039,7 +1056,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): access_token=self.user_id_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) threepids = {threepid["address"] for threepid in channel.json_body["threepids"]} @@ -1055,18 +1072,18 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): url = "/_matrix/client/unstable/org.matrix.msc3720/account_status" - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: config = self.default_config() config["experimental_features"] = {"msc3720_enabled": True} return self.setup_test_homeserver(config=config) - def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.requester = self.register_user("requester", "password") self.requester_tok = self.login("requester", "password") - self.server_name = homeserver.config.server.server_name + self.server_name = hs.config.server.server_name - def test_missing_mxid(self): + def test_missing_mxid(self) -> None: """Tests that not providing any MXID raises an error.""" self._test_status( users=None, @@ -1074,7 +1091,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): expected_errcode=Codes.MISSING_PARAM, ) - def test_invalid_mxid(self): + def test_invalid_mxid(self) -> None: """Tests that providing an invalid MXID raises an error.""" self._test_status( users=["bad:test"], @@ -1082,7 +1099,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): expected_errcode=Codes.INVALID_PARAM, ) - def test_local_user_not_exists(self): + def test_local_user_not_exists(self) -> None: """Tests that the account status endpoints correctly reports that a user doesn't exist. """ @@ -1098,7 +1115,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): expected_failures=[], ) - def test_local_user_exists(self): + def test_local_user_exists(self) -> None: """Tests that the account status endpoint correctly reports that a user doesn't exist. """ @@ -1115,7 +1132,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): expected_failures=[], ) - def test_local_user_deactivated(self): + def test_local_user_deactivated(self) -> None: """Tests that the account status endpoint correctly reports a deactivated user.""" user = self.register_user("someuser", "password") self.get_success( @@ -1135,7 +1152,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): expected_failures=[], ) - def test_mixed_local_and_remote_users(self): + def test_mixed_local_and_remote_users(self) -> None: """Tests that if some users are remote the account status endpoint correctly merges the remote responses with the local result. """ @@ -1150,7 +1167,13 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): "@bad:badremote", ] - async def post_json(destination, path, data, *a, **kwa): + async def post_json( + destination: str, + path: str, + data: Optional[JsonDict] = None, + *a: Any, + **kwa: Any, + ) -> Union[JsonDict, list]: if destination == "remote": return { "account_statuses": { @@ -1160,9 +1183,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): }, } } - if destination == "otherremote": - return {} - if destination == "badremote": + elif destination == "badremote": # badremote tries to overwrite the status of a user that doesn't belong # to it (i.e. users[1]) with false data, which Synapse is expected to # ignore. @@ -1176,6 +1197,9 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): }, } } + # if destination == "otherremote" + else: + return {} # Register a mock that will return the expected result depending on the remote. self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) @@ -1205,7 +1229,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None, expected_failures: Optional[List[str]] = None, expected_errcode: Optional[str] = None, - ): + ) -> None: """Send a request to the account status endpoint and check that the response matches with what's expected. |