diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index 5d89ba94ad..2ee343d8a4 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -67,7 +67,9 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
}
# Listen with the config
- self.hs._listen_http(parse_listener_def(0, config))
+ hs = self.hs
+ assert isinstance(hs, GenericWorkerServer)
+ hs._listen_http(parse_listener_def(0, config))
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
@@ -115,7 +117,9 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
}
# Listen with the config
- self.hs._listener_http(self.hs.config, parse_listener_def(0, config))
+ hs = self.hs
+ assert isinstance(hs, SynapseHomeServer)
+ hs._listener_http(self.hs.config, parse_listener_def(0, config))
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index febcc1499d..e2a3bad065 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, cast
+from typing import List, Optional, Sequence, Tuple, cast
from unittest.mock import Mock
from typing_extensions import TypeAlias
from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
from synapse.appservice import (
ApplicationService,
@@ -40,9 +41,6 @@ from tests.test_utils import simple_async_mock
from ..utils import MockClock
-if TYPE_CHECKING:
- from twisted.internet.testing import MemoryReactor
-
class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
def setUp(self) -> None:
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 0e8af2da54..1b9696748f 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -192,7 +192,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
key1 = signedjson.key.generate_signing_key("1")
r = self.hs.get_datastores().main.store_server_verify_keys(
"server9",
- time.time() * 1000,
+ int(time.time() * 1000),
[("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
)
self.get_success(r)
@@ -287,7 +287,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
key1 = signedjson.key.generate_signing_key("1")
r = self.hs.get_datastores().main.store_server_verify_keys(
"server9",
- time.time() * 1000,
+ int(time.time() * 1000),
# None is not a valid value in FetchKeyResult, but we're abusing this
# API to insert null values into the database. The nulls get converted
# to 0 when fetched in KeyStore.get_server_verify_keys.
@@ -466,9 +466,9 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
)
- res = key_json[lookup_triplet]
- self.assertEqual(len(res), 1)
- res = res[0]
+ res_keys = key_json[lookup_triplet]
+ self.assertEqual(len(res_keys), 1)
+ res = res_keys[0]
self.assertEqual(res["key_id"], testverifykey_id)
self.assertEqual(res["from_server"], SERVER_NAME)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
@@ -584,9 +584,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
)
- res = key_json[lookup_triplet]
- self.assertEqual(len(res), 1)
- res = res[0]
+ res_keys = key_json[lookup_triplet]
+ self.assertEqual(len(res_keys), 1)
+ res = res_keys[0]
self.assertEqual(res["key_id"], testverifykey_id)
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
@@ -705,9 +705,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
)
- res = key_json[lookup_triplet]
- self.assertEqual(len(res), 1)
- res = res[0]
+ res_keys = key_json[lookup_triplet]
+ self.assertEqual(len(res_keys), 1)
+ res = res_keys[0]
self.assertEqual(res["key_id"], testverifykey_id)
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index a9893def74..6fb1f1bd6e 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -31,7 +31,11 @@ from synapse.util import Clock
from tests.handlers.test_sync import generate_sync_config
from tests.test_utils import simple_async_mock
-from tests.unittest import FederatingHomeserverTestCase, override_config
+from tests.unittest import (
+ FederatingHomeserverTestCase,
+ HomeserverTestCase,
+ override_config,
+)
@attr.s
@@ -152,11 +156,11 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# Mock out the calls over federation.
- fed_transport_client = Mock(spec=["send_transaction"])
- fed_transport_client.send_transaction = simple_async_mock({})
+ self.fed_transport_client = Mock(spec=["send_transaction"])
+ self.fed_transport_client.send_transaction = simple_async_mock({})
hs = self.setup_test_homeserver(
- federation_transport_client=fed_transport_client,
+ federation_transport_client=self.fed_transport_client,
)
load_legacy_presence_router(hs)
@@ -418,7 +422,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
#
# Thus we reset the mock, and try sending all online local user
# presence again
- self.hs.get_federation_transport_client().send_transaction.reset_mock()
+ self.fed_transport_client.send_transaction.reset_mock()
# Broadcast local user online presence
self.get_success(
@@ -443,9 +447,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
}
found_users = set()
- calls = (
- self.hs.get_federation_transport_client().send_transaction.call_args_list
- )
+ calls = self.fed_transport_client.send_transaction.call_args_list
for call in calls:
call_args = call[0]
federation_transaction: Transaction = call_args[0]
@@ -470,7 +472,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
def send_presence_update(
- testcase: FederatingHomeserverTestCase,
+ testcase: HomeserverTestCase,
user_id: str,
access_token: str,
presence_state: str,
@@ -491,7 +493,7 @@ def send_presence_update(
def sync_presence(
- testcase: FederatingHomeserverTestCase,
+ testcase: HomeserverTestCase,
user_id: str,
since_token: Optional[StreamToken] = None,
) -> Tuple[List[UserPresenceState], StreamToken]:
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index d667dd27bf..35dd9a20df 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -17,7 +17,7 @@ from unittest.mock import Mock
from synapse.api.errors import Codes, SynapseError
from synapse.rest import admin
from synapse.rest.client import login, room
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, UserID, create_requester
from tests import unittest
from tests.test_utils import make_awaitable
@@ -56,7 +56,11 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Artificially raise the complexity
store = self.hs.get_datastores().main
- store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23)
+
+ async def get_current_state_event_counts(room_id: str) -> int:
+ return int(500 * 1.23)
+
+ store.get_current_state_event_counts = get_current_state_event_counts # type: ignore[assignment]
# Get the room complexity again -- make sure it's our artificial value
channel = self.make_signed_federation_request(
@@ -75,12 +79,12 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
- handler.federation_handler.do_invite_join = Mock(
+ handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
- None,
+ create_requester(u1),
["other.example.com"],
"roomid",
UserID.from_string(u1),
@@ -106,12 +110,12 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
- handler.federation_handler.do_invite_join = Mock(
+ handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
- None,
+ create_requester(u1),
["other.example.com"],
"roomid",
UserID.from_string(u1),
@@ -144,17 +148,18 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable(None))
- handler.federation_handler.do_invite_join = Mock(
+ handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
)
# Artificially raise the complexity
- self.hs.get_datastores().main.get_current_state_event_counts = (
- lambda x: make_awaitable(600)
- )
+ async def get_current_state_event_counts(room_id: str) -> int:
+ return 600
+
+ self.hs.get_datastores().main.get_current_state_event_counts = get_current_state_event_counts # type: ignore[assignment]
d = handler._remote_join(
- None,
+ create_requester(u1),
["other.example.com"],
room_1,
UserID.from_string(u1),
@@ -200,12 +205,12 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
- handler.federation_handler.do_invite_join = Mock(
+ handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
- None,
+ create_requester(u1),
["other.example.com"],
"roomid",
UserID.from_string(u1),
@@ -230,12 +235,12 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
- handler.federation_handler.do_invite_join = Mock(
+ handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
- None,
+ create_requester(u1),
["other.example.com"],
"roomid",
UserID.from_string(u1),
diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index a986b15f0a..6381583c24 100644
--- a/tests/federation/test_federation_catch_up.py
+++ b/tests/federation/test_federation_catch_up.py
@@ -5,7 +5,11 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes
from synapse.events import EventBase
-from synapse.federation.sender import PerDestinationQueue, TransactionManager
+from synapse.federation.sender import (
+ FederationSender,
+ PerDestinationQueue,
+ TransactionManager,
+)
from synapse.federation.units import Edu, Transaction
from synapse.rest import admin
from synapse.rest.client import login, room
@@ -33,8 +37,9 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.federation_transport_client = Mock(spec=["send_transaction"])
return self.setup_test_homeserver(
- federation_transport_client=Mock(spec=["send_transaction"]),
+ federation_transport_client=self.federation_transport_client,
)
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -52,10 +57,14 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
self.pdus: List[JsonDict] = []
self.failed_pdus: List[JsonDict] = []
self.is_online = True
- self.hs.get_federation_transport_client().send_transaction.side_effect = (
+ self.federation_transport_client.send_transaction.side_effect = (
self.record_transaction
)
+ federation_sender = hs.get_federation_sender()
+ assert isinstance(federation_sender, FederationSender)
+ self.federation_sender = federation_sender
+
def default_config(self) -> JsonDict:
config = super().default_config()
config["federation_sender_instances"] = None
@@ -229,11 +238,11 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
# let's delete the federation transmission queue
# (this pretends we are starting up fresh.)
self.assertFalse(
- self.hs.get_federation_sender()
- ._per_destination_queues["host2"]
- .transmission_loop_running
+ self.federation_sender._per_destination_queues[
+ "host2"
+ ].transmission_loop_running
)
- del self.hs.get_federation_sender()._per_destination_queues["host2"]
+ del self.federation_sender._per_destination_queues["host2"]
# let's also clear any backoffs
self.get_success(
@@ -322,6 +331,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
# also fetch event 5 so we know its last_successful_stream_ordering later
event_5 = self.get_success(self.hs.get_datastores().main.get_event(event_id_5))
+ assert event_2.internal_metadata.stream_ordering is not None
self.get_success(
self.hs.get_datastores().main.set_destination_last_successful_stream_ordering(
"host2", event_2.internal_metadata.stream_ordering
@@ -425,15 +435,16 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
def wake_destination_track(destination: str) -> None:
woken.append(destination)
- self.hs.get_federation_sender().wake_destination = wake_destination_track
+ self.federation_sender.wake_destination = wake_destination_track # type: ignore[assignment]
# cancel the pre-existing timer for _wake_destinations_needing_catchup
# this is because we are calling it manually rather than waiting for it
# to be called automatically
- self.hs.get_federation_sender()._catchup_after_startup_timer.cancel()
+ assert self.federation_sender._catchup_after_startup_timer is not None
+ self.federation_sender._catchup_after_startup_timer.cancel()
self.get_success(
- self.hs.get_federation_sender()._wake_destinations_needing_catchup(), by=5.0
+ self.federation_sender._wake_destinations_needing_catchup(), by=5.0
)
# ASSERT (_wake_destinations_needing_catchup):
@@ -475,6 +486,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
)
)
+ assert event_1.internal_metadata.stream_ordering is not None
self.get_success(
self.hs.get_datastores().main.set_destination_last_successful_stream_ordering(
"host2", event_1.internal_metadata.stream_ordering
diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py
index 86e1236501..91694e4fca 100644
--- a/tests/federation/test_federation_client.py
+++ b/tests/federation/test_federation_client.py
@@ -178,7 +178,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
RoomVersions.V9,
)
)
- self.assertIsNotNone(pulled_pdu_info2)
+ assert pulled_pdu_info2 is not None
remote_pdu2 = pulled_pdu_info2.pdu
# Sanity check that we are working against the same event
@@ -226,7 +226,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
RoomVersions.V9,
)
)
- self.assertIsNotNone(pulled_pdu_info)
+ assert pulled_pdu_info is not None
remote_pdu = pulled_pdu_info.pdu
# check the right call got made to the agent
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index ddeffe1ad5..9e104fd96a 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -22,6 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EduTypes, RoomEncryptionAlgorithms
from synapse.federation.units import Transaction
+from synapse.handlers.device import DeviceHandler
from synapse.rest import admin
from synapse.rest.client import login
from synapse.server import HomeServer
@@ -41,8 +42,9 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
"""
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.federation_transport_client = Mock(spec=["send_transaction"])
hs = self.setup_test_homeserver(
- federation_transport_client=Mock(spec=["send_transaction"]),
+ federation_transport_client=self.federation_transport_client,
)
hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[assignment]
@@ -61,9 +63,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
return config
def test_send_receipts(self) -> None:
- mock_send_transaction = (
- self.hs.get_federation_transport_client().send_transaction
- )
+ mock_send_transaction = self.federation_transport_client.send_transaction
mock_send_transaction.return_value = make_awaitable({})
sender = self.hs.get_federation_sender()
@@ -103,9 +103,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
)
def test_send_receipts_thread(self) -> None:
- mock_send_transaction = (
- self.hs.get_federation_transport_client().send_transaction
- )
+ mock_send_transaction = self.federation_transport_client.send_transaction
mock_send_transaction.return_value = make_awaitable({})
# Create receipts for:
@@ -181,9 +179,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
def test_send_receipts_with_backoff(self) -> None:
"""Send two receipts in quick succession; the second should be flushed, but
only after 20ms"""
- mock_send_transaction = (
- self.hs.get_federation_transport_client().send_transaction
- )
+ mock_send_transaction = self.federation_transport_client.send_transaction
mock_send_transaction.return_value = make_awaitable({})
sender = self.hs.get_federation_sender()
@@ -277,10 +273,11 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.federation_transport_client = Mock(
+ spec=["send_transaction", "query_user_devices"]
+ )
return self.setup_test_homeserver(
- federation_transport_client=Mock(
- spec=["send_transaction", "query_user_devices"]
- ),
+ federation_transport_client=self.federation_transport_client,
)
def default_config(self) -> JsonDict:
@@ -310,9 +307,13 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
hs.get_datastores().main.get_current_hosts_in_room = get_current_hosts_in_room # type: ignore[assignment]
+ device_handler = hs.get_device_handler()
+ assert isinstance(device_handler, DeviceHandler)
+ self.device_handler = device_handler
+
# whenever send_transaction is called, record the edu data
self.edus: List[JsonDict] = []
- self.hs.get_federation_transport_client().send_transaction.side_effect = (
+ self.federation_transport_client.send_transaction.side_effect = (
self.record_transaction
)
@@ -353,7 +354,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# Send the server a device list EDU for the other user, this will cause
# it to try and resync the device lists.
- self.hs.get_federation_transport_client().query_user_devices.return_value = (
+ self.federation_transport_client.query_user_devices.return_value = (
make_awaitable(
{
"stream_id": "1",
@@ -364,7 +365,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
)
self.get_success(
- self.hs.get_device_handler().device_list_updater.incoming_device_list_update(
+ self.device_handler.device_list_updater.incoming_device_list_update(
"host2",
{
"user_id": "@user2:host2",
@@ -507,9 +508,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D3", stream_id)
# delete them again
- self.get_success(
- self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
- )
+ self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"]))
# We queue up device list updates to be sent over federation, so we
# advance to clear the queue.
@@ -533,7 +532,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
"""If the destination server is unreachable, all the updates should get sent on
recovery
"""
- mock_send_txn = self.hs.get_federation_transport_client().send_transaction
+ mock_send_txn = self.federation_transport_client.send_transaction
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
# create devices
@@ -543,9 +542,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.login("user", "pass", device_id="D3")
# delete them again
- self.get_success(
- self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
- )
+ self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"]))
# We queue up device list updates to be sent over federation, so we
# advance to clear the queue.
@@ -580,7 +577,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
This case tests the behaviour when the server has never been reachable.
"""
- mock_send_txn = self.hs.get_federation_transport_client().send_transaction
+ mock_send_txn = self.federation_transport_client.send_transaction
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
# create devices
@@ -590,9 +587,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.login("user", "pass", device_id="D3")
# delete them again
- self.get_success(
- self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
- )
+ self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"]))
# We queue up device list updates to be sent over federation, so we
# advance to clear the queue.
@@ -640,7 +635,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.check_device_update_edu(self.edus.pop(0), u1, "D1", None)
# now the server goes offline
- mock_send_txn = self.hs.get_federation_transport_client().send_transaction
+ mock_send_txn = self.federation_transport_client.send_transaction
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
self.login("user", "pass", device_id="D2")
@@ -651,9 +646,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.reactor.advance(1)
# delete them again
- self.get_success(
- self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
- )
+ self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"]))
self.assertGreaterEqual(mock_send_txn.call_count, 3)
diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
index 6f300b8e11..1b97aaeed1 100644
--- a/tests/handlers/test_admin.py
+++ b/tests/handlers/test_admin.py
@@ -296,3 +296,30 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(args[0][0]["user_agent"], "user_agent")
self.assertGreater(args[0][0]["last_seen"], 0)
self.assertNotIn("access_token", args[0][0])
+
+ def test_account_data(self) -> None:
+ """Tests that user account data get exported."""
+ # add account data
+ self.get_success(
+ self._store.add_account_data_for_user(self.user2, "m.global", {"a": 1})
+ )
+ self.get_success(
+ self._store.add_account_data_to_room(
+ self.user2, "test_room", "m.per_room", {"b": 2}
+ )
+ )
+
+ writer = Mock()
+
+ self.get_success(self.admin_handler.export_user_data(self.user2, writer))
+
+ # two calls, one call for user data and one call for room data
+ writer.write_account_data.assert_called()
+
+ args = writer.write_account_data.call_args_list[0][0]
+ self.assertEqual(args[0], "global")
+ self.assertEqual(args[1]["m.global"]["a"], 1)
+
+ args = writer.write_account_data.call_args_list[1][0]
+ self.assertEqual(args[0], "test_room")
+ self.assertEqual(args[1]["m.per_room"]["b"], 2)
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index a7495ab21a..9014e60577 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -899,7 +899,7 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase)
# Mock out application services, and allow defining our own in tests
self._services: List[ApplicationService] = []
- self.hs.get_datastores().main.get_app_services = Mock(
+ self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[assignment]
return_value=self._services
)
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index 2733719d82..63aad0d10c 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -61,7 +61,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
cas_response = CasResponse("test_user", {})
request = _mock_request()
@@ -89,7 +89,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# Map a user via SSO.
cas_response = CasResponse("test_user", {})
@@ -129,7 +129,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
cas_response = CasResponse("föö", {})
request = _mock_request()
@@ -160,7 +160,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# The response doesn't have the proper userGroup or department.
cas_response = CasResponse("test_user", {})
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 95698bc275..6b4cba65d0 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -23,6 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import Codes, SynapseError
+from synapse.handlers.device import DeviceHandler
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
@@ -187,37 +188,37 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
# we should now have an unused alg1 key
- res = self.get_success(
+ fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
- self.assertEqual(res, ["alg1"])
+ self.assertEqual(fallback_res, ["alg1"])
# claiming an OTK when no OTKs are available should return the fallback
# key
- res = self.get_success(
+ claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
- res,
+ claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
)
# we shouldn't have any unused fallback keys again
- res = self.get_success(
+ unused_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
- self.assertEqual(res, [])
+ self.assertEqual(unused_res, [])
# claiming an OTK again should return the same fallback key
- res = self.get_success(
+ claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
- res,
+ claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
)
@@ -231,10 +232,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
)
- res = self.get_success(
+ unused_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
- self.assertEqual(res, [])
+ self.assertEqual(unused_res, [])
# uploading a new fallback key should result in an unused fallback key
self.get_success(
@@ -245,10 +246,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
)
- res = self.get_success(
+ unused_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
- self.assertEqual(res, ["alg1"])
+ self.assertEqual(unused_res, ["alg1"])
# if the user uploads a one-time key, the next claim should fetch the
# one-time key, and then go back to the fallback
@@ -258,23 +259,23 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
)
- res = self.get_success(
+ claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
- res,
+ claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
)
- res = self.get_success(
+ claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
- res,
+ claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}},
)
@@ -287,13 +288,13 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
)
- res = self.get_success(
+ claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
- res,
+ claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
)
@@ -366,7 +367,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
# upload two device keys, which will be signed later by the self-signing key
- device_key_1 = {
+ device_key_1: JsonDict = {
"user_id": local_user,
"device_id": "abc",
"algorithms": [
@@ -379,7 +380,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
},
"signatures": {local_user: {"ed25519:abc": "base64+signature"}},
}
- device_key_2 = {
+ device_key_2: JsonDict = {
"user_id": local_user,
"device_id": "def",
"algorithms": [
@@ -451,8 +452,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
}
self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
+ device_handler = self.hs.get_device_handler()
+ assert isinstance(device_handler, DeviceHandler)
e = self.get_failure(
- self.hs.get_device_handler().check_device_registered(
+ device_handler.check_device_registered(
user_id=local_user,
device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
initial_device_display_name="new display name",
@@ -475,7 +478,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
device_id = "xyz"
# private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA
device_pubkey = "NnHhnqiMFQkq969szYkooLaBAXW244ZOxgukCvm2ZeY"
- device_key = {
+ device_key: JsonDict = {
"user_id": local_user,
"device_id": device_id,
"algorithms": [
@@ -497,7 +500,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0
master_pubkey = "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk"
- master_key = {
+ master_key: JsonDict = {
"user_id": local_user,
"usage": ["master"],
"keys": {"ed25519:" + master_pubkey: master_pubkey},
@@ -540,7 +543,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# the first user
other_user = "@otherboris:" + self.hs.hostname
other_master_pubkey = "fHZ3NPiKxoLQm5OoZbKa99SYxprOjNs4TwJUKP+twCM"
- other_master_key = {
+ other_master_key: JsonDict = {
# private key: oyw2ZUx0O4GifbfFYM0nQvj9CL0b8B7cyN4FprtK8OI
"user_id": other_user,
"usage": ["master"],
@@ -702,7 +705,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
- self.hs.get_federation_client().query_client_keys = mock.Mock(
+ self.hs.get_federation_client().query_client_keys = mock.Mock( # type: ignore[assignment]
return_value=make_awaitable(
{
"device_keys": {remote_user_id: {}},
@@ -782,7 +785,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
- self.hs.get_federation_client().query_user_devices = mock.Mock(
+ self.hs.get_federation_client().query_user_devices = mock.Mock( # type: ignore[assignment]
return_value=make_awaitable(
{
"user_id": remote_user_id,
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 57675fa407..bf0862ed54 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -371,14 +371,14 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
# We mock out the FederationClient.backfill method, to pretend that a remote
# server has returned our fake event.
federation_client_backfill_mock = Mock(return_value=make_awaitable([event]))
- self.hs.get_federation_client().backfill = federation_client_backfill_mock
+ self.hs.get_federation_client().backfill = federation_client_backfill_mock # type: ignore[assignment]
# We also mock the persist method with a side effect of itself. This allows us
# to track when it has been called while preserving its function.
persist_events_and_notify_mock = Mock(
side_effect=self.hs.get_federation_event_handler().persist_events_and_notify
)
- self.hs.get_federation_event_handler().persist_events_and_notify = (
+ self.hs.get_federation_event_handler().persist_events_and_notify = ( # type: ignore[assignment]
persist_events_and_notify_mock
)
@@ -575,26 +575,6 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
fed_client = fed_handler.federation_client
room_id = "!room:example.com"
- membership_event = make_event_from_dict(
- {
- "room_id": room_id,
- "type": "m.room.member",
- "sender": "@alice:test",
- "state_key": "@alice:test",
- "content": {"membership": "join"},
- },
- RoomVersions.V10,
- )
-
- mock_make_membership_event = Mock(
- return_value=make_awaitable(
- (
- "example.com",
- membership_event,
- RoomVersions.V10,
- )
- )
- )
EVENT_CREATE = make_event_from_dict(
{
@@ -640,6 +620,26 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
},
room_version=RoomVersions.V10,
)
+ membership_event = make_event_from_dict(
+ {
+ "room_id": room_id,
+ "type": "m.room.member",
+ "sender": "@alice:test",
+ "state_key": "@alice:test",
+ "content": {"membership": "join"},
+ "prev_events": [EVENT_INVITATION_MEMBERSHIP.event_id],
+ },
+ RoomVersions.V10,
+ )
+ mock_make_membership_event = Mock(
+ return_value=make_awaitable(
+ (
+ "example.com",
+ membership_event,
+ RoomVersions.V10,
+ )
+ )
+ )
mock_send_join = Mock(
return_value=make_awaitable(
SendJoinResult(
@@ -712,12 +712,12 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
# Start the partial state sync.
- fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+ fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# Try to start another partial state sync.
# Nothing should happen.
- fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+ fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# End the partial state sync
@@ -729,7 +729,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
# The next attempt to start the partial state sync should work.
is_partial_state = True
- fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+ fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 2)
def test_partial_state_room_sync_restart(self) -> None:
@@ -764,7 +764,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
# Start the partial state sync.
- fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+ fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# Fail the partial state sync.
@@ -773,11 +773,11 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# Start the partial state sync again.
- fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+ fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 2)
# Deduplicate another partial state sync.
- fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+ fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 2)
# Fail the partial state sync.
@@ -786,6 +786,6 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(mock_sync_partial_state_room.call_count, 3)
mock_sync_partial_state_room.assert_called_with(
initial_destination="hs3",
- other_destinations=["hs2"],
+ other_destinations={"hs2"},
room_id="room_id",
)
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index 70ea4d15d4..c067e5bfe3 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -29,6 +29,7 @@ from synapse.logging.context import LoggingContext
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
+from synapse.state import StateResolutionStore
from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort
from synapse.types import JsonDict
from synapse.util import Clock
@@ -161,6 +162,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
if prev_exists_as_outlier:
prev_event.internal_metadata.outlier = True
persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
self.get_success(
persistence.persist_event(
prev_event,
@@ -861,7 +863,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
bert_member_event.event_id: bert_member_event,
rejected_kick_event.event_id: rejected_kick_event,
},
- state_res_store=main_store,
+ state_res_store=StateResolutionStore(main_store),
)
),
[bert_member_event.event_id, rejected_kick_event.event_id],
@@ -906,7 +908,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
rejected_power_levels_event.event_id,
],
event_map={},
- state_res_store=main_store,
+ state_res_store=StateResolutionStore(main_store),
full_conflicted_set=set(),
)
),
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index c4727ab917..69d384442f 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -41,20 +41,21 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = self.hs.get_event_creation_handler()
- self._persist_event_storage_controller = (
- self.hs.get_storage_controllers().persistence
- )
+ persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
+ self._persist_event_storage_controller = persistence
self.user_id = self.register_user("tester", "foobar")
self.access_token = self.login("tester", "foobar")
self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
- self.info = self.get_success(
+ info = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(
self.access_token,
)
)
- self.token_id = self.info.token_id
+ assert info is not None
+ self.token_id = info.token_id
self.requester = create_requester(self.user_id, access_token_id=self.token_id)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index adddbd002f..951caaa6b3 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -150,7 +150,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
hs = self.setup_test_homeserver()
self.hs_patcher = self.fake_server.patch_homeserver(hs=hs)
- self.hs_patcher.start()
+ self.hs_patcher.start() # type: ignore[attr-defined]
self.handler = hs.get_oidc_handler()
self.provider = self.handler._providers["oidc"]
@@ -170,7 +170,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
return hs
def tearDown(self) -> None:
- self.hs_patcher.stop()
+ self.hs_patcher.stop() # type: ignore[attr-defined]
return super().tearDown()
def reset_mocks(self) -> None:
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 0916de64f5..aa91bc0a3d 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -852,7 +852,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
username: The username to use for the test.
registration: Whether to test with registration URLs.
"""
- self.hs.get_identity_handler().send_threepid_validation = Mock(
+ self.hs.get_identity_handler().send_threepid_validation = Mock( # type: ignore[assignment]
return_value=make_awaitable(0),
)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index b9332d97dc..1db99b3c00 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -62,7 +62,7 @@ class TestSpamChecker:
request_info: Collection[Tuple[str, str]],
auth_provider_id: Optional[str],
) -> RegistrationBehaviour:
- pass
+ return RegistrationBehaviour.ALLOW
class DenyAll(TestSpamChecker):
@@ -111,7 +111,7 @@ class TestLegacyRegistrationSpamChecker:
username: Optional[str],
request_info: Collection[Tuple[str, str]],
) -> RegistrationBehaviour:
- pass
+ return RegistrationBehaviour.ALLOW
class LegacyAllowAll(TestLegacyRegistrationSpamChecker):
@@ -203,7 +203,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_not_blocked(self) -> None:
- self.store.count_monthly_users = Mock(
+ self.store.count_monthly_users = Mock( # type: ignore[assignment]
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
)
# Ensure does not throw exception
@@ -304,7 +304,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None:
room_alias_str = "#room:test"
- self.store.count_real_users = Mock(return_value=make_awaitable(1))
+ self.store.count_real_users = Mock(return_value=make_awaitable(1)) # type: ignore[assignment]
self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
@@ -319,7 +319,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(
self,
) -> None:
- self.store.count_real_users = Mock(return_value=make_awaitable(2))
+ self.store.count_real_users = Mock(return_value=make_awaitable(2)) # type: ignore[assignment]
self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
@@ -346,6 +346,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Ensure the room is properly not federated.
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+ assert room is not None
self.assertFalse(room["federatable"])
self.assertFalse(room["public"])
self.assertEqual(room["join_rules"], "public")
@@ -375,6 +376,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Ensure the room is properly a public room.
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+ assert room is not None
self.assertEqual(room["join_rules"], "public")
# Both users should be in the room.
@@ -413,6 +415,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Ensure the room is properly a private room.
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+ assert room is not None
self.assertFalse(room["public"])
self.assertEqual(room["join_rules"], "invite")
self.assertEqual(room["guest_access"], "can_join")
@@ -456,6 +459,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Ensure the room is properly a private room.
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+ assert room is not None
self.assertFalse(room["public"])
self.assertEqual(room["join_rules"], "invite")
self.assertEqual(room["guest_access"], "can_join")
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 9b1b8b9f13..b5c772a7ae 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -134,7 +134,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# send a mocked-up SAML response to the callback
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
@@ -164,7 +164,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# Map a user via SSO.
saml_response = FakeAuthnResponse(
@@ -206,11 +206,11 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# mock out the error renderer too
sso_handler = self.hs.get_sso_handler()
- sso_handler.render_error = Mock(return_value=None)
+ sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment]
saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
request = _mock_request()
@@ -227,9 +227,9 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler and error renderer
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
sso_handler = self.hs.get_sso_handler()
- sso_handler.render_error = Mock(return_value=None)
+ sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment]
# register a user to occupy the first-choice MXID
store = self.hs.get_datastores().main
@@ -312,7 +312,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# The response doesn't have the proper userGroup or department.
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 1fe9563c98..94518a7196 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -74,8 +74,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
mock_keyring.verify_json_for_server.return_value = make_awaitable(True)
# we mock out the federation client too
- mock_federation_client = Mock(spec=["put_json"])
- mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))
+ self.mock_federation_client = Mock(spec=["put_json"])
+ self.mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))
# the tests assume that we are starting at unix time 1000
reactor.pump((1000,))
@@ -83,7 +83,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.mock_hs_notifier = Mock()
hs = self.setup_test_homeserver(
notifier=self.mock_hs_notifier,
- federation_http_client=mock_federation_client,
+ federation_http_client=self.mock_federation_client,
keyring=mock_keyring,
replication_streams={},
)
@@ -233,8 +233,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
)
- put_json = self.hs.get_federation_http_client().put_json
- put_json.assert_called_once_with(
+ self.mock_federation_client.put_json.assert_called_once_with(
"farm",
path="/_matrix/federation/v1/send/1000000",
data=_expect_edu_transaction(
@@ -349,8 +348,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
- put_json = self.hs.get_federation_http_client().put_json
- put_json.assert_called_once_with(
+ self.mock_federation_client.put_json.assert_called_once_with(
"farm",
path="/_matrix/federation/v1/send/1000000",
data=_expect_edu_transaction(
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 75fc5a17a4..a02c1c6227 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Tuple
+from typing import Any, Tuple
from unittest.mock import Mock, patch
from urllib.parse import quote
@@ -24,7 +24,7 @@ from synapse.appservice import ApplicationService
from synapse.rest.client import login, register, room, user_directory
from synapse.server import HomeServer
from synapse.storage.roommember import ProfileInfo
-from synapse.types import create_requester
+from synapse.types import UserProfile, create_requester
from synapse.util import Clock
from tests import unittest
@@ -34,6 +34,12 @@ from tests.test_utils.event_injection import inject_member_event
from tests.unittest import override_config
+# A spam checker which doesn't implement anything, so create a bare object.
+class UselessSpamChecker:
+ def __init__(self, config: Any):
+ pass
+
+
class UserDirectoryTestCase(unittest.HomeserverTestCase):
"""Tests the UserDirectoryHandler.
@@ -186,6 +192,13 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.helper.join(room, self.appservice.sender, tok=self.appservice.token)
self._check_only_one_user_in_directory(user, room)
+ def test_search_term_with_colon_in_it_does_not_raise(self) -> None:
+ """
+ Regression test: Test that search terms with colons in them are acceptable.
+ """
+ u1 = self.register_user("user1", "pass")
+ self.get_success(self.handler.search_users(u1, "haha:paamayim-nekudotayim", 10))
+
def test_user_not_in_users_table(self) -> None:
"""Unclear how it happens, but on matrix.org we've seen join events
for users who aren't in the users table. Test that we don't fall over
@@ -773,7 +786,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 1)
- async def allow_all(user_profile: ProfileInfo) -> bool:
+ async def allow_all(user_profile: UserProfile) -> bool:
# Allow all users.
return False
@@ -787,7 +800,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(s["results"]), 1)
# Configure a spam checker that filters all users.
- async def block_all(user_profile: ProfileInfo) -> bool:
+ async def block_all(user_profile: UserProfile) -> bool:
# All users are spammy.
return True
@@ -797,6 +810,13 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 0)
+ @override_config(
+ {
+ "spam_checker": {
+ "module": "tests.handlers.test_user_directory.UselessSpamChecker"
+ }
+ }
+ )
def test_legacy_spam_checker(self) -> None:
"""
A spam checker without the expected method should be ignored.
@@ -825,11 +845,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)})
self.assertEqual(public_users, set())
- # Configure a spam checker.
- spam_checker = self.hs.get_spam_checker()
- # The spam checker doesn't need any methods, so create a bare object.
- spam_checker.spam_checker = object()
-
# We get one search result when searching for user2 by user1.
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 1)
@@ -949,13 +964,14 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(
- self.hs.get_storage_controllers().persistence.persist_event(event, context)
- )
+ context = self.get_success(unpersisted_context.persist(event))
+ persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
+ self.get_success(persistence.persist_event(event, context))
def test_local_user_leaving_room_remains_in_user_directory(self) -> None:
"""We've chosen to simplify the user directory's implementation by
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index acfdcd3bca..eb7f53fee5 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -30,7 +30,7 @@ from twisted.internet.interfaces import (
IOpenSSLClientConnectionCreator,
IProtocolFactory,
)
-from twisted.internet.protocol import Factory
+from twisted.internet.protocol import Factory, Protocol
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.web._newclient import ResponseNeverReceived
from twisted.web.client import Agent
@@ -63,7 +63,7 @@ from tests.http import (
get_test_ca_cert_file,
)
from tests.server import FakeTransport, ThreadedMemoryReactorClock
-from tests.utils import default_config
+from tests.utils import checked_cast, default_config
logger = logging.getLogger(__name__)
@@ -146,8 +146,10 @@ class MatrixFederationAgentTests(unittest.TestCase):
#
# Normally this would be done by the TCP socket code in Twisted, but we are
# stubbing that out here.
- client_protocol = client_factory.buildProtocol(dummy_address)
- assert isinstance(client_protocol, _WrappingProtocol)
+ # NB: we use a checked_cast here to workaround https://github.com/Shoobx/mypy-zope/issues/91)
+ client_protocol = checked_cast(
+ _WrappingProtocol, client_factory.buildProtocol(dummy_address)
+ )
client_protocol.makeConnection(
FakeTransport(server_protocol, self.reactor, client_protocol)
)
@@ -446,7 +448,6 @@ class MatrixFederationAgentTests(unittest.TestCase):
server_ssl_protocol = _wrap_server_factory_for_tls(
_get_test_protocol_factory()
).buildProtocol(dummy_address)
- assert isinstance(server_ssl_protocol, TLSMemoryBIOProtocol)
# Tell the HTTP server to send outgoing traffic back via the proxy's transport.
proxy_server_transport = proxy_server.transport
@@ -465,7 +466,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
else:
assert isinstance(proxy_server_transport, FakeTransport)
client_protocol = proxy_server_transport.other
- c2s_transport = client_protocol.transport
+ assert isinstance(client_protocol, Protocol)
+ c2s_transport = checked_cast(FakeTransport, client_protocol.transport)
c2s_transport.other = server_ssl_protocol
self.reactor.advance(0)
@@ -1529,7 +1531,7 @@ def _check_logcontext(context: LoggingContextOrSentinel) -> None:
def _wrap_server_factory_for_tls(
factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None
-) -> IProtocolFactory:
+) -> TLSMemoryBIOFactory:
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
The resultant factory will create a TLS server which presents a certificate
signed by our test CA, valid for the domains in `sanlist`
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index a817940730..cc175052ac 100644
--- a/tests/http/test_proxyagent.py
+++ b/tests/http/test_proxyagent.py
@@ -28,7 +28,7 @@ from twisted.internet.endpoints import (
_WrappingProtocol,
)
from twisted.internet.interfaces import IProtocol, IProtocolFactory
-from twisted.internet.protocol import Factory
+from twisted.internet.protocol import Factory, Protocol
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.web.http import HTTPChannel
@@ -43,6 +43,7 @@ from tests.http import (
)
from tests.server import FakeTransport, ThreadedMemoryReactorClock
from tests.unittest import TestCase
+from tests.utils import checked_cast
logger = logging.getLogger(__name__)
@@ -620,7 +621,6 @@ class MatrixFederationAgentTests(TestCase):
server_ssl_protocol = _wrap_server_factory_for_tls(
_get_test_protocol_factory()
).buildProtocol(dummy_address)
- assert isinstance(server_ssl_protocol, TLSMemoryBIOProtocol)
# Tell the HTTP server to send outgoing traffic back via the proxy's transport.
proxy_server_transport = proxy_server.transport
@@ -644,7 +644,8 @@ class MatrixFederationAgentTests(TestCase):
else:
assert isinstance(proxy_server_transport, FakeTransport)
client_protocol = proxy_server_transport.other
- c2s_transport = client_protocol.transport
+ assert isinstance(client_protocol, Protocol)
+ c2s_transport = checked_cast(FakeTransport, client_protocol.transport)
c2s_transport.other = server_ssl_protocol
self.reactor.advance(0)
@@ -757,12 +758,14 @@ class MatrixFederationAgentTests(TestCase):
assert isinstance(proxy_server, HTTPChannel)
# fish the transports back out so that we can do the old switcheroo
- s2c_transport = proxy_server.transport
- assert isinstance(s2c_transport, FakeTransport)
- client_protocol = s2c_transport.other
- assert isinstance(client_protocol, _WrappingProtocol)
- c2s_transport = client_protocol.transport
- assert isinstance(c2s_transport, FakeTransport)
+ # To help mypy out with the various Protocols and wrappers and mocks, we do
+ # some explicit casting. Without the casts, we hit the bug I reported at
+ # https://github.com/Shoobx/mypy-zope/issues/91 .
+ # We also double-checked these casts at runtime (test-time) because I found it
+ # quite confusing to deduce these types in the first place!
+ s2c_transport = checked_cast(FakeTransport, proxy_server.transport)
+ client_protocol = checked_cast(_WrappingProtocol, s2c_transport.other)
+ c2s_transport = checked_cast(FakeTransport, client_protocol.transport)
# the FakeTransport is async, so we need to pump the reactor
self.reactor.advance(0)
@@ -822,9 +825,9 @@ class MatrixFederationAgentTests(TestCase):
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888"})
def test_proxy_with_no_scheme(self) -> None:
http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
- assert isinstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint)
- self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com")
- self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888)
+ proxy_ep = checked_cast(HostnameEndpoint, http_proxy_agent.http_proxy_endpoint)
+ self.assertEqual(proxy_ep._hostStr, "proxy.com")
+ self.assertEqual(proxy_ep._port, 8888)
@patch.dict(os.environ, {"http_proxy": "socks://proxy.com:8888"})
def test_proxy_with_unsupported_scheme(self) -> None:
@@ -834,25 +837,21 @@ class MatrixFederationAgentTests(TestCase):
@patch.dict(os.environ, {"http_proxy": "http://proxy.com:8888"})
def test_proxy_with_http_scheme(self) -> None:
http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
- assert isinstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint)
- self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com")
- self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888)
+ proxy_ep = checked_cast(HostnameEndpoint, http_proxy_agent.http_proxy_endpoint)
+ self.assertEqual(proxy_ep._hostStr, "proxy.com")
+ self.assertEqual(proxy_ep._port, 8888)
@patch.dict(os.environ, {"http_proxy": "https://proxy.com:8888"})
def test_proxy_with_https_scheme(self) -> None:
https_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
- assert isinstance(https_proxy_agent.http_proxy_endpoint, _WrapperEndpoint)
- self.assertEqual(
- https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._hostStr, "proxy.com"
- )
- self.assertEqual(
- https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._port, 8888
- )
+ proxy_ep = checked_cast(_WrapperEndpoint, https_proxy_agent.http_proxy_endpoint)
+ self.assertEqual(proxy_ep._wrappedEndpoint._hostStr, "proxy.com")
+ self.assertEqual(proxy_ep._wrappedEndpoint._port, 8888)
def _wrap_server_factory_for_tls(
factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None
-) -> IProtocolFactory:
+) -> TLSMemoryBIOFactory:
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
The resultant factory will create a TLS server which presents a certificate
diff --git a/tests/logging/test_remote_handler.py b/tests/logging/test_remote_handler.py
index c08954d887..5191e31a8a 100644
--- a/tests/logging/test_remote_handler.py
+++ b/tests/logging/test_remote_handler.py
@@ -21,6 +21,7 @@ from synapse.logging import RemoteHandler
from tests.logging import LoggerCleanupMixin
from tests.server import FakeTransport, get_clock
from tests.unittest import TestCase
+from tests.utils import checked_cast
def connect_logging_client(
@@ -56,8 +57,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
client, server = connect_logging_client(self.reactor, 0)
# Trigger data being sent
- assert isinstance(client.transport, FakeTransport)
- client.transport.flush()
+ client_transport = checked_cast(FakeTransport, client.transport)
+ client_transport.flush()
# One log message, with a single trailing newline
logs = server.data.decode("utf8").splitlines()
@@ -89,8 +90,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
# Allow the reconnection
client, server = connect_logging_client(self.reactor, 0)
- assert isinstance(client.transport, FakeTransport)
- client.transport.flush()
+ client_transport = checked_cast(FakeTransport, client.transport)
+ client_transport.flush()
# Only the 7 infos made it through, the debugs were elided
logs = server.data.splitlines()
@@ -123,8 +124,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
# Allow the reconnection
client, server = connect_logging_client(self.reactor, 0)
- assert isinstance(client.transport, FakeTransport)
- client.transport.flush()
+ client_transport = checked_cast(FakeTransport, client.transport)
+ client_transport.flush()
# The 10 warnings made it through, the debugs and infos were elided
logs = server.data.splitlines()
@@ -148,8 +149,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
# Allow the reconnection
client, server = connect_logging_client(self.reactor, 0)
- assert isinstance(client.transport, FakeTransport)
- client.transport.flush()
+ client_transport = checked_cast(FakeTransport, client.transport)
+ client_transport.flush()
# The first five and last five warnings made it through, the debugs and
# infos were elided
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 8f88c0117d..3a1929691e 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -11,9 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Dict
from unittest.mock import Mock
from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EduTypes, EventTypes
from synapse.api.errors import NotFoundError
@@ -21,9 +23,12 @@ from synapse.events import EventBase
from synapse.federation.units import Transaction
from synapse.handlers.presence import UserPresenceState
from synapse.handlers.push_rules import InvalidRuleException
+from synapse.module_api import ModuleApi
from synapse.rest import admin
from synapse.rest.client import login, notifications, presence, profile, room
-from synapse.types import create_requester
+from synapse.server import HomeServer
+from synapse.types import JsonDict, create_requester
+from synapse.util import Clock
from tests.events.test_presence_router import send_presence_update, sync_presence
from tests.replication._base import BaseMultiWorkerStreamTestCase
@@ -32,7 +37,19 @@ from tests.test_utils.event_injection import inject_member_event
from tests.unittest import HomeserverTestCase, override_config
-class ModuleApiTestCase(HomeserverTestCase):
+class BaseModuleApiTestCase(HomeserverTestCase):
+ """Common properties of the two test case classes."""
+
+ module_api: ModuleApi
+
+ # These are all written by _test_sending_local_online_presence_to_local_user.
+ presence_receiver_id: str
+ presence_receiver_tok: str
+ presence_sender_id: str
+ presence_sender_tok: str
+
+
+class ModuleApiTestCase(BaseModuleApiTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
@@ -42,23 +59,23 @@ class ModuleApiTestCase(HomeserverTestCase):
notifications.register_servlets,
]
- def prepare(self, reactor, clock, homeserver):
- self.store = homeserver.get_datastores().main
- self.module_api = homeserver.get_module_api()
- self.event_creation_handler = homeserver.get_event_creation_handler()
- self.sync_handler = homeserver.get_sync_handler()
- self.auth_handler = homeserver.get_auth_handler()
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.module_api = hs.get_module_api()
+ self.event_creation_handler = hs.get_event_creation_handler()
+ self.sync_handler = hs.get_sync_handler()
+ self.auth_handler = hs.get_auth_handler()
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# Mock out the calls over federation.
- fed_transport_client = Mock(spec=["send_transaction"])
- fed_transport_client.send_transaction = simple_async_mock({})
+ self.fed_transport_client = Mock(spec=["send_transaction"])
+ self.fed_transport_client.send_transaction = simple_async_mock({})
return self.setup_test_homeserver(
- federation_transport_client=fed_transport_client,
+ federation_transport_client=self.fed_transport_client,
)
- def test_can_register_user(self):
+ def test_can_register_user(self) -> None:
"""Tests that an external module can register a user"""
# Register a new user
user_id, access_token = self.get_success(
@@ -88,16 +105,17 @@ class ModuleApiTestCase(HomeserverTestCase):
displayname = self.get_success(self.store.get_profile_displayname("bob"))
self.assertEqual(displayname, "Bobberino")
- def test_can_register_admin_user(self):
+ def test_can_register_admin_user(self) -> None:
user_id = self.register_user(
"bob_module_admin", "1234", displayname="Bobberino Admin", admin=True
)
found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
+ assert found_user is not None
self.assertEqual(found_user.user_id.to_string(), user_id)
self.assertIdentical(found_user.is_admin, True)
- def test_can_set_admin(self):
+ def test_can_set_admin(self) -> None:
user_id = self.register_user(
"alice_wants_admin",
"1234",
@@ -107,16 +125,17 @@ class ModuleApiTestCase(HomeserverTestCase):
self.get_success(self.module_api.set_user_admin(user_id, True))
found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
+ assert found_user is not None
self.assertEqual(found_user.user_id.to_string(), user_id)
self.assertIdentical(found_user.is_admin, True)
- def test_can_set_displayname(self):
+ def test_can_set_displayname(self) -> None:
localpart = "alice_wants_a_new_displayname"
user_id = self.register_user(
localpart, "1234", displayname="Alice", admin=False
)
found_userinfo = self.get_success(self.module_api.get_userinfo_by_id(user_id))
-
+ assert found_userinfo is not None
self.get_success(
self.module_api.set_displayname(
found_userinfo.user_id, "Bob", deactivation=False
@@ -128,17 +147,18 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertEqual(found_profile.display_name, "Bob")
- def test_get_userinfo_by_id(self):
+ def test_get_userinfo_by_id(self) -> None:
user_id = self.register_user("alice", "1234")
found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
+ assert found_user is not None
self.assertEqual(found_user.user_id.to_string(), user_id)
self.assertIdentical(found_user.is_admin, False)
- def test_get_userinfo_by_id__no_user_found(self):
+ def test_get_userinfo_by_id__no_user_found(self) -> None:
found_user = self.get_success(self.module_api.get_userinfo_by_id("@alice:test"))
self.assertIsNone(found_user)
- def test_get_user_ip_and_agents(self):
+ def test_get_user_ip_and_agents(self) -> None:
user_id = self.register_user("test_get_user_ip_and_agents_user", "1234")
# Initially, we should have no ip/agent for our user.
@@ -185,7 +205,7 @@ class ModuleApiTestCase(HomeserverTestCase):
# we should only find the second ip, agent.
info = self.get_success(
self.module_api.get_user_ip_and_agents(
- user_id, (last_seen_1 + last_seen_2) / 2
+ user_id, (last_seen_1 + last_seen_2) // 2
)
)
self.assertEqual(len(info), 1)
@@ -200,7 +220,7 @@ class ModuleApiTestCase(HomeserverTestCase):
)
self.assertEqual(info, [])
- def test_get_user_ip_and_agents__no_user_found(self):
+ def test_get_user_ip_and_agents__no_user_found(self) -> None:
info = self.get_success(
self.module_api.get_user_ip_and_agents(
"@test_get_user_ip_and_agents_user_nonexistent:example.com"
@@ -208,10 +228,10 @@ class ModuleApiTestCase(HomeserverTestCase):
)
self.assertEqual(info, [])
- def test_sending_events_into_room(self):
+ def test_sending_events_into_room(self) -> None:
"""Tests that a module can send events into a room"""
# Mock out create_and_send_nonmember_event to check whether events are being sent
- self.event_creation_handler.create_and_send_nonmember_event = Mock(
+ self.event_creation_handler.create_and_send_nonmember_event = Mock( # type: ignore[assignment]
spec=[],
side_effect=self.event_creation_handler.create_and_send_nonmember_event,
)
@@ -222,7 +242,7 @@ class ModuleApiTestCase(HomeserverTestCase):
room_id = self.helper.create_room_as(user_id, tok=tok)
# Create and send a non-state event
- content = {"body": "I am a puppet", "msgtype": "m.text"}
+ content: JsonDict = {"body": "I am a puppet", "msgtype": "m.text"}
event_dict = {
"room_id": room_id,
"type": "m.room.message",
@@ -265,7 +285,7 @@ class ModuleApiTestCase(HomeserverTestCase):
"sender": user_id,
"state_key": "",
}
- event: EventBase = self.get_success(
+ event = self.get_success(
self.module_api.create_and_send_event_into_room(event_dict)
)
self.assertEqual(event.sender, user_id)
@@ -303,7 +323,7 @@ class ModuleApiTestCase(HomeserverTestCase):
self.module_api.create_and_send_event_into_room(event_dict), Exception
)
- def test_public_rooms(self):
+ def test_public_rooms(self) -> None:
"""Tests that a room can be added and removed from the public rooms list,
as well as have its public rooms directory state queried.
"""
@@ -350,13 +370,13 @@ class ModuleApiTestCase(HomeserverTestCase):
)
self.assertFalse(is_in_public_rooms)
- def test_send_local_online_presence_to(self):
+ def test_send_local_online_presence_to(self) -> None:
# Test sending local online presence to users from the main process
_test_sending_local_online_presence_to_local_user(self, test_with_workers=False)
# Enable federation sending on the main process.
@override_config({"federation_sender_instances": None})
- def test_send_local_online_presence_to_federation(self):
+ def test_send_local_online_presence_to_federation(self) -> None:
"""Tests that send_local_presence_to_users sends local online presence to remote users."""
# Create a user who will send presence updates
self.presence_sender_id = self.register_user("presence_sender1", "monkey")
@@ -397,7 +417,7 @@ class ModuleApiTestCase(HomeserverTestCase):
#
# Thus we reset the mock, and try sending online local user
# presence again
- self.hs.get_federation_transport_client().send_transaction.reset_mock()
+ self.fed_transport_client.send_transaction.reset_mock()
# Broadcast local user online presence
self.get_success(
@@ -409,9 +429,7 @@ class ModuleApiTestCase(HomeserverTestCase):
# Check that a presence update was sent as part of a federation transaction
found_update = False
- calls = (
- self.hs.get_federation_transport_client().send_transaction.call_args_list
- )
+ calls = self.fed_transport_client.send_transaction.call_args_list
for call in calls:
call_args = call[0]
federation_transaction: Transaction = call_args[0]
@@ -431,7 +449,7 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertTrue(found_update)
- def test_update_membership(self):
+ def test_update_membership(self) -> None:
"""Tests that the module API can update the membership of a user in a room."""
peter = self.register_user("peter", "hackme")
lesley = self.register_user("lesley", "hackme")
@@ -554,14 +572,14 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertEqual(res["displayname"], "simone")
self.assertIsNone(res["avatar_url"])
- def test_update_room_membership_remote_join(self):
+ def test_update_room_membership_remote_join(self) -> None:
"""Test that the module API can join a remote room."""
# Necessary to fake a remote join.
fake_stream_id = 1
mocked_remote_join = simple_async_mock(
return_value=("fake-event-id", fake_stream_id)
)
- self.hs.get_room_member_handler()._remote_join = mocked_remote_join
+ self.hs.get_room_member_handler()._remote_join = mocked_remote_join # type: ignore[assignment]
fake_remote_host = f"{self.module_api.server_name}-remote"
# Given that the join is to be faked, we expect the relevant join event not to
@@ -582,7 +600,7 @@ class ModuleApiTestCase(HomeserverTestCase):
# Check that a remote join was attempted.
self.assertEqual(mocked_remote_join.call_count, 1)
- def test_get_room_state(self):
+ def test_get_room_state(self) -> None:
"""Tests that a module can retrieve the state of a room through the module API."""
user_id = self.register_user("peter", "hackme")
tok = self.login("peter", "hackme")
@@ -677,7 +695,7 @@ class ModuleApiTestCase(HomeserverTestCase):
self.module_api.check_push_rule_actions(["foo"])
with self.assertRaises(InvalidRuleException):
- self.module_api.check_push_rule_actions({"foo": "bar"})
+ self.module_api.check_push_rule_actions([{"foo": "bar"}])
self.module_api.check_push_rule_actions(["notify"])
@@ -756,7 +774,7 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertIsNone(room_alias)
-class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
+class ModuleApiWorkerTestCase(BaseModuleApiTestCase, BaseMultiWorkerStreamTestCase):
"""For testing ModuleApi functionality in a multi-worker setup"""
servlets = [
@@ -766,7 +784,7 @@ class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
presence.register_servlets,
]
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
conf = super().default_config()
conf["stream_writers"] = {"presence": ["presence_writer"]}
conf["instance_map"] = {
@@ -774,18 +792,18 @@ class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
}
return conf
- def prepare(self, reactor, clock, homeserver):
- self.module_api = homeserver.get_module_api()
- self.sync_handler = homeserver.get_sync_handler()
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.module_api = hs.get_module_api()
+ self.sync_handler = hs.get_sync_handler()
- def test_send_local_online_presence_to_workers(self):
+ def test_send_local_online_presence_to_workers(self) -> None:
# Test sending local online presence to users from a worker process
_test_sending_local_online_presence_to_local_user(self, test_with_workers=True)
def _test_sending_local_online_presence_to_local_user(
- test_case: HomeserverTestCase, test_with_workers: bool = False
-):
+ test_case: BaseModuleApiTestCase, test_with_workers: bool = False
+) -> None:
"""Tests that send_local_presence_to_users sends local online presence to local users.
This simultaneously tests two different usecases:
@@ -852,6 +870,7 @@ def _test_sending_local_online_presence_to_local_user(
# Replicate the current sync presence token from the main process to the worker process.
# We need to do this so that the worker process knows the current presence stream ID to
# insert into the database when we call ModuleApi.send_local_online_presence_to.
+ assert isinstance(test_case, BaseMultiWorkerStreamTestCase)
test_case.replicate()
# Syncing again should result in no presence updates
@@ -868,6 +887,7 @@ def _test_sending_local_online_presence_to_local_user(
# Determine on which process (main or worker) to call ModuleApi.send_local_online_presence_to on
if test_with_workers:
+ assert isinstance(test_case, BaseMultiWorkerStreamTestCase)
module_api_to_use = worker_hs.get_module_api()
else:
module_api_to_use = test_case.module_api
@@ -875,12 +895,11 @@ def _test_sending_local_online_presence_to_local_user(
# Trigger sending local online presence. We expect this information
# to be saved to the database where all processes can access it.
# Note that we're syncing via the master.
- d = module_api_to_use.send_local_online_presence_to(
- [
- test_case.presence_receiver_id,
- ]
+ d = defer.ensureDeferred(
+ module_api_to_use.send_local_online_presence_to(
+ [test_case.presence_receiver_id],
+ )
)
- d = defer.ensureDeferred(d)
if test_with_workers:
# In order for the required presence_set_state replication request to occur between the
@@ -897,7 +916,7 @@ def _test_sending_local_online_presence_to_local_user(
)
test_case.assertEqual(len(presence_updates), 1)
- presence_update: UserPresenceState = presence_updates[0]
+ presence_update = presence_updates[0]
test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
test_case.assertEqual(presence_update.state, "online")
@@ -908,7 +927,7 @@ def _test_sending_local_online_presence_to_local_user(
)
test_case.assertEqual(len(presence_updates), 1)
- presence_update: UserPresenceState = presence_updates[0]
+ presence_update = presence_updates[0]
test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
test_case.assertEqual(presence_update.state, "online")
@@ -936,12 +955,13 @@ def _test_sending_local_online_presence_to_local_user(
test_case.assertEqual(len(presence_updates), 1)
# Now trigger sending local online presence.
- d = module_api_to_use.send_local_online_presence_to(
- [
- test_case.presence_receiver_id,
- ]
+ d = defer.ensureDeferred(
+ module_api_to_use.send_local_online_presence_to(
+ [
+ test_case.presence_receiver_id,
+ ]
+ )
)
- d = defer.ensureDeferred(d)
if test_with_workers:
# In order for the required presence_set_state replication request to occur between the
diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py
index 7567756135..199e3d7b70 100644
--- a/tests/push/test_bulk_push_rule_evaluator.py
+++ b/tests/push/test_bulk_push_rule_evaluator.py
@@ -227,7 +227,14 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
)
return len(result) > 0
- @override_config({"experimental_features": {"msc3952_intentional_mentions": True}})
+ @override_config(
+ {
+ "experimental_features": {
+ "msc3758_exact_event_match": True,
+ "msc3952_intentional_mentions": True,
+ }
+ }
+ )
def test_user_mentions(self) -> None:
"""Test the behavior of an event which includes invalid user mentions."""
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
@@ -323,7 +330,14 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
)
)
- @override_config({"experimental_features": {"msc3952_intentional_mentions": True}})
+ @override_config(
+ {
+ "experimental_features": {
+ "msc3758_exact_event_match": True,
+ "msc3952_intentional_mentions": True,
+ }
+ }
+ )
def test_room_mentions(self) -> None:
"""Test the behavior of an event which includes invalid room mentions."""
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index ab8bb417e7..7563f33fdc 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -23,6 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.errors import Codes, SynapseError
+from synapse.push.emailpusher import EmailPusher
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.util import Clock
@@ -105,6 +106,7 @@ class EmailPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(self.access_token)
)
+ assert user_tuple is not None
self.token_id = user_tuple.token_id
# We need to add email to account before we can create a pusher.
@@ -114,7 +116,7 @@ class EmailPusherTests(HomeserverTestCase):
)
)
- self.pusher = self.get_success(
+ pusher = self.get_success(
self.hs.get_pusherpool().add_or_update_pusher(
user_id=self.user_id,
access_token=self.token_id,
@@ -127,6 +129,8 @@ class EmailPusherTests(HomeserverTestCase):
data={},
)
)
+ assert isinstance(pusher, EmailPusher)
+ self.pusher = pusher
self.auth_handler = hs.get_auth_handler()
self.store = hs.get_datastores().main
@@ -375,10 +379,13 @@ class EmailPusherTests(HomeserverTestCase):
)
# check that the pusher for that email address has been deleted
- pushers = self.get_success(
- self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
+ pushers = list(
+ self.get_success(
+ self.hs.get_datastores().main.get_pushers_by(
+ {"user_name": self.user_id}
+ )
+ )
)
- pushers = list(pushers)
self.assertEqual(len(pushers), 0)
def test_remove_unlinked_pushers_background_job(self) -> None:
@@ -413,10 +420,13 @@ class EmailPusherTests(HomeserverTestCase):
self.wait_for_background_updates()
# Check that all pushers with unlinked addresses were deleted
- pushers = self.get_success(
- self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
+ pushers = list(
+ self.get_success(
+ self.hs.get_datastores().main.get_pushers_by(
+ {"user_name": self.user_id}
+ )
+ )
)
- pushers = list(pushers)
self.assertEqual(len(pushers), 0)
def _check_for_mail(self) -> Tuple[Sequence, Dict]:
@@ -428,10 +438,13 @@ class EmailPusherTests(HomeserverTestCase):
that notification.
"""
# Get the stream ordering before it gets sent
- pushers = self.get_success(
- self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
+ pushers = list(
+ self.get_success(
+ self.hs.get_datastores().main.get_pushers_by(
+ {"user_name": self.user_id}
+ )
+ )
)
- pushers = list(pushers)
self.assertEqual(len(pushers), 1)
last_stream_ordering = pushers[0].last_stream_ordering
@@ -439,10 +452,13 @@ class EmailPusherTests(HomeserverTestCase):
self.pump(10)
# It hasn't succeeded yet, so the stream ordering shouldn't have moved
- pushers = self.get_success(
- self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
+ pushers = list(
+ self.get_success(
+ self.hs.get_datastores().main.get_pushers_by(
+ {"user_name": self.user_id}
+ )
+ )
)
- pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
@@ -458,10 +474,13 @@ class EmailPusherTests(HomeserverTestCase):
self.assertEqual(len(self.email_attempts), 1)
# The stream ordering has increased
- pushers = self.get_success(
- self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
+ pushers = list(
+ self.get_success(
+ self.hs.get_datastores().main.get_pushers_by(
+ {"user_name": self.user_id}
+ )
+ )
)
- pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index 23447cc310..c280ddcdf6 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List, Optional, Tuple
+from typing import Any, List, Tuple
from unittest.mock import Mock
from twisted.internet.defer import Deferred
@@ -22,7 +22,6 @@ from synapse.logging.context import make_deferred_yieldable
from synapse.push import PusherConfig, PusherConfigException
from synapse.rest.client import login, push_rule, pusher, receipts, room
from synapse.server import HomeServer
-from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import JsonDict
from synapse.util import Clock
@@ -67,9 +66,10 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert user_tuple is not None
token_id = user_tuple.token_id
- def test_data(data: Optional[JsonDict]) -> None:
+ def test_data(data: Any) -> None:
self.get_failure(
self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id,
@@ -113,6 +113,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert user_tuple is not None
token_id = user_tuple.token_id
self.get_success(
@@ -140,10 +141,11 @@ class HTTPPusherTests(HomeserverTestCase):
self.helper.send(room, body="There!", tok=other_access_token)
# Get the stream ordering before it gets sent
- pushers = self.get_success(
- self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+ pushers = list(
+ self.get_success(
+ self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+ )
)
- pushers = list(pushers)
self.assertEqual(len(pushers), 1)
last_stream_ordering = pushers[0].last_stream_ordering
@@ -151,10 +153,11 @@ class HTTPPusherTests(HomeserverTestCase):
self.pump()
# It hasn't succeeded yet, so the stream ordering shouldn't have moved
- pushers = self.get_success(
- self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+ pushers = list(
+ self.get_success(
+ self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+ )
)
- pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
@@ -172,10 +175,11 @@ class HTTPPusherTests(HomeserverTestCase):
self.pump()
# The stream ordering has increased
- pushers = self.get_success(
- self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+ pushers = list(
+ self.get_success(
+ self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+ )
)
- pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
last_stream_ordering = pushers[0].last_stream_ordering
@@ -194,10 +198,11 @@ class HTTPPusherTests(HomeserverTestCase):
self.pump()
# The stream ordering has increased, again
- pushers = self.get_success(
- self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+ pushers = list(
+ self.get_success(
+ self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+ )
)
- pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
@@ -229,6 +234,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert user_tuple is not None
token_id = user_tuple.token_id
self.get_success(
@@ -349,6 +355,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert user_tuple is not None
token_id = user_tuple.token_id
self.get_success(
@@ -435,6 +442,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert user_tuple is not None
token_id = user_tuple.token_id
self.get_success(
@@ -512,6 +520,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert user_tuple is not None
token_id = user_tuple.token_id
self.get_success(
@@ -618,6 +627,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert user_tuple is not None
token_id = user_tuple.token_id
self.get_success(
@@ -753,6 +763,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert user_tuple is not None
token_id = user_tuple.token_id
self.get_success(
@@ -895,6 +906,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert user_tuple is not None
token_id = user_tuple.token_id
device_id = user_tuple.device_id
@@ -941,9 +953,10 @@ class HTTPPusherTests(HomeserverTestCase):
)
# Look up the user info for the access token so we can compare the device ID.
- lookup_result: TokenLookupResult = self.get_success(
+ lookup_result = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert lookup_result is not None
# Get the user's devices and check it has the correct device ID.
channel = self.make_request("GET", "/pushers", access_token=access_token)
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index da33423871..d320a12f96 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -32,6 +32,7 @@ from synapse.storage.databases.main.appservice import _make_exclusive_regex
from synapse.synapse_rust.push import PushRuleEvaluator
from synapse.types import JsonDict, JsonMapping, UserID
from synapse.util import Clock
+from synapse.util.frozenutils import freeze
from tests import unittest
from tests.test_utils.event_injection import create_event, inject_member_event
@@ -48,17 +49,34 @@ class FlattenDictTestCase(unittest.TestCase):
input = {"foo": {"bar": "abc"}}
self.assertEqual({"foo.bar": "abc"}, _flatten_dict(input))
+ # If a field has a dot in it, escape it.
+ input = {"m.foo": {"b\\ar": "abc"}}
+ self.assertEqual({"m.foo.b\\ar": "abc"}, _flatten_dict(input))
+ self.assertEqual(
+ {"m\\.foo.b\\\\ar": "abc"},
+ _flatten_dict(input, msc3783_escape_event_match_key=True),
+ )
+
def test_non_string(self) -> None:
- """Non-string items are dropped."""
+ """String, booleans, ints, nulls and list of those should be kept while other items are dropped."""
input: Dict[str, Any] = {
"woo": "woo",
"foo": True,
"bar": 1,
"baz": None,
- "fuzz": [],
+ "fuzz": ["woo", True, 1, None, [], {}],
"boo": {},
}
- self.assertEqual({"woo": "woo"}, _flatten_dict(input))
+ self.assertEqual(
+ {
+ "woo": "woo",
+ "foo": True,
+ "bar": 1,
+ "baz": None,
+ "fuzz": ["woo", True, 1, None],
+ },
+ _flatten_dict(input),
+ )
def test_event(self) -> None:
"""Events can also be flattened."""
@@ -78,9 +96,9 @@ class FlattenDictTestCase(unittest.TestCase):
)
expected = {
"content.msgtype": "m.text",
- "content.body": "hello world!",
+ "content.body": "Hello world!",
"content.format": "org.matrix.custom.html",
- "content.formatted_body": "<h1>hello world!</h1>",
+ "content.formatted_body": "<h1>Hello world!</h1>",
"room_id": "!test:test",
"sender": "@alice:test",
"type": "m.room.message",
@@ -107,6 +125,7 @@ class FlattenDictTestCase(unittest.TestCase):
"room_id": "!test:test",
"sender": "@alice:test",
"type": "m.room.message",
+ "content.org.matrix.msc1767.markup": [],
}
self.assertEqual(expected, _flatten_dict(event))
@@ -118,6 +137,7 @@ class FlattenDictTestCase(unittest.TestCase):
"room_id": "!test:test",
"sender": "@alice:test",
"type": "m.room.message",
+ "content.org.matrix.msc1767.markup": [],
}
self.assertEqual(expected, _flatten_dict(event))
@@ -129,7 +149,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
*,
has_mentions: bool = False,
user_mentions: Optional[Set[str]] = None,
- room_mention: bool = False,
related_events: Optional[JsonDict] = None,
) -> PushRuleEvaluator:
event = FrozenEvent(
@@ -150,7 +169,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
_flatten_dict(event),
has_mentions,
user_mentions or set(),
- room_mention,
room_member_count,
sender_power_level,
cast(Dict[str, int], power_levels.get("notifications", {})),
@@ -158,6 +176,8 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
related_event_match_enabled=True,
room_version_feature_flags=event.room_version.msc3931_push_features,
msc3931_enabled=True,
+ msc3758_exact_event_match=True,
+ msc3966_exact_event_property_contains=True,
)
def test_display_name(self) -> None:
@@ -210,27 +230,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
# Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions
# since the BulkPushRuleEvaluator is what handles data sanitisation.
- def test_room_mentions(self) -> None:
- """Check for room mentions."""
- condition = {"kind": "org.matrix.msc3952.is_room_mention"}
-
- # No room mention shouldn't match.
- evaluator = self._get_evaluator({}, has_mentions=True)
- self.assertFalse(evaluator.matches(condition, None, None))
-
- # Room mention should match.
- evaluator = self._get_evaluator({}, has_mentions=True, room_mention=True)
- self.assertTrue(evaluator.matches(condition, None, None))
-
- # A room mention and user mention is valid.
- evaluator = self._get_evaluator(
- {}, has_mentions=True, user_mentions={"@another:test"}, room_mention=True
- )
- self.assertTrue(evaluator.matches(condition, None, None))
-
- # Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions
- # since the BulkPushRuleEvaluator is what handles data sanitisation.
-
def _assert_matches(
self, condition: JsonDict, content: JsonMapping, msg: Optional[str] = None
) -> None:
@@ -402,6 +401,178 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
"pattern should not match before a newline",
)
+ def test_exact_event_match_string(self) -> None:
+ """Check that exact_event_match conditions work as expected for strings."""
+
+ # Test against a string value.
+ condition = {
+ "kind": "com.beeper.msc3758.exact_event_match",
+ "key": "content.value",
+ "value": "foobaz",
+ }
+ self._assert_matches(
+ condition,
+ {"value": "foobaz"},
+ "exact value should match",
+ )
+ self._assert_not_matches(
+ condition,
+ {"value": "FoobaZ"},
+ "values should match and be case-sensitive",
+ )
+ self._assert_not_matches(
+ condition,
+ {"value": "test foobaz test"},
+ "values must exactly match",
+ )
+ value: Any
+ for value in (True, False, 1, 1.1, None, [], {}):
+ self._assert_not_matches(
+ condition,
+ {"value": value},
+ "incorrect types should not match",
+ )
+
+ # it should work on frozendicts too
+ self._assert_matches(
+ condition,
+ frozendict.frozendict({"value": "foobaz"}),
+ "values should match on frozendicts",
+ )
+
+ def test_exact_event_match_boolean(self) -> None:
+ """Check that exact_event_match conditions work as expected for booleans."""
+
+ # Test against a True boolean value.
+ condition = {
+ "kind": "com.beeper.msc3758.exact_event_match",
+ "key": "content.value",
+ "value": True,
+ }
+ self._assert_matches(
+ condition,
+ {"value": True},
+ "exact value should match",
+ )
+ self._assert_not_matches(
+ condition,
+ {"value": False},
+ "incorrect values should not match",
+ )
+ for value in ("foobaz", 1, 1.1, None, [], {}):
+ self._assert_not_matches(
+ condition,
+ {"value": value},
+ "incorrect types should not match",
+ )
+
+ # Test against a False boolean value.
+ condition = {
+ "kind": "com.beeper.msc3758.exact_event_match",
+ "key": "content.value",
+ "value": False,
+ }
+ self._assert_matches(
+ condition,
+ {"value": False},
+ "exact value should match",
+ )
+ self._assert_not_matches(
+ condition,
+ {"value": True},
+ "incorrect values should not match",
+ )
+ # Choose false-y values to ensure there's no type coercion.
+ for value in ("", 0, 1.1, None, [], {}):
+ self._assert_not_matches(
+ condition,
+ {"value": value},
+ "incorrect types should not match",
+ )
+
+ def test_exact_event_match_null(self) -> None:
+ """Check that exact_event_match conditions work as expected for null."""
+
+ condition = {
+ "kind": "com.beeper.msc3758.exact_event_match",
+ "key": "content.value",
+ "value": None,
+ }
+ self._assert_matches(
+ condition,
+ {"value": None},
+ "exact value should match",
+ )
+ for value in ("foobaz", True, False, 1, 1.1, [], {}):
+ self._assert_not_matches(
+ condition,
+ {"value": value},
+ "incorrect types should not match",
+ )
+
+ def test_exact_event_match_integer(self) -> None:
+ """Check that exact_event_match conditions work as expected for integers."""
+
+ condition = {
+ "kind": "com.beeper.msc3758.exact_event_match",
+ "key": "content.value",
+ "value": 1,
+ }
+ self._assert_matches(
+ condition,
+ {"value": 1},
+ "exact value should match",
+ )
+ value: Any
+ for value in (1.1, -1, 0):
+ self._assert_not_matches(
+ condition,
+ {"value": value},
+ "incorrect values should not match",
+ )
+ for value in ("1", True, False, None, [], {}):
+ self._assert_not_matches(
+ condition,
+ {"value": value},
+ "incorrect types should not match",
+ )
+
+ def test_exact_event_property_contains(self) -> None:
+ """Check that exact_event_property_contains conditions work as expected."""
+
+ condition = {
+ "kind": "org.matrix.msc3966.exact_event_property_contains",
+ "key": "content.value",
+ "value": "foobaz",
+ }
+ self._assert_matches(
+ condition,
+ {"value": ["foobaz"]},
+ "exact value should match",
+ )
+ self._assert_matches(
+ condition,
+ {"value": ["foobaz", "bugz"]},
+ "extra values should match",
+ )
+ self._assert_not_matches(
+ condition,
+ {"value": ["FoobaZ"]},
+ "values should match and be case-sensitive",
+ )
+ self._assert_not_matches(
+ condition,
+ {"value": "foobaz"},
+ "does not search in a string",
+ )
+
+ # it should work on frozendicts too
+ self._assert_matches(
+ condition,
+ freeze({"value": ["foobaz"]}),
+ "values should match on frozendicts",
+ )
+
def test_no_body(self) -> None:
"""Not having a body shouldn't break the evaluator."""
evaluator = self._get_evaluator({})
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 043dbe76af..65ef4bb160 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, List, Optional
+from typing import Any, List, Optional, Sequence
from twisted.test.proto_helpers import MemoryReactor
@@ -139,7 +139,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
)
# this is the point in the DAG where we make a fork
- fork_point: List[str] = self.get_success(
+ fork_point: Sequence[str] = self.get_success(
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
)
@@ -168,7 +168,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
pl_event = self.get_success(
inject_event(
self.hs,
- prev_event_ids=prev_events,
+ prev_event_ids=list(prev_events),
type=EventTypes.PowerLevels,
state_key="",
sender=self.user_id,
@@ -294,7 +294,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
)
# this is the point in the DAG where we make a fork
- fork_point: List[str] = self.get_success(
+ fork_point: Sequence[str] = self.get_success(
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
)
@@ -323,7 +323,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
e = self.get_success(
inject_event(
self.hs,
- prev_event_ids=prev_events,
+ prev_event_ids=list(prev_events),
type=EventTypes.PowerLevels,
state_key="",
sender=self.user_id,
diff --git a/tests/replication/tcp/streams/test_partial_state.py b/tests/replication/tcp/streams/test_partial_state.py
index 38b5020ce0..452ac85069 100644
--- a/tests/replication/tcp/streams/test_partial_state.py
+++ b/tests/replication/tcp/streams/test_partial_state.py
@@ -37,7 +37,7 @@ class PartialStateStreamsTestCase(BaseMultiWorkerStreamTestCase):
room_id = self.helper.create_room_as("@bob:test")
# Mark the room as partial-stated.
self.get_success(
- self.store.store_partial_state_room(room_id, ["serv1", "serv2"], 0, "serv1")
+ self.store.store_partial_state_room(room_id, {"serv1", "serv2"}, 0, "serv1")
)
worker = self.make_worker_hs("synapse.app.generic_worker")
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index 68de5d1cc2..5a38ac831f 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -13,7 +13,7 @@
# limitations under the License.
from unittest.mock import Mock
-from synapse.handlers.typing import RoomMember
+from synapse.handlers.typing import RoomMember, TypingWriterHandler
from synapse.replication.tcp.streams import TypingStream
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -33,6 +33,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
def test_typing(self) -> None:
typing = self.hs.get_typing_handler()
+ assert isinstance(typing, TypingWriterHandler)
self.reconnect()
@@ -88,6 +89,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
sends the proper position and RDATA).
"""
typing = self.hs.get_typing_handler()
+ assert isinstance(typing, TypingWriterHandler)
self.reconnect()
diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py
index 6e4055cc21..bf927beb6a 100644
--- a/tests/replication/tcp/test_handler.py
+++ b/tests/replication/tcp/test_handler.py
@@ -127,6 +127,7 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
# ... updating the cache ID gen on the master still shouldn't cause the
# deferred to wake up.
+ assert store._cache_id_gen is not None
ctx = store._cache_id_gen.get_next()
self.get_success(ctx.__aenter__())
self.get_success(ctx.__aexit__(None, None, None))
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 89380e25b5..08703206a9 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -16,6 +16,7 @@ from unittest.mock import Mock
from synapse.api.constants import EventTypes, Membership
from synapse.events.builder import EventBuilderFactory
+from synapse.handlers.typing import TypingWriterHandler
from synapse.rest.admin import register_servlets_for_client_rest_resource
from synapse.rest.client import login, room
from synapse.types import UserID, create_requester
@@ -174,6 +175,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
token = self.login("user3", "pass")
typing_handler = self.hs.get_typing_handler()
+ assert isinstance(typing_handler, TypingWriterHandler)
sent_on_1 = False
sent_on_2 = False
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 9345cfbeb2..0798b021c3 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -50,6 +50,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
user_dict = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
+ assert user_dict is not None
token_id = user_dict.token_id
self.get_success(
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index aadb31ca83..db77a45ae3 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -213,7 +213,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.admin_user_tok = self.login("admin", "pass")
self.filepaths = MediaFilePaths(hs.config.media.media_store_path)
- self.url = "/_synapse/admin/v1/media/%s/delete" % self.server_name
+ self.url = "/_synapse/admin/v1/media/delete"
+ self.legacy_url = "/_synapse/admin/v1/media/%s/delete" % self.server_name
# Move clock up to somewhat realistic time
self.reactor.advance(1000000000)
@@ -332,11 +333,13 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
channel.json_body["error"],
)
- def test_delete_media_never_accessed(self) -> None:
+ @parameterized.expand([(True,), (False,)])
+ def test_delete_media_never_accessed(self, use_legacy_url: bool) -> None:
"""
Tests that media deleted if it is older than `before_ts` and never accessed
`last_access_ts` is `NULL` and `created_ts` < `before_ts`
"""
+ url = self.legacy_url if use_legacy_url else self.url
# upload and do not access
server_and_media_id = self._create_media()
@@ -351,7 +354,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
now_ms = self.clock.time_msec()
channel = self.make_request(
"POST",
- self.url + "?before_ts=" + str(now_ms),
+ url + "?before_ts=" + str(now_ms),
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
index a2f347f666..f71ff46d87 100644
--- a/tests/rest/admin/test_server_notice.py
+++ b/tests/rest/admin/test_server_notice.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List
+from typing import List, Sequence
from twisted.test.proto_helpers import MemoryReactor
@@ -558,7 +558,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
def _check_invite_and_join_status(
self, user_id: str, expected_invites: int, expected_memberships: int
- ) -> List[RoomsForUser]:
+ ) -> Sequence[RoomsForUser]:
"""Check invite and room membership status of a user.
Args
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 5c1ced355f..f5b213219f 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -2913,7 +2913,8 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
other_user_tok = self.login("user", "pass")
event_builder_factory = self.hs.get_event_builder_factory()
event_creation_handler = self.hs.get_event_creation_handler()
- storage_controllers = self.hs.get_storage_controllers()
+ persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
# Create two rooms, one with a local user only and one with both a local
# and remote user.
@@ -2934,11 +2935,13 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
event_creation_handler.create_new_client_event(builder)
)
- self.get_success(storage_controllers.persistence.persist_event(event, context))
+ context = self.get_success(unpersisted_context.persist(event))
+
+ self.get_success(persistence.persist_event(event, context))
# Now get rooms
url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms"
diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py
index 30f12f1bff..6c04e6c56c 100644
--- a/tests/rest/admin/test_username_available.py
+++ b/tests/rest/admin/test_username_available.py
@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional
+
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -33,9 +35,14 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
- async def check_username(username: str) -> bool:
- if username == "allowed":
- return True
+ async def check_username(
+ localpart: str,
+ guest_access_token: Optional[str] = None,
+ assigned_user_id: Optional[str] = None,
+ inhibit_user_in_use_error: bool = False,
+ ) -> None:
+ if localpart == "allowed":
+ return
raise SynapseError(
400,
"User ID already taken.",
@@ -43,7 +50,7 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
)
handler = self.hs.get_registration_handler()
- handler.check_username = check_username
+ handler.check_username = check_username # type: ignore[assignment]
def test_username_available(self) -> None:
"""
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index 88f255c9ee..e2ee1a1766 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -1193,7 +1193,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
return {}
# Register a mock that will return the expected result depending on the remote.
- self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json)
+ self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) # type: ignore[assignment]
# Check that we've got the correct response from the client-side endpoint.
self._test_status(
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index 208ec44829..a144610078 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -34,7 +34,7 @@ from synapse.util import Clock
from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC
from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER
-from tests.server import FakeChannel, make_request
+from tests.server import FakeChannel
from tests.unittest import override_config, skip_unless
@@ -43,6 +43,9 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker):
super().__init__(hs)
self.recaptcha_attempts: List[Tuple[dict, str]] = []
+ def is_enabled(self) -> bool:
+ return True
+
def check_auth(self, authdict: dict, clientip: str) -> Any:
self.recaptcha_attempts.append((authdict, clientip))
return succeed(True)
@@ -1319,16 +1322,8 @@ class OidcBackchannelLogoutTests(unittest.HomeserverTestCase):
channel = self.submit_logout_token(logout_token)
self.assertEqual(channel.code, 200)
- # Now try to exchange the login token
- channel = make_request(
- self.hs.get_reactor(),
- self.site,
- "POST",
- "/login",
- content={"type": "m.login.token", "token": login_token},
- )
- # It should have failed
- self.assertEqual(channel.code, 403)
+ # Now try to exchange the login token, it should fail.
+ self.helper.login_via_token(login_token, 403)
@override_config(
{
diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py
index afc8d641be..830762fd53 100644
--- a/tests/rest/client/test_filter.py
+++ b/tests/rest/client/test_filter.py
@@ -63,14 +63,14 @@ class FilterTestCase(unittest.HomeserverTestCase):
def test_add_filter_non_local_user(self) -> None:
_is_mine = self.hs.is_mine
- self.hs.is_mine = lambda target_user: False
+ self.hs.is_mine = lambda target_user: False # type: ignore[assignment]
channel = self.make_request(
"POST",
"/_matrix/client/r0/user/%s/filter" % (self.user_id),
self.EXAMPLE_FILTER_JSON,
)
- self.hs.is_mine = _is_mine
+ self.hs.is_mine = _is_mine # type: ignore[assignment]
self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py
index b3738a0304..67e16880e6 100644
--- a/tests/rest/client/test_presence.py
+++ b/tests/rest/client/test_presence.py
@@ -36,14 +36,14 @@ class PresenceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- presence_handler = Mock(spec=PresenceHandler)
- presence_handler.set_state.return_value = make_awaitable(None)
+ self.presence_handler = Mock(spec=PresenceHandler)
+ self.presence_handler.set_state.return_value = make_awaitable(None)
hs = self.setup_test_homeserver(
"red",
federation_http_client=None,
federation_client=Mock(),
- presence_handler=presence_handler,
+ presence_handler=self.presence_handler,
)
return hs
@@ -61,7 +61,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, HTTPStatus.OK)
- self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1)
+ self.assertEqual(self.presence_handler.set_state.call_count, 1)
@unittest.override_config({"use_presence": False})
def test_put_presence_disabled(self) -> None:
@@ -76,4 +76,4 @@ class PresenceTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, HTTPStatus.OK)
- self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0)
+ self.assertEqual(self.presence_handler.set_state.call_count, 0)
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 11cf3939d8..4c561f9525 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -151,7 +151,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
def test_POST_guest_registration(self) -> None:
- self.hs.config.key.macaroon_secret_key = "test"
+ self.hs.config.key.macaroon_secret_key = b"test"
self.hs.config.registration.allow_guest_access = True
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
@@ -1166,12 +1166,15 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
"""
user_id = self.register_user("kermit_delta", "user")
- self.hs.config.account_validity.startup_job_max_delta = self.max_delta
+ self.hs.config.account_validity.account_validity_startup_job_max_delta = (
+ self.max_delta
+ )
now_ms = self.hs.get_clock().time_msec()
self.get_success(self.store._set_expiration_date_when_missing())
res = self.get_success(self.store.get_expiration_ts_for_user(user_id))
+ assert res is not None
self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
self.assertLessEqual(res, now_ms + self.validity_period)
diff --git a/tests/rest/client/test_report_event.py b/tests/rest/client/test_report_event.py
index 7cb1017a4a..1250685d39 100644
--- a/tests/rest/client/test_report_event.py
+++ b/tests/rest/client/test_report_event.py
@@ -73,6 +73,18 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
data = {"reason": None, "score": None}
self._assert_status(400, data)
+ def test_cannot_report_nonexistent_event(self) -> None:
+ """
+ Tests that we don't accept event reports for events which do not exist.
+ """
+ channel = self.make_request(
+ "POST",
+ f"rooms/{self.room_id}/report/$nonsenseeventid:test",
+ {"reason": "i am very sad"},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(404, channel.code, msg=channel.result["body"])
+
def _assert_status(self, response_status: int, data: JsonDict) -> None:
channel = self.make_request(
"POST", self.report_path, data, access_token=self.other_user_tok
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 9c8c1889d3..d3e06bf6b3 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -136,6 +136,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# Send a first event, which should be filtered out at the end of the test.
resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
first_event_id = resp.get("event_id")
+ assert isinstance(first_event_id, str)
# Advance the time by 2 days. We're using the default retention policy, therefore
# after this the first event will still be valid.
@@ -144,6 +145,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# Send another event, which shouldn't get filtered out.
resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
valid_event_id = resp.get("event_id")
+ assert isinstance(valid_event_id, str)
# Advance the time by another 2 days. After this, the first event should be
# outdated but not the second one.
@@ -229,7 +231,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# Check that we can still access state events that were sent before the event that
# has been purged.
- self.get_event(room_id, create_event.event_id)
+ self.get_event(room_id, bool(create_event))
def get_event(self, event_id: str, expect_none: bool = False) -> JsonDict:
event = self.get_success(self.store.get_event(event_id, allow_none=True))
@@ -238,7 +240,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
self.assertIsNone(event)
return {}
- self.assertIsNotNone(event)
+ assert event is not None
time_now = self.clock.time_msec()
serialized = self.serializer.serialize_event(event, time_now)
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 9222cab198..cfad182b2f 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -3382,8 +3382,8 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# a remote IS. We keep the mock for make_and_store_3pid_invite around so we
# can check its call_count later on during the test.
make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
- self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock
- self.hs.get_identity_handler().lookup_3pid = Mock(
+ self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment]
+ self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment]
return_value=make_awaitable(None),
)
@@ -3443,8 +3443,8 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# a remote IS. We keep the mock for make_and_store_3pid_invite around so we
# can check its call_count later on during the test.
make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
- self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock
- self.hs.get_identity_handler().lookup_3pid = Mock(
+ self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment]
+ self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment]
return_value=make_awaitable(None),
)
@@ -3563,8 +3563,10 @@ class TimestampLookupTestCase(unittest.HomeserverTestCase):
)
event.internal_metadata.outlier = True
+ persistence = self._storage_controllers.persistence
+ assert persistence is not None
self.get_success(
- self._storage_controllers.persistence.persist_event(
+ persistence.persist_event(
event, EventContext.for_outlier(self._storage_controllers)
)
)
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index c807a37bc2..8d2cdf8751 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -84,7 +84,7 @@ class RoomTestCase(_ShadowBannedBase):
def test_invite_3pid(self) -> None:
"""Ensure that a 3PID invite does not attempt to contact the identity server."""
identity_handler = self.hs.get_identity_handler()
- identity_handler.lookup_3pid = Mock(
+ identity_handler.lookup_3pid = Mock( # type: ignore[assignment]
side_effect=AssertionError("This should not get called")
)
@@ -222,7 +222,7 @@ class RoomTestCase(_ShadowBannedBase):
event_source.get_new_events(
user=UserID.from_string(self.other_user_id),
from_key=0,
- limit=None,
+ limit=10,
room_ids=[room_id],
is_guest=False,
)
@@ -286,6 +286,7 @@ class ProfileTestCase(_ShadowBannedBase):
self.banned_user_id,
)
)
+ assert event is not None
self.assertEqual(
event.content, {"membership": "join", "displayname": original_display_name}
)
@@ -321,6 +322,7 @@ class ProfileTestCase(_ShadowBannedBase):
self.banned_user_id,
)
)
+ assert event is not None
self.assertEqual(
event.content, {"membership": "join", "displayname": original_display_name}
)
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 3325d43a2f..5fa3440691 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -425,7 +425,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
async def test_fn(
event: EventBase, state_events: StateMap[EventBase]
) -> Tuple[bool, Optional[JsonDict]]:
- if event.is_state and event.type == EventTypes.PowerLevels:
+ if event.is_state() and event.type == EventTypes.PowerLevels:
await api.create_and_send_event_into_room(
{
"room_id": event.room_id,
diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py
index 5ec343dd7f..0b4c691318 100644
--- a/tests/rest/client/test_upgrade_room.py
+++ b/tests/rest/client/test_upgrade_room.py
@@ -84,7 +84,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
self.room_id, EventTypes.Tombstone, ""
)
)
- self.assertIsNotNone(tombstone_event)
+ assert tombstone_event is not None
self.assertEqual(new_room_id, tombstone_event.content["replacement_room"])
# Check that the new room exists.
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 8d6f2b6ff9..9532e5ddc1 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -36,6 +36,7 @@ from urllib.parse import urlencode
import attr
from typing_extensions import Literal
+from twisted.test.proto_helpers import MemoryReactorClock
from twisted.web.resource import Resource
from twisted.web.server import Site
@@ -67,6 +68,7 @@ class RestHelper:
"""
hs: HomeServer
+ reactor: MemoryReactorClock
site: Site
auth_user_id: Optional[str]
@@ -142,7 +144,7 @@ class RestHelper:
path = path + "?access_token=%s" % tok
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"POST",
path,
@@ -216,7 +218,7 @@ class RestHelper:
data["reason"] = reason
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"POST",
path,
@@ -313,7 +315,7 @@ class RestHelper:
data.update(extra_data or {})
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"PUT",
path,
@@ -394,7 +396,7 @@ class RestHelper:
path = path + "?access_token=%s" % tok
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"PUT",
path,
@@ -433,7 +435,7 @@ class RestHelper:
path = path + f"?access_token={tok}"
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"GET",
path,
@@ -488,7 +490,7 @@ class RestHelper:
if body is not None:
content = json.dumps(body).encode("utf8")
- channel = make_request(self.hs.get_reactor(), self.site, method, path, content)
+ channel = make_request(self.reactor, self.site, method, path, content)
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
@@ -573,8 +575,8 @@ class RestHelper:
image_length = len(image_data)
path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
channel = make_request(
- self.hs.get_reactor(),
- FakeSite(resource, self.hs.get_reactor()),
+ self.reactor,
+ FakeSite(resource, self.reactor),
"POST",
path,
content=image_data,
@@ -603,7 +605,7 @@ class RestHelper:
expect_code: The return code to expect from attempting the whoami request
"""
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"GET",
"account/whoami",
@@ -642,7 +644,7 @@ class RestHelper:
) -> Tuple[JsonDict, FakeAuthorizationGrant]:
"""Log in (as a new user) via OIDC
- Returns the result of the final token login.
+ Returns the result of the final token login and the fake authorization grant.
Requires that "oidc_config" in the homeserver config be set appropriately
(TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
@@ -672,10 +674,28 @@ class RestHelper:
assert m, channel.text_body
login_token = m.group(1)
- # finally, submit the matrix login token to the login API, which gives us our
- # matrix access token and device id.
+ return self.login_via_token(login_token, expected_status), grant
+
+ def login_via_token(
+ self,
+ login_token: str,
+ expected_status: int = 200,
+ ) -> JsonDict:
+ """Submit the matrix login token to the login API, which gives us our
+ matrix access token and device id.Log in (as a new user) via OIDC
+
+ Returns the result of the token login.
+
+ Requires that "oidc_config" in the homeserver config be set appropriately
+ (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
+ "public_base_url".
+
+ Also requires the login servlet and the OIDC callback resource to be mounted at
+ the normal places.
+ """
+
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"POST",
"/login",
@@ -684,7 +704,7 @@ class RestHelper:
assert (
channel.code == expected_status
), f"unexpected status in response: {channel.code}"
- return channel.json_body, grant
+ return channel.json_body
def auth_via_oidc(
self,
@@ -805,7 +825,7 @@ class RestHelper:
with fake_serer.patch_homeserver(hs=self.hs):
# now hit the callback URI with the right params and a made-up code
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"GET",
callback_uri,
@@ -849,7 +869,7 @@ class RestHelper:
# is the easiest way of figuring out what the Host header ought to be set to
# to keep Synapse happy.
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"GET",
uri,
@@ -867,7 +887,7 @@ class RestHelper:
location = get_location(channel)
parts = urllib.parse.urlsplit(location)
channel = make_request(
- self.hs.get_reactor(),
+ self.reactor,
self.site,
"GET",
urllib.parse.urlunsplit(("", "") + parts[2:]),
@@ -900,9 +920,7 @@ class RestHelper:
+ urllib.parse.urlencode({"session": ui_auth_session_id})
)
# hit the redirect url (which will issue a cookie and state)
- channel = make_request(
- self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint
- )
+ channel = make_request(self.reactor, self.site, "GET", sso_redirect_endpoint)
# that should serve a confirmation page
assert channel.code == HTTPStatus.OK, channel.text_body
channel.extract_cookies(cookies)
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index d18fc13c21..17a3b06a8e 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -16,7 +16,7 @@ import shutil
import tempfile
from binascii import unhexlify
from io import BytesIO
-from typing import Any, BinaryIO, Dict, List, Optional, Union
+from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Tuple, Union
from unittest.mock import Mock
from urllib import parse
@@ -32,6 +32,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import Codes
from synapse.events import EventBase
from synapse.events.spamcheck import load_legacy_spam_checkers
+from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable
from synapse.module_api import ModuleApi
from synapse.rest import admin
@@ -41,7 +42,7 @@ from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.rest.media.v1.media_storage import MediaStorage, ReadableFileWrapper
from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
from synapse.server import HomeServer
-from synapse.types import RoomAlias
+from synapse.types import JsonDict, RoomAlias
from synapse.util import Clock
from tests import unittest
@@ -201,36 +202,46 @@ class _TestImage:
],
)
class MediaRepoTests(unittest.HomeserverTestCase):
-
+ test_image: ClassVar[_TestImage]
hijack_auth = True
user_id = "@test:user"
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- self.fetches = []
+ self.fetches: List[
+ Tuple[
+ "Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]]",
+ str,
+ str,
+ Optional[QueryParams],
+ ]
+ ] = []
def get_file(
destination: str,
path: str,
output_stream: BinaryIO,
- args: Optional[Dict[str, Union[str, List[str]]]] = None,
+ args: Optional[QueryParams] = None,
+ retry_on_dns_fail: bool = True,
max_size: Optional[int] = None,
- ) -> Deferred:
- """
- Returns tuple[int,dict,str,int] of file length, response headers,
- absolute URI, and response code.
- """
+ ignore_backoff: bool = False,
+ ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]":
+ """A mock for MatrixFederationHttpClient.get_file."""
- def write_to(r):
+ def write_to(
+ r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]
+ ) -> Tuple[int, Dict[bytes, List[bytes]]]:
data, response = r
output_stream.write(data)
return response
- d = Deferred()
- d.addCallback(write_to)
+ d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred()
self.fetches.append((d, destination, path, args))
- return make_deferred_yieldable(d)
+ # Note that this callback changes the value held by d.
+ d_after_callback = d.addCallback(write_to)
+ return make_deferred_yieldable(d_after_callback)
+ # Mock out the homeserver's MatrixFederationHttpClient
client = Mock()
client.get_file = get_file
@@ -461,6 +472,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
# Synapse should regenerate missing thumbnails.
origin, media_id = self.media_id.split("/")
info = self.get_success(self.store.get_cached_remote_media(origin, media_id))
+ assert info is not None
file_id = info["filesystem_id"]
thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir(
@@ -581,7 +593,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
"thumbnail_method": method,
"thumbnail_type": self.test_image.content_type,
"thumbnail_length": 256,
- "filesystem_id": f"thumbnail1{self.test_image.extension}",
+ "filesystem_id": f"thumbnail1{self.test_image.extension.decode()}",
},
{
"thumbnail_width": 32,
@@ -589,10 +601,10 @@ class MediaRepoTests(unittest.HomeserverTestCase):
"thumbnail_method": method,
"thumbnail_type": self.test_image.content_type,
"thumbnail_length": 256,
- "filesystem_id": f"thumbnail2{self.test_image.extension}",
+ "filesystem_id": f"thumbnail2{self.test_image.extension.decode()}",
},
],
- file_id=f"image{self.test_image.extension}",
+ file_id=f"image{self.test_image.extension.decode()}",
url_cache=None,
server_name=None,
)
@@ -637,6 +649,7 @@ class TestSpamCheckerLegacy:
self.config = config
self.api = api
+ @staticmethod
def parse_config(config: Dict[str, Any]) -> Dict[str, Any]:
return config
@@ -748,7 +761,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
async def check_media_file_for_spam(
self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
- ) -> Union[Codes, Literal["NOT_SPAM"]]:
+ ) -> Union[Codes, Literal["NOT_SPAM"], Tuple[Codes, JsonDict]]:
buf = BytesIO()
await file_wrapper.write_chunks_to(buf.write)
diff --git a/tests/scripts/test_new_matrix_user.py b/tests/scripts/test_new_matrix_user.py
index 22f99c6ab1..3285f2433c 100644
--- a/tests/scripts/test_new_matrix_user.py
+++ b/tests/scripts/test_new_matrix_user.py
@@ -12,29 +12,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List
+from typing import List, Optional
from unittest.mock import Mock, patch
from synapse._scripts.register_new_matrix_user import request_registration
+from synapse.types import JsonDict
from tests.unittest import TestCase
class RegisterTestCase(TestCase):
- def test_success(self):
+ def test_success(self) -> None:
"""
The script will fetch a nonce, and then generate a MAC with it, and then
post that MAC.
"""
- def get(url, verify=None):
+ def get(url: str, verify: Optional[bool] = None) -> Mock:
r = Mock()
r.status_code = 200
r.json = lambda: {"nonce": "a"}
return r
- def post(url, json=None, verify=None):
+ def post(
+ url: str, json: Optional[JsonDict] = None, verify: Optional[bool] = None
+ ) -> Mock:
# Make sure we are sent the correct info
+ assert json is not None
self.assertEqual(json["username"], "user")
self.assertEqual(json["password"], "pass")
self.assertEqual(json["nonce"], "a")
@@ -70,12 +74,12 @@ class RegisterTestCase(TestCase):
# sys.exit shouldn't have been called.
self.assertEqual(err_code, [])
- def test_failure_nonce(self):
+ def test_failure_nonce(self) -> None:
"""
If the script fails to fetch a nonce, it throws an error and quits.
"""
- def get(url, verify=None):
+ def get(url: str, verify: Optional[bool] = None) -> Mock:
r = Mock()
r.status_code = 404
r.reason = "Not Found"
@@ -107,20 +111,23 @@ class RegisterTestCase(TestCase):
self.assertIn("ERROR! Received 404 Not Found", out)
self.assertNotIn("Success!", out)
- def test_failure_post(self):
+ def test_failure_post(self) -> None:
"""
The script will fetch a nonce, and then if the final POST fails, will
report an error and quit.
"""
- def get(url, verify=None):
+ def get(url: str, verify: Optional[bool] = None) -> Mock:
r = Mock()
r.status_code = 200
r.json = lambda: {"nonce": "a"}
return r
- def post(url, json=None, verify=None):
+ def post(
+ url: str, json: Optional[JsonDict] = None, verify: Optional[bool] = None
+ ) -> Mock:
# Make sure we are sent the correct info
+ assert json is not None
self.assertEqual(json["username"], "user")
self.assertEqual(json["password"], "pass")
self.assertEqual(json["nonce"], "a")
diff --git a/tests/server.py b/tests/server.py
index 237bcad8ba..5de9722766 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -22,20 +22,25 @@ import warnings
from collections import deque
from io import SEEK_END, BytesIO
from typing import (
+ Any,
+ Awaitable,
Callable,
Dict,
Iterable,
List,
MutableMapping,
Optional,
+ Sequence,
Tuple,
Type,
+ TypeVar,
Union,
+ cast,
)
from unittest.mock import Mock
import attr
-from typing_extensions import Deque
+from typing_extensions import Deque, ParamSpec
from zope.interface import implementer
from twisted.internet import address, threads, udp
@@ -44,8 +49,10 @@ from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import (
IAddress,
+ IConnector,
IConsumer,
IHostnameResolver,
+ IProducer,
IProtocol,
IPullProducer,
IPushProducer,
@@ -54,6 +61,8 @@ from twisted.internet.interfaces import (
IResolverSimple,
ITransport,
)
+from twisted.internet.protocol import ClientFactory, DatagramProtocol
+from twisted.python import threadpool
from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
from twisted.web.http_headers import Headers
@@ -61,6 +70,7 @@ from twisted.web.resource import IResource
from twisted.web.server import Request, Site
from synapse.config.database import DatabaseConnectionConfig
+from synapse.config.homeserver import HomeServerConfig
from synapse.events.presence_router import load_legacy_presence_router
from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
@@ -88,6 +98,9 @@ from tests.utils import (
logger = logging.getLogger(__name__)
+R = TypeVar("R")
+P = ParamSpec("P")
+
# the type of thing that can be passed into `make_request` in the headers list
CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]
@@ -98,12 +111,14 @@ class TimedOutException(Exception):
"""
-@implementer(IConsumer)
+@implementer(ITransport, IPushProducer, IConsumer)
@attr.s(auto_attribs=True)
class FakeChannel:
"""
A fake Twisted Web Channel (the part that interfaces with the
wire).
+
+ See twisted.web.http.HTTPChannel.
"""
site: Union[Site, "FakeSite"]
@@ -142,7 +157,7 @@ class FakeChannel:
Raises an exception if the request has not yet completed.
"""
- if not self.is_finished:
+ if not self.is_finished():
raise Exception("Request not yet completed")
return self.result["body"].decode("utf8")
@@ -165,27 +180,36 @@ class FakeChannel:
h.addRawHeader(*i)
return h
- def writeHeaders(self, version, code, reason, headers):
+ def writeHeaders(
+ self, version: bytes, code: bytes, reason: bytes, headers: Headers
+ ) -> None:
self.result["version"] = version
self.result["code"] = code
self.result["reason"] = reason
self.result["headers"] = headers
- def write(self, content: bytes) -> None:
- assert isinstance(content, bytes), "Should be bytes! " + repr(content)
+ def write(self, data: bytes) -> None:
+ assert isinstance(data, bytes), "Should be bytes! " + repr(data)
if "body" not in self.result:
self.result["body"] = b""
- self.result["body"] += content
+ self.result["body"] += data
+
+ def writeSequence(self, data: Iterable[bytes]) -> None:
+ for x in data:
+ self.write(x)
+
+ def loseConnection(self) -> None:
+ self.unregisterProducer()
+ self.transport.loseConnection()
# Type ignore: mypy doesn't like the fact that producer isn't an IProducer.
- def registerProducer( # type: ignore[override]
- self,
- producer: Union[IPullProducer, IPushProducer],
- streaming: bool,
- ) -> None:
- self._producer = producer
+ def registerProducer(self, producer: IProducer, streaming: bool) -> None:
+ # TODO This should ensure that the IProducer is an IPushProducer or
+ # IPullProducer, unfortunately twisted.protocols.basic.FileSender does
+ # implement those, but doesn't declare it.
+ self._producer = cast(Union[IPushProducer, IPullProducer], producer)
self.producerStreaming = streaming
def _produce() -> None:
@@ -202,6 +226,16 @@ class FakeChannel:
self._producer = None
+ def stopProducing(self) -> None:
+ if self._producer is not None:
+ self._producer.stopProducing()
+
+ def pauseProducing(self) -> None:
+ raise NotImplementedError()
+
+ def resumeProducing(self) -> None:
+ raise NotImplementedError()
+
def requestDone(self, _self: Request) -> None:
self.result["done"] = True
if isinstance(_self, SynapseRequest):
@@ -281,12 +315,12 @@ class FakeSite:
self.reactor = reactor
self.experimental_cors_msc3886 = experimental_cors_msc3886
- def getResourceFor(self, request):
+ def getResourceFor(self, request: Request) -> IResource:
return self._resource
def make_request(
- reactor,
+ reactor: MemoryReactorClock,
site: Union[Site, FakeSite],
method: Union[bytes, str],
path: Union[bytes, str],
@@ -409,19 +443,21 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
A MemoryReactorClock that supports callFromThread.
"""
- def __init__(self):
+ def __init__(self) -> None:
self.threadpool = ThreadPool(self)
self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {}
- self._udp = []
+ self._udp: List[udp.Port] = []
self.lookups: Dict[str, str] = {}
- self._thread_callbacks: Deque[Callable[[], None]] = deque()
+ self._thread_callbacks: Deque[Callable[..., R]] = deque()
lookups = self.lookups
@implementer(IResolverSimple)
class FakeResolver:
- def getHostByName(self, name, timeout=None):
+ def getHostByName(
+ self, name: str, timeout: Optional[Sequence[int]] = None
+ ) -> "Deferred[str]":
if name not in lookups:
return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
return succeed(lookups[name])
@@ -432,25 +468,44 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
raise NotImplementedError()
- def listenUDP(self, port, protocol, interface="", maxPacketSize=8196):
+ def listenUDP(
+ self,
+ port: int,
+ protocol: DatagramProtocol,
+ interface: str = "",
+ maxPacketSize: int = 8196,
+ ) -> udp.Port:
p = udp.Port(port, protocol, interface, maxPacketSize, self)
p.startListening()
self._udp.append(p)
return p
- def callFromThread(self, callback, *args, **kwargs):
+ def callFromThread(
+ self, callable: Callable[..., Any], *args: object, **kwargs: object
+ ) -> None:
"""
Make the callback fire in the next reactor iteration.
"""
- cb = lambda: callback(*args, **kwargs)
+ cb = lambda: callable(*args, **kwargs)
# it's not safe to call callLater() here, so we append the callback to a
# separate queue.
self._thread_callbacks.append(cb)
- def getThreadPool(self):
- return self.threadpool
+ def callInThread(
+ self, callable: Callable[..., Any], *args: object, **kwargs: object
+ ) -> None:
+ raise NotImplementedError()
+
+ def suggestThreadPoolSize(self, size: int) -> None:
+ raise NotImplementedError()
+
+ def getThreadPool(self) -> "threadpool.ThreadPool":
+ # Cast to match super-class.
+ return cast(threadpool.ThreadPool, self.threadpool)
- def add_tcp_client_callback(self, host: str, port: int, callback: Callable):
+ def add_tcp_client_callback(
+ self, host: str, port: int, callback: Callable[[], None]
+ ) -> None:
"""Add a callback that will be invoked when we receive a connection
attempt to the given IP/port using `connectTCP`.
@@ -459,7 +514,14 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
"""
self._tcp_callbacks[(host, port)] = callback
- def connectTCP(self, host: str, port: int, factory, timeout=30, bindAddress=None):
+ def connectTCP(
+ self,
+ host: str,
+ port: int,
+ factory: ClientFactory,
+ timeout: float = 30,
+ bindAddress: Optional[Tuple[str, int]] = None,
+ ) -> IConnector:
"""Fake L{IReactorTCP.connectTCP}."""
conn = super().connectTCP(
@@ -472,7 +534,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
return conn
- def advance(self, amount):
+ def advance(self, amount: float) -> None:
# first advance our reactor's time, and run any "callLater" callbacks that
# makes ready
super().advance(amount)
@@ -500,25 +562,33 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
class ThreadPool:
"""
Threadless thread pool.
+
+ See twisted.python.threadpool.ThreadPool
"""
- def __init__(self, reactor):
+ def __init__(self, reactor: IReactorTime):
self._reactor = reactor
- def start(self):
+ def start(self) -> None:
pass
- def stop(self):
+ def stop(self) -> None:
pass
- def callInThreadWithCallback(self, onResult, function, *args, **kwargs):
- def _(res):
+ def callInThreadWithCallback(
+ self,
+ onResult: Callable[[bool, Union[Failure, R]], None],
+ function: Callable[P, R],
+ *args: P.args,
+ **kwargs: P.kwargs,
+ ) -> "Deferred[None]":
+ def _(res: Any) -> None:
if isinstance(res, Failure):
onResult(False, res)
else:
onResult(True, res)
- d = Deferred()
+ d: "Deferred[None]" = Deferred()
d.addCallback(lambda x: function(*args, **kwargs))
d.addBoth(_)
self._reactor.callLater(0, d.callback, True)
@@ -535,7 +605,9 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None:
for database in server.get_datastores().databases:
pool = database._db_pool
- def runWithConnection(func, *args, **kwargs):
+ def runWithConnection(
+ func: Callable[..., R], *args: Any, **kwargs: Any
+ ) -> Awaitable[R]:
return threads.deferToThreadPool(
pool._reactor,
pool.threadpool,
@@ -545,20 +617,23 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None:
**kwargs,
)
- def runInteraction(interaction, *args, **kwargs):
+ def runInteraction(
+ desc: str, func: Callable[..., R], *args: Any, **kwargs: Any
+ ) -> Awaitable[R]:
return threads.deferToThreadPool(
pool._reactor,
pool.threadpool,
pool._runInteraction,
- interaction,
+ desc,
+ func,
*args,
**kwargs,
)
- pool.runWithConnection = runWithConnection
- pool.runInteraction = runInteraction
+ pool.runWithConnection = runWithConnection # type: ignore[assignment]
+ pool.runInteraction = runInteraction # type: ignore[assignment]
# Replace the thread pool with a threadless 'thread' pool
- pool.threadpool = ThreadPool(clock._reactor)
+ pool.threadpool = ThreadPool(clock._reactor) # type: ignore[assignment]
pool.running = True
# We've just changed the Databases to run DB transactions on the same
@@ -573,7 +648,7 @@ def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
@implementer(ITransport)
-@attr.s(cmp=False)
+@attr.s(cmp=False, auto_attribs=True)
class FakeTransport:
"""
A twisted.internet.interfaces.ITransport implementation which sends all its data
@@ -588,48 +663,50 @@ class FakeTransport:
If you want bidirectional communication, you'll need two instances.
"""
- other = attr.ib()
+ other: IProtocol
"""The Protocol object which will receive any data written to this transport.
-
- :type: twisted.internet.interfaces.IProtocol
"""
- _reactor = attr.ib()
+ _reactor: IReactorTime
"""Test reactor
-
- :type: twisted.internet.interfaces.IReactorTime
"""
- _protocol = attr.ib(default=None)
+ _protocol: Optional[IProtocol] = None
"""The Protocol which is producing data for this transport. Optional, but if set
will get called back for connectionLost() notifications etc.
"""
- _peer_address: Optional[IAddress] = attr.ib(default=None)
+ _peer_address: IAddress = attr.Factory(
+ lambda: address.IPv4Address("TCP", "127.0.0.1", 5678)
+ )
"""The value to be returned by getPeer"""
- _host_address: Optional[IAddress] = attr.ib(default=None)
+ _host_address: IAddress = attr.Factory(
+ lambda: address.IPv4Address("TCP", "127.0.0.1", 1234)
+ )
"""The value to be returned by getHost"""
disconnecting = False
disconnected = False
connected = True
- buffer = attr.ib(default=b"")
- producer = attr.ib(default=None)
- autoflush = attr.ib(default=True)
+ buffer: bytes = b""
+ producer: Optional[IPushProducer] = None
+ autoflush: bool = True
- def getPeer(self) -> Optional[IAddress]:
+ def getPeer(self) -> IAddress:
return self._peer_address
- def getHost(self) -> Optional[IAddress]:
+ def getHost(self) -> IAddress:
return self._host_address
- def loseConnection(self, reason=None):
+ def loseConnection(self) -> None:
if not self.disconnecting:
- logger.info("FakeTransport: loseConnection(%s)", reason)
+ logger.info("FakeTransport: loseConnection()")
self.disconnecting = True
if self._protocol:
- self._protocol.connectionLost(reason)
+ self._protocol.connectionLost(
+ Failure(RuntimeError("FakeTransport.loseConnection()"))
+ )
# if we still have data to write, delay until that is done
if self.buffer:
@@ -640,38 +717,38 @@ class FakeTransport:
self.connected = False
self.disconnected = True
- def abortConnection(self):
+ def abortConnection(self) -> None:
logger.info("FakeTransport: abortConnection()")
if not self.disconnecting:
self.disconnecting = True
if self._protocol:
- self._protocol.connectionLost(None)
+ self._protocol.connectionLost(None) # type: ignore[arg-type]
self.disconnected = True
- def pauseProducing(self):
+ def pauseProducing(self) -> None:
if not self.producer:
return
self.producer.pauseProducing()
- def resumeProducing(self):
+ def resumeProducing(self) -> None:
if not self.producer:
return
self.producer.resumeProducing()
- def unregisterProducer(self):
+ def unregisterProducer(self) -> None:
if not self.producer:
return
self.producer = None
- def registerProducer(self, producer, streaming):
+ def registerProducer(self, producer: IPushProducer, streaming: bool) -> None:
self.producer = producer
self.producerStreaming = streaming
- def _produce():
+ def _produce() -> None:
if not self.producer:
# we've been unregistered
return
@@ -683,7 +760,7 @@ class FakeTransport:
if not streaming:
self._reactor.callLater(0.0, _produce)
- def write(self, byt):
+ def write(self, byt: bytes) -> None:
if self.disconnecting:
raise Exception("Writing to disconnecting FakeTransport")
@@ -695,11 +772,11 @@ class FakeTransport:
if self.autoflush:
self._reactor.callLater(0.0, self.flush)
- def writeSequence(self, seq):
+ def writeSequence(self, seq: Iterable[bytes]) -> None:
for x in seq:
self.write(x)
- def flush(self, maxbytes=None):
+ def flush(self, maxbytes: Optional[int] = None) -> None:
if not self.buffer:
# nothing to do. Don't write empty buffers: it upsets the
# TLSMemoryBIOProtocol
@@ -750,17 +827,17 @@ def connect_client(
class TestHomeServer(HomeServer):
- DATASTORE_CLASS = DataStore
+ DATASTORE_CLASS = DataStore # type: ignore[assignment]
def setup_test_homeserver(
- cleanup_func,
- name="test",
- config=None,
- reactor=None,
+ cleanup_func: Callable[[Callable[[], None]], None],
+ name: str = "test",
+ config: Optional[HomeServerConfig] = None,
+ reactor: Optional[ISynapseReactor] = None,
homeserver_to_use: Type[HomeServer] = TestHomeServer,
- **kwargs,
-):
+ **kwargs: Any,
+) -> HomeServer:
"""
Setup a homeserver suitable for running tests against. Keyword arguments
are passed to the Homeserver constructor.
@@ -775,13 +852,14 @@ def setup_test_homeserver(
HomeserverTestCase.
"""
if reactor is None:
- from twisted.internet import reactor
+ from twisted.internet import reactor as _reactor
+
+ reactor = cast(ISynapseReactor, _reactor)
if config is None:
config = default_config(name, parse=True)
config.caches.resize_all_caches()
- config.ldap_enabled = False
if "clock" not in kwargs:
kwargs["clock"] = MockClock()
@@ -832,6 +910,8 @@ def setup_test_homeserver(
# Create the database before we actually try and connect to it, based off
# the template database we generate in setupdb()
if isinstance(db_engine, PostgresEngine):
+ import psycopg2.extensions
+
db_conn = db_engine.module.connect(
database=POSTGRES_BASE_DB,
user=POSTGRES_USER,
@@ -839,6 +919,7 @@ def setup_test_homeserver(
port=POSTGRES_PORT,
password=POSTGRES_PASSWORD,
)
+ assert isinstance(db_conn, psycopg2.extensions.connection)
db_conn.autocommit = True
cur = db_conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
@@ -867,14 +948,15 @@ def setup_test_homeserver(
hs.setup_background_tasks()
if isinstance(db_engine, PostgresEngine):
- database = hs.get_datastores().databases[0]
+ database_pool = hs.get_datastores().databases[0]
# We need to do cleanup on PostgreSQL
- def cleanup():
+ def cleanup() -> None:
import psycopg2
+ import psycopg2.extensions
# Close all the db pools
- database._db_pool.close()
+ database_pool._db_pool.close()
dropped = False
@@ -886,6 +968,7 @@ def setup_test_homeserver(
port=POSTGRES_PORT,
password=POSTGRES_PASSWORD,
)
+ assert isinstance(db_conn, psycopg2.extensions.connection)
db_conn.autocommit = True
cur = db_conn.cursor()
@@ -918,23 +1001,23 @@ def setup_test_homeserver(
# Need to let the HS build an auth handler and then mess with it
# because AuthHandler's constructor requires the HS, so we can't make one
# beforehand and pass it in to the HS's constructor (chicken / egg)
- async def hash(p):
+ async def hash(p: str) -> str:
return hashlib.md5(p.encode("utf8")).hexdigest()
- hs.get_auth_handler().hash = hash
+ hs.get_auth_handler().hash = hash # type: ignore[assignment]
- async def validate_hash(p, h):
+ async def validate_hash(p: str, h: str) -> bool:
return hashlib.md5(p.encode("utf8")).hexdigest() == h
- hs.get_auth_handler().validate_hash = validate_hash
+ hs.get_auth_handler().validate_hash = validate_hash # type: ignore[assignment]
# Make the threadpool and database transactions synchronous for testing.
_make_test_homeserver_synchronous(hs)
# Load any configured modules into the homeserver
module_api = hs.get_module_api()
- for module, config in hs.config.modules.loaded_modules:
- module(config=config, api=module_api)
+ for module, module_config in hs.config.modules.loaded_modules:
+ module(config=module_config, api=module_api)
load_legacy_spam_checkers(hs)
load_legacy_third_party_event_rules(hs)
diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py
index 58b399a043..6540ed53f1 100644
--- a/tests/server_notices/test_consent.py
+++ b/tests/server_notices/test_consent.py
@@ -14,8 +14,12 @@
import os
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse.rest.admin
from synapse.rest.client import login, room, sync
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests import unittest
@@ -29,7 +33,7 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
room.register_servlets,
]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
tmpdir = self.mktemp()
os.mkdir(tmpdir)
@@ -53,15 +57,13 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
"room_name": "Server Notices",
}
- hs = self.setup_test_homeserver(config=config)
-
- return hs
+ return self.setup_test_homeserver(config=config)
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("bob", "abc123")
self.access_token = self.login("bob", "abc123")
- def test_get_sync_message(self):
+ def test_get_sync_message(self) -> None:
"""
When user consent server notices are enabled, a sync will cause a notice
to fire (in a room which the user is invited to). The notice contains
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index dadc6efcbf..d2bfa53eda 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -24,6 +24,8 @@ from synapse.server import HomeServer
from synapse.server_notices.resource_limits_server_notices import (
ResourceLimitsServerNotices,
)
+from synapse.server_notices.server_notices_sender import ServerNoticesSender
+from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
@@ -33,7 +35,7 @@ from tests.utils import default_config
class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
- def default_config(self):
+ def default_config(self) -> JsonDict:
config = default_config("test")
config.update(
@@ -57,14 +59,15 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.server_notices_sender = self.hs.get_server_notices_sender()
+ server_notices_sender = self.hs.get_server_notices_sender()
+ assert isinstance(server_notices_sender, ServerNoticesSender)
# relying on [1] is far from ideal, but the only case where
# ResourceLimitsServerNotices class needs to be isolated is this test,
# general code should never have a reason to do so ...
- self._rlsn = self.server_notices_sender._server_notices[1]
- if not isinstance(self._rlsn, ResourceLimitsServerNotices):
- raise Exception("Failed to find reference to ResourceLimitsServerNotices")
+ rlsn = list(server_notices_sender._server_notices)[1]
+ assert isinstance(rlsn, ResourceLimitsServerNotices)
+ self._rlsn = rlsn
self._rlsn._store.user_last_seen_monthly_active = Mock(
return_value=make_awaitable(1000)
@@ -86,39 +89,43 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({})) # type: ignore[assignment]
@override_config({"hs_disabled": True})
- def test_maybe_send_server_notice_disabled_hs(self):
+ def test_maybe_send_server_notice_disabled_hs(self) -> None:
"""If the HS is disabled, we should not send notices"""
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called()
@override_config({"limit_usage_by_mau": False})
- def test_maybe_send_server_notice_to_user_flag_off(self):
+ def test_maybe_send_server_notice_to_user_flag_off(self) -> None:
"""If mau limiting is disabled, we should not send notices"""
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called()
- def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
+ def test_maybe_send_server_notice_to_user_remove_blocked_notice(self) -> None:
"""Test when user has blocked notice, but should have it removed"""
- self._rlsn._auth_blocking.check_auth_blocking = Mock(
+ self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None)
)
mock_event = Mock(
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
- self._rlsn._store.get_events = Mock(
+ self._rlsn._store.get_events = Mock( # type: ignore[assignment]
return_value=make_awaitable({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
# Would be better to check the content, but once == remove blocking event
- self._rlsn._server_notices_manager.maybe_get_notice_room_for_user.assert_called_once()
+ maybe_get_notice_room_for_user = (
+ self._rlsn._server_notices_manager.maybe_get_notice_room_for_user
+ )
+ assert isinstance(maybe_get_notice_room_for_user, Mock)
+ maybe_get_notice_room_for_user.assert_called_once()
self._send_notice.assert_called_once()
- def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self):
+ def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self) -> None:
"""
Test when user has blocked notice, but notice ought to be there (NOOP)
"""
- self._rlsn._auth_blocking.check_auth_blocking = Mock(
+ self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None),
side_effect=ResourceLimitError(403, "foo"),
)
@@ -126,7 +133,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
mock_event = Mock(
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
- self._rlsn._store.get_events = Mock(
+ self._rlsn._store.get_events = Mock( # type: ignore[assignment]
return_value=make_awaitable({"123": mock_event})
)
@@ -134,11 +141,11 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._send_notice.assert_not_called()
- def test_maybe_send_server_notice_to_user_add_blocked_notice(self):
+ def test_maybe_send_server_notice_to_user_add_blocked_notice(self) -> None:
"""
Test when user does not have blocked notice, but should have one
"""
- self._rlsn._auth_blocking.check_auth_blocking = Mock(
+ self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None),
side_effect=ResourceLimitError(403, "foo"),
)
@@ -147,11 +154,11 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
# Would be better to check contents, but 2 calls == set blocking event
self.assertEqual(self._send_notice.call_count, 2)
- def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self):
+ def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self) -> None:
"""
Test when user does not have blocked notice, nor should they (NOOP)
"""
- self._rlsn._auth_blocking.check_auth_blocking = Mock(
+ self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None)
)
@@ -159,12 +166,12 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._send_notice.assert_not_called()
- def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self):
+ def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self) -> None:
"""
Test when user is not part of the MAU cohort - this should not ever
happen - but ...
"""
- self._rlsn._auth_blocking.check_auth_blocking = Mock(
+ self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None)
)
self._rlsn._store.user_last_seen_monthly_active = Mock(
@@ -175,12 +182,14 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._send_notice.assert_not_called()
@override_config({"mau_limit_alerting": False})
- def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(self):
+ def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(
+ self,
+ ) -> None:
"""
Test that when server is over MAU limit and alerting is suppressed, then
an alert message is not sent into the room
"""
- self._rlsn._auth_blocking.check_auth_blocking = Mock(
+ self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
@@ -191,11 +200,11 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self.assertEqual(self._send_notice.call_count, 0)
@override_config({"mau_limit_alerting": False})
- def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self):
+ def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self) -> None:
"""
Test that when a server is disabled, that MAU limit alerting is ignored.
"""
- self._rlsn._auth_blocking.check_auth_blocking = Mock(
+ self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
@@ -207,26 +216,28 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self.assertEqual(self._send_notice.call_count, 2)
@override_config({"mau_limit_alerting": False})
- def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(self):
+ def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(
+ self,
+ ) -> None:
"""
When the room is already in a blocked state, test that when alerting
is suppressed that the room is returned to an unblocked state.
"""
- self._rlsn._auth_blocking.check_auth_blocking = Mock(
+ self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
),
)
- self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock(
+ self._rlsn._is_room_currently_blocked = Mock( # type: ignore[assignment]
return_value=make_awaitable((True, []))
)
mock_event = Mock(
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
- self._rlsn._store.get_events = Mock(
+ self._rlsn._store.get_events = Mock( # type: ignore[assignment]
return_value=make_awaitable({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -242,7 +253,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
sync.register_servlets,
]
- def default_config(self):
+ def default_config(self) -> JsonDict:
c = super().default_config()
c["server_notices"] = {
"system_mxid_localpart": "server",
@@ -257,20 +268,22 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = self.hs.get_datastores().main
- self.server_notices_sender = self.hs.get_server_notices_sender()
self.server_notices_manager = self.hs.get_server_notices_manager()
self.event_source = self.hs.get_event_sources()
+ server_notices_sender = self.hs.get_server_notices_sender()
+ assert isinstance(server_notices_sender, ServerNoticesSender)
+
# relying on [1] is far from ideal, but the only case where
# ResourceLimitsServerNotices class needs to be isolated is this test,
# general code should never have a reason to do so ...
- self._rlsn = self.server_notices_sender._server_notices[1]
- if not isinstance(self._rlsn, ResourceLimitsServerNotices):
- raise Exception("Failed to find reference to ResourceLimitsServerNotices")
+ rlsn = list(server_notices_sender._server_notices)[1]
+ assert isinstance(rlsn, ResourceLimitsServerNotices)
+ self._rlsn = rlsn
self.user_id = "@user_id:test"
- def test_server_notice_only_sent_once(self):
+ def test_server_notice_only_sent_once(self) -> None:
self.store.get_monthly_active_count = Mock(return_value=make_awaitable(1000))
self.store.user_last_seen_monthly_active = Mock(
@@ -306,7 +319,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
self.assertEqual(count, 1)
- def test_no_invite_without_notice(self):
+ def test_no_invite_without_notice(self) -> None:
"""Tests that a user doesn't get invited to a server notices room without a
server notice being sent.
@@ -328,7 +341,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
m.assert_called_once_with(user_id)
- def test_invite_with_notice(self):
+ def test_invite_with_notice(self) -> None:
"""Tests that, if the MAU limit is hit, the server notices user invites each user
to a room in which it has sent a notice.
"""
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 9f33afcca0..9606ecc43b 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -120,6 +120,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
# Persist the event which should invalidate or prefill the
# `have_seen_event` cache so we don't return stale values.
persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
self.get_success(
persistence.persist_event(
event,
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index c070278db8..a10e5fa8b1 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -389,6 +389,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
"""
persist_events_store = self.hs.get_datastores().persist_events
+ assert persist_events_store is not None
for e in events:
e.internal_metadata.stream_ordering = self._next_stream_ordering
@@ -397,6 +398,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
def _persist(txn: LoggingTransaction) -> None:
# We need to persist the events to the events and state_events
# tables.
+ assert persist_events_store is not None
persist_events_store._store_event_txn(
txn,
[(e, EventContext(self.hs.get_storage_controllers())) for e in events],
@@ -540,7 +542,9 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
self.requester, events_and_context=[(event, context)]
)
)
- state1 = set(self.get_success(context.get_current_state_ids()).values())
+ state_ids1 = self.get_success(context.get_current_state_ids())
+ assert state_ids1 is not None
+ state1 = set(state_ids1.values())
event, context = self.get_success(
event_handler.create_event(
@@ -560,7 +564,9 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
self.requester, events_and_context=[(event, context)]
)
)
- state2 = set(self.get_success(context.get_current_state_ids()).values())
+ state_ids2 = self.get_success(context.get_current_state_ids())
+ assert state_ids2 is not None
+ state2 = set(state_ids2.values())
# Delete the chain cover info.
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 7fd3e01364..8fc7936ab0 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -54,6 +54,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
+ persist_events = hs.get_datastores().persist_events
+ assert persist_events is not None
+ self.persist_events = persist_events
def test_get_prev_events_for_room(self) -> None:
room_id = "@ROOM:local"
@@ -226,7 +229,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
},
)
- self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+ self.persist_events._persist_event_auth_chain_txn(
txn,
[
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
@@ -445,7 +448,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
)
# Insert all events apart from 'B'
- self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+ self.persist_events._persist_event_auth_chain_txn(
txn,
[
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
@@ -464,7 +467,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
updatevalues={"has_auth_chain_index": False},
)
- self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+ self.persist_events._persist_event_auth_chain_txn(
txn,
[cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))],
)
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index 05661a537d..e67dd0589d 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -40,7 +40,9 @@ class ExtremPruneTestCase(HomeserverTestCase):
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self.state = self.hs.get_state_handler()
- self._persistence = self.hs.get_storage_controllers().persistence
+ persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
+ self._persistence = persistence
self._state_storage_controller = self.hs.get_storage_controllers().state
self.store = self.hs.get_datastores().main
@@ -374,7 +376,9 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self.state = self.hs.get_state_handler()
- self._persistence = self.hs.get_storage_controllers().persistence
+ persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
+ self._persistence = persistence
self.store = self.hs.get_datastores().main
def test_remote_user_rooms_cache_invalidated(self) -> None:
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index aa4b5bd3b1..ba68171ad7 100644
--- a/tests/storage/test_keys.py
+++ b/tests/storage/test_keys.py
@@ -16,8 +16,6 @@ import signedjson.key
import signedjson.types
import unpaddedbase64
-from twisted.internet.defer import Deferred
-
from synapse.storage.keys import FetchKeyResult
import tests.unittest
@@ -44,20 +42,26 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
key_id_1 = "ed25519:key1"
key_id_2 = "ed25519:KEY_ID_2"
- d = store.store_server_verify_keys(
- "from_server",
- 10,
- [
- ("server1", key_id_1, FetchKeyResult(KEY_1, 100)),
- ("server1", key_id_2, FetchKeyResult(KEY_2, 200)),
- ],
+ self.get_success(
+ store.store_server_verify_keys(
+ "from_server",
+ 10,
+ [
+ ("server1", key_id_1, FetchKeyResult(KEY_1, 100)),
+ ("server1", key_id_2, FetchKeyResult(KEY_2, 200)),
+ ],
+ )
)
- self.get_success(d)
- d = store.get_server_verify_keys(
- [("server1", key_id_1), ("server1", key_id_2), ("server1", "ed25519:key3")]
+ res = self.get_success(
+ store.get_server_verify_keys(
+ [
+ ("server1", key_id_1),
+ ("server1", key_id_2),
+ ("server1", "ed25519:key3"),
+ ]
+ )
)
- res = self.get_success(d)
self.assertEqual(len(res.keys()), 3)
res1 = res[("server1", key_id_1)]
@@ -82,18 +86,20 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
key_id_1 = "ed25519:key1"
key_id_2 = "ed25519:key2"
- d = store.store_server_verify_keys(
- "from_server",
- 0,
- [
- ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)),
- ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)),
- ],
+ self.get_success(
+ store.store_server_verify_keys(
+ "from_server",
+ 0,
+ [
+ ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)),
+ ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)),
+ ],
+ )
)
- self.get_success(d)
- d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
- res = self.get_success(d)
+ res = self.get_success(
+ store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
+ )
self.assertEqual(len(res.keys()), 2)
res1 = res[("srv1", key_id_1)]
@@ -105,9 +111,7 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
self.assertEqual(res2.valid_until_ts, 200)
# we should be able to look up the same thing again without a db hit
- res = store.get_server_verify_keys([("srv1", key_id_1)])
- if isinstance(res, Deferred):
- res = self.successResultOf(res)
+ res = self.get_success(store.get_server_verify_keys([("srv1", key_id_1)]))
self.assertEqual(len(res.keys()), 1)
self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1)
@@ -119,8 +123,9 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
)
self.get_success(d)
- d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
- res = self.get_success(d)
+ res = self.get_success(
+ store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
+ )
self.assertEqual(len(res.keys()), 2)
res1 = res[("srv1", key_id_1)]
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 010cc74c31..d8f42c5d05 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -112,7 +112,7 @@ class PurgeTests(HomeserverTestCase):
self.room_id, "m.room.create", ""
)
)
- self.assertIsNotNone(create_event)
+ assert create_event is not None
# Purge everything before this topological token
self.get_success(
diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py
index d8d84152dc..12c17f1073 100644
--- a/tests/storage/test_receipts.py
+++ b/tests/storage/test_receipts.py
@@ -37,9 +37,9 @@ class ReceiptTestCase(HomeserverTestCase):
self.store = homeserver.get_datastores().main
self.room_creator = homeserver.get_room_creation_handler()
- self.persist_event_storage_controller = (
- self.hs.get_storage_controllers().persistence
- )
+ persist_event_storage_controller = self.hs.get_storage_controllers().persistence
+ assert persist_event_storage_controller is not None
+ self.persist_event_storage_controller = persist_event_storage_controller
# Create a test user
self.ourUser = UserID.from_string(OUR_USER_ID)
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index df4740f9d9..0100f7da14 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -74,10 +74,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
+
self.get_success(self._persistence.persist_event(event, context))
return event
@@ -96,10 +98,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
+
self.get_success(self._persistence.persist_event(event, context))
return event
@@ -119,10 +123,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
+
self.get_success(self._persistence.persist_event(event, context))
return event
@@ -259,7 +265,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
def internal_metadata(self) -> _EventInternalMetadata:
return self._base_builder.internal_metadata
- event_1, context_1 = self.get_success(
+ event_1, unpersisted_context_1 = self.get_success(
self.event_creation_handler.create_new_client_event(
cast(
EventBuilder,
@@ -280,9 +286,11 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
)
+ context_1 = self.get_success(unpersisted_context_1.persist(event_1))
+
self.get_success(self._persistence.persist_event(event_1, context_1))
- event_2, context_2 = self.get_success(
+ event_2, unpersisted_context_2 = self.get_success(
self.event_creation_handler.create_new_client_event(
cast(
EventBuilder,
@@ -302,6 +310,8 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
)
)
+
+ context_2 = self.get_success(unpersisted_context_2.persist(event_2))
self.get_success(self._persistence.persist_event(event_2, context_2))
# fetch one of the redactions
@@ -421,10 +431,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
},
)
- redaction_event, context = self.get_success(
+ redaction_event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(redaction_event))
+
self.get_success(self._persistence.persist_event(redaction_event, context))
# Now lets jump to the future where we have censored the redaction event
diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py
index 14d872514d..f183c38477 100644
--- a/tests/storage/test_room_search.py
+++ b/tests/storage/test_room_search.py
@@ -119,7 +119,6 @@ class EventSearchInsertionTest(HomeserverTestCase):
"content": {"msgtype": "m.text", "body": 2},
"room_id": room_id,
"sender": user_id,
- "depth": prev_event.depth + 1,
"prev_events": prev_event_ids,
"origin_server_ts": self.clock.time_msec(),
}
@@ -134,7 +133,7 @@ class EventSearchInsertionTest(HomeserverTestCase):
prev_state_map,
for_verification=False,
),
- depth=event_dict["depth"],
+ depth=prev_event.depth + 1,
)
)
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index bad7f0bc60..f730b888f7 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -67,10 +67,12 @@ class StateStoreTestCase(HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
+
assert self.storage.persistence is not None
self.get_success(self.storage.persistence.persist_event(event, context))
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
index bc090ebce0..05dc4f64b8 100644
--- a/tests/storage/test_stream.py
+++ b/tests/storage/test_stream.py
@@ -16,7 +16,7 @@ from typing import List
from twisted.test.proto_helpers import MemoryReactor
-from synapse.api.constants import EventTypes, RelationTypes
+from synapse.api.constants import Direction, EventTypes, RelationTypes
from synapse.api.filtering import Filter
from synapse.rest import admin
from synapse.rest.client import login, room
@@ -128,7 +128,7 @@ class PaginationTestCase(HomeserverTestCase):
room_id=self.room_id,
from_key=self.from_token.room_key,
to_key=None,
- direction="f",
+ direction=Direction.FORWARDS,
limit=10,
event_filter=Filter(self.hs, filter),
)
diff --git a/tests/storage/test_unsafe_locale.py b/tests/storage/test_unsafe_locale.py
index ba53c22818..19da8a9b09 100644
--- a/tests/storage/test_unsafe_locale.py
+++ b/tests/storage/test_unsafe_locale.py
@@ -14,6 +14,7 @@
from unittest.mock import MagicMock, patch
from synapse.storage.database import make_conn
+from synapse.storage.engines import PostgresEngine
from synapse.storage.engines._base import IncorrectDatabaseSetup
from tests.unittest import HomeserverTestCase
@@ -38,6 +39,7 @@ class UnsafeLocaleTest(HomeserverTestCase):
def test_safe_locale(self) -> None:
database = self.hs.get_datastores().databases[0]
+ assert isinstance(database.engine, PostgresEngine)
db_conn = make_conn(database._database_config, database.engine, "test_unsafe")
with db_conn.cursor() as txn:
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index f1ca523d23..2d169684cf 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -25,6 +25,11 @@ from synapse.rest.client import login, register, room
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.background_updates import _BackgroundUpdateHandler
+from synapse.storage.databases.main import user_directory
+from synapse.storage.databases.main.user_directory import (
+ _parse_words_with_icu,
+ _parse_words_with_regex,
+)
from synapse.storage.roommember import ProfileInfo
from synapse.util import Clock
@@ -42,7 +47,7 @@ ALICE = "@alice:a"
BOB = "@bob:b"
BOBBY = "@bobby:a"
# The localpart isn't 'Bela' on purpose so we can test looking up display names.
-BELA = "@somenickname:a"
+BELA = "@somenickname:example.org"
class GetUserDirectoryTables:
@@ -423,6 +428,8 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
class UserDirectoryStoreTestCase(HomeserverTestCase):
+ use_icu = False
+
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
@@ -434,6 +441,12 @@ class UserDirectoryStoreTestCase(HomeserverTestCase):
self.get_success(self.store.update_profile_in_user_dir(BELA, "Bela", None))
self.get_success(self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)))
+ self._restore_use_icu = user_directory.USE_ICU
+ user_directory.USE_ICU = self.use_icu
+
+ def tearDown(self) -> None:
+ user_directory.USE_ICU = self._restore_use_icu
+
def test_search_user_dir(self) -> None:
# normally when alice searches the directory she should just find
# bob because bobby doesn't share a room with her.
@@ -478,6 +491,26 @@ class UserDirectoryStoreTestCase(HomeserverTestCase):
{"user_id": BELA, "display_name": "Bela", "avatar_url": None},
)
+ @override_config({"user_directory": {"search_all_users": True}})
+ def test_search_user_dir_start_of_user_id(self) -> None:
+ """Tests that a user can look up another user by searching for the start
+ of their user ID.
+ """
+ r = self.get_success(self.store.search_user_dir(ALICE, "somenickname:exa", 10))
+ self.assertFalse(r["limited"])
+ self.assertEqual(1, len(r["results"]))
+ self.assertDictEqual(
+ r["results"][0],
+ {"user_id": BELA, "display_name": "Bela", "avatar_url": None},
+ )
+
+
+class UserDirectoryStoreTestCaseWithIcu(UserDirectoryStoreTestCase):
+ use_icu = True
+
+ if not icu:
+ skip = "Requires PyICU"
+
class UserDirectoryICUTestCase(HomeserverTestCase):
if not icu:
@@ -513,3 +546,31 @@ class UserDirectoryICUTestCase(HomeserverTestCase):
r["results"][0],
{"user_id": ALICE, "display_name": display_name, "avatar_url": None},
)
+
+ def test_icu_word_boundary_punctuation(self) -> None:
+ """
+ Tests the behaviour of punctuation with the ICU tokeniser.
+
+ Seems to depend on underlying version of ICU.
+ """
+
+ # Note: either tokenisation is fine, because Postgres actually splits
+ # words itself afterwards.
+ self.assertIn(
+ _parse_words_with_icu("lazy'fox jumped:over the.dog"),
+ (
+ # ICU 66 on Ubuntu 20.04
+ ["lazy'fox", "jumped", "over", "the", "dog"],
+ # ICU 70 on Ubuntu 22.04
+ ["lazy'fox", "jumped:over", "the.dog"],
+ ),
+ )
+
+ def test_regex_word_boundary_punctuation(self) -> None:
+ """
+ Tests the behaviour of punctuation with the non-ICU tokeniser
+ """
+ self.assertEqual(
+ _parse_words_with_regex("lazy'fox jumped:over the.dog"),
+ ["lazy", "fox", "jumped", "over", "the", "dog"],
+ )
diff --git a/tests/test_distributor.py b/tests/test_distributor.py
index 31546ea52b..a248f1d277 100644
--- a/tests/test_distributor.py
+++ b/tests/test_distributor.py
@@ -21,10 +21,10 @@ from . import unittest
class DistributorTestCase(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.dist = Distributor()
- def test_signal_dispatch(self):
+ def test_signal_dispatch(self) -> None:
self.dist.declare("alert")
observer = Mock()
@@ -33,7 +33,7 @@ class DistributorTestCase(unittest.TestCase):
self.dist.fire("alert", 1, 2, 3)
observer.assert_called_with(1, 2, 3)
- def test_signal_catch(self):
+ def test_signal_catch(self) -> None:
self.dist.declare("alarm")
observers = [Mock() for i in (1, 2)]
@@ -51,7 +51,7 @@ class DistributorTestCase(unittest.TestCase):
self.assertEqual(mock_logger.warning.call_count, 1)
self.assertIsInstance(mock_logger.warning.call_args[0][0], str)
- def test_signal_prereg(self):
+ def test_signal_prereg(self) -> None:
observer = Mock()
self.dist.observe("flare", observer)
@@ -60,8 +60,8 @@ class DistributorTestCase(unittest.TestCase):
observer.assert_called_with(4, 5)
- def test_signal_undeclared(self):
- def code():
+ def test_signal_undeclared(self) -> None:
+ def code() -> None:
self.dist.fire("notification")
self.assertRaises(KeyError, code)
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 0a7937f1cc..2860564afc 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -31,13 +31,13 @@ from tests.test_utils import get_awaitable_result
class _StubEventSourceStore:
"""A stub implementation of the EventSourceStore"""
- def __init__(self):
+ def __init__(self) -> None:
self._store: Dict[str, EventBase] = {}
- def add_event(self, event: EventBase):
+ def add_event(self, event: EventBase) -> None:
self._store[event.event_id] = event
- def add_events(self, events: Iterable[EventBase]):
+ def add_events(self, events: Iterable[EventBase]) -> None:
for event in events:
self._store[event.event_id] = event
@@ -59,7 +59,7 @@ class _StubEventSourceStore:
class EventAuthTestCase(unittest.TestCase):
- def test_rejected_auth_events(self):
+ def test_rejected_auth_events(self) -> None:
"""
Events that refer to rejected events in their auth events are rejected
"""
@@ -109,7 +109,7 @@ class EventAuthTestCase(unittest.TestCase):
)
)
- def test_create_event_with_prev_events(self):
+ def test_create_event_with_prev_events(self) -> None:
"""A create event with prev_events should be rejected
https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules
@@ -150,7 +150,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_state_independent_auth_rules(event_store, bad_event)
)
- def test_duplicate_auth_events(self):
+ def test_duplicate_auth_events(self) -> None:
"""Events with duplicate auth_events should be rejected
https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules
@@ -196,7 +196,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_state_independent_auth_rules(event_store, bad_event2)
)
- def test_unexpected_auth_events(self):
+ def test_unexpected_auth_events(self) -> None:
"""Events with excess auth_events should be rejected
https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules
@@ -236,7 +236,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_state_independent_auth_rules(event_store, bad_event)
)
- def test_random_users_cannot_send_state_before_first_pl(self):
+ def test_random_users_cannot_send_state_before_first_pl(self) -> None:
"""
Check that, before the first PL lands, the creator is the only user
that can send a state event.
@@ -263,7 +263,7 @@ class EventAuthTestCase(unittest.TestCase):
auth_events,
)
- def test_state_default_level(self):
+ def test_state_default_level(self) -> None:
"""
Check that users above the state_default level can send state and
those below cannot
@@ -298,7 +298,7 @@ class EventAuthTestCase(unittest.TestCase):
auth_events,
)
- def test_alias_event(self):
+ def test_alias_event(self) -> None:
"""Alias events have special behavior up through room version 6."""
creator = "@creator:example.com"
other = "@other:example.com"
@@ -333,7 +333,7 @@ class EventAuthTestCase(unittest.TestCase):
auth_events,
)
- def test_msc2432_alias_event(self):
+ def test_msc2432_alias_event(self) -> None:
"""After MSC2432, alias events have no special behavior."""
creator = "@creator:example.com"
other = "@other:example.com"
@@ -366,7 +366,9 @@ class EventAuthTestCase(unittest.TestCase):
)
@parameterized.expand([(RoomVersions.V1, True), (RoomVersions.V6, False)])
- def test_notifications(self, room_version: RoomVersion, allow_modification: bool):
+ def test_notifications(
+ self, room_version: RoomVersion, allow_modification: bool
+ ) -> None:
"""
Notifications power levels get checked due to MSC2209.
"""
@@ -395,7 +397,7 @@ class EventAuthTestCase(unittest.TestCase):
with self.assertRaises(AuthError):
event_auth.check_state_dependent_auth_rules(pl_event, auth_events)
- def test_join_rules_public(self):
+ def test_join_rules_public(self) -> None:
"""
Test joining a public room.
"""
@@ -460,7 +462,7 @@ class EventAuthTestCase(unittest.TestCase):
auth_events.values(),
)
- def test_join_rules_invite(self):
+ def test_join_rules_invite(self) -> None:
"""
Test joining an invite only room.
"""
@@ -835,7 +837,7 @@ def _power_levels_event(
)
-def _alias_event(room_version: RoomVersion, sender: str, **kwargs) -> EventBase:
+def _alias_event(room_version: RoomVersion, sender: str, **kwargs: Any) -> EventBase:
data = {
"room_id": TEST_ROOM_ID,
**_maybe_get_event_id_dict_for_room_version(room_version),
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 80e5c590d8..82dfd88b99 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -12,53 +12,48 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Collection, List, Optional, Union
from unittest.mock import Mock
-from twisted.internet.defer import succeed
+from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import FederationError
-from synapse.api.room_versions import RoomVersions
-from synapse.events import make_event_from_dict
+from synapse.api.room_versions import RoomVersion, RoomVersions
+from synapse.events import EventBase, make_event_from_dict
+from synapse.events.snapshot import EventContext
from synapse.federation.federation_base import event_from_pdu_json
+from synapse.handlers.device import DeviceListUpdater
+from synapse.http.types import QueryParams
from synapse.logging.context import LoggingContext
-from synapse.types import UserID, create_requester
+from synapse.server import HomeServer
+from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
from synapse.util.retryutils import NotRetryingDestination
from tests import unittest
-from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver
from tests.test_utils import make_awaitable
class MessageAcceptTests(unittest.HomeserverTestCase):
- def setUp(self):
-
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.http_client = Mock()
- self.reactor = ThreadedMemoryReactorClock()
- self.hs_clock = Clock(self.reactor)
- self.homeserver = setup_test_homeserver(
- self.addCleanup,
- federation_http_client=self.http_client,
- clock=self.hs_clock,
- reactor=self.reactor,
- )
+ return self.setup_test_homeserver(federation_http_client=self.http_client)
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
user_id = UserID("us", "test")
our_user = create_requester(user_id)
- room_creator = self.homeserver.get_room_creation_handler()
+ room_creator = self.hs.get_room_creation_handler()
self.room_id = self.get_success(
room_creator.create_room(
our_user, room_creator._presets_dict["public_chat"], ratelimit=False
)
)[0]["room_id"]
- self.store = self.homeserver.get_datastores().main
+ self.store = self.hs.get_datastores().main
# Figure out what the most recent event is
most_recent = self.get_success(
- self.homeserver.get_datastores().main.get_latest_event_ids_in_room(
- self.room_id
- )
+ self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
)[0]
join_event = make_event_from_dict(
@@ -78,17 +73,23 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
}
)
- self.handler = self.homeserver.get_federation_handler()
- federation_event_handler = self.homeserver.get_federation_event_handler()
+ self.handler = self.hs.get_federation_handler()
+ federation_event_handler = self.hs.get_federation_event_handler()
- async def _check_event_auth(origin, event, context):
+ async def _check_event_auth(
+ origin: Optional[str], event: EventBase, context: EventContext
+ ) -> None:
pass
- federation_event_handler._check_event_auth = _check_event_auth
- self.client = self.homeserver.get_federation_client()
- self.client._check_sigs_and_hash_for_pulled_events_and_fetch = (
- lambda dest, pdus, **k: succeed(pdus)
- )
+ federation_event_handler._check_event_auth = _check_event_auth # type: ignore[assignment]
+ self.client = self.hs.get_federation_client()
+
+ async def _check_sigs_and_hash_for_pulled_events_and_fetch(
+ dest: str, pdus: Collection[EventBase], room_version: RoomVersion
+ ) -> List[EventBase]:
+ return list(pdus)
+
+ self.client._check_sigs_and_hash_for_pulled_events_and_fetch = _check_sigs_and_hash_for_pulled_events_and_fetch # type: ignore[assignment]
# Send the join, it should return None (which is not an error)
self.assertEqual(
@@ -104,16 +105,25 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
"$join:test.serv",
)
- def test_cant_hide_direct_ancestors(self):
+ def test_cant_hide_direct_ancestors(self) -> None:
"""
If you send a message, you must be able to provide the direct
prev_events that said event references.
"""
- async def post_json(destination, path, data, headers=None, timeout=0):
+ async def post_json(
+ destination: str,
+ path: str,
+ data: Optional[JsonDict] = None,
+ long_retries: bool = False,
+ timeout: Optional[int] = None,
+ ignore_backoff: bool = False,
+ args: Optional[QueryParams] = None,
+ ) -> Union[JsonDict, list]:
# If it asks us for new missing events, give them NOTHING
if path.startswith("/_matrix/federation/v1/get_missing_events/"):
return {"events": []}
+ return {}
self.http_client.post_json = post_json
@@ -138,7 +148,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
}
)
- federation_event_handler = self.homeserver.get_federation_event_handler()
+ federation_event_handler = self.hs.get_federation_event_handler()
with LoggingContext("test-context"):
failure = self.get_failure(
federation_event_handler.on_receive_pdu("test.serv", lying_event),
@@ -158,7 +168,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
self.assertEqual(extrem[0], "$join:test.serv")
- def test_retry_device_list_resync(self):
+ def test_retry_device_list_resync(self) -> None:
"""Tests that device lists are marked as stale if they couldn't be synced, and
that stale device lists are retried periodically.
"""
@@ -171,24 +181,27 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# When this function is called, increment the number of resync attempts (only if
# we're querying devices for the right user ID), then raise a
# NotRetryingDestination error to fail the resync gracefully.
- def query_user_devices(destination, user_id):
+ def query_user_devices(
+ destination: str, user_id: str, timeout: int = 30000
+ ) -> JsonDict:
if user_id == remote_user_id:
self.resync_attempts += 1
raise NotRetryingDestination(0, 0, destination)
# Register the mock on the federation client.
- federation_client = self.homeserver.get_federation_client()
- federation_client.query_user_devices = Mock(side_effect=query_user_devices)
+ federation_client = self.hs.get_federation_client()
+ federation_client.query_user_devices = Mock(side_effect=query_user_devices) # type: ignore[assignment]
# Register a mock on the store so that the incoming update doesn't fail because
# we don't share a room with the user.
- store = self.homeserver.get_datastores().main
+ store = self.hs.get_datastores().main
store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"]))
# Manually inject a fake device list update. We need this update to include at
# least one prev_id so that the user's device list will need to be retried.
- device_list_updater = self.homeserver.get_device_handler().device_list_updater
+ device_list_updater = self.hs.get_device_handler().device_list_updater
+ assert isinstance(device_list_updater, DeviceListUpdater)
self.get_success(
device_list_updater.incoming_device_list_update(
origin=remote_origin,
@@ -218,7 +231,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
self.reactor.advance(30)
self.assertEqual(self.resync_attempts, 2)
- def test_cross_signing_keys_retry(self):
+ def test_cross_signing_keys_retry(self) -> None:
"""Tests that resyncing a device list correctly processes cross-signing keys from
the remote server.
"""
@@ -227,8 +240,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
# Register mock device list retrieval on the federation client.
- federation_client = self.homeserver.get_federation_client()
- federation_client.query_user_devices = Mock(
+ federation_client = self.hs.get_federation_client()
+ federation_client.query_user_devices = Mock( # type: ignore[assignment]
return_value=make_awaitable(
{
"user_id": remote_user_id,
@@ -252,7 +265,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
)
# Resync the device list.
- device_handler = self.homeserver.get_device_handler()
+ device_handler = self.hs.get_device_handler()
self.get_success(
device_handler.device_list_updater.user_device_resync(remote_user_id),
)
@@ -261,16 +274,18 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
keys = self.get_success(
self.store.get_e2e_cross_signing_keys_bulk(user_ids=[remote_user_id]),
)
- self.assertTrue(remote_user_id in keys)
+ self.assertIn(remote_user_id, keys)
+ key = keys[remote_user_id]
+ assert key is not None
# Check that the master key is the one returned by the mock.
- master_key = keys[remote_user_id]["master"]
+ master_key = key["master"]
self.assertEqual(len(master_key["keys"]), 1)
self.assertTrue("ed25519:" + remote_master_key in master_key["keys"].keys())
self.assertTrue(remote_master_key in master_key["keys"].values())
# Check that the self-signing key is the one returned by the mock.
- self_signing_key = keys[remote_user_id]["self_signing"]
+ self_signing_key = key["self_signing"]
self.assertEqual(len(self_signing_key["keys"]), 1)
self.assertTrue(
"ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(),
@@ -279,7 +294,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
class StripUnsignedFromEventsTestCase(unittest.TestCase):
- def test_strip_unauthorized_unsigned_values(self):
+ def test_strip_unauthorized_unsigned_values(self) -> None:
event1 = {
"sender": "@baduser:test.serv",
"state_key": "@baduser:test.serv",
@@ -296,7 +311,7 @@ class StripUnsignedFromEventsTestCase(unittest.TestCase):
# Make sure unauthorized fields are stripped from unsigned
self.assertNotIn("more warez", filtered_event.unsigned)
- def test_strip_event_maintains_allowed_fields(self):
+ def test_strip_event_maintains_allowed_fields(self) -> None:
event2 = {
"sender": "@baduser:test.serv",
"state_key": "@baduser:test.serv",
@@ -323,7 +338,7 @@ class StripUnsignedFromEventsTestCase(unittest.TestCase):
self.assertIn("invite_room_state", filtered_event2.unsigned)
self.assertEqual([], filtered_event2.unsigned["invite_room_state"])
- def test_strip_event_removes_fields_based_on_event_type(self):
+ def test_strip_event_removes_fields_based_on_event_type(self) -> None:
event3 = {
"sender": "@baduser:test.serv",
"state_key": "@baduser:test.serv",
diff --git a/tests/test_mau.py b/tests/test_mau.py
index f14fcb7db9..4e7665a22b 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -14,12 +14,17 @@
"""Tests REST events for /rooms paths."""
-from typing import List
+from typing import List, Optional
+
+from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.appservice import ApplicationService
from synapse.rest.client import register, sync
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
from tests import unittest
from tests.unittest import override_config
@@ -30,7 +35,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
servlets = [register.register_servlets, sync.register_servlets]
- def default_config(self):
+ def default_config(self) -> JsonDict:
config = default_config("test")
config.update(
@@ -53,10 +58,12 @@ class TestMauLimit(unittest.HomeserverTestCase):
return config
- def prepare(self, reactor, clock, homeserver):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
self.store = homeserver.get_datastores().main
- def test_simple_deny_mau(self):
+ def test_simple_deny_mau(self) -> None:
# Create and sync so that the MAU counts get updated
token1 = self.create_user("kermit1")
self.do_sync_for_user(token1)
@@ -75,7 +82,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.code, 403)
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
- def test_as_ignores_mau(self):
+ def test_as_ignores_mau(self) -> None:
"""Test that application services can still create users when the MAU
limit has been reached. This only works when application service
user ip tracking is disabled.
@@ -113,7 +120,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.create_user("as_kermit4", token=as_token, appservice=True)
- def test_allowed_after_a_month_mau(self):
+ def test_allowed_after_a_month_mau(self) -> None:
# Create and sync so that the MAU counts get updated
token1 = self.create_user("kermit1")
self.do_sync_for_user(token1)
@@ -132,7 +139,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.do_sync_for_user(token3)
@override_config({"mau_trial_days": 1})
- def test_trial_delay(self):
+ def test_trial_delay(self) -> None:
# We should be able to register more than the limit initially
token1 = self.create_user("kermit1")
self.do_sync_for_user(token1)
@@ -165,7 +172,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
@override_config({"mau_trial_days": 1})
- def test_trial_users_cant_come_back(self):
+ def test_trial_users_cant_come_back(self) -> None:
self.hs.config.server.mau_trial_days = 1
# We should be able to register more than the limit initially
@@ -216,7 +223,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
# max_mau_value should not matter
{"max_mau_value": 1, "limit_usage_by_mau": False, "mau_stats_only": True}
)
- def test_tracked_but_not_limited(self):
+ def test_tracked_but_not_limited(self) -> None:
# Simply being able to create 2 users indicates that the
# limit was not reached.
token1 = self.create_user("kermit1")
@@ -236,10 +243,10 @@ class TestMauLimit(unittest.HomeserverTestCase):
"mau_appservice_trial_days": {"SomeASID": 1, "AnotherASID": 2},
}
)
- def test_as_trial_days(self):
+ def test_as_trial_days(self) -> None:
user_tokens: List[str] = []
- def advance_time_and_sync():
+ def advance_time_and_sync() -> None:
self.reactor.advance(24 * 60 * 61)
for token in user_tokens:
self.do_sync_for_user(token)
@@ -300,7 +307,9 @@ class TestMauLimit(unittest.HomeserverTestCase):
},
)
- def create_user(self, localpart, token=None, appservice=False):
+ def create_user(
+ self, localpart: str, token: Optional[str] = None, appservice: bool = False
+ ) -> str:
request_data = {
"username": localpart,
"password": "monkey",
@@ -326,7 +335,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
return access_token
- def do_sync_for_user(self, token):
+ def do_sync_for_user(self, token: str) -> None:
channel = self.make_request("GET", "/sync", access_token=token)
if channel.code != 200:
diff --git a/tests/test_phone_home.py b/tests/test_phone_home.py
index cc1a98f1c4..3f899b0d91 100644
--- a/tests/test_phone_home.py
+++ b/tests/test_phone_home.py
@@ -33,7 +33,7 @@ class PhoneHomeStatsTestCase(HomeserverTestCase):
If time doesn't move, don't error out.
"""
past_stats = [
- (self.hs.get_clock().time(), resource.getrusage(resource.RUSAGE_SELF))
+ (int(self.hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF))
]
stats: JsonDict = {}
self.get_success(phone_stats_home(self.hs, stats, past_stats))
diff --git a/tests/test_rust.py b/tests/test_rust.py
index 55d8b6b28c..67443b6280 100644
--- a/tests/test_rust.py
+++ b/tests/test_rust.py
@@ -6,6 +6,6 @@ from tests import unittest
class RustTestCase(unittest.TestCase):
"""Basic tests to ensure that we can call into Rust code."""
- def test_basic(self):
+ def test_basic(self) -> None:
result = sum_as_string(1, 2)
self.assertEqual("3", result)
diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py
index d04bcae0fa..5cd698147e 100644
--- a/tests/test_test_utils.py
+++ b/tests/test_test_utils.py
@@ -17,25 +17,25 @@ from tests.utils import MockClock
class MockClockTestCase(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.clock = MockClock()
- def test_advance_time(self):
+ def test_advance_time(self) -> None:
start_time = self.clock.time()
self.clock.advance_time(20)
self.assertEqual(20, self.clock.time() - start_time)
- def test_later(self):
+ def test_later(self) -> None:
invoked = [0, 0]
- def _cb0():
+ def _cb0() -> None:
invoked[0] = 1
self.clock.call_later(10, _cb0)
- def _cb1():
+ def _cb1() -> None:
invoked[1] = 1
self.clock.call_later(20, _cb1)
@@ -51,15 +51,15 @@ class MockClockTestCase(unittest.TestCase):
self.assertTrue(invoked[1])
- def test_cancel_later(self):
+ def test_cancel_later(self) -> None:
invoked = [0, 0]
- def _cb0():
+ def _cb0() -> None:
invoked[0] = 1
t0 = self.clock.call_later(10, _cb0)
- def _cb1():
+ def _cb1() -> None:
invoked[1] = 1
self.clock.call_later(20, _cb1)
diff --git a/tests/test_types.py b/tests/test_types.py
index 1111169384..c491cc9a96 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -43,34 +43,34 @@ class IsMineIDTests(unittest.HomeserverTestCase):
class UserIDTestCase(unittest.HomeserverTestCase):
- def test_parse(self):
+ def test_parse(self) -> None:
user = UserID.from_string("@1234abcd:test")
self.assertEqual("1234abcd", user.localpart)
self.assertEqual("test", user.domain)
self.assertEqual(True, self.hs.is_mine(user))
- def test_parse_rejects_empty_id(self):
+ def test_parse_rejects_empty_id(self) -> None:
with self.assertRaises(SynapseError):
UserID.from_string("")
- def test_parse_rejects_missing_sigil(self):
+ def test_parse_rejects_missing_sigil(self) -> None:
with self.assertRaises(SynapseError):
UserID.from_string("alice:example.com")
- def test_parse_rejects_missing_separator(self):
+ def test_parse_rejects_missing_separator(self) -> None:
with self.assertRaises(SynapseError):
UserID.from_string("@alice.example.com")
- def test_validation_rejects_missing_domain(self):
+ def test_validation_rejects_missing_domain(self) -> None:
self.assertFalse(UserID.is_valid("@alice:"))
- def test_build(self):
+ def test_build(self) -> None:
user = UserID("5678efgh", "my.domain")
self.assertEqual(user.to_string(), "@5678efgh:my.domain")
- def test_compare(self):
+ def test_compare(self) -> None:
userA = UserID.from_string("@userA:my.domain")
userAagain = UserID.from_string("@userA:my.domain")
userB = UserID.from_string("@userB:my.domain")
@@ -80,43 +80,43 @@ class UserIDTestCase(unittest.HomeserverTestCase):
class RoomAliasTestCase(unittest.HomeserverTestCase):
- def test_parse(self):
+ def test_parse(self) -> None:
room = RoomAlias.from_string("#channel:test")
self.assertEqual("channel", room.localpart)
self.assertEqual("test", room.domain)
self.assertEqual(True, self.hs.is_mine(room))
- def test_build(self):
+ def test_build(self) -> None:
room = RoomAlias("channel", "my.domain")
self.assertEqual(room.to_string(), "#channel:my.domain")
- def test_validate(self):
+ def test_validate(self) -> None:
id_string = "#test:domain,test"
self.assertFalse(RoomAlias.is_valid(id_string))
class MapUsernameTestCase(unittest.TestCase):
- def testPassThrough(self):
+ def test_pass_througuh(self) -> None:
self.assertEqual(map_username_to_mxid_localpart("test1234"), "test1234")
- def testUpperCase(self):
+ def test_upper_case(self) -> None:
self.assertEqual(map_username_to_mxid_localpart("tEST_1234"), "test_1234")
self.assertEqual(
map_username_to_mxid_localpart("tEST_1234", case_sensitive=True),
"t_e_s_t__1234",
)
- def testSymbols(self):
+ def test_symbols(self) -> None:
self.assertEqual(
map_username_to_mxid_localpart("test=$?_1234"), "test=3d=24=3f_1234"
)
- def testLeadingUnderscore(self):
+ def test_leading_underscore(self) -> None:
self.assertEqual(map_username_to_mxid_localpart("_test_1234"), "=5ftest_1234")
- def testNonAscii(self):
+ def test_non_ascii(self) -> None:
# this should work with either a unicode or a bytes
self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast")
self.assertEqual(map_username_to_mxid_localpart("têst".encode()), "t=c3=aast")
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index e62ebcc6a5..e5dae670a7 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -20,12 +20,13 @@ import sys
import warnings
from asyncio import Future
from binascii import unhexlify
-from typing import Awaitable, Callable, Tuple, TypeVar
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, TypeVar
from unittest.mock import Mock
import attr
import zope.interface
+from twisted.internet.interfaces import IProtocol
from twisted.python.failure import Failure
from twisted.web.client import ResponseDone
from twisted.web.http import RESPONSES
@@ -34,6 +35,9 @@ from twisted.web.iweb import IResponse
from synapse.types import JsonDict
+if TYPE_CHECKING:
+ from sys import UnraisableHookArgs
+
TV = TypeVar("TV")
@@ -78,25 +82,29 @@ def setup_awaitable_errors() -> Callable[[], None]:
unraisable_exceptions = []
orig_unraisablehook = sys.unraisablehook
- def unraisablehook(unraisable):
+ def unraisablehook(unraisable: "UnraisableHookArgs") -> None:
unraisable_exceptions.append(unraisable.exc_value)
- def cleanup():
+ def cleanup() -> None:
"""
A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions.
"""
sys.unraisablehook = orig_unraisablehook
if unraisable_exceptions:
- raise unraisable_exceptions.pop()
+ exc = unraisable_exceptions.pop()
+ assert exc is not None
+ raise exc
sys.unraisablehook = unraisablehook
return cleanup
-def simple_async_mock(return_value=None, raises=None) -> Mock:
+def simple_async_mock(
+ return_value: Optional[TV] = None, raises: Optional[Exception] = None
+) -> Mock:
# AsyncMock is not available in python3.5, this mimics part of its behaviour
- async def cb(*args, **kwargs):
+ async def cb(*args: Any, **kwargs: Any) -> Optional[TV]:
if raises:
raise raises
return return_value
@@ -125,14 +133,14 @@ class FakeResponse: # type: ignore[misc]
headers: Headers = attr.Factory(Headers)
@property
- def phrase(self):
+ def phrase(self) -> bytes:
return RESPONSES.get(self.code, b"Unknown Status")
@property
- def length(self):
+ def length(self) -> int:
return len(self.body)
- def deliverBody(self, protocol):
+ def deliverBody(self, protocol: IProtocol) -> None:
protocol.dataReceived(self.body)
protocol.connectionLost(Failure(ResponseDone()))
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index 8027c7a856..a6330ed840 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List, Optional, Tuple
+from typing import Any, List, Optional, Tuple
import synapse.server
from synapse.api.constants import EventTypes
@@ -32,7 +32,7 @@ async def inject_member_event(
membership: str,
target: Optional[str] = None,
extra_content: Optional[dict] = None,
- **kwargs,
+ **kwargs: Any,
) -> EventBase:
"""Inject a membership event into a room."""
if target is None:
@@ -57,7 +57,7 @@ async def inject_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
prev_event_ids: Optional[List[str]] = None,
- **kwargs,
+ **kwargs: Any,
) -> EventBase:
"""Inject a generic event into a room
@@ -82,7 +82,7 @@ async def create_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
prev_event_ids: Optional[List[str]] = None,
- **kwargs,
+ **kwargs: Any,
) -> Tuple[EventBase, EventContext]:
if room_version is None:
room_version = await hs.get_datastores().main.get_room_version_id(
@@ -92,8 +92,13 @@ async def create_event(
builder = hs.get_event_builder_factory().for_room_version(
KNOWN_ROOM_VERSIONS[room_version], kwargs
)
- event, context = await hs.get_event_creation_handler().create_new_client_event(
+ (
+ event,
+ unpersisted_context,
+ ) = await hs.get_event_creation_handler().create_new_client_event(
builder, prev_event_ids=prev_event_ids
)
+ context = await unpersisted_context.persist(event)
+
return event, context
diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py
index e878af5f12..189c697efb 100644
--- a/tests/test_utils/html_parsers.py
+++ b/tests/test_utils/html_parsers.py
@@ -13,13 +13,13 @@
# limitations under the License.
from html.parser import HTMLParser
-from typing import Dict, Iterable, List, Optional, Tuple
+from typing import Dict, Iterable, List, NoReturn, Optional, Tuple
class TestHtmlParser(HTMLParser):
"""A generic HTML page parser which extracts useful things from the HTML"""
- def __init__(self):
+ def __init__(self) -> None:
super().__init__()
# a list of links found in the doc
@@ -48,5 +48,5 @@ class TestHtmlParser(HTMLParser):
assert input_name
self.hiddens[input_name] = attr_dict["value"]
- def error(_, message):
+ def error(self, message: str) -> NoReturn:
raise AssertionError(message)
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index 304c7b98c5..b522163a34 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -25,7 +25,7 @@ class ToTwistedHandler(logging.Handler):
tx_log = twisted.logger.Logger()
- def emit(self, record):
+ def emit(self, record: logging.LogRecord) -> None:
log_entry = self.format(record)
log_level = record.levelname.lower().replace("warning", "warn")
self.tx_log.emit(
@@ -33,7 +33,7 @@ class ToTwistedHandler(logging.Handler):
)
-def setup_logging():
+def setup_logging() -> None:
"""Configure the python logging appropriately for the tests.
(Logs will end up in _trial_temp.)
diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py
index 1461d23ee8..d555b24255 100644
--- a/tests/test_utils/oidc.py
+++ b/tests/test_utils/oidc.py
@@ -14,7 +14,7 @@
import json
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, ContextManager, Dict, List, Optional, Tuple
from unittest.mock import Mock, patch
from urllib.parse import parse_qs
@@ -77,14 +77,14 @@ class FakeOidcServer:
self._id_token_overrides: Dict[str, Any] = {}
- def reset_mocks(self):
+ def reset_mocks(self) -> None:
self.request.reset_mock()
self.get_jwks_handler.reset_mock()
self.get_metadata_handler.reset_mock()
self.get_userinfo_handler.reset_mock()
self.post_token_handler.reset_mock()
- def patch_homeserver(self, hs: HomeServer):
+ def patch_homeserver(self, hs: HomeServer) -> ContextManager[Mock]:
"""Patch the ``HomeServer`` HTTP client to handle requests through the ``FakeOidcServer``.
This patch should be used whenever the HS is expected to perform request to the
@@ -188,7 +188,7 @@ class FakeOidcServer:
return self._sign(logout_token)
- def id_token_override(self, overrides: dict):
+ def id_token_override(self, overrides: dict) -> ContextManager[dict]:
"""Temporarily patch the ID token generated by the token endpoint."""
return patch.object(self, "_id_token_overrides", overrides)
@@ -247,7 +247,7 @@ class FakeOidcServer:
metadata: bool = False,
token: bool = False,
userinfo: bool = False,
- ):
+ ) -> ContextManager[Dict[str, Mock]]:
"""A context which makes a set of endpoints return a 500 error.
Args:
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index d0b9ad5454..2801a950a8 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -35,6 +35,8 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
self.event_creation_handler = self.hs.get_event_creation_handler()
self.event_builder_factory = self.hs.get_event_builder_factory()
self._storage_controllers = self.hs.get_storage_controllers()
+ assert self._storage_controllers.persistence is not None
+ self._persistence = self._storage_controllers.persistence
self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
@@ -175,12 +177,11 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(
- self._storage_controllers.persistence.persist_event(event, context)
- )
+ context = self.get_success(unpersisted_context.persist(event))
+ self.get_success(self._persistence.persist_event(event, context))
return event
def _inject_room_member(
@@ -202,13 +203,12 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
- self.get_success(
- self._storage_controllers.persistence.persist_event(event, context)
- )
+ self.get_success(self._persistence.persist_event(event, context))
return event
def _inject_message(
@@ -226,13 +226,12 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
- self.get_success(
- self._storage_controllers.persistence.persist_event(event, context)
- )
+ self.get_success(self._persistence.persist_event(event, context))
return event
def _inject_outlier(self) -> EventBase:
@@ -250,7 +249,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[]))
event.internal_metadata.outlier = True
self.get_success(
- self._storage_controllers.persistence.persist_event(
+ self._persistence.persist_event(
event, EventContext.for_outlier(self._storage_controllers)
)
)
@@ -258,7 +257,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
- def test_out_of_band_invite_rejection(self):
+ def test_out_of_band_invite_rejection(self) -> None:
# this is where we have received an invite event over federation, and then
# rejected it.
invite_pdu = {
diff --git a/tests/unittest.py b/tests/unittest.py
index fa92dd94eb..b21e7f1221 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -45,7 +45,7 @@ from typing_extensions import Concatenate, ParamSpec, Protocol
from twisted.internet.defer import Deferred, ensureDeferred
from twisted.python.failure import Failure
from twisted.python.threadpool import ThreadPool
-from twisted.test.proto_helpers import MemoryReactor
+from twisted.test.proto_helpers import MemoryReactor, MemoryReactorClock
from twisted.trial import unittest
from twisted.web.resource import Resource
from twisted.web.server import Request
@@ -82,7 +82,7 @@ from tests.server import (
)
from tests.test_utils import event_injection, setup_awaitable_errors
from tests.test_utils.logging_setup import setup_logging
-from tests.utils import default_config, setupdb
+from tests.utils import checked_cast, default_config, setupdb
setupdb()
setup_logging()
@@ -296,7 +296,12 @@ class HomeserverTestCase(TestCase):
from tests.rest.client.utils import RestHelper
- self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None))
+ self.helper = RestHelper(
+ self.hs,
+ checked_cast(MemoryReactorClock, self.hs.get_reactor()),
+ self.site,
+ getattr(self, "user_id", None),
+ )
if hasattr(self, "user_id"):
if self.hijack_auth:
@@ -315,7 +320,7 @@ class HomeserverTestCase(TestCase):
# This has to be a function and not just a Mock, because
# `self.helper.auth_user_id` is temporarily reassigned in some tests
- async def get_requester(*args, **kwargs) -> Requester:
+ async def get_requester(*args: Any, **kwargs: Any) -> Requester:
assert self.helper.auth_user_id is not None
return create_requester(
user_id=UserID.from_string(self.helper.auth_user_id),
@@ -361,7 +366,9 @@ class HomeserverTestCase(TestCase):
store.db_pool.updates.do_next_background_update(False), by=0.1
)
- def make_homeserver(self, reactor: ThreadedMemoryReactorClock, clock: Clock):
+ def make_homeserver(
+ self, reactor: ThreadedMemoryReactorClock, clock: Clock
+ ) -> HomeServer:
"""
Make and return a homeserver.
diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py
index 9529ee53c8..5f8f4e76b5 100644
--- a/tests/util/test_retryutils.py
+++ b/tests/util/test_retryutils.py
@@ -54,6 +54,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
self.pump()
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
+ assert new_timings is not None
self.assertEqual(new_timings.failure_ts, failure_ts)
self.assertEqual(new_timings.retry_last_ts, failure_ts)
self.assertEqual(new_timings.retry_interval, MIN_RETRY_INTERVAL)
@@ -82,6 +83,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
self.pump()
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
+ assert new_timings is not None
self.assertEqual(new_timings.failure_ts, failure_ts)
self.assertEqual(new_timings.retry_last_ts, retry_ts)
self.assertGreaterEqual(
diff --git a/tests/utils.py b/tests/utils.py
index d76bf9716a..a0ac11bc5c 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -15,7 +15,7 @@
import atexit
import os
-from typing import Any, Callable, Dict, List, Tuple, Union, overload
+from typing import Any, Callable, Dict, List, Tuple, Type, TypeVar, Union, overload
import attr
from typing_extensions import Literal, ParamSpec
@@ -335,6 +335,33 @@ async def create_room(hs: HomeServer, room_id: str, creator_id: str) -> None:
},
)
- event, context = await event_creation_handler.create_new_client_event(builder)
+ event, unpersisted_context = await event_creation_handler.create_new_client_event(
+ builder
+ )
+ context = await unpersisted_context.persist(event)
await persistence_store.persist_event(event, context)
+
+
+T = TypeVar("T")
+
+
+def checked_cast(type: Type[T], x: object) -> T:
+ """A version of typing.cast that is checked at runtime.
+
+ We have our own function for this for two reasons:
+
+ 1. typing.cast itself is deliberately a no-op at runtime, see
+ https://docs.python.org/3/library/typing.html#typing.cast
+ 2. To help workaround a mypy-zope bug https://github.com/Shoobx/mypy-zope/issues/91
+ where mypy would erroneously consider `isinstance(x, type)` to be false in all
+ circumstances.
+
+ For this to make sense, `T` needs to be something that `isinstance` can check; see
+ https://docs.python.org/3/library/functions.html?highlight=isinstance#isinstance
+ https://docs.python.org/3/glossary.html#term-abstract-base-class
+ https://docs.python.org/3/library/typing.html#typing.runtime_checkable
+ for more details.
+ """
+ assert isinstance(x, type)
+ return x
|