diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/handlers/test_cas.py | 19 | ||||
-rw-r--r-- | tests/handlers/test_federation.py | 36 | ||||
-rw-r--r-- | tests/handlers/test_presence.py | 13 | ||||
-rw-r--r-- | tests/push/test_http.py | 40 |
4 files changed, 65 insertions, 43 deletions
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index a267228846..a54aa29cf1 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -11,9 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor + from synapse.handlers.cas import CasResponse +from synapse.server import HomeServer +from synapse.util import Clock from tests.test_utils import simple_async_mock from tests.unittest import HomeserverTestCase, override_config @@ -24,7 +29,7 @@ SERVER_URL = "https://issuer/" class CasHandlerTestCase(HomeserverTestCase): - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["public_baseurl"] = BASE_URL cas_config = { @@ -40,7 +45,7 @@ class CasHandlerTestCase(HomeserverTestCase): return config - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver() self.handler = hs.get_cas_handler() @@ -51,7 +56,7 @@ class CasHandlerTestCase(HomeserverTestCase): return hs - def test_map_cas_user_to_user(self): + def test_map_cas_user_to_user(self) -> None: """Ensure that mapping the CAS user returned from a provider to an MXID works properly.""" # stub out the auth handler @@ -75,7 +80,7 @@ class CasHandlerTestCase(HomeserverTestCase): auth_provider_session_id=None, ) - def test_map_cas_user_to_existing_user(self): + def test_map_cas_user_to_existing_user(self) -> None: """Existing users can log in with CAS account.""" store = self.hs.get_datastores().main self.get_success( @@ -119,7 +124,7 @@ class CasHandlerTestCase(HomeserverTestCase): auth_provider_session_id=None, ) - def test_map_cas_user_to_invalid_localpart(self): + def test_map_cas_user_to_invalid_localpart(self) -> None: """CAS automaps invalid characters to base-64 encoding.""" # stub out the auth handler @@ -150,7 +155,7 @@ class CasHandlerTestCase(HomeserverTestCase): } } ) - def test_required_attributes(self): + def test_required_attributes(self) -> None: """The required attributes must be met from the CAS response.""" # stub out the auth handler @@ -166,7 +171,7 @@ class CasHandlerTestCase(HomeserverTestCase): auth_handler.complete_sso_login.assert_not_called() # The response doesn't have any department. - cas_response = CasResponse("test_user", {"userGroup": "staff"}) + cas_response = CasResponse("test_user", {"userGroup": ["staff"]}) request.reset_mock() self.get_success( self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index e8b4e39d1a..89078fc637 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import List +from typing import List, cast from unittest import TestCase +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.constants import EventTypes from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError from synapse.api.room_versions import RoomVersions @@ -23,7 +25,9 @@ from synapse.federation.federation_base import event_from_pdu_json from synapse.logging.context import LoggingContext, run_in_background from synapse.rest import admin from synapse.rest.client import login, room +from synapse.server import HomeServer from synapse.types import create_requester +from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest @@ -42,7 +46,7 @@ class FederationTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver(federation_http_client=None) self.handler = hs.get_federation_handler() self.store = hs.get_datastores().main @@ -50,7 +54,7 @@ class FederationTestCase(unittest.HomeserverTestCase): self._event_auth_handler = hs.get_event_auth_handler() return hs - def test_exchange_revoked_invite(self): + def test_exchange_revoked_invite(self) -> None: user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test") @@ -96,7 +100,7 @@ class FederationTestCase(unittest.HomeserverTestCase): self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure) self.assertEqual(failure.msg, "You are not invited to this room.") - def test_rejected_message_event_state(self): + def test_rejected_message_event_state(self) -> None: """ Check that we store the state group correctly for rejected non-state events. @@ -126,7 +130,7 @@ class FederationTestCase(unittest.HomeserverTestCase): "content": {}, "room_id": room_id, "sender": "@yetanotheruser:" + OTHER_SERVER, - "depth": join_event["depth"] + 1, + "depth": cast(int, join_event["depth"]) + 1, "prev_events": [join_event.event_id], "auth_events": [], "origin_server_ts": self.clock.time_msec(), @@ -149,7 +153,7 @@ class FederationTestCase(unittest.HomeserverTestCase): self.assertEqual(sg, sg2) - def test_rejected_state_event_state(self): + def test_rejected_state_event_state(self) -> None: """ Check that we store the state group correctly for rejected state events. @@ -180,7 +184,7 @@ class FederationTestCase(unittest.HomeserverTestCase): "content": {}, "room_id": room_id, "sender": "@yetanotheruser:" + OTHER_SERVER, - "depth": join_event["depth"] + 1, + "depth": cast(int, join_event["depth"]) + 1, "prev_events": [join_event.event_id], "auth_events": [], "origin_server_ts": self.clock.time_msec(), @@ -203,7 +207,7 @@ class FederationTestCase(unittest.HomeserverTestCase): self.assertEqual(sg, sg2) - def test_backfill_with_many_backward_extremities(self): + def test_backfill_with_many_backward_extremities(self) -> None: """ Check that we can backfill with many backward extremities. The goal is to make sure that when we only use a portion @@ -262,7 +266,7 @@ class FederationTestCase(unittest.HomeserverTestCase): ) self.get_success(d) - def test_backfill_floating_outlier_membership_auth(self): + def test_backfill_floating_outlier_membership_auth(self) -> None: """ As the local homeserver, check that we can properly process a federated event from the OTHER_SERVER with auth_events that include a floating @@ -377,7 +381,7 @@ class FederationTestCase(unittest.HomeserverTestCase): for ae in auth_events ] - self.handler.federation_client.get_event_auth = get_event_auth + self.handler.federation_client.get_event_auth = get_event_auth # type: ignore[assignment] with LoggingContext("receive_pdu"): # Fake the OTHER_SERVER federating the message event over to our local homeserver @@ -397,7 +401,7 @@ class FederationTestCase(unittest.HomeserverTestCase): @unittest.override_config( {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}} ) - def test_invite_by_user_ratelimit(self): + def test_invite_by_user_ratelimit(self) -> None: """Tests that invites from federation to a particular user are actually rate-limited. """ @@ -446,7 +450,9 @@ class FederationTestCase(unittest.HomeserverTestCase): exc=LimitExceededError, ) - def _build_and_send_join_event(self, other_server, other_user, room_id): + def _build_and_send_join_event( + self, other_server: str, other_user: str, room_id: str + ) -> EventBase: join_event = self.get_success( self.handler.on_make_join_request(other_server, room_id, other_user) ) @@ -469,7 +475,7 @@ class FederationTestCase(unittest.HomeserverTestCase): class EventFromPduTestCase(TestCase): - def test_valid_json(self): + def test_valid_json(self) -> None: """Valid JSON should be turned into an event.""" ev = event_from_pdu_json( { @@ -487,7 +493,7 @@ class EventFromPduTestCase(TestCase): self.assertIsInstance(ev, EventBase) - def test_invalid_numbers(self): + def test_invalid_numbers(self) -> None: """Invalid values for an integer should be rejected, all floats should be rejected.""" for value in [ -(2 ** 53), @@ -512,7 +518,7 @@ class EventFromPduTestCase(TestCase): RoomVersions.V6, ) - def test_invalid_nested(self): + def test_invalid_nested(self) -> None: """List and dictionaries are recursively searched.""" with self.assertRaises(SynapseError): event_from_pdu_json( diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 6ddec9ecf1..b2ed9cbe37 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -331,11 +331,11 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): # Extract presence update user ID and state information into lists of tuples db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states[0]] - presence_states = [(ps.user_id, ps.state) for ps in presence_states] + presence_states_compare = [(ps.user_id, ps.state) for ps in presence_states] # Compare what we put into the storage with what we got out. # They should be identical. - self.assertEqual(presence_states, db_presence_states) + self.assertEqual(presence_states_compare, db_presence_states) class PresenceTimeoutTestCase(unittest.TestCase): @@ -357,6 +357,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) + assert new_state is not None self.assertEqual(new_state.state, PresenceState.UNAVAILABLE) self.assertEqual(new_state.status_msg, status_msg) @@ -380,6 +381,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) + assert new_state is not None self.assertEqual(new_state.state, PresenceState.BUSY) self.assertEqual(new_state.status_msg, status_msg) @@ -399,6 +401,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) + assert new_state is not None self.assertEqual(new_state.state, PresenceState.OFFLINE) self.assertEqual(new_state.status_msg, status_msg) @@ -420,6 +423,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): ) self.assertIsNotNone(new_state) + assert new_state is not None self.assertEqual(new_state.state, PresenceState.ONLINE) self.assertEqual(new_state.status_msg, status_msg) @@ -477,6 +481,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): ) self.assertIsNotNone(new_state) + assert new_state is not None self.assertEqual(new_state.state, PresenceState.OFFLINE) self.assertEqual(new_state.status_msg, status_msg) @@ -653,13 +658,13 @@ class PresenceHandlerTestCase(unittest.HomeserverTestCase): self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None) def _set_presencestate_with_status_msg( - self, user_id: str, state: PresenceState, status_msg: Optional[str] + self, user_id: str, state: str, status_msg: Optional[str] ): """Set a PresenceState and status_msg and check the result. Args: user_id: User for that the status is to be set. - PresenceState: The new PresenceState. + state: The new PresenceState. status_msg: Status message that is to be set. """ self.get_success( diff --git a/tests/push/test_http.py b/tests/push/test_http.py index 6691e07128..ba158f5d93 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -11,15 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple +from typing import Any, Dict, List, Optional, Tuple from unittest.mock import Mock from twisted.internet.defer import Deferred +from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.logging.context import make_deferred_yieldable from synapse.push import PusherConfigException from synapse.rest.client import login, push_rule, receipts, room +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests.unittest import HomeserverTestCase, override_config @@ -35,13 +39,13 @@ class HTTPPusherTests(HomeserverTestCase): user_id = True hijack_auth = False - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["start_pushers"] = True return config - def make_homeserver(self, reactor, clock): - self.push_attempts: List[tuple[Deferred, str, dict]] = [] + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.push_attempts: List[Tuple[Deferred, str, dict]] = [] m = Mock() @@ -56,7 +60,7 @@ class HTTPPusherTests(HomeserverTestCase): return hs - def test_invalid_configuration(self): + def test_invalid_configuration(self) -> None: """Invalid push configurations should be rejected.""" # Register the user who gets notified user_id = self.register_user("user", "pass") @@ -68,7 +72,7 @@ class HTTPPusherTests(HomeserverTestCase): ) token_id = user_tuple.token_id - def test_data(data): + def test_data(data: Optional[JsonDict]) -> None: self.get_failure( self.hs.get_pusherpool().add_pusher( user_id=user_id, @@ -95,7 +99,7 @@ class HTTPPusherTests(HomeserverTestCase): # A url with an incorrect path isn't accepted. test_data({"url": "http://example.com/foo"}) - def test_sends_http(self): + def test_sends_http(self) -> None: """ The HTTP pusher will send pushes for each message to a HTTP endpoint when configured to do so. @@ -200,7 +204,7 @@ class HTTPPusherTests(HomeserverTestCase): self.assertEqual(len(pushers), 1) self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering) - def test_sends_high_priority_for_encrypted(self): + def test_sends_high_priority_for_encrypted(self) -> None: """ The HTTP pusher will send pushes at high priority if they correspond to an encrypted message. @@ -321,7 +325,7 @@ class HTTPPusherTests(HomeserverTestCase): ) self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high") - def test_sends_high_priority_for_one_to_one_only(self): + def test_sends_high_priority_for_one_to_one_only(self) -> None: """ The HTTP pusher will send pushes at high priority if they correspond to a message in a one-to-one room. @@ -404,7 +408,7 @@ class HTTPPusherTests(HomeserverTestCase): # check that this is low-priority self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") - def test_sends_high_priority_for_mention(self): + def test_sends_high_priority_for_mention(self) -> None: """ The HTTP pusher will send pushes at high priority if they correspond to a message containing the user's display name. @@ -480,7 +484,7 @@ class HTTPPusherTests(HomeserverTestCase): # check that this is low-priority self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") - def test_sends_high_priority_for_atroom(self): + def test_sends_high_priority_for_atroom(self) -> None: """ The HTTP pusher will send pushes at high priority if they correspond to a message that contains @room. @@ -563,7 +567,7 @@ class HTTPPusherTests(HomeserverTestCase): # check that this is low-priority self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low") - def test_push_unread_count_group_by_room(self): + def test_push_unread_count_group_by_room(self) -> None: """ The HTTP pusher will group unread count by number of unread rooms. """ @@ -576,7 +580,7 @@ class HTTPPusherTests(HomeserverTestCase): self._check_push_attempt(6, 1) @override_config({"push": {"group_unread_count_by_room": False}}) - def test_push_unread_count_message_count(self): + def test_push_unread_count_message_count(self) -> None: """ The HTTP pusher will send the total unread message count. """ @@ -589,7 +593,7 @@ class HTTPPusherTests(HomeserverTestCase): # last read receipt self._check_push_attempt(6, 3) - def _test_push_unread_count(self): + def _test_push_unread_count(self) -> None: """ Tests that the correct unread count appears in sent push notifications @@ -681,7 +685,7 @@ class HTTPPusherTests(HomeserverTestCase): self.helper.send(room_id, body="HELLO???", tok=other_access_token) - def _advance_time_and_make_push_succeed(self, expected_push_attempts): + def _advance_time_and_make_push_succeed(self, expected_push_attempts: int) -> None: self.pump() self.push_attempts[expected_push_attempts - 1][0].callback({}) @@ -708,7 +712,9 @@ class HTTPPusherTests(HomeserverTestCase): expected_unread_count_last_push, ) - def _send_read_request(self, access_token, message_event_id, room_id): + def _send_read_request( + self, access_token: str, message_event_id: str, room_id: str + ) -> None: # Now set the user's read receipt position to the first event # # This will actually trigger a new notification to be sent out so that @@ -748,7 +754,7 @@ class HTTPPusherTests(HomeserverTestCase): return user_id, access_token - def test_dont_notify_rule_overrides_message(self): + def test_dont_notify_rule_overrides_message(self) -> None: """ The override push rule will suppress notification """ |