diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
index abf2a0fe0d..c1579dac61 100644
--- a/tests/handlers/test_admin.py
+++ b/tests/handlers/test_admin.py
@@ -15,11 +15,15 @@
from collections import Counter
from unittest.mock import Mock
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse.rest.admin
import synapse.storage
from synapse.api.constants import EventTypes, JoinRules
from synapse.api.room_versions import RoomVersions
from synapse.rest.client import knock, login, room
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests import unittest
@@ -32,7 +36,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
knock.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.admin_handler = hs.get_admin_handler()
self.user1 = self.register_user("user1", "password")
@@ -41,7 +45,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.user2 = self.register_user("user2", "password")
self.token2 = self.login("user2", "password")
- def test_single_public_joined_room(self):
+ def test_single_public_joined_room(self) -> None:
"""Test that we write *all* events for a public room"""
room_id = self.helper.create_room_as(
self.user1, tok=self.token1, is_public=True
@@ -74,7 +78,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
self.assertEqual(counter[(EventTypes.Member, self.user2)], 1)
- def test_single_private_joined_room(self):
+ def test_single_private_joined_room(self) -> None:
"""Tests that we correctly write state when we can't see all events in
a room.
"""
@@ -112,7 +116,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
self.assertEqual(counter[(EventTypes.Member, self.user2)], 1)
- def test_single_left_room(self):
+ def test_single_left_room(self) -> None:
"""Tests that we don't see events in the room after we leave."""
room_id = self.helper.create_room_as(self.user1, tok=self.token1)
self.helper.send(room_id, body="Hello!", tok=self.token1)
@@ -144,7 +148,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
self.assertEqual(counter[(EventTypes.Member, self.user2)], 2)
- def test_single_left_rejoined_private_room(self):
+ def test_single_left_rejoined_private_room(self) -> None:
"""Tests that see the correct events in private rooms when we
repeatedly join and leave.
"""
@@ -185,7 +189,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
self.assertEqual(counter[(EventTypes.Member, self.user2)], 3)
- def test_invite(self):
+ def test_invite(self) -> None:
"""Tests that pending invites get handled correctly."""
room_id = self.helper.create_room_as(self.user1, tok=self.token1)
self.helper.send(room_id, body="Hello!", tok=self.token1)
@@ -204,7 +208,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(args[1].content["membership"], "invite")
self.assertTrue(args[2]) # Assert there is at least one bit of state
- def test_knock(self):
+ def test_knock(self) -> None:
"""Tests that knock get handled correctly."""
# create a knockable v7 room
room_id = self.helper.create_room_as(
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 072e6bbcdd..cead9f90df 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -59,11 +59,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.event_source = hs.get_event_sources()
def test_notify_interested_services(self):
- interested_service = self._mkservice(is_interested=True)
+ interested_service = self._mkservice(is_interested_in_event=True)
services = [
- self._mkservice(is_interested=False),
+ self._mkservice(is_interested_in_event=False),
interested_service,
- self._mkservice(is_interested=False),
+ self._mkservice(is_interested_in_event=False),
]
self.mock_as_api.query_user.return_value = make_awaitable(True)
@@ -85,7 +85,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def test_query_user_exists_unknown_user(self):
user_id = "@someone:anywhere"
- services = [self._mkservice(is_interested=True)]
+ services = [self._mkservice(is_interested_in_event=True)]
services[0].is_interested_in_user.return_value = True
self.mock_store.get_app_services.return_value = services
self.mock_store.get_user_by_id.return_value = make_awaitable(None)
@@ -102,7 +102,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def test_query_user_exists_known_user(self):
user_id = "@someone:anywhere"
- services = [self._mkservice(is_interested=True)]
+ services = [self._mkservice(is_interested_in_event=True)]
services[0].is_interested_in_user.return_value = True
self.mock_store.get_app_services.return_value = services
self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id})
@@ -127,11 +127,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
room_id = "!alpha:bet"
servers = ["aperture"]
- interested_service = self._mkservice_alias(is_interested_in_alias=True)
+ interested_service = self._mkservice_alias(is_room_alias_in_namespace=True)
services = [
- self._mkservice_alias(is_interested_in_alias=False),
+ self._mkservice_alias(is_room_alias_in_namespace=False),
interested_service,
- self._mkservice_alias(is_interested_in_alias=False),
+ self._mkservice_alias(is_room_alias_in_namespace=False),
]
self.mock_as_api.query_alias.return_value = make_awaitable(True)
@@ -275,7 +275,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
to be pushed out to interested appservices, and that the stream ID is
updated accordingly.
"""
- interested_service = self._mkservice(is_interested=True)
+ interested_service = self._mkservice(is_interested_in_event=True)
services = [interested_service]
self.mock_store.get_app_services.return_value = services
self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
@@ -304,7 +304,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
Test sending out of order ephemeral events to the appservice handler
are ignored.
"""
- interested_service = self._mkservice(is_interested=True)
+ interested_service = self._mkservice(is_interested_in_event=True)
services = [interested_service]
self.mock_store.get_app_services.return_value = services
@@ -325,17 +325,45 @@ class AppServiceHandlerTestCase(unittest.TestCase):
interested_service, ephemeral=[]
)
- def _mkservice(self, is_interested, protocols=None):
+ def _mkservice(
+ self, is_interested_in_event: bool, protocols: Optional[Iterable] = None
+ ) -> Mock:
+ """
+ Create a new mock representing an ApplicationService.
+
+ Args:
+ is_interested_in_event: Whether this application service will be considered
+ interested in all events.
+ protocols: The third-party protocols that this application service claims to
+ support.
+
+ Returns:
+ A mock representing the ApplicationService.
+ """
service = Mock()
- service.is_interested.return_value = make_awaitable(is_interested)
+ service.is_interested_in_event.return_value = make_awaitable(
+ is_interested_in_event
+ )
service.token = "mock_service_token"
service.url = "mock_service_url"
service.protocols = protocols
return service
- def _mkservice_alias(self, is_interested_in_alias):
+ def _mkservice_alias(self, is_room_alias_in_namespace: bool) -> Mock:
+ """
+ Create a new mock representing an ApplicationService that is or is not interested
+ any given room aliase.
+
+ Args:
+ is_room_alias_in_namespace: If true, the application service will be interested
+ in all room aliases that are queried against it. If false, the application
+ service will not be interested in any room aliases.
+
+ Returns:
+ A mock representing the ApplicationService.
+ """
service = Mock()
- service.is_interested_in_alias.return_value = is_interested_in_alias
+ service.is_room_alias_in_namespace.return_value = is_room_alias_in_namespace
service.token = "mock_service_token"
service.url = "mock_service_url"
return service
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 0c6e55e725..67a7829769 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -15,8 +15,12 @@ from unittest.mock import Mock
import pymacaroons
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.errors import AuthError, ResourceLimitError
from synapse.rest import admin
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
@@ -27,7 +31,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
admin.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.auth_handler = hs.get_auth_handler()
self.macaroon_generator = hs.get_macaroon_generator()
@@ -42,23 +46,23 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.user1 = self.register_user("a_user", "pass")
- def test_macaroon_caveats(self):
+ def test_macaroon_caveats(self) -> None:
token = self.macaroon_generator.generate_guest_access_token("a_user")
macaroon = pymacaroons.Macaroon.deserialize(token)
- def verify_gen(caveat):
+ def verify_gen(caveat: str) -> bool:
return caveat == "gen = 1"
- def verify_user(caveat):
+ def verify_user(caveat: str) -> bool:
return caveat == "user_id = a_user"
- def verify_type(caveat):
+ def verify_type(caveat: str) -> bool:
return caveat == "type = access"
- def verify_nonce(caveat):
+ def verify_nonce(caveat: str) -> bool:
return caveat.startswith("nonce =")
- def verify_guest(caveat):
+ def verify_guest(caveat: str) -> bool:
return caveat == "guest = true"
v = pymacaroons.Verifier()
@@ -69,7 +73,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
v.satisfy_general(verify_guest)
v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
- def test_short_term_login_token_gives_user_id(self):
+ def test_short_term_login_token_gives_user_id(self) -> None:
token = self.macaroon_generator.generate_short_term_login_token(
self.user1, "", duration_in_ms=5000
)
@@ -84,7 +88,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
AuthError,
)
- def test_short_term_login_token_gives_auth_provider(self):
+ def test_short_term_login_token_gives_auth_provider(self) -> None:
token = self.macaroon_generator.generate_short_term_login_token(
self.user1, auth_provider_id="my_idp"
)
@@ -92,7 +96,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(self.user1, res.user_id)
self.assertEqual("my_idp", res.auth_provider_id)
- def test_short_term_login_token_cannot_replace_user_id(self):
+ def test_short_term_login_token_cannot_replace_user_id(self) -> None:
token = self.macaroon_generator.generate_short_term_login_token(
self.user1, "", duration_in_ms=5000
)
@@ -112,7 +116,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
AuthError,
)
- def test_mau_limits_disabled(self):
+ def test_mau_limits_disabled(self) -> None:
self.auth_blocking._limit_usage_by_mau = False
# Ensure does not throw exception
self.get_success(
@@ -127,7 +131,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
)
- def test_mau_limits_exceeded_large(self):
+ def test_mau_limits_exceeded_large(self) -> None:
self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastores().main.get_monthly_active_count = Mock(
return_value=make_awaitable(self.large_number_of_users)
@@ -150,7 +154,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
ResourceLimitError,
)
- def test_mau_limits_parity(self):
+ def test_mau_limits_parity(self) -> None:
# Ensure we're not at the unix epoch.
self.reactor.advance(1)
self.auth_blocking._limit_usage_by_mau = True
@@ -189,7 +193,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
)
- def test_mau_limits_not_exceeded(self):
+ def test_mau_limits_not_exceeded(self) -> None:
self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastores().main.get_monthly_active_count = Mock(
@@ -211,7 +215,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
)
- def _get_macaroon(self):
+ def _get_macaroon(self) -> pymacaroons.Macaroon:
token = self.macaroon_generator.generate_short_term_login_token(
self.user1, "", duration_in_ms=5000
)
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index a267228846..a54aa29cf1 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -11,9 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Dict
from unittest.mock import Mock
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.handlers.cas import CasResponse
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
@@ -24,7 +29,7 @@ SERVER_URL = "https://issuer/"
class CasHandlerTestCase(HomeserverTestCase):
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["public_baseurl"] = BASE_URL
cas_config = {
@@ -40,7 +45,7 @@ class CasHandlerTestCase(HomeserverTestCase):
return config
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver()
self.handler = hs.get_cas_handler()
@@ -51,7 +56,7 @@ class CasHandlerTestCase(HomeserverTestCase):
return hs
- def test_map_cas_user_to_user(self):
+ def test_map_cas_user_to_user(self) -> None:
"""Ensure that mapping the CAS user returned from a provider to an MXID works properly."""
# stub out the auth handler
@@ -75,7 +80,7 @@ class CasHandlerTestCase(HomeserverTestCase):
auth_provider_session_id=None,
)
- def test_map_cas_user_to_existing_user(self):
+ def test_map_cas_user_to_existing_user(self) -> None:
"""Existing users can log in with CAS account."""
store = self.hs.get_datastores().main
self.get_success(
@@ -119,7 +124,7 @@ class CasHandlerTestCase(HomeserverTestCase):
auth_provider_session_id=None,
)
- def test_map_cas_user_to_invalid_localpart(self):
+ def test_map_cas_user_to_invalid_localpart(self) -> None:
"""CAS automaps invalid characters to base-64 encoding."""
# stub out the auth handler
@@ -150,7 +155,7 @@ class CasHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_required_attributes(self):
+ def test_required_attributes(self) -> None:
"""The required attributes must be met from the CAS response."""
# stub out the auth handler
@@ -166,7 +171,7 @@ class CasHandlerTestCase(HomeserverTestCase):
auth_handler.complete_sso_login.assert_not_called()
# The response doesn't have any department.
- cas_response = CasResponse("test_user", {"userGroup": "staff"})
+ cas_response = CasResponse("test_user", {"userGroup": ["staff"]})
request.reset_mock()
self.get_success(
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py
index ddda36c5a9..3a10791226 100644
--- a/tests/handlers/test_deactivate_account.py
+++ b/tests/handlers/test_deactivate_account.py
@@ -39,7 +39,7 @@ class DeactivateAccountTestCase(HomeserverTestCase):
self.user = self.register_user("user", "pass")
self.token = self.login("user", "pass")
- def _deactivate_my_account(self):
+ def _deactivate_my_account(self) -> None:
"""
Deactivates the account `self.user` using `self.token` and asserts
that it returns a 200 success code.
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 683677fd07..01ea7d2a42 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -14,9 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import synapse.api.errors
-import synapse.handlers.device
-import synapse.storage
+from typing import Optional
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.errors import NotFoundError, SynapseError
+from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests import unittest
@@ -25,28 +30,27 @@ user2 = "@theresa:bbb"
class DeviceTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None)
self.handler = hs.get_device_handler()
self.store = hs.get_datastores().main
return hs
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# These tests assume that it starts 1000 seconds in.
self.reactor.advance(1000)
- def test_device_is_created_with_invalid_name(self):
+ def test_device_is_created_with_invalid_name(self) -> None:
self.get_failure(
self.handler.check_device_registered(
user_id="@boris:foo",
device_id="foo",
- initial_device_display_name="a"
- * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1),
+ initial_device_display_name="a" * (MAX_DEVICE_DISPLAY_NAME_LEN + 1),
),
- synapse.api.errors.SynapseError,
+ SynapseError,
)
- def test_device_is_created_if_doesnt_exist(self):
+ def test_device_is_created_if_doesnt_exist(self) -> None:
res = self.get_success(
self.handler.check_device_registered(
user_id="@boris:foo",
@@ -59,7 +63,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
self.assertEqual(dev["display_name"], "display name")
- def test_device_is_preserved_if_exists(self):
+ def test_device_is_preserved_if_exists(self) -> None:
res1 = self.get_success(
self.handler.check_device_registered(
user_id="@boris:foo",
@@ -81,7 +85,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
self.assertEqual(dev["display_name"], "display name")
- def test_device_id_is_made_up_if_unspecified(self):
+ def test_device_id_is_made_up_if_unspecified(self) -> None:
device_id = self.get_success(
self.handler.check_device_registered(
user_id="@theresa:foo",
@@ -93,7 +97,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id))
self.assertEqual(dev["display_name"], "display")
- def test_get_devices_by_user(self):
+ def test_get_devices_by_user(self) -> None:
self._record_users()
res = self.get_success(self.handler.get_devices_by_user(user1))
@@ -131,7 +135,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
device_map["abc"],
)
- def test_get_device(self):
+ def test_get_device(self) -> None:
self._record_users()
res = self.get_success(self.handler.get_device(user1, "abc"))
@@ -146,21 +150,19 @@ class DeviceTestCase(unittest.HomeserverTestCase):
res,
)
- def test_delete_device(self):
+ def test_delete_device(self) -> None:
self._record_users()
# delete the device
self.get_success(self.handler.delete_device(user1, "abc"))
# check the device was deleted
- self.get_failure(
- self.handler.get_device(user1, "abc"), synapse.api.errors.NotFoundError
- )
+ self.get_failure(self.handler.get_device(user1, "abc"), NotFoundError)
# we'd like to check the access token was invalidated, but that's a
# bit of a PITA.
- def test_delete_device_and_device_inbox(self):
+ def test_delete_device_and_device_inbox(self) -> None:
self._record_users()
# add an device_inbox
@@ -191,7 +193,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
)
self.assertIsNone(res)
- def test_update_device(self):
+ def test_update_device(self) -> None:
self._record_users()
update = {"display_name": "new display"}
@@ -200,32 +202,29 @@ class DeviceTestCase(unittest.HomeserverTestCase):
res = self.get_success(self.handler.get_device(user1, "abc"))
self.assertEqual(res["display_name"], "new display")
- def test_update_device_too_long_display_name(self):
+ def test_update_device_too_long_display_name(self) -> None:
"""Update a device with a display name that is invalid (too long)."""
self._record_users()
# Request to update a device display name with a new value that is longer than allowed.
- update = {
- "display_name": "a"
- * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1)
- }
+ update = {"display_name": "a" * (MAX_DEVICE_DISPLAY_NAME_LEN + 1)}
self.get_failure(
self.handler.update_device(user1, "abc", update),
- synapse.api.errors.SynapseError,
+ SynapseError,
)
# Ensure the display name was not updated.
res = self.get_success(self.handler.get_device(user1, "abc"))
self.assertEqual(res["display_name"], "display 2")
- def test_update_unknown_device(self):
+ def test_update_unknown_device(self) -> None:
update = {"display_name": "new_display"}
self.get_failure(
self.handler.update_device("user_id", "unknown_device_id", update),
- synapse.api.errors.NotFoundError,
+ NotFoundError,
)
- def _record_users(self):
+ def _record_users(self) -> None:
# check this works for both devices which have a recorded client_ip,
# and those which don't.
self._record_user(user1, "xyz", "display 0")
@@ -238,8 +237,13 @@ class DeviceTestCase(unittest.HomeserverTestCase):
self.reactor.advance(10000)
def _record_user(
- self, user_id, device_id, display_name, access_token=None, ip=None
- ):
+ self,
+ user_id: str,
+ device_id: str,
+ display_name: str,
+ access_token: Optional[str] = None,
+ ip: Optional[str] = None,
+ ) -> None:
device_id = self.get_success(
self.handler.check_device_registered(
user_id=user_id,
@@ -248,7 +252,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
)
)
- if ip is not None:
+ if access_token is not None and ip is not None:
self.get_success(
self.store.insert_client_ip(
user_id, access_token, ip, "user_agent", device_id
@@ -258,7 +262,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
class DehydrationTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None)
self.handler = hs.get_device_handler()
self.registration = hs.get_registration_handler()
@@ -266,7 +270,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
self.store = hs.get_datastores().main
return hs
- def test_dehydrate_and_rehydrate_device(self):
+ def test_dehydrate_and_rehydrate_device(self) -> None:
user_id = "@boris:dehydration"
self.get_success(self.store.register_user(user_id, "foobar"))
@@ -303,7 +307,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
access_token=access_token,
device_id="not the right device ID",
),
- synapse.api.errors.NotFoundError,
+ NotFoundError,
)
# dehydrating the right devices should succeed and change our device ID
@@ -331,7 +335,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
# make sure that the device ID that we were initially assigned no longer exists
self.get_failure(
self.handler.get_device(user_id, device_id),
- synapse.api.errors.NotFoundError,
+ NotFoundError,
)
# make sure that there's no device available for dehydrating now
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 6e403a87c5..11ad44223d 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -12,14 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+from typing import Any, Awaitable, Callable, Dict
from unittest.mock import Mock
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse.api.errors
import synapse.rest.admin
from synapse.api.constants import EventTypes
from synapse.rest.client import directory, login, room
-from synapse.types import RoomAlias, create_requester
+from synapse.server import HomeServer
+from synapse.types import JsonDict, RoomAlias, create_requester
+from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
@@ -28,13 +32,15 @@ from tests.test_utils import make_awaitable
class DirectoryTestCase(unittest.HomeserverTestCase):
"""Tests the directory service."""
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.mock_federation = Mock()
self.mock_registry = Mock()
- self.query_handlers = {}
+ self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
- def register_query_handler(query_type, handler):
+ def register_query_handler(
+ query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
+ ) -> None:
self.query_handlers[query_type] = handler
self.mock_registry.register_query_handler = register_query_handler
@@ -54,7 +60,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
return hs
- def test_get_local_association(self):
+ def test_get_local_association(self) -> None:
self.get_success(
self.store.create_room_alias_association(
self.my_room, "!8765qwer:test", ["test"]
@@ -65,7 +71,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
- def test_get_remote_association(self):
+ def test_get_remote_association(self) -> None:
self.mock_federation.make_query.return_value = make_awaitable(
{"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
)
@@ -83,7 +89,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
ignore_backoff=True,
)
- def test_incoming_fed_query(self):
+ def test_incoming_fed_query(self) -> None:
self.get_success(
self.store.create_room_alias_association(
self.your_room, "!8765asdf:test", ["test"]
@@ -105,7 +111,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
directory.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_directory_handler()
# Create user
@@ -125,7 +131,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
self.test_user_tok = self.login("user", "pass")
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
- def test_create_alias_joined_room(self):
+ def test_create_alias_joined_room(self) -> None:
"""A user can create an alias for a room they're in."""
self.get_success(
self.handler.create_association(
@@ -135,7 +141,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
)
)
- def test_create_alias_other_room(self):
+ def test_create_alias_other_room(self) -> None:
"""A user cannot create an alias for a room they're NOT in."""
other_room_id = self.helper.create_room_as(
self.admin_user, tok=self.admin_user_tok
@@ -150,7 +156,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
synapse.api.errors.SynapseError,
)
- def test_create_alias_admin(self):
+ def test_create_alias_admin(self) -> None:
"""An admin can create an alias for a room they're NOT in."""
other_room_id = self.helper.create_room_as(
self.test_user, tok=self.test_user_tok
@@ -173,7 +179,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
directory.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.handler = hs.get_directory_handler()
self.state_handler = hs.get_state_handler()
@@ -195,7 +201,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
self.test_user_tok = self.login("user", "pass")
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
- def _create_alias(self, user):
+ def _create_alias(self, user) -> None:
# Create a new alias to this room.
self.get_success(
self.store.create_room_alias_association(
@@ -203,7 +209,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
)
)
- def test_delete_alias_not_allowed(self):
+ def test_delete_alias_not_allowed(self) -> None:
"""A user that doesn't meet the expected guidelines cannot delete an alias."""
self._create_alias(self.admin_user)
self.get_failure(
@@ -213,7 +219,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
synapse.api.errors.AuthError,
)
- def test_delete_alias_creator(self):
+ def test_delete_alias_creator(self) -> None:
"""An alias creator can delete their own alias."""
# Create an alias from a different user.
self._create_alias(self.test_user)
@@ -232,7 +238,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
synapse.api.errors.SynapseError,
)
- def test_delete_alias_admin(self):
+ def test_delete_alias_admin(self) -> None:
"""A server admin can delete an alias created by another user."""
# Create an alias from a different user.
self._create_alias(self.test_user)
@@ -251,7 +257,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
synapse.api.errors.SynapseError,
)
- def test_delete_alias_sufficient_power(self):
+ def test_delete_alias_sufficient_power(self) -> None:
"""A user with a sufficient power level should be able to delete an alias."""
self._create_alias(self.admin_user)
@@ -288,7 +294,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
directory.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.handler = hs.get_directory_handler()
self.state_handler = hs.get_state_handler()
@@ -317,7 +323,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
)
return room_alias
- def _set_canonical_alias(self, content):
+ def _set_canonical_alias(self, content) -> None:
"""Configure the canonical alias state on the room."""
self.helper.send_state(
self.room_id,
@@ -334,7 +340,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
)
)
- def test_remove_alias(self):
+ def test_remove_alias(self) -> None:
"""Removing an alias that is the canonical alias should remove it there too."""
# Set this new alias as the canonical alias for this room
self._set_canonical_alias(
@@ -356,7 +362,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
self.assertNotIn("alias", data["content"])
self.assertNotIn("alt_aliases", data["content"])
- def test_remove_other_alias(self):
+ def test_remove_other_alias(self) -> None:
"""Removing an alias listed as in alt_aliases should remove it there too."""
# Create a second alias.
other_test_alias = "#test2:test"
@@ -393,7 +399,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
servlets = [directory.register_servlets, room.register_servlets]
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
config = super().default_config()
# Add custom alias creation rules to the config.
@@ -403,7 +409,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
return config
- def test_denied(self):
+ def test_denied(self) -> None:
room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request(
@@ -413,7 +419,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
)
self.assertEqual(403, channel.code, channel.result)
- def test_allowed(self):
+ def test_allowed(self) -> None:
room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request(
@@ -423,7 +429,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
)
self.assertEqual(200, channel.code, channel.result)
- def test_denied_during_creation(self):
+ def test_denied_during_creation(self) -> None:
"""A room alias that is not allowed should be rejected during creation."""
# Invalid room alias.
self.helper.create_room_as(
@@ -432,7 +438,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
extra_content={"room_alias_name": "foo"},
)
- def test_allowed_during_creation(self):
+ def test_allowed_during_creation(self) -> None:
"""A valid room alias should be allowed during creation."""
room_id = self.helper.create_room_as(
self.user_id,
@@ -459,7 +465,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
data = {"room_alias_name": "unofficial_test"}
allowed_localpart = "allowed"
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
config = super().default_config()
# Add custom room list publication rules to the config.
@@ -474,7 +480,9 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
return config
- def prepare(self, reactor, clock, hs):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
+ ) -> HomeServer:
self.allowed_user_id = self.register_user(self.allowed_localpart, "pass")
self.allowed_access_token = self.login(self.allowed_localpart, "pass")
@@ -483,7 +491,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
return hs
- def test_denied_without_publication_permission(self):
+ def test_denied_without_publication_permission(self) -> None:
"""
Try to create a room, register an alias for it, and publish it,
as a user without permission to publish rooms.
@@ -497,7 +505,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
expect_code=403,
)
- def test_allowed_when_creating_private_room(self):
+ def test_allowed_when_creating_private_room(self) -> None:
"""
Try to create a room, register an alias for it, and NOT publish it,
as a user without permission to publish rooms.
@@ -511,7 +519,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
expect_code=200,
)
- def test_allowed_with_publication_permission(self):
+ def test_allowed_with_publication_permission(self) -> None:
"""
Try to create a room, register an alias for it, and publish it,
as a user WITH permission to publish rooms.
@@ -525,7 +533,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
expect_code=200,
)
- def test_denied_publication_with_invalid_alias(self):
+ def test_denied_publication_with_invalid_alias(self) -> None:
"""
Try to create a room, register an alias for it, and publish it,
as a user WITH permission to publish rooms.
@@ -538,7 +546,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
expect_code=403,
)
- def test_can_create_as_private_room_after_rejection(self):
+ def test_can_create_as_private_room_after_rejection(self) -> None:
"""
After failing to publish a room with an alias as a user without publish permission,
retry as the same user, but without publishing the room.
@@ -549,7 +557,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
self.test_denied_without_publication_permission()
self.test_allowed_when_creating_private_room()
- def test_can_create_with_permission_after_rejection(self):
+ def test_can_create_with_permission_after_rejection(self) -> None:
"""
After failing to publish a room with an alias as a user without publish permission,
retry as someone with permission, using the same alias.
@@ -566,7 +574,9 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
servlets = [directory.register_servlets, room.register_servlets]
- def prepare(self, reactor, clock, hs):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
+ ) -> HomeServer:
room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request(
@@ -579,7 +589,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
return hs
- def test_disabling_room_list(self):
+ def test_disabling_room_list(self) -> None:
self.room_list_handler.enable_room_list_search = True
self.directory_handler.enable_room_list_search = True
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 9338ab92e9..ac21a28c43 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -20,33 +20,37 @@ from parameterized import parameterized
from signedjson import key as key, sign as sign
from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import Codes, SynapseError
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(federation_client=mock.Mock())
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_e2e_keys_handler()
self.store = self.hs.get_datastores().main
- def test_query_local_devices_no_devices(self):
+ def test_query_local_devices_no_devices(self) -> None:
"""If the user has no devices, we expect an empty list."""
local_user = "@boris:" + self.hs.hostname
res = self.get_success(self.handler.query_local_devices({local_user: None}))
self.assertDictEqual(res, {local_user: {}})
- def test_reupload_one_time_keys(self):
+ def test_reupload_one_time_keys(self) -> None:
"""we should be able to re-upload the same keys"""
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
- keys = {
+ keys: JsonDict = {
"alg1:k1": "key1",
"alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
"alg2:k3": {"key": "key3"},
@@ -74,7 +78,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}}
)
- def test_change_one_time_keys(self):
+ def test_change_one_time_keys(self) -> None:
"""attempts to change one-time-keys should be rejected"""
local_user = "@boris:" + self.hs.hostname
@@ -134,7 +138,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
SynapseError,
)
- def test_claim_one_time_key(self):
+ def test_claim_one_time_key(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
keys = {"alg1:k1": "key1"}
@@ -161,7 +165,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)
- def test_fallback_key(self):
+ def test_fallback_key(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
fallback_key = {"alg1:k1": "fallback_key1"}
@@ -294,7 +298,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
)
- def test_replace_master_key(self):
+ def test_replace_master_key(self) -> None:
"""uploading a new signing key should make the old signing key unavailable"""
local_user = "@boris:" + self.hs.hostname
keys1 = {
@@ -328,7 +332,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
- def test_reupload_signatures(self):
+ def test_reupload_signatures(self) -> None:
"""re-uploading a signature should not fail"""
local_user = "@boris:" + self.hs.hostname
keys1 = {
@@ -433,7 +437,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1)
self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2)
- def test_self_signing_key_doesnt_show_up_as_device(self):
+ def test_self_signing_key_doesnt_show_up_as_device(self) -> None:
"""signing keys should be hidden when fetching a user's devices"""
local_user = "@boris:" + self.hs.hostname
keys1 = {
@@ -462,7 +466,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
res = self.get_success(self.handler.query_local_devices({local_user: None}))
self.assertDictEqual(res, {local_user: {}})
- def test_upload_signatures(self):
+ def test_upload_signatures(self) -> None:
"""should check signatures that are uploaded"""
# set up a user with cross-signing keys and a device. This user will
# try uploading signatures
@@ -686,7 +690,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
)
- def test_query_devices_remote_no_sync(self):
+ def test_query_devices_remote_no_sync(self) -> None:
"""Tests that querying keys for a remote user that we don't share a room
with returns the cross signing keys correctly.
"""
@@ -759,7 +763,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)
- def test_query_devices_remote_sync(self):
+ def test_query_devices_remote_sync(self) -> None:
"""Tests that querying keys for a remote user that we share a room with,
but haven't yet fetched the keys for, returns the cross signing keys
correctly.
@@ -845,7 +849,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
(["device_1", "device_2"],),
]
)
- def test_query_all_devices_caches_result(self, device_ids: Iterable[str]):
+ def test_query_all_devices_caches_result(self, device_ids: Iterable[str]) -> None:
"""Test that requests for all of a remote user's devices are cached.
We do this by asserting that only one call over federation was made, and that
@@ -853,7 +857,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
"""
local_user_id = "@test:test"
remote_user_id = "@test:other"
- request_body = {"device_keys": {remote_user_id: []}}
+ request_body: JsonDict = {"device_keys": {remote_user_id: []}}
response_devices = [
{
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index e8b4e39d1a..89078fc637 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import List
+from typing import List, cast
from unittest import TestCase
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
from synapse.api.room_versions import RoomVersions
@@ -23,7 +25,9 @@ from synapse.federation.federation_base import event_from_pdu_json
from synapse.logging.context import LoggingContext, run_in_background
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.server import HomeServer
from synapse.types import create_requester
+from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
@@ -42,7 +46,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver(federation_http_client=None)
self.handler = hs.get_federation_handler()
self.store = hs.get_datastores().main
@@ -50,7 +54,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
self._event_auth_handler = hs.get_event_auth_handler()
return hs
- def test_exchange_revoked_invite(self):
+ def test_exchange_revoked_invite(self) -> None:
user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test")
@@ -96,7 +100,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure)
self.assertEqual(failure.msg, "You are not invited to this room.")
- def test_rejected_message_event_state(self):
+ def test_rejected_message_event_state(self) -> None:
"""
Check that we store the state group correctly for rejected non-state events.
@@ -126,7 +130,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
"content": {},
"room_id": room_id,
"sender": "@yetanotheruser:" + OTHER_SERVER,
- "depth": join_event["depth"] + 1,
+ "depth": cast(int, join_event["depth"]) + 1,
"prev_events": [join_event.event_id],
"auth_events": [],
"origin_server_ts": self.clock.time_msec(),
@@ -149,7 +153,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(sg, sg2)
- def test_rejected_state_event_state(self):
+ def test_rejected_state_event_state(self) -> None:
"""
Check that we store the state group correctly for rejected state events.
@@ -180,7 +184,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
"content": {},
"room_id": room_id,
"sender": "@yetanotheruser:" + OTHER_SERVER,
- "depth": join_event["depth"] + 1,
+ "depth": cast(int, join_event["depth"]) + 1,
"prev_events": [join_event.event_id],
"auth_events": [],
"origin_server_ts": self.clock.time_msec(),
@@ -203,7 +207,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(sg, sg2)
- def test_backfill_with_many_backward_extremities(self):
+ def test_backfill_with_many_backward_extremities(self) -> None:
"""
Check that we can backfill with many backward extremities.
The goal is to make sure that when we only use a portion
@@ -262,7 +266,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
)
self.get_success(d)
- def test_backfill_floating_outlier_membership_auth(self):
+ def test_backfill_floating_outlier_membership_auth(self) -> None:
"""
As the local homeserver, check that we can properly process a federated
event from the OTHER_SERVER with auth_events that include a floating
@@ -377,7 +381,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
for ae in auth_events
]
- self.handler.federation_client.get_event_auth = get_event_auth
+ self.handler.federation_client.get_event_auth = get_event_auth # type: ignore[assignment]
with LoggingContext("receive_pdu"):
# Fake the OTHER_SERVER federating the message event over to our local homeserver
@@ -397,7 +401,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
@unittest.override_config(
{"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
)
- def test_invite_by_user_ratelimit(self):
+ def test_invite_by_user_ratelimit(self) -> None:
"""Tests that invites from federation to a particular user are
actually rate-limited.
"""
@@ -446,7 +450,9 @@ class FederationTestCase(unittest.HomeserverTestCase):
exc=LimitExceededError,
)
- def _build_and_send_join_event(self, other_server, other_user, room_id):
+ def _build_and_send_join_event(
+ self, other_server: str, other_user: str, room_id: str
+ ) -> EventBase:
join_event = self.get_success(
self.handler.on_make_join_request(other_server, room_id, other_user)
)
@@ -469,7 +475,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
class EventFromPduTestCase(TestCase):
- def test_valid_json(self):
+ def test_valid_json(self) -> None:
"""Valid JSON should be turned into an event."""
ev = event_from_pdu_json(
{
@@ -487,7 +493,7 @@ class EventFromPduTestCase(TestCase):
self.assertIsInstance(ev, EventBase)
- def test_invalid_numbers(self):
+ def test_invalid_numbers(self) -> None:
"""Invalid values for an integer should be rejected, all floats should be rejected."""
for value in [
-(2 ** 53),
@@ -512,7 +518,7 @@ class EventFromPduTestCase(TestCase):
RoomVersions.V6,
)
- def test_invalid_nested(self):
+ def test_invalid_nested(self) -> None:
"""List and dictionaries are recursively searched."""
with self.assertRaises(SynapseError):
event_from_pdu_json(
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index e8418b6638..014815db6e 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -13,14 +13,18 @@
# limitations under the License.
import json
import os
+from typing import Any, Dict
from unittest.mock import ANY, Mock, patch
from urllib.parse import parse_qs, urlparse
import pymacaroons
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.handlers.sso import MappingException
from synapse.server import HomeServer
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
+from synapse.util import Clock
from synapse.util.macaroons import get_value_from_macaroon
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
@@ -98,7 +102,7 @@ class TestMappingProviderFailures(TestMappingProvider):
}
-async def get_json(url):
+async def get_json(url: str) -> JsonDict:
# Mock get_json calls to handle jwks & oidc discovery endpoints
if url == WELL_KNOWN:
# Minimal discovery document, as defined in OpenID.Discovery
@@ -116,6 +120,8 @@ async def get_json(url):
elif url == JWKS_URI:
return {"keys": []}
+ return {}
+
def _key_file_path() -> str:
"""path to a file containing the private half of a test key"""
@@ -147,12 +153,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
if not HAS_OIDC:
skip = "requires OIDC"
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["public_baseurl"] = BASE_URL
return config
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.http_client = Mock(spec=["get_json"])
self.http_client.get_json.side_effect = get_json
self.http_client.user_agent = b"Synapse Test"
@@ -164,7 +170,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
sso_handler = hs.get_sso_handler()
# Mock the render error method.
self.render_error = Mock(return_value=None)
- sso_handler.render_error = self.render_error
+ sso_handler.render_error = self.render_error # type: ignore[assignment]
# Reduce the number of attempts when generating MXIDs.
sso_handler._MAP_USERNAME_RETRIES = 3
@@ -193,14 +199,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
return args
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_config(self):
+ def test_config(self) -> None:
"""Basic config correctly sets up the callback URL and client auth correctly."""
self.assertEqual(self.provider._callback_url, CALLBACK_URL)
self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
@override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
- def test_discovery(self):
+ def test_discovery(self) -> None:
"""The handler should discover the endpoints from OIDC discovery document."""
# This would throw if some metadata were invalid
metadata = self.get_success(self.provider.load_metadata())
@@ -219,13 +225,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
- def test_no_discovery(self):
+ def test_no_discovery(self) -> None:
"""When discovery is disabled, it should not try to load from discovery document."""
self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
- def test_load_jwks(self):
+ def test_load_jwks(self) -> None:
"""JWKS loading is done once (then cached) if used."""
jwks = self.get_success(self.provider.load_jwks())
self.http_client.get_json.assert_called_once_with(JWKS_URI)
@@ -253,7 +259,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_validate_config(self):
+ def test_validate_config(self) -> None:
"""Provider metadatas are extensively validated."""
h = self.provider
@@ -336,14 +342,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
force_load_metadata()
@override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
- def test_skip_verification(self):
+ def test_skip_verification(self) -> None:
"""Provider metadata validation can be disabled by config."""
with self.metadata_edit({"issuer": "http://insecure"}):
# This should not throw
get_awaitable_result(self.provider.load_metadata())
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_redirect_request(self):
+ def test_redirect_request(self) -> None:
"""The redirect request has the right arguments & generates a valid session cookie."""
req = Mock(spec=["cookies"])
req.cookies = []
@@ -387,7 +393,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(redirect, "http://client/redirect")
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_callback_error(self):
+ def test_callback_error(self) -> None:
"""Errors from the provider returned in the callback are displayed."""
request = Mock(args={})
request.args[b"error"] = [b"invalid_client"]
@@ -399,7 +405,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertRenderedError("invalid_client", "some description")
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_callback(self):
+ def test_callback(self) -> None:
"""Code callback works and display errors if something went wrong.
A lot of scenarios are tested here:
@@ -428,9 +434,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": username,
}
expected_user_id = "@%s:%s" % (username, self.hs.hostname)
- self.provider._exchange_code = simple_async_mock(return_value=token)
- self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
- self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
+ self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
+ self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
+ self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -468,7 +474,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertRenderedError("mapping_error")
# Handle ID token errors
- self.provider._parse_id_token = simple_async_mock(raises=Exception())
+ self.provider._parse_id_token = simple_async_mock(raises=Exception()) # type: ignore[assignment]
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_token")
@@ -483,7 +489,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"type": "bearer",
"access_token": "access_token",
}
- self.provider._exchange_code = simple_async_mock(return_value=token)
+ self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with(
@@ -510,8 +516,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
id_token = {
"sid": "abcdefgh",
}
- self.provider._parse_id_token = simple_async_mock(return_value=id_token)
- self.provider._exchange_code = simple_async_mock(return_value=token)
+ self.provider._parse_id_token = simple_async_mock(return_value=id_token) # type: ignore[assignment]
+ self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
auth_handler.complete_sso_login.reset_mock()
self.provider._fetch_userinfo.reset_mock()
self.get_success(self.handler.handle_oidc_callback(request))
@@ -531,21 +537,21 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.render_error.assert_not_called()
# Handle userinfo fetching error
- self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
+ self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) # type: ignore[assignment]
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("fetch_error")
# Handle code exchange failure
from synapse.handlers.oidc import OidcError
- self.provider._exchange_code = simple_async_mock(
+ self.provider._exchange_code = simple_async_mock( # type: ignore[assignment]
raises=OidcError("invalid_request")
)
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request")
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_callback_session(self):
+ def test_callback_session(self) -> None:
"""The callback verifies the session presence and validity"""
request = Mock(spec=["args", "getCookie", "cookies"])
@@ -590,7 +596,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
@override_config(
{"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
)
- def test_exchange_code(self):
+ def test_exchange_code(self) -> None:
"""Code exchange behaves correctly and handles various error scenarios."""
token = {"type": "bearer"}
token_json = json.dumps(token).encode("utf-8")
@@ -686,7 +692,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_exchange_code_jwt_key(self):
+ def test_exchange_code_jwt_key(self) -> None:
"""Test that code exchange works with a JWK client secret."""
from authlib.jose import jwt
@@ -741,7 +747,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_exchange_code_no_auth(self):
+ def test_exchange_code_no_auth(self) -> None:
"""Test that code exchange works with no client secret."""
token = {"type": "bearer"}
self.http_client.request = simple_async_mock(
@@ -776,7 +782,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_extra_attributes(self):
+ def test_extra_attributes(self) -> None:
"""
Login while using a mapping provider that implements get_extra_attributes.
"""
@@ -790,8 +796,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "foo",
"phone": "1234567",
}
- self.provider._exchange_code = simple_async_mock(return_value=token)
- self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
+ self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
+ self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -817,12 +823,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_map_userinfo_to_user(self):
+ def test_map_userinfo_to_user(self) -> None:
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
- userinfo = {
+ userinfo: dict = {
"sub": "test_user",
"username": "test_user",
}
@@ -870,7 +876,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
@override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
- def test_map_userinfo_to_existing_user(self):
+ def test_map_userinfo_to_existing_user(self) -> None:
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""
store = self.hs.get_datastores().main
user = UserID.from_string("@test_user:test")
@@ -974,7 +980,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_map_userinfo_to_invalid_localpart(self):
+ def test_map_userinfo_to_invalid_localpart(self) -> None:
"""If the mapping provider generates an invalid localpart it should be rejected."""
self.get_success(
_make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
@@ -991,7 +997,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_map_userinfo_to_user_retries(self):
+ def test_map_userinfo_to_user_retries(self) -> None:
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -1039,7 +1045,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_empty_localpart(self):
+ def test_empty_localpart(self) -> None:
"""Attempts to map onto an empty localpart should be rejected."""
userinfo = {
"sub": "tester",
@@ -1058,7 +1064,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_null_localpart(self):
+ def test_null_localpart(self) -> None:
"""Mapping onto a null localpart via an empty OIDC attribute should be rejected"""
userinfo = {
"sub": "tester",
@@ -1075,7 +1081,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_attribute_requirements(self):
+ def test_attribute_requirements(self) -> None:
"""The required attributes must be met from the OIDC userinfo response."""
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -1115,7 +1121,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_attribute_requirements_contains(self):
+ def test_attribute_requirements_contains(self) -> None:
"""Test that auth succeeds if userinfo attribute CONTAINS required value"""
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -1146,7 +1152,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_attribute_requirements_mismatch(self):
+ def test_attribute_requirements_mismatch(self) -> None:
"""
Test that auth fails if attributes exist but don't match,
or are non-string values.
@@ -1154,7 +1160,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
# userinfo with "test": "not_foobar" attribute should fail
- userinfo = {
+ userinfo: dict = {
"sub": "tester",
"username": "tester",
"test": "not_foobar",
@@ -1248,9 +1254,9 @@ async def _make_callback_with_userinfo(
handler = hs.get_oidc_handler()
provider = handler._providers["oidc"]
- provider._exchange_code = simple_async_mock(return_value={"id_token": ""})
- provider._parse_id_token = simple_async_mock(return_value=userinfo)
- provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
+ provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) # type: ignore[assignment]
+ provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
+ provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
state = "state"
session = handler._token_generator.generate_oidc_session_token(
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 49d832de81..d401fda938 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -124,7 +124,6 @@ class PasswordCustomAuthProvider:
("m.login.password", ("password",)): self.check_auth,
}
)
- pass
def check_auth(self, *args):
return mock_password_provider.check_auth(*args)
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 6ddec9ecf1..b2ed9cbe37 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -331,11 +331,11 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
# Extract presence update user ID and state information into lists of tuples
db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states[0]]
- presence_states = [(ps.user_id, ps.state) for ps in presence_states]
+ presence_states_compare = [(ps.user_id, ps.state) for ps in presence_states]
# Compare what we put into the storage with what we got out.
# They should be identical.
- self.assertEqual(presence_states, db_presence_states)
+ self.assertEqual(presence_states_compare, db_presence_states)
class PresenceTimeoutTestCase(unittest.TestCase):
@@ -357,6 +357,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
self.assertIsNotNone(new_state)
+ assert new_state is not None
self.assertEqual(new_state.state, PresenceState.UNAVAILABLE)
self.assertEqual(new_state.status_msg, status_msg)
@@ -380,6 +381,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
self.assertIsNotNone(new_state)
+ assert new_state is not None
self.assertEqual(new_state.state, PresenceState.BUSY)
self.assertEqual(new_state.status_msg, status_msg)
@@ -399,6 +401,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
self.assertIsNotNone(new_state)
+ assert new_state is not None
self.assertEqual(new_state.state, PresenceState.OFFLINE)
self.assertEqual(new_state.status_msg, status_msg)
@@ -420,6 +423,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
)
self.assertIsNotNone(new_state)
+ assert new_state is not None
self.assertEqual(new_state.state, PresenceState.ONLINE)
self.assertEqual(new_state.status_msg, status_msg)
@@ -477,6 +481,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
)
self.assertIsNotNone(new_state)
+ assert new_state is not None
self.assertEqual(new_state.state, PresenceState.OFFLINE)
self.assertEqual(new_state.status_msg, status_msg)
@@ -653,13 +658,13 @@ class PresenceHandlerTestCase(unittest.HomeserverTestCase):
self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None)
def _set_presencestate_with_status_msg(
- self, user_id: str, state: PresenceState, status_msg: Optional[str]
+ self, user_id: str, state: str, status_msg: Optional[str]
):
"""Set a PresenceState and status_msg and check the result.
Args:
user_id: User for that the status is to be set.
- PresenceState: The new PresenceState.
+ state: The new PresenceState.
status_msg: Status message that is to be set.
"""
self.get_success(
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 972cbac6e4..1ec105c373 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -11,14 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict
+from typing import Any, Awaitable, Callable, Dict
from unittest.mock import Mock
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse.types
from synapse.api.errors import AuthError, SynapseError
from synapse.rest import admin
from synapse.server import HomeServer
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
+from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
@@ -29,13 +32,15 @@ class ProfileTestCase(unittest.HomeserverTestCase):
servlets = [admin.register_servlets]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.mock_federation = Mock()
self.mock_registry = Mock()
- self.query_handlers = {}
+ self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
- def register_query_handler(query_type, handler):
+ def register_query_handler(
+ query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
+ ) -> None:
self.query_handlers[query_type] = handler
self.mock_registry.register_query_handler = register_query_handler
@@ -47,7 +52,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
return hs
- def prepare(self, reactor, clock, hs: HomeServer):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.frank = UserID.from_string("@1234abcd:test")
@@ -58,7 +63,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.handler = hs.get_profile_handler()
- def test_get_my_name(self):
+ def test_get_my_name(self) -> None:
self.get_success(
self.store.set_profile_displayname(self.frank.localpart, "Frank")
)
@@ -67,7 +72,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual("Frank", displayname)
- def test_set_my_name(self):
+ def test_set_my_name(self) -> None:
self.get_success(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
@@ -110,7 +115,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.get_success(self.store.get_profile_displayname(self.frank.localpart))
)
- def test_set_my_name_if_disabled(self):
+ def test_set_my_name_if_disabled(self) -> None:
self.hs.config.registration.enable_set_displayname = False
# Setting displayname for the first time is allowed
@@ -135,7 +140,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
SynapseError,
)
- def test_set_my_name_noauth(self):
+ def test_set_my_name_noauth(self) -> None:
self.get_failure(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
@@ -143,7 +148,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
AuthError,
)
- def test_get_other_name(self):
+ def test_get_other_name(self) -> None:
self.mock_federation.make_query.return_value = make_awaitable(
{"displayname": "Alice"}
)
@@ -158,7 +163,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
ignore_backoff=True,
)
- def test_incoming_fed_query(self):
+ def test_incoming_fed_query(self) -> None:
self.get_success(self.store.create_profile("caroline"))
self.get_success(self.store.set_profile_displayname("caroline", "Caroline"))
@@ -174,7 +179,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual({"displayname": "Caroline"}, response)
- def test_get_my_avatar(self):
+ def test_get_my_avatar(self) -> None:
self.get_success(
self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png"
@@ -184,7 +189,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertEqual("http://my.server/me.png", avatar_url)
- def test_set_my_avatar(self):
+ def test_set_my_avatar(self) -> None:
self.get_success(
self.handler.set_avatar_url(
self.frank,
@@ -225,7 +230,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
)
- def test_set_my_avatar_if_disabled(self):
+ def test_set_my_avatar_if_disabled(self) -> None:
self.hs.config.registration.enable_set_avatar_url = False
# Setting displayname for the first time is allowed
@@ -250,7 +255,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
SynapseError,
)
- def test_avatar_constraints_no_config(self):
+ def test_avatar_constraints_no_config(self) -> None:
"""Tests that the method to check an avatar against configured constraints skips
all of its check if no constraint is configured.
"""
@@ -263,7 +268,13 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertTrue(res)
@unittest.override_config({"max_avatar_size": 50})
- def test_avatar_constraints_missing(self):
+ def test_avatar_constraints_allow_empty_avatar_url(self) -> None:
+ """An empty avatar is always permitted."""
+ res = self.get_success(self.handler.check_avatar_size_and_mime_type(""))
+ self.assertTrue(res)
+
+ @unittest.override_config({"max_avatar_size": 50})
+ def test_avatar_constraints_missing(self) -> None:
"""Tests that an avatar isn't allowed if the file at the given MXC URI couldn't
be found.
"""
@@ -273,7 +284,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertFalse(res)
@unittest.override_config({"max_avatar_size": 50})
- def test_avatar_constraints_file_size(self):
+ def test_avatar_constraints_file_size(self) -> None:
"""Tests that a file that's above the allowed file size is forbidden but one
that's below it is allowed.
"""
@@ -295,7 +306,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.assertFalse(res)
@unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
- def test_avatar_constraint_mime_type(self):
+ def test_avatar_constraint_mime_type(self) -> None:
"""Tests that a file with an unauthorised MIME type is forbidden but one with
an authorised content type is allowed.
"""
diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py
index cff07a8973..d37292ce13 100644
--- a/tests/handlers/test_room_summary.py
+++ b/tests/handlers/test_room_summary.py
@@ -172,6 +172,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
result_room_ids = []
result_children_ids = []
for result_room in result["rooms"]:
+ # Ensure federation results are not leaking over the client-server API.
+ self.assertNotIn("allowed_room_ids", result_room)
+
result_room_ids.append(result_room["room_id"])
result_children_ids.append(
[
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 23941abed8..8d4404eda1 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -12,12 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional
+from typing import Any, Dict, Optional
from unittest.mock import Mock
import attr
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.errors import RedirectException
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
@@ -81,10 +85,10 @@ class TestRedirectMappingProvider(TestMappingProvider):
class SamlHandlerTestCase(HomeserverTestCase):
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["public_baseurl"] = BASE_URL
- saml_config = {
+ saml_config: Dict[str, Any] = {
"sp_config": {"metadata": {}},
# Disable grandfathering.
"grandfathered_mxid_source_attribute": None,
@@ -98,7 +102,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
return config
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver()
self.handler = hs.get_saml_handler()
@@ -114,7 +118,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
elif not has_xmlsec1:
skip = "Requires xmlsec1"
- def test_map_saml_response_to_user(self):
+ def test_map_saml_response_to_user(self) -> None:
"""Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
# stub out the auth handler
@@ -140,7 +144,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
)
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
- def test_map_saml_response_to_existing_user(self):
+ def test_map_saml_response_to_existing_user(self) -> None:
"""Existing users can log in with SAML account."""
store = self.hs.get_datastores().main
self.get_success(
@@ -186,7 +190,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
auth_provider_session_id=None,
)
- def test_map_saml_response_to_invalid_localpart(self):
+ def test_map_saml_response_to_invalid_localpart(self) -> None:
"""If the mapping provider generates an invalid localpart it should be rejected."""
# stub out the auth handler
@@ -207,7 +211,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
)
auth_handler.complete_sso_login.assert_not_called()
- def test_map_saml_response_to_user_retries(self):
+ def test_map_saml_response_to_user_retries(self) -> None:
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
# stub out the auth handler and error renderer
@@ -271,7 +275,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_map_saml_response_redirect(self):
+ def test_map_saml_response_redirect(self) -> None:
"""Test a mapping provider that raises a RedirectException"""
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
@@ -292,7 +296,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
},
}
)
- def test_attribute_requirements(self):
+ def test_attribute_requirements(self) -> None:
"""The required attributes must be met from the SAML response."""
# stub out the auth handler
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index f91a80b9fa..ffd5c4cb93 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -18,11 +18,14 @@ from typing import Dict
from unittest.mock import ANY, Mock, call
from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
from synapse.api.errors import AuthError
from synapse.federation.transport.server import TransportLayerServer
-from synapse.types import UserID, create_requester
+from synapse.server import HomeServer
+from synapse.types import JsonDict, UserID, create_requester
+from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
@@ -42,7 +45,9 @@ ROOM_ID = "a-room"
OTHER_ROOM_ID = "another-room"
-def _expect_edu_transaction(edu_type, content, origin="test"):
+def _expect_edu_transaction(
+ edu_type: str, content: JsonDict, origin: str = "test"
+) -> JsonDict:
return {
"origin": origin,
"origin_server_ts": 1000000,
@@ -51,12 +56,12 @@ def _expect_edu_transaction(edu_type, content, origin="test"):
}
-def _make_edu_transaction_json(edu_type, content):
+def _make_edu_transaction_json(edu_type: str, content: JsonDict) -> bytes:
return json.dumps(_expect_edu_transaction(edu_type, content)).encode("utf8")
class TypingNotificationsTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# we mock out the keyring so as to skip the authentication check on the
# federation API call.
mock_keyring = Mock(spec=["verify_json_for_server"])
@@ -83,7 +88,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
d["/_matrix/federation"] = TransportLayerServer(self.hs)
return d
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
mock_notifier = hs.get_notifier()
self.on_new_event = mock_notifier.on_new_event
@@ -111,24 +116,24 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.room_members = []
- async def check_user_in_room(room_id, user_id):
+ async def check_user_in_room(room_id: str, user_id: str) -> None:
if user_id not in [u.to_string() for u in self.room_members]:
raise AuthError(401, "User is not in the room")
return None
hs.get_auth().check_user_in_room = check_user_in_room
- async def check_host_in_room(room_id, server_name):
+ async def check_host_in_room(room_id: str, server_name: str) -> bool:
return room_id == ROOM_ID
hs.get_event_auth_handler().check_host_in_room = check_host_in_room
- def get_joined_hosts_for_room(room_id):
+ def get_joined_hosts_for_room(room_id: str):
return {member.domain for member in self.room_members}
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
- async def get_users_in_room(room_id):
+ async def get_users_in_room(room_id: str):
return {str(u) for u in self.room_members}
self.datastore.get_users_in_room = get_users_in_room
@@ -153,7 +158,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
lambda *args, **kwargs: make_awaitable(None)
)
- def test_started_typing_local(self):
+ def test_started_typing_local(self) -> None:
self.room_members = [U_APPLE, U_BANANA]
self.assertEqual(self.event_source.get_current_key(), 0)
@@ -187,7 +192,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
@override_config({"send_federation": True})
- def test_started_typing_remote_send(self):
+ def test_started_typing_remote_send(self) -> None:
self.room_members = [U_APPLE, U_ONION]
self.get_success(
@@ -217,7 +222,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
try_trailing_slash_on_400=True,
)
- def test_started_typing_remote_recv(self):
+ def test_started_typing_remote_recv(self) -> None:
self.room_members = [U_APPLE, U_ONION]
self.assertEqual(self.event_source.get_current_key(), 0)
@@ -256,7 +261,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
],
)
- def test_started_typing_remote_recv_not_in_room(self):
+ def test_started_typing_remote_recv_not_in_room(self) -> None:
self.room_members = [U_APPLE, U_ONION]
self.assertEqual(self.event_source.get_current_key(), 0)
@@ -292,7 +297,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEqual(events[1], 0)
@override_config({"send_federation": True})
- def test_stopped_typing(self):
+ def test_stopped_typing(self) -> None:
self.room_members = [U_APPLE, U_BANANA, U_ONION]
# Gut-wrenching
@@ -343,7 +348,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
[{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
)
- def test_typing_timeout(self):
+ def test_typing_timeout(self) -> None:
self.room_members = [U_APPLE, U_BANANA]
self.assertEqual(self.event_source.get_current_key(), 0)
|