diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index 0b22afdc75..0a1ae83a2b 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -69,7 +69,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
events=events,
ephemeral=[],
to_device_messages=[], # txn made and saved
- one_time_key_counts={},
+ one_time_keys_count={},
unused_fallback_keys={},
device_list_summary=DeviceListUpdates(),
)
@@ -96,7 +96,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
events=events,
ephemeral=[],
to_device_messages=[], # txn made and saved
- one_time_key_counts={},
+ one_time_keys_count={},
unused_fallback_keys={},
device_list_summary=DeviceListUpdates(),
)
@@ -125,7 +125,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
events=events,
ephemeral=[],
to_device_messages=[],
- one_time_key_counts={},
+ one_time_keys_count={},
unused_fallback_keys={},
device_list_summary=DeviceListUpdates(),
)
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 63628aa6b0..f7c309cad0 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -433,7 +433,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
async def get_json(destination, path, **kwargs):
self.assertEqual(destination, SERVER_NAME)
- self.assertEqual(path, "/_matrix/key/v2/server/key1")
+ self.assertEqual(path, "/_matrix/key/v2/server")
return response
self.http_client.get_json.side_effect = get_json
@@ -469,18 +469,6 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
self.assertEqual(keys, {})
- def test_keyid_containing_forward_slash(self) -> None:
- """We should url-encode any url unsafe chars in key ids.
-
- Detects https://github.com/matrix-org/synapse/issues/14488.
- """
- fetcher = ServerKeyFetcher(self.hs)
- self.get_success(fetcher.get_keys("example.com", ["key/potato"], 0))
-
- self.http_client.get_json.assert_called_once()
- args, kwargs = self.http_client.get_json.call_args
- self.assertEqual(kwargs["path"], "/_matrix/key/v2/server/key%2Fpotato")
-
class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index 685a9a6d52..b703e4472e 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -126,6 +126,13 @@ class PresenceRouterTestModule:
class PresenceRouterTestCase(FederatingHomeserverTestCase):
+ """
+ Test cases using a custom PresenceRouter
+
+ By default in test cases, federation sending is disabled. This class re-enables it
+ for the main process by setting `federation_sender_instances` to None.
+ """
+
servlets = [
admin.register_servlets,
login.register_servlets,
@@ -150,6 +157,11 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
self.sync_handler = self.hs.get_sync_handler()
self.module_api = homeserver.get_module_api()
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["federation_sender_instances"] = None
+ return config
+
@override_config(
{
"presence": {
@@ -162,7 +174,6 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
},
}
},
- "send_federation": True,
}
)
def test_receiving_all_presence_legacy(self):
@@ -180,7 +191,6 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
},
},
],
- "send_federation": True,
}
)
def test_receiving_all_presence(self):
@@ -290,7 +300,6 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
},
}
},
- "send_federation": True,
}
)
def test_send_local_online_presence_to_with_module_legacy(self):
@@ -310,7 +319,6 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
},
},
],
- "send_federation": True,
}
)
def test_send_local_online_presence_to_with_module(self):
diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index 2873b4d430..b8fee72898 100644
--- a/tests/federation/test_federation_catch_up.py
+++ b/tests/federation/test_federation_catch_up.py
@@ -7,13 +7,21 @@ from synapse.federation.sender import PerDestinationQueue, TransactionManager
from synapse.federation.units import Edu
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.types import JsonDict
from synapse.util.retryutils import NotRetryingDestination
from tests.test_utils import event_injection, make_awaitable
-from tests.unittest import FederatingHomeserverTestCase, override_config
+from tests.unittest import FederatingHomeserverTestCase
class FederationCatchUpTestCases(FederatingHomeserverTestCase):
+ """
+ Tests cases of catching up over federation.
+
+ By default for test cases federation sending is disabled. This Test class has it
+ re-enabled for the main process.
+ """
+
servlets = [
admin.register_servlets,
room.register_servlets,
@@ -42,6 +50,11 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
self.record_transaction
)
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["federation_sender_instances"] = None
+ return config
+
async def record_transaction(self, txn, json_cb):
if self.is_online:
data = json_cb()
@@ -79,7 +92,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
)[0]
return {"event_id": event_id, "stream_ordering": stream_ordering}
- @override_config({"send_federation": True})
def test_catch_up_destination_rooms_tracking(self):
"""
Tests that we populate the `destination_rooms` table as needed.
@@ -105,7 +117,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
self.assertEqual(row_2["event_id"], event_id_2)
self.assertEqual(row_1["stream_ordering"], row_2["stream_ordering"] - 1)
- @override_config({"send_federation": True})
def test_catch_up_last_successful_stream_ordering_tracking(self):
"""
Tests that we populate the `destination_rooms` table as needed.
@@ -163,7 +174,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
"Send succeeded but not marked as last_successful_stream_ordering",
)
- @override_config({"send_federation": True}) # critical to federate
def test_catch_up_from_blank_state(self):
"""
Runs an overall test of federation catch-up from scratch.
@@ -260,7 +270,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
return per_dest_queue, results_list
- @override_config({"send_federation": True})
def test_catch_up_loop(self):
"""
Tests the behaviour of _catch_up_transmission_loop.
@@ -325,7 +334,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
event_5.internal_metadata.stream_ordering,
)
- @override_config({"send_federation": True})
def test_catch_up_on_synapse_startup(self):
"""
Tests the behaviour of get_catch_up_outstanding_destinations and
@@ -424,7 +432,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
# - all destinations are woken exactly once; they appear once in woken.
self.assertCountEqual(woken, server_names[:-1])
- @override_config({"send_federation": True})
def test_not_latest_event(self):
"""Test that we send the latest event in the room even if its not ours."""
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index f1e357764f..8692d8190f 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -25,10 +25,17 @@ from synapse.rest.client import login
from synapse.types import JsonDict, ReadReceipt
from tests.test_utils import make_awaitable
-from tests.unittest import HomeserverTestCase, override_config
+from tests.unittest import HomeserverTestCase
class FederationSenderReceiptsTestCases(HomeserverTestCase):
+ """
+ Test federation sending to update receipts.
+
+ By default for test cases federation sending is disabled. This Test class has it
+ re-enabled for the main process.
+ """
+
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
federation_transport_client=Mock(spec=["send_transaction"]),
@@ -38,9 +45,17 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
return_value=make_awaitable({"test", "host2"})
)
+ hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = (
+ hs.get_storage_controllers().state.get_current_hosts_in_room
+ )
+
return hs
- @override_config({"send_federation": True})
+ def default_config(self) -> JsonDict:
+ config = super().default_config()
+ config["federation_sender_instances"] = None
+ return config
+
def test_send_receipts(self):
mock_send_transaction = (
self.hs.get_federation_transport_client().send_transaction
@@ -83,7 +98,82 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
],
)
- @override_config({"send_federation": True})
+ def test_send_receipts_thread(self):
+ mock_send_transaction = (
+ self.hs.get_federation_transport_client().send_transaction
+ )
+ mock_send_transaction.return_value = make_awaitable({})
+
+ # Create receipts for:
+ #
+ # * The same room / user on multiple threads.
+ # * A different user in the same room.
+ sender = self.hs.get_federation_sender()
+ for user, thread in (
+ ("alice", None),
+ ("alice", "thread"),
+ ("bob", None),
+ ("bob", "diff-thread"),
+ ):
+ receipt = ReadReceipt(
+ "room_id",
+ "m.read",
+ user,
+ ["event_id"],
+ thread_id=thread,
+ data={"ts": 1234},
+ )
+ self.successResultOf(
+ defer.ensureDeferred(sender.send_read_receipt(receipt))
+ )
+
+ self.pump()
+
+ # expect a call to send_transaction with two EDUs to separate threads.
+ mock_send_transaction.assert_called_once()
+ json_cb = mock_send_transaction.call_args[0][1]
+ data = json_cb()
+ # Note that the ordering of the EDUs doesn't matter.
+ self.assertCountEqual(
+ data["edus"],
+ [
+ {
+ "edu_type": EduTypes.RECEIPT,
+ "content": {
+ "room_id": {
+ "m.read": {
+ "alice": {
+ "event_ids": ["event_id"],
+ "data": {"ts": 1234, "thread_id": "thread"},
+ },
+ "bob": {
+ "event_ids": ["event_id"],
+ "data": {"ts": 1234, "thread_id": "diff-thread"},
+ },
+ }
+ }
+ },
+ },
+ {
+ "edu_type": EduTypes.RECEIPT,
+ "content": {
+ "room_id": {
+ "m.read": {
+ "alice": {
+ "event_ids": ["event_id"],
+ "data": {"ts": 1234},
+ },
+ "bob": {
+ "event_ids": ["event_id"],
+ "data": {"ts": 1234},
+ },
+ }
+ }
+ },
+ },
+ ],
+ )
+
def test_send_receipts_with_backoff(self):
"""Send two receipts in quick succession; the second should be flushed, but
only after 20ms"""
@@ -170,6 +260,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
class FederationSenderDevicesTestCases(HomeserverTestCase):
+ """
+ Test federation sending to update devices.
+
+ By default for test cases federation sending is disabled. This Test class has it
+ re-enabled for the main process.
+ """
+
servlets = [
admin.register_servlets,
login.register_servlets,
@@ -184,7 +281,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
def default_config(self):
c = super().default_config()
- c["send_federation"] = True
+ # Enable federation sending on the main process.
+ c["federation_sender_instances"] = None
return c
def prepare(self, reactor, clock, hs):
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 144e49d0fd..57bfbd7734 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -25,7 +25,7 @@ import synapse.storage
from synapse.api.constants import EduTypes, EventTypes
from synapse.appservice import (
ApplicationService,
- TransactionOneTimeKeyCounts,
+ TransactionOneTimeKeysCount,
TransactionUnusedFallbackKeys,
)
from synapse.handlers.appservice import ApplicationServicesHandler
@@ -765,7 +765,12 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
fake_device_ids = [f"device_{num}" for num in range(number_of_messages - 1)]
messages = {
self.exclusive_as_user: {
- device_id: to_device_message_content for device_id in fake_device_ids
+ device_id: {
+ "type": "test_to_device_message",
+ "sender": "@some:sender",
+ "content": to_device_message_content,
+ }
+ for device_id in fake_device_ids
}
}
@@ -1123,7 +1128,7 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
# Capture what was sent as an AS transaction.
self.send_mock.assert_called()
last_args, _last_kwargs = self.send_mock.call_args
- otks: Optional[TransactionOneTimeKeyCounts] = last_args[self.ARG_OTK_COUNTS]
+ otks: Optional[TransactionOneTimeKeysCount] = last_args[self.ARG_OTK_COUNTS]
unused_fallbacks: Optional[TransactionUnusedFallbackKeys] = last_args[
self.ARG_FALLBACK_KEYS
]
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index c5981ff965..584e7b8971 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -992,7 +992,8 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
def default_config(self):
config = super().default_config()
- config["send_federation"] = True
+ # Enable federation sending on the main process.
+ config["federation_sender_instances"] = None
return config
def prepare(self, reactor, clock, hs):
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 9c821b3042..efbb5a8dbb 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -200,7 +200,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
],
)
- @override_config({"send_federation": True})
+ # Enable federation sending on the main process.
+ @override_config({"federation_sender_instances": None})
def test_started_typing_remote_send(self) -> None:
self.room_members = [U_APPLE, U_ONION]
@@ -305,7 +306,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEqual(events[0], [])
self.assertEqual(events[1], 0)
- @override_config({"send_federation": True})
+ # Enable federation sending on the main process.
+ @override_config({"federation_sender_instances": None})
def test_stopped_typing(self) -> None:
self.room_members = [U_APPLE, U_BANANA, U_ONION]
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 9e39cd97e5..75fc5a17a4 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -56,7 +56,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
- config["update_user_directory"] = True
+ # Re-enables updating the user directory, as that function is needed below.
+ config["update_user_directory_from_worker"] = None
self.appservice = ApplicationService(
token="i_am_an_app_service",
@@ -1045,7 +1046,9 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
- config["update_user_directory"] = True
+ # Re-enables updating the user directory, as that function is needed below. It
+ # will be force disabled later
+ config["update_user_directory_from_worker"] = None
hs = self.setup_test_homeserver(config=config)
self.config = hs.config
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 058ca57e55..b0f3f4374d 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -336,7 +336,8 @@ class ModuleApiTestCase(HomeserverTestCase):
# Test sending local online presence to users from the main process
_test_sending_local_online_presence_to_local_user(self, test_with_workers=False)
- @override_config({"send_federation": True})
+ # Enable federation sending on the main process.
+ @override_config({"federation_sender_instances": None})
def test_send_local_online_presence_to_federation(self):
"""Tests that send_local_presence_to_users sends local online presence to remote users."""
# Create a user who will send presence updates
diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py
index 594e7937a8..1cd453248e 100644
--- a/tests/push/test_bulk_push_rule_evaluator.py
+++ b/tests/push/test_bulk_push_rule_evaluator.py
@@ -6,10 +6,11 @@ from synapse.rest import admin
from synapse.rest.client import login, register, room
from synapse.types import create_requester
-from tests import unittest
+from tests.test_utils import simple_async_mock
+from tests.unittest import HomeserverTestCase, override_config
-class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase):
+class TestBulkPushRuleEvaluator(HomeserverTestCase):
servlets = [
admin.register_servlets_for_client_rest_resource,
@@ -72,3 +73,43 @@ class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase):
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
# should not raise
self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)]))
+
+ @override_config({"push": {"enabled": False}})
+ def test_action_for_event_by_user_disabled_by_config(self) -> None:
+ """Ensure that push rules are not calculated when disabled in the config"""
+ # Create a new user and room.
+ alice = self.register_user("alice", "pass")
+ token = self.login(alice, "pass")
+
+ room_id = self.helper.create_room_as(
+ alice, room_version=RoomVersions.V9.identifier, tok=token
+ )
+
+ # Alter the power levels in that room to include stringy and floaty levels.
+ # We need to suppress the validation logic or else it will reject these dodgy
+ # values. (Presumably this validation was not always present.)
+ event_creation_handler = self.hs.get_event_creation_handler()
+ requester = create_requester(alice)
+
+ # Create a new message event, and try to evaluate it under the dodgy
+ # power level event.
+ event, context = self.get_success(
+ event_creation_handler.create_event(
+ requester,
+ {
+ "type": "m.room.message",
+ "room_id": room_id,
+ "content": {
+ "msgtype": "m.text",
+ "body": "helo",
+ },
+ "sender": alice,
+ },
+ )
+ )
+
+ bulk_evaluator = BulkPushRuleEvaluator(self.hs)
+ bulk_evaluator._action_for_event_by_user = simple_async_mock() # type: ignore[assignment]
+ # should not raise
+ self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)]))
+ bulk_evaluator._action_for_event_by_user.assert_not_called()
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index fd14568f55..57b2f0536e 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -66,7 +66,6 @@ class EmailPusherTests(HomeserverTestCase):
"riot_base_url": None,
}
config["public_baseurl"] = "http://aaa"
- config["start_pushers"] = True
hs = self.setup_test_homeserver(config=config)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index b383b8401f..afaafe79aa 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Optional, Tuple
+from typing import List, Optional, Tuple
from unittest.mock import Mock
from twisted.internet.defer import Deferred
@@ -41,11 +41,6 @@ class HTTPPusherTests(HomeserverTestCase):
user_id = True
hijack_auth = False
- def default_config(self) -> Dict[str, Any]:
- config = super().default_config()
- config["start_pushers"] = True
- return config
-
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.push_attempts: List[Tuple[Deferred, str, dict]] = []
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index fe7c145840..5ababe6a39 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -62,6 +62,8 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
power_levels.get("notifications", {}),
{} if related_events is None else related_events,
True,
+ event.room_version.msc3931_push_features,
+ True,
)
def test_display_name(self) -> None:
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 3029a16dda..6a7174b333 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -307,7 +307,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
stream to the master HS.
Args:
- worker_app: Type of worker, e.g. `synapse.app.federation_sender`.
+ worker_app: Type of worker, e.g. `synapse.app.generic_worker`.
extra_config: Any extra config to use for this instances.
**kwargs: Options that get passed to `self.setup_test_homeserver`,
useful to e.g. pass some mocks for things like `federation_http_client`
diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py
index ffec06a0d6..bcb82c9c80 100644
--- a/tests/replication/tcp/streams/test_federation.py
+++ b/tests/replication/tcp/streams/test_federation.py
@@ -22,9 +22,8 @@ class FederationStreamTestCase(BaseStreamTestCase):
def _get_worker_hs_config(self) -> dict:
# enable federation sending on the worker
config = super()._get_worker_hs_config()
- # TODO: make it so we don't need both of these
- config["send_federation"] = False
- config["worker_app"] = "synapse.app.federation_sender"
+ config["worker_name"] = "federation_sender1"
+ config["federation_sender_instances"] = ["federation_sender1"]
return config
def test_catchup(self):
diff --git a/tests/replication/tcp/streams/test_partial_state.py b/tests/replication/tcp/streams/test_partial_state.py
new file mode 100644
index 0000000000..2c10eab4db
--- /dev/null
+++ b/tests/replication/tcp/streams/test_partial_state.py
@@ -0,0 +1,65 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet.defer import ensureDeferred
+
+from synapse.rest.client import room
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+
+
+class PartialStateStreamsTestCase(BaseMultiWorkerStreamTestCase):
+ servlets = [room.register_servlets]
+ hijack_auth = True
+ user_id = "@bob:test"
+
+ def setUp(self):
+ super().setUp()
+ self.store = self.hs.get_datastores().main
+
+ def test_un_partial_stated_room_unblocks_over_replication(self) -> None:
+ """
+ Tests that, when a room is un-partial-stated on another worker,
+ pending calls to `await_full_state` get unblocked.
+ """
+
+ # Make a room.
+ room_id = self.helper.create_room_as("@bob:test")
+ # Mark the room as partial-stated.
+ self.get_success(
+ self.store.store_partial_state_room(room_id, ["serv1", "serv2"], 0, "serv1")
+ )
+
+ worker = self.make_worker_hs("synapse.app.generic_worker")
+
+ # On the worker, attempt to get the current hosts in the room
+ d = ensureDeferred(
+ worker.get_storage_controllers().state.get_current_hosts_in_room(room_id)
+ )
+
+ self.reactor.advance(0.1)
+
+ # This should block
+ self.assertFalse(
+ d.called, "get_current_hosts_in_room/await_full_state did not block"
+ )
+
+ # On the master, clear the partial state flag.
+ self.get_success(self.store.clear_partial_state_room(room_id))
+
+ self.reactor.advance(0.1)
+
+ # The worker should have unblocked
+ self.assertTrue(
+ d.called, "get_current_hosts_in_room/await_full_state did not unblock"
+ )
diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py
index 43a16bb141..5d7a89e0c7 100644
--- a/tests/replication/test_auth.py
+++ b/tests/replication/test_auth.py
@@ -38,7 +38,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
- config["worker_app"] = "synapse.app.client_reader"
+ config["worker_app"] = "synapse.app.generic_worker"
config["worker_replication_host"] = "testserv"
config["worker_replication_http_port"] = "8765"
@@ -53,7 +53,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
4. Return the final request.
"""
- worker_hs = self.make_worker_hs("synapse.app.client_reader")
+ worker_hs = self.make_worker_hs("synapse.app.generic_worker")
site = self._hs_to_site[worker_hs]
channel_1 = make_request(
diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
index 995097d72c..eb5b376534 100644
--- a/tests/replication/test_client_reader_shard.py
+++ b/tests/replication/test_client_reader_shard.py
@@ -22,20 +22,20 @@ logger = logging.getLogger(__name__)
class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
- """Test using one or more client readers for registration."""
+ """Test using one or more generic workers for registration."""
servlets = [register.register_servlets]
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
- config["worker_app"] = "synapse.app.client_reader"
+ config["worker_app"] = "synapse.app.generic_worker"
config["worker_replication_host"] = "testserv"
config["worker_replication_http_port"] = "8765"
return config
def test_register_single_worker(self):
- """Test that registration works when using a single client reader worker."""
- worker_hs = self.make_worker_hs("synapse.app.client_reader")
+ """Test that registration works when using a single generic worker."""
+ worker_hs = self.make_worker_hs("synapse.app.generic_worker")
site = self._hs_to_site[worker_hs]
channel_1 = make_request(
@@ -64,9 +64,9 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(channel_2.json_body["user_id"], "@user:test")
def test_register_multi_worker(self):
- """Test that registration works when using multiple client reader workers."""
- worker_hs_1 = self.make_worker_hs("synapse.app.client_reader")
- worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")
+ """Test that registration works when using multiple generic workers."""
+ worker_hs_1 = self.make_worker_hs("synapse.app.generic_worker")
+ worker_hs_2 = self.make_worker_hs("synapse.app.generic_worker")
site_1 = self._hs_to_site[worker_hs_1]
channel_1 = make_request(
diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index 26b8bd512a..63b1dd40b5 100644
--- a/tests/replication/test_federation_ack.py
+++ b/tests/replication/test_federation_ack.py
@@ -25,8 +25,9 @@ from tests.unittest import HomeserverTestCase
class FederationAckTestCase(HomeserverTestCase):
def default_config(self) -> dict:
config = super().default_config()
- config["worker_app"] = "synapse.app.federation_sender"
- config["send_federation"] = False
+ config["worker_app"] = "synapse.app.generic_worker"
+ config["worker_name"] = "federation_sender1"
+ config["federation_sender_instances"] = ["federation_sender1"]
return config
def make_homeserver(self, reactor, clock):
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 6104a55aa1..c28073b8f7 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -27,17 +27,19 @@ logger = logging.getLogger(__name__)
class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
+ """
+ Various tests for federation sending on workers.
+
+ Federation sending is disabled by default, it will be enabled in each test by
+ updating 'federation_sender_instances'.
+ """
+
servlets = [
login.register_servlets,
register_servlets_for_client_rest_resource,
room.register_servlets,
]
- def default_config(self):
- conf = super().default_config()
- conf["send_federation"] = False
- return conf
-
def test_send_event_single_sender(self):
"""Test that using a single federation sender worker correctly sends a
new event.
@@ -46,8 +48,11 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
mock_client.put_json.return_value = make_awaitable({})
self.make_worker_hs(
- "synapse.app.federation_sender",
- {"send_federation": False},
+ "synapse.app.generic_worker",
+ {
+ "worker_name": "federation_sender1",
+ "federation_sender_instances": ["federation_sender1"],
+ },
federation_http_client=mock_client,
)
@@ -73,11 +78,13 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
mock_client1 = Mock(spec=["put_json"])
mock_client1.put_json.return_value = make_awaitable({})
self.make_worker_hs(
- "synapse.app.federation_sender",
+ "synapse.app.generic_worker",
{
- "send_federation": True,
- "worker_name": "sender1",
- "federation_sender_instances": ["sender1", "sender2"],
+ "worker_name": "federation_sender1",
+ "federation_sender_instances": [
+ "federation_sender1",
+ "federation_sender2",
+ ],
},
federation_http_client=mock_client1,
)
@@ -85,11 +92,13 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
mock_client2 = Mock(spec=["put_json"])
mock_client2.put_json.return_value = make_awaitable({})
self.make_worker_hs(
- "synapse.app.federation_sender",
+ "synapse.app.generic_worker",
{
- "send_federation": True,
- "worker_name": "sender2",
- "federation_sender_instances": ["sender1", "sender2"],
+ "worker_name": "federation_sender2",
+ "federation_sender_instances": [
+ "federation_sender1",
+ "federation_sender2",
+ ],
},
federation_http_client=mock_client2,
)
@@ -136,11 +145,13 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
mock_client1 = Mock(spec=["put_json"])
mock_client1.put_json.return_value = make_awaitable({})
self.make_worker_hs(
- "synapse.app.federation_sender",
+ "synapse.app.generic_worker",
{
- "send_federation": True,
- "worker_name": "sender1",
- "federation_sender_instances": ["sender1", "sender2"],
+ "worker_name": "federation_sender1",
+ "federation_sender_instances": [
+ "federation_sender1",
+ "federation_sender2",
+ ],
},
federation_http_client=mock_client1,
)
@@ -148,11 +159,13 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
mock_client2 = Mock(spec=["put_json"])
mock_client2.put_json.return_value = make_awaitable({})
self.make_worker_hs(
- "synapse.app.federation_sender",
+ "synapse.app.generic_worker",
{
- "send_federation": True,
- "worker_name": "sender2",
- "federation_sender_instances": ["sender1", "sender2"],
+ "worker_name": "federation_sender2",
+ "federation_sender_instances": [
+ "federation_sender1",
+ "federation_sender2",
+ ],
},
federation_http_client=mock_client2,
)
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 59fea93e49..ca18ad6553 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -38,11 +38,6 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
self.other_user_id = self.register_user("otheruser", "pass")
self.other_access_token = self.login("otheruser", "pass")
- def default_config(self):
- conf = super().default_config()
- conf["start_pushers"] = False
- return conf
-
def _create_pusher_and_send_msg(self, localpart):
# Create a user that will get push notifications
user_id = self.register_user(localpart, "pass")
@@ -92,8 +87,8 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
)
self.make_worker_hs(
- "synapse.app.pusher",
- {"start_pushers": False},
+ "synapse.app.generic_worker",
+ {"worker_name": "pusher1", "pusher_instances": ["pusher1"]},
proxied_blacklisted_http_client=http_client_mock,
)
@@ -122,9 +117,8 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
)
self.make_worker_hs(
- "synapse.app.pusher",
+ "synapse.app.generic_worker",
{
- "start_pushers": True,
"worker_name": "pusher1",
"pusher_instances": ["pusher1", "pusher2"],
},
@@ -137,9 +131,8 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
)
self.make_worker_hs(
- "synapse.app.pusher",
+ "synapse.app.generic_worker",
{
- "start_pushers": True,
"worker_name": "pusher2",
"pusher_instances": ["pusher1", "pusher2"],
},
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index e8c9457794..5c1ced355f 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -3994,7 +3994,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
"""
Tests that shadow-banning for a user that is not a local returns a 400
"""
- url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
+ url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/shadow_ban"
channel = self.make_request(method, url, access_token=self.admin_user_tok)
self.assertEqual(400, channel.code, msg=channel.json_body)
diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py
index c2e1e08811..6aedc1a11c 100644
--- a/tests/rest/client/test_login_token_request.py
+++ b/tests/rest/client/test_login_token_request.py
@@ -48,13 +48,13 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
def test_disabled(self) -> None:
channel = self.make_request("POST", endpoint, {}, access_token=None)
- self.assertEqual(channel.code, 400)
+ self.assertEqual(channel.code, 404)
self.register_user(self.user, self.password)
token = self.login(self.user, self.password)
channel = self.make_request("POST", endpoint, {}, access_token=token)
- self.assertEqual(channel.code, 400)
+ self.assertEqual(channel.code, 404)
@override_config({"experimental_features": {"msc3882_enabled": True}})
def test_require_auth(self) -> None:
diff --git a/tests/rest/client/test_receipts.py b/tests/rest/client/test_receipts.py
new file mode 100644
index 0000000000..2a7fcea386
--- /dev/null
+++ b/tests/rest/client/test_receipts.py
@@ -0,0 +1,76 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
+import synapse.rest.admin
+from synapse.rest.client import login, receipts, register
+from synapse.server import HomeServer
+from synapse.util import Clock
+
+from tests import unittest
+
+
+class ReceiptsTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ login.register_servlets,
+ register.register_servlets,
+ receipts.register_servlets,
+ synapse.rest.admin.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.owner = self.register_user("owner", "pass")
+ self.owner_tok = self.login("owner", "pass")
+
+ def test_send_receipt(self) -> None:
+ channel = self.make_request(
+ "POST",
+ "/rooms/!abc:beep/receipt/m.read/$def",
+ content={},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ def test_send_receipt_invalid_room_id(self) -> None:
+ channel = self.make_request(
+ "POST",
+ "/rooms/not-a-room-id/receipt/m.read/$def",
+ content={},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["error"], "A valid room ID and event ID must be specified"
+ )
+
+ def test_send_receipt_invalid_event_id(self) -> None:
+ channel = self.make_request(
+ "POST",
+ "/rooms/!abc:beep/receipt/m.read/not-an-event-id",
+ content={},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(
+ channel.json_body["error"], "A valid room ID and event ID must be specified"
+ )
+
+ def test_send_receipt_invalid_receipt_type(self) -> None:
+ channel = self.make_request(
+ "POST",
+ "/rooms/!abc:beep/receipt/invalid-receipt-type/$def",
+ content={},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py
index ad00a476e1..c0eb5d01a6 100644
--- a/tests/rest/client/test_rendezvous.py
+++ b/tests/rest/client/test_rendezvous.py
@@ -36,7 +36,7 @@ class RendezvousServletTestCase(unittest.HomeserverTestCase):
def test_disabled(self) -> None:
channel = self.make_request("POST", endpoint, {}, access_token=None)
- self.assertEqual(channel.code, 400)
+ self.assertEqual(channel.code, 404)
@override_config({"experimental_features": {"msc3886_endpoint": "/asd"}})
def test_redirect(self) -> None:
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index e919e089cb..b4daace556 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -3546,11 +3546,6 @@ class TimestampLookupTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def default_config(self) -> JsonDict:
- config = super().default_config()
- config["experimental_features"] = {"msc3030_enabled": True}
- return config
-
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self._storage_controllers = self.hs.get_storage_controllers()
@@ -3592,7 +3587,7 @@ class TimestampLookupTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- f"/_matrix/client/unstable/org.matrix.msc3030/rooms/{room_id}/timestamp_to_event?dir=b&ts={outlier_event.origin_server_ts}",
+ f"/_matrix/client/v1/rooms/{room_id}/timestamp_to_event?dir=b&ts={outlier_event.origin_server_ts}",
access_token=self.room_owner_tok,
)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 7f1fba1086..2bb6e27d94 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import urllib.parse
from io import BytesIO, StringIO
from typing import Any, Dict, Optional, Union
from unittest.mock import Mock
@@ -65,9 +64,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
self.assertTrue(ignore_backoff)
self.assertEqual(destination, server_name)
key_id = "%s:%s" % (signing_key.alg, signing_key.version)
- self.assertEqual(
- path, "/_matrix/key/v2/server/%s" % (urllib.parse.quote(key_id),)
- )
+ self.assertEqual(path, "/_matrix/key/v2/server")
response = {
"server_name": server_name,
diff --git a/tests/storage/databases/main/test_deviceinbox.py b/tests/storage/databases/main/test_deviceinbox.py
index 50c20c5b92..373707b275 100644
--- a/tests/storage/databases/main/test_deviceinbox.py
+++ b/tests/storage/databases/main/test_deviceinbox.py
@@ -12,8 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.rest import admin
from synapse.rest.client import devices
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -25,11 +29,11 @@ class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase):
devices.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.user_id = self.register_user("foo", "pass")
- def test_background_remove_deleted_devices_from_device_inbox(self):
+ def test_background_remove_deleted_devices_from_device_inbox(self) -> None:
"""Test that the background task to delete old device_inboxes works properly."""
# create a valid device
@@ -89,7 +93,7 @@ class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase):
self.assertEqual(1, len(res))
self.assertEqual(res[0], "cur_device")
- def test_background_remove_hidden_devices_from_device_inbox(self):
+ def test_background_remove_hidden_devices_from_device_inbox(self) -> None:
"""Test that the background task to delete hidden devices
from device_inboxes works properly."""
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 5773172ab8..9f33afcca0 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -45,7 +45,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.hs = hs
self.store: EventsWorkerStore = hs.get_datastores().main
@@ -68,7 +68,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
self.event_ids.append(event.event_id)
- def test_simple(self):
+ def test_simple(self) -> None:
with LoggingContext(name="test") as ctx:
res = self.get_success(
self.store.have_seen_events(
@@ -90,7 +90,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
self.assertEqual(res, {self.event_ids[0]})
self.assertEqual(ctx.get_resource_usage().db_txn_count, 0)
- def test_persisting_event_invalidates_cache(self):
+ def test_persisting_event_invalidates_cache(self) -> None:
"""
Test to make sure that the `have_seen_event` cache
is invalidated after we persist an event and returns
@@ -138,7 +138,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
# That should result in a single db query to lookup
self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
- def test_invalidate_cache_by_room_id(self):
+ def test_invalidate_cache_by_room_id(self) -> None:
"""
Test to make sure that all events associated with the given `(room_id,)`
are invalidated in the `have_seen_event` cache.
@@ -175,7 +175,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store: EventsWorkerStore = hs.get_datastores().main
self.user = self.register_user("user", "pass")
@@ -189,7 +189,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
# Reset the event cache so the tests start with it empty
self.get_success(self.store._get_event_cache.clear())
- def test_simple(self):
+ def test_simple(self) -> None:
"""Test that we cache events that we pull from the DB."""
with LoggingContext("test") as ctx:
@@ -198,7 +198,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
# We should have fetched the event from the DB
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
- def test_event_ref(self):
+ def test_event_ref(self) -> None:
"""Test that we reuse events that are still in memory but have fallen
out of the cache, rather than requesting them from the DB.
"""
@@ -223,7 +223,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
# from the DB
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 0)
- def test_dedupe(self):
+ def test_dedupe(self) -> None:
"""Test that if we request the same event multiple times we only pull it
out once.
"""
@@ -241,7 +241,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
class DatabaseOutageTestCase(unittest.HomeserverTestCase):
"""Test event fetching during a database outage."""
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store: EventsWorkerStore = hs.get_datastores().main
self.room_id = f"!room:{hs.hostname}"
@@ -377,7 +377,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store: EventsWorkerStore = hs.get_datastores().main
self.user = self.register_user("user", "pass")
@@ -412,7 +412,8 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase):
unblock: "Deferred[None]" = Deferred()
original_runWithConnection = self.store.db_pool.runWithConnection
- async def runWithConnection(*args, **kwargs):
+ # Don't bother with the types here, we just pass into the original function.
+ async def runWithConnection(*args, **kwargs): # type: ignore[no-untyped-def]
await unblock
return await original_runWithConnection(*args, **kwargs)
@@ -441,7 +442,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase):
self.assertEqual(ctx1.get_resource_usage().evt_db_fetch_count, 1)
self.assertEqual(ctx2.get_resource_usage().evt_db_fetch_count, 0)
- def test_first_get_event_cancelled(self):
+ def test_first_get_event_cancelled(self) -> None:
"""Test cancellation of the first `get_event` call sharing a database fetch.
The first `get_event` call is the one which initiates the fetch. We expect the
@@ -467,7 +468,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase):
# The second `get_event` call should complete successfully.
self.get_success(get_event2)
- def test_second_get_event_cancelled(self):
+ def test_second_get_event_cancelled(self) -> None:
"""Test cancellation of the second `get_event` call sharing a database fetch."""
with self.blocking_get_event_calls() as (unblock, get_event1, get_event2):
# Cancel the second `get_event` call.
diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py
index 3cc2a58d8d..56cb49d9b5 100644
--- a/tests/storage/databases/main/test_lock.py
+++ b/tests/storage/databases/main/test_lock.py
@@ -15,18 +15,20 @@
from twisted.internet import defer, reactor
from twisted.internet.base import ReactorBase
from twisted.internet.defer import Deferred
+from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer
from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS
+from synapse.util import Clock
from tests import unittest
class LockTestCase(unittest.HomeserverTestCase):
- def prepare(self, reactor, clock, hs: HomeServer):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
- def test_acquire_contention(self):
+ def test_acquire_contention(self) -> None:
# Track the number of tasks holding the lock.
# Should be at most 1.
in_lock = 0
@@ -34,7 +36,7 @@ class LockTestCase(unittest.HomeserverTestCase):
release_lock: "Deferred[None]" = Deferred()
- async def task():
+ async def task() -> None:
nonlocal in_lock
nonlocal max_in_lock
@@ -76,7 +78,7 @@ class LockTestCase(unittest.HomeserverTestCase):
# At most one task should have held the lock at a time.
self.assertEqual(max_in_lock, 1)
- def test_simple_lock(self):
+ def test_simple_lock(self) -> None:
"""Test that we can take out a lock and that while we hold it nobody
else can take it out.
"""
@@ -103,7 +105,7 @@ class LockTestCase(unittest.HomeserverTestCase):
self.get_success(lock3.__aenter__())
self.get_success(lock3.__aexit__(None, None, None))
- def test_maintain_lock(self):
+ def test_maintain_lock(self) -> None:
"""Test that we don't time out locks while they're still active"""
lock = self.get_success(self.store.try_acquire_lock("name", "key"))
@@ -119,7 +121,7 @@ class LockTestCase(unittest.HomeserverTestCase):
self.get_success(lock.__aexit__(None, None, None))
- def test_timeout_lock(self):
+ def test_timeout_lock(self) -> None:
"""Test that we time out locks if they're not updated for ages"""
lock = self.get_success(self.store.try_acquire_lock("name", "key"))
@@ -139,7 +141,7 @@ class LockTestCase(unittest.HomeserverTestCase):
self.assertFalse(self.get_success(lock.is_still_valid()))
- def test_drop(self):
+ def test_drop(self) -> None:
"""Test that dropping the context manager means we stop renewing the lock"""
lock = self.get_success(self.store.try_acquire_lock("name", "key"))
@@ -153,7 +155,7 @@ class LockTestCase(unittest.HomeserverTestCase):
lock2 = self.get_success(self.store.try_acquire_lock("name", "key"))
self.assertIsNotNone(lock2)
- def test_shutdown(self):
+ def test_shutdown(self) -> None:
"""Test that shutting down Synapse releases the locks"""
# Acquire two locks
lock = self.get_success(self.store.try_acquire_lock("name", "key1"))
diff --git a/tests/storage/databases/main/test_receipts.py b/tests/storage/databases/main/test_receipts.py
index c4f12d81d7..68026e2830 100644
--- a/tests/storage/databases/main/test_receipts.py
+++ b/tests/storage/databases/main/test_receipts.py
@@ -33,7 +33,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.user_id = self.register_user("foo", "pass")
self.token = self.login("foo", "pass")
@@ -47,7 +47,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
table: str,
receipts: Dict[Tuple[str, str, str], Sequence[Dict[str, Any]]],
expected_unique_receipts: Dict[Tuple[str, str, str], Optional[Dict[str, Any]]],
- ):
+ ) -> None:
"""Test that the background update to uniqueify non-thread receipts in
the given receipts table works properly.
@@ -154,7 +154,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
f"Background update did not remove all duplicate receipts from {table}",
)
- def test_background_receipts_linearized_unique_index(self):
+ def test_background_receipts_linearized_unique_index(self) -> None:
"""Test that the background update to uniqueify non-thread receipts in
`receipts_linearized` works properly.
"""
@@ -177,7 +177,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
},
)
- def test_background_receipts_graph_unique_index(self):
+ def test_background_receipts_graph_unique_index(self) -> None:
"""Test that the background update to uniqueify non-thread receipts in
`receipts_graph` works properly.
"""
diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py
index 1edb619630..7d961fac64 100644
--- a/tests/storage/databases/main/test_room.py
+++ b/tests/storage/databases/main/test_room.py
@@ -14,10 +14,14 @@
import json
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import RoomTypes
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.server import HomeServer
from synapse.storage.databases.main.room import _BackgroundUpdates
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -30,7 +34,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.user_id = self.register_user("foo", "pass")
self.token = self.login("foo", "pass")
@@ -40,7 +44,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):
return room_id
- def test_background_populate_rooms_creator_column(self):
+ def test_background_populate_rooms_creator_column(self) -> None:
"""Test that the background update to populate the rooms creator column
works properly.
"""
@@ -95,7 +99,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):
)
self.assertEqual(room_creator_after, self.user_id)
- def test_background_add_room_type_column(self):
+ def test_background_add_room_type_column(self) -> None:
"""Test that the background update to populate the `room_type` column in
`room_stats_state` works properly.
"""
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 09cb06d614..8bbf936ae9 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -106,7 +106,7 @@ class UpdateUpsertManyTests(unittest.HomeserverTestCase):
{(1, "user1", "hello"), (2, "user2", "bleb")},
)
- def test_simple_update_many(self):
+ def test_simple_update_many(self) -> None:
"""
simple_update_many performs many updates at once.
"""
diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py
index 72bf5b3d31..1bfd11ceae 100644
--- a/tests/storage/test_account_data.py
+++ b/tests/storage/test_account_data.py
@@ -14,13 +14,17 @@
from typing import Iterable, Optional, Set
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import AccountDataTypes
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests import unittest
class IgnoredUsersTestCase(unittest.HomeserverTestCase):
- def prepare(self, hs, reactor, clock):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = self.hs.get_datastores().main
self.user = "@user:test"
@@ -55,7 +59,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
expected_ignored_user_ids,
)
- def test_ignoring_users(self):
+ def test_ignoring_users(self) -> None:
"""Basic adding/removing of users from the ignore list."""
self._update_ignore_list("@other:test", "@another:remote")
self.assert_ignored(self.user, {"@other:test", "@another:remote"})
@@ -82,7 +86,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
# Check the removed user.
self.assert_ignorers("@another:remote", {self.user})
- def test_caching(self):
+ def test_caching(self) -> None:
"""Ensure that caching works properly between different users."""
# The first user ignores a user.
self._update_ignore_list("@other:test")
@@ -99,7 +103,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
self.assert_ignored(self.user, set())
self.assert_ignorers("@other:test", {"@second:test"})
- def test_invalid_data(self):
+ def test_invalid_data(self) -> None:
"""Invalid data ends up clearing out the ignored users list."""
# Add some data and ensure it is there.
self._update_ignore_list("@other:test")
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 1047ed09c8..5e1324a169 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -26,7 +26,7 @@ from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.config._base import ConfigError
from synapse.events import EventBase
from synapse.server import HomeServer
-from synapse.storage.database import DatabasePool, make_conn
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection, make_conn
from synapse.storage.databases.main.appservice import (
ApplicationServiceStore,
ApplicationServiceTransactionStore,
@@ -39,7 +39,7 @@ from tests.test_utils import make_awaitable
class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
- def setUp(self):
+ def setUp(self) -> None:
super(ApplicationServiceStoreTestCase, self).setUp()
self.as_yaml_files: List[str] = []
@@ -73,7 +73,9 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
super(ApplicationServiceStoreTestCase, self).tearDown()
- def _add_appservice(self, as_token, id, url, hs_token, sender) -> None:
+ def _add_appservice(
+ self, as_token: str, id: str, url: str, hs_token: str, sender: str
+ ) -> None:
as_yaml = {
"url": url,
"as_token": as_token,
@@ -135,7 +137,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
database, make_conn(db_config, self.engine, "test"), self.hs
)
- def _add_service(self, url, as_token, id) -> None:
+ def _add_service(self, url: str, as_token: str, id: str) -> None:
as_yaml = {
"url": url,
"as_token": as_token,
@@ -149,7 +151,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
outfile.write(yaml.dump(as_yaml))
self.as_yaml_files.append(as_token)
- def _set_state(self, id: str, state: ApplicationServiceState):
+ def _set_state(self, id: str, state: ApplicationServiceState) -> defer.Deferred:
return self.db_pool.runOperation(
self.engine.convert_param_style(
"INSERT INTO application_services_state(as_id, state) VALUES(?,?)"
@@ -157,7 +159,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
(id, state.value),
)
- def _insert_txn(self, as_id, txn_id, events):
+ def _insert_txn(
+ self, as_id: str, txn_id: int, events: List[Mock]
+ ) -> "defer.Deferred[None]":
return self.db_pool.runOperation(
self.engine.convert_param_style(
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
@@ -448,12 +452,14 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
# required for ApplicationServiceTransactionStoreTestCase tests
class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
- def __init__(self, database: DatabasePool, db_conn, hs) -> None:
+ def __init__(
+ self, database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: HomeServer
+ ) -> None:
super().__init__(database, db_conn, hs)
class ApplicationServiceStoreConfigTestCase(unittest.HomeserverTestCase):
- def _write_config(self, suffix, **kwargs) -> str:
+ def _write_config(self, suffix: str, **kwargs: str) -> str:
vals = {
"id": "id" + suffix,
"url": "url" + suffix,
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 40e58f8199..256d28e4c9 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from collections import OrderedDict
+from typing import Generator
from unittest.mock import Mock
from twisted.internet import defer
@@ -30,7 +30,7 @@ from tests.utils import default_config
class SQLBaseStoreTestCase(unittest.TestCase):
"""Test the "simple" SQL generating methods in SQLBaseStore."""
- def setUp(self):
+ def setUp(self) -> None:
self.db_pool = Mock(spec=["runInteraction"])
self.mock_txn = Mock()
self.mock_conn = Mock(spec_set=["cursor", "rollback", "commit"])
@@ -38,12 +38,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_conn.rollback.return_value = None
# Our fake runInteraction just runs synchronously inline
- def runInteraction(func, *args, **kwargs):
+ def runInteraction(func, *args, **kwargs) -> defer.Deferred: # type: ignore[no-untyped-def]
return defer.succeed(func(self.mock_txn, *args, **kwargs))
self.db_pool.runInteraction = runInteraction
- def runWithConnection(func, *args, **kwargs):
+ def runWithConnection(func, *args, **kwargs): # type: ignore[no-untyped-def]
return defer.succeed(func(self.mock_conn, *args, **kwargs))
self.db_pool.runWithConnection = runWithConnection
@@ -62,7 +62,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.datastore = SQLBaseStore(db, None, hs) # type: ignore[arg-type]
@defer.inlineCallbacks
- def test_insert_1col(self):
+ def test_insert_1col(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1
yield defer.ensureDeferred(
@@ -76,7 +76,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)
@defer.inlineCallbacks
- def test_insert_3cols(self):
+ def test_insert_3cols(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1
yield defer.ensureDeferred(
@@ -92,7 +92,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)
@defer.inlineCallbacks
- def test_select_one_1col(self):
+ def test_select_one_1col(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1
self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
@@ -108,7 +108,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)
@defer.inlineCallbacks
- def test_select_one_3col(self):
+ def test_select_one_3col(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1
self.mock_txn.fetchone.return_value = (1, 2, 3)
@@ -126,7 +126,9 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)
@defer.inlineCallbacks
- def test_select_one_missing(self):
+ def test_select_one_missing(
+ self,
+ ) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 0
self.mock_txn.fetchone.return_value = None
@@ -142,7 +144,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.assertFalse(ret)
@defer.inlineCallbacks
- def test_select_list(self):
+ def test_select_list(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 3
self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
self.mock_txn.description = (("colA", None, None, None, None, None, None),)
@@ -159,7 +161,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)
@defer.inlineCallbacks
- def test_update_one_1col(self):
+ def test_update_one_1col(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1
yield defer.ensureDeferred(
@@ -176,7 +178,9 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)
@defer.inlineCallbacks
- def test_update_one_4cols(self):
+ def test_update_one_4cols(
+ self,
+ ) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1
yield defer.ensureDeferred(
@@ -193,7 +197,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
)
@defer.inlineCallbacks
- def test_delete_one(self):
+ def test_delete_one(self) -> Generator["defer.Deferred[object]", object, None]:
self.mock_txn.rowcount = 1
yield defer.ensureDeferred(
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index b998ad42d9..d570684c99 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -15,11 +15,16 @@
import os.path
from unittest.mock import Mock, patch
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse.rest.admin
from synapse.api.constants import EventTypes
from synapse.rest.client import login, room
+from synapse.server import HomeServer
from synapse.storage import prepare_database
+from synapse.storage.types import Cursor
from synapse.types import UserID, create_requester
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -29,7 +34,9 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
Test the background update to clean forward extremities table.
"""
- def prepare(self, reactor, clock, homeserver):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
self.store = homeserver.get_datastores().main
self.room_creator = homeserver.get_room_creation_handler()
@@ -39,7 +46,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"]
- def run_background_update(self):
+ def run_background_update(self) -> None:
"""Re run the background update to clean up the extremities."""
# Make sure we don't clash with in progress updates.
self.assertTrue(
@@ -54,7 +61,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
"delete_forward_extremities.sql",
)
- def run_delta_file(txn):
+ def run_delta_file(txn: Cursor) -> None:
prepare_database.executescript(txn, schema_path)
self.get_success(
@@ -84,7 +91,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
(room_id,)
)
- def test_soft_failed_extremities_handled_correctly(self):
+ def test_soft_failed_extremities_handled_correctly(self) -> None:
"""Test that extremities are correctly calculated in the presence of
soft failed events.
@@ -114,7 +121,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
self.assertEqual(latest_event_ids, [event_id_4])
- def test_basic_cleanup(self):
+ def test_basic_cleanup(self) -> None:
"""Test that extremities are correctly calculated in the presence of
soft failed events.
@@ -149,7 +156,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
)
self.assertEqual(latest_event_ids, [event_id_b])
- def test_chain_of_fail_cleanup(self):
+ def test_chain_of_fail_cleanup(self) -> None:
"""Test that extremities are correctly calculated in the presence of
soft failed events.
@@ -187,7 +194,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
)
self.assertEqual(latest_event_ids, [event_id_b])
- def test_forked_graph_cleanup(self):
+ def test_forked_graph_cleanup(self) -> None:
r"""Test that extremities are correctly calculated in the presence of
soft failed events.
@@ -252,12 +259,14 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
room.register_servlets,
]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["cleanup_extremities_with_dummy_events"] = True
return self.setup_test_homeserver(config=config)
- def prepare(self, reactor, clock, homeserver):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
self.store = homeserver.get_datastores().main
self.room_creator = homeserver.get_room_creation_handler()
self.event_creator_handler = homeserver.get_event_creation_handler()
@@ -273,7 +282,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
self.event_creator = homeserver.get_event_creation_handler()
homeserver.config.consent.user_consent_version = self.CONSENT_VERSION
- def test_send_dummy_event(self):
+ def test_send_dummy_event(self) -> None:
self._create_extremity_rich_graph()
# Pump the reactor repeatedly so that the background updates have a
@@ -286,7 +295,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
@patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0)
- def test_send_dummy_events_when_insufficient_power(self):
+ def test_send_dummy_events_when_insufficient_power(self) -> None:
self._create_extremity_rich_graph()
# Criple power levels
self.helper.send_state(
@@ -317,7 +326,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
@patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=250)
- def test_expiry_logic(self):
+ def test_expiry_logic(self) -> None:
"""Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion()
expires old entries correctly.
"""
@@ -357,7 +366,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
0,
)
- def _create_extremity_rich_graph(self):
+ def _create_extremity_rich_graph(self) -> None:
"""Helper method to create bushy graph on demand"""
event_id_start = self.create_and_send_event(self.room_id, self.user)
@@ -372,7 +381,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
)
self.assertEqual(len(latest_event_ids), 50)
- def _enable_consent_checking(self):
+ def _enable_consent_checking(self) -> None:
"""Helper method to enable consent checking"""
self.event_creator._block_events_without_consent_error = "No consent from user"
consent_uri_builder = Mock()
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 49ad3c1324..7f7f4ef892 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -13,15 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Dict
from unittest.mock import Mock
from parameterized import parameterized
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse.rest.admin
from synapse.http.site import XForwardedForRequest
from synapse.rest.client import login
+from synapse.server import HomeServer
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.types import UserID
+from synapse.util import Clock
from tests import unittest
from tests.server import make_request
@@ -30,14 +35,10 @@ from tests.unittest import override_config
class ClientIpStoreTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver()
- return hs
-
- def prepare(self, hs, reactor, clock):
- self.store = self.hs.get_datastores().main
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
- def test_insert_new_client_ip(self):
+ def test_insert_new_client_ip(self) -> None:
self.reactor.advance(12345678)
user_id = "@user:id"
@@ -76,7 +77,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
r,
)
- def test_insert_new_client_ip_none_device_id(self):
+ def test_insert_new_client_ip_none_device_id(self) -> None:
"""
An insert with a device ID of NULL will not create a new entry, but
update an existing entry in the user_ips table.
@@ -148,7 +149,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
@parameterized.expand([(False,), (True,)])
- def test_get_last_client_ip_by_device(self, after_persisting: bool):
+ def test_get_last_client_ip_by_device(self, after_persisting: bool) -> None:
"""Test `get_last_client_ip_by_device` for persisted and unpersisted data"""
self.reactor.advance(12345678)
@@ -211,7 +212,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
},
)
- def test_get_last_client_ip_by_device_combined_data(self):
+ def test_get_last_client_ip_by_device_combined_data(self) -> None:
"""Test that `get_last_client_ip_by_device` combines persisted and unpersisted
data together correctly
"""
@@ -310,7 +311,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
@parameterized.expand([(False,), (True,)])
- def test_get_user_ip_and_agents(self, after_persisting: bool):
+ def test_get_user_ip_and_agents(self, after_persisting: bool) -> None:
"""Test `get_user_ip_and_agents` for persisted and unpersisted data"""
self.reactor.advance(12345678)
@@ -350,7 +351,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
],
)
- def test_get_user_ip_and_agents_combined_data(self):
+ def test_get_user_ip_and_agents_combined_data(self) -> None:
"""Test that `get_user_ip_and_agents` combines persisted and unpersisted data
together correctly
"""
@@ -427,7 +428,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
@override_config({"limit_usage_by_mau": False, "max_mau_value": 50})
- def test_disabled_monthly_active_user(self):
+ def test_disabled_monthly_active_user(self) -> None:
user_id = "@user:server"
self.get_success(
self.store.insert_client_ip(
@@ -438,7 +439,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.assertFalse(active)
@override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
- def test_adding_monthly_active_user_when_full(self):
+ def test_adding_monthly_active_user_when_full(self) -> None:
lots_of_users = 100
user_id = "@user:server"
@@ -454,7 +455,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.assertFalse(active)
@override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
- def test_adding_monthly_active_user_when_space(self):
+ def test_adding_monthly_active_user_when_space(self) -> None:
user_id = "@user:server"
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active)
@@ -471,7 +472,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.assertTrue(active)
@override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
- def test_updating_monthly_active_user_when_space(self):
+ def test_updating_monthly_active_user_when_space(self) -> None:
user_id = "@user:server"
self.get_success(self.store.register_user(user_id=user_id, password_hash=None))
@@ -489,7 +490,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertTrue(active)
- def test_devices_last_seen_bg_update(self):
+ def test_devices_last_seen_bg_update(self) -> None:
# First make sure we have completed all updates.
self.wait_for_background_updates()
@@ -574,7 +575,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
r,
)
- def test_old_user_ips_pruned(self):
+ def test_old_user_ips_pruned(self) -> None:
# First make sure we have completed all updates.
self.wait_for_background_updates()
@@ -637,11 +638,11 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.assertEqual(result, [])
# But we should still get the correct values for the device
- result = self.get_success(
+ result2 = self.get_success(
self.store.get_last_client_ip_by_device(user_id, device_id)
)
- r = result[(user_id, device_id)]
+ r = result2[(user_id, device_id)]
self.assertDictContainsSubset(
{
"user_id": user_id,
@@ -661,15 +662,11 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
- def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver()
- return hs
-
- def prepare(self, hs, reactor, clock):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = self.hs.get_datastores().main
self.user_id = self.register_user("bob", "abc123", True)
- def test_request_with_xforwarded(self):
+ def test_request_with_xforwarded(self) -> None:
"""
The IP in X-Forwarded-For is entered into the client IPs table.
"""
@@ -679,14 +676,19 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
{"request": XForwardedForRequest},
)
- def test_request_from_getPeer(self):
+ def test_request_from_getPeer(self) -> None:
"""
The IP returned by getPeer is entered into the client IPs table, if
there's no X-Forwarded-For header.
"""
self._runtest({}, "127.0.0.1", {})
- def _runtest(self, headers, expected_ip, make_request_args):
+ def _runtest(
+ self,
+ headers: Dict[bytes, bytes],
+ expected_ip: str,
+ make_request_args: Dict[str, Any],
+ ) -> None:
device_id = "bleb"
access_token = self.login("bob", "abc123", device_id=device_id)
diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py
index a40fc20ef9..543cce6b3e 100644
--- a/tests/storage/test_database.py
+++ b/tests/storage/test_database.py
@@ -31,7 +31,7 @@ from tests import unittest
class TupleComparisonClauseTestCase(unittest.TestCase):
- def test_native_tuple_comparison(self):
+ def test_native_tuple_comparison(self) -> None:
clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)])
self.assertEqual(clause, "(a,b) > (?,?)")
self.assertEqual(args, [1, 2])
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 8e7db2c4ec..f03807c8f9 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -12,17 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Collection, List, Tuple
+
+from twisted.test.proto_helpers import MemoryReactor
+
import synapse.api.errors
from synapse.api.constants import EduTypes
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
class DeviceStoreTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
- def add_device_change(self, user_id, device_ids, host):
+ def add_device_change(self, user_id: str, device_ids: List[str], host: str) -> None:
"""Add a device list change for the given device to
`device_lists_outbound_pokes` table.
"""
@@ -44,12 +51,13 @@ class DeviceStoreTestCase(HomeserverTestCase):
)
)
- def test_store_new_device(self):
+ def test_store_new_device(self) -> None:
self.get_success(
self.store.store_device("user_id", "device_id", "display_name")
)
res = self.get_success(self.store.get_device("user_id", "device_id"))
+ assert res is not None
self.assertDictContainsSubset(
{
"user_id": "user_id",
@@ -59,7 +67,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
res,
)
- def test_get_devices_by_user(self):
+ def test_get_devices_by_user(self) -> None:
self.get_success(
self.store.store_device("user_id", "device1", "display_name 1")
)
@@ -89,7 +97,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
res["device2"],
)
- def test_count_devices_by_users(self):
+ def test_count_devices_by_users(self) -> None:
self.get_success(
self.store.store_device("user_id", "device1", "display_name 1")
)
@@ -114,7 +122,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
)
self.assertEqual(3, res)
- def test_get_device_updates_by_remote(self):
+ def test_get_device_updates_by_remote(self) -> None:
device_ids = ["device_id1", "device_id2"]
# Add two device updates with sequential `stream_id`s
@@ -128,7 +136,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
# Check original device_ids are contained within these updates
self._check_devices_in_updates(device_ids, device_updates)
- def test_get_device_updates_by_remote_can_limit_properly(self):
+ def test_get_device_updates_by_remote_can_limit_properly(self) -> None:
"""
Tests that `get_device_updates_by_remote` returns an appropriate
stream_id to resume fetching from (without skipping any results).
@@ -280,7 +288,11 @@ class DeviceStoreTestCase(HomeserverTestCase):
)
self.assertEqual(device_updates, [])
- def _check_devices_in_updates(self, expected_device_ids, device_updates):
+ def _check_devices_in_updates(
+ self,
+ expected_device_ids: Collection[str],
+ device_updates: List[Tuple[str, JsonDict]],
+ ) -> None:
"""Check that an specific device ids exist in a list of device update EDUs"""
self.assertEqual(len(device_updates), len(expected_device_ids))
@@ -289,17 +301,19 @@ class DeviceStoreTestCase(HomeserverTestCase):
}
self.assertEqual(received_device_ids, set(expected_device_ids))
- def test_update_device(self):
+ def test_update_device(self) -> None:
self.get_success(
self.store.store_device("user_id", "device_id", "display_name 1")
)
res = self.get_success(self.store.get_device("user_id", "device_id"))
+ assert res is not None
self.assertEqual("display_name 1", res["display_name"])
# do a no-op first
self.get_success(self.store.update_device("user_id", "device_id"))
res = self.get_success(self.store.get_device("user_id", "device_id"))
+ assert res is not None
self.assertEqual("display_name 1", res["display_name"])
# do the update
@@ -311,9 +325,10 @@ class DeviceStoreTestCase(HomeserverTestCase):
# check it worked
res = self.get_success(self.store.get_device("user_id", "device_id"))
+ assert res is not None
self.assertEqual("display_name 2", res["display_name"])
- def test_update_unknown_device(self):
+ def test_update_unknown_device(self) -> None:
exc = self.get_failure(
self.store.update_device(
"user_id", "unknown_device_id", new_display_name="display_name 2"
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index 20bf3ca17b..8bedc6bdf3 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -12,19 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
from synapse.types import RoomAlias, RoomID
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
class DirectoryStoreTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.room = RoomID.from_string("!abcde:test")
self.alias = RoomAlias.from_string("#my-room:test")
- def test_room_to_alias(self):
+ def test_room_to_alias(self) -> None:
self.get_success(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
@@ -36,7 +40,7 @@ class DirectoryStoreTestCase(HomeserverTestCase):
(self.get_success(self.store.get_aliases_for_room(self.room.to_string()))),
)
- def test_alias_to_room(self):
+ def test_alias_to_room(self) -> None:
self.get_success(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
@@ -48,7 +52,7 @@ class DirectoryStoreTestCase(HomeserverTestCase):
(self.get_success(self.store.get_association_from_room_alias(self.alias))),
)
- def test_delete_alias(self):
+ def test_delete_alias(self) -> None:
self.get_success(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py
index fb96ab3a2f..9cb326d90a 100644
--- a/tests/storage/test_e2e_room_keys.py
+++ b/tests/storage/test_e2e_room_keys.py
@@ -12,7 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
from synapse.storage.databases.main.e2e_room_keys import RoomKey
+from synapse.util import Clock
from tests import unittest
@@ -26,12 +30,12 @@ room_key: RoomKey = {
class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None)
self.store = hs.get_datastores().main
return hs
- def test_room_keys_version_delete(self):
+ def test_room_keys_version_delete(self) -> None:
# test that deleting a room key backup deletes the keys
version1 = self.get_success(
self.store.create_e2e_room_keys_version(
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 0f04493ad0..5fde3b9c78 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
+from synapse.util import Clock
+
from tests.unittest import HomeserverTestCase
class EndToEndKeyStoreTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
- def test_key_without_device_name(self):
+ def test_key_without_device_name(self) -> None:
now = 1470174257070
json = {"key": "value"}
@@ -35,7 +40,7 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase):
dev = res["user"]["device"]
self.assertDictContainsSubset(json, dev)
- def test_reupload_key(self):
+ def test_reupload_key(self) -> None:
now = 1470174257070
json = {"key": "value"}
@@ -53,7 +58,7 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase):
)
self.assertFalse(changed)
- def test_get_key_with_device_name(self):
+ def test_get_key_with_device_name(self) -> None:
now = 1470174257070
json = {"key": "value"}
@@ -70,7 +75,7 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase):
{"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev
)
- def test_multiple_devices(self):
+ def test_multiple_devices(self) -> None:
now = 1470174257070
self.get_success(self.store.store_device("user1", "device1", None))
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index de9f4af2de..c070278db8 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -14,6 +14,7 @@
from typing import Dict, List, Set, Tuple
+from twisted.test.proto_helpers import MemoryReactor
from twisted.trial import unittest
from synapse.api.constants import EventTypes
@@ -22,18 +23,22 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.events import _LinkMap
+from synapse.storage.types import Cursor
from synapse.types import create_requester
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
class EventChainStoreTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self._next_stream_ordering = 1
- def test_simple(self):
+ def test_simple(self) -> None:
"""Test that the example in `docs/auth_chain_difference_algorithm.md`
works.
"""
@@ -232,7 +237,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
),
)
- def test_out_of_order_events(self):
+ def test_out_of_order_events(self) -> None:
"""Test that we handle persisting events that we don't have the full
auth chain for yet (which should only happen for out of band memberships).
"""
@@ -378,7 +383,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
def persist(
self,
events: List[EventBase],
- ):
+ ) -> None:
"""Persist the given events and check that the links generated match
those given.
"""
@@ -389,7 +394,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
e.internal_metadata.stream_ordering = self._next_stream_ordering
self._next_stream_ordering += 1
- def _persist(txn):
+ def _persist(txn: LoggingTransaction) -> None:
# We need to persist the events to the events and state_events
# tables.
persist_events_store._store_event_txn(
@@ -456,7 +461,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
class LinkMapTestCase(unittest.TestCase):
- def test_simple(self):
+ def test_simple(self) -> None:
"""Basic tests for the LinkMap."""
link_map = _LinkMap()
@@ -492,7 +497,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.user_id = self.register_user("foo", "pass")
self.token = self.login("foo", "pass")
@@ -559,7 +564,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
# Delete the chain cover info.
- def _delete_tables(txn):
+ def _delete_tables(txn: Cursor) -> None:
txn.execute("DELETE FROM event_auth_chains")
txn.execute("DELETE FROM event_auth_chain_links")
@@ -567,7 +572,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
return room_id, [state1, state2]
- def test_background_update_single_room(self):
+ def test_background_update_single_room(self) -> None:
"""Test that the background update to calculate auth chains for historic
rooms works correctly.
"""
@@ -602,7 +607,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
)
)
- def test_background_update_multiple_rooms(self):
+ def test_background_update_multiple_rooms(self) -> None:
"""Test that the background update to calculate auth chains for historic
rooms works correctly.
"""
@@ -640,7 +645,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
)
)
- def test_background_update_single_large_room(self):
+ def test_background_update_single_large_room(self) -> None:
"""Test that the background update to calculate auth chains for historic
rooms works correctly.
"""
@@ -693,7 +698,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
)
)
- def test_background_update_multiple_large_room(self):
+ def test_background_update_multiple_large_room(self) -> None:
"""Test that the background update to calculate auth chains for historic
rooms works correctly.
"""
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 853db930d6..7fd3e01364 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -13,7 +13,7 @@
# limitations under the License.
import datetime
-from typing import Dict, List, Tuple, Union
+from typing import Dict, List, Tuple, Union, cast
import attr
from parameterized import parameterized
@@ -26,11 +26,12 @@ from synapse.api.room_versions import (
EventFormatVersions,
RoomVersion,
)
-from synapse.events import _EventInternalMetadata
+from synapse.events import EventBase, _EventInternalMetadata
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.storage.database import LoggingTransaction
+from synapse.storage.types import Cursor
from synapse.types import JsonDict
from synapse.util import Clock, json_encoder
@@ -54,11 +55,11 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
- def test_get_prev_events_for_room(self):
+ def test_get_prev_events_for_room(self) -> None:
room_id = "@ROOM:local"
# add a bunch of events and hashes to act as forward extremities
- def insert_event(txn, i):
+ def insert_event(txn: Cursor, i: int) -> None:
event_id = "$event_%i:local" % i
txn.execute(
@@ -90,12 +91,12 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
for i in range(0, 10):
self.assertEqual("$event_%i:local" % (19 - i), r[i])
- def test_get_rooms_with_many_extremities(self):
+ def test_get_rooms_with_many_extremities(self) -> None:
room1 = "#room1"
room2 = "#room2"
room3 = "#room3"
- def insert_event(txn, i, room_id):
+ def insert_event(txn: Cursor, i: int, room_id: str) -> None:
event_id = "$event_%i:local" % i
txn.execute(
(
@@ -155,7 +156,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# | |
# K J
- auth_graph = {
+ auth_graph: Dict[str, List[str]] = {
"a": ["e"],
"b": ["e"],
"c": ["g", "i"],
@@ -185,7 +186,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# Mark the room as maybe having a cover index.
- def store_room(txn):
+ def store_room(txn: LoggingTransaction) -> None:
self.store.db_pool.simple_insert_txn(
txn,
"rooms",
@@ -203,7 +204,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# We rudely fiddle with the appropriate tables directly, as that's much
# easier than constructing events properly.
- def insert_event(txn):
+ def insert_event(txn: LoggingTransaction) -> None:
stream_ordering = 0
for event_id in auth_graph:
@@ -228,7 +229,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.hs.datastores.persist_events._persist_event_auth_chain_txn(
txn,
[
- FakeEvent(event_id, room_id, auth_graph[event_id])
+ cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
for event_id in auth_graph
],
)
@@ -243,7 +244,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
return room_id
@parameterized.expand([(True,), (False,)])
- def test_auth_chain_ids(self, use_chain_cover_index: bool):
+ def test_auth_chain_ids(self, use_chain_cover_index: bool) -> None:
room_id = self._setup_auth_chain(use_chain_cover_index)
# a and b have the same auth chain.
@@ -308,7 +309,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.assertCountEqual(auth_chain_ids, ["i", "j"])
@parameterized.expand([(True,), (False,)])
- def test_auth_difference(self, use_chain_cover_index: bool):
+ def test_auth_difference(self, use_chain_cover_index: bool) -> None:
room_id = self._setup_auth_chain(use_chain_cover_index)
# Now actually test that various combinations give the right result:
@@ -353,7 +354,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
)
self.assertSetEqual(difference, set())
- def test_auth_difference_partial_cover(self):
+ def test_auth_difference_partial_cover(self) -> None:
"""Test that we correctly handle rooms where not all events have a chain
cover calculated. This can happen in some obscure edge cases, including
during the background update that calculates the chain cover for old
@@ -377,7 +378,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# | |
# K J
- auth_graph = {
+ auth_graph: Dict[str, List[str]] = {
"a": ["e"],
"b": ["e"],
"c": ["g", "i"],
@@ -408,7 +409,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
# We rudely fiddle with the appropriate tables directly, as that's much
# easier than constructing events properly.
- def insert_event(txn):
+ def insert_event(txn: LoggingTransaction) -> None:
# First insert the room and mark it as having a chain cover.
self.store.db_pool.simple_insert_txn(
txn,
@@ -447,7 +448,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.hs.datastores.persist_events._persist_event_auth_chain_txn(
txn,
[
- FakeEvent(event_id, room_id, auth_graph[event_id])
+ cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
for event_id in auth_graph
if event_id != "b"
],
@@ -465,7 +466,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.hs.datastores.persist_events._persist_event_auth_chain_txn(
txn,
- [FakeEvent("b", room_id, auth_graph["b"])],
+ [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))],
)
self.store.db_pool.simple_update_txn(
@@ -527,7 +528,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
@parameterized.expand(
[(room_version,) for room_version in KNOWN_ROOM_VERSIONS.values()]
)
- def test_prune_inbound_federation_queue(self, room_version: RoomVersion):
+ def test_prune_inbound_federation_queue(self, room_version: RoomVersion) -> None:
"""Test that pruning of inbound federation queues work"""
room_id = "some_room_id"
@@ -686,7 +687,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
stream_ordering += 1
- def populate_db(txn: LoggingTransaction):
+ def populate_db(txn: LoggingTransaction) -> None:
# Insert the room to satisfy the foreign key constraint of
# `event_failed_pull_attempts`
self.store.db_pool.simple_insert_txn(
@@ -760,7 +761,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map)
- def test_get_backfill_points_in_room(self):
+ def test_get_backfill_points_in_room(self) -> None:
"""
Test to make sure only backfill points that are older and come before
the `current_depth` are returned.
@@ -787,7 +788,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_backfill_points_in_room_excludes_events_we_have_attempted(
self,
- ):
+ ) -> None:
"""
Test to make sure that events we have attempted to backfill (and within
backoff timeout duration) do not show up as an event to backfill again.
@@ -824,7 +825,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_backfill_points_in_room_attempted_event_retry_after_backoff_duration(
self,
- ):
+ ) -> None:
"""
Test to make sure after we fake attempt to backfill event "b3" many times,
we can see retry and see the "b3" again after the backoff timeout duration
@@ -941,7 +942,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
"5": 7,
}
- def populate_db(txn: LoggingTransaction):
+ def populate_db(txn: LoggingTransaction) -> None:
# Insert the room to satisfy the foreign key constraint of
# `event_failed_pull_attempts`
self.store.db_pool.simple_insert_txn(
@@ -996,7 +997,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map)
- def test_get_insertion_event_backward_extremities_in_room(self):
+ def test_get_insertion_event_backward_extremities_in_room(self) -> None:
"""
Test to make sure only insertion event backward extremities that are
older and come before the `current_depth` are returned.
@@ -1027,7 +1028,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_insertion_event_backward_extremities_in_room_excludes_events_we_have_attempted(
self,
- ):
+ ) -> None:
"""
Test to make sure that insertion events we have attempted to backfill
(and within backoff timeout duration) do not show up as an event to
@@ -1060,7 +1061,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_insertion_event_backward_extremities_in_room_attempted_event_retry_after_backoff_duration(
self,
- ):
+ ) -> None:
"""
Test to make sure after we fake attempt to backfill event
"insertion_eventA" many times, we can see retry and see the
@@ -1130,9 +1131,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
self.assertEqual(backfill_event_ids, ["insertion_eventA"])
- def test_get_event_ids_to_not_pull_from_backoff(
- self,
- ):
+ def test_get_event_ids_to_not_pull_from_backoff(self) -> None:
"""
Test to make sure only event IDs we should backoff from are returned.
"""
@@ -1157,7 +1156,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_event_ids_to_not_pull_from_backoff_retry_after_backoff_duration(
self,
- ):
+ ) -> None:
"""
Test to make sure no event IDs are returned after the backoff duration has
elapsed.
@@ -1187,19 +1186,19 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.assertEqual(event_ids_to_backoff, [])
-@attr.s
+@attr.s(auto_attribs=True)
class FakeEvent:
- event_id = attr.ib()
- room_id = attr.ib()
- auth_events = attr.ib()
+ event_id: str
+ room_id: str
+ auth_events: List[str]
type = "foo"
state_key = "foo"
internal_metadata = _EventInternalMetadata({})
- def auth_event_ids(self):
+ def auth_event_ids(self) -> List[str]:
return self.auth_events
- def is_state(self):
+ def is_state(self) -> bool:
return True
diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index 6f1135eef4..a91411168c 100644
--- a/tests/storage/test_event_metrics.py
+++ b/tests/storage/test_event_metrics.py
@@ -20,7 +20,7 @@ from tests.unittest import HomeserverTestCase
class ExtremStatisticsTestCase(HomeserverTestCase):
- def test_exposed_to_prometheus(self):
+ def test_exposed_to_prometheus(self) -> None:
"""
Forward extremity counts are exposed via Prometheus.
"""
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index ee48920f84..5fa8bd2d98 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -156,7 +156,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
last_event_id: str
- def _assert_counts(noitf_count: int, highlight_count: int) -> None:
+ def _assert_counts(notif_count: int, highlight_count: int) -> None:
counts = self.get_success(
self.store.db_pool.runInteraction(
"get-unread-counts",
@@ -168,13 +168,22 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
self.assertEqual(
counts.main_timeline,
NotifCounts(
- notify_count=noitf_count,
+ notify_count=notif_count,
unread_count=0,
highlight_count=highlight_count,
),
)
self.assertEqual(counts.threads, {})
+ aggregate_counts = self.get_success(
+ self.store.db_pool.runInteraction(
+ "get-aggregate-unread-counts",
+ self.store._get_unread_counts_by_room_for_user_txn,
+ user_id,
+ )
+ )
+ self.assertEqual(aggregate_counts[room_id], notif_count)
+
def _create_event(highlight: bool = False) -> str:
result = self.helper.send_event(
room_id,
@@ -283,7 +292,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
last_event_id: str
def _assert_counts(
- noitf_count: int,
+ notif_count: int,
highlight_count: int,
thread_notif_count: int,
thread_highlight_count: int,
@@ -299,7 +308,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
self.assertEqual(
counts.main_timeline,
NotifCounts(
- notify_count=noitf_count,
+ notify_count=notif_count,
unread_count=0,
highlight_count=highlight_count,
),
@@ -318,6 +327,17 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
else:
self.assertEqual(counts.threads, {})
+ aggregate_counts = self.get_success(
+ self.store.db_pool.runInteraction(
+ "get-aggregate-unread-counts",
+ self.store._get_unread_counts_by_room_for_user_txn,
+ user_id,
+ )
+ )
+ self.assertEqual(
+ aggregate_counts[room_id], notif_count + thread_notif_count
+ )
+
def _create_event(
highlight: bool = False, thread_id: Optional[str] = None
) -> str:
@@ -454,7 +474,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
last_event_id: str
def _assert_counts(
- noitf_count: int,
+ notif_count: int,
highlight_count: int,
thread_notif_count: int,
thread_highlight_count: int,
@@ -470,7 +490,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
self.assertEqual(
counts.main_timeline,
NotifCounts(
- notify_count=noitf_count,
+ notify_count=notif_count,
unread_count=0,
highlight_count=highlight_count,
),
@@ -489,6 +509,17 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
else:
self.assertEqual(counts.threads, {})
+ aggregate_counts = self.get_success(
+ self.store.db_pool.runInteraction(
+ "get-aggregate-unread-counts",
+ self.store._get_unread_counts_by_room_for_user_txn,
+ user_id,
+ )
+ )
+ self.assertEqual(
+ aggregate_counts[room_id], notif_count + thread_notif_count
+ )
+
def _create_event(
highlight: bool = False, thread_id: Optional[str] = None
) -> str:
@@ -646,7 +677,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
)
return result["event_id"]
- def _assert_counts(noitf_count: int, thread_notif_count: int) -> None:
+ def _assert_counts(notif_count: int, thread_notif_count: int) -> None:
counts = self.get_success(
self.store.db_pool.runInteraction(
"get-unread-counts",
@@ -658,7 +689,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
self.assertEqual(
counts.main_timeline,
NotifCounts(
- notify_count=noitf_count, unread_count=0, highlight_count=0
+ notify_count=notif_count, unread_count=0, highlight_count=0
),
)
if thread_notif_count:
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index 3ce4f35cb7..05661a537d 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -12,12 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List, Optional
+
+from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
+from synapse.events import EventBase
from synapse.federation.federation_base import event_from_pdu_json
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.types import StateMap
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -29,7 +36,9 @@ class ExtremPruneTestCase(HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, homeserver):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
self.state = self.hs.get_state_handler()
self._persistence = self.hs.get_storage_controllers().persistence
self._state_storage_controller = self.hs.get_storage_controllers().state
@@ -67,7 +76,9 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check that the current extremities is the remote event.
self.assert_extremities([self.remote_event_1.event_id])
- def persist_event(self, event, state=None):
+ def persist_event(
+ self, event: EventBase, state: Optional[StateMap[str]] = None
+ ) -> None:
"""Persist the event, with optional state"""
context = self.get_success(
self.state.compute_event_context(
@@ -78,14 +89,14 @@ class ExtremPruneTestCase(HomeserverTestCase):
)
self.get_success(self._persistence.persist_event(event, context))
- def assert_extremities(self, expected_extremities):
+ def assert_extremities(self, expected_extremities: List[str]) -> None:
"""Assert the current extremities for the room"""
extremities = self.get_success(
self.store.get_prev_events_for_room(self.room_id)
)
self.assertCountEqual(extremities, expected_extremities)
- def test_prune_gap(self):
+ def test_prune_gap(self) -> None:
"""Test that we drop extremities after a gap when we see an event from
the same domain.
"""
@@ -117,7 +128,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
- def test_do_not_prune_gap_if_state_different(self):
+ def test_do_not_prune_gap_if_state_different(self) -> None:
"""Test that we don't prune extremities after a gap if the resolved
state is different.
"""
@@ -161,7 +172,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check that we haven't dropped the old extremity.
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
- def test_prune_gap_if_old(self):
+ def test_prune_gap_if_old(self) -> None:
"""Test that we drop extremities after a gap when the previous extremity
is "old"
"""
@@ -197,7 +208,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
- def test_do_not_prune_gap_if_other_server(self):
+ def test_do_not_prune_gap_if_other_server(self) -> None:
"""Test that we do not drop extremities after a gap when we see an event
from a different domain.
"""
@@ -229,7 +240,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check the new extremity is just the new remote event.
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
- def test_prune_gap_if_dummy_remote(self):
+ def test_prune_gap_if_dummy_remote(self) -> None:
"""Test that we drop extremities after a gap when the previous extremity
is a local dummy event and only points to remote events.
"""
@@ -271,7 +282,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
- def test_prune_gap_if_dummy_local(self):
+ def test_prune_gap_if_dummy_local(self) -> None:
"""Test that we don't drop extremities after a gap when the previous
extremity is a local dummy event and points to local events.
"""
@@ -315,7 +326,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id, local_message_event_id])
- def test_do_not_prune_gap_if_not_dummy(self):
+ def test_do_not_prune_gap_if_not_dummy(self) -> None:
"""Test that we do not drop extremities after a gap when the previous extremity
is not a dummy event.
"""
@@ -359,12 +370,14 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
login.register_servlets,
]
- def prepare(self, reactor, clock, homeserver):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
self.state = self.hs.get_state_handler()
self._persistence = self.hs.get_storage_controllers().persistence
self.store = self.hs.get_datastores().main
- def test_remote_user_rooms_cache_invalidated(self):
+ def test_remote_user_rooms_cache_invalidated(self) -> None:
"""Test that if the server leaves a room the `get_rooms_for_user` cache
is invalidated for remote users.
"""
@@ -411,7 +424,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
rooms = self.get_success(self.store.get_rooms_for_user(remote_user))
self.assertEqual(set(rooms), set())
- def test_room_remote_user_cache_invalidated(self):
+ def test_room_remote_user_cache_invalidated(self) -> None:
"""Test that if the server leaves a room the `get_users_in_room` cache
is invalidated for remote users.
"""
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index 9059095525..aa4b5bd3b1 100644
--- a/tests/storage/test_keys.py
+++ b/tests/storage/test_keys.py
@@ -13,6 +13,7 @@
# limitations under the License.
import signedjson.key
+import signedjson.types
import unpaddedbase64
from twisted.internet.defer import Deferred
@@ -22,7 +23,9 @@ from synapse.storage.keys import FetchKeyResult
import tests.unittest
-def decode_verify_key_base64(key_id: str, key_base64: str):
+def decode_verify_key_base64(
+ key_id: str, key_base64: str
+) -> signedjson.types.VerifyKey:
key_bytes = unpaddedbase64.decode_base64(key_base64)
return signedjson.key.decode_verify_key_bytes(key_id, key_bytes)
@@ -36,7 +39,7 @@ KEY_2 = decode_verify_key_base64(
class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
- def test_get_server_verify_keys(self):
+ def test_get_server_verify_keys(self) -> None:
store = self.hs.get_datastores().main
key_id_1 = "ed25519:key1"
@@ -71,7 +74,7 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
# non-existent result gives None
self.assertIsNone(res[("server1", "ed25519:key3")])
- def test_cache(self):
+ def test_cache(self) -> None:
"""Check that updates correctly invalidate the cache."""
store = self.hs.get_datastores().main
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index c55c4db970..2827738379 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -53,7 +53,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.reactor.advance(FORTY_DAYS)
@override_config({"max_mau_value": 3, "mau_limit_reserved_threepids": gen_3pids(3)})
- def test_initialise_reserved_users(self):
+ def test_initialise_reserved_users(self) -> None:
threepids = self.hs.config.server.mau_limits_reserved_threepids
# register three users, of which two have reserved 3pids, and a third
@@ -133,7 +133,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
active_count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(active_count, 3)
- def test_can_insert_and_count_mau(self):
+ def test_can_insert_and_count_mau(self) -> None:
count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(count, 0)
@@ -143,7 +143,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(count, 1)
- def test_appservice_user_not_counted_in_mau(self):
+ def test_appservice_user_not_counted_in_mau(self) -> None:
self.get_success(
self.store.register_user(
user_id="@appservice_user:server", appservice_id="wibble"
@@ -158,7 +158,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(count, 0)
- def test_user_last_seen_monthly_active(self):
+ def test_user_last_seen_monthly_active(self) -> None:
user_id1 = "@user1:server"
user_id2 = "@user2:server"
user_id3 = "@user3:server"
@@ -177,7 +177,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.assertIsNone(result)
@override_config({"max_mau_value": 5})
- def test_reap_monthly_active_users(self):
+ def test_reap_monthly_active_users(self) -> None:
initial_users = 10
for i in range(initial_users):
self.get_success(
@@ -204,7 +204,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
# Note that below says mau_limit (no s), this is the name of the config
# value, although it gets stored on the config object as mau_limits.
@override_config({"max_mau_value": 5, "mau_limit_reserved_threepids": gen_3pids(5)})
- def test_reap_monthly_active_users_reserved_users(self):
+ def test_reap_monthly_active_users_reserved_users(self) -> None:
"""Tests that reaping correctly handles reaping where reserved users are
present"""
threepids = self.hs.config.server.mau_limits_reserved_threepids
@@ -244,7 +244,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(count, self.hs.config.server.max_mau_value)
- def test_populate_monthly_users_is_guest(self):
+ def test_populate_monthly_users_is_guest(self) -> None:
# Test that guest users are not added to mau list
user_id = "@user_id:host"
@@ -260,7 +260,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_not_called()
- def test_populate_monthly_users_should_update(self):
+ def test_populate_monthly_users_should_update(self) -> None:
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment]
@@ -273,7 +273,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_called_once()
- def test_populate_monthly_users_should_not_update(self):
+ def test_populate_monthly_users_should_not_update(self) -> None:
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment]
@@ -286,7 +286,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_not_called()
- def test_get_reserved_real_user_account(self):
+ def test_get_reserved_real_user_account(self) -> None:
# Test no reserved users, or reserved threepids
users = self.get_success(self.store.get_registered_reserved_users())
self.assertEqual(len(users), 0)
@@ -326,7 +326,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
users = self.get_success(self.store.get_registered_reserved_users())
self.assertEqual(len(users), len(threepids))
- def test_support_user_not_add_to_mau_limits(self):
+ def test_support_user_not_add_to_mau_limits(self) -> None:
support_user_id = "@support:test"
count = self.get_success(self.store.get_monthly_active_count())
@@ -347,7 +347,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
@override_config(
{"limit_usage_by_mau": False, "mau_stats_only": True, "max_mau_value": 1}
)
- def test_track_monthly_users_without_cap(self):
+ def test_track_monthly_users_without_cap(self) -> None:
count = self.get_success(self.store.get_monthly_active_count())
self.assertEqual(0, count)
@@ -358,14 +358,14 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.assertEqual(2, count)
@override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
- def test_no_users_when_not_tracking(self):
+ def test_no_users_when_not_tracking(self) -> None:
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
self.get_success(self.store.populate_monthly_active_users("@user:sever"))
self.store.upsert_monthly_active_user.assert_not_called()
- def test_get_monthly_active_count_by_service(self):
+ def test_get_monthly_active_count_by_service(self) -> None:
appservice1_user1 = "@appservice1_user1:example.com"
appservice1_user2 = "@appservice1_user2:example.com"
@@ -413,7 +413,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.assertEqual(result[service2], 1)
self.assertEqual(result[native], 1)
- def test_get_monthly_active_users_by_service(self):
+ def test_get_monthly_active_users_by_service(self) -> None:
# (No users, no filtering) -> empty result
result = self.get_success(self.store.get_monthly_active_users_by_service())
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 9c1182ed16..010cc74c31 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -12,8 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.errors import NotFoundError, SynapseError
from synapse.rest.client import room
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -23,17 +27,17 @@ class PurgeTests(HomeserverTestCase):
user_id = "@red:server"
servlets = [room.register_servlets]
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None)
return hs
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.room_id = self.helper.create_room_as(self.user_id)
self.store = hs.get_datastores().main
self._storage_controllers = self.hs.get_storage_controllers()
- def test_purge_history(self):
+ def test_purge_history(self) -> None:
"""
Purging a room history will delete everything before the topological point.
"""
@@ -63,7 +67,7 @@ class PurgeTests(HomeserverTestCase):
self.get_failure(self.store.get_event(third["event_id"]), NotFoundError)
self.get_success(self.store.get_event(last["event_id"]))
- def test_purge_history_wont_delete_extrems(self):
+ def test_purge_history_wont_delete_extrems(self) -> None:
"""
Purging a room history will delete everything before the topological point.
"""
@@ -77,6 +81,7 @@ class PurgeTests(HomeserverTestCase):
token = self.get_success(
self.store.get_topological_token_for_event(last["event_id"])
)
+ assert token.topological is not None
event = f"t{token.topological + 1}-{token.stream + 1}"
# Purge everything before this topological token
@@ -94,7 +99,7 @@ class PurgeTests(HomeserverTestCase):
self.get_success(self.store.get_event(third["event_id"]))
self.get_success(self.store.get_event(last["event_id"]))
- def test_purge_room(self):
+ def test_purge_room(self) -> None:
"""
Purging a room will delete everything about it.
"""
diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py
index 81253d0361..d8d84152dc 100644
--- a/tests/storage/test_receipts.py
+++ b/tests/storage/test_receipts.py
@@ -14,8 +14,12 @@
from typing import Collection, Optional
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import ReceiptTypes
+from synapse.server import HomeServer
from synapse.types import UserID, create_requester
+from synapse.util import Clock
from tests.test_utils.event_injection import create_event
from tests.unittest import HomeserverTestCase
@@ -25,7 +29,9 @@ OUR_USER_ID = "@our:test"
class ReceiptTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, homeserver) -> None:
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
super().prepare(reactor, clock, homeserver)
self.store = homeserver.get_datastores().main
@@ -135,11 +141,11 @@ class ReceiptTestCase(HomeserverTestCase):
)
self.assertEqual(res, {})
- res = self.get_last_unthreaded_receipt(
+ res2 = self.get_last_unthreaded_receipt(
[ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
)
- self.assertEqual(res, None)
+ self.assertIsNone(res2)
def test_get_receipts_for_user(self) -> None:
# Send some events into the first room
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 6c4e63b77c..df4740f9d9 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -11,27 +11,35 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List, Optional
+from typing import List, Optional, cast
from canonicaljson import json
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
-from synapse.types import RoomID, UserID
+from synapse.events import EventBase, _EventInternalMetadata
+from synapse.events.builder import EventBuilder
+from synapse.server import HomeServer
+from synapse.types import JsonDict, RoomID, UserID
+from synapse.util import Clock
from tests import unittest
from tests.utils import create_room
class RedactionTestCase(unittest.HomeserverTestCase):
- def default_config(self):
+ def default_config(self) -> JsonDict:
config = super().default_config()
config["redaction_retention_period"] = "30d"
return config
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
- self._storage = hs.get_storage_controllers()
+ storage = hs.get_storage_controllers()
+ assert storage.persistence is not None
+ self._persistence = storage.persistence
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
@@ -46,14 +54,13 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.depth = 1
- def inject_room_member(
+ def inject_room_member( # type: ignore[override]
self,
- room,
- user,
- membership,
- replaces_state=None,
- extra_content: Optional[dict] = None,
- ):
+ room: RoomID,
+ user: UserID,
+ membership: str,
+ extra_content: Optional[JsonDict] = None,
+ ) -> EventBase:
content = {"membership": membership}
content.update(extra_content or {})
builder = self.event_builder_factory.for_room_version(
@@ -71,11 +78,11 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self._storage.persistence.persist_event(event, context))
+ self.get_success(self._persistence.persist_event(event, context))
return event
- def inject_message(self, room, user, body):
+ def inject_message(self, room: RoomID, user: UserID, body: str) -> EventBase:
self.depth += 1
builder = self.event_builder_factory.for_room_version(
@@ -93,11 +100,13 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self._storage.persistence.persist_event(event, context))
+ self.get_success(self._persistence.persist_event(event, context))
return event
- def inject_redaction(self, room, event_id, user, reason):
+ def inject_redaction(
+ self, room: RoomID, event_id: str, user: UserID, reason: str
+ ) -> EventBase:
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
@@ -114,11 +123,11 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self._storage.persistence.persist_event(event, context))
+ self.get_success(self._persistence.persist_event(event, context))
return event
- def test_redact(self):
+ def test_redact(self) -> None:
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
msg_event = self.inject_message(self.room1, self.u_alice, "t")
@@ -165,7 +174,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
event.unsigned["redacted_because"],
)
- def test_redact_join(self):
+ def test_redact_join(self) -> None:
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
msg_event = self.inject_room_member(
@@ -213,12 +222,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
event.unsigned["redacted_because"],
)
- def test_circular_redaction(self):
+ def test_circular_redaction(self) -> None:
redaction_event_id1 = "$redaction1_id:test"
redaction_event_id2 = "$redaction2_id:test"
class EventIdManglingBuilder:
- def __init__(self, base_builder, event_id):
+ def __init__(self, base_builder: EventBuilder, event_id: str):
self._base_builder = base_builder
self._event_id = event_id
@@ -227,67 +236,73 @@ class RedactionTestCase(unittest.HomeserverTestCase):
prev_event_ids: List[str],
auth_event_ids: Optional[List[str]],
depth: Optional[int] = None,
- ):
+ ) -> EventBase:
built_event = await self._base_builder.build(
prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids
)
- built_event._event_id = self._event_id
+ built_event._event_id = self._event_id # type: ignore[attr-defined]
built_event._dict["event_id"] = self._event_id
assert built_event.event_id == self._event_id
return built_event
@property
- def room_id(self):
+ def room_id(self) -> str:
return self._base_builder.room_id
@property
- def type(self):
+ def type(self) -> str:
return self._base_builder.type
@property
- def internal_metadata(self):
+ def internal_metadata(self) -> _EventInternalMetadata:
return self._base_builder.internal_metadata
event_1, context_1 = self.get_success(
self.event_creation_handler.create_new_client_event(
- EventIdManglingBuilder(
- self.event_builder_factory.for_room_version(
- RoomVersions.V1,
- {
- "type": EventTypes.Redaction,
- "sender": self.u_alice.to_string(),
- "room_id": self.room1.to_string(),
- "content": {"reason": "test"},
- "redacts": redaction_event_id2,
- },
+ cast(
+ EventBuilder,
+ EventIdManglingBuilder(
+ self.event_builder_factory.for_room_version(
+ RoomVersions.V1,
+ {
+ "type": EventTypes.Redaction,
+ "sender": self.u_alice.to_string(),
+ "room_id": self.room1.to_string(),
+ "content": {"reason": "test"},
+ "redacts": redaction_event_id2,
+ },
+ ),
+ redaction_event_id1,
),
- redaction_event_id1,
)
)
)
- self.get_success(self._storage.persistence.persist_event(event_1, context_1))
+ self.get_success(self._persistence.persist_event(event_1, context_1))
event_2, context_2 = self.get_success(
self.event_creation_handler.create_new_client_event(
- EventIdManglingBuilder(
- self.event_builder_factory.for_room_version(
- RoomVersions.V1,
- {
- "type": EventTypes.Redaction,
- "sender": self.u_alice.to_string(),
- "room_id": self.room1.to_string(),
- "content": {"reason": "test"},
- "redacts": redaction_event_id1,
- },
+ cast(
+ EventBuilder,
+ EventIdManglingBuilder(
+ self.event_builder_factory.for_room_version(
+ RoomVersions.V1,
+ {
+ "type": EventTypes.Redaction,
+ "sender": self.u_alice.to_string(),
+ "room_id": self.room1.to_string(),
+ "content": {"reason": "test"},
+ "redacts": redaction_event_id1,
+ },
+ ),
+ redaction_event_id2,
),
- redaction_event_id2,
)
)
)
- self.get_success(self._storage.persistence.persist_event(event_2, context_2))
+ self.get_success(self._persistence.persist_event(event_2, context_2))
# fetch one of the redactions
fetched = self.get_success(self.store.get_event(redaction_event_id1))
@@ -298,7 +313,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
fetched.unsigned["redacted_because"].event_id, redaction_event_id2
)
- def test_redact_censor(self):
+ def test_redact_censor(self) -> None:
"""Test that a redacted event gets censored in the DB after a month"""
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
@@ -364,7 +379,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.assert_dict({"content": {}}, json.loads(event_json))
- def test_redact_redaction(self):
+ def test_redact_redaction(self) -> None:
"""Tests that we can redact a redaction and can fetch it again."""
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
@@ -391,7 +406,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.store.get_event(first_redact_event.event_id, allow_none=True)
)
- def test_store_redacted_redaction(self):
+ def test_store_redacted_redaction(self) -> None:
"""Tests that we can store a redacted redaction."""
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
@@ -410,9 +425,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(
- self._storage.persistence.persist_event(redaction_event, context)
- )
+ self.get_success(self._persistence.persist_event(redaction_event, context))
# Now lets jump to the future where we have censored the redaction event
# in the DB.
diff --git a/tests/storage/test_rollback_worker.py b/tests/storage/test_rollback_worker.py
index 0baa54312e..966aafea6f 100644
--- a/tests/storage/test_rollback_worker.py
+++ b/tests/storage/test_rollback_worker.py
@@ -14,10 +14,15 @@
from typing import List
from unittest import mock
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.app.generic_worker import GenericWorkerServer
+from synapse.server import HomeServer
from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.prepare_database import PrepareDatabaseException, prepare_database
from synapse.storage.schema import SCHEMA_VERSION
+from synapse.types import JsonDict
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -39,13 +44,13 @@ def fake_listdir(filepath: str) -> List[str]:
class WorkerSchemaTests(HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver(
federation_http_client=None, homeserver_to_use=GenericWorkerServer
)
return hs
- def default_config(self):
+ def default_config(self) -> JsonDict:
conf = super().default_config()
# Mark this as a worker app.
@@ -53,7 +58,7 @@ class WorkerSchemaTests(HomeserverTestCase):
return conf
- def test_rolling_back(self):
+ def test_rolling_back(self) -> None:
"""Test that workers can start if the DB is a newer schema version"""
db_pool = self.hs.get_datastores().main.db_pool
@@ -70,7 +75,7 @@ class WorkerSchemaTests(HomeserverTestCase):
prepare_database(db_conn, db_pool.engine, self.hs.config)
- def test_not_upgraded_old_schema_version(self):
+ def test_not_upgraded_old_schema_version(self) -> None:
"""Test that workers don't start if the DB has an older schema version"""
db_pool = self.hs.get_datastores().main.db_pool
db_conn = LoggingDatabaseConnection(
@@ -87,7 +92,7 @@ class WorkerSchemaTests(HomeserverTestCase):
with self.assertRaises(PrepareDatabaseException):
prepare_database(db_conn, db_pool.engine, self.hs.config)
- def test_not_upgraded_current_schema_version_with_outstanding_deltas(self):
+ def test_not_upgraded_current_schema_version_with_outstanding_deltas(self) -> None:
"""
Test that workers don't start if the DB is on the current schema version,
but there are still outstanding delta migrations to run.
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 3405efb6a8..71ec74eadc 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -12,14 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.room_versions import RoomVersions
+from synapse.server import HomeServer
from synapse.types import RoomAlias, RoomID, UserID
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
class RoomStoreTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# We can't test RoomStore on its own without the DirectoryStore, for
# management of the 'room_aliases' table
self.store = hs.get_datastores().main
@@ -37,30 +41,34 @@ class RoomStoreTestCase(HomeserverTestCase):
)
)
- def test_get_room(self):
+ def test_get_room(self) -> None:
+ res = self.get_success(self.store.get_room(self.room.to_string()))
+ assert res is not None
self.assertDictContainsSubset(
{
"room_id": self.room.to_string(),
"creator": self.u_creator.to_string(),
"is_public": True,
},
- (self.get_success(self.store.get_room(self.room.to_string()))),
+ res,
)
- def test_get_room_unknown_room(self):
+ def test_get_room_unknown_room(self) -> None:
self.assertIsNone(self.get_success(self.store.get_room("!uknown:test")))
- def test_get_room_with_stats(self):
+ def test_get_room_with_stats(self) -> None:
+ res = self.get_success(self.store.get_room_with_stats(self.room.to_string()))
+ assert res is not None
self.assertDictContainsSubset(
{
"room_id": self.room.to_string(),
"creator": self.u_creator.to_string(),
"public": True,
},
- (self.get_success(self.store.get_room_with_stats(self.room.to_string()))),
+ res,
)
- def test_get_room_with_stats_unknown_room(self):
+ def test_get_room_with_stats_unknown_room(self) -> None:
self.assertIsNone(
- (self.get_success(self.store.get_room_with_stats("!uknown:test"))),
+ self.get_success(self.store.get_room_with_stats("!uknown:test"))
)
diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py
index ef850daa73..14d872514d 100644
--- a/tests/storage/test_room_search.py
+++ b/tests/storage/test_room_search.py
@@ -39,7 +39,7 @@ class EventSearchInsertionTest(HomeserverTestCase):
room.register_servlets,
]
- def test_null_byte(self):
+ def test_null_byte(self) -> None:
"""
Postgres/SQLite don't like null bytes going into the search tables. Internally
we replace those with a space.
@@ -86,7 +86,7 @@ class EventSearchInsertionTest(HomeserverTestCase):
if isinstance(store.database_engine, PostgresEngine):
self.assertIn("alice", result.get("highlights"))
- def test_non_string(self):
+ def test_non_string(self) -> None:
"""Test that non-string `value`s are not inserted into `event_search`.
This is particularly important when using sqlite, since a sqlite column can hold
@@ -157,7 +157,7 @@ class EventSearchInsertionTest(HomeserverTestCase):
self.assertEqual(f.value.code, 404)
@skip_unless(not USE_POSTGRES_FOR_TESTS, "requires sqlite")
- def test_sqlite_non_string_deletion_background_update(self):
+ def test_sqlite_non_string_deletion_background_update(self) -> None:
"""Test the background update to delete bad rows from `event_search`."""
store = self.hs.get_datastores().main
@@ -350,7 +350,7 @@ class MessageSearchTest(HomeserverTestCase):
"results array length should match count",
)
- def test_postgres_web_search_for_phrase(self):
+ def test_postgres_web_search_for_phrase(self) -> None:
"""
Test searching for phrases using typical web search syntax, as per postgres' websearch_to_tsquery.
This test is skipped unless the postgres instance supports websearch_to_tsquery.
@@ -364,7 +364,7 @@ class MessageSearchTest(HomeserverTestCase):
self._check_test_cases(store, self.COMMON_CASES + self.POSTGRES_CASES)
- def test_sqlite_search(self):
+ def test_sqlite_search(self) -> None:
"""
Test sqlite searching for phrases.
"""
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 5564161750..a433e70870 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -16,10 +16,15 @@ import logging
from frozendict import frozendict
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
-from synapse.storage.state import StateFilter
-from synapse.types import RoomID, UserID
+from synapse.events import EventBase
+from synapse.server import HomeServer
+from synapse.types import JsonDict, RoomID, StateMap, UserID
+from synapse.types.state import StateFilter
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase, TestCase
@@ -27,7 +32,7 @@ logger = logging.getLogger(__name__)
class StateStoreTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.storage = hs.get_storage_controllers()
self.state_datastore = self.storage.state.stores.state
@@ -48,7 +53,9 @@ class StateStoreTestCase(HomeserverTestCase):
)
)
- def inject_state_event(self, room, sender, typ, state_key, content):
+ def inject_state_event(
+ self, room: RoomID, sender: UserID, typ: str, state_key: str, content: JsonDict
+ ) -> EventBase:
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
@@ -64,24 +71,29 @@ class StateStoreTestCase(HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
+ assert self.storage.persistence is not None
self.get_success(self.storage.persistence.persist_event(event, context))
return event
- def assertStateMapEqual(self, s1, s2):
+ def assertStateMapEqual(
+ self, s1: StateMap[EventBase], s2: StateMap[EventBase]
+ ) -> None:
for t in s1:
# just compare event IDs for simplicity
self.assertEqual(s1[t].event_id, s2[t].event_id)
self.assertEqual(len(s1), len(s2))
- def test_get_state_groups_ids(self):
+ def test_get_state_groups_ids(self) -> None:
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
e2 = self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
state_group_map = self.get_success(
- self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
+ self.storage.state.get_state_groups_ids(
+ self.room.to_string(), [e2.event_id]
+ )
)
self.assertEqual(len(state_group_map), 1)
state_map = list(state_group_map.values())[0]
@@ -90,21 +102,21 @@ class StateStoreTestCase(HomeserverTestCase):
{(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id},
)
- def test_get_state_groups(self):
+ def test_get_state_groups(self) -> None:
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
e2 = self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
state_group_map = self.get_success(
- self.storage.state.get_state_groups(self.room, [e2.event_id])
+ self.storage.state.get_state_groups(self.room.to_string(), [e2.event_id])
)
self.assertEqual(len(state_group_map), 1)
state_list = list(state_group_map.values())[0]
self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
- def test_get_state_for_event(self):
+ def test_get_state_for_event(self) -> None:
# this defaults to a linear DAG as each new injection defaults to whatever
# forward extremities are currently in the DB for this room.
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
@@ -487,14 +499,16 @@ class StateStoreTestCase(HomeserverTestCase):
class StateFilterDifferenceTestCase(TestCase):
def assert_difference(
self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
- ):
+ ) -> None:
self.assertEqual(
minuend.approx_difference(subtrahend),
expected,
f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
)
- def test_state_filter_difference_no_include_other_minus_no_include_other(self):
+ def test_state_filter_difference_no_include_other_minus_no_include_other(
+ self,
+ ) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), both a and b do not have the
@@ -610,7 +624,7 @@ class StateFilterDifferenceTestCase(TestCase):
),
)
- def test_state_filter_difference_include_other_minus_no_include_other(self):
+ def test_state_filter_difference_include_other_minus_no_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), only a has the include_others flag set.
@@ -739,7 +753,7 @@ class StateFilterDifferenceTestCase(TestCase):
),
)
- def test_state_filter_difference_include_other_minus_include_other(self):
+ def test_state_filter_difference_include_other_minus_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), both a and b have the include_others
@@ -864,7 +878,7 @@ class StateFilterDifferenceTestCase(TestCase):
),
)
- def test_state_filter_difference_no_include_other_minus_include_other(self):
+ def test_state_filter_difference_no_include_other_minus_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), only b has the include_others flag set.
@@ -979,7 +993,7 @@ class StateFilterDifferenceTestCase(TestCase):
),
)
- def test_state_filter_difference_simple_cases(self):
+ def test_state_filter_difference_simple_cases(self) -> None:
"""
Tests some very simple cases of the StateFilter approx_difference,
that are not explicitly tested by the more in-depth tests.
@@ -995,7 +1009,7 @@ class StateFilterDifferenceTestCase(TestCase):
class StateFilterTestCase(TestCase):
- def test_return_expanded(self):
+ def test_return_expanded(self) -> None:
"""
Tests the behaviour of the return_expanded() function that expands
StateFilters to include more state types (for the sake of cache hit rate).
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
index 34fa810cf6..bc090ebce0 100644
--- a/tests/storage/test_stream.py
+++ b/tests/storage/test_stream.py
@@ -14,11 +14,15 @@
from typing import List
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.filtering import Filter
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.server import HomeServer
from synapse.types import JsonDict
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -37,12 +41,14 @@ class PaginationTestCase(HomeserverTestCase):
login.register_servlets,
]
- def default_config(self):
+ def default_config(self) -> JsonDict:
config = super().default_config()
config["experimental_features"] = {"msc3874_enabled": True}
return config
- def prepare(self, reactor, clock, homeserver):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
self.user_id = self.register_user("test", "test")
self.tok = self.login("test", "test")
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
@@ -130,7 +136,7 @@ class PaginationTestCase(HomeserverTestCase):
return [ev.event_id for ev in events]
- def test_filter_relation_senders(self):
+ def test_filter_relation_senders(self) -> None:
# Messages which second user reacted to.
filter = {"related_by_senders": [self.second_user_id]}
chunk = self._filter_messages(filter)
@@ -146,7 +152,7 @@ class PaginationTestCase(HomeserverTestCase):
chunk = self._filter_messages(filter)
self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2])
- def test_filter_relation_type(self):
+ def test_filter_relation_type(self) -> None:
# Messages which have annotations.
filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]}
chunk = self._filter_messages(filter)
@@ -167,7 +173,7 @@ class PaginationTestCase(HomeserverTestCase):
chunk = self._filter_messages(filter)
self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2])
- def test_filter_relation_senders_and_type(self):
+ def test_filter_relation_senders_and_type(self) -> None:
# Messages which second user reacted to.
filter = {
"related_by_senders": [self.second_user_id],
@@ -176,7 +182,7 @@ class PaginationTestCase(HomeserverTestCase):
chunk = self._filter_messages(filter)
self.assertEqual(chunk, [self.event_id_1])
- def test_duplicate_relation(self):
+ def test_duplicate_relation(self) -> None:
"""An event should only be returned once if there are multiple relations to it."""
self.helper.send_event(
room_id=self.room_id,
diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py
index e05daa285e..db9ee9955e 100644
--- a/tests/storage/test_transactions.py
+++ b/tests/storage/test_transactions.py
@@ -12,17 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
from synapse.storage.databases.main.transactions import DestinationRetryTimings
+from synapse.util import Clock
from synapse.util.retryutils import MAX_RETRY_INTERVAL
from tests.unittest import HomeserverTestCase
class TransactionStoreTestCase(HomeserverTestCase):
- def prepare(self, reactor, clock, homeserver):
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
self.store = homeserver.get_datastores().main
- def test_get_set_transactions(self):
+ def test_get_set_transactions(self) -> None:
"""Tests that we can successfully get a non-existent entry for
destination retries, as well as testing tht we can set and get
correctly.
@@ -44,18 +50,18 @@ class TransactionStoreTestCase(HomeserverTestCase):
r,
)
- def test_initial_set_transactions(self):
+ def test_initial_set_transactions(self) -> None:
"""Tests that we can successfully set the destination retries (there
was a bug around invalidating the cache that broke this)
"""
d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100)
self.get_success(d)
- def test_large_destination_retry(self):
+ def test_large_destination_retry(self) -> None:
d = self.store.set_destination_retry_timings(
"example.com", MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL
)
self.get_success(d)
- d = self.store.get_destination_retry_timings("example.com")
- self.get_success(d)
+ d2 = self.store.get_destination_retry_timings("example.com")
+ self.get_success(d2)
diff --git a/tests/storage/test_txn_limit.py b/tests/storage/test_txn_limit.py
index ace82cbf42..15ea4770bd 100644
--- a/tests/storage/test_txn_limit.py
+++ b/tests/storage/test_txn_limit.py
@@ -12,21 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
+from synapse.storage.types import Cursor
+from synapse.util import Clock
+
from tests import unittest
class SQLTransactionLimitTestCase(unittest.HomeserverTestCase):
"""Test SQL transaction limit doesn't break transactions."""
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(db_txn_limit=1000)
- def test_config(self):
+ def test_config(self) -> None:
db_config = self.hs.config.database.get_single_database()
self.assertEqual(db_config.config["txn_limit"], 1000)
- def test_select(self):
- def do_select(txn):
+ def test_select(self) -> None:
+ def do_select(txn: Cursor) -> None:
txn.execute("SELECT 1")
db_pool = self.hs.get_datastores().databases[0]
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 5b60cf5285..3ba896ecf3 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import re
from typing import Any, Dict, Set, Tuple
from unittest import mock
from unittest.mock import Mock, patch
@@ -30,6 +31,12 @@ from synapse.util import Clock
from tests.test_utils.event_injection import inject_member_event
from tests.unittest import HomeserverTestCase, override_config
+try:
+ import icu
+except ImportError:
+ icu = None # type: ignore
+
+
ALICE = "@alice:a"
BOB = "@bob:b"
BOBBY = "@bobby:a"
@@ -449,6 +456,12 @@ class UserDirectoryStoreTestCase(HomeserverTestCase):
)
@override_config({"user_directory": {"search_all_users": True}})
+ def test_search_user_limit_correct(self) -> None:
+ r = self.get_success(self.store.search_user_dir(ALICE, "bob", 1))
+ self.assertTrue(r["limited"])
+ self.assertEqual(1, len(r["results"]))
+
+ @override_config({"user_directory": {"search_all_users": True}})
def test_search_user_dir_stop_words(self) -> None:
"""Tests that a user can look up another user by searching for the start if its
display name even if that name happens to be a common English word that would
@@ -461,3 +474,39 @@ class UserDirectoryStoreTestCase(HomeserverTestCase):
r["results"][0],
{"user_id": BELA, "display_name": "Bela", "avatar_url": None},
)
+
+
+class UserDirectoryICUTestCase(HomeserverTestCase):
+ if not icu:
+ skip = "Requires PyICU"
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.user_dir_helper = GetUserDirectoryTables(self.store)
+
+ def test_icu_word_boundary(self) -> None:
+ """Tests that we correctly detect word boundaries when ICU (International
+ Components for Unicode) support is available.
+ """
+
+ display_name = "Gáo"
+
+ # This word is not broken down correctly by Python's regular expressions,
+ # likely because á is actually a lowercase a followed by a U+0301 combining
+ # acute accent. This is specifically something that ICU support fixes.
+ matches = re.findall(r"([\w\-]+)", display_name, re.UNICODE)
+ self.assertEqual(len(matches), 2)
+
+ self.get_success(
+ self.store.update_profile_in_user_dir(ALICE, display_name, None)
+ )
+ self.get_success(self.store.add_users_in_public_rooms("!room:id", (ALICE,)))
+
+ # Check that searching for this user yields the correct result.
+ r = self.get_success(self.store.search_user_dir(BOB, display_name, 10))
+ self.assertFalse(r["limited"])
+ self.assertEqual(len(r["results"]), 1)
+ self.assertDictEqual(
+ r["results"][0],
+ {"user_id": ALICE, "display_name": display_name, "avatar_url": None},
+ )
diff --git a/tests/storage/util/test_partial_state_events_tracker.py b/tests/storage/util/test_partial_state_events_tracker.py
index cae14151c0..0e3fc2a77f 100644
--- a/tests/storage/util/test_partial_state_events_tracker.py
+++ b/tests/storage/util/test_partial_state_events_tracker.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict
+from typing import Collection, Dict
from unittest import mock
from twisted.internet.defer import CancelledError, ensureDeferred
@@ -31,7 +31,7 @@ class PartialStateEventsTrackerTestCase(TestCase):
# the results to be returned by the mocked get_partial_state_events
self._events_dict: Dict[str, bool] = {}
- async def get_partial_state_events(events):
+ async def get_partial_state_events(events: Collection[str]) -> Dict[str, bool]:
return {e: self._events_dict[e] for e in events}
self.mock_store = mock.Mock(spec_set=["get_partial_state_events"])
@@ -39,7 +39,7 @@ class PartialStateEventsTrackerTestCase(TestCase):
self.tracker = PartialStateEventsTracker(self.mock_store)
- def test_does_not_block_for_full_state_events(self):
+ def test_does_not_block_for_full_state_events(self) -> None:
self._events_dict = {"event1": False, "event2": False}
self.successResultOf(
@@ -50,7 +50,7 @@ class PartialStateEventsTrackerTestCase(TestCase):
["event1", "event2"]
)
- def test_blocks_for_partial_state_events(self):
+ def test_blocks_for_partial_state_events(self) -> None:
self._events_dict = {"event1": True, "event2": False}
d = ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
@@ -62,12 +62,12 @@ class PartialStateEventsTrackerTestCase(TestCase):
self.tracker.notify_un_partial_stated("event1")
self.successResultOf(d)
- def test_un_partial_state_race(self):
+ def test_un_partial_state_race(self) -> None:
# if the event is un-partial-stated between the initial check and the
# registration of the listener, it should not block.
self._events_dict = {"event1": True, "event2": False}
- async def get_partial_state_events(events):
+ async def get_partial_state_events(events: Collection[str]) -> Dict[str, bool]:
res = {e: self._events_dict[e] for e in events}
# change the result for next time
self._events_dict = {"event1": False, "event2": False}
@@ -79,19 +79,19 @@ class PartialStateEventsTrackerTestCase(TestCase):
ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
)
- def test_un_partial_state_during_get_partial_state_events(self):
+ def test_un_partial_state_during_get_partial_state_events(self) -> None:
# we should correctly handle a call to notify_un_partial_stated during the
# second call to get_partial_state_events.
self._events_dict = {"event1": True, "event2": False}
- async def get_partial_state_events1(events):
+ async def get_partial_state_events1(events: Collection[str]) -> Dict[str, bool]:
self.mock_store.get_partial_state_events.side_effect = (
get_partial_state_events2
)
return {e: self._events_dict[e] for e in events}
- async def get_partial_state_events2(events):
+ async def get_partial_state_events2(events: Collection[str]) -> Dict[str, bool]:
self.tracker.notify_un_partial_stated("event1")
self._events_dict["event1"] = False
return {e: self._events_dict[e] for e in events}
@@ -102,7 +102,7 @@ class PartialStateEventsTrackerTestCase(TestCase):
ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
)
- def test_cancellation(self):
+ def test_cancellation(self) -> None:
self._events_dict = {"event1": True, "event2": False}
d1 = ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
@@ -127,12 +127,12 @@ class PartialCurrentStateTrackerTestCase(TestCase):
self.tracker = PartialCurrentStateTracker(self.mock_store)
- def test_does_not_block_for_full_state_rooms(self):
+ def test_does_not_block_for_full_state_rooms(self) -> None:
self.mock_store.is_partial_state_room.return_value = make_awaitable(False)
self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
- def test_blocks_for_partial_room_state(self):
+ def test_blocks_for_partial_room_state(self) -> None:
self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
d = ensureDeferred(self.tracker.await_full_state("room_id"))
@@ -144,10 +144,10 @@ class PartialCurrentStateTrackerTestCase(TestCase):
self.tracker.notify_un_partial_stated("room_id")
self.successResultOf(d)
- def test_un_partial_state_race(self):
+ def test_un_partial_state_race(self) -> None:
# We should correctly handle race between awaiting the state and us
# un-partialling the state
- async def is_partial_state_room(events):
+ async def is_partial_state_room(room_id: str) -> bool:
self.tracker.notify_un_partial_stated("room_id")
return True
@@ -155,7 +155,7 @@ class PartialCurrentStateTrackerTestCase(TestCase):
self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
- def test_cancellation(self):
+ def test_cancellation(self) -> None:
self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
d1 = ensureDeferred(self.tracker.await_full_state("room_id"))
diff --git a/tests/test_server.py b/tests/test_server.py
index 2d9a0257d4..d67d7722a4 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -174,7 +174,7 @@ class JsonResourceTests(unittest.TestCase):
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foobar"
)
- self.assertEqual(channel.code, 400)
+ self.assertEqual(channel.code, 404)
self.assertEqual(channel.json_body["error"], "Unrecognized request")
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py
index 9d5010bf92..91cac9822a 100644
--- a/tests/util/test_async_helpers.py
+++ b/tests/util/test_async_helpers.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import traceback
+from typing import Generator, List, NoReturn, Optional
from parameterized import parameterized_class
@@ -41,8 +42,8 @@ from tests.unittest import TestCase
class ObservableDeferredTest(TestCase):
- def test_succeed(self):
- origin_d = Deferred()
+ def test_succeed(self) -> None:
+ origin_d: "Deferred[int]" = Deferred()
observable = ObservableDeferred(origin_d)
observer1 = observable.observe()
@@ -52,16 +53,18 @@ class ObservableDeferredTest(TestCase):
self.assertFalse(observer2.called)
# check the first observer is called first
- def check_called_first(res):
+ def check_called_first(res: int) -> int:
self.assertFalse(observer2.called)
return res
observer1.addBoth(check_called_first)
# store the results
- results = [None, None]
+ results: List[Optional[ObservableDeferred[int]]] = [None, None]
- def check_val(res, idx):
+ def check_val(
+ res: ObservableDeferred[int], idx: int
+ ) -> ObservableDeferred[int]:
results[idx] = res
return res
@@ -72,8 +75,8 @@ class ObservableDeferredTest(TestCase):
self.assertEqual(results[0], 123, "observer 1 callback result")
self.assertEqual(results[1], 123, "observer 2 callback result")
- def test_failure(self):
- origin_d = Deferred()
+ def test_failure(self) -> None:
+ origin_d: Deferred = Deferred()
observable = ObservableDeferred(origin_d, consumeErrors=True)
observer1 = observable.observe()
@@ -83,16 +86,16 @@ class ObservableDeferredTest(TestCase):
self.assertFalse(observer2.called)
# check the first observer is called first
- def check_called_first(res):
+ def check_called_first(res: int) -> int:
self.assertFalse(observer2.called)
return res
observer1.addBoth(check_called_first)
# store the results
- results = [None, None]
+ results: List[Optional[ObservableDeferred[str]]] = [None, None]
- def check_val(res, idx):
+ def check_val(res: ObservableDeferred[str], idx: int) -> None:
results[idx] = res
return None
@@ -103,10 +106,12 @@ class ObservableDeferredTest(TestCase):
raise Exception("gah!")
except Exception as e:
origin_d.errback(e)
+ assert results[0] is not None
self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result")
+ assert results[1] is not None
self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result")
- def test_cancellation(self):
+ def test_cancellation(self) -> None:
"""Test that cancelling an observer does not affect other observers."""
origin_d: "Deferred[int]" = Deferred()
observable = ObservableDeferred(origin_d, consumeErrors=True)
@@ -136,37 +141,38 @@ class ObservableDeferredTest(TestCase):
class TimeoutDeferredTest(TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.clock = Clock()
- def test_times_out(self):
+ def test_times_out(self) -> None:
"""Basic test case that checks that the original deferred is cancelled and that
the timing-out deferred is errbacked
"""
- cancelled = [False]
+ cancelled = False
- def canceller(_d):
- cancelled[0] = True
+ def canceller(_d: Deferred) -> None:
+ nonlocal cancelled
+ cancelled = True
- non_completing_d = Deferred(canceller)
+ non_completing_d: Deferred = Deferred(canceller)
timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
self.assertNoResult(timing_out_d)
- self.assertFalse(cancelled[0], "deferred was cancelled prematurely")
+ self.assertFalse(cancelled, "deferred was cancelled prematurely")
self.clock.pump((1.0,))
- self.assertTrue(cancelled[0], "deferred was not cancelled by timeout")
+ self.assertTrue(cancelled, "deferred was not cancelled by timeout")
self.failureResultOf(timing_out_d, defer.TimeoutError)
- def test_times_out_when_canceller_throws(self):
+ def test_times_out_when_canceller_throws(self) -> None:
"""Test that we have successfully worked around
https://twistedmatrix.com/trac/ticket/9534"""
- def canceller(_d):
+ def canceller(_d: Deferred) -> None:
raise Exception("can't cancel this deferred")
- non_completing_d = Deferred(canceller)
+ non_completing_d: Deferred = Deferred(canceller)
timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
self.assertNoResult(timing_out_d)
@@ -175,22 +181,24 @@ class TimeoutDeferredTest(TestCase):
self.failureResultOf(timing_out_d, defer.TimeoutError)
- def test_logcontext_is_preserved_on_cancellation(self):
- blocking_was_cancelled = [False]
+ def test_logcontext_is_preserved_on_cancellation(self) -> None:
+ blocking_was_cancelled = False
@defer.inlineCallbacks
- def blocking():
- non_completing_d = Deferred()
+ def blocking() -> Generator["Deferred[object]", object, None]:
+ nonlocal blocking_was_cancelled
+
+ non_completing_d: Deferred = Deferred()
with PreserveLoggingContext():
try:
yield non_completing_d
except CancelledError:
- blocking_was_cancelled[0] = True
+ blocking_was_cancelled = True
raise
with LoggingContext("one") as context_one:
# the errbacks should be run in the test logcontext
- def errback(res, deferred_name):
+ def errback(res: Failure, deferred_name: str) -> Failure:
self.assertIs(
current_context(),
context_one,
@@ -209,7 +217,7 @@ class TimeoutDeferredTest(TestCase):
self.clock.pump((1.0,))
self.assertTrue(
- blocking_was_cancelled[0], "non-completing deferred was not cancelled"
+ blocking_was_cancelled, "non-completing deferred was not cancelled"
)
self.failureResultOf(timing_out_d, defer.TimeoutError)
self.assertIs(current_context(), context_one)
@@ -220,13 +228,13 @@ class _TestException(Exception):
class ConcurrentlyExecuteTest(TestCase):
- def test_limits_runners(self):
+ def test_limits_runners(self) -> None:
"""If we have more tasks than runners, we should get the limit of runners"""
started = 0
waiters = []
processed = []
- async def callback(v):
+ async def callback(v: int) -> None:
# when we first enter, bump the start count
nonlocal started
started += 1
@@ -235,7 +243,7 @@ class ConcurrentlyExecuteTest(TestCase):
processed.append(v)
# wait for the goahead before returning
- d2 = Deferred()
+ d2: "Deferred[int]" = Deferred()
waiters.append(d2)
await d2
@@ -265,16 +273,16 @@ class ConcurrentlyExecuteTest(TestCase):
self.assertCountEqual(processed, [1, 2, 3, 4, 5])
self.successResultOf(d2)
- def test_preserves_stacktraces(self):
+ def test_preserves_stacktraces(self) -> None:
"""Test that the stacktrace from an exception thrown in the callback is preserved"""
- d1 = Deferred()
+ d1: "Deferred[int]" = Deferred()
- async def callback(v):
+ async def callback(v: int) -> None:
# alas, this doesn't work at all without an await here
await d1
raise _TestException("bah")
- async def caller():
+ async def caller() -> None:
try:
await concurrently_execute(callback, [1], 2)
except _TestException as e:
@@ -290,17 +298,17 @@ class ConcurrentlyExecuteTest(TestCase):
d1.callback(0)
self.successResultOf(d2)
- def test_preserves_stacktraces_on_preformed_failure(self):
+ def test_preserves_stacktraces_on_preformed_failure(self) -> None:
"""Test that the stacktrace on a Failure returned by the callback is preserved"""
- d1 = Deferred()
+ d1: "Deferred[int]" = Deferred()
f = Failure(_TestException("bah"))
- async def callback(v):
+ async def callback(v: int) -> None:
# alas, this doesn't work at all without an await here
await d1
await defer.fail(f)
- async def caller():
+ async def caller() -> None:
try:
await concurrently_execute(callback, [1], 2)
except _TestException as e:
@@ -336,7 +344,7 @@ class CancellationWrapperTests(TestCase):
else:
raise ValueError(f"Unsupported wrapper type: {self.wrapper}")
- def test_succeed(self):
+ def test_succeed(self) -> None:
"""Test that the new `Deferred` receives the result."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = self.wrap_deferred(deferred)
@@ -346,7 +354,7 @@ class CancellationWrapperTests(TestCase):
self.assertTrue(wrapper_deferred.called)
self.assertEqual("success", self.successResultOf(wrapper_deferred))
- def test_failure(self):
+ def test_failure(self) -> None:
"""Test that the new `Deferred` receives the `Failure`."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = self.wrap_deferred(deferred)
@@ -361,7 +369,7 @@ class CancellationWrapperTests(TestCase):
class StopCancellationTests(TestCase):
"""Tests for the `stop_cancellation` function."""
- def test_cancellation(self):
+ def test_cancellation(self) -> None:
"""Test that cancellation of the new `Deferred` leaves the original running."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = stop_cancellation(deferred)
@@ -384,7 +392,7 @@ class StopCancellationTests(TestCase):
class DelayCancellationTests(TestCase):
"""Tests for the `delay_cancellation` function."""
- def test_deferred_cancellation(self):
+ def test_deferred_cancellation(self) -> None:
"""Test that cancellation of the new `Deferred` waits for the original."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = delay_cancellation(deferred)
@@ -405,12 +413,12 @@ class DelayCancellationTests(TestCase):
# Now that the original `Deferred` has failed, we should get a `CancelledError`.
self.failureResultOf(wrapper_deferred, CancelledError)
- def test_coroutine_cancellation(self):
+ def test_coroutine_cancellation(self) -> None:
"""Test that cancellation of the new `Deferred` waits for the original."""
blocking_deferred: "Deferred[None]" = Deferred()
completion_deferred: "Deferred[None]" = Deferred()
- async def task():
+ async def task() -> NoReturn:
await blocking_deferred
completion_deferred.callback(None)
# Raise an exception. Twisted should consume it, otherwise unwanted
@@ -434,7 +442,7 @@ class DelayCancellationTests(TestCase):
# Now that the original coroutine has failed, we should get a `CancelledError`.
self.failureResultOf(wrapper_deferred, CancelledError)
- def test_suppresses_second_cancellation(self):
+ def test_suppresses_second_cancellation(self) -> None:
"""Test that a second cancellation is suppressed.
Identical to `test_cancellation` except the new `Deferred` is cancelled twice.
@@ -459,7 +467,7 @@ class DelayCancellationTests(TestCase):
# Now that the original `Deferred` has failed, we should get a `CancelledError`.
self.failureResultOf(wrapper_deferred, CancelledError)
- def test_propagates_cancelled_error(self):
+ def test_propagates_cancelled_error(self) -> None:
"""Test that a `CancelledError` from the original `Deferred` gets propagated."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = delay_cancellation(deferred)
@@ -472,14 +480,14 @@ class DelayCancellationTests(TestCase):
self.assertTrue(wrapper_deferred.called)
self.assertIs(cancelled_error, self.failureResultOf(wrapper_deferred).value)
- def test_preserves_logcontext(self):
+ def test_preserves_logcontext(self) -> None:
"""Test that logging contexts are preserved."""
blocking_d: "Deferred[None]" = Deferred()
- async def inner():
+ async def inner() -> None:
await make_deferred_yieldable(blocking_d)
- async def outer():
+ async def outer() -> None:
with LoggingContext("c") as c:
try:
await delay_cancellation(inner())
@@ -503,7 +511,7 @@ class DelayCancellationTests(TestCase):
class AwakenableSleeperTests(TestCase):
"Tests AwakenableSleeper"
- def test_sleep(self):
+ def test_sleep(self) -> None:
reactor, _ = get_clock()
sleeper = AwakenableSleeper(reactor)
@@ -518,7 +526,7 @@ class AwakenableSleeperTests(TestCase):
reactor.advance(0.6)
self.assertTrue(d.called)
- def test_explicit_wake(self):
+ def test_explicit_wake(self) -> None:
reactor, _ = get_clock()
sleeper = AwakenableSleeper(reactor)
@@ -535,7 +543,7 @@ class AwakenableSleeperTests(TestCase):
reactor.advance(0.6)
- def test_multiple_sleepers_timeout(self):
+ def test_multiple_sleepers_timeout(self) -> None:
reactor, _ = get_clock()
sleeper = AwakenableSleeper(reactor)
@@ -555,7 +563,7 @@ class AwakenableSleeperTests(TestCase):
reactor.advance(0.6)
self.assertTrue(d2.called)
- def test_multiple_sleepers_wake(self):
+ def test_multiple_sleepers_wake(self) -> None:
reactor, _ = get_clock()
sleeper = AwakenableSleeper(reactor)
diff --git a/tests/util/test_batching_queue.py b/tests/util/test_batching_queue.py
index 07be57d72c..94ef91f645 100644
--- a/tests/util/test_batching_queue.py
+++ b/tests/util/test_batching_queue.py
@@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List, Tuple
+
+from prometheus_client import Gauge
+
from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable
@@ -26,7 +30,7 @@ from tests.unittest import TestCase
class BatchingQueueTestCase(TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.clock, hs_clock = get_clock()
# We ensure that we remove any existing metrics for "test_queue".
@@ -37,25 +41,27 @@ class BatchingQueueTestCase(TestCase):
except KeyError:
pass
- self._pending_calls = []
- self.queue = BatchingQueue("test_queue", hs_clock, self._process_queue)
+ self._pending_calls: List[Tuple[List[str], defer.Deferred]] = []
+ self.queue: BatchingQueue[str, str] = BatchingQueue(
+ "test_queue", hs_clock, self._process_queue
+ )
- async def _process_queue(self, values):
- d = defer.Deferred()
+ async def _process_queue(self, values: List[str]) -> str:
+ d: "defer.Deferred[str]" = defer.Deferred()
self._pending_calls.append((values, d))
return await make_deferred_yieldable(d)
- def _get_sample_with_name(self, metric, name) -> int:
+ def _get_sample_with_name(self, metric: Gauge, name: str) -> float:
"""For a prometheus metric get the value of the sample that has a
matching "name" label.
"""
- for sample in metric.collect()[0].samples:
+ for sample in next(iter(metric.collect())).samples:
if sample.labels.get("name") == name:
return sample.value
self.fail("Found no matching sample")
- def _assert_metrics(self, queued, keys, in_flight):
+ def _assert_metrics(self, queued: int, keys: int, in_flight: int) -> None:
"""Assert that the metrics are correct"""
sample = self._get_sample_with_name(number_queued, self.queue._name)
@@ -75,7 +81,7 @@ class BatchingQueueTestCase(TestCase):
"number_in_flight",
)
- def test_simple(self):
+ def test_simple(self) -> None:
"""Tests the basic case of calling `add_to_queue` once and having
`_process_queue` return.
"""
@@ -106,7 +112,7 @@ class BatchingQueueTestCase(TestCase):
self._assert_metrics(queued=0, keys=0, in_flight=0)
- def test_batching(self):
+ def test_batching(self) -> None:
"""Test that multiple calls at the same time get batched up into one
call to `_process_queue`.
"""
@@ -134,7 +140,7 @@ class BatchingQueueTestCase(TestCase):
self.assertEqual(self.successResultOf(queue_d2), "bar")
self._assert_metrics(queued=0, keys=0, in_flight=0)
- def test_queuing(self):
+ def test_queuing(self) -> None:
"""Test that we queue up requests while a `_process_queue` is being
called.
"""
@@ -184,7 +190,7 @@ class BatchingQueueTestCase(TestCase):
self.assertEqual(self.successResultOf(queue_d3), "bar2")
self._assert_metrics(queued=0, keys=0, in_flight=0)
- def test_different_keys(self):
+ def test_different_keys(self) -> None:
"""Test that calls to different keys get processed in parallel."""
self.assertFalse(self._pending_calls)
diff --git a/tests/util/test_check_dependencies.py b/tests/util/test_check_dependencies.py
index 6913de24b9..aa20fe6780 100644
--- a/tests/util/test_check_dependencies.py
+++ b/tests/util/test_check_dependencies.py
@@ -1,5 +1,20 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from contextlib import contextmanager
-from typing import Generator, Optional
+from os import PathLike
+from typing import Generator, Optional, Union
from unittest.mock import patch
from synapse.util.check_dependencies import (
@@ -12,17 +27,17 @@ from tests.unittest import TestCase
class DummyDistribution(metadata.Distribution):
- def __init__(self, version: object):
+ def __init__(self, version: str):
self._version = version
@property
- def version(self):
+ def version(self) -> str:
return self._version
- def locate_file(self, path):
+ def locate_file(self, path: Union[str, PathLike]) -> PathLike:
raise NotImplementedError()
- def read_text(self, filename):
+ def read_text(self, filename: str) -> None:
raise NotImplementedError()
@@ -30,7 +45,7 @@ old = DummyDistribution("0.1.2")
old_release_candidate = DummyDistribution("0.1.2rc3")
new = DummyDistribution("1.2.3")
new_release_candidate = DummyDistribution("1.2.3rc4")
-distribution_with_no_version = DummyDistribution(None)
+distribution_with_no_version = DummyDistribution(None) # type: ignore[arg-type]
# could probably use stdlib TestCase --- no need for twisted here
@@ -45,7 +60,7 @@ class TestDependencyChecker(TestCase):
If `distribution = None`, we pretend that the package is not installed.
"""
- def mock_distribution(name: str):
+ def mock_distribution(name: str) -> DummyDistribution:
if distribution is None:
raise metadata.PackageNotFoundError
else:
diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py
index e8b6246ab5..acb251bfea 100644
--- a/tests/util/test_dict_cache.py
+++ b/tests/util/test_dict_cache.py
@@ -19,10 +19,12 @@ from tests import unittest
class DictCacheTestCase(unittest.TestCase):
- def setUp(self):
- self.cache = DictionaryCache("foobar", max_entries=10)
+ def setUp(self) -> None:
+ self.cache: DictionaryCache[str, str, str] = DictionaryCache(
+ "foobar", max_entries=10
+ )
- def test_simple_cache_hit_full(self):
+ def test_simple_cache_hit_full(self) -> None:
key = "test_simple_cache_hit_full"
v = self.cache.get(key)
@@ -37,7 +39,7 @@ class DictCacheTestCase(unittest.TestCase):
c = self.cache.get(key)
self.assertEqual(test_value, c.value)
- def test_simple_cache_hit_partial(self):
+ def test_simple_cache_hit_partial(self) -> None:
key = "test_simple_cache_hit_partial"
seq = self.cache.sequence
@@ -47,7 +49,7 @@ class DictCacheTestCase(unittest.TestCase):
c = self.cache.get(key, ["test"])
self.assertEqual(test_value, c.value)
- def test_simple_cache_miss_partial(self):
+ def test_simple_cache_miss_partial(self) -> None:
key = "test_simple_cache_miss_partial"
seq = self.cache.sequence
@@ -57,7 +59,7 @@ class DictCacheTestCase(unittest.TestCase):
c = self.cache.get(key, ["test2"])
self.assertEqual({}, c.value)
- def test_simple_cache_hit_miss_partial(self):
+ def test_simple_cache_hit_miss_partial(self) -> None:
key = "test_simple_cache_hit_miss_partial"
seq = self.cache.sequence
@@ -71,7 +73,7 @@ class DictCacheTestCase(unittest.TestCase):
c = self.cache.get(key, ["test2"])
self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value)
- def test_multi_insert(self):
+ def test_multi_insert(self) -> None:
key = "test_simple_cache_hit_miss_partial"
seq = self.cache.sequence
@@ -92,7 +94,7 @@ class DictCacheTestCase(unittest.TestCase):
)
self.assertEqual(c.full, False)
- def test_invalidation(self):
+ def test_invalidation(self) -> None:
"""Test that the partial dict and full dicts get invalidated
separately.
"""
@@ -106,7 +108,7 @@ class DictCacheTestCase(unittest.TestCase):
# entry for "a" warm.
for i in range(20):
self.cache.get(key, ["a"])
- self.cache.update(seq, f"key{i}", {1: 2})
+ self.cache.update(seq, f"key{i}", {"1": "2"})
# We should have evicted the full dict...
r = self.cache.get(key)
diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py
index 7f60aae5ba..9cf920daf8 100644
--- a/tests/util/test_expiring_cache.py
+++ b/tests/util/test_expiring_cache.py
@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List, cast
+from synapse.util import Clock
from synapse.util.caches.expiringcache import ExpiringCache
from tests.utils import MockClock
@@ -21,17 +23,21 @@ from .. import unittest
class ExpiringCacheTestCase(unittest.HomeserverTestCase):
- def test_get_set(self):
+ def test_get_set(self) -> None:
clock = MockClock()
- cache = ExpiringCache("test", clock, max_len=1)
+ cache: ExpiringCache[str, str] = ExpiringCache(
+ "test", cast(Clock, clock), max_len=1
+ )
cache["key"] = "value"
self.assertEqual(cache.get("key"), "value")
self.assertEqual(cache["key"], "value")
- def test_eviction(self):
+ def test_eviction(self) -> None:
clock = MockClock()
- cache = ExpiringCache("test", clock, max_len=2)
+ cache: ExpiringCache[str, str] = ExpiringCache(
+ "test", cast(Clock, clock), max_len=2
+ )
cache["key"] = "value"
cache["key2"] = "value2"
@@ -43,9 +49,11 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get("key2"), "value2")
self.assertEqual(cache.get("key3"), "value3")
- def test_iterable_eviction(self):
+ def test_iterable_eviction(self) -> None:
clock = MockClock()
- cache = ExpiringCache("test", clock, max_len=5, iterable=True)
+ cache: ExpiringCache[str, List[int]] = ExpiringCache(
+ "test", cast(Clock, clock), max_len=5, iterable=True
+ )
cache["key"] = [1]
cache["key2"] = [2, 3]
@@ -61,9 +69,11 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get("key3"), [4, 5])
self.assertEqual(cache.get("key4"), [6, 7])
- def test_time_eviction(self):
+ def test_time_eviction(self) -> None:
clock = MockClock()
- cache = ExpiringCache("test", clock, expiry_ms=1000)
+ cache: ExpiringCache[str, int] = ExpiringCache(
+ "test", cast(Clock, clock), expiry_ms=1000
+ )
cache["key"] = 1
clock.advance_time(0.5)
diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py
index 3bb4695405..4f3c983c15 100644
--- a/tests/util/test_file_consumer.py
+++ b/tests/util/test_file_consumer.py
@@ -12,22 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import threading
-from io import StringIO
+from io import BytesIO
+from typing import BinaryIO, Generator, Optional, cast
from unittest.mock import NonCallableMock
-from twisted.internet import defer, reactor
+from zope.interface import implementer
+
+from twisted.internet import defer, reactor as _reactor
+from twisted.internet.interfaces import IPullProducer
+from synapse.types import ISynapseReactor
from synapse.util.file_consumer import BackgroundFileConsumer
from tests import unittest
+reactor = cast(ISynapseReactor, _reactor)
+
class FileConsumerTests(unittest.TestCase):
@defer.inlineCallbacks
- def test_pull_consumer(self):
- string_file = StringIO()
+ def test_pull_consumer(self) -> Generator["defer.Deferred[object]", object, None]:
+ string_file = BytesIO()
consumer = BackgroundFileConsumer(string_file, reactor=reactor)
try:
@@ -35,55 +41,57 @@ class FileConsumerTests(unittest.TestCase):
yield producer.register_with_consumer(consumer)
- yield producer.write_and_wait("Foo")
+ yield producer.write_and_wait(b"Foo")
- self.assertEqual(string_file.getvalue(), "Foo")
+ self.assertEqual(string_file.getvalue(), b"Foo")
- yield producer.write_and_wait("Bar")
+ yield producer.write_and_wait(b"Bar")
- self.assertEqual(string_file.getvalue(), "FooBar")
+ self.assertEqual(string_file.getvalue(), b"FooBar")
finally:
consumer.unregisterProducer()
- yield consumer.wait()
+ yield consumer.wait() # type: ignore[misc]
self.assertTrue(string_file.closed)
@defer.inlineCallbacks
- def test_push_consumer(self):
- string_file = BlockingStringWrite()
- consumer = BackgroundFileConsumer(string_file, reactor=reactor)
+ def test_push_consumer(self) -> Generator["defer.Deferred[object]", object, None]:
+ string_file = BlockingBytesWrite()
+ consumer = BackgroundFileConsumer(cast(BinaryIO, string_file), reactor=reactor)
try:
producer = NonCallableMock(spec_set=[])
consumer.registerProducer(producer, True)
- consumer.write("Foo")
- yield string_file.wait_for_n_writes(1)
+ consumer.write(b"Foo")
+ yield string_file.wait_for_n_writes(1) # type: ignore[misc]
- self.assertEqual(string_file.buffer, "Foo")
+ self.assertEqual(string_file.buffer, b"Foo")
- consumer.write("Bar")
- yield string_file.wait_for_n_writes(2)
+ consumer.write(b"Bar")
+ yield string_file.wait_for_n_writes(2) # type: ignore[misc]
- self.assertEqual(string_file.buffer, "FooBar")
+ self.assertEqual(string_file.buffer, b"FooBar")
finally:
consumer.unregisterProducer()
- yield consumer.wait()
+ yield consumer.wait() # type: ignore[misc]
self.assertTrue(string_file.closed)
@defer.inlineCallbacks
- def test_push_producer_feedback(self):
- string_file = BlockingStringWrite()
- consumer = BackgroundFileConsumer(string_file, reactor=reactor)
+ def test_push_producer_feedback(
+ self,
+ ) -> Generator["defer.Deferred[object]", object, None]:
+ string_file = BlockingBytesWrite()
+ consumer = BackgroundFileConsumer(cast(BinaryIO, string_file), reactor=reactor)
try:
producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"])
- resume_deferred = defer.Deferred()
+ resume_deferred: defer.Deferred = defer.Deferred()
producer.resumeProducing.side_effect = lambda: resume_deferred.callback(
None
)
@@ -93,65 +101,72 @@ class FileConsumerTests(unittest.TestCase):
number_writes = 0
with string_file.write_lock:
for _ in range(consumer._PAUSE_ON_QUEUE_SIZE):
- consumer.write("Foo")
+ consumer.write(b"Foo")
number_writes += 1
producer.pauseProducing.assert_called_once()
- yield string_file.wait_for_n_writes(number_writes)
+ yield string_file.wait_for_n_writes(number_writes) # type: ignore[misc]
yield resume_deferred
producer.resumeProducing.assert_called_once()
finally:
consumer.unregisterProducer()
- yield consumer.wait()
+ yield consumer.wait() # type: ignore[misc]
self.assertTrue(string_file.closed)
+@implementer(IPullProducer)
class DummyPullProducer:
- def __init__(self):
- self.consumer = None
- self.deferred = defer.Deferred()
+ def __init__(self) -> None:
+ self.consumer: Optional[BackgroundFileConsumer] = None
+ self.deferred: "defer.Deferred[object]" = defer.Deferred()
- def resumeProducing(self):
+ def resumeProducing(self) -> None:
d = self.deferred
self.deferred = defer.Deferred()
d.callback(None)
- def write_and_wait(self, bytes):
+ def stopProducing(self) -> None:
+ raise RuntimeError("Unexpected call")
+
+ def write_and_wait(self, write_bytes: bytes) -> "defer.Deferred[object]":
+ assert self.consumer is not None
d = self.deferred
- self.consumer.write(bytes)
+ self.consumer.write(write_bytes)
return d
- def register_with_consumer(self, consumer):
+ def register_with_consumer(
+ self, consumer: BackgroundFileConsumer
+ ) -> "defer.Deferred[object]":
d = self.deferred
self.consumer = consumer
self.consumer.registerProducer(self, False)
return d
-class BlockingStringWrite:
- def __init__(self):
- self.buffer = ""
+class BlockingBytesWrite:
+ def __init__(self) -> None:
+ self.buffer = b""
self.closed = False
self.write_lock = threading.Lock()
- self._notify_write_deferred = None
+ self._notify_write_deferred: Optional[defer.Deferred] = None
self._number_of_writes = 0
- def write(self, bytes):
+ def write(self, write_bytes: bytes) -> None:
with self.write_lock:
- self.buffer += bytes
+ self.buffer += write_bytes
self._number_of_writes += 1
reactor.callFromThread(self._notify_write)
- def close(self):
+ def close(self) -> None:
self.closed = True
- def _notify_write(self):
+ def _notify_write(self) -> None:
"Called by write to indicate a write happened"
with self.write_lock:
if not self._notify_write_deferred:
@@ -161,7 +176,9 @@ class BlockingStringWrite:
d.callback(None)
@defer.inlineCallbacks
- def wait_for_n_writes(self, n):
+ def wait_for_n_writes(
+ self, n: int
+ ) -> Generator["defer.Deferred[object]", object, None]:
"Wait for n writes to have happened"
while True:
with self.write_lock:
diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py
index 3c0ddd4f18..406c16cdcf 100644
--- a/tests/util/test_itertools.py
+++ b/tests/util/test_itertools.py
@@ -19,7 +19,7 @@ from tests.unittest import TestCase
class ChunkSeqTests(TestCase):
- def test_short_seq(self):
+ def test_short_seq(self) -> None:
parts = chunk_seq("123", 8)
self.assertEqual(
@@ -27,7 +27,7 @@ class ChunkSeqTests(TestCase):
["123"],
)
- def test_long_seq(self):
+ def test_long_seq(self) -> None:
parts = chunk_seq("abcdefghijklmnop", 8)
self.assertEqual(
@@ -35,7 +35,7 @@ class ChunkSeqTests(TestCase):
["abcdefgh", "ijklmnop"],
)
- def test_uneven_parts(self):
+ def test_uneven_parts(self) -> None:
parts = chunk_seq("abcdefghijklmnop", 5)
self.assertEqual(
@@ -43,7 +43,7 @@ class ChunkSeqTests(TestCase):
["abcde", "fghij", "klmno", "p"],
)
- def test_empty_input(self):
+ def test_empty_input(self) -> None:
parts: Iterable[Sequence] = chunk_seq([], 5)
self.assertEqual(
@@ -53,13 +53,13 @@ class ChunkSeqTests(TestCase):
class SortTopologically(TestCase):
- def test_empty(self):
+ def test_empty(self) -> None:
"Test that an empty graph works correctly"
graph: Dict[int, List[int]] = {}
self.assertEqual(list(sorted_topologically([], graph)), [])
- def test_handle_empty_graph(self):
+ def test_handle_empty_graph(self) -> None:
"Test that a graph where a node doesn't have an entry is treated as empty"
graph: Dict[int, List[int]] = {}
@@ -67,7 +67,7 @@ class SortTopologically(TestCase):
# For disconnected nodes the output is simply sorted.
self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
- def test_disconnected(self):
+ def test_disconnected(self) -> None:
"Test that a graph with no edges work"
graph: Dict[int, List[int]] = {1: [], 2: []}
@@ -75,20 +75,20 @@ class SortTopologically(TestCase):
# For disconnected nodes the output is simply sorted.
self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
- def test_linear(self):
+ def test_linear(self) -> None:
"Test that a simple `4 -> 3 -> 2 -> 1` graph works"
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]}
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
- def test_subset(self):
+ def test_subset(self) -> None:
"Test that only sorting a subset of the graph works"
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]}
self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4])
- def test_fork(self):
+ def test_fork(self) -> None:
"Test that a forked graph works"
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [1], 4: [2, 3]}
@@ -96,13 +96,13 @@ class SortTopologically(TestCase):
# always get the same one.
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
- def test_duplicates(self):
+ def test_duplicates(self) -> None:
"Test that a graph with duplicate edges work"
graph: Dict[int, List[int]] = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]}
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
- def test_multiple_paths(self):
+ def test_multiple_paths(self) -> None:
"Test that a graph with multiple paths between two nodes work"
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]}
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 2ad321e184..d64c162e1d 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -1,5 +1,21 @@
+# Copyright 2014-2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, Generator, cast
+
import twisted.python.failure
-from twisted.internet import defer, reactor
+from twisted.internet import defer, reactor as _reactor
from synapse.logging.context import (
SENTINEL_CONTEXT,
@@ -10,25 +26,30 @@ from synapse.logging.context import (
nested_logging_context,
run_in_background,
)
+from synapse.types import ISynapseReactor
from synapse.util import Clock
from .. import unittest
+reactor = cast(ISynapseReactor, _reactor)
+
class LoggingContextTestCase(unittest.TestCase):
- def _check_test_key(self, value):
- self.assertEqual(current_context().name, value)
+ def _check_test_key(self, value: str) -> None:
+ context = current_context()
+ assert isinstance(context, LoggingContext)
+ self.assertEqual(context.name, value)
- def test_with_context(self):
+ def test_with_context(self) -> None:
with LoggingContext("test"):
self._check_test_key("test")
@defer.inlineCallbacks
- def test_sleep(self):
+ def test_sleep(self) -> Generator["defer.Deferred[object]", object, None]:
clock = Clock(reactor)
@defer.inlineCallbacks
- def competing_callback():
+ def competing_callback() -> Generator["defer.Deferred[object]", object, None]:
with LoggingContext("competing"):
yield clock.sleep(0)
self._check_test_key("competing")
@@ -39,17 +60,18 @@ class LoggingContextTestCase(unittest.TestCase):
yield clock.sleep(0)
self._check_test_key("one")
- def _test_run_in_background(self, function):
+ def _test_run_in_background(self, function: Callable[[], object]) -> defer.Deferred:
sentinel_context = current_context()
- callback_completed = [False]
+ callback_completed = False
with LoggingContext("one"):
# fire off function, but don't wait on it.
d2 = run_in_background(function)
- def cb(res):
- callback_completed[0] = True
+ def cb(res: object) -> object:
+ nonlocal callback_completed
+ callback_completed = True
return res
d2.addCallback(cb)
@@ -60,8 +82,8 @@ class LoggingContextTestCase(unittest.TestCase):
# the logcontext is left in a sane state.
d2 = defer.Deferred()
- def check_logcontext():
- if not callback_completed[0]:
+ def check_logcontext() -> None:
+ if not callback_completed:
reactor.callLater(0.01, check_logcontext)
return
@@ -78,31 +100,31 @@ class LoggingContextTestCase(unittest.TestCase):
# test is done once d2 finishes
return d2
- def test_run_in_background_with_blocking_fn(self):
+ def test_run_in_background_with_blocking_fn(self) -> defer.Deferred:
@defer.inlineCallbacks
- def blocking_function():
+ def blocking_function() -> Generator["defer.Deferred[object]", object, None]:
yield Clock(reactor).sleep(0)
return self._test_run_in_background(blocking_function)
- def test_run_in_background_with_non_blocking_fn(self):
+ def test_run_in_background_with_non_blocking_fn(self) -> defer.Deferred:
@defer.inlineCallbacks
- def nonblocking_function():
+ def nonblocking_function() -> Generator["defer.Deferred[object]", object, None]:
with PreserveLoggingContext():
yield defer.succeed(None)
return self._test_run_in_background(nonblocking_function)
- def test_run_in_background_with_chained_deferred(self):
+ def test_run_in_background_with_chained_deferred(self) -> defer.Deferred:
# a function which returns a deferred which looks like it has been
# called, but is actually paused
- def testfunc():
+ def testfunc() -> defer.Deferred:
return make_deferred_yieldable(_chained_deferred_function())
return self._test_run_in_background(testfunc)
- def test_run_in_background_with_coroutine(self):
- async def testfunc():
+ def test_run_in_background_with_coroutine(self) -> defer.Deferred:
+ async def testfunc() -> None:
self._check_test_key("one")
d = Clock(reactor).sleep(0)
self.assertIs(current_context(), SENTINEL_CONTEXT)
@@ -111,18 +133,20 @@ class LoggingContextTestCase(unittest.TestCase):
return self._test_run_in_background(testfunc)
- def test_run_in_background_with_nonblocking_coroutine(self):
- async def testfunc():
+ def test_run_in_background_with_nonblocking_coroutine(self) -> defer.Deferred:
+ async def testfunc() -> None:
self._check_test_key("one")
return self._test_run_in_background(testfunc)
@defer.inlineCallbacks
- def test_make_deferred_yieldable(self):
+ def test_make_deferred_yieldable(
+ self,
+ ) -> Generator["defer.Deferred[object]", object, None]:
# a function which returns an incomplete deferred, but doesn't follow
# the synapse rules.
- def blocking_function():
- d = defer.Deferred()
+ def blocking_function() -> defer.Deferred:
+ d: defer.Deferred = defer.Deferred()
reactor.callLater(0, d.callback, None)
return d
@@ -139,7 +163,9 @@ class LoggingContextTestCase(unittest.TestCase):
self._check_test_key("one")
@defer.inlineCallbacks
- def test_make_deferred_yieldable_with_chained_deferreds(self):
+ def test_make_deferred_yieldable_with_chained_deferreds(
+ self,
+ ) -> Generator["defer.Deferred[object]", object, None]:
sentinel_context = current_context()
with LoggingContext("one"):
@@ -152,7 +178,7 @@ class LoggingContextTestCase(unittest.TestCase):
# now it should be restored
self._check_test_key("one")
- def test_nested_logging_context(self):
+ def test_nested_logging_context(self) -> None:
with LoggingContext("foo"):
nested_context = nested_logging_context(suffix="bar")
self.assertEqual(nested_context.name, "foo-bar")
@@ -161,11 +187,11 @@ class LoggingContextTestCase(unittest.TestCase):
# a function which returns a deferred which has been "called", but
# which had a function which returned another incomplete deferred on
# its callback list, so won't yet call any other new callbacks.
-def _chained_deferred_function():
+def _chained_deferred_function() -> defer.Deferred:
d = defer.succeed(None)
- def cb(res):
- d2 = defer.Deferred()
+ def cb(res: object) -> defer.Deferred:
+ d2: defer.Deferred = defer.Deferred()
reactor.callLater(0, d2.callback, res)
return d2
diff --git a/tests/util/test_logformatter.py b/tests/util/test_logformatter.py
index a2e08281e6..0dee69a6fe 100644
--- a/tests/util/test_logformatter.py
+++ b/tests/util/test_logformatter.py
@@ -23,7 +23,7 @@ class TestException(Exception):
class LogFormatterTestCase(unittest.TestCase):
- def test_formatter(self):
+ def test_formatter(self) -> None:
formatter = LogFormatter()
try:
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index 67173a4f5b..1fc5a473f0 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -13,10 +13,11 @@
# limitations under the License.
-from typing import List
+from typing import List, Tuple
from unittest.mock import Mock, patch
from synapse.metrics.jemalloc import JemallocStats
+from synapse.types import JsonDict
from synapse.util.caches.lrucache import LruCache, setup_expire_lru_cache_entries
from synapse.util.caches.treecache import TreeCache
@@ -25,14 +26,14 @@ from tests.unittest import override_config
class LruCacheTestCase(unittest.HomeserverTestCase):
- def test_get_set(self):
- cache = LruCache(1)
+ def test_get_set(self) -> None:
+ cache: LruCache[str, str] = LruCache(1)
cache["key"] = "value"
self.assertEqual(cache.get("key"), "value")
self.assertEqual(cache["key"], "value")
- def test_eviction(self):
- cache = LruCache(2)
+ def test_eviction(self) -> None:
+ cache: LruCache[int, int] = LruCache(2)
cache[1] = 1
cache[2] = 2
@@ -45,8 +46,8 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get(2), 2)
self.assertEqual(cache.get(3), 3)
- def test_setdefault(self):
- cache = LruCache(1)
+ def test_setdefault(self) -> None:
+ cache: LruCache[str, int] = LruCache(1)
self.assertEqual(cache.setdefault("key", 1), 1)
self.assertEqual(cache.get("key"), 1)
self.assertEqual(cache.setdefault("key", 2), 1)
@@ -54,14 +55,15 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
cache["key"] = 2 # Make sure overriding works.
self.assertEqual(cache.get("key"), 2)
- def test_pop(self):
- cache = LruCache(1)
+ def test_pop(self) -> None:
+ cache: LruCache[str, int] = LruCache(1)
cache["key"] = 1
self.assertEqual(cache.pop("key"), 1)
self.assertEqual(cache.pop("key"), None)
- def test_del_multi(self):
- cache = LruCache(4, cache_type=TreeCache)
+ def test_del_multi(self) -> None:
+ # The type here isn't quite correct as they don't handle TreeCache well.
+ cache: LruCache[Tuple[str, str], str] = LruCache(4, cache_type=TreeCache)
cache[("animal", "cat")] = "mew"
cache[("animal", "dog")] = "woof"
cache[("vehicles", "car")] = "vroom"
@@ -71,7 +73,7 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get(("animal", "cat")), "mew")
self.assertEqual(cache.get(("vehicles", "car")), "vroom")
- cache.del_multi(("animal",))
+ cache.del_multi(("animal",)) # type: ignore[arg-type]
self.assertEqual(len(cache), 2)
self.assertEqual(cache.get(("animal", "cat")), None)
self.assertEqual(cache.get(("animal", "dog")), None)
@@ -79,22 +81,22 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get(("vehicles", "train")), "chuff")
# Man from del_multi say "Yes".
- def test_clear(self):
- cache = LruCache(1)
+ def test_clear(self) -> None:
+ cache: LruCache[str, int] = LruCache(1)
cache["key"] = 1
cache.clear()
self.assertEqual(len(cache), 0)
@override_config({"caches": {"per_cache_factors": {"mycache": 10}}})
- def test_special_size(self):
- cache = LruCache(10, "mycache")
+ def test_special_size(self) -> None:
+ cache: LruCache = LruCache(10, "mycache")
self.assertEqual(cache.max_size, 100)
class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
- def test_get(self):
+ def test_get(self) -> None:
m = Mock()
- cache = LruCache(1)
+ cache: LruCache[str, str] = LruCache(1)
cache.set("key", "value")
self.assertFalse(m.called)
@@ -111,9 +113,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
cache.set("key", "value")
self.assertEqual(m.call_count, 1)
- def test_multi_get(self):
+ def test_multi_get(self) -> None:
m = Mock()
- cache = LruCache(1)
+ cache: LruCache[str, str] = LruCache(1)
cache.set("key", "value")
self.assertFalse(m.called)
@@ -130,9 +132,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
cache.set("key", "value")
self.assertEqual(m.call_count, 1)
- def test_set(self):
+ def test_set(self) -> None:
m = Mock()
- cache = LruCache(1)
+ cache: LruCache[str, str] = LruCache(1)
cache.set("key", "value", callbacks=[m])
self.assertFalse(m.called)
@@ -146,9 +148,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
cache.set("key", "value")
self.assertEqual(m.call_count, 1)
- def test_pop(self):
+ def test_pop(self) -> None:
m = Mock()
- cache = LruCache(1)
+ cache: LruCache[str, str] = LruCache(1)
cache.set("key", "value", callbacks=[m])
self.assertFalse(m.called)
@@ -162,12 +164,13 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
cache.pop("key")
self.assertEqual(m.call_count, 1)
- def test_del_multi(self):
+ def test_del_multi(self) -> None:
m1 = Mock()
m2 = Mock()
m3 = Mock()
m4 = Mock()
- cache = LruCache(4, cache_type=TreeCache)
+ # The type here isn't quite correct as they don't handle TreeCache well.
+ cache: LruCache[Tuple[str, str], str] = LruCache(4, cache_type=TreeCache)
cache.set(("a", "1"), "value", callbacks=[m1])
cache.set(("a", "2"), "value", callbacks=[m2])
@@ -179,17 +182,17 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
self.assertEqual(m3.call_count, 0)
self.assertEqual(m4.call_count, 0)
- cache.del_multi(("a",))
+ cache.del_multi(("a",)) # type: ignore[arg-type]
self.assertEqual(m1.call_count, 1)
self.assertEqual(m2.call_count, 1)
self.assertEqual(m3.call_count, 0)
self.assertEqual(m4.call_count, 0)
- def test_clear(self):
+ def test_clear(self) -> None:
m1 = Mock()
m2 = Mock()
- cache = LruCache(5)
+ cache: LruCache[str, str] = LruCache(5)
cache.set("key1", "value", callbacks=[m1])
cache.set("key2", "value", callbacks=[m2])
@@ -202,11 +205,11 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
self.assertEqual(m1.call_count, 1)
self.assertEqual(m2.call_count, 1)
- def test_eviction(self):
+ def test_eviction(self) -> None:
m1 = Mock(name="m1")
m2 = Mock(name="m2")
m3 = Mock(name="m3")
- cache = LruCache(2)
+ cache: LruCache[str, str] = LruCache(2)
cache.set("key1", "value", callbacks=[m1])
cache.set("key2", "value", callbacks=[m2])
@@ -241,8 +244,8 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
class LruCacheSizedTestCase(unittest.HomeserverTestCase):
- def test_evict(self):
- cache = LruCache(5, size_callback=len)
+ def test_evict(self) -> None:
+ cache: LruCache[str, List[int]] = LruCache(5, size_callback=len)
cache["key1"] = [0]
cache["key2"] = [1, 2]
cache["key3"] = [3]
@@ -269,6 +272,7 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase):
cache["key1"] = []
self.assertEqual(len(cache), 0)
+ assert isinstance(cache.cache, dict)
cache.cache["key1"].drop_from_cache()
self.assertIsNone(
cache.pop("key1"), "Cache entry should have been evicted but wasn't"
@@ -278,17 +282,17 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase):
class TimeEvictionTestCase(unittest.HomeserverTestCase):
"""Test that time based eviction works correctly."""
- def default_config(self):
+ def default_config(self) -> JsonDict:
config = super().default_config()
config.setdefault("caches", {})["expiry_time"] = "30m"
return config
- def test_evict(self):
+ def test_evict(self) -> None:
setup_expire_lru_cache_entries(self.hs)
- cache = LruCache(5, clock=self.hs.get_clock())
+ cache: LruCache[str, int] = LruCache(5, clock=self.hs.get_clock())
# Check that we evict entries we haven't accessed for 30 minutes.
cache["key1"] = 1
@@ -332,7 +336,7 @@ class MemoryEvictionTestCase(unittest.HomeserverTestCase):
}
)
@patch("synapse.util.caches.lrucache.get_jemalloc_stats")
- def test_evict_memory(self, jemalloc_interface) -> None:
+ def test_evict_memory(self, jemalloc_interface: Mock) -> None:
mock_jemalloc_class = Mock(spec=JemallocStats)
jemalloc_interface.return_value = mock_jemalloc_class
@@ -340,7 +344,7 @@ class MemoryEvictionTestCase(unittest.HomeserverTestCase):
mock_jemalloc_class.get_stat.return_value = 924288000
setup_expire_lru_cache_entries(self.hs)
- cache = LruCache(4, clock=self.hs.get_clock())
+ cache: LruCache[str, int] = LruCache(4, clock=self.hs.get_clock())
cache["key1"] = 1
cache["key2"] = 2
diff --git a/tests/util/test_macaroons.py b/tests/util/test_macaroons.py
index 40754a4711..f68377a05a 100644
--- a/tests/util/test_macaroons.py
+++ b/tests/util/test_macaroons.py
@@ -21,14 +21,14 @@ from tests.unittest import TestCase
class MacaroonGeneratorTestCase(TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.reactor, hs_clock = get_clock()
self.macaroon_generator = MacaroonGenerator(hs_clock, "tesths", b"verysecret")
self.other_macaroon_generator = MacaroonGenerator(
hs_clock, "tesths", b"anothersecretkey"
)
- def test_guest_access_token(self):
+ def test_guest_access_token(self) -> None:
"""Test the generation and verification of guest access tokens"""
token = self.macaroon_generator.generate_guest_access_token("@user:tesths")
user_id = self.macaroon_generator.verify_guest_token(token)
@@ -47,7 +47,7 @@ class MacaroonGeneratorTestCase(TestCase):
with self.assertRaises(MacaroonVerificationFailedException):
self.macaroon_generator.verify_guest_token(token)
- def test_delete_pusher_token(self):
+ def test_delete_pusher_token(self) -> None:
"""Test the generation and verification of delete_pusher tokens"""
token = self.macaroon_generator.generate_delete_pusher_token(
"@user:tesths", "m.mail", "john@example.com"
@@ -84,7 +84,7 @@ class MacaroonGeneratorTestCase(TestCase):
)
self.assertEqual(user_id, "@user:tesths")
- def test_oidc_session_token(self):
+ def test_oidc_session_token(self) -> None:
"""Test the generation and verification of OIDC session cookies"""
state = "arandomstate"
session_data = OidcSessionData(
diff --git a/tests/util/test_ratelimitutils.py b/tests/util/test_ratelimitutils.py
index 89d8656634..5b327b390e 100644
--- a/tests/util/test_ratelimitutils.py
+++ b/tests/util/test_ratelimitutils.py
@@ -13,16 +13,19 @@
# limitations under the License.
from typing import Optional
+from twisted.internet.defer import Deferred
+
from synapse.config.homeserver import HomeServerConfig
+from synapse.config.ratelimiting import FederationRatelimitSettings
from synapse.util.ratelimitutils import FederationRateLimiter
-from tests.server import get_clock
+from tests.server import ThreadedMemoryReactorClock, get_clock
from tests.unittest import TestCase
from tests.utils import default_config
class FederationRateLimiterTestCase(TestCase):
- def test_ratelimit(self):
+ def test_ratelimit(self) -> None:
"""A simple test with the default values"""
reactor, clock = get_clock()
rc_config = build_rc_config()
@@ -32,7 +35,7 @@ class FederationRateLimiterTestCase(TestCase):
# shouldn't block
self.successResultOf(d1)
- def test_concurrent_limit(self):
+ def test_concurrent_limit(self) -> None:
"""Test what happens when we hit the concurrent limit"""
reactor, clock = get_clock()
rc_config = build_rc_config({"rc_federation": {"concurrent": 2}})
@@ -56,7 +59,7 @@ class FederationRateLimiterTestCase(TestCase):
cm2.__exit__(None, None, None)
self.successResultOf(d3)
- def test_sleep_limit(self):
+ def test_sleep_limit(self) -> None:
"""Test what happens when we hit the sleep limit"""
reactor, clock = get_clock()
rc_config = build_rc_config(
@@ -79,7 +82,7 @@ class FederationRateLimiterTestCase(TestCase):
self.assertAlmostEqual(sleep_time, 500, places=3)
-def _await_resolution(reactor, d):
+def _await_resolution(reactor: ThreadedMemoryReactorClock, d: Deferred) -> float:
"""advance the clock until the deferred completes.
Returns the number of milliseconds it took to complete.
@@ -90,7 +93,7 @@ def _await_resolution(reactor, d):
return (reactor.seconds() - start_time) * 1000
-def build_rc_config(settings: Optional[dict] = None):
+def build_rc_config(settings: Optional[dict] = None) -> FederationRatelimitSettings:
config_dict = default_config("test")
config_dict.update(settings or {})
config = HomeServerConfig()
diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py
index 26cb71c640..9529ee53c8 100644
--- a/tests/util/test_retryutils.py
+++ b/tests/util/test_retryutils.py
@@ -22,7 +22,7 @@ from tests.unittest import HomeserverTestCase
class RetryLimiterTestCase(HomeserverTestCase):
- def test_new_destination(self):
+ def test_new_destination(self) -> None:
"""A happy-path case with a new destination and a successful operation"""
store = self.hs.get_datastores().main
limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
@@ -36,7 +36,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.assertIsNone(new_timings)
- def test_limiter(self):
+ def test_limiter(self) -> None:
"""General test case which walks through the process of a failing request"""
store = self.hs.get_datastores().main
diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py
index 5da04362a9..bc93de62eb 100644
--- a/tests/util/test_rwlock.py
+++ b/tests/util/test_rwlock.py
@@ -49,7 +49,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
acquired_d: "Deferred[None]" = Deferred()
unblock_d: "Deferred[None]" = Deferred()
- async def reader_or_writer():
+ async def reader_or_writer() -> str:
async with read_or_write(key):
acquired_d.callback(None)
await unblock_d
@@ -134,7 +134,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
d.called, msg="deferred %d was unexpectedly resolved" % (i + n)
)
- def test_rwlock(self):
+ def test_rwlock(self) -> None:
rwlock = ReadWriteLock()
key = "key"
@@ -197,7 +197,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
_, acquired_d = self._start_nonblocking_reader(rwlock, key, "last reader")
self.assertTrue(acquired_d.called)
- def test_lock_handoff_to_nonblocking_writer(self):
+ def test_lock_handoff_to_nonblocking_writer(self) -> None:
"""Test a writer handing the lock to another writer that completes instantly."""
rwlock = ReadWriteLock()
key = "key"
@@ -216,7 +216,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
d3, _ = self._start_nonblocking_writer(rwlock, key, "write 3 completed")
self.assertTrue(d3.called)
- def test_cancellation_while_holding_read_lock(self):
+ def test_cancellation_while_holding_read_lock(self) -> None:
"""Test cancellation while holding a read lock.
A waiting writer should be given the lock when the reader holding the lock is
@@ -242,7 +242,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
)
self.assertEqual("write completed", self.successResultOf(writer_d))
- def test_cancellation_while_holding_write_lock(self):
+ def test_cancellation_while_holding_write_lock(self) -> None:
"""Test cancellation while holding a write lock.
A waiting reader should be given the lock when the writer holding the lock is
@@ -268,7 +268,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
)
self.assertEqual("read completed", self.successResultOf(reader_d))
- def test_cancellation_while_waiting_for_read_lock(self):
+ def test_cancellation_while_waiting_for_read_lock(self) -> None:
"""Test cancellation while waiting for a read lock.
Tests that cancelling a waiting reader:
@@ -319,7 +319,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
)
self.assertEqual("write 2 completed", self.successResultOf(writer2_d))
- def test_cancellation_while_waiting_for_write_lock(self):
+ def test_cancellation_while_waiting_for_write_lock(self) -> None:
"""Test cancellation while waiting for a write lock.
Tests that cancelling a waiting writer:
diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py
index 9ed01f7e0c..3df053493b 100644
--- a/tests/util/test_stream_change_cache.py
+++ b/tests/util/test_stream_change_cache.py
@@ -8,7 +8,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
Tests for StreamChangeCache.
"""
- def test_prefilled_cache(self):
+ def test_prefilled_cache(self) -> None:
"""
Providing a prefilled cache to StreamChangeCache will result in a cache
with the prefilled-cache entered in.
@@ -16,7 +16,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
cache = StreamChangeCache("#test", 1, prefilled_cache={"user@foo.com": 2})
self.assertTrue(cache.has_entity_changed("user@foo.com", 1))
- def test_has_entity_changed(self):
+ def test_has_entity_changed(self) -> None:
"""
StreamChangeCache.entity_has_changed will mark entities as changed, and
has_entity_changed will observe the changed entities.
@@ -51,8 +51,10 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
# return True, whether it's a known entity or not.
self.assertTrue(cache.has_entity_changed("user@foo.com", 0))
self.assertTrue(cache.has_entity_changed("not@here.website", 0))
+ self.assertTrue(cache.has_entity_changed("user@foo.com", 3))
+ self.assertTrue(cache.has_entity_changed("not@here.website", 3))
- def test_entity_has_changed_pops_off_start(self):
+ def test_entity_has_changed_pops_off_start(self) -> None:
"""
StreamChangeCache.entity_has_changed will respect the max size and
purge the oldest items upon reaching that max size.
@@ -65,15 +67,16 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
# The cache is at the max size, 2
self.assertEqual(len(cache._cache), 2)
+ # The cache's earliest known position is 2.
+ self.assertEqual(cache._earliest_known_stream_pos, 2)
# The oldest item has been popped off
self.assertTrue("user@foo.com" not in cache._entity_to_key)
self.assertEqual(
- cache.get_all_entities_changed(2),
- ["bar@baz.net", "user@elsewhere.org"],
+ cache.get_all_entities_changed(3).entities, ["user@elsewhere.org"]
)
- self.assertIsNone(cache.get_all_entities_changed(1))
+ self.assertFalse(cache.get_all_entities_changed(2).hit)
# If we update an existing entity, it keeps the two existing entities
cache.entity_has_changed("bar@baz.net", 5)
@@ -81,12 +84,12 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
{"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key)
)
self.assertEqual(
- cache.get_all_entities_changed(2),
+ cache.get_all_entities_changed(3).entities,
["user@elsewhere.org", "bar@baz.net"],
)
- self.assertIsNone(cache.get_all_entities_changed(1))
+ self.assertFalse(cache.get_all_entities_changed(2).hit)
- def test_get_all_entities_changed(self):
+ def test_get_all_entities_changed(self) -> None:
"""
StreamChangeCache.get_all_entities_changed will return all changed
entities since the given position. If the position is before the start
@@ -99,28 +102,17 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
cache.entity_has_changed("anotheruser@foo.com", 3)
cache.entity_has_changed("user@elsewhere.org", 4)
- r = cache.get_all_entities_changed(1)
-
- # either of these are valid
- ok1 = [
- "user@foo.com",
- "bar@baz.net",
- "anotheruser@foo.com",
- "user@elsewhere.org",
- ]
- ok2 = [
- "user@foo.com",
- "anotheruser@foo.com",
- "bar@baz.net",
- "user@elsewhere.org",
- ]
- self.assertTrue(r == ok1 or r == ok2)
-
r = cache.get_all_entities_changed(2)
- self.assertTrue(r == ok1[1:] or r == ok2[1:])
- self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"])
- self.assertEqual(cache.get_all_entities_changed(0), None)
+ # Results are ordered so either of these are valid.
+ ok1 = ["bar@baz.net", "anotheruser@foo.com", "user@elsewhere.org"]
+ ok2 = ["anotheruser@foo.com", "bar@baz.net", "user@elsewhere.org"]
+ self.assertTrue(r.entities == ok1 or r.entities == ok2)
+
+ self.assertEqual(
+ cache.get_all_entities_changed(3).entities, ["user@elsewhere.org"]
+ )
+ self.assertFalse(cache.get_all_entities_changed(1).hit)
# ... later, things gest more updates
cache.entity_has_changed("user@foo.com", 5)
@@ -140,9 +132,9 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
"anotheruser@foo.com",
]
r = cache.get_all_entities_changed(3)
- self.assertTrue(r == ok1 or r == ok2)
+ self.assertTrue(r.entities == ok1 or r.entities == ok2)
- def test_has_any_entity_changed(self):
+ def test_has_any_entity_changed(self) -> None:
"""
StreamChangeCache.has_any_entity_changed will return True if any
entities have been changed since the provided stream position, and
@@ -152,9 +144,10 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
"""
cache = StreamChangeCache("#test", 1)
- # With no entities, it returns False for the past, present, and future.
- self.assertFalse(cache.has_any_entity_changed(0))
- self.assertFalse(cache.has_any_entity_changed(1))
+ # With no entities, it returns True for the past, present, and False for
+ # the future.
+ self.assertTrue(cache.has_any_entity_changed(0))
+ self.assertTrue(cache.has_any_entity_changed(1))
self.assertFalse(cache.has_any_entity_changed(2))
# We add an entity
@@ -168,7 +161,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
self.assertFalse(cache.has_any_entity_changed(2))
self.assertFalse(cache.has_any_entity_changed(3))
- def test_get_entities_changed(self):
+ def test_get_entities_changed(self) -> None:
"""
StreamChangeCache.get_entities_changed will return the entities in the
given list that have changed since the provided stream ID. If the
@@ -228,7 +221,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
{"bar@baz.net"},
)
- def test_max_pos(self):
+ def test_max_pos(self) -> None:
"""
StreamChangeCache.get_max_pos_of_last_change will return the most
recent point where the entity could have changed. If the entity is not
diff --git a/tests/util/test_stringutils.py b/tests/util/test_stringutils.py
index ad4dd7f007..f137e05191 100644
--- a/tests/util/test_stringutils.py
+++ b/tests/util/test_stringutils.py
@@ -19,7 +19,7 @@ from .. import unittest
class StringUtilsTestCase(unittest.TestCase):
- def test_client_secret_regex(self):
+ def test_client_secret_regex(self) -> None:
"""Ensure that client_secret does not contain illegal characters"""
good = [
"abcde12345",
@@ -46,7 +46,7 @@ class StringUtilsTestCase(unittest.TestCase):
with self.assertRaises(SynapseError):
assert_valid_client_secret(client_secret)
- def test_base62_encode(self):
+ def test_base62_encode(self) -> None:
self.assertEqual("0", base62_encode(0))
self.assertEqual("10", base62_encode(62))
self.assertEqual("1c", base62_encode(100))
diff --git a/tests/util/test_threepids.py b/tests/util/test_threepids.py
index d957b953bb..3b35b8e4ec 100644
--- a/tests/util/test_threepids.py
+++ b/tests/util/test_threepids.py
@@ -18,31 +18,31 @@ from tests.unittest import HomeserverTestCase
class CanonicaliseEmailTests(HomeserverTestCase):
- def test_no_at(self):
+ def test_no_at(self) -> None:
with self.assertRaises(ValueError):
canonicalise_email("address-without-at.bar")
- def test_two_at(self):
+ def test_two_at(self) -> None:
with self.assertRaises(ValueError):
canonicalise_email("foo@foo@test.bar")
- def test_bad_format(self):
+ def test_bad_format(self) -> None:
with self.assertRaises(ValueError):
canonicalise_email("user@bad.example.net@good.example.com")
- def test_valid_format(self):
+ def test_valid_format(self) -> None:
self.assertEqual(canonicalise_email("foo@test.bar"), "foo@test.bar")
- def test_domain_to_lower(self):
+ def test_domain_to_lower(self) -> None:
self.assertEqual(canonicalise_email("foo@TEST.BAR"), "foo@test.bar")
- def test_domain_with_umlaut(self):
+ def test_domain_with_umlaut(self) -> None:
self.assertEqual(canonicalise_email("foo@Öumlaut.com"), "foo@öumlaut.com")
- def test_address_casefold(self):
+ def test_address_casefold(self) -> None:
self.assertEqual(
canonicalise_email("Strauß@Example.com"), "strauss@example.com"
)
- def test_address_trim(self):
+ def test_address_trim(self) -> None:
self.assertEqual(canonicalise_email(" foo@test.bar "), "foo@test.bar")
diff --git a/tests/util/test_treecache.py b/tests/util/test_treecache.py
index 567cb18468..fe3b4dc6a4 100644
--- a/tests/util/test_treecache.py
+++ b/tests/util/test_treecache.py
@@ -19,7 +19,7 @@ from .. import unittest
class TreeCacheTestCase(unittest.TestCase):
- def test_get_set_onelevel(self):
+ def test_get_set_onelevel(self) -> None:
cache = TreeCache()
cache[("a",)] = "A"
cache[("b",)] = "B"
@@ -27,7 +27,7 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEqual(cache.get(("b",)), "B")
self.assertEqual(len(cache), 2)
- def test_pop_onelevel(self):
+ def test_pop_onelevel(self) -> None:
cache = TreeCache()
cache[("a",)] = "A"
cache[("b",)] = "B"
@@ -36,7 +36,7 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEqual(cache.get(("b",)), "B")
self.assertEqual(len(cache), 1)
- def test_get_set_twolevel(self):
+ def test_get_set_twolevel(self) -> None:
cache = TreeCache()
cache[("a", "a")] = "AA"
cache[("a", "b")] = "AB"
@@ -46,7 +46,7 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEqual(cache.get(("b", "a")), "BA")
self.assertEqual(len(cache), 3)
- def test_pop_twolevel(self):
+ def test_pop_twolevel(self) -> None:
cache = TreeCache()
cache[("a", "a")] = "AA"
cache[("a", "b")] = "AB"
@@ -58,7 +58,7 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEqual(cache.pop(("b", "a")), None)
self.assertEqual(len(cache), 1)
- def test_pop_mixedlevel(self):
+ def test_pop_mixedlevel(self) -> None:
cache = TreeCache()
cache[("a", "a")] = "AA"
cache[("a", "b")] = "AB"
@@ -72,14 +72,14 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEqual({"AA", "AB"}, set(iterate_tree_cache_entry(popped)))
- def test_clear(self):
+ def test_clear(self) -> None:
cache = TreeCache()
cache[("a",)] = "A"
cache[("b",)] = "B"
cache.clear()
self.assertEqual(len(cache), 0)
- def test_contains(self):
+ def test_contains(self) -> None:
cache = TreeCache()
cache[("a",)] = "A"
self.assertTrue(("a",) in cache)
diff --git a/tests/util/test_wheel_timer.py b/tests/util/test_wheel_timer.py
index 0d5039de04..c9d22b6d8c 100644
--- a/tests/util/test_wheel_timer.py
+++ b/tests/util/test_wheel_timer.py
@@ -18,8 +18,8 @@ from .. import unittest
class WheelTimerTestCase(unittest.TestCase):
- def test_single_insert_fetch(self):
- wheel = WheelTimer(bucket_size=5)
+ def test_single_insert_fetch(self) -> None:
+ wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
obj = object()
wheel.insert(100, obj, 150)
@@ -32,8 +32,8 @@ class WheelTimerTestCase(unittest.TestCase):
self.assertListEqual(wheel.fetch(156), [obj])
self.assertListEqual(wheel.fetch(170), [])
- def test_multi_insert(self):
- wheel = WheelTimer(bucket_size=5)
+ def test_multi_insert(self) -> None:
+ wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
obj1 = object()
obj2 = object()
@@ -50,15 +50,15 @@ class WheelTimerTestCase(unittest.TestCase):
self.assertListEqual(wheel.fetch(200), [obj3])
self.assertListEqual(wheel.fetch(210), [])
- def test_insert_past(self):
- wheel = WheelTimer(bucket_size=5)
+ def test_insert_past(self) -> None:
+ wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
obj = object()
wheel.insert(100, obj, 50)
self.assertListEqual(wheel.fetch(120), [obj])
- def test_insert_past_multi(self):
- wheel = WheelTimer(bucket_size=5)
+ def test_insert_past_multi(self) -> None:
+ wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
obj1 = object()
obj2 = object()
diff --git a/tests/utils.py b/tests/utils.py
index 045a8b5fa7..d76bf9716a 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -125,7 +125,8 @@ def default_config(
"""
config_dict = {
"server_name": name,
- "send_federation": False,
+ # Setting this to an empty list turns off federation sending.
+ "federation_sender_instances": [],
"media_store_path": "media",
# the test signing key is just an arbitrary ed25519 key to keep the config
# parser happy
@@ -183,8 +184,9 @@ def default_config(
# rooms will fail.
"default_room_version": DEFAULT_ROOM_VERSION,
# disable user directory updates, because they get done in the
- # background, which upsets the test runner.
- "update_user_directory": False,
+ # background, which upsets the test runner. Setting this to an
+ # (obviously) fake worker name disables updating the user directory.
+ "update_user_directory_from_worker": "does_not_exist_worker_name",
"caches": {"global_factor": 1, "sync_response_cache_duration": 0},
"listeners": [{"port": 0, "type": "http"}],
}
|