diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
index 6f300b8e11..5569ccef8a 100644
--- a/tests/handlers/test_admin.py
+++ b/tests/handlers/test_admin.py
@@ -23,6 +23,7 @@ from synapse.api.constants import EventTypes, JoinRules
from synapse.api.room_versions import RoomVersions
from synapse.rest.client import knock, login, room
from synapse.server import HomeServer
+from synapse.types import UserID
from synapse.util import Clock
from tests import unittest
@@ -296,3 +297,58 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(args[0][0]["user_agent"], "user_agent")
self.assertGreater(args[0][0]["last_seen"], 0)
self.assertNotIn("access_token", args[0][0])
+
+ def test_account_data(self) -> None:
+ """Tests that user account data get exported."""
+ # add account data
+ self.get_success(
+ self._store.add_account_data_for_user(self.user2, "m.global", {"a": 1})
+ )
+ self.get_success(
+ self._store.add_account_data_to_room(
+ self.user2, "test_room", "m.per_room", {"b": 2}
+ )
+ )
+
+ writer = Mock()
+
+ self.get_success(self.admin_handler.export_user_data(self.user2, writer))
+
+ # two calls, one call for user data and one call for room data
+ writer.write_account_data.assert_called()
+
+ args = writer.write_account_data.call_args_list[0][0]
+ self.assertEqual(args[0], "global")
+ self.assertEqual(args[1]["m.global"]["a"], 1)
+
+ args = writer.write_account_data.call_args_list[1][0]
+ self.assertEqual(args[0], "test_room")
+ self.assertEqual(args[1]["m.per_room"]["b"], 2)
+
+ def test_media_ids(self) -> None:
+ """Tests that media's metadata get exported."""
+
+ self.get_success(
+ self._store.store_local_media(
+ media_id="media_1",
+ media_type="image/png",
+ time_now_ms=self.clock.time_msec(),
+ upload_name=None,
+ media_length=50,
+ user_id=UserID.from_string(self.user2),
+ )
+ )
+
+ writer = Mock()
+
+ self.get_success(self.admin_handler.export_user_data(self.user2, writer))
+
+ writer.write_media_id.assert_called_once()
+
+ args = writer.write_media_id.call_args[0]
+ self.assertEqual(args[0], "media_1")
+ self.assertEqual(args[1]["media_id"], "media_1")
+ self.assertEqual(args[1]["media_length"], 50)
+ self.assertGreater(args[1]["created_ts"], 0)
+ self.assertIsNone(args[1]["upload_name"])
+ self.assertIsNone(args[1]["last_access_ts"])
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index a7495ab21a..9014e60577 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -899,7 +899,7 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase)
# Mock out application services, and allow defining our own in tests
self._services: List[ApplicationService] = []
- self.hs.get_datastores().main.get_app_services = Mock(
+ self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[assignment]
return_value=self._services
)
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index 2733719d82..63aad0d10c 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -61,7 +61,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
cas_response = CasResponse("test_user", {})
request = _mock_request()
@@ -89,7 +89,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# Map a user via SSO.
cas_response = CasResponse("test_user", {})
@@ -129,7 +129,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
cas_response = CasResponse("föö", {})
request = _mock_request()
@@ -160,7 +160,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# The response doesn't have the proper userGroup or department.
cas_response = CasResponse("test_user", {})
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 95698bc275..6b4cba65d0 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -23,6 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import Codes, SynapseError
+from synapse.handlers.device import DeviceHandler
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
@@ -187,37 +188,37 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
# we should now have an unused alg1 key
- res = self.get_success(
+ fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
- self.assertEqual(res, ["alg1"])
+ self.assertEqual(fallback_res, ["alg1"])
# claiming an OTK when no OTKs are available should return the fallback
# key
- res = self.get_success(
+ claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
- res,
+ claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
)
# we shouldn't have any unused fallback keys again
- res = self.get_success(
+ unused_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
- self.assertEqual(res, [])
+ self.assertEqual(unused_res, [])
# claiming an OTK again should return the same fallback key
- res = self.get_success(
+ claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
- res,
+ claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
)
@@ -231,10 +232,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
)
- res = self.get_success(
+ unused_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
- self.assertEqual(res, [])
+ self.assertEqual(unused_res, [])
# uploading a new fallback key should result in an unused fallback key
self.get_success(
@@ -245,10 +246,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
)
- res = self.get_success(
+ unused_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
- self.assertEqual(res, ["alg1"])
+ self.assertEqual(unused_res, ["alg1"])
# if the user uploads a one-time key, the next claim should fetch the
# one-time key, and then go back to the fallback
@@ -258,23 +259,23 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
)
- res = self.get_success(
+ claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
- res,
+ claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
)
- res = self.get_success(
+ claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
- res,
+ claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}},
)
@@ -287,13 +288,13 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
)
- res = self.get_success(
+ claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
- res,
+ claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
)
@@ -366,7 +367,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
# upload two device keys, which will be signed later by the self-signing key
- device_key_1 = {
+ device_key_1: JsonDict = {
"user_id": local_user,
"device_id": "abc",
"algorithms": [
@@ -379,7 +380,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
},
"signatures": {local_user: {"ed25519:abc": "base64+signature"}},
}
- device_key_2 = {
+ device_key_2: JsonDict = {
"user_id": local_user,
"device_id": "def",
"algorithms": [
@@ -451,8 +452,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
}
self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
+ device_handler = self.hs.get_device_handler()
+ assert isinstance(device_handler, DeviceHandler)
e = self.get_failure(
- self.hs.get_device_handler().check_device_registered(
+ device_handler.check_device_registered(
user_id=local_user,
device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
initial_device_display_name="new display name",
@@ -475,7 +478,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
device_id = "xyz"
# private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA
device_pubkey = "NnHhnqiMFQkq969szYkooLaBAXW244ZOxgukCvm2ZeY"
- device_key = {
+ device_key: JsonDict = {
"user_id": local_user,
"device_id": device_id,
"algorithms": [
@@ -497,7 +500,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0
master_pubkey = "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk"
- master_key = {
+ master_key: JsonDict = {
"user_id": local_user,
"usage": ["master"],
"keys": {"ed25519:" + master_pubkey: master_pubkey},
@@ -540,7 +543,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# the first user
other_user = "@otherboris:" + self.hs.hostname
other_master_pubkey = "fHZ3NPiKxoLQm5OoZbKa99SYxprOjNs4TwJUKP+twCM"
- other_master_key = {
+ other_master_key: JsonDict = {
# private key: oyw2ZUx0O4GifbfFYM0nQvj9CL0b8B7cyN4FprtK8OI
"user_id": other_user,
"usage": ["master"],
@@ -702,7 +705,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
- self.hs.get_federation_client().query_client_keys = mock.Mock(
+ self.hs.get_federation_client().query_client_keys = mock.Mock( # type: ignore[assignment]
return_value=make_awaitable(
{
"device_keys": {remote_user_id: {}},
@@ -782,7 +785,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
- self.hs.get_federation_client().query_user_devices = mock.Mock(
+ self.hs.get_federation_client().query_user_devices = mock.Mock( # type: ignore[assignment]
return_value=make_awaitable(
{
"user_id": remote_user_id,
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 57675fa407..bf0862ed54 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -371,14 +371,14 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
# We mock out the FederationClient.backfill method, to pretend that a remote
# server has returned our fake event.
federation_client_backfill_mock = Mock(return_value=make_awaitable([event]))
- self.hs.get_federation_client().backfill = federation_client_backfill_mock
+ self.hs.get_federation_client().backfill = federation_client_backfill_mock # type: ignore[assignment]
# We also mock the persist method with a side effect of itself. This allows us
# to track when it has been called while preserving its function.
persist_events_and_notify_mock = Mock(
side_effect=self.hs.get_federation_event_handler().persist_events_and_notify
)
- self.hs.get_federation_event_handler().persist_events_and_notify = (
+ self.hs.get_federation_event_handler().persist_events_and_notify = ( # type: ignore[assignment]
persist_events_and_notify_mock
)
@@ -575,26 +575,6 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
fed_client = fed_handler.federation_client
room_id = "!room:example.com"
- membership_event = make_event_from_dict(
- {
- "room_id": room_id,
- "type": "m.room.member",
- "sender": "@alice:test",
- "state_key": "@alice:test",
- "content": {"membership": "join"},
- },
- RoomVersions.V10,
- )
-
- mock_make_membership_event = Mock(
- return_value=make_awaitable(
- (
- "example.com",
- membership_event,
- RoomVersions.V10,
- )
- )
- )
EVENT_CREATE = make_event_from_dict(
{
@@ -640,6 +620,26 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
},
room_version=RoomVersions.V10,
)
+ membership_event = make_event_from_dict(
+ {
+ "room_id": room_id,
+ "type": "m.room.member",
+ "sender": "@alice:test",
+ "state_key": "@alice:test",
+ "content": {"membership": "join"},
+ "prev_events": [EVENT_INVITATION_MEMBERSHIP.event_id],
+ },
+ RoomVersions.V10,
+ )
+ mock_make_membership_event = Mock(
+ return_value=make_awaitable(
+ (
+ "example.com",
+ membership_event,
+ RoomVersions.V10,
+ )
+ )
+ )
mock_send_join = Mock(
return_value=make_awaitable(
SendJoinResult(
@@ -712,12 +712,12 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
# Start the partial state sync.
- fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+ fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# Try to start another partial state sync.
# Nothing should happen.
- fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+ fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# End the partial state sync
@@ -729,7 +729,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
# The next attempt to start the partial state sync should work.
is_partial_state = True
- fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+ fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 2)
def test_partial_state_room_sync_restart(self) -> None:
@@ -764,7 +764,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
# Start the partial state sync.
- fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+ fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# Fail the partial state sync.
@@ -773,11 +773,11 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# Start the partial state sync again.
- fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+ fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 2)
# Deduplicate another partial state sync.
- fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+ fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 2)
# Fail the partial state sync.
@@ -786,6 +786,6 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(mock_sync_partial_state_room.call_count, 3)
mock_sync_partial_state_room.assert_called_with(
initial_destination="hs3",
- other_destinations=["hs2"],
+ other_destinations={"hs2"},
room_id="room_id",
)
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index 70ea4d15d4..c067e5bfe3 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -29,6 +29,7 @@ from synapse.logging.context import LoggingContext
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
+from synapse.state import StateResolutionStore
from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort
from synapse.types import JsonDict
from synapse.util import Clock
@@ -161,6 +162,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
if prev_exists_as_outlier:
prev_event.internal_metadata.outlier = True
persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
self.get_success(
persistence.persist_event(
prev_event,
@@ -861,7 +863,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
bert_member_event.event_id: bert_member_event,
rejected_kick_event.event_id: rejected_kick_event,
},
- state_res_store=main_store,
+ state_res_store=StateResolutionStore(main_store),
)
),
[bert_member_event.event_id, rejected_kick_event.event_id],
@@ -906,7 +908,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
rejected_power_levels_event.event_id,
],
event_map={},
- state_res_store=main_store,
+ state_res_store=StateResolutionStore(main_store),
full_conflicted_set=set(),
)
),
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index c4727ab917..9691d66b48 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -18,7 +18,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes
from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
@@ -41,20 +41,21 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = self.hs.get_event_creation_handler()
- self._persist_event_storage_controller = (
- self.hs.get_storage_controllers().persistence
- )
+ persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
+ self._persist_event_storage_controller = persistence
self.user_id = self.register_user("tester", "foobar")
self.access_token = self.login("tester", "foobar")
self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
- self.info = self.get_success(
+ info = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(
self.access_token,
)
)
- self.token_id = self.info.token_id
+ assert info is not None
+ self.token_id = info.token_id
self.requester = create_requester(self.user_id, access_token_id=self.token_id)
@@ -78,7 +79,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
return memberEvent, memberEventContext
- def _create_duplicate_event(self, txn_id: str) -> Tuple[EventBase, EventContext]:
+ def _create_duplicate_event(
+ self, txn_id: str
+ ) -> Tuple[EventBase, UnpersistedEventContextBase]:
"""Create a new event with the given transaction ID. All events produced
by this method will be considered duplicates.
"""
@@ -106,7 +109,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
txn_id = "something_suitably_random"
- event1, context = self._create_duplicate_event(txn_id)
+ event1, unpersisted_context = self._create_duplicate_event(txn_id)
+ context = self.get_success(unpersisted_context.persist(event1))
ret_event1 = self.get_success(
self.handler.handle_new_client_event(
@@ -118,7 +122,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertEqual(event1.event_id, ret_event1.event_id)
- event2, context = self._create_duplicate_event(txn_id)
+ event2, unpersisted_context = self._create_duplicate_event(txn_id)
+ context = self.get_success(unpersisted_context.persist(event2))
# We want to test that the deduplication at the persit event end works,
# so we want to make sure we test with different events.
@@ -139,7 +144,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
# Let's test that calling `persist_event` directly also does the right
# thing.
- event3, context = self._create_duplicate_event(txn_id)
+ event3, unpersisted_context = self._create_duplicate_event(txn_id)
+ context = self.get_success(unpersisted_context.persist(event3))
+
self.assertNotEqual(event1.event_id, event3.event_id)
ret_event3, event_pos3, _ = self.get_success(
@@ -153,7 +160,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
# Let's test that calling `persist_events` directly also does the right
# thing.
- event4, context = self._create_duplicate_event(txn_id)
+ event4, unpersisted_context = self._create_duplicate_event(txn_id)
+ context = self.get_success(unpersisted_context.persist(event4))
self.assertNotEqual(event1.event_id, event3.event_id)
events, _ = self.get_success(
@@ -173,8 +181,10 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
txn_id = "something_else_suitably_random"
# Create two duplicate events to persist at the same time
- event1, context1 = self._create_duplicate_event(txn_id)
- event2, context2 = self._create_duplicate_event(txn_id)
+ event1, unpersisted_context1 = self._create_duplicate_event(txn_id)
+ context1 = self.get_success(unpersisted_context1.persist(event1))
+ event2, unpersisted_context2 = self._create_duplicate_event(txn_id)
+ context2 = self.get_success(unpersisted_context2.persist(event2))
# Ensure their event IDs are different to start with
self.assertNotEqual(event1.event_id, event2.event_id)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index ef5311ce64..bb52b3b1af 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -150,7 +150,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
hs = self.setup_test_homeserver()
self.hs_patcher = self.fake_server.patch_homeserver(hs=hs)
- self.hs_patcher.start()
+ self.hs_patcher.start() # type: ignore[attr-defined]
self.handler = hs.get_oidc_handler()
self.provider = self.handler._providers["oidc"]
@@ -170,7 +170,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
return hs
def tearDown(self) -> None:
- self.hs_patcher.stop()
+ self.hs_patcher.stop() # type: ignore[attr-defined]
return super().tearDown()
def reset_mocks(self) -> None:
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 0916de64f5..aa91bc0a3d 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -852,7 +852,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
username: The username to use for the test.
registration: Whether to test with registration URLs.
"""
- self.hs.get_identity_handler().send_threepid_validation = Mock(
+ self.hs.get_identity_handler().send_threepid_validation = Mock( # type: ignore[assignment]
return_value=make_awaitable(0),
)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index b9332d97dc..aff1ec4758 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -62,7 +62,7 @@ class TestSpamChecker:
request_info: Collection[Tuple[str, str]],
auth_provider_id: Optional[str],
) -> RegistrationBehaviour:
- pass
+ return RegistrationBehaviour.ALLOW
class DenyAll(TestSpamChecker):
@@ -111,7 +111,7 @@ class TestLegacyRegistrationSpamChecker:
username: Optional[str],
request_info: Collection[Tuple[str, str]],
) -> RegistrationBehaviour:
- pass
+ return RegistrationBehaviour.ALLOW
class LegacyAllowAll(TestLegacyRegistrationSpamChecker):
@@ -203,7 +203,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_not_blocked(self) -> None:
- self.store.count_monthly_users = Mock(
+ self.store.count_monthly_users = Mock( # type: ignore[assignment]
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
)
# Ensure does not throw exception
@@ -304,7 +304,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None:
room_alias_str = "#room:test"
- self.store.count_real_users = Mock(return_value=make_awaitable(1))
+ self.store.count_real_users = Mock(return_value=make_awaitable(1)) # type: ignore[assignment]
self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
@@ -319,7 +319,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(
self,
) -> None:
- self.store.count_real_users = Mock(return_value=make_awaitable(2))
+ self.store.count_real_users = Mock(return_value=make_awaitable(2)) # type: ignore[assignment]
self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
@@ -346,6 +346,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Ensure the room is properly not federated.
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+ assert room is not None
self.assertFalse(room["federatable"])
self.assertFalse(room["public"])
self.assertEqual(room["join_rules"], "public")
@@ -375,6 +376,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Ensure the room is properly a public room.
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+ assert room is not None
self.assertEqual(room["join_rules"], "public")
# Both users should be in the room.
@@ -413,6 +415,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Ensure the room is properly a private room.
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+ assert room is not None
self.assertFalse(room["public"])
self.assertEqual(room["join_rules"], "invite")
self.assertEqual(room["guest_access"], "can_join")
@@ -456,6 +459,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Ensure the room is properly a private room.
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+ assert room is not None
self.assertFalse(room["public"])
self.assertEqual(room["join_rules"], "invite")
self.assertEqual(room["guest_access"], "can_join")
@@ -503,7 +507,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Lower the permissions of the inviter.
event_creation_handler = self.hs.get_event_creation_handler()
requester = create_requester(inviter)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
event_creation_handler.create_event(
requester,
{
@@ -515,6 +519,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
},
)
)
+ context = self.get_success(unpersisted_context.persist(event))
self.get_success(
event_creation_handler.handle_new_client_event(
requester, events_and_context=[(event, context)]
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 9b1b8b9f13..b5c772a7ae 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -134,7 +134,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# send a mocked-up SAML response to the callback
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
@@ -164,7 +164,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# Map a user via SSO.
saml_response = FakeAuthnResponse(
@@ -206,11 +206,11 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# mock out the error renderer too
sso_handler = self.hs.get_sso_handler()
- sso_handler.render_error = Mock(return_value=None)
+ sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment]
saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
request = _mock_request()
@@ -227,9 +227,9 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler and error renderer
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
sso_handler = self.hs.get_sso_handler()
- sso_handler.render_error = Mock(return_value=None)
+ sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment]
# register a user to occupy the first-choice MXID
store = self.hs.get_datastores().main
@@ -312,7 +312,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# The response doesn't have the proper userGroup or department.
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
diff --git a/tests/handlers/test_sso.py b/tests/handlers/test_sso.py
index 137deab138..d6f43a98fc 100644
--- a/tests/handlers/test_sso.py
+++ b/tests/handlers/test_sso.py
@@ -113,7 +113,6 @@ async def mock_get_file(
headers: Optional[RawHeaders] = None,
is_allowed_content_type: Optional[Callable[[str], bool]] = None,
) -> Tuple[int, Dict[bytes, List[bytes]], str, int]:
-
fake_response = FakeResponse(code=404)
if url == "http://my.server/me.png":
fake_response = FakeResponse(
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index f1a50c5bcb..d11ded6c5b 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -31,7 +31,6 @@ EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 6
class StatsRoomTests(unittest.HomeserverTestCase):
-
servlets = [
admin.register_servlets_for_client_rest_resource,
room.register_servlets,
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 1fe9563c98..94518a7196 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -74,8 +74,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
mock_keyring.verify_json_for_server.return_value = make_awaitable(True)
# we mock out the federation client too
- mock_federation_client = Mock(spec=["put_json"])
- mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))
+ self.mock_federation_client = Mock(spec=["put_json"])
+ self.mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))
# the tests assume that we are starting at unix time 1000
reactor.pump((1000,))
@@ -83,7 +83,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.mock_hs_notifier = Mock()
hs = self.setup_test_homeserver(
notifier=self.mock_hs_notifier,
- federation_http_client=mock_federation_client,
+ federation_http_client=self.mock_federation_client,
keyring=mock_keyring,
replication_streams={},
)
@@ -233,8 +233,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
)
- put_json = self.hs.get_federation_http_client().put_json
- put_json.assert_called_once_with(
+ self.mock_federation_client.put_json.assert_called_once_with(
"farm",
path="/_matrix/federation/v1/send/1000000",
data=_expect_edu_transaction(
@@ -349,8 +348,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
- put_json = self.hs.get_federation_http_client().put_json
- put_json.assert_called_once_with(
+ self.mock_federation_client.put_json.assert_called_once_with(
"farm",
path="/_matrix/federation/v1/send/1000000",
data=_expect_edu_transaction(
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 75fc5a17a4..a02c1c6227 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Tuple
+from typing import Any, Tuple
from unittest.mock import Mock, patch
from urllib.parse import quote
@@ -24,7 +24,7 @@ from synapse.appservice import ApplicationService
from synapse.rest.client import login, register, room, user_directory
from synapse.server import HomeServer
from synapse.storage.roommember import ProfileInfo
-from synapse.types import create_requester
+from synapse.types import UserProfile, create_requester
from synapse.util import Clock
from tests import unittest
@@ -34,6 +34,12 @@ from tests.test_utils.event_injection import inject_member_event
from tests.unittest import override_config
+# A spam checker which doesn't implement anything, so create a bare object.
+class UselessSpamChecker:
+ def __init__(self, config: Any):
+ pass
+
+
class UserDirectoryTestCase(unittest.HomeserverTestCase):
"""Tests the UserDirectoryHandler.
@@ -186,6 +192,13 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.helper.join(room, self.appservice.sender, tok=self.appservice.token)
self._check_only_one_user_in_directory(user, room)
+ def test_search_term_with_colon_in_it_does_not_raise(self) -> None:
+ """
+ Regression test: Test that search terms with colons in them are acceptable.
+ """
+ u1 = self.register_user("user1", "pass")
+ self.get_success(self.handler.search_users(u1, "haha:paamayim-nekudotayim", 10))
+
def test_user_not_in_users_table(self) -> None:
"""Unclear how it happens, but on matrix.org we've seen join events
for users who aren't in the users table. Test that we don't fall over
@@ -773,7 +786,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 1)
- async def allow_all(user_profile: ProfileInfo) -> bool:
+ async def allow_all(user_profile: UserProfile) -> bool:
# Allow all users.
return False
@@ -787,7 +800,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(s["results"]), 1)
# Configure a spam checker that filters all users.
- async def block_all(user_profile: ProfileInfo) -> bool:
+ async def block_all(user_profile: UserProfile) -> bool:
# All users are spammy.
return True
@@ -797,6 +810,13 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 0)
+ @override_config(
+ {
+ "spam_checker": {
+ "module": "tests.handlers.test_user_directory.UselessSpamChecker"
+ }
+ }
+ )
def test_legacy_spam_checker(self) -> None:
"""
A spam checker without the expected method should be ignored.
@@ -825,11 +845,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)})
self.assertEqual(public_users, set())
- # Configure a spam checker.
- spam_checker = self.hs.get_spam_checker()
- # The spam checker doesn't need any methods, so create a bare object.
- spam_checker.spam_checker = object()
-
# We get one search result when searching for user2 by user1.
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 1)
@@ -949,13 +964,14 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(
- self.hs.get_storage_controllers().persistence.persist_event(event, context)
- )
+ context = self.get_success(unpersisted_context.persist(event))
+ persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
+ self.get_success(persistence.persist_event(event, context))
def test_local_user_leaving_room_remains_in_user_directory(self) -> None:
"""We've chosen to simplify the user directory's implementation by
|